|
|
from fastapi import FastAPI, HTTPException
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel
|
|
|
from contextlib import asynccontextmanager
|
|
|
from transformers import AutoTokenizer
|
|
|
from optimum.onnxruntime import ORTModelForSeq2SeqLM
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ml_models = {}
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
|
print("⚡ Loading Quantized Model into RAM...")
|
|
|
model_path = "./tone_slider_model_quantized"
|
|
|
|
|
|
try:
|
|
|
|
|
|
ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_path)
|
|
|
|
|
|
|
|
|
ml_models["model"] = ORTModelForSeq2SeqLM.from_pretrained(
|
|
|
model_path,
|
|
|
encoder_file_name="encoder_model_quantized.onnx",
|
|
|
decoder_file_name="decoder_model_quantized.onnx",
|
|
|
decoder_with_past_file_name="decoder_with_past_model_quantized.onnx",
|
|
|
provider="CPUExecutionProvider"
|
|
|
)
|
|
|
print("✅ Model loaded and ready!")
|
|
|
except Exception as e:
|
|
|
print(f"❌ Failed to load model: {e}")
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
ml_models.clear()
|
|
|
print("🛑 Model unloaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan, title="Tone Slider API")
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ToneRequest(BaseModel):
|
|
|
text: str
|
|
|
style: str
|
|
|
|
|
|
class ToneResponse(BaseModel):
|
|
|
original: str
|
|
|
transformed: str
|
|
|
latency_ms: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/transform", response_model=ToneResponse)
|
|
|
async def transform_text(request: ToneRequest):
|
|
|
if "model" not in ml_models:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
tokenizer = ml_models["tokenizer"]
|
|
|
model = ml_models["model"]
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
style_prefix = "casual: " if request.style.lower() == "casual" else "formal: "
|
|
|
input_text = style_prefix + request.text
|
|
|
|
|
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt")
|
|
|
|
|
|
|
|
|
outputs = model.generate(
|
|
|
**inputs,
|
|
|
max_length=64,
|
|
|
do_sample=False
|
|
|
)
|
|
|
|
|
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
duration = (time.time() - start_time) * 1000
|
|
|
|
|
|
return ToneResponse(
|
|
|
original=request.text,
|
|
|
transformed=result,
|
|
|
latency_ms=round(duration, 2)
|
|
|
)
|
|
|
|
|
|
|