Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Form, Request,Depends,HTTPException,BackgroundTasks | |
| from service import ChatService | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from request import RequestChat | |
| from fastapi.requests import Request | |
| from fastapi.responses import JSONResponse | |
| import asyncio | |
| router = APIRouter() | |
| import asyncio | |
| # Lưu task đang chạy theo chat_history_id | |
| task_registry: dict[str, asyncio.Task] = {} | |
| # Lưu stop_event để dừng từng task | |
| stop_events: dict[str, asyncio.Event] = {} | |
| import decode_token | |
| from models.Database_Entity import StopSignal | |
| from models.Database_Entity import ChatHistory, User | |
| from repository.MySQL import UserRepository | |
| class JWTBearer(HTTPBearer): | |
| def __init__(self, auto_error: bool = True): | |
| super(JWTBearer, self).__init__(auto_error=auto_error) | |
| async def __call__(self, request: Request): | |
| credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request) | |
| if credentials: | |
| if credentials.scheme != "Bearer": | |
| raise HTTPException(status_code=401, detail="Invalid authentication scheme.") | |
| return credentials.credentials | |
| else: | |
| raise HTTPException(status_code=401, detail="Invalid authorization code.") | |
| jwt_bearer = JWTBearer() | |
| async def get_user_chat_history(token: str = Depends(jwt_bearer)): | |
| try: | |
| user_id_token = decode_token.JwtService.extract_user_id(token) | |
| return await ChatService.get_user_chat_history(user_id_token) | |
| except Exception as e: | |
| print("Lỗi khi gọi user history chat:", e) | |
| return JSONResponse(content={"error": "Internal server error"}, status_code=500) | |
| async def create_chat(token: str = Depends(jwt_bearer)): | |
| try: | |
| user_id_token = decode_token.JwtService.extract_user_id(token) | |
| new_chat = await ChatService.create_new_chat_history(user_id_token) | |
| return new_chat | |
| except Exception as e: | |
| print("Lỗi khi gọi create new chat:", e) | |
| return JSONResponse(content={"error": "Internal server error"}, status_code=500) | |
| from bson import ObjectId | |
| async def question( | |
| request: RequestChat.ChatWithServer, | |
| background_tasks: BackgroundTasks, | |
| token: str = Depends(jwt_bearer) | |
| ): | |
| try: | |
| user_id_token = decode_token.JwtService.extract_user_id(token) | |
| user_role = decode_token.JwtService.extract_user_role(token) | |
| chat_id = request.chat_history_id | |
| stop_event = asyncio.Event() | |
| stop_events[chat_id] = stop_event | |
| print(f"[CREATE] stop_event id: {id(stop_event)} for chat_id: {chat_id}") | |
| chat_history = ChatHistory.objects(pk=ObjectId(chat_id)).first() | |
| if chat_history: | |
| signal = StopSignal.objects(chat_history=chat_history).first() | |
| if not signal: | |
| signal = StopSignal(chat_history=chat_history) | |
| signal.is_stopped = False | |
| signal.stopped_at = None | |
| signal.save() | |
| async def run_chat(): | |
| try: | |
| result = await ChatService.chat_with_user( | |
| request.user_input, | |
| user_id_token, | |
| request.language, | |
| user_role, | |
| token, | |
| chat_id, | |
| stop_event | |
| ) | |
| return result | |
| except asyncio.CancelledError: | |
| print(f"🛑 Task {chat_id} was cancelled by asyncio.") | |
| return {"status": "cancelled"} | |
| except Exception as e: | |
| print(f"❌ Lỗi trong task {chat_id}:", e) | |
| return {"error": str(e)} | |
| finally: | |
| # Dọn dẹp | |
| stop_events.pop(chat_id, None) | |
| task_registry.pop(chat_id, None) | |
| task = asyncio.create_task(run_chat()) | |
| task_registry[chat_id] = task | |
| return await task | |
| except Exception as e: | |
| print("Lỗi khi chạy:", e) | |
| return JSONResponse(content={"error": "Internal server error"}, status_code=500) | |
| from datetime import datetime | |
| async def stop_task(chat_history_id: str, token: str = Depends(jwt_bearer)): | |
| user_id = int(decode_token.JwtService.extract_user_id(token)) | |
| if not isinstance(user_id, int) or user_id <= 0: | |
| raise HTTPException(status_code=400, detail="Invalid user_id: must be a positive integer") | |
| check = UserRepository.getUserByUserId(user_id) | |
| if check is None: | |
| raise HTTPException(status_code=400, detail="User not found or has been deleted in MySQL") | |
| check_history_id = UserRepository.getChatHistory(user_id,chat_history_id) | |
| if check_history_id is None: | |
| raise HTTPException(status_code=400, detail="Chat not found or has been deleted in MySQL") | |
| user = User.objects(user_id=user_id).first() | |
| if not user: | |
| return {"error": "User not found or has been deleted in MongoDB"} | |
| event = stop_events.get(chat_history_id) | |
| task = task_registry.get(chat_history_id) | |
| print(f"🚨 Đã vào stop-task với chat_history_id = {chat_history_id}") | |
| print(f"🔎 stop_event: {event}, task: {task}") | |
| # Set event trong RAM | |
| if event: | |
| print(f"🛑 Setting stop_event for {chat_history_id}") | |
| event.set() | |
| # Cancel task | |
| if task: | |
| print(f"🔪 Cancelling task for {chat_history_id}") | |
| task.cancel() | |
| # Cập nhật trạng thái stop vào MongoDB | |
| from bson import ObjectId | |
| chat_history = ChatHistory.objects(pk=ObjectId(chat_history_id)).first() | |
| if chat_history: | |
| signal = StopSignal.objects(chat_history=chat_history).first() | |
| if not signal: | |
| signal = StopSignal(chat_history=chat_history) | |
| signal.is_stopped = True | |
| signal.stopped_at = datetime.utcnow() | |
| signal.save() | |
| return {"message": f"Stop signal sent for chat_history_id {chat_history_id}"} | |
| async def regenerate(request: RequestChat.Regenerate, token: str = Depends(jwt_bearer)): | |
| try: | |
| user_id_token = decode_token.JwtService.extract_user_id(token) | |
| user_role = decode_token.JwtService.extract_user_role(token) | |
| chat_id = request.chat_id | |
| stop_event = asyncio.Event() | |
| stop_events[chat_id] = stop_event | |
| print(f"[CREATE] stop_event id: {id(stop_event)} for chat_id: {chat_id}") | |
| chat_history = ChatHistory.objects(pk=ObjectId(chat_id)).first() | |
| if chat_history: | |
| signal = StopSignal.objects(chat_history=chat_history).first() | |
| if not signal: | |
| signal = StopSignal(chat_history=chat_history) | |
| signal.is_stopped = False | |
| signal.stopped_at = None | |
| signal.save() | |
| async def run_chat(): | |
| try: | |
| new_chat = await ChatService.regenerate(request.question_new ,user_id_token,request.languages ,user_role,token,request.chat_id,stop_event) | |
| return new_chat | |
| except asyncio.CancelledError: | |
| print(f"🛑 Task {chat_id} was cancelled by asyncio.") | |
| return {"status": "cancelled"} | |
| except Exception as e: | |
| print(f"❌ Lỗi trong task {chat_id}:", e) | |
| return {"error": str(e)} | |
| task = asyncio.create_task(run_chat()) | |
| task_registry[chat_id] = task | |
| return await task | |
| except Exception as e: | |
| print("Lỗi khi gọi regenerate:", e) | |
| return JSONResponse(content={"error": "Internal server error"}, status_code=500) | |
| async def update_chat_name(request: RequestChat.UpdateNameChat, token: str = Depends(jwt_bearer)): | |
| try: | |
| user_id_token = decode_token.JwtService.extract_user_id(token) | |
| updated_chat = await ChatService.update_chat_name(request.chat_id, request.name_chat,user_id_token) | |
| return updated_chat | |
| except Exception as e: | |
| print("Lỗi khi gọi update:", e) | |
| return JSONResponse(content={"error": "Internal server error"}, status_code=500) | |
| async def delete_chat(request: RequestChat.DeleteChatRequest, token: str = Depends(jwt_bearer)): | |
| try: | |
| user_id_token = decode_token.JwtService.extract_user_id(token) | |
| deleted_chat = await ChatService.soft_delete_chat(request.chat_id,user_id_token) | |
| return deleted_chat | |
| except Exception as e: | |
| print("Lỗi khi gọi deleted:", e) | |
| return JSONResponse(content={"error": "Internal server error"}, status_code=500) | |
| async def get_chat_details(chat_id: str, token: str = Depends(jwt_bearer)): | |
| try: | |
| user_id_token = decode_token.JwtService.extract_user_id(token) | |
| return await ChatService.get_chat_details(chat_id,user_id_token) | |
| except Exception as e: | |
| print("Lỗi khi gọi list_detail_chat:", e) | |
| return JSONResponse(content={"error": "Internal server error"}, status_code=500) |