import os import requests import zipfile import torch import logging from transformers import AutoTokenizer, AutoModelForSequenceClassification from fastapi import FastAPI, HTTPException from pydantic import BaseModel # Setup logging logging.basicConfig(level=logging.INFO) # Model location MODEL_DIR = "model" MODEL_ZIP_PATH = os.path.join(MODEL_DIR, "model.zip") MODEL_BLOB_URL = "https://brewtinkersa.blob.core.windows.net/models/models/model.zip" # Download and unzip the model at startup def download_and_extract_model(): if not os.path.exists(MODEL_DIR): os.makedirs(MODEL_DIR, exist_ok=True) if not os.path.exists(os.path.join(MODEL_DIR, "config.json")): logging.info("Downloading model from Azure Blob...") response = requests.get(MODEL_BLOB_URL) with open(MODEL_ZIP_PATH, "wb") as f: f.write(response.content) with zipfile.ZipFile(MODEL_ZIP_PATH, 'r') as zip_ref: zip_ref.extractall(MODEL_DIR) logging.info("Model extracted.") # Prepare model download_and_extract_model() tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) model.eval() # FastAPI setup app = FastAPI() class RequestData(BaseModel): text: str @app.post("/predict") def predict(request: RequestData): try: inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True) outputs = model(**inputs) prediction = torch.argmax(outputs.logits, dim=1).item() labels = {0: "negative", 1: "neutral", 2: "positive"} return {"prediction": prediction, "label": labels.get(prediction, "unknown")} except Exception as e: logging.error(f"Prediction failed: {e}") raise HTTPException(status_code=500, detail="Internal Server Error")