from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from contextlib import asynccontextmanager from transformers import AutoTokenizer from optimum.onnxruntime import ORTModelForSeq2SeqLM import time # ========================================== # 1. GLOBAL STATE & LIFESPAN # ========================================== # We store the model in a global dictionary so it stays in memory ml_models = {} @asynccontextmanager async def lifespan(app: FastAPI): # --- STARTUP LOGIC --- print("⚡ Loading Quantized Model into RAM...") model_path = "./tone_slider_model_quantized" try: # Load Tokenizer ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_path) # Load ONNX Model (Explicit filenames to avoid warnings) ml_models["model"] = ORTModelForSeq2SeqLM.from_pretrained( model_path, encoder_file_name="encoder_model_quantized.onnx", decoder_file_name="decoder_model_quantized.onnx", decoder_with_past_file_name="decoder_with_past_model_quantized.onnx", provider="CPUExecutionProvider" ) print("✅ Model loaded and ready!") except Exception as e: print(f"❌ Failed to load model: {e}") yield # --- SHUTDOWN LOGIC --- ml_models.clear() print("🛑 Model unloaded.") # ========================================== # 2. APP SETUP # ========================================== app = FastAPI(lifespan=lifespan, title="Tone Slider API") # Enable CORS (Allows your frontend to talk to this API) app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, replace "*" with your frontend URL allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ========================================== # 3. DATA MODELS # ========================================== class ToneRequest(BaseModel): text: str style: str # Expecting "casual" or "formal" class ToneResponse(BaseModel): original: str transformed: str latency_ms: float # ========================================== # 4. THE ENDPOINT # ========================================== @app.post("/transform", response_model=ToneResponse) async def transform_text(request: ToneRequest): if "model" not in ml_models: raise HTTPException(status_code=503, detail="Model not loaded") tokenizer = ml_models["tokenizer"] model = ml_models["model"] start_time = time.time() # 1. Preprocess Input # Validating style to prevent injection style_prefix = "casual: " if request.style.lower() == "casual" else "formal: " input_text = style_prefix + request.text # 2. Inference inputs = tokenizer(input_text, return_tensors="pt") # Generate (Greedy search is fastest for CPU) outputs = model.generate( **inputs, max_length=64, do_sample=False ) # 3. Decode result = tokenizer.decode(outputs[0], skip_special_tokens=True) duration = (time.time() - start_time) * 1000 return ToneResponse( original=request.text, transformed=result, latency_ms=round(duration, 2) ) # Run with: uvicorn main:app --reload