File size: 1,697 Bytes
d7db76e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer
from config import DEVICE, MAX_LEN, LABEL_COLUMNS, ROBERTA_MODEL_NAME, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH
from models.roberta_model import RobertaMultiOutputModel
from dataset_utils import load_label_encoders
import numpy as np

app = FastAPI()

# Load label encoders
label_encoders = load_label_encoders()

# Load tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_MODEL_NAME)
num_labels = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
model = RobertaMultiOutputModel(num_labels)
model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()

class RequestText(BaseModel):
    text: str

@app.post("/predict")
def predict_labels(request: RequestText):
    try:
        inputs = tokenizer(
            request.text,
            padding='max_length',
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt"
        )
        input_ids = inputs['input_ids'].to(DEVICE)
        attention_mask = inputs['attention_mask'].to(DEVICE)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = [torch.argmax(logits, dim=1).cpu().numpy()[0] for logits in outputs]

        decoded_predictions = {
            label: label_encoders[label].inverse_transform([pred])[0]
            for label, pred in zip(LABEL_COLUMNS, predictions)
        }
        return decoded_predictions
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))