#!/usr/bin/env python3 """ AuraMind REST API Server Production-ready API for AuraMind smartphone deployment """ from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Dict import torch from transformers import AutoTokenizer, AutoModelForCausalLM import uvicorn import logging import time from datetime import datetime import os # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Request/Response models class ChatRequest(BaseModel): message: str mode: str = "Assistant" # "Therapist" or "Assistant" max_tokens: int = 200 temperature: float = 0.7 class ChatResponse(BaseModel): response: str mode: str inference_time_ms: float timestamp: str class ModelInfo(BaseModel): variant: str memory_usage: str inference_speed: str status: str # Initialize FastAPI app app = FastAPI( title="AuraMind API", description="Smartphone-optimized dual-mode AI companion API", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure appropriately for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model variables tokenizer = None model = None model_variant = None def load_model(variant: str = "270m"): """Load AuraMind model""" global tokenizer, model, model_variant try: logger.info(f"Loading AuraMind {variant}...") model_name = "zail-ai/Auramind" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True ) model.eval() model_variant = variant logger.info(f"✅ AuraMind {variant} loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise @app.on_event("startup") async def startup_event(): """Initialize model on startup""" variant = os.getenv("MODEL_VARIANT", "270m") load_model(variant) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "model_loaded": model is not None, "variant": model_variant, "timestamp": datetime.now().isoformat() } @app.get("/model/info", response_model=ModelInfo) async def get_model_info(): """Get model information""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") variant_configs = { "270m": {"memory": "~680MB RAM", "speed": "100-300ms"}, "180m": {"memory": "~450MB RAM", "speed": "80-200ms"}, "90m": {"memory": "~225MB RAM", "speed": "50-150ms"} } config = variant_configs.get(model_variant, {"memory": "Unknown", "speed": "Unknown"}) return ModelInfo( variant=model_variant, memory_usage=config["memory"], inference_speed=config["speed"], status="ready" ) @app.post("/chat", response_model=ChatResponse) async def chat(request: ChatRequest): """Generate chat response""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") if request.mode not in ["Therapist", "Assistant"]: raise HTTPException(status_code=400, detail="Mode must be 'Therapist' or 'Assistant'") try: start_time = time.time() # Format prompt prompt = f"<|start_of_turn|>user\n[{request.mode} Mode] {request.message}<|end_of_turn|>\n<|start_of_turn|>model\n" # Tokenize inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_tokens, temperature=request.temperature, do_sample=True, top_p=0.9, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id ) # Decode response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = full_response.split("<|start_of_turn|>model\n")[-1].strip() inference_time = (time.time() - start_time) * 1000 return ChatResponse( response=response, mode=request.mode, inference_time_ms=round(inference_time, 2), timestamp=datetime.now().isoformat() ) except Exception as e: logger.error(f"Error generating response: {e}") raise HTTPException(status_code=500, detail="Failed to generate response") @app.post("/chat/batch") async def chat_batch(requests: List[ChatRequest]): """Process multiple chat requests""" if len(requests) > 10: # Limit batch size raise HTTPException(status_code=400, detail="Batch size limited to 10 requests") responses = [] for req in requests: response = await chat(req) responses.append(response) return {"responses": responses} if __name__ == "__main__": uvicorn.run( app, host="0.0.0.0", port=int(os.getenv("PORT", 8000)), workers=1 # Single worker for model consistency )