| """ |
| FastAPI main application for GAKR AI Chatbot Platform |
| """ |
| import os |
| import sys |
| import json |
| import asyncio |
| import time |
| import re |
| from uuid import uuid4 |
| from datetime import timedelta |
| from typing import Dict, List, Optional |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from fastapi import FastAPI, HTTPException, Depends, File, UploadFile, Form, Request |
| from fastapi.responses import StreamingResponse, JSONResponse, FileResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from pydantic import BaseModel, Field |
|
|
| |
| from config import settings, DEFAULT_SYSTEM_PROMPT |
| from auth import ( |
| authenticate_user, create_access_token, decode_token, |
| register_user, get_user, update_user_settings, get_user_settings, |
| change_password, ACCESS_TOKEN_EXPIRE_MINUTES |
| ) |
| from database import ( |
| get_or_create_user, create_conversation, get_conversations, |
| get_conversation, get_messages, add_message, delete_conversation, |
| update_conversation_title, get_user_stats, get_user_by_username, |
| append_to_last_assistant_message, append_to_assistant_message |
| ) |
| from file_processor import process_file, format_files_for_prompt, save_uploaded_file |
| from model_manager import model_manager |
| from tool_client import tool_client |
|
|
| |
| security = HTTPBearer(auto_error=False) |
| active_chat_stop_requests: Dict[str, dict] = {} |
| CONTINUATION_PREFIX_PATTERNS = [ |
| re.compile(r"^\s*(?:let(?:'|\u2019)?s|let us)\s+continue(?:\s+from\s+where\s+we\s+left\s+off)?[.!:\-\s]*", re.IGNORECASE), |
| re.compile(r"^\s*continuing(?:\s+from\s+where\s+we\s+left\s+off)?[.!:\-\s]*", re.IGNORECASE), |
| re.compile(r"^\s*sure[.!:\-\s]*", re.IGNORECASE), |
| ] |
|
|
| |
| class UserRegister(BaseModel): |
| username: str = Field(..., min_length=3, max_length=30) |
| password: str = Field(..., min_length=6) |
|
|
| class UserLogin(BaseModel): |
| username: str |
| password: str |
|
|
| class ChatRequest(BaseModel): |
| message: str |
| conversation_id: Optional[int] = None |
| temperature: Optional[float] = settings.TEMPERATURE |
| max_tokens: Optional[int] = settings.MAX_TOKENS |
|
|
| class SettingsUpdate(BaseModel): |
| temperature: Optional[float] = None |
| max_tokens: Optional[int] = None |
| system_prompt: Optional[str] = Field(default=None, max_length=4000) |
| theme: Optional[str] = None |
|
|
| class TitleUpdate(BaseModel): |
| title: str = Field(..., min_length=1, max_length=100) |
|
|
| def log_chat_status( |
| stage: str, |
| username: str, |
| conversation_id: Optional[int], |
| **details |
| ): |
| """Structured chat pipeline logs""" |
| payload = { |
| "stage": stage, |
| "user": username, |
| "conversation_id": conversation_id, |
| **details, |
| } |
| print(f"[CHAT] {json.dumps(payload, ensure_ascii=False, default=str)}") |
|
|
| def ensure_db_user(username: str) -> dict: |
| """Ensure auth user exists in DB users table and return DB user row.""" |
| db_user = get_user_by_username(username) |
| if db_user: |
| return db_user |
|
|
| |
| get_or_create_user(username) |
| db_user = get_user_by_username(username) |
| if db_user: |
| return db_user |
|
|
| raise HTTPException(status_code=500, detail="Failed to initialize user in database") |
|
|
| def strip_continuation_prefix(text: str) -> str: |
| """Remove common continuation preambles from model output.""" |
| value = text or "" |
| for pattern in CONTINUATION_PREFIX_PATTERNS: |
| value = pattern.sub("", value, count=1) |
| return value |
|
|
| def build_continuation_prompt( |
| chat_history: List[dict], |
| file_content: str, |
| user_instructions: str, |
| assistant_prefix: str |
| ) -> str: |
| """Build prompt to continue an interrupted assistant response.""" |
| system_parts = [DEFAULT_SYSTEM_PROMPT.strip()] |
| system_parts.append( |
| "## Continuation Mode\n" |
| "- Continue an interrupted assistant response.\n" |
| "- Output only the missing continuation text after the provided prefix.\n" |
| "- Do not restart or repeat prior text.\n" |
| "- Do not prepend phrases like 'Let's continue' or 'Sure'." |
| ) |
| if user_instructions: |
| system_parts.append(f"## Custom Instructions\n{user_instructions.strip()}") |
| system_prompt = "\n\n".join(system_parts) |
|
|
| prompt_parts = [f"SYSTEM: {system_prompt}"] |
| if file_content: |
| prompt_parts.append(f"\n{file_content}") |
|
|
| if chat_history: |
| prompt_parts.append("\n--- Conversation History ---") |
| for msg in chat_history: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
| if role == "user": |
| prompt_parts.append(f"USER: {content}") |
| elif role == "assistant": |
| prompt_parts.append(f"ASSISTANT: {content}") |
|
|
| prompt_parts.append(f"\nASSISTANT: {assistant_prefix}") |
| return "\n".join(prompt_parts) |
|
|
| |
| async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| if not credentials: |
| raise HTTPException(status_code=401, detail="Not authenticated") |
| |
| payload = decode_token(credentials.credentials) |
| if not payload: |
| raise HTTPException(status_code=401, detail="Invalid or expired token") |
| |
| username = payload.get("sub") |
| if not username: |
| raise HTTPException(status_code=401, detail="Invalid token") |
| |
| user = get_user(username) |
| if not user: |
| raise HTTPException(status_code=401, detail="User not found") |
|
|
| |
| ensure_db_user(username) |
| |
| return user |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Startup and shutdown events""" |
| |
| print("=" * 50) |
| print(f"Starting {settings.APP_NAME}") |
| print("=" * 50) |
| |
| |
| if model_manager.is_available: |
| print(f"Checking NVIDIA API at {model_manager.nvidia_base_url}...") |
| if model_manager.load_model(): |
| print(f"NVIDIA API connected: model={model_manager.get_model_info().get('model_name', 'unknown')}") |
| else: |
| print(f"Warning: NVIDIA API not reachable: {model_manager.last_error}") |
| print(f"NVIDIA Base URL: {model_manager.nvidia_base_url}") |
| print("The server will start but model inference will be unavailable until the API is reachable.") |
| else: |
| print("Warning: NVIDIA_API_KEY not configured. Model inference disabled.") |
|
|
| |
| print("Initializing web_search tool...") |
| tools_ok = await tool_client.initialize() |
| if tools_ok: |
| tool_names = ", ".join(tool_client.get_tool_names()) or "none" |
| print(f"Tools ready: {len(tool_client.tools)} tool(s) [{tool_names}]") |
| else: |
| print(f"Tool init failed: {tool_client.init_error}") |
| print("Server will run without tool support.") |
| |
| yield |
| |
| |
| print("Shutting down...") |
| await tool_client.shutdown() |
| model_manager.unload_model() |
|
|
| |
| app = FastAPI( |
| title=settings.APP_NAME, |
| description="AI Chatbot Platform using NVIDIA API inference", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=settings.CORS_ORIGINS, |
| allow_credentials=settings.CORS_ALLOW_CREDENTIALS, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| API_BASE_ENDPOINT = settings.API_BASE_ENDPOINT |
| frontend_dir = Path(settings.FRONTEND_DIR) |
| serve_frontend = settings.SERVE_FRONTEND and frontend_dir.exists() |
|
|
| |
|
|
| @app.get("/") |
| async def root(): |
| """Serve frontend when enabled, otherwise return backend service info.""" |
| if serve_frontend: |
| return FileResponse(frontend_dir / "index.html") |
|
|
| return { |
| "service": settings.APP_NAME, |
| "mode": "backend-only", |
| "frontend_url": settings.FRONTEND_URL, |
| "api_base_endpoint": API_BASE_ENDPOINT, |
| "health_url": f"{API_BASE_ENDPOINT}/health", |
| "docs_url": app.docs_url, |
| } |
|
|
| if serve_frontend: |
| @app.get("/login") |
| async def login_page(): |
| """Serve login page""" |
| return FileResponse(frontend_dir / "login.html") |
|
|
| @app.get("/register") |
| async def register_page(): |
| """Serve register page""" |
| return FileResponse(frontend_dir / "register.html") |
|
|
| @app.get("/profile") |
| async def profile_page(): |
| """Serve profile page""" |
| return FileResponse(frontend_dir / "profile.html") |
|
|
| @app.get("/css/{file_path:path}") |
| async def serve_css(file_path: str): |
| """Serve CSS files""" |
| return FileResponse(frontend_dir / "css" / file_path) |
|
|
| @app.get("/js/{file_path:path}") |
| async def serve_js(file_path: str): |
| """Serve JS files""" |
| return FileResponse(frontend_dir / "js" / file_path) |
|
|
| |
|
|
| |
| @app.post(f"{API_BASE_ENDPOINT}/auth/register") |
| async def api_register(user_data: UserRegister): |
| """Register a new user""" |
| try: |
| user = register_user(user_data.username, user_data.password) |
| |
| |
| db_user_id = get_or_create_user(user_data.username) |
| |
| |
| access_token = create_access_token( |
| data={"sub": user["username"]}, |
| expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
| ) |
| |
| return { |
| "success": True, |
| "message": "User registered successfully", |
| "access_token": access_token, |
| "token_type": "bearer", |
| "username": user["username"] |
| } |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}") |
|
|
| @app.post(f"{API_BASE_ENDPOINT}/auth/login") |
| async def api_login(user_data: UserLogin): |
| """Login user""" |
| user = authenticate_user(user_data.username, user_data.password) |
| if not user: |
| raise HTTPException(status_code=401, detail="Invalid username or password") |
| |
| |
| get_or_create_user(user["username"]) |
| |
| access_token = create_access_token( |
| data={"sub": user["username"]}, |
| expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
| ) |
| |
| return { |
| "success": True, |
| "access_token": access_token, |
| "token_type": "bearer", |
| "username": user["username"] |
| } |
|
|
| @app.post(f"{API_BASE_ENDPOINT}/auth/logout") |
| async def api_logout(current_user: dict = Depends(get_current_user)): |
| """Logout user (client-side token removal)""" |
| return {"success": True, "message": "Logged out successfully"} |
|
|
| @app.get(f"{API_BASE_ENDPOINT}/auth/me") |
| async def api_me(current_user: dict = Depends(get_current_user)): |
| """Get current user info""" |
| return { |
| "username": current_user["username"], |
| "created_at": current_user.get("created_at"), |
| "settings": get_user_settings(current_user["username"]) |
| } |
|
|
| |
| @app.get(f"{API_BASE_ENDPOINT}/user/profile") |
| async def api_profile(current_user: dict = Depends(get_current_user)): |
| """Get user profile with stats""" |
| username = current_user["username"] |
| |
| db_user = ensure_db_user(username) |
| stats = get_user_stats(db_user["id"]) if db_user else {"total_conversations": 0, "total_messages": 0} |
| |
| return { |
| "username": username, |
| "created_at": current_user.get("created_at"), |
| "settings": get_user_settings(username), |
| "stats": stats |
| } |
|
|
| @app.put(f"{API_BASE_ENDPOINT}/user/settings") |
| async def api_update_settings( |
| settings_update: SettingsUpdate, |
| current_user: dict = Depends(get_current_user) |
| ): |
| """Update user settings""" |
| username = current_user["username"] |
| |
| update_data = {} |
| if settings_update.temperature is not None: |
| update_data["temperature"] = max(0.0, min(2.0, settings_update.temperature)) |
| if settings_update.max_tokens is not None: |
| token_limit = model_manager.get_max_generation_tokens_limit() |
| update_data["max_tokens"] = max(1, min(token_limit, settings_update.max_tokens)) |
| if settings_update.system_prompt is not None: |
| update_data["system_prompt"] = settings_update.system_prompt |
| if settings_update.theme is not None: |
| update_data["theme"] = settings_update.theme |
| |
| if update_user_settings(username, update_data): |
| return {"success": True, "settings": get_user_settings(username)} |
| else: |
| raise HTTPException(status_code=400, detail="Failed to update settings") |
|
|
| @app.post(f"{API_BASE_ENDPOINT}/user/change-password") |
| async def api_change_password( |
| old_password: str = Form(...), |
| new_password: str = Form(...), |
| current_user: dict = Depends(get_current_user) |
| ): |
| """Change user password""" |
| try: |
| if change_password(current_user["username"], old_password, new_password): |
| return {"success": True, "message": "Password changed successfully"} |
| else: |
| raise HTTPException(status_code=400, detail="Invalid old password") |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
|
|
| |
| @app.get(f"{API_BASE_ENDPOINT}/conversations") |
| async def api_get_conversations(current_user: dict = Depends(get_current_user)): |
| """Get user's conversations""" |
| db_user = ensure_db_user(current_user["username"]) |
| |
| conversations = get_conversations(db_user["id"]) |
| return {"conversations": conversations} |
|
|
| @app.post(f"{API_BASE_ENDPOINT}/conversations") |
| async def api_create_conversation(current_user: dict = Depends(get_current_user)): |
| """Create a new conversation""" |
| db_user = ensure_db_user(current_user["username"]) |
| |
| conversation_id = create_conversation(db_user["id"], "New Chat") |
| return { |
| "success": True, |
| "conversation_id": conversation_id, |
| "title": "New Chat" |
| } |
|
|
| @app.get(f"{API_BASE_ENDPOINT}/conversations/{{conversation_id}}/messages") |
| async def api_get_messages( |
| conversation_id: int, |
| current_user: dict = Depends(get_current_user) |
| ): |
| """Get messages for a conversation""" |
| db_user = ensure_db_user(current_user["username"]) |
| |
| |
| conversation = get_conversation(conversation_id, db_user["id"]) |
| if not conversation: |
| raise HTTPException(status_code=404, detail="Conversation not found") |
| |
| messages = get_messages(conversation_id) |
| return { |
| "conversation": conversation, |
| "messages": messages |
| } |
|
|
| @app.put(f"{API_BASE_ENDPOINT}/conversations/{{conversation_id}}/title") |
| async def api_update_title( |
| conversation_id: int, |
| title_update: TitleUpdate, |
| current_user: dict = Depends(get_current_user) |
| ): |
| """Update conversation title""" |
| db_user = ensure_db_user(current_user["username"]) |
| |
| if update_conversation_title(conversation_id, db_user["id"], title_update.title): |
| return {"success": True, "title": title_update.title} |
| else: |
| raise HTTPException(status_code=400, detail="Failed to update title") |
|
|
| @app.delete(f"{API_BASE_ENDPOINT}/conversations/{{conversation_id}}") |
| async def api_delete_conversation( |
| conversation_id: int, |
| current_user: dict = Depends(get_current_user) |
| ): |
| """Delete a conversation""" |
| db_user = ensure_db_user(current_user["username"]) |
| |
| if delete_conversation(conversation_id, db_user["id"]): |
| return {"success": True, "message": "Conversation deleted"} |
| else: |
| raise HTTPException(status_code=400, detail="Failed to delete conversation") |
|
|
| |
| @app.post(f"{API_BASE_ENDPOINT}/chat/stop") |
| async def api_chat_stop( |
| request_id: str = Form(...), |
| current_user: dict = Depends(get_current_user) |
| ): |
| """Request cancellation for an active chat stream""" |
| username = current_user["username"] |
| active = active_chat_stop_requests.get(request_id) |
| if not active: |
| return {"success": True, "request_id": request_id, "stopped": False} |
|
|
| if active.get("username") != username: |
| raise HTTPException(status_code=403, detail="Not authorized to stop this request") |
|
|
| stop_event = active.get("event") |
| if stop_event is not None: |
| stop_event.set() |
| log_chat_status( |
| stage="stop_requested", |
| username=username, |
| conversation_id=active.get("conversation_id"), |
| request_id=request_id |
| ) |
| return {"success": True, "request_id": request_id, "stopped": True} |
|
|
| return {"success": True, "request_id": request_id, "stopped": False} |
|
|
| @app.post(f"{API_BASE_ENDPOINT}/chat/stream") |
| async def api_chat_stream( |
| request: Request, |
| message: str = Form(...), |
| conversation_id: Optional[int] = Form(None), |
| temperature: float = Form(settings.TEMPERATURE), |
| max_tokens: int = Form(settings.MAX_TOKENS), |
| request_id: Optional[str] = Form(None), |
| persist_user_message: bool = Form(True), |
| continuation_mode: bool = Form(False), |
| continuation_prefix: Optional[str] = Form(None), |
| continuation_message_id: Optional[int] = Form(None), |
| files: List[UploadFile] = File(default=[]), |
| current_user: dict = Depends(get_current_user) |
| ): |
| """Stream chat response with optional file uploads""" |
| request_start = time.perf_counter() |
| username = current_user["username"] |
| request_id = request_id.strip() if request_id else uuid4().hex |
| |
| |
| user_settings = get_user_settings(username) |
| temp = ( |
| temperature |
| if temperature != settings.TEMPERATURE |
| else user_settings.get("temperature", settings.TEMPERATURE) |
| ) |
| tokens = ( |
| max_tokens |
| if max_tokens != settings.MAX_TOKENS |
| else user_settings.get("max_tokens", settings.MAX_TOKENS) |
| ) |
| tokens = max(1, min(tokens, model_manager.get_max_generation_tokens_limit())) |
| user_instructions = user_settings.get("system_prompt", "") |
| log_chat_status( |
| stage="request_received", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| user_query_chars=len(message), |
| file_count=len([f for f in files if getattr(f, "filename", None)]), |
| user_instruction_chars=len(user_instructions), |
| requested_max_tokens=max_tokens, |
| resolved_initial_max_tokens=tokens, |
| temperature=temp, |
| continuation_mode=continuation_mode |
| ) |
| |
| |
| db_user = ensure_db_user(username) |
| |
| |
| if not conversation_id: |
| conversation_id = create_conversation(db_user["id"], message[:50] + "..." if len(message) > 50 else message) |
| else: |
| |
| conversation = get_conversation(conversation_id, db_user["id"]) |
| if not conversation: |
| raise HTTPException(status_code=404, detail="Conversation not found") |
| log_chat_status( |
| stage="conversation_ready", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id |
| ) |
| |
| |
| file_results = [] |
| file_text_chars = 0 |
| saved_file_count = 0 |
| if files: |
| for file in files: |
| if file.filename: |
| content = await file.read() |
| saved_path = None |
| save_error = None |
| try: |
| saved_path = save_uploaded_file(content, file.filename) |
| except Exception as exc: |
| save_error = str(exc) |
|
|
| result = process_file(content, file.filename) |
| result["saved_permanently"] = saved_path is not None |
| if saved_path is not None: |
| result["saved_path"] = str(saved_path) |
| result["saved_filename"] = saved_path.name |
| saved_file_count += 1 |
| if save_error: |
| result["save_error"] = save_error |
| file_results.append(result) |
| file_text_chars += len(result.get("content", "")) |
| log_chat_status( |
| stage="files_processed", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| file_count=len(file_results), |
| saved_file_count=saved_file_count, |
| file_text_chars=file_text_chars |
| ) |
| |
| |
| file_content = format_files_for_prompt(file_results) |
| |
| |
| all_messages = get_messages(conversation_id, limit=None) |
| chat_history = [ |
| {"role": msg["role"], "content": msg["content"]} |
| for msg in all_messages |
| if msg.get("role") in {"user", "assistant"} |
| ] |
| history_chars = sum(len(msg.get("content", "")) for msg in chat_history) |
| log_chat_status( |
| stage="history_loaded", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| history_messages_total=len(chat_history), |
| history_chars_total=history_chars |
| ) |
|
|
| continuation_prefix_text = (continuation_prefix or "").strip() |
| if continuation_mode and not continuation_prefix_text: |
| for msg in reversed(chat_history): |
| if msg.get("role") == "assistant" and msg.get("content", "").strip(): |
| continuation_prefix_text = msg["content"] |
| break |
|
|
| if continuation_mode and continuation_prefix_text and chat_history: |
| last_msg = chat_history[-1] |
| if ( |
| last_msg.get("role") == "assistant" |
| and ( |
| last_msg.get("content", "") == continuation_prefix_text |
| or last_msg.get("content", "").endswith(continuation_prefix_text) |
| or continuation_prefix_text.endswith(last_msg.get("content", "")) |
| ) |
| ): |
| chat_history = chat_history[:-1] |
|
|
| effective_continuation_mode = bool(continuation_mode and continuation_prefix_text) |
| |
| |
| if effective_continuation_mode: |
| prompt = build_continuation_prompt( |
| chat_history=chat_history, |
| file_content=file_content, |
| user_instructions=user_instructions, |
| assistant_prefix=continuation_prefix_text |
| ) |
| prompt_meta = { |
| "history_messages_total": len(chat_history), |
| "history_messages_used": len(chat_history), |
| "truncated": False |
| } |
| else: |
| prompt = model_manager.build_prompt( |
| query=message, |
| history=chat_history, |
| file_content=file_content, |
| custom_instructions=user_instructions, |
| ) |
| prompt_meta = model_manager.last_prompt_meta |
| tokens = model_manager.resolve_max_tokens(prompt, tokens) |
| log_chat_status( |
| stage="prompt_built", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| default_system_prompt_chars=len(DEFAULT_SYSTEM_PROMPT), |
| user_instruction_chars=len(user_instructions), |
| user_query_chars=len(message), |
| file_content_chars=len(file_content), |
| full_prompt_chars=len(prompt), |
| full_prompt_tokens=model_manager.count_tokens(prompt), |
| history_messages_used=prompt_meta.get("history_messages_used"), |
| history_messages_total=prompt_meta.get("history_messages_total"), |
| prompt_truncated=prompt_meta.get("truncated"), |
| generation_max_tokens=tokens, |
| continuation_mode=effective_continuation_mode, |
| continuation_prefix_chars=len(continuation_prefix_text) |
| ) |
| |
| persist_user_message = bool(persist_user_message and not effective_continuation_mode) |
|
|
| |
| if persist_user_message: |
| add_message(conversation_id, "user", message, file_results) |
| |
| |
| messages = get_messages(conversation_id, limit=2) |
| if len(messages) <= 2: |
| title = message[:50] + "..." if len(message) > 50 else message |
| update_conversation_title(conversation_id, db_user["id"], title) |
| |
| |
| if not model_manager.is_available: |
| |
| async def error_stream(): |
| runtime_error = model_manager.last_error or "NVIDIA API not available. Check NVIDIA_API_KEY configuration." |
| log_chat_status( |
| stage="model_unavailable", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id |
| ) |
| yield f"data: {json.dumps({'error': f'Model not available. {runtime_error}'})}\n\n" |
| yield f"data: {json.dumps({'done': True})}\n\n" |
| |
| return StreamingResponse( |
| error_stream(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Conversation-Id": str(conversation_id), |
| "X-Request-Id": request_id, |
| "X-Accel-Buffering": "no" |
| } |
| ) |
| |
| |
| if not model_manager.is_loaded: |
| if not model_manager.load_model(): |
| async def error_stream(): |
| error_message = model_manager.last_error or "Failed to load model. Check server logs." |
| log_chat_status( |
| stage="model_load_failed", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| error=error_message |
| ) |
| yield f"data: {json.dumps({'error': error_message})}\n\n" |
| yield f"data: {json.dumps({'done': True})}\n\n" |
| |
| return StreamingResponse( |
| error_stream(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Conversation-Id": str(conversation_id), |
| "X-Request-Id": request_id, |
| "X-Accel-Buffering": "no" |
| } |
| ) |
| |
| stop_event = asyncio.Event() |
| active_chat_stop_requests[request_id] = { |
| "event": stop_event, |
| "username": username, |
| "conversation_id": conversation_id |
| } |
|
|
| async def monitor_disconnect(): |
| while not stop_event.is_set(): |
| if await request.is_disconnected(): |
| stop_event.set() |
| log_chat_status( |
| stage="client_disconnected", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id |
| ) |
| return |
| await asyncio.sleep(0.05) |
|
|
| |
| async def generate(): |
| full_response = "" |
| generation_start = time.perf_counter() |
| disconnect_task = asyncio.create_task(monitor_disconnect()) |
| stopped_by_model = False |
| generation_error = None |
| log_chat_status( |
| stage="generation_started", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| generation_max_tokens=tokens |
| ) |
| |
| try: |
| async for chunk in model_manager.generate_stream( |
| prompt=prompt, |
| temperature=temp, |
| max_tokens=tokens, |
| stop_event=stop_event |
| ): |
| if stop_event.is_set(): |
| break |
|
|
| data = json.loads(chunk) |
| |
| if "token" in data: |
| full_response += data["token"] |
| if "error" in data: |
| generation_error = data.get("error") |
| log_chat_status( |
| stage="generation_error", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| error=data.get("error") |
| ) |
| yield f"data: {chunk}\n\n" |
| continue |
| if data.get("stopped"): |
| stopped_by_model = True |
| continue |
| if data.get("done"): |
| |
| |
| continue |
|
|
| yield f"data: {chunk}\n\n" |
| await asyncio.sleep(0) |
|
|
| |
| assistant_message_id = None |
| if full_response: |
| response_to_store = full_response |
| if effective_continuation_mode: |
| response_to_store = strip_continuation_prefix(response_to_store) |
|
|
| if response_to_store: |
| if effective_continuation_mode: |
| updated = False |
| if continuation_message_id is not None: |
| updated = append_to_assistant_message( |
| conversation_id, |
| continuation_message_id, |
| response_to_store |
| ) |
| if updated: |
| assistant_message_id = continuation_message_id |
| if not updated: |
| updated = append_to_last_assistant_message(conversation_id, response_to_store) |
| if updated: |
| latest_messages = get_messages(conversation_id, limit=None) |
| for msg in reversed(latest_messages): |
| if msg.get("role") == "assistant": |
| assistant_message_id = msg.get("id") |
| break |
| if not updated: |
| assistant_message_id = add_message(conversation_id, "assistant", response_to_store) |
| else: |
| assistant_message_id = add_message(conversation_id, "assistant", response_to_store) |
|
|
| generation_ms = int((time.perf_counter() - generation_start) * 1000) |
| total_ms = int((time.perf_counter() - request_start) * 1000) |
|
|
| stop_reached = stop_event.is_set() or stopped_by_model |
| if stop_reached: |
| log_chat_status( |
| stage="generation_stopped", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| response_chars=len(full_response), |
| generation_ms=generation_ms, |
| total_request_ms=total_ms |
| ) |
| yield f"data: {json.dumps({'stopped': True, 'done': True, 'assistant_message_id': assistant_message_id, 'error': generation_error})}\n\n" |
| else: |
| log_chat_status( |
| stage="generation_completed", |
| username=username, |
| conversation_id=conversation_id, |
| request_id=request_id, |
| response_chars=len(full_response), |
| generation_ms=generation_ms, |
| total_request_ms=total_ms |
| ) |
| yield f"data: {json.dumps({'done': True, 'assistant_message_id': assistant_message_id, 'error': generation_error})}\n\n" |
| except asyncio.CancelledError: |
| stop_event.set() |
| raise |
| finally: |
| disconnect_task.cancel() |
| active_chat_stop_requests.pop(request_id, None) |
| |
| return StreamingResponse( |
| generate(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Conversation-Id": str(conversation_id), |
| "X-Request-Id": request_id, |
| "X-Accel-Buffering": "no" |
| } |
| ) |
|
|
| |
| @app.get(f"{API_BASE_ENDPOINT}/model/info") |
| async def api_model_info(): |
| """Get model information""" |
| return model_manager.get_model_info() |
|
|
| |
| @app.get(f"{API_BASE_ENDPOINT}/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return { |
| "status": "healthy", |
| "model_available": model_manager.is_available, |
| "model_loaded": model_manager.is_loaded, |
| "nvidia_api_configured": bool(model_manager.nvidia_api_key), |
| "tools_available": tool_client.is_available, |
| "tools": tool_client.get_tool_names(), |
| } |
|
|
| |
| @app.exception_handler(HTTPException) |
| async def http_exception_handler(request: Request, exc: HTTPException): |
| return JSONResponse( |
| status_code=exc.status_code, |
| content={"error": True, "detail": exc.detail} |
| ) |
|
|
| @app.exception_handler(Exception) |
| async def general_exception_handler(request: Request, exc: Exception): |
| return JSONResponse( |
| status_code=500, |
| content={"error": True, "detail": str(exc)} |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| |
| print(f""" |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β β |
| β GAKR AI Chatbot Platform β |
| β β |
| β Local URL: http://{settings.HOST}:{settings.PORT} β |
| β β |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| """) |
| |
| uvicorn.run( |
| "main:app", |
| host=settings.HOST, |
| port=settings.PORT, |
| reload=settings.DEBUG, |
| log_level="info" |
| ) |
|
|