gakrchat1 / backend /main.py
extraplus's picture
Upload main.py
db60cfe verified
"""
FastAPI main application for GAKR AI Chatbot Platform
"""
import os
import sys
import json
import asyncio
import time
import re
from uuid import uuid4
from datetime import timedelta
from typing import Dict, List, Optional
from contextlib import asynccontextmanager
from pathlib import Path
# Add backend directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from fastapi import FastAPI, HTTPException, Depends, File, UploadFile, Form, Request
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
# Import modules
from config import settings, DEFAULT_SYSTEM_PROMPT
from auth import (
authenticate_user, create_access_token, decode_token,
register_user, get_user, update_user_settings, get_user_settings,
change_password, ACCESS_TOKEN_EXPIRE_MINUTES
)
from database import (
get_or_create_user, create_conversation, get_conversations,
get_conversation, get_messages, add_message, delete_conversation,
update_conversation_title, get_user_stats, get_user_by_username,
append_to_last_assistant_message, append_to_assistant_message
)
from file_processor import process_file, format_files_for_prompt, save_uploaded_file
from model_manager import model_manager
from tool_client import tool_client
# Security
security = HTTPBearer(auto_error=False)
active_chat_stop_requests: Dict[str, dict] = {}
CONTINUATION_PREFIX_PATTERNS = [
re.compile(r"^\s*(?:let(?:'|\u2019)?s|let us)\s+continue(?:\s+from\s+where\s+we\s+left\s+off)?[.!:\-\s]*", re.IGNORECASE),
re.compile(r"^\s*continuing(?:\s+from\s+where\s+we\s+left\s+off)?[.!:\-\s]*", re.IGNORECASE),
re.compile(r"^\s*sure[.!:\-\s]*", re.IGNORECASE),
]
# Pydantic models
class UserRegister(BaseModel):
username: str = Field(..., min_length=3, max_length=30)
password: str = Field(..., min_length=6)
class UserLogin(BaseModel):
username: str
password: str
class ChatRequest(BaseModel):
message: str
conversation_id: Optional[int] = None
temperature: Optional[float] = settings.TEMPERATURE
max_tokens: Optional[int] = settings.MAX_TOKENS
class SettingsUpdate(BaseModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
system_prompt: Optional[str] = Field(default=None, max_length=4000)
theme: Optional[str] = None
class TitleUpdate(BaseModel):
title: str = Field(..., min_length=1, max_length=100)
def log_chat_status(
stage: str,
username: str,
conversation_id: Optional[int],
**details
):
"""Structured chat pipeline logs"""
payload = {
"stage": stage,
"user": username,
"conversation_id": conversation_id,
**details,
}
print(f"[CHAT] {json.dumps(payload, ensure_ascii=False, default=str)}")
def ensure_db_user(username: str) -> dict:
"""Ensure auth user exists in DB users table and return DB user row."""
db_user = get_user_by_username(username)
if db_user:
return db_user
# Recover gracefully when users.json exists but sqlite row is missing.
get_or_create_user(username)
db_user = get_user_by_username(username)
if db_user:
return db_user
raise HTTPException(status_code=500, detail="Failed to initialize user in database")
def strip_continuation_prefix(text: str) -> str:
"""Remove common continuation preambles from model output."""
value = text or ""
for pattern in CONTINUATION_PREFIX_PATTERNS:
value = pattern.sub("", value, count=1)
return value
def build_continuation_prompt(
chat_history: List[dict],
file_content: str,
user_instructions: str,
assistant_prefix: str
) -> str:
"""Build prompt to continue an interrupted assistant response."""
system_parts = [DEFAULT_SYSTEM_PROMPT.strip()]
system_parts.append(
"## Continuation Mode\n"
"- Continue an interrupted assistant response.\n"
"- Output only the missing continuation text after the provided prefix.\n"
"- Do not restart or repeat prior text.\n"
"- Do not prepend phrases like 'Let's continue' or 'Sure'."
)
if user_instructions:
system_parts.append(f"## Custom Instructions\n{user_instructions.strip()}")
system_prompt = "\n\n".join(system_parts)
prompt_parts = [f"SYSTEM: {system_prompt}"]
if file_content:
prompt_parts.append(f"\n{file_content}")
if chat_history:
prompt_parts.append("\n--- Conversation History ---")
for msg in chat_history:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "user":
prompt_parts.append(f"USER: {content}")
elif role == "assistant":
prompt_parts.append(f"ASSISTANT: {content}")
prompt_parts.append(f"\nASSISTANT: {assistant_prefix}")
return "\n".join(prompt_parts)
# Get current user from token
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
if not credentials:
raise HTTPException(status_code=401, detail="Not authenticated")
payload = decode_token(credentials.credentials)
if not payload:
raise HTTPException(status_code=401, detail="Invalid or expired token")
username = payload.get("sub")
if not username:
raise HTTPException(status_code=401, detail="Invalid token")
user = get_user(username)
if not user:
raise HTTPException(status_code=401, detail="User not found")
# Keep auth/json users and sqlite users in sync.
ensure_db_user(username)
return user
# Lifespan context manager
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown events"""
# Startup
print("=" * 50)
print(f"Starting {settings.APP_NAME}")
print("=" * 50)
# Check NVIDIA API connectivity
if model_manager.is_available:
print(f"Checking NVIDIA API at {model_manager.nvidia_base_url}...")
if model_manager.load_model():
print(f"NVIDIA API connected: model={model_manager.get_model_info().get('model_name', 'unknown')}")
else:
print(f"Warning: NVIDIA API not reachable: {model_manager.last_error}")
print(f"NVIDIA Base URL: {model_manager.nvidia_base_url}")
print("The server will start but model inference will be unavailable until the API is reachable.")
else:
print("Warning: NVIDIA_API_KEY not configured. Model inference disabled.")
# Initialize web_search tool
print("Initializing web_search tool...")
tools_ok = await tool_client.initialize()
if tools_ok:
tool_names = ", ".join(tool_client.get_tool_names()) or "none"
print(f"Tools ready: {len(tool_client.tools)} tool(s) [{tool_names}]")
else:
print(f"Tool init failed: {tool_client.init_error}")
print("Server will run without tool support.")
yield
# Shutdown
print("Shutting down...")
await tool_client.shutdown()
model_manager.unload_model()
# Create FastAPI app
app = FastAPI(
title=settings.APP_NAME,
description="AI Chatbot Platform using NVIDIA API inference",
version="1.0.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=settings.CORS_ALLOW_CREDENTIALS,
allow_methods=["*"],
allow_headers=["*"],
)
API_BASE_ENDPOINT = settings.API_BASE_ENDPOINT
frontend_dir = Path(settings.FRONTEND_DIR)
serve_frontend = settings.SERVE_FRONTEND and frontend_dir.exists()
# Routes
@app.get("/")
async def root():
"""Serve frontend when enabled, otherwise return backend service info."""
if serve_frontend:
return FileResponse(frontend_dir / "index.html")
return {
"service": settings.APP_NAME,
"mode": "backend-only",
"frontend_url": settings.FRONTEND_URL,
"api_base_endpoint": API_BASE_ENDPOINT,
"health_url": f"{API_BASE_ENDPOINT}/health",
"docs_url": app.docs_url,
}
if serve_frontend:
@app.get("/login")
async def login_page():
"""Serve login page"""
return FileResponse(frontend_dir / "login.html")
@app.get("/register")
async def register_page():
"""Serve register page"""
return FileResponse(frontend_dir / "register.html")
@app.get("/profile")
async def profile_page():
"""Serve profile page"""
return FileResponse(frontend_dir / "profile.html")
@app.get("/css/{file_path:path}")
async def serve_css(file_path: str):
"""Serve CSS files"""
return FileResponse(frontend_dir / "css" / file_path)
@app.get("/js/{file_path:path}")
async def serve_js(file_path: str):
"""Serve JS files"""
return FileResponse(frontend_dir / "js" / file_path)
# API Routes
# Auth endpoints
@app.post(f"{API_BASE_ENDPOINT}/auth/register")
async def api_register(user_data: UserRegister):
"""Register a new user"""
try:
user = register_user(user_data.username, user_data.password)
# Create user in database
db_user_id = get_or_create_user(user_data.username)
# Create access token
access_token = create_access_token(
data={"sub": user["username"]},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {
"success": True,
"message": "User registered successfully",
"access_token": access_token,
"token_type": "bearer",
"username": user["username"]
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}")
@app.post(f"{API_BASE_ENDPOINT}/auth/login")
async def api_login(user_data: UserLogin):
"""Login user"""
user = authenticate_user(user_data.username, user_data.password)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
# Ensure user exists in database
get_or_create_user(user["username"])
access_token = create_access_token(
data={"sub": user["username"]},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {
"success": True,
"access_token": access_token,
"token_type": "bearer",
"username": user["username"]
}
@app.post(f"{API_BASE_ENDPOINT}/auth/logout")
async def api_logout(current_user: dict = Depends(get_current_user)):
"""Logout user (client-side token removal)"""
return {"success": True, "message": "Logged out successfully"}
@app.get(f"{API_BASE_ENDPOINT}/auth/me")
async def api_me(current_user: dict = Depends(get_current_user)):
"""Get current user info"""
return {
"username": current_user["username"],
"created_at": current_user.get("created_at"),
"settings": get_user_settings(current_user["username"])
}
# User endpoints
@app.get(f"{API_BASE_ENDPOINT}/user/profile")
async def api_profile(current_user: dict = Depends(get_current_user)):
"""Get user profile with stats"""
username = current_user["username"]
db_user = ensure_db_user(username)
stats = get_user_stats(db_user["id"]) if db_user else {"total_conversations": 0, "total_messages": 0}
return {
"username": username,
"created_at": current_user.get("created_at"),
"settings": get_user_settings(username),
"stats": stats
}
@app.put(f"{API_BASE_ENDPOINT}/user/settings")
async def api_update_settings(
settings_update: SettingsUpdate,
current_user: dict = Depends(get_current_user)
):
"""Update user settings"""
username = current_user["username"]
update_data = {}
if settings_update.temperature is not None:
update_data["temperature"] = max(0.0, min(2.0, settings_update.temperature))
if settings_update.max_tokens is not None:
token_limit = model_manager.get_max_generation_tokens_limit()
update_data["max_tokens"] = max(1, min(token_limit, settings_update.max_tokens))
if settings_update.system_prompt is not None:
update_data["system_prompt"] = settings_update.system_prompt
if settings_update.theme is not None:
update_data["theme"] = settings_update.theme
if update_user_settings(username, update_data):
return {"success": True, "settings": get_user_settings(username)}
else:
raise HTTPException(status_code=400, detail="Failed to update settings")
@app.post(f"{API_BASE_ENDPOINT}/user/change-password")
async def api_change_password(
old_password: str = Form(...),
new_password: str = Form(...),
current_user: dict = Depends(get_current_user)
):
"""Change user password"""
try:
if change_password(current_user["username"], old_password, new_password):
return {"success": True, "message": "Password changed successfully"}
else:
raise HTTPException(status_code=400, detail="Invalid old password")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Conversation endpoints
@app.get(f"{API_BASE_ENDPOINT}/conversations")
async def api_get_conversations(current_user: dict = Depends(get_current_user)):
"""Get user's conversations"""
db_user = ensure_db_user(current_user["username"])
conversations = get_conversations(db_user["id"])
return {"conversations": conversations}
@app.post(f"{API_BASE_ENDPOINT}/conversations")
async def api_create_conversation(current_user: dict = Depends(get_current_user)):
"""Create a new conversation"""
db_user = ensure_db_user(current_user["username"])
conversation_id = create_conversation(db_user["id"], "New Chat")
return {
"success": True,
"conversation_id": conversation_id,
"title": "New Chat"
}
@app.get(f"{API_BASE_ENDPOINT}/conversations/{{conversation_id}}/messages")
async def api_get_messages(
conversation_id: int,
current_user: dict = Depends(get_current_user)
):
"""Get messages for a conversation"""
db_user = ensure_db_user(current_user["username"])
# Verify conversation belongs to user
conversation = get_conversation(conversation_id, db_user["id"])
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
messages = get_messages(conversation_id)
return {
"conversation": conversation,
"messages": messages
}
@app.put(f"{API_BASE_ENDPOINT}/conversations/{{conversation_id}}/title")
async def api_update_title(
conversation_id: int,
title_update: TitleUpdate,
current_user: dict = Depends(get_current_user)
):
"""Update conversation title"""
db_user = ensure_db_user(current_user["username"])
if update_conversation_title(conversation_id, db_user["id"], title_update.title):
return {"success": True, "title": title_update.title}
else:
raise HTTPException(status_code=400, detail="Failed to update title")
@app.delete(f"{API_BASE_ENDPOINT}/conversations/{{conversation_id}}")
async def api_delete_conversation(
conversation_id: int,
current_user: dict = Depends(get_current_user)
):
"""Delete a conversation"""
db_user = ensure_db_user(current_user["username"])
if delete_conversation(conversation_id, db_user["id"]):
return {"success": True, "message": "Conversation deleted"}
else:
raise HTTPException(status_code=400, detail="Failed to delete conversation")
# Chat endpoints
@app.post(f"{API_BASE_ENDPOINT}/chat/stop")
async def api_chat_stop(
request_id: str = Form(...),
current_user: dict = Depends(get_current_user)
):
"""Request cancellation for an active chat stream"""
username = current_user["username"]
active = active_chat_stop_requests.get(request_id)
if not active:
return {"success": True, "request_id": request_id, "stopped": False}
if active.get("username") != username:
raise HTTPException(status_code=403, detail="Not authorized to stop this request")
stop_event = active.get("event")
if stop_event is not None:
stop_event.set()
log_chat_status(
stage="stop_requested",
username=username,
conversation_id=active.get("conversation_id"),
request_id=request_id
)
return {"success": True, "request_id": request_id, "stopped": True}
return {"success": True, "request_id": request_id, "stopped": False}
@app.post(f"{API_BASE_ENDPOINT}/chat/stream")
async def api_chat_stream(
request: Request,
message: str = Form(...),
conversation_id: Optional[int] = Form(None),
temperature: float = Form(settings.TEMPERATURE),
max_tokens: int = Form(settings.MAX_TOKENS),
request_id: Optional[str] = Form(None),
persist_user_message: bool = Form(True),
continuation_mode: bool = Form(False),
continuation_prefix: Optional[str] = Form(None),
continuation_message_id: Optional[int] = Form(None),
files: List[UploadFile] = File(default=[]),
current_user: dict = Depends(get_current_user)
):
"""Stream chat response with optional file uploads"""
request_start = time.perf_counter()
username = current_user["username"]
request_id = request_id.strip() if request_id else uuid4().hex
# Get user settings
user_settings = get_user_settings(username)
temp = (
temperature
if temperature != settings.TEMPERATURE
else user_settings.get("temperature", settings.TEMPERATURE)
)
tokens = (
max_tokens
if max_tokens != settings.MAX_TOKENS
else user_settings.get("max_tokens", settings.MAX_TOKENS)
)
tokens = max(1, min(tokens, model_manager.get_max_generation_tokens_limit()))
user_instructions = user_settings.get("system_prompt", "")
log_chat_status(
stage="request_received",
username=username,
conversation_id=conversation_id,
request_id=request_id,
user_query_chars=len(message),
file_count=len([f for f in files if getattr(f, "filename", None)]),
user_instruction_chars=len(user_instructions),
requested_max_tokens=max_tokens,
resolved_initial_max_tokens=tokens,
temperature=temp,
continuation_mode=continuation_mode
)
# Get or create database user
db_user = ensure_db_user(username)
# Create conversation if not provided
if not conversation_id:
conversation_id = create_conversation(db_user["id"], message[:50] + "..." if len(message) > 50 else message)
else:
# Verify conversation belongs to user
conversation = get_conversation(conversation_id, db_user["id"])
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
log_chat_status(
stage="conversation_ready",
username=username,
conversation_id=conversation_id,
request_id=request_id
)
# Process uploaded files
file_results = []
file_text_chars = 0
saved_file_count = 0
if files:
for file in files:
if file.filename:
content = await file.read()
saved_path = None
save_error = None
try:
saved_path = save_uploaded_file(content, file.filename)
except Exception as exc:
save_error = str(exc)
result = process_file(content, file.filename)
result["saved_permanently"] = saved_path is not None
if saved_path is not None:
result["saved_path"] = str(saved_path)
result["saved_filename"] = saved_path.name
saved_file_count += 1
if save_error:
result["save_error"] = save_error
file_results.append(result)
file_text_chars += len(result.get("content", ""))
log_chat_status(
stage="files_processed",
username=username,
conversation_id=conversation_id,
request_id=request_id,
file_count=len(file_results),
saved_file_count=saved_file_count,
file_text_chars=file_text_chars
)
# Format file content for prompt
file_content = format_files_for_prompt(file_results)
# Get full chat history for this conversation (user + assistant only)
all_messages = get_messages(conversation_id, limit=None)
chat_history = [
{"role": msg["role"], "content": msg["content"]}
for msg in all_messages
if msg.get("role") in {"user", "assistant"}
]
history_chars = sum(len(msg.get("content", "")) for msg in chat_history)
log_chat_status(
stage="history_loaded",
username=username,
conversation_id=conversation_id,
request_id=request_id,
history_messages_total=len(chat_history),
history_chars_total=history_chars
)
continuation_prefix_text = (continuation_prefix or "").strip()
if continuation_mode and not continuation_prefix_text:
for msg in reversed(chat_history):
if msg.get("role") == "assistant" and msg.get("content", "").strip():
continuation_prefix_text = msg["content"]
break
if continuation_mode and continuation_prefix_text and chat_history:
last_msg = chat_history[-1]
if (
last_msg.get("role") == "assistant"
and (
last_msg.get("content", "") == continuation_prefix_text
or last_msg.get("content", "").endswith(continuation_prefix_text)
or continuation_prefix_text.endswith(last_msg.get("content", ""))
)
):
chat_history = chat_history[:-1]
effective_continuation_mode = bool(continuation_mode and continuation_prefix_text)
# Build prompt
if effective_continuation_mode:
prompt = build_continuation_prompt(
chat_history=chat_history,
file_content=file_content,
user_instructions=user_instructions,
assistant_prefix=continuation_prefix_text
)
prompt_meta = {
"history_messages_total": len(chat_history),
"history_messages_used": len(chat_history),
"truncated": False
}
else:
prompt = model_manager.build_prompt(
query=message,
history=chat_history,
file_content=file_content,
custom_instructions=user_instructions,
)
prompt_meta = model_manager.last_prompt_meta
tokens = model_manager.resolve_max_tokens(prompt, tokens)
log_chat_status(
stage="prompt_built",
username=username,
conversation_id=conversation_id,
request_id=request_id,
default_system_prompt_chars=len(DEFAULT_SYSTEM_PROMPT),
user_instruction_chars=len(user_instructions),
user_query_chars=len(message),
file_content_chars=len(file_content),
full_prompt_chars=len(prompt),
full_prompt_tokens=model_manager.count_tokens(prompt),
history_messages_used=prompt_meta.get("history_messages_used"),
history_messages_total=prompt_meta.get("history_messages_total"),
prompt_truncated=prompt_meta.get("truncated"),
generation_max_tokens=tokens,
continuation_mode=effective_continuation_mode,
continuation_prefix_chars=len(continuation_prefix_text)
)
persist_user_message = bool(persist_user_message and not effective_continuation_mode)
# Save user message to database (skipped for internal continuation prompts)
if persist_user_message:
add_message(conversation_id, "user", message, file_results)
# Update conversation title if it's the first message
messages = get_messages(conversation_id, limit=2)
if len(messages) <= 2: # User message + assistant response
title = message[:50] + "..." if len(message) > 50 else message
update_conversation_title(conversation_id, db_user["id"], title)
# Check if model is available
if not model_manager.is_available:
# Return error as SSE
async def error_stream():
runtime_error = model_manager.last_error or "NVIDIA API not available. Check NVIDIA_API_KEY configuration."
log_chat_status(
stage="model_unavailable",
username=username,
conversation_id=conversation_id,
request_id=request_id
)
yield f"data: {json.dumps({'error': f'Model not available. {runtime_error}'})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
return StreamingResponse(
error_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Conversation-Id": str(conversation_id),
"X-Request-Id": request_id,
"X-Accel-Buffering": "no"
}
)
# Load model if needed
if not model_manager.is_loaded:
if not model_manager.load_model():
async def error_stream():
error_message = model_manager.last_error or "Failed to load model. Check server logs."
log_chat_status(
stage="model_load_failed",
username=username,
conversation_id=conversation_id,
request_id=request_id,
error=error_message
)
yield f"data: {json.dumps({'error': error_message})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
return StreamingResponse(
error_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Conversation-Id": str(conversation_id),
"X-Request-Id": request_id,
"X-Accel-Buffering": "no"
}
)
stop_event = asyncio.Event()
active_chat_stop_requests[request_id] = {
"event": stop_event,
"username": username,
"conversation_id": conversation_id
}
async def monitor_disconnect():
while not stop_event.is_set():
if await request.is_disconnected():
stop_event.set()
log_chat_status(
stage="client_disconnected",
username=username,
conversation_id=conversation_id,
request_id=request_id
)
return
await asyncio.sleep(0.05)
# Generate streaming response
async def generate():
full_response = ""
generation_start = time.perf_counter()
disconnect_task = asyncio.create_task(monitor_disconnect())
stopped_by_model = False
generation_error = None
log_chat_status(
stage="generation_started",
username=username,
conversation_id=conversation_id,
request_id=request_id,
generation_max_tokens=tokens
)
try:
async for chunk in model_manager.generate_stream(
prompt=prompt,
temperature=temp,
max_tokens=tokens,
stop_event=stop_event
):
if stop_event.is_set():
break
data = json.loads(chunk)
if "token" in data:
full_response += data["token"]
if "error" in data:
generation_error = data.get("error")
log_chat_status(
stage="generation_error",
username=username,
conversation_id=conversation_id,
request_id=request_id,
error=data.get("error")
)
yield f"data: {chunk}\n\n"
continue
if data.get("stopped"):
stopped_by_model = True
continue
if data.get("done"):
# Delay final done event until DB save is complete so we can
# include assistant_message_id in a single terminal event.
continue
yield f"data: {chunk}\n\n"
await asyncio.sleep(0)
# Save any assistant response that was produced
assistant_message_id = None
if full_response:
response_to_store = full_response
if effective_continuation_mode:
response_to_store = strip_continuation_prefix(response_to_store)
if response_to_store:
if effective_continuation_mode:
updated = False
if continuation_message_id is not None:
updated = append_to_assistant_message(
conversation_id,
continuation_message_id,
response_to_store
)
if updated:
assistant_message_id = continuation_message_id
if not updated:
updated = append_to_last_assistant_message(conversation_id, response_to_store)
if updated:
latest_messages = get_messages(conversation_id, limit=None)
for msg in reversed(latest_messages):
if msg.get("role") == "assistant":
assistant_message_id = msg.get("id")
break
if not updated:
assistant_message_id = add_message(conversation_id, "assistant", response_to_store)
else:
assistant_message_id = add_message(conversation_id, "assistant", response_to_store)
generation_ms = int((time.perf_counter() - generation_start) * 1000)
total_ms = int((time.perf_counter() - request_start) * 1000)
stop_reached = stop_event.is_set() or stopped_by_model
if stop_reached:
log_chat_status(
stage="generation_stopped",
username=username,
conversation_id=conversation_id,
request_id=request_id,
response_chars=len(full_response),
generation_ms=generation_ms,
total_request_ms=total_ms
)
yield f"data: {json.dumps({'stopped': True, 'done': True, 'assistant_message_id': assistant_message_id, 'error': generation_error})}\n\n"
else:
log_chat_status(
stage="generation_completed",
username=username,
conversation_id=conversation_id,
request_id=request_id,
response_chars=len(full_response),
generation_ms=generation_ms,
total_request_ms=total_ms
)
yield f"data: {json.dumps({'done': True, 'assistant_message_id': assistant_message_id, 'error': generation_error})}\n\n"
except asyncio.CancelledError:
stop_event.set()
raise
finally:
disconnect_task.cancel()
active_chat_stop_requests.pop(request_id, None)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Conversation-Id": str(conversation_id),
"X-Request-Id": request_id,
"X-Accel-Buffering": "no"
}
)
# Model info endpoint
@app.get(f"{API_BASE_ENDPOINT}/model/info")
async def api_model_info():
"""Get model information"""
return model_manager.get_model_info()
# Health check
@app.get(f"{API_BASE_ENDPOINT}/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_available": model_manager.is_available,
"model_loaded": model_manager.is_loaded,
"nvidia_api_configured": bool(model_manager.nvidia_api_key),
"tools_available": tool_client.is_available,
"tools": tool_client.get_tool_names(),
}
# Error handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"error": True, "detail": exc.detail}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={"error": True, "detail": str(exc)}
)
if __name__ == "__main__":
import uvicorn
print(f"""
╔══════════════════════════════════════════════════════════════╗
β•‘ β•‘
β•‘ GAKR AI Chatbot Platform β•‘
β•‘ β•‘
β•‘ Local URL: http://{settings.HOST}:{settings.PORT} β•‘
β•‘ β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
""")
uvicorn.run(
"main:app",
host=settings.HOST,
port=settings.PORT,
reload=settings.DEBUG,
log_level="info"
)