laminou / main.py
lamionx's picture
Fix: Rename app.py to main.py to match Dockerfile
c76dd2a
"""
AI API Server - FastAPI + Mistral-7B
Production-ready API with streaming, authentication, and caching
"""
import os
import logging
import asyncio
from typing import Optional, Dict, Any, AsyncGenerator
from datetime import datetime
from functools import lru_cache
import hashlib
from fastapi import FastAPI, HTTPException, Header, Request, status
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import uvicorn
# ============================================================================
# LOGGING CONFIGURATION
# ============================================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler("api.log")
]
)
logger = logging.getLogger(__name__)
# ============================================================================
# CONFIGURATION
# ============================================================================
class Config:
"""Application configuration"""
MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2")
API_KEY = os.getenv("API_KEY", "your-secret-api-key-here")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "2048"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
TOP_P = float(os.getenv("TOP_P", "0.95"))
CACHE_SIZE = int(os.getenv("CACHE_SIZE", "100"))
PORT = int(os.getenv("PORT", "7860"))
HOST = os.getenv("HOST", "0.0.0.0")
# Quantization config for 4-bit loading (optimized for free hardware)
QUANTIZATION_CONFIG = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
config = Config()
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class ChatRequest(BaseModel):
"""Request model for chat endpoint"""
prompt: str = Field(..., min_length=1, max_length=4000, description="User prompt")
language: str = Field(default="en", description="Response language (en, pt, es)")
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=None, ge=1, le=4096)
stream: bool = Field(default=True, description="Enable streaming response")
@validator("language")
def validate_language(cls, v):
allowed = ["en", "pt", "es", "fr", "de", "it", "ja", "zh"]
if v not in allowed:
raise ValueError(f"Language must be one of {allowed}")
return v
class ChatResponse(BaseModel):
"""Response model for chat endpoint"""
response: str
language: str
model: str
timestamp: str
cached: bool = False
class HealthResponse(BaseModel):
"""Health check response"""
status: str
model_loaded: bool
timestamp: str
# ============================================================================
# SYSTEM PROMPTS (MULTI-LANGUAGE)
# ============================================================================
SYSTEM_PROMPTS = {
"en": "You are a helpful, respectful and honest AI assistant. Always answer as helpfully as possible, while being safe. If you don't know the answer, say so instead of making up information.",
"pt": "Você é um assistente de IA útil, respeitoso e honesto. Sempre responda da forma mais útil possível, mantendo a segurança. Se não souber a resposta, diga isso ao invés de inventar informações.",
"es": "Eres un asistente de IA útil, respetuoso y honesto. Siempre responde de la manera más útil posible, manteniendo la seguridad. Si no sabes la respuesta, dilo en lugar de inventar información.",
"fr": "Vous êtes un assistant IA utile, respectueux et honnête. Répondez toujours de la manière la plus utile possible, tout en restant sûr. Si vous ne connaissez pas la réponse, dites-le au lieu d'inventer des informations.",
"de": "Sie sind ein hilfreicher, respektvoller und ehrlicher KI-Assistent. Antworten Sie immer so hilfreich wie möglich und bleiben Sie dabei sicher. Wenn Sie die Antwort nicht wissen, sagen Sie es, anstatt Informationen zu erfinden.",
"it": "Sei un assistente AI utile, rispettoso e onesto. Rispondi sempre nel modo più utile possibile, mantenendo la sicurezza. Se non conosci la risposta, dillo invece di inventare informazioni.",
"ja": "あなたは親切で、礼儀正しく、正直なAIアシスタントです。常に安全を保ちながら、できるだけ役立つように答えてください。答えがわからない場合は、情報を作り上げるのではなく、そう言ってください。",
"zh": "你是一个乐于助人、尊重他人且诚实的AI助手。在保持安全的同时,始终尽可能有帮助地回答。如果你不知道答案,请说出来,而不是编造信息。"
}
# ============================================================================
# SIMPLE CACHE IMPLEMENTATION
# ============================================================================
class ResponseCache:
"""Simple in-memory cache for responses"""
def __init__(self, max_size: int = 100):
self.cache: Dict[str, tuple[str, datetime]] = {}
self.max_size = max_size
logger.info(f"Initialized cache with max size: {max_size}")
def _generate_key(self, prompt: str, language: str, temperature: float) -> str:
"""Generate cache key from parameters"""
content = f"{prompt}:{language}:{temperature}"
return hashlib.md5(content.encode()).hexdigest()
def get(self, prompt: str, language: str, temperature: float) -> Optional[str]:
"""Retrieve cached response"""
key = self._generate_key(prompt, language, temperature)
if key in self.cache:
response, timestamp = self.cache[key]
logger.info(f"Cache HIT for key: {key[:8]}...")
return response
logger.info(f"Cache MISS for key: {key[:8]}...")
return None
def set(self, prompt: str, language: str, temperature: float, response: str):
"""Store response in cache"""
if len(self.cache) >= self.max_size:
# Remove oldest entry
oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k][1])
del self.cache[oldest_key]
logger.info(f"Cache full, removed oldest entry: {oldest_key[:8]}...")
key = self._generate_key(prompt, language, temperature)
self.cache[key] = (response, datetime.now())
logger.info(f"Cached response for key: {key[:8]}...")
# ============================================================================
# MODEL LOADING
# ============================================================================
class ModelManager:
"""Manages model loading and inference"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {self.device}")
async def load_model(self):
"""Load model with quantization"""
try:
logger.info(f"Loading model: {config.MODEL_NAME}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
config.MODEL_NAME,
trust_remote_code=True
)
# Load model with 4-bit quantization
self.model = AutoModelForCausalLM.from_pretrained(
config.MODEL_NAME,
quantization_config=config.QUANTIZATION_CONFIG,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
logger.info("Model loaded successfully!")
return True
except Exception as e:
logger.error(f"Failed to load model: {str(e)}", exc_info=True)
return False
def format_prompt(self, prompt: str, language: str) -> str:
"""Format prompt with system message"""
system_prompt = SYSTEM_PROMPTS.get(language, SYSTEM_PROMPTS["en"])
return f"<s>[INST] {system_prompt}\n\nUser: {prompt} [/INST]"
async def generate_stream(
self,
prompt: str,
language: str,
temperature: float,
max_tokens: int
) -> AsyncGenerator[str, None]:
"""Generate response with streaming"""
try:
formatted_prompt = self.format_prompt(prompt, language)
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=config.MAX_LENGTH
).to(self.device)
# Generate with streaming
with torch.no_grad():
for i in range(max_tokens):
outputs = self.model.generate(
**inputs,
max_new_tokens=1,
temperature=temperature,
top_p=config.TOP_P,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode new token
new_token = self.tokenizer.decode(
outputs[0][-1:],
skip_special_tokens=True
)
# Check for end of sequence
if outputs[0][-1] == self.tokenizer.eos_token_id:
break
yield new_token
# Update inputs for next iteration
inputs = {"input_ids": outputs}
# Small delay to simulate realistic streaming
await asyncio.sleep(0.01)
except Exception as e:
logger.error(f"Generation error: {str(e)}", exc_info=True)
yield f"\n\n[Error: {str(e)}]"
async def generate(
self,
prompt: str,
language: str,
temperature: float,
max_tokens: int
) -> str:
"""Generate complete response (non-streaming)"""
try:
formatted_prompt = self.format_prompt(prompt, language)
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=config.MAX_LENGTH
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=config.TOP_P,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
logger.error(f"Generation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Generation failed: {str(e)}"
)
# ============================================================================
# FASTAPI APPLICATION
# ============================================================================
app = FastAPI(
title="AI API - Mistral 7B",
description="Production-ready AI API with streaming, authentication, and caching",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global instances
model_manager = ModelManager()
cache = ResponseCache(max_size=config.CACHE_SIZE)
# ============================================================================
# AUTHENTICATION
# ============================================================================
async def verify_api_key(x_api_key: str = Header(..., alias="X-API-Key")):
"""Verify API key from header"""
if x_api_key != config.API_KEY:
logger.warning(f"Invalid API key attempt: {x_api_key[:8]}...")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
return x_api_key
# ============================================================================
# STARTUP/SHUTDOWN EVENTS
# ============================================================================
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("Starting AI API server...")
success = await model_manager.load_model()
if not success:
logger.error("Failed to load model, server may not function correctly")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
logger.info("Shutting down AI API server...")
# Clear cache
cache.cache.clear()
# ============================================================================
# ROUTES
# ============================================================================
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve frontend HTML"""
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI API - Mistral 7B</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
.message {
animation: fadeIn 0.3s ease-out;
}
.typing-indicator {
display: inline-block;
}
.typing-indicator span {
display: inline-block;
width: 8px;
height: 8px;
border-radius: 50%;
background-color: #6366F1;
margin: 0 2px;
animation: typing 1.4s infinite;
}
.typing-indicator span:nth-child(2) {
animation-delay: 0.2s;
}
.typing-indicator span:nth-child(3) {
animation-delay: 0.4s;
}
@keyframes typing {
0%, 60%, 100% { transform: translateY(0); }
30% { transform: translateY(-10px); }
}
</style>
</head>
<body class="bg-gradient-to-br from-slate-50 to-slate-100 min-h-screen">
<div class="container mx-auto px-4 py-8 max-w-4xl">
<!-- Header -->
<div class="bg-white rounded-2xl shadow-lg p-6 mb-6">
<h1 class="text-3xl font-bold text-slate-800 mb-2">🤖 AI API - Mistral 7B</h1>
<p class="text-slate-600">Production-ready AI API with streaming responses</p>
</div>
<!-- API Key Section -->
<div class="bg-white rounded-2xl shadow-lg p-6 mb-6">
<label class="block text-sm font-semibold text-slate-700 mb-2">API Key</label>
<input
type="password"
id="apiKey"
placeholder="Enter your API key"
class="w-full px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-indigo-500 focus:border-transparent outline-none"
/>
</div>
<!-- Chat Interface -->
<div class="bg-white rounded-2xl shadow-lg p-6 mb-6">
<div id="messages" class="space-y-4 mb-6 max-h-96 overflow-y-auto">
<div class="text-center text-slate-400 py-8">
Start a conversation by typing a message below
</div>
</div>
<!-- Input Area -->
<div class="flex gap-3">
<select id="language" class="px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-indigo-500 outline-none">
<option value="en">English</option>
<option value="pt">Português</option>
<option value="es">Español</option>
<option value="fr">Français</option>
<option value="de">Deutsch</option>
<option value="it">Italiano</option>
</select>
<input
type="text"
id="prompt"
placeholder="Type your message..."
class="flex-1 px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-indigo-500 focus:border-transparent outline-none"
/>
<button
onclick="sendMessage()"
id="sendBtn"
class="px-6 py-3 bg-indigo-600 text-white font-semibold rounded-lg hover:bg-indigo-700 transition-colors disabled:bg-slate-300 disabled:cursor-not-allowed"
>
Send
</button>
</div>
</div>
<!-- Documentation -->
<div class="bg-white rounded-2xl shadow-lg p-6">
<h2 class="text-xl font-bold text-slate-800 mb-4">📚 API Documentation</h2>
<div class="space-y-4">
<div>
<h3 class="font-semibold text-slate-700 mb-2">Endpoint</h3>
<code class="block bg-slate-100 p-3 rounded-lg text-sm">POST /api/chat</code>
</div>
<div>
<h3 class="font-semibold text-slate-700 mb-2">Headers</h3>
<code class="block bg-slate-100 p-3 rounded-lg text-sm">X-API-Key: your-api-key</code>
</div>
<div>
<h3 class="font-semibold text-slate-700 mb-2">Example (curl)</h3>
<pre class="bg-slate-100 p-3 rounded-lg text-sm overflow-x-auto"><code>curl -X POST "http://localhost:7860/api/chat" \\
-H "X-API-Key: your-secret-api-key-here" \\
-H "Content-Type: application/json" \\
-d '{
"prompt": "Explain quantum computing",
"language": "en",
"stream": false
}'</code></pre>
</div>
</div>
</div>
</div>
<script>
const messagesDiv = document.getElementById('messages');
const promptInput = document.getElementById('prompt');
const apiKeyInput = document.getElementById('apiKey');
const languageSelect = document.getElementById('language');
const sendBtn = document.getElementById('sendBtn');
// Load API key from localStorage
const savedApiKey = localStorage.getItem('apiKey');
if (savedApiKey) {
apiKeyInput.value = savedApiKey;
}
// Save API key on change
apiKeyInput.addEventListener('change', () => {
localStorage.setItem('apiKey', apiKeyInput.value);
});
// Send on Enter
promptInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
}
});
function addMessage(content, isUser = false) {
if (messagesDiv.children[0]?.textContent.includes('Start a conversation')) {
messagesDiv.innerHTML = '';
}
const messageDiv = document.createElement('div');
messageDiv.className = `message flex ${isUser ? 'justify-end' : 'justify-start'}`;
const bubble = document.createElement('div');
bubble.className = `max-w-[70%] px-4 py-3 rounded-2xl ${
isUser
? 'bg-indigo-600 text-white'
: 'bg-slate-100 text-slate-800'
}`;
bubble.textContent = content;
messageDiv.appendChild(bubble);
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
return bubble;
}
function addTypingIndicator() {
const messageDiv = document.createElement('div');
messageDiv.className = 'message flex justify-start';
messageDiv.id = 'typing-indicator';
const bubble = document.createElement('div');
bubble.className = 'max-w-[70%] px-4 py-3 rounded-2xl bg-slate-100';
bubble.innerHTML = '<div class="typing-indicator"><span></span><span></span><span></span></div>';
messageDiv.appendChild(bubble);
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
function removeTypingIndicator() {
const indicator = document.getElementById('typing-indicator');
if (indicator) {
indicator.remove();
}
}
async function sendMessage() {
const prompt = promptInput.value.trim();
const apiKey = apiKeyInput.value.trim();
const language = languageSelect.value;
if (!prompt) return;
if (!apiKey) {
alert('Please enter your API key');
return;
}
// Add user message
addMessage(prompt, true);
promptInput.value = '';
// Disable send button
sendBtn.disabled = true;
addTypingIndicator();
try {
const response = await fetch('/api/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': apiKey
},
body: JSON.stringify({
prompt: prompt,
language: language,
stream: true
})
});
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${await response.text()}`);
}
removeTypingIndicator();
const bubble = addMessage('', false);
// Read stream
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = line.slice(6);
if (data === '[DONE]') break;
try {
const json = JSON.parse(data);
if (json.token) {
fullResponse += json.token;
bubble.textContent = fullResponse;
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
} catch (e) {
console.error('Parse error:', e);
}
}
}
}
} catch (error) {
removeTypingIndicator();
addMessage(`Error: ${error.message}`, false);
} finally {
sendBtn.disabled = false;
promptInput.focus();
}
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy",
model_loaded=model_manager.model is not None,
timestamp=datetime.now().isoformat()
)
@app.post("/api/chat")
async def chat(
request: ChatRequest,
api_key: str = Header(..., alias="X-API-Key")
):
"""
Chat endpoint with streaming support
Requires X-API-Key header for authentication
"""
# Verify API key
await verify_api_key(api_key)
# Check if model is loaded
if model_manager.model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model not loaded yet, please try again later"
)
# Get parameters
temperature = request.temperature or config.TEMPERATURE
max_tokens = request.max_tokens or 512
try:
# Check cache for non-streaming requests
if not request.stream:
cached_response = cache.get(request.prompt, request.language, temperature)
if cached_response:
return ChatResponse(
response=cached_response,
language=request.language,
model=config.MODEL_NAME,
timestamp=datetime.now().isoformat(),
cached=True
)
# Streaming response
if request.stream:
async def generate():
full_response = ""
async for token in model_manager.generate_stream(
request.prompt,
request.language,
temperature,
max_tokens
):
full_response += token
yield f"data: {{'token': '{token}'}}\n\n"
# Cache complete response
cache.set(request.prompt, request.language, temperature, full_response)
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
# Non-streaming response
else:
response = await model_manager.generate(
request.prompt,
request.language,
temperature,
max_tokens
)
# Cache response
cache.set(request.prompt, request.language, temperature, response)
return ChatResponse(
response=response,
language=request.language,
model=config.MODEL_NAME,
timestamp=datetime.now().isoformat(),
cached=False
)
except Exception as e:
logger.error(f"Chat endpoint error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
logger.info(f"Starting server on {config.HOST}:{config.PORT}")
uvicorn.run(
app,
host=config.HOST,
port=config.PORT,
log_level="info"
)