import torch import logging from contextlib import asynccontextmanager from fastapi import FastAPI, Request, Form from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from transformers import AutoModelForSequenceClassification, AutoTokenizer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) model = None tokenizer = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup and cleanup on shutdown""" global model, tokenizer try: model_id = "codeby-hp/FinetuneTinybert-SentimentClassification" logger.info(f"Loading tokenizer from {model_id}...") tokenizer = AutoTokenizer.from_pretrained(model_id) logger.info(f"Loading model from {model_id}...") model = AutoModelForSequenceClassification.from_pretrained(model_id) model.to(device) model.eval() logger.info(f"Model loaded successfully on {device}") except Exception as e: logger.error(f"Error loading model: {e}") raise yield logger.info("Shutting down...") app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan) templates = Jinja2Templates(directory="templates") @app.get("/", response_class=HTMLResponse) async def home(request: Request): """Render the home page""" return templates.TemplateResponse("index.html", {"request": request}) @app.post("/predict") async def predict(request: Request, text: str = Form(...)): """Predict sentiment for the given text""" if not text.strip(): return templates.TemplateResponse( "index.html", {"request": request, "error": "Please enter some text to analyze"}, ) try: inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() confidence = probabilities[0][predicted_class].item() sentiment_map = {0: "Negative", 1: "Positive"} sentiment = sentiment_map.get(predicted_class, "Unknown") return templates.TemplateResponse( "index.html", { "request": request, "text": text, "sentiment": sentiment, "confidence": round(confidence * 100, 2), }, ) except Exception as e: logger.error(f"Prediction error: {e}") return templates.TemplateResponse( "index.html", {"request": request, "error": f"An error occurred: {str(e)}"} ) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "model_loaded": model is not None, "device": str(device), } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)