import os from fastapi import FastAPI, Form from fastapi.responses import HTMLResponse from contextlib import asynccontextmanager from huggingface_hub import hf_hub_download from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification MODEL = {} HF_REPO = os.getenv("HF_REPO", "amarshiv86/sentiment-analysis-imdb-model") @asynccontextmanager async def lifespan(app: FastAPI): print(f"Downloading model from HF Hub: {HF_REPO} …") files = [ "model/config.json", "model/model.safetensors", "model/tokenizer.json", "model/tokenizer_config.json", ] local_dir = "/tmp/sentiment-model" os.makedirs(local_dir, exist_ok=True) for f in files: hf_hub_download( repo_id=HF_REPO, filename=f, repo_type="model", local_dir=local_dir, ) tokenizer = AutoTokenizer.from_pretrained(f"{local_dir}/model") # DistilBERT does not use token_type_ids — strip them to avoid TypeError tokenizer.model_input_names = ["input_ids", "attention_mask"] model = AutoModelForSequenceClassification.from_pretrained(f"{local_dir}/model") MODEL["clf"] = pipeline( "sentiment-analysis", model=model, tokenizer=tokenizer, truncation=True, max_length=256, ) print("Model loaded ✓") yield MODEL.clear() app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan) HTML = """ Sentiment Analyzer · AI Demo

Sentiment Analyzer

distilBERT · fine-tuned on IMDB · MLOps demo

{result_html}
""" @app.get("/", response_class=HTMLResponse) async def index(): return HTML.replace("{result_html}", "").replace("{text_value}", "") @app.post("/predict", response_class=HTMLResponse) async def predict(text: str = Form(...)): result = MODEL["clf"](text)[0] label = result["label"] score = result["score"] pct = round(score * 100, 1) css_class = "pos" if label == "POSITIVE" else "neg" icon = "😊" if label == "POSITIVE" else "😞" bar_width = round(score * 100) result_html = f"""
{icon}

{label}

Model confidence · {pct}%

{label}
confidence {pct}%
""" safe_text = text.replace('"', """).replace("<", "<").replace(">", ">") return (HTML .replace("{result_html}", result_html) .replace("{text_value}", safe_text)) @app.get("/health") async def health(): return {"status": "ok", "model_loaded": "clf" in MODEL}