from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch import pickle import torch.nn as nn from transformers import AutoTokenizer, AutoModel class ModelClass(nn.Module): def __init__(self, num_labels_per_task): super().__init__() self.encoder = AutoModel.from_pretrained("roberta-base") hidden_size = self.encoder.config.hidden_size self.classifiers = nn.ModuleList([ nn.Linear(hidden_size, num_labels) for num_labels in num_labels_per_task ]) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output return [clf(pooled_output) for clf in self.classifiers] with open("roberta_model.pkl", "rb") as f: bundle = pickle.load(f) tokenizer = bundle["tokenizer"] label_encoders = bundle["label_encoders"] model_state_dict = bundle["model_state_dict"] label_columns = list(label_encoders.keys()) num_labels_per_task = [len(le.classes_) for le in label_encoders.values()] model = ModelClass(num_labels_per_task) model.load_state_dict(model_state_dict) model.eval() app = FastAPI() class Request(BaseModel): text: str @app.post("/predict") def predict(req: Request): try: inputs = tokenizer(req.text, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): logits = model(**inputs) preds = [torch.argmax(logit, dim=1).item() for logit in logits] decoded = {col: label_encoders[col].inverse_transform([pred])[0] for col, pred in zip(label_columns, preds)} return {"predictions": decoded} except Exception as e: raise HTTPException(status_code=500, detail=str(e))