ganeshkonapalli commited on
Commit
4584f44
·
verified ·
1 Parent(s): 4f4a05e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import BertTokenizer
5
+ from models.bert_model import BertMultiOutputModel
6
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE
7
+ from dataset_utils import load_label_encoders
8
+ import numpy as np
9
+ import os
10
+
11
+ app = FastAPI()
12
+
13
+ # Load the model and tokenizer
14
+ model_path = "BERT_model.pth"
15
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
16
+ model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
17
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
18
+ model.eval()
19
+
20
+ class PredictionRequest(BaseModel):
21
+ sanction_context: str
22
+
23
+ @app.get("/")
24
+ async def root():
25
+ return {"status": "healthy", "message": "BERT API is running"}
26
+
27
+ @app.get("/health")
28
+ async def health_check():
29
+ return {"status": "healthy"}
30
+
31
+ @app.post("/predict")
32
+ async def predict(request: PredictionRequest):
33
+ try:
34
+ # Tokenize the input text
35
+ inputs = tokenizer(
36
+ request.sanction_context,
37
+ padding='max_length',
38
+ truncation=True,
39
+ max_length=MAX_LEN,
40
+ return_tensors="pt"
41
+ )
42
+
43
+ # Move inputs to device
44
+ input_ids = inputs['input_ids'].to(DEVICE)
45
+ attention_mask = inputs['attention_mask'].to(DEVICE)
46
+
47
+ # Get predictions
48
+ with torch.no_grad():
49
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
50
+ probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs]
51
+ predictions = [np.argmax(prob, axis=1) for prob in probabilities]
52
+
53
+ # Load label encoders to decode predictions
54
+ label_encoders = load_label_encoders()
55
+
56
+ # Format response
57
+ response = {}
58
+ for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)):
59
+ decoded_pred = label_encoders[col].inverse_transform(pred)[0]
60
+ response[col] = {
61
+ "prediction": decoded_pred,
62
+ "probabilities": {
63
+ label: float(prob[0][j])
64
+ for j, label in enumerate(label_encoders[col].classes_)
65
+ }
66
+ }
67
+
68
+ return response
69
+
70
+ except Exception as e:
71
+ raise HTTPException(status_code=500, detail=str(e))
72
+
73
+ if __name__ == "__main__":
74
+ import uvicorn
75
+ # For Hugging Face Spaces, we need to use port 7860
76
+ port = int(os.environ.get("PORT", 7860))
77
+ uvicorn.run(app, host="0.0.0.0", port=port)