roberta / main.py
ganeshkonapalli's picture
Upload 8 files
46f994e verified
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer
from app.roberta_model import RobertaMultiOutputModel
from app.dataset_utils import load_label_encoders
from app.config import MAX_LEN, LABEL_COLUMNS, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH, TOKENIZER_PATH
app = FastAPI()
class InputText(BaseModel):
sanction_context: str
label_encoders = load_label_encoders(LABEL_ENCODERS_PATH)
num_classes_per_label = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
model = RobertaMultiOutputModel(num_classes_per_label)
model.load_state_dict(torch.load(f"{MODEL_SAVE_DIR}/ROBERTA_model.pth", map_location="cpu"))
model.eval()
tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH)
@app.post("/predict")
def predict(input_text: InputText):
inputs = tokenizer(
input_text.sanction_context,
padding='max_length',
truncation=True,
max_length=MAX_LEN,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
predicted = [torch.argmax(logit, dim=1).item() for logit in outputs]
decoded = {
label: label_encoders[label].inverse_transform([pred])[0]
for label, pred in zip(LABEL_COLUMNS, predicted)
}
return {"predictions": decoded}