Devaholic's picture
First commit
2addc01
import time
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import pipeline
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_PATH = "./results/best_model" # produced by train.py
# ── Lifespan (load model once on startup) ─────────────────────────────────────
ml: dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
print(f"Loading model from {MODEL_PATH} ...")
ml["pipe"] = pipeline(
"text-classification",
model=MODEL_PATH,
tokenizer=MODEL_PATH,
truncation=True,
max_length=128,
)
print("Model ready ✅")
yield
ml.clear()
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Sentiment API",
description="DistilBERT fine-tuned on SST-2 — binary sentiment classification",
version="1.0.0",
lifespan=lifespan,
)
# ── Schemas ───────────────────────────────────────────────────────────────────
class PredictRequest(BaseModel):
text: str = Field(
...,
min_length=1,
max_length=512,
example="This movie was absolutely fantastic!",
)
class PredictResponse(BaseModel):
text: str
label: str # "POSITIVE" | "NEGATIVE"
score: float # confidence 0–1
latency_ms: float
class BatchRequest(BaseModel):
texts: list[str] = Field(
...,
min_length=1,
max_length=32,
example=["Great film!", "Terrible waste of time."],
)
class BatchResponse(BaseModel):
results: list[PredictResponse]
total_latency_ms: float
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/", tags=["health"])
def health():
return {"status": "ok", "model": MODEL_PATH}
@app.post("/predict", response_model=PredictResponse, tags=["inference"])
def predict(req: PredictRequest):
if "pipe" not in ml:
raise HTTPException(status_code=503, detail="Model not loaded")
t0 = time.perf_counter()
result = ml["pipe"](req.text)[0]
latency = (time.perf_counter() - t0) * 1000
return PredictResponse(
text=req.text,
label=result["label"],
score=round(result["score"], 4),
latency_ms=round(latency, 2),
)
@app.post("/predict/batch", response_model=BatchResponse, tags=["inference"])
def predict_batch(req: BatchRequest):
if "pipe" not in ml:
raise HTTPException(status_code=503, detail="Model not loaded")
t0 = time.perf_counter()
raw = ml["pipe"](req.texts)
total_latency = (time.perf_counter() - t0) * 1000
results = [
PredictResponse(
text=text,
label=r["label"],
score=round(r["score"], 4),
latency_ms=round(total_latency / len(req.texts), 2),
)
for text, r in zip(req.texts, raw)
]
return BatchResponse(results=results, total_latency_ms=round(total_latency, 2))