| | """ |
| | AI API Server - FastAPI + Mistral-7B |
| | Production-ready API with streaming, authentication, and caching |
| | """ |
| |
|
| | import os |
| | import logging |
| | import asyncio |
| | from typing import Optional, Dict, Any, AsyncGenerator |
| | from datetime import datetime |
| | from functools import lru_cache |
| | import hashlib |
| |
|
| | from fastapi import FastAPI, HTTPException, Header, Request, status |
| | from fastapi.responses import StreamingResponse, HTMLResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field, validator |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| | import uvicorn |
| |
|
| | |
| | |
| | |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| | handlers=[ |
| | logging.StreamHandler(), |
| | logging.FileHandler("api.log") |
| | ] |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| |
|
| | class Config: |
| | """Application configuration""" |
| | MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2") |
| | API_KEY = os.getenv("API_KEY", "your-secret-api-key-here") |
| | MAX_LENGTH = int(os.getenv("MAX_LENGTH", "2048")) |
| | TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) |
| | TOP_P = float(os.getenv("TOP_P", "0.95")) |
| | CACHE_SIZE = int(os.getenv("CACHE_SIZE", "100")) |
| | PORT = int(os.getenv("PORT", "7860")) |
| | HOST = os.getenv("HOST", "0.0.0.0") |
| | |
| | |
| | QUANTIZATION_CONFIG = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.float16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4" |
| | ) |
| |
|
| | config = Config() |
| |
|
| | |
| | |
| | |
| |
|
| | class ChatRequest(BaseModel): |
| | """Request model for chat endpoint""" |
| | prompt: str = Field(..., min_length=1, max_length=4000, description="User prompt") |
| | language: str = Field(default="en", description="Response language (en, pt, es)") |
| | temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) |
| | max_tokens: Optional[int] = Field(default=None, ge=1, le=4096) |
| | stream: bool = Field(default=True, description="Enable streaming response") |
| | |
| | @validator("language") |
| | def validate_language(cls, v): |
| | allowed = ["en", "pt", "es", "fr", "de", "it", "ja", "zh"] |
| | if v not in allowed: |
| | raise ValueError(f"Language must be one of {allowed}") |
| | return v |
| |
|
| | class ChatResponse(BaseModel): |
| | """Response model for chat endpoint""" |
| | response: str |
| | language: str |
| | model: str |
| | timestamp: str |
| | cached: bool = False |
| |
|
| | class HealthResponse(BaseModel): |
| | """Health check response""" |
| | status: str |
| | model_loaded: bool |
| | timestamp: str |
| |
|
| | |
| | |
| | |
| |
|
| | SYSTEM_PROMPTS = { |
| | "en": "You are a helpful, respectful and honest AI assistant. Always answer as helpfully as possible, while being safe. If you don't know the answer, say so instead of making up information.", |
| | "pt": "Você é um assistente de IA útil, respeitoso e honesto. Sempre responda da forma mais útil possível, mantendo a segurança. Se não souber a resposta, diga isso ao invés de inventar informações.", |
| | "es": "Eres un asistente de IA útil, respetuoso y honesto. Siempre responde de la manera más útil posible, manteniendo la seguridad. Si no sabes la respuesta, dilo en lugar de inventar información.", |
| | "fr": "Vous êtes un assistant IA utile, respectueux et honnête. Répondez toujours de la manière la plus utile possible, tout en restant sûr. Si vous ne connaissez pas la réponse, dites-le au lieu d'inventer des informations.", |
| | "de": "Sie sind ein hilfreicher, respektvoller und ehrlicher KI-Assistent. Antworten Sie immer so hilfreich wie möglich und bleiben Sie dabei sicher. Wenn Sie die Antwort nicht wissen, sagen Sie es, anstatt Informationen zu erfinden.", |
| | "it": "Sei un assistente AI utile, rispettoso e onesto. Rispondi sempre nel modo più utile possibile, mantenendo la sicurezza. Se non conosci la risposta, dillo invece di inventare informazioni.", |
| | "ja": "あなたは親切で、礼儀正しく、正直なAIアシスタントです。常に安全を保ちながら、できるだけ役立つように答えてください。答えがわからない場合は、情報を作り上げるのではなく、そう言ってください。", |
| | "zh": "你是一个乐于助人、尊重他人且诚实的AI助手。在保持安全的同时,始终尽可能有帮助地回答。如果你不知道答案,请说出来,而不是编造信息。" |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | class ResponseCache: |
| | """Simple in-memory cache for responses""" |
| | |
| | def __init__(self, max_size: int = 100): |
| | self.cache: Dict[str, tuple[str, datetime]] = {} |
| | self.max_size = max_size |
| | logger.info(f"Initialized cache with max size: {max_size}") |
| | |
| | def _generate_key(self, prompt: str, language: str, temperature: float) -> str: |
| | """Generate cache key from parameters""" |
| | content = f"{prompt}:{language}:{temperature}" |
| | return hashlib.md5(content.encode()).hexdigest() |
| | |
| | def get(self, prompt: str, language: str, temperature: float) -> Optional[str]: |
| | """Retrieve cached response""" |
| | key = self._generate_key(prompt, language, temperature) |
| | if key in self.cache: |
| | response, timestamp = self.cache[key] |
| | logger.info(f"Cache HIT for key: {key[:8]}...") |
| | return response |
| | logger.info(f"Cache MISS for key: {key[:8]}...") |
| | return None |
| | |
| | def set(self, prompt: str, language: str, temperature: float, response: str): |
| | """Store response in cache""" |
| | if len(self.cache) >= self.max_size: |
| | |
| | oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k][1]) |
| | del self.cache[oldest_key] |
| | logger.info(f"Cache full, removed oldest entry: {oldest_key[:8]}...") |
| | |
| | key = self._generate_key(prompt, language, temperature) |
| | self.cache[key] = (response, datetime.now()) |
| | logger.info(f"Cached response for key: {key[:8]}...") |
| |
|
| | |
| | |
| | |
| |
|
| | class ModelManager: |
| | """Manages model loading and inference""" |
| | |
| | def __init__(self): |
| | self.model = None |
| | self.tokenizer = None |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | logger.info(f"Device: {self.device}") |
| | |
| | async def load_model(self): |
| | """Load model with quantization""" |
| | try: |
| | logger.info(f"Loading model: {config.MODEL_NAME}") |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | config.MODEL_NAME, |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | config.MODEL_NAME, |
| | quantization_config=config.QUANTIZATION_CONFIG, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | low_cpu_mem_usage=True |
| | ) |
| | |
| | logger.info("Model loaded successfully!") |
| | return True |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to load model: {str(e)}", exc_info=True) |
| | return False |
| | |
| | def format_prompt(self, prompt: str, language: str) -> str: |
| | """Format prompt with system message""" |
| | system_prompt = SYSTEM_PROMPTS.get(language, SYSTEM_PROMPTS["en"]) |
| | return f"<s>[INST] {system_prompt}\n\nUser: {prompt} [/INST]" |
| | |
| | async def generate_stream( |
| | self, |
| | prompt: str, |
| | language: str, |
| | temperature: float, |
| | max_tokens: int |
| | ) -> AsyncGenerator[str, None]: |
| | """Generate response with streaming""" |
| | try: |
| | formatted_prompt = self.format_prompt(prompt, language) |
| | |
| | |
| | inputs = self.tokenizer( |
| | formatted_prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=config.MAX_LENGTH |
| | ).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | for i in range(max_tokens): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=1, |
| | temperature=temperature, |
| | top_p=config.TOP_P, |
| | do_sample=True, |
| | pad_token_id=self.tokenizer.eos_token_id |
| | ) |
| | |
| | |
| | new_token = self.tokenizer.decode( |
| | outputs[0][-1:], |
| | skip_special_tokens=True |
| | ) |
| | |
| | |
| | if outputs[0][-1] == self.tokenizer.eos_token_id: |
| | break |
| | |
| | yield new_token |
| | |
| | |
| | inputs = {"input_ids": outputs} |
| | |
| | |
| | await asyncio.sleep(0.01) |
| | |
| | except Exception as e: |
| | logger.error(f"Generation error: {str(e)}", exc_info=True) |
| | yield f"\n\n[Error: {str(e)}]" |
| | |
| | async def generate( |
| | self, |
| | prompt: str, |
| | language: str, |
| | temperature: float, |
| | max_tokens: int |
| | ) -> str: |
| | """Generate complete response (non-streaming)""" |
| | try: |
| | formatted_prompt = self.format_prompt(prompt, language) |
| | |
| | inputs = self.tokenizer( |
| | formatted_prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=config.MAX_LENGTH |
| | ).to(self.device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | top_p=config.TOP_P, |
| | do_sample=True, |
| | pad_token_id=self.tokenizer.eos_token_id |
| | ) |
| | |
| | response = self.tokenizer.decode( |
| | outputs[0][inputs["input_ids"].shape[1]:], |
| | skip_special_tokens=True |
| | ) |
| | |
| | return response.strip() |
| | |
| | except Exception as e: |
| | logger.error(f"Generation error: {str(e)}", exc_info=True) |
| | raise HTTPException( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | detail=f"Generation failed: {str(e)}" |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | app = FastAPI( |
| | title="AI API - Mistral 7B", |
| | description="Production-ready AI API with streaming, authentication, and caching", |
| | version="1.0.0", |
| | docs_url="/docs", |
| | redoc_url="/redoc" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | model_manager = ModelManager() |
| | cache = ResponseCache(max_size=config.CACHE_SIZE) |
| |
|
| | |
| | |
| | |
| |
|
| | async def verify_api_key(x_api_key: str = Header(..., alias="X-API-Key")): |
| | """Verify API key from header""" |
| | if x_api_key != config.API_KEY: |
| | logger.warning(f"Invalid API key attempt: {x_api_key[:8]}...") |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail="Invalid API key" |
| | ) |
| | return x_api_key |
| |
|
| | |
| | |
| | |
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Load model on startup""" |
| | logger.info("Starting AI API server...") |
| | success = await model_manager.load_model() |
| | if not success: |
| | logger.error("Failed to load model, server may not function correctly") |
| |
|
| | @app.on_event("shutdown") |
| | async def shutdown_event(): |
| | """Cleanup on shutdown""" |
| | logger.info("Shutting down AI API server...") |
| | |
| | cache.cache.clear() |
| |
|
| | |
| | |
| | |
| |
|
| | @app.get("/", response_class=HTMLResponse) |
| | async def root(): |
| | """Serve frontend HTML""" |
| | html_content = """ |
| | <!DOCTYPE html> |
| | <html lang="en"> |
| | <head> |
| | <meta charset="UTF-8"> |
| | <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| | <title>AI API - Mistral 7B</title> |
| | <script src="https://cdn.tailwindcss.com"></script> |
| | <style> |
| | @keyframes fadeIn { |
| | from { opacity: 0; transform: translateY(10px); } |
| | to { opacity: 1; transform: translateY(0); } |
| | } |
| | .message { |
| | animation: fadeIn 0.3s ease-out; |
| | } |
| | .typing-indicator { |
| | display: inline-block; |
| | } |
| | .typing-indicator span { |
| | display: inline-block; |
| | width: 8px; |
| | height: 8px; |
| | border-radius: 50%; |
| | background-color: #6366F1; |
| | margin: 0 2px; |
| | animation: typing 1.4s infinite; |
| | } |
| | .typing-indicator span:nth-child(2) { |
| | animation-delay: 0.2s; |
| | } |
| | .typing-indicator span:nth-child(3) { |
| | animation-delay: 0.4s; |
| | } |
| | @keyframes typing { |
| | 0%, 60%, 100% { transform: translateY(0); } |
| | 30% { transform: translateY(-10px); } |
| | } |
| | </style> |
| | </head> |
| | <body class="bg-gradient-to-br from-slate-50 to-slate-100 min-h-screen"> |
| | <div class="container mx-auto px-4 py-8 max-w-4xl"> |
| | <!-- Header --> |
| | <div class="bg-white rounded-2xl shadow-lg p-6 mb-6"> |
| | <h1 class="text-3xl font-bold text-slate-800 mb-2">🤖 AI API - Mistral 7B</h1> |
| | <p class="text-slate-600">Production-ready AI API with streaming responses</p> |
| | </div> |
| | |
| | <!-- API Key Section --> |
| | <div class="bg-white rounded-2xl shadow-lg p-6 mb-6"> |
| | <label class="block text-sm font-semibold text-slate-700 mb-2">API Key</label> |
| | <input |
| | type="password" |
| | id="apiKey" |
| | placeholder="Enter your API key" |
| | class="w-full px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-indigo-500 focus:border-transparent outline-none" |
| | /> |
| | </div> |
| | |
| | <!-- Chat Interface --> |
| | <div class="bg-white rounded-2xl shadow-lg p-6 mb-6"> |
| | <div id="messages" class="space-y-4 mb-6 max-h-96 overflow-y-auto"> |
| | <div class="text-center text-slate-400 py-8"> |
| | Start a conversation by typing a message below |
| | </div> |
| | </div> |
| | |
| | <!-- Input Area --> |
| | <div class="flex gap-3"> |
| | <select id="language" class="px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-indigo-500 outline-none"> |
| | <option value="en">English</option> |
| | <option value="pt">Português</option> |
| | <option value="es">Español</option> |
| | <option value="fr">Français</option> |
| | <option value="de">Deutsch</option> |
| | <option value="it">Italiano</option> |
| | </select> |
| | <input |
| | type="text" |
| | id="prompt" |
| | placeholder="Type your message..." |
| | class="flex-1 px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-indigo-500 focus:border-transparent outline-none" |
| | /> |
| | <button |
| | onclick="sendMessage()" |
| | id="sendBtn" |
| | class="px-6 py-3 bg-indigo-600 text-white font-semibold rounded-lg hover:bg-indigo-700 transition-colors disabled:bg-slate-300 disabled:cursor-not-allowed" |
| | > |
| | Send |
| | </button> |
| | </div> |
| | </div> |
| | |
| | <!-- Documentation --> |
| | <div class="bg-white rounded-2xl shadow-lg p-6"> |
| | <h2 class="text-xl font-bold text-slate-800 mb-4">📚 API Documentation</h2> |
| | <div class="space-y-4"> |
| | <div> |
| | <h3 class="font-semibold text-slate-700 mb-2">Endpoint</h3> |
| | <code class="block bg-slate-100 p-3 rounded-lg text-sm">POST /api/chat</code> |
| | </div> |
| | <div> |
| | <h3 class="font-semibold text-slate-700 mb-2">Headers</h3> |
| | <code class="block bg-slate-100 p-3 rounded-lg text-sm">X-API-Key: your-api-key</code> |
| | </div> |
| | <div> |
| | <h3 class="font-semibold text-slate-700 mb-2">Example (curl)</h3> |
| | <pre class="bg-slate-100 p-3 rounded-lg text-sm overflow-x-auto"><code>curl -X POST "http://localhost:7860/api/chat" \\ |
| | -H "X-API-Key: your-secret-api-key-here" \\ |
| | -H "Content-Type: application/json" \\ |
| | -d '{ |
| | "prompt": "Explain quantum computing", |
| | "language": "en", |
| | "stream": false |
| | }'</code></pre> |
| | </div> |
| | </div> |
| | </div> |
| | </div> |
| | |
| | <script> |
| | const messagesDiv = document.getElementById('messages'); |
| | const promptInput = document.getElementById('prompt'); |
| | const apiKeyInput = document.getElementById('apiKey'); |
| | const languageSelect = document.getElementById('language'); |
| | const sendBtn = document.getElementById('sendBtn'); |
| | |
| | // Load API key from localStorage |
| | const savedApiKey = localStorage.getItem('apiKey'); |
| | if (savedApiKey) { |
| | apiKeyInput.value = savedApiKey; |
| | } |
| | |
| | // Save API key on change |
| | apiKeyInput.addEventListener('change', () => { |
| | localStorage.setItem('apiKey', apiKeyInput.value); |
| | }); |
| | |
| | // Send on Enter |
| | promptInput.addEventListener('keypress', (e) => { |
| | if (e.key === 'Enter' && !e.shiftKey) { |
| | e.preventDefault(); |
| | sendMessage(); |
| | } |
| | }); |
| | |
| | function addMessage(content, isUser = false) { |
| | if (messagesDiv.children[0]?.textContent.includes('Start a conversation')) { |
| | messagesDiv.innerHTML = ''; |
| | } |
| | |
| | const messageDiv = document.createElement('div'); |
| | messageDiv.className = `message flex ${isUser ? 'justify-end' : 'justify-start'}`; |
| | |
| | const bubble = document.createElement('div'); |
| | bubble.className = `max-w-[70%] px-4 py-3 rounded-2xl ${ |
| | isUser |
| | ? 'bg-indigo-600 text-white' |
| | : 'bg-slate-100 text-slate-800' |
| | }`; |
| | bubble.textContent = content; |
| | |
| | messageDiv.appendChild(bubble); |
| | messagesDiv.appendChild(messageDiv); |
| | messagesDiv.scrollTop = messagesDiv.scrollHeight; |
| | |
| | return bubble; |
| | } |
| | |
| | function addTypingIndicator() { |
| | const messageDiv = document.createElement('div'); |
| | messageDiv.className = 'message flex justify-start'; |
| | messageDiv.id = 'typing-indicator'; |
| | |
| | const bubble = document.createElement('div'); |
| | bubble.className = 'max-w-[70%] px-4 py-3 rounded-2xl bg-slate-100'; |
| | bubble.innerHTML = '<div class="typing-indicator"><span></span><span></span><span></span></div>'; |
| | |
| | messageDiv.appendChild(bubble); |
| | messagesDiv.appendChild(messageDiv); |
| | messagesDiv.scrollTop = messagesDiv.scrollHeight; |
| | } |
| | |
| | function removeTypingIndicator() { |
| | const indicator = document.getElementById('typing-indicator'); |
| | if (indicator) { |
| | indicator.remove(); |
| | } |
| | } |
| | |
| | async function sendMessage() { |
| | const prompt = promptInput.value.trim(); |
| | const apiKey = apiKeyInput.value.trim(); |
| | const language = languageSelect.value; |
| | |
| | if (!prompt) return; |
| | if (!apiKey) { |
| | alert('Please enter your API key'); |
| | return; |
| | } |
| | |
| | // Add user message |
| | addMessage(prompt, true); |
| | promptInput.value = ''; |
| | |
| | // Disable send button |
| | sendBtn.disabled = true; |
| | addTypingIndicator(); |
| | |
| | try { |
| | const response = await fetch('/api/chat', { |
| | method: 'POST', |
| | headers: { |
| | 'Content-Type': 'application/json', |
| | 'X-API-Key': apiKey |
| | }, |
| | body: JSON.stringify({ |
| | prompt: prompt, |
| | language: language, |
| | stream: true |
| | }) |
| | }); |
| | |
| | if (!response.ok) { |
| | throw new Error(`HTTP ${response.status}: ${await response.text()}`); |
| | } |
| | |
| | removeTypingIndicator(); |
| | const bubble = addMessage('', false); |
| | |
| | // Read stream |
| | const reader = response.body.getReader(); |
| | const decoder = new TextDecoder(); |
| | let fullResponse = ''; |
| | |
| | while (true) { |
| | const { done, value } = await reader.read(); |
| | if (done) break; |
| | |
| | const chunk = decoder.decode(value); |
| | const lines = chunk.split('\\n'); |
| | |
| | for (const line of lines) { |
| | if (line.startsWith('data: ')) { |
| | const data = line.slice(6); |
| | if (data === '[DONE]') break; |
| | |
| | try { |
| | const json = JSON.parse(data); |
| | if (json.token) { |
| | fullResponse += json.token; |
| | bubble.textContent = fullResponse; |
| | messagesDiv.scrollTop = messagesDiv.scrollHeight; |
| | } |
| | } catch (e) { |
| | console.error('Parse error:', e); |
| | } |
| | } |
| | } |
| | } |
| | |
| | } catch (error) { |
| | removeTypingIndicator(); |
| | addMessage(`Error: ${error.message}`, false); |
| | } finally { |
| | sendBtn.disabled = false; |
| | promptInput.focus(); |
| | } |
| | } |
| | </script> |
| | </body> |
| | </html> |
| | """ |
| | return HTMLResponse(content=html_content) |
| |
|
| | @app.get("/health", response_model=HealthResponse) |
| | async def health_check(): |
| | """Health check endpoint""" |
| | return HealthResponse( |
| | status="healthy", |
| | model_loaded=model_manager.model is not None, |
| | timestamp=datetime.now().isoformat() |
| | ) |
| |
|
| | @app.post("/api/chat") |
| | async def chat( |
| | request: ChatRequest, |
| | api_key: str = Header(..., alias="X-API-Key") |
| | ): |
| | """ |
| | Chat endpoint with streaming support |
| | |
| | Requires X-API-Key header for authentication |
| | """ |
| | |
| | await verify_api_key(api_key) |
| | |
| | |
| | if model_manager.model is None: |
| | raise HTTPException( |
| | status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| | detail="Model not loaded yet, please try again later" |
| | ) |
| | |
| | |
| | temperature = request.temperature or config.TEMPERATURE |
| | max_tokens = request.max_tokens or 512 |
| | |
| | try: |
| | |
| | if not request.stream: |
| | cached_response = cache.get(request.prompt, request.language, temperature) |
| | if cached_response: |
| | return ChatResponse( |
| | response=cached_response, |
| | language=request.language, |
| | model=config.MODEL_NAME, |
| | timestamp=datetime.now().isoformat(), |
| | cached=True |
| | ) |
| | |
| | |
| | if request.stream: |
| | async def generate(): |
| | full_response = "" |
| | async for token in model_manager.generate_stream( |
| | request.prompt, |
| | request.language, |
| | temperature, |
| | max_tokens |
| | ): |
| | full_response += token |
| | yield f"data: {{'token': '{token}'}}\n\n" |
| | |
| | |
| | cache.set(request.prompt, request.language, temperature, full_response) |
| | yield "data: [DONE]\n\n" |
| | |
| | return StreamingResponse( |
| | generate(), |
| | media_type="text/event-stream" |
| | ) |
| | |
| | |
| | else: |
| | response = await model_manager.generate( |
| | request.prompt, |
| | request.language, |
| | temperature, |
| | max_tokens |
| | ) |
| | |
| | |
| | cache.set(request.prompt, request.language, temperature, response) |
| | |
| | return ChatResponse( |
| | response=response, |
| | language=request.language, |
| | model=config.MODEL_NAME, |
| | timestamp=datetime.now().isoformat(), |
| | cached=False |
| | ) |
| | |
| | except Exception as e: |
| | logger.error(f"Chat endpoint error: {str(e)}", exc_info=True) |
| | raise HTTPException( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | detail=str(e) |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | logger.info(f"Starting server on {config.HOST}:{config.PORT}") |
| | uvicorn.run( |
| | app, |
| | host=config.HOST, |
| | port=config.PORT, |
| | log_level="info" |
| | ) |
| |
|