NTThong0710
m
d5c2751
from fastapi.responses import StreamingResponse
import asyncio
from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import os
import shutil
# Load biến môi trường nếu có file .env (chạy local)
env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
if os.path.exists(env_path):
from dotenv import load_dotenv
load_dotenv(env_path)
from agent import create_chatbot_agent
from rag_tool import get_list_files, process_new_file, DATA_DIR
app = FastAPI(title="AI Staff Assistant API")
# Cấu hình CORS để frontend gọi được
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Khởi tạo agent
print("Đang khởi tạo Agentic Chatbot...")
agent = create_chatbot_agent()
print("Khởi tạo hoàn tất!")
class ChatRequest(BaseModel):
session_id : str
message: str
class ChatResponse(BaseModel):
reply: str
def extract_text(content):
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, dict):
return content.get("text", "")
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
# Gemini hay trả dạng:
# {'type': 'text', 'text': '...', 'extras': {...}}
if "text" in item:
parts.append(item["text"])
return "".join(parts)
return str(content)
@app.post("/chat/stream")
async def chat_stream_endpoint(request: ChatRequest):
"""
Endpoint trả về kết quả dạng Stream.
Vừa hiển thị trạng thái gọi Tool, vừa stream từng chữ (token) của câu trả lời.
"""
async def event_stream():
config = {"configurable": {"thread_id": request.session_id}}
try:
# Sử dụng astream_events v2 để theo dõi chính xác hành vi của Agent
async for event in agent.astream_events(
{"messages": [("user", request.message)]},
config=config,
version="v2"
):
kind = event["event"]
# --- TRƯỜNG HỢP 1: Bắt đầu gọi Tools ---
if kind == "on_tool_start":
tool_name = event["name"]
if tool_name == "restaurant_documents_search":
# Remove markdown '***' if present
yield "⏳ Đang lật sổ tay công thức...\n\n"
elif "sql" in tool_name.lower():
yield "⏳ Đang tra cứu dữ liệu kho hệ thống...\n\n"
# --- TRƯỜNG HỢP 2: Bắt đầu Stream chữ từ LLM ---
elif kind == "on_chat_model_stream":
import re
chunk = event["data"]["chunk"]
# Kiểm tra: Chỉ stream nếu chunk có nội dung văn bản
# và KHÔNG PHẢI là chuỗi JSON đang gọi hàm (tool_calls)
if chunk.content and not chunk.tool_calls:
# Chỉ lấy phần text nếu là dict
text = extract_text(chunk.content)
if text:
yield text
except Exception as e:
import traceback
print(f"Lỗi Stream: {e}")
traceback.print_exc()
# Thông báo riêng cho lỗi model quá tải (503)
if "503" in str(e) or "high demand" in str(e) or "UNAVAILABLE" in str(e):
yield "\n\n❌ Xin lỗi, mô hình AI (Gemini) đang quá tải hoặc bị giới hạn. Vui lòng thử lại sau vài phút."
else:
yield "\n\n❌ Xin lỗi, có lỗi xảy ra trong quá trình kết nối hoặc xử lý."
# Trả về kết nối dạng luồng (Server-Sent Events / Stream)
return StreamingResponse(event_stream(), media_type="text/plain")
@app.get("/files")
async def get_files():
"""Lấy danh sách tài liệu RAG."""
try:
files = get_list_files()
return {"status": "success", "data": files}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
"""Upload tài liệu mới và nhúng vào DB."""
try:
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
file_location = os.path.join(DATA_DIR, file.filename)
with open(file_location, "wb+") as file_object:
shutil.copyfileobj(file.file, file_object)
# Ném việc nặng (đọc file, embedding) xuống background
background_tasks.add_task(process_new_file, file.filename)
# Xử lý nhúng vào ChromaDB
success = process_new_file(file.filename)
if success:
return {"status": "success", "message": f"Đã upload và huấn luyện thành công tài liệu {file.filename}"}
else:
raise HTTPException(status_code=500, detail="Không thể xử lý dữ liệu của file này.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Trên Hugging Face Spaces mặc định yêu cầu dùng port 7860
uvicorn.run("main:app", host="0.0.0.0", port=7860)