|
|
import os
|
|
|
import requests
|
|
|
import zipfile
|
|
|
import torch
|
|
|
import logging
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
MODEL_DIR = "model"
|
|
|
MODEL_ZIP_PATH = os.path.join(MODEL_DIR, "model.zip")
|
|
|
MODEL_BLOB_URL = "https://brewtinkersa.blob.core.windows.net/models/models/model.zip"
|
|
|
|
|
|
def download_and_extract_model():
|
|
|
if not os.path.exists(MODEL_DIR):
|
|
|
os.makedirs(MODEL_DIR, exist_ok=True)
|
|
|
if not os.path.exists(os.path.join(MODEL_DIR, "config.json")):
|
|
|
logging.info("Downloading model from Azure Blob...")
|
|
|
response = requests.get(MODEL_BLOB_URL)
|
|
|
with open(MODEL_ZIP_PATH, "wb") as f:
|
|
|
f.write(response.content)
|
|
|
with zipfile.ZipFile(MODEL_ZIP_PATH, 'r') as zip_ref:
|
|
|
zip_ref.extractall(MODEL_DIR)
|
|
|
logging.info("Model extracted.")
|
|
|
|
|
|
download_and_extract_model()
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
|
|
model.eval()
|
|
|
|
|
|
app = FastAPI()
|
|
|
class RequestData(BaseModel):
|
|
|
text: str
|
|
|
@app.post("/predict")
|
|
|
def predict(request: RequestData):
|
|
|
try:
|
|
|
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True)
|
|
|
outputs = model(**inputs)
|
|
|
prediction = torch.argmax(outputs.logits, dim=1).item()
|
|
|
labels = {0: "negative", 1: "neutral", 2: "positive"}
|
|
|
return {"prediction": prediction, "label": labels.get(prediction, "unknown")}
|
|
|
except Exception as e:
|
|
|
logging.error(f"Prediction failed: {e}")
|
|
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
|