ganeshkonapalli commited on
Commit
d7db76e
·
verified ·
1 Parent(s): 486b53c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import RobertaTokenizer
5
+ from config import DEVICE, MAX_LEN, LABEL_COLUMNS, ROBERTA_MODEL_NAME, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH
6
+ from models.roberta_model import RobertaMultiOutputModel
7
+ from dataset_utils import load_label_encoders
8
+ import numpy as np
9
+
10
+ app = FastAPI()
11
+
12
+ # Load label encoders
13
+ label_encoders = load_label_encoders()
14
+
15
+ # Load tokenizer and model
16
+ tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_MODEL_NAME)
17
+ num_labels = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
18
+ model = RobertaMultiOutputModel(num_labels)
19
+ model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE))
20
+ model.to(DEVICE)
21
+ model.eval()
22
+
23
+ class RequestText(BaseModel):
24
+ text: str
25
+
26
+ @app.post("/predict")
27
+ def predict_labels(request: RequestText):
28
+ try:
29
+ inputs = tokenizer(
30
+ request.text,
31
+ padding='max_length',
32
+ truncation=True,
33
+ max_length=MAX_LEN,
34
+ return_tensors="pt"
35
+ )
36
+ input_ids = inputs['input_ids'].to(DEVICE)
37
+ attention_mask = inputs['attention_mask'].to(DEVICE)
38
+
39
+ with torch.no_grad():
40
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
41
+ predictions = [torch.argmax(logits, dim=1).cpu().numpy()[0] for logits in outputs]
42
+
43
+ decoded_predictions = {
44
+ label: label_encoders[label].inverse_transform([pred])[0]
45
+ for label, pred in zip(LABEL_COLUMNS, predictions)
46
+ }
47
+ return decoded_predictions
48
+ except Exception as e:
49
+ raise HTTPException(status_code=500, detail=str(e))