subbunanepalli commited on
Commit
5baf551
·
verified ·
1 Parent(s): 9fe40c4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import RobertaTokenizer
5
+ from models.roberta_model import RobertaMultiOutputModel
6
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE
7
+ from dataset_utils import load_label_encoders
8
+ import numpy as np
9
+ import os
10
+
11
+ app = FastAPI()
12
+
13
+ # Load the model and tokenizer
14
+ model_path = "saved_models/ROBERTA_model.pth" # Adjust if different
15
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
16
+
17
+ # Load label encoders
18
+ label_encoders = load_label_encoders()
19
+ num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
20
+
21
+ # Initialize model and load weights
22
+ model = RobertaMultiOutputModel(num_classes).to(DEVICE)
23
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
24
+ model.eval()
25
+
26
+ # Request format
27
+ class PredictionRequest(BaseModel):
28
+ sanction_context: str
29
+
30
+ # Root health check
31
+ @app.get("/")
32
+ async def root():
33
+ return {"status": "healthy", "message": "RoBERTa API is running"}
34
+
35
+ @app.get("/health")
36
+ async def health_check():
37
+ return {"status": "healthy"}
38
+
39
+ # Prediction endpoint
40
+ @app.post("/predict")
41
+ async def predict(request: PredictionRequest):
42
+ try:
43
+ # Tokenize the input text
44
+ inputs = tokenizer(
45
+ request.sanction_context,
46
+ padding='max_length',
47
+ truncation=True,
48
+ max_length=MAX_LEN,
49
+ return_tensors="pt"
50
+ )
51
+
52
+ # Move inputs to device
53
+ input_ids = inputs['input_ids'].to(DEVICE)
54
+ attention_mask = inputs['attention_mask'].to(DEVICE)
55
+
56
+ # Get model predictions
57
+ with torch.no_grad():
58
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
59
+ probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs]
60
+ predictions = [np.argmax(prob, axis=1) for prob in probabilities]
61
+
62
+ # Format the response
63
+ response = {}
64
+ for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)):
65
+ decoded_pred = label_encoders[col].inverse_transform(pred)[0]
66
+ response[col] = {
67
+ "prediction": decoded_pred,
68
+ "probabilities": {
69
+ label: float(prob[0][j])
70
+ for j, label in enumerate(label_encoders[col].classes_)
71
+ }
72
+ }
73
+
74
+ return response
75
+
76
+ except Exception as e:
77
+ raise HTTPException(status_code=500, detail=str(e))
78
+
79
+ # For local or Spaces deployment
80
+ if __name__ == "__main__":
81
+ import uvicorn
82
+ port = int(os.environ.get("PORT", 7860))
83
+ uvicorn.run(app, host="0.0.0.0", port=port)