File size: 3,408 Bytes
f2d35dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSeq2SeqLM
import time


# ==========================================
# 1. GLOBAL STATE & LIFESPAN
# ==========================================
# We store the model in a global dictionary so it stays in memory
ml_models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # --- STARTUP LOGIC ---
    print("⚡ Loading Quantized Model into RAM...")
    model_path = "./tone_slider_model_quantized"
    
    try:
        # Load Tokenizer
        ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_path)
        
        # Load ONNX Model (Explicit filenames to avoid warnings)
        ml_models["model"] = ORTModelForSeq2SeqLM.from_pretrained(
            model_path,
            encoder_file_name="encoder_model_quantized.onnx",
            decoder_file_name="decoder_model_quantized.onnx",
            decoder_with_past_file_name="decoder_with_past_model_quantized.onnx",
            provider="CPUExecutionProvider"
        )
        print("✅ Model loaded and ready!")
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        
    yield
    
    # --- SHUTDOWN LOGIC ---
    ml_models.clear()
    print("🛑 Model unloaded.")

# ==========================================
# 2. APP SETUP
# ==========================================
app = FastAPI(lifespan=lifespan, title="Tone Slider API")

# Enable CORS (Allows your frontend to talk to this API)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # In production, replace "*" with your frontend URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ==========================================
# 3. DATA MODELS
# ==========================================
class ToneRequest(BaseModel):
    text: str
    style: str  # Expecting "casual" or "formal"

class ToneResponse(BaseModel):
    original: str
    transformed: str
    latency_ms: float

# ==========================================
# 4. THE ENDPOINT
# ==========================================
@app.post("/transform", response_model=ToneResponse)
async def transform_text(request: ToneRequest):
    if "model" not in ml_models:
        raise HTTPException(status_code=503, detail="Model not loaded")

    tokenizer = ml_models["tokenizer"]
    model = ml_models["model"]
    
    start_time = time.time()

    # 1. Preprocess Input
    # Validating style to prevent injection
    style_prefix = "casual: " if request.style.lower() == "casual" else "formal: "
    input_text = style_prefix + request.text

    # 2. Inference
    inputs = tokenizer(input_text, return_tensors="pt")
    
    # Generate (Greedy search is fastest for CPU)
    outputs = model.generate(
        **inputs, 
        max_length=64,
        do_sample=False 
    )

    # 3. Decode
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    duration = (time.time() - start_time) * 1000

    return ToneResponse(
        original=request.text,
        transformed=result,
        latency_ms=round(duration, 2)
    )

# Run with: uvicorn main:app --reload