from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch app = FastAPI() ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta" BERT_MODEL = "Unknownaut/entity-level-framing-news-bert" labels = ["Legitimate", "Aggressor", "Defensive", "Neutral"] _current_model = None _current_tokenizer = None _current_model_name = None class RequestData(BaseModel): sentence: str entity: str model: str # "RoBERTa" or "BERT" def load_model(model_choice): global _current_model, _current_tokenizer, _current_model_name # reuse if already loaded if _current_model_name == model_choice: return _current_model, _current_tokenizer if model_choice == "RoBERTa": tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL) model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL) elif model_choice == "BERT": tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL) model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL) else: raise ValueError("Invalid model") model.eval() _current_model = model _current_tokenizer = tokenizer _current_model_name = model_choice return model, tokenizer @app.get("/") def health(): return {"status": "ok"} @app.post("/predict") def predict(data: RequestData): model, tokenizer = load_model(data.model) inputs = tokenizer( data.sentence, data.entity, return_tensors="pt", truncation=True, max_length=160 ) with torch.inference_mode(): outputs = model(**inputs) pred = torch.argmax(outputs.logits, dim=1).item() return {"label": labels[pred]}