File size: 1,783 Bytes
a62ec9c 9b167d9 a62ec9c d019588 9b167d9 d019588 a62ec9c 0d4493f d019588 a62ec9c 0d4493f d019588 0d4493f d019588 0d4493f d019588 0d4493f d019588 a62ec9c d019588 9b167d9 a62ec9c 9b167d9 47d3dae 9b167d9 a62ec9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = FastAPI()
ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta"
BERT_MODEL = "Unknownaut/entity-level-framing-news-bert"
labels = ["Legitimate", "Aggressor", "Defensive", "Neutral"]
_current_model = None
_current_tokenizer = None
_current_model_name = None
class RequestData(BaseModel):
sentence: str
entity: str
model: str # "RoBERTa" or "BERT"
def load_model(model_choice):
global _current_model, _current_tokenizer, _current_model_name
# reuse if already loaded
if _current_model_name == model_choice:
return _current_model, _current_tokenizer
if model_choice == "RoBERTa":
tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL)
elif model_choice == "BERT":
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL)
else:
raise ValueError("Invalid model")
model.eval()
_current_model = model
_current_tokenizer = tokenizer
_current_model_name = model_choice
return model, tokenizer
@app.get("/")
def health():
return {"status": "ok"}
@app.post("/predict")
def predict(data: RequestData):
model, tokenizer = load_model(data.model)
inputs = tokenizer(
data.sentence,
data.entity,
return_tensors="pt",
truncation=True,
max_length=160
)
with torch.inference_mode():
outputs = model(**inputs)
pred = torch.argmax(outputs.logits, dim=1).item()
return {"label": labels[pred]} |