|
|
|
|
|
""" |
|
|
AuraMind REST API Server |
|
|
Production-ready API for AuraMind smartphone deployment |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from typing import Optional, List, Dict |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import uvicorn |
|
|
import logging |
|
|
import time |
|
|
from datetime import datetime |
|
|
import os |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
message: str |
|
|
mode: str = "Assistant" |
|
|
max_tokens: int = 200 |
|
|
temperature: float = 0.7 |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
response: str |
|
|
mode: str |
|
|
inference_time_ms: float |
|
|
timestamp: str |
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
variant: str |
|
|
memory_usage: str |
|
|
inference_speed: str |
|
|
status: str |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="AuraMind API", |
|
|
description="Smartphone-optimized dual-mode AI companion API", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
model_variant = None |
|
|
|
|
|
def load_model(variant: str = "270m"): |
|
|
"""Load AuraMind model""" |
|
|
global tokenizer, model, model_variant |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading AuraMind {variant}...") |
|
|
|
|
|
model_name = "zail-ai/Auramind" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
model_variant = variant |
|
|
|
|
|
logger.info(f"✅ AuraMind {variant} loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize model on startup""" |
|
|
variant = os.getenv("MODEL_VARIANT", "270m") |
|
|
load_model(variant) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": model is not None, |
|
|
"variant": model_variant, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
@app.get("/model/info", response_model=ModelInfo) |
|
|
async def get_model_info(): |
|
|
"""Get model information""" |
|
|
if model is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
variant_configs = { |
|
|
"270m": {"memory": "~680MB RAM", "speed": "100-300ms"}, |
|
|
"180m": {"memory": "~450MB RAM", "speed": "80-200ms"}, |
|
|
"90m": {"memory": "~225MB RAM", "speed": "50-150ms"} |
|
|
} |
|
|
|
|
|
config = variant_configs.get(model_variant, {"memory": "Unknown", "speed": "Unknown"}) |
|
|
|
|
|
return ModelInfo( |
|
|
variant=model_variant, |
|
|
memory_usage=config["memory"], |
|
|
inference_speed=config["speed"], |
|
|
status="ready" |
|
|
) |
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat(request: ChatRequest): |
|
|
"""Generate chat response""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
if request.mode not in ["Therapist", "Assistant"]: |
|
|
raise HTTPException(status_code=400, detail="Mode must be 'Therapist' or 'Assistant'") |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
prompt = f"<|start_of_turn|>user\n[{request.mode} Mode] {request.message}<|end_of_turn|>\n<|start_of_turn|>model\n" |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.1, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
response = full_response.split("<|start_of_turn|>model\n")[-1].strip() |
|
|
|
|
|
inference_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
return ChatResponse( |
|
|
response=response, |
|
|
mode=request.mode, |
|
|
inference_time_ms=round(inference_time, 2), |
|
|
timestamp=datetime.now().isoformat() |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating response: {e}") |
|
|
raise HTTPException(status_code=500, detail="Failed to generate response") |
|
|
|
|
|
@app.post("/chat/batch") |
|
|
async def chat_batch(requests: List[ChatRequest]): |
|
|
"""Process multiple chat requests""" |
|
|
if len(requests) > 10: |
|
|
raise HTTPException(status_code=400, detail="Batch size limited to 10 requests") |
|
|
|
|
|
responses = [] |
|
|
for req in requests: |
|
|
response = await chat(req) |
|
|
responses.append(response) |
|
|
|
|
|
return {"responses": responses} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=int(os.getenv("PORT", 8000)), |
|
|
workers=1 |
|
|
) |
|
|
|