Spam-Detection / main.py
ibraheem15
Refactor Dockerfile for multi-stage builds and enhance main.py with CORS middleware and updated model
a1099ee
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from contextlib import asynccontextmanager
import torch
# --- MLOps Configuration ---
HF_MODEL_NAME = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
CLASSIFIER_PIPELINE = None
# Pydantic model for request body
class Message(BaseModel):
text: str
# --- LIFESPAN MANAGER (The Modern Fix) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# STARTUP: Load the model
global CLASSIFIER_PIPELINE
print(f"Loading model {HF_MODEL_NAME}...")
try:
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL_NAME)
CLASSIFIER_PIPELINE = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
yield # This point is where the app runs
# SHUTDOWN: Clean up resources (if needed)
print("Shutting down model resources...")
CLASSIFIER_PIPELINE = None
# Initialize App with Lifespan
app = FastAPI(
title="Spam Detection API",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def health_check():
return {"status": "ok", "model_loaded": CLASSIFIER_PIPELINE is not None}
@app.post("/predict")
def predict_spam(item: Message):
if CLASSIFIER_PIPELINE is None:
raise HTTPException(status_code=503, detail="Model is not ready.")
try:
print(f"Received text for prediction: {item.text}")
results = CLASSIFIER_PIPELINE(item.text, truncation=True, max_length=512)
label = results[0]['label']
score = results[0]['score']
output_label = "spam" if label == 'LABEL_1' else "ham"
return {
"prediction": output_label,
"confidence_score": score,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))