Unknownaut's picture
Update app.py
a62ec9c verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = FastAPI()
ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta"
BERT_MODEL = "Unknownaut/entity-level-framing-news-bert"
labels = ["Legitimate", "Aggressor", "Defensive", "Neutral"]
_current_model = None
_current_tokenizer = None
_current_model_name = None
class RequestData(BaseModel):
sentence: str
entity: str
model: str # "RoBERTa" or "BERT"
def load_model(model_choice):
global _current_model, _current_tokenizer, _current_model_name
# reuse if already loaded
if _current_model_name == model_choice:
return _current_model, _current_tokenizer
if model_choice == "RoBERTa":
tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL)
elif model_choice == "BERT":
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL)
else:
raise ValueError("Invalid model")
model.eval()
_current_model = model
_current_tokenizer = tokenizer
_current_model_name = model_choice
return model, tokenizer
@app.get("/")
def health():
return {"status": "ok"}
@app.post("/predict")
def predict(data: RequestData):
model, tokenizer = load_model(data.model)
inputs = tokenizer(
data.sentence,
data.entity,
return_tensors="pt",
truncation=True,
max_length=160
)
with torch.inference_mode():
outputs = model(**inputs)
pred = torch.argmax(outputs.logits, dim=1).item()
return {"label": labels[pred]}