Spaces:
Sleeping
Sleeping
ibraheem15
Refactor Dockerfile for multi-stage builds and enhance main.py with CORS middleware and updated model
a1099ee
| # main.py | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| from contextlib import asynccontextmanager | |
| import torch | |
| # --- MLOps Configuration --- | |
| HF_MODEL_NAME = "mrm8488/bert-tiny-finetuned-sms-spam-detection" | |
| CLASSIFIER_PIPELINE = None | |
| # Pydantic model for request body | |
| class Message(BaseModel): | |
| text: str | |
| # --- LIFESPAN MANAGER (The Modern Fix) --- | |
| async def lifespan(app: FastAPI): | |
| # STARTUP: Load the model | |
| global CLASSIFIER_PIPELINE | |
| print(f"Loading model {HF_MODEL_NAME}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL_NAME) | |
| CLASSIFIER_PIPELINE = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| yield # This point is where the app runs | |
| # SHUTDOWN: Clean up resources (if needed) | |
| print("Shutting down model resources...") | |
| CLASSIFIER_PIPELINE = None | |
| # Initialize App with Lifespan | |
| app = FastAPI( | |
| title="Spam Detection API", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def health_check(): | |
| return {"status": "ok", "model_loaded": CLASSIFIER_PIPELINE is not None} | |
| def predict_spam(item: Message): | |
| if CLASSIFIER_PIPELINE is None: | |
| raise HTTPException(status_code=503, detail="Model is not ready.") | |
| try: | |
| print(f"Received text for prediction: {item.text}") | |
| results = CLASSIFIER_PIPELINE(item.text, truncation=True, max_length=512) | |
| label = results[0]['label'] | |
| score = results[0]['score'] | |
| output_label = "spam" if label == 'LABEL_1' else "ham" | |
| return { | |
| "prediction": output_label, | |
| "confidence_score": score, | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |