|
|
import torch |
|
|
import logging |
|
|
from contextlib import asynccontextmanager |
|
|
from fastapi import FastAPI, Request, Form |
|
|
from fastapi.responses import HTMLResponse |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Load model on startup and cleanup on shutdown""" |
|
|
global model, tokenizer |
|
|
|
|
|
try: |
|
|
model_id = "codeby-hp/FinetuneTinybert-SentimentClassification" |
|
|
|
|
|
logger.info(f"Loading tokenizer from {model_id}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
logger.info(f"Loading model from {model_id}...") |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
logger.info(f"Model loaded successfully on {device}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
yield |
|
|
|
|
|
logger.info("Shutting down...") |
|
|
|
|
|
|
|
|
app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan) |
|
|
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def home(request: Request): |
|
|
"""Render the home page""" |
|
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(request: Request, text: str = Form(...)): |
|
|
"""Predict sentiment for the given text""" |
|
|
if not text.strip(): |
|
|
return templates.TemplateResponse( |
|
|
"index.html", |
|
|
{"request": request, "error": "Please enter some text to analyze"}, |
|
|
) |
|
|
|
|
|
try: |
|
|
inputs = tokenizer( |
|
|
text, return_tensors="pt", truncation=True, max_length=512, padding=True |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
predicted_class = torch.argmax(probabilities, dim=-1).item() |
|
|
confidence = probabilities[0][predicted_class].item() |
|
|
|
|
|
sentiment_map = {0: "Negative", 1: "Positive"} |
|
|
sentiment = sentiment_map.get(predicted_class, "Unknown") |
|
|
|
|
|
return templates.TemplateResponse( |
|
|
"index.html", |
|
|
{ |
|
|
"request": request, |
|
|
"text": text, |
|
|
"sentiment": sentiment, |
|
|
"confidence": round(confidence * 100, 2), |
|
|
}, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Prediction error: {e}") |
|
|
return templates.TemplateResponse( |
|
|
"index.html", {"request": request, "error": f"An error occurred: {str(e)}"} |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": model is not None, |
|
|
"device": str(device), |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|