subbu123456 commited on
Commit
d12f28d
·
verified ·
1 Parent(s): 048ee88

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +50 -0
main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import BertTokenizerFast, BertModel
4
+ import torch
5
+ import torch.nn as nn
6
+ import os
7
+
8
+ # Define constants
9
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "model")
10
+ WEIGHTS_PATH = os.path.join(MODEL_PATH, "bert-multilabel-model.pth")
11
+ NUM_LABELS = 6 # Adjust based on your dataset
12
+
13
+ # Initialize FastAPI app
14
+ app = FastAPI()
15
+
16
+ # Load tokenizer from local directory
17
+ tokenizer = BertTokenizerFast.from_pretrained(MODEL_PATH)
18
+
19
+ # Define the BERT-based multi-label classifier
20
+ class BertMultiLabelClassifier(nn.Module):
21
+ def __init__(self):
22
+ super(BertMultiLabelClassifier, self).__init__()
23
+ self.bert = BertModel.from_pretrained(MODEL_PATH)
24
+ self.classifier = nn.Linear(self.bert.config.hidden_size, NUM_LABELS)
25
+
26
+ def forward(self, input_ids, attention_mask):
27
+ output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
28
+ cls_output = output.last_hidden_state[:, 0, :]
29
+ return self.classifier(cls_output)
30
+
31
+ # Load the model weights
32
+ model = BertMultiLabelClassifier()
33
+ model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu"))
34
+ model.eval()
35
+
36
+ # Input schema for prediction
37
+ class PredictRequest(BaseModel):
38
+ text: str
39
+
40
+ @app.get("/")
41
+ def read_root():
42
+ return {"message": "Multi-label BERT model is running!"}
43
+
44
+ @app.post("/predict")
45
+ def predict(request: PredictRequest):
46
+ inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=512)
47
+ with torch.no_grad():
48
+ logits = model(**inputs)
49
+ probs = torch.sigmoid(logits).squeeze().tolist()
50
+ return {"probabilities": probs}