codeby-hp's picture
Update fastapi_app/app.py
a955d15 verified
import torch
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from transformers import AutoModelForSequenceClassification, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model = None
tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup and cleanup on shutdown"""
global model, tokenizer
try:
model_id = "codeby-hp/FinetuneTinybert-SentimentClassification"
logger.info(f"Loading tokenizer from {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
logger.info(f"Loading model from {model_id}...")
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.to(device)
model.eval()
logger.info(f"Model loaded successfully on {device}")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
yield
logger.info("Shutting down...")
app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan)
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Render the home page"""
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict")
async def predict(request: Request, text: str = Form(...)):
"""Predict sentiment for the given text"""
if not text.strip():
return templates.TemplateResponse(
"index.html",
{"request": request, "error": "Please enter some text to analyze"},
)
try:
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512, padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][predicted_class].item()
sentiment_map = {0: "Negative", 1: "Positive"}
sentiment = sentiment_map.get(predicted_class, "Unknown")
return templates.TemplateResponse(
"index.html",
{
"request": request,
"text": text,
"sentiment": sentiment,
"confidence": round(confidence * 100, 2),
},
)
except Exception as e:
logger.error(f"Prediction error: {e}")
return templates.TemplateResponse(
"index.html", {"request": request, "error": f"An error occurred: {str(e)}"}
)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": model is not None,
"device": str(device),
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)