| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch |
|
|
| app = FastAPI() |
|
|
| ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta" |
| BERT_MODEL = "Unknownaut/entity-level-framing-news-bert" |
|
|
| labels = ["Legitimate", "Aggressor", "Defensive", "Neutral"] |
|
|
| _current_model = None |
| _current_tokenizer = None |
| _current_model_name = None |
|
|
|
|
| class RequestData(BaseModel): |
| sentence: str |
| entity: str |
| model: str |
|
|
|
|
| def load_model(model_choice): |
| global _current_model, _current_tokenizer, _current_model_name |
|
|
| |
| if _current_model_name == model_choice: |
| return _current_model, _current_tokenizer |
|
|
| if model_choice == "RoBERTa": |
| tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL) |
| model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL) |
|
|
| elif model_choice == "BERT": |
| tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL) |
| model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL) |
|
|
| else: |
| raise ValueError("Invalid model") |
|
|
| model.eval() |
|
|
| _current_model = model |
| _current_tokenizer = tokenizer |
| _current_model_name = model_choice |
|
|
| return model, tokenizer |
|
|
|
|
| @app.get("/") |
| def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/predict") |
| def predict(data: RequestData): |
| model, tokenizer = load_model(data.model) |
|
|
| inputs = tokenizer( |
| data.sentence, |
| data.entity, |
| return_tensors="pt", |
| truncation=True, |
| max_length=160 |
| ) |
|
|
| with torch.inference_mode(): |
| outputs = model(**inputs) |
| pred = torch.argmax(outputs.logits, dim=1).item() |
|
|
| return {"label": labels[pred]} |