File size: 3,408 Bytes
f2d35dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSeq2SeqLM
import time
# ==========================================
# 1. GLOBAL STATE & LIFESPAN
# ==========================================
# We store the model in a global dictionary so it stays in memory
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# --- STARTUP LOGIC ---
print("⚡ Loading Quantized Model into RAM...")
model_path = "./tone_slider_model_quantized"
try:
# Load Tokenizer
ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_path)
# Load ONNX Model (Explicit filenames to avoid warnings)
ml_models["model"] = ORTModelForSeq2SeqLM.from_pretrained(
model_path,
encoder_file_name="encoder_model_quantized.onnx",
decoder_file_name="decoder_model_quantized.onnx",
decoder_with_past_file_name="decoder_with_past_model_quantized.onnx",
provider="CPUExecutionProvider"
)
print("✅ Model loaded and ready!")
except Exception as e:
print(f"❌ Failed to load model: {e}")
yield
# --- SHUTDOWN LOGIC ---
ml_models.clear()
print("🛑 Model unloaded.")
# ==========================================
# 2. APP SETUP
# ==========================================
app = FastAPI(lifespan=lifespan, title="Tone Slider API")
# Enable CORS (Allows your frontend to talk to this API)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace "*" with your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==========================================
# 3. DATA MODELS
# ==========================================
class ToneRequest(BaseModel):
text: str
style: str # Expecting "casual" or "formal"
class ToneResponse(BaseModel):
original: str
transformed: str
latency_ms: float
# ==========================================
# 4. THE ENDPOINT
# ==========================================
@app.post("/transform", response_model=ToneResponse)
async def transform_text(request: ToneRequest):
if "model" not in ml_models:
raise HTTPException(status_code=503, detail="Model not loaded")
tokenizer = ml_models["tokenizer"]
model = ml_models["model"]
start_time = time.time()
# 1. Preprocess Input
# Validating style to prevent injection
style_prefix = "casual: " if request.style.lower() == "casual" else "formal: "
input_text = style_prefix + request.text
# 2. Inference
inputs = tokenizer(input_text, return_tensors="pt")
# Generate (Greedy search is fastest for CPU)
outputs = model.generate(
**inputs,
max_length=64,
do_sample=False
)
# 3. Decode
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
duration = (time.time() - start_time) * 1000
return ToneResponse(
original=request.text,
transformed=result,
latency_ms=round(duration, 2)
)
# Run with: uvicorn main:app --reload |