Spaces:
Sleeping
Sleeping
| from fastapi.responses import StreamingResponse | |
| import asyncio | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks | |
| from pydantic import BaseModel | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import os | |
| import shutil | |
| # Load biến môi trường nếu có file .env (chạy local) | |
| env_path = os.path.join(os.path.dirname(__file__), "..", ".env") | |
| if os.path.exists(env_path): | |
| from dotenv import load_dotenv | |
| load_dotenv(env_path) | |
| from agent import create_chatbot_agent | |
| from rag_tool import get_list_files, process_new_file, DATA_DIR | |
| app = FastAPI(title="AI Staff Assistant API") | |
| # Cấu hình CORS để frontend gọi được | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Khởi tạo agent | |
| print("Đang khởi tạo Agentic Chatbot...") | |
| agent = create_chatbot_agent() | |
| print("Khởi tạo hoàn tất!") | |
| class ChatRequest(BaseModel): | |
| session_id : str | |
| message: str | |
| class ChatResponse(BaseModel): | |
| reply: str | |
| def extract_text(content): | |
| if content is None: | |
| return "" | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, dict): | |
| return content.get("text", "") | |
| if isinstance(content, list): | |
| parts = [] | |
| for item in content: | |
| if isinstance(item, str): | |
| parts.append(item) | |
| elif isinstance(item, dict): | |
| # Gemini hay trả dạng: | |
| # {'type': 'text', 'text': '...', 'extras': {...}} | |
| if "text" in item: | |
| parts.append(item["text"]) | |
| return "".join(parts) | |
| return str(content) | |
| async def chat_stream_endpoint(request: ChatRequest): | |
| """ | |
| Endpoint trả về kết quả dạng Stream. | |
| Vừa hiển thị trạng thái gọi Tool, vừa stream từng chữ (token) của câu trả lời. | |
| """ | |
| async def event_stream(): | |
| config = {"configurable": {"thread_id": request.session_id}} | |
| try: | |
| # Sử dụng astream_events v2 để theo dõi chính xác hành vi của Agent | |
| async for event in agent.astream_events( | |
| {"messages": [("user", request.message)]}, | |
| config=config, | |
| version="v2" | |
| ): | |
| kind = event["event"] | |
| # --- TRƯỜNG HỢP 1: Bắt đầu gọi Tools --- | |
| if kind == "on_tool_start": | |
| tool_name = event["name"] | |
| if tool_name == "restaurant_documents_search": | |
| # Remove markdown '***' if present | |
| yield "⏳ Đang lật sổ tay công thức...\n\n" | |
| elif "sql" in tool_name.lower(): | |
| yield "⏳ Đang tra cứu dữ liệu kho hệ thống...\n\n" | |
| # --- TRƯỜNG HỢP 2: Bắt đầu Stream chữ từ LLM --- | |
| elif kind == "on_chat_model_stream": | |
| import re | |
| chunk = event["data"]["chunk"] | |
| # Kiểm tra: Chỉ stream nếu chunk có nội dung văn bản | |
| # và KHÔNG PHẢI là chuỗi JSON đang gọi hàm (tool_calls) | |
| if chunk.content and not chunk.tool_calls: | |
| # Chỉ lấy phần text nếu là dict | |
| text = extract_text(chunk.content) | |
| if text: | |
| yield text | |
| except Exception as e: | |
| import traceback | |
| print(f"Lỗi Stream: {e}") | |
| traceback.print_exc() | |
| # Thông báo riêng cho lỗi model quá tải (503) | |
| if "503" in str(e) or "high demand" in str(e) or "UNAVAILABLE" in str(e): | |
| yield "\n\n❌ Xin lỗi, mô hình AI (Gemini) đang quá tải hoặc bị giới hạn. Vui lòng thử lại sau vài phút." | |
| else: | |
| yield "\n\n❌ Xin lỗi, có lỗi xảy ra trong quá trình kết nối hoặc xử lý." | |
| # Trả về kết nối dạng luồng (Server-Sent Events / Stream) | |
| return StreamingResponse(event_stream(), media_type="text/plain") | |
| async def get_files(): | |
| """Lấy danh sách tài liệu RAG.""" | |
| try: | |
| files = get_list_files() | |
| return {"status": "success", "data": files} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): | |
| """Upload tài liệu mới và nhúng vào DB.""" | |
| try: | |
| if not os.path.exists(DATA_DIR): | |
| os.makedirs(DATA_DIR) | |
| file_location = os.path.join(DATA_DIR, file.filename) | |
| with open(file_location, "wb+") as file_object: | |
| shutil.copyfileobj(file.file, file_object) | |
| # Ném việc nặng (đọc file, embedding) xuống background | |
| background_tasks.add_task(process_new_file, file.filename) | |
| # Xử lý nhúng vào ChromaDB | |
| success = process_new_file(file.filename) | |
| if success: | |
| return {"status": "success", "message": f"Đã upload và huấn luyện thành công tài liệu {file.filename}"} | |
| else: | |
| raise HTTPException(status_code=500, detail="Không thể xử lý dữ liệu của file này.") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Trên Hugging Face Spaces mặc định yêu cầu dùng port 7860 | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860) | |