Auramind / deployment /api_server.py
ibrahim256's picture
Upload deployment/api_server.py with huggingface_hub
9665dce verified
#!/usr/bin/env python3
"""
AuraMind REST API Server
Production-ready API for AuraMind smartphone deployment
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import uvicorn
import logging
import time
from datetime import datetime
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Request/Response models
class ChatRequest(BaseModel):
message: str
mode: str = "Assistant" # "Therapist" or "Assistant"
max_tokens: int = 200
temperature: float = 0.7
class ChatResponse(BaseModel):
response: str
mode: str
inference_time_ms: float
timestamp: str
class ModelInfo(BaseModel):
variant: str
memory_usage: str
inference_speed: str
status: str
# Initialize FastAPI app
app = FastAPI(
title="AuraMind API",
description="Smartphone-optimized dual-mode AI companion API",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model variables
tokenizer = None
model = None
model_variant = None
def load_model(variant: str = "270m"):
"""Load AuraMind model"""
global tokenizer, model, model_variant
try:
logger.info(f"Loading AuraMind {variant}...")
model_name = "zail-ai/Auramind"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
model.eval()
model_variant = variant
logger.info(f"✅ AuraMind {variant} loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup"""
variant = os.getenv("MODEL_VARIANT", "270m")
load_model(variant)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": model is not None,
"variant": model_variant,
"timestamp": datetime.now().isoformat()
}
@app.get("/model/info", response_model=ModelInfo)
async def get_model_info():
"""Get model information"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
variant_configs = {
"270m": {"memory": "~680MB RAM", "speed": "100-300ms"},
"180m": {"memory": "~450MB RAM", "speed": "80-200ms"},
"90m": {"memory": "~225MB RAM", "speed": "50-150ms"}
}
config = variant_configs.get(model_variant, {"memory": "Unknown", "speed": "Unknown"})
return ModelInfo(
variant=model_variant,
memory_usage=config["memory"],
inference_speed=config["speed"],
status="ready"
)
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""Generate chat response"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if request.mode not in ["Therapist", "Assistant"]:
raise HTTPException(status_code=400, detail="Mode must be 'Therapist' or 'Assistant'")
try:
start_time = time.time()
# Format prompt
prompt = f"<|start_of_turn|>user\n[{request.mode} Mode] {request.message}<|end_of_turn|>\n<|start_of_turn|>model\n"
# Tokenize
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
# Decode response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_response.split("<|start_of_turn|>model\n")[-1].strip()
inference_time = (time.time() - start_time) * 1000
return ChatResponse(
response=response,
mode=request.mode,
inference_time_ms=round(inference_time, 2),
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error generating response: {e}")
raise HTTPException(status_code=500, detail="Failed to generate response")
@app.post("/chat/batch")
async def chat_batch(requests: List[ChatRequest]):
"""Process multiple chat requests"""
if len(requests) > 10: # Limit batch size
raise HTTPException(status_code=400, detail="Batch size limited to 10 requests")
responses = []
for req in requests:
response = await chat(req)
responses.append(response)
return {"responses": responses}
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.getenv("PORT", 8000)),
workers=1 # Single worker for model consistency
)