subbu123456 commited on
Commit
b50a487
·
verified ·
1 Parent(s): a2b7acc

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -77
app.py DELETED
@@ -1,77 +0,0 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- import torch
4
- from transformers import BertTokenizer
5
- from models.bert_model import BertMultiOutputModel
6
- from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE
7
- from dataset_utils import load_label_encoders
8
- import numpy as np
9
- import os
10
-
11
- app = FastAPI()
12
-
13
- # Load the model and tokenizer
14
- model_path = "BERT_model.pth"
15
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
16
- model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
17
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
18
- model.eval()
19
-
20
- class PredictionRequest(BaseModel):
21
- sanction_context: str
22
-
23
- @app.get("/")
24
- async def root():
25
- return {"status": "healthy", "message": "BERT API is running"}
26
-
27
- @app.get("/health")
28
- async def health_check():
29
- return {"status": "healthy"}
30
-
31
- @app.post("/predict")
32
- async def predict(request: PredictionRequest):
33
- try:
34
- # Tokenize the input text
35
- inputs = tokenizer(
36
- request.sanction_context,
37
- padding='max_length',
38
- truncation=True,
39
- max_length=MAX_LEN,
40
- return_tensors="pt"
41
- )
42
-
43
- # Move inputs to device
44
- input_ids = inputs['input_ids'].to(DEVICE)
45
- attention_mask = inputs['attention_mask'].to(DEVICE)
46
-
47
- # Get predictions
48
- with torch.no_grad():
49
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
50
- probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs]
51
- predictions = [np.argmax(prob, axis=1) for prob in probabilities]
52
-
53
- # Load label encoders to decode predictions
54
- label_encoders = load_label_encoders()
55
-
56
- # Format response
57
- response = {}
58
- for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)):
59
- decoded_pred = label_encoders[col].inverse_transform(pred)[0]
60
- response[col] = {
61
- "prediction": decoded_pred,
62
- "probabilities": {
63
- label: float(prob[0][j])
64
- for j, label in enumerate(label_encoders[col].classes_)
65
- }
66
- }
67
-
68
- return response
69
-
70
- except Exception as e:
71
- raise HTTPException(status_code=500, detail=str(e))
72
-
73
- if __name__ == "__main__":
74
- import uvicorn
75
- # For Hugging Face Spaces, we need to use port 7860
76
- port = int(os.environ.get("PORT", 7860))
77
- uvicorn.run(app, host="0.0.0.0", port=port)