distilbert-sa / main.py
noootmj's picture
Upload 9 files
53a0393 verified
import os
import requests
import zipfile
import torch
import logging
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Setup logging
logging.basicConfig(level=logging.INFO)
# Model location
MODEL_DIR = "model"
MODEL_ZIP_PATH = os.path.join(MODEL_DIR, "model.zip")
MODEL_BLOB_URL = "https://brewtinkersa.blob.core.windows.net/models/models/model.zip"
# Download and unzip the model at startup
def download_and_extract_model():
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR, exist_ok=True)
if not os.path.exists(os.path.join(MODEL_DIR, "config.json")):
logging.info("Downloading model from Azure Blob...")
response = requests.get(MODEL_BLOB_URL)
with open(MODEL_ZIP_PATH, "wb") as f:
f.write(response.content)
with zipfile.ZipFile(MODEL_ZIP_PATH, 'r') as zip_ref:
zip_ref.extractall(MODEL_DIR)
logging.info("Model extracted.")
# Prepare model
download_and_extract_model()
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()
# FastAPI setup
app = FastAPI()
class RequestData(BaseModel):
text: str
@app.post("/predict")
def predict(request: RequestData):
try:
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True)
outputs = model(**inputs)
prediction = torch.argmax(outputs.logits, dim=1).item()
labels = {0: "negative", 1: "neutral", 2: "positive"}
return {"prediction": prediction, "label": labels.get(prediction, "unknown")}
except Exception as e:
logging.error(f"Prediction failed: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")