from fastapi import APIRouter, Form, Request,Depends,HTTPException,BackgroundTasks from service import ChatService from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from request import RequestChat from fastapi.requests import Request from fastapi.responses import JSONResponse import asyncio router = APIRouter() import asyncio # Lưu task đang chạy theo chat_history_id task_registry: dict[str, asyncio.Task] = {} # Lưu stop_event để dừng từng task stop_events: dict[str, asyncio.Event] = {} import decode_token from models.Database_Entity import StopSignal from models.Database_Entity import ChatHistory, User from repository.MySQL import UserRepository class JWTBearer(HTTPBearer): def __init__(self, auto_error: bool = True): super(JWTBearer, self).__init__(auto_error=auto_error) async def __call__(self, request: Request): credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request) if credentials: if credentials.scheme != "Bearer": raise HTTPException(status_code=401, detail="Invalid authentication scheme.") return credentials.credentials else: raise HTTPException(status_code=401, detail="Invalid authorization code.") jwt_bearer = JWTBearer() @router.get("/chat/user_history") async def get_user_chat_history(token: str = Depends(jwt_bearer)): try: user_id_token = decode_token.JwtService.extract_user_id(token) return await ChatService.get_user_chat_history(user_id_token) except Exception as e: print("Lỗi khi gọi user history chat:", e) return JSONResponse(content={"error": "Internal server error"}, status_code=500) @router.post("/new_chat/create/") async def create_chat(token: str = Depends(jwt_bearer)): try: user_id_token = decode_token.JwtService.extract_user_id(token) new_chat = await ChatService.create_new_chat_history(user_id_token) return new_chat except Exception as e: print("Lỗi khi gọi create new chat:", e) return JSONResponse(content={"error": "Internal server error"}, status_code=500) from bson import ObjectId @router.post("/question") async def question( request: RequestChat.ChatWithServer, background_tasks: BackgroundTasks, token: str = Depends(jwt_bearer) ): try: user_id_token = decode_token.JwtService.extract_user_id(token) user_role = decode_token.JwtService.extract_user_role(token) chat_id = request.chat_history_id stop_event = asyncio.Event() stop_events[chat_id] = stop_event print(f"[CREATE] stop_event id: {id(stop_event)} for chat_id: {chat_id}") chat_history = ChatHistory.objects(pk=ObjectId(chat_id)).first() if chat_history: signal = StopSignal.objects(chat_history=chat_history).first() if not signal: signal = StopSignal(chat_history=chat_history) signal.is_stopped = False signal.stopped_at = None signal.save() async def run_chat(): try: result = await ChatService.chat_with_user( request.user_input, user_id_token, request.language, user_role, token, chat_id, stop_event ) return result except asyncio.CancelledError: print(f"🛑 Task {chat_id} was cancelled by asyncio.") return {"status": "cancelled"} except Exception as e: print(f"❌ Lỗi trong task {chat_id}:", e) return {"error": str(e)} finally: # Dọn dẹp stop_events.pop(chat_id, None) task_registry.pop(chat_id, None) task = asyncio.create_task(run_chat()) task_registry[chat_id] = task return await task except Exception as e: print("Lỗi khi chạy:", e) return JSONResponse(content={"error": "Internal server error"}, status_code=500) from datetime import datetime @router.post("/stop-task/{chat_history_id}") async def stop_task(chat_history_id: str, token: str = Depends(jwt_bearer)): user_id = int(decode_token.JwtService.extract_user_id(token)) if not isinstance(user_id, int) or user_id <= 0: raise HTTPException(status_code=400, detail="Invalid user_id: must be a positive integer") check = UserRepository.getUserByUserId(user_id) if check is None: raise HTTPException(status_code=400, detail="User not found or has been deleted in MySQL") check_history_id = UserRepository.getChatHistory(user_id,chat_history_id) if check_history_id is None: raise HTTPException(status_code=400, detail="Chat not found or has been deleted in MySQL") user = User.objects(user_id=user_id).first() if not user: return {"error": "User not found or has been deleted in MongoDB"} event = stop_events.get(chat_history_id) task = task_registry.get(chat_history_id) print(f"🚨 Đã vào stop-task với chat_history_id = {chat_history_id}") print(f"🔎 stop_event: {event}, task: {task}") # Set event trong RAM if event: print(f"🛑 Setting stop_event for {chat_history_id}") event.set() # Cancel task if task: print(f"🔪 Cancelling task for {chat_history_id}") task.cancel() # Cập nhật trạng thái stop vào MongoDB from bson import ObjectId chat_history = ChatHistory.objects(pk=ObjectId(chat_history_id)).first() if chat_history: signal = StopSignal.objects(chat_history=chat_history).first() if not signal: signal = StopSignal(chat_history=chat_history) signal.is_stopped = True signal.stopped_at = datetime.utcnow() signal.save() return {"message": f"Stop signal sent for chat_history_id {chat_history_id}"} @router.put("/regenerate") async def regenerate(request: RequestChat.Regenerate, token: str = Depends(jwt_bearer)): try: user_id_token = decode_token.JwtService.extract_user_id(token) user_role = decode_token.JwtService.extract_user_role(token) chat_id = request.chat_id stop_event = asyncio.Event() stop_events[chat_id] = stop_event print(f"[CREATE] stop_event id: {id(stop_event)} for chat_id: {chat_id}") chat_history = ChatHistory.objects(pk=ObjectId(chat_id)).first() if chat_history: signal = StopSignal.objects(chat_history=chat_history).first() if not signal: signal = StopSignal(chat_history=chat_history) signal.is_stopped = False signal.stopped_at = None signal.save() async def run_chat(): try: new_chat = await ChatService.regenerate(request.question_new ,user_id_token,request.languages ,user_role,token,request.chat_id,stop_event) return new_chat except asyncio.CancelledError: print(f"🛑 Task {chat_id} was cancelled by asyncio.") return {"status": "cancelled"} except Exception as e: print(f"❌ Lỗi trong task {chat_id}:", e) return {"error": str(e)} task = asyncio.create_task(run_chat()) task_registry[chat_id] = task return await task except Exception as e: print("Lỗi khi gọi regenerate:", e) return JSONResponse(content={"error": "Internal server error"}, status_code=500) @router.put("/update") async def update_chat_name(request: RequestChat.UpdateNameChat, token: str = Depends(jwt_bearer)): try: user_id_token = decode_token.JwtService.extract_user_id(token) updated_chat = await ChatService.update_chat_name(request.chat_id, request.name_chat,user_id_token) return updated_chat except Exception as e: print("Lỗi khi gọi update:", e) return JSONResponse(content={"error": "Internal server error"}, status_code=500) @router.delete("/delete") async def delete_chat(request: RequestChat.DeleteChatRequest, token: str = Depends(jwt_bearer)): try: user_id_token = decode_token.JwtService.extract_user_id(token) deleted_chat = await ChatService.soft_delete_chat(request.chat_id,user_id_token) return deleted_chat except Exception as e: print("Lỗi khi gọi deleted:", e) return JSONResponse(content={"error": "Internal server error"}, status_code=500) @router.get("/list_detail_chat/{chat_id}") async def get_chat_details(chat_id: str, token: str = Depends(jwt_bearer)): try: user_id_token = decode_token.JwtService.extract_user_id(token) return await ChatService.get_chat_details(chat_id,user_id_token) except Exception as e: print("Lỗi khi gọi list_detail_chat:", e) return JSONResponse(content={"error": "Internal server error"}, status_code=500)