subbunanepalli commited on
Commit
c44a5b8
·
verified ·
1 Parent(s): 4e668dc

Upload 6 files

Browse files
Files changed (5) hide show
  1. main.py +82 -39
  2. requirements.txt +5 -3
  3. special_tokens_map.json +1 -0
  4. tokenizer_config.json +1 -0
  5. vocab.txt +8 -0
main.py CHANGED
@@ -1,52 +1,95 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- import pickle
3
- import pandas as pd
4
- from pydantic import BaseModel
5
- from typing import List
6
  import os
 
 
 
 
 
 
7
 
8
- app = FastAPI()
9
- MODEL_PATH = "tfidf_models.pkl"
10
-
11
- class TrainRequest(BaseModel):
12
- texts: List[str]
13
- labels: List[List[int]]
 
 
 
 
14
 
15
- class PredictRequest(BaseModel):
16
- texts: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- @app.post("/train")
19
- def train_model(request: TrainRequest):
20
- from sklearn.feature_extraction.text import TfidfVectorizer
21
- from sklearn.multioutput import MultiOutputClassifier
22
- from sklearn.linear_model import LogisticRegression
23
 
24
- if len(request.texts) != len(request.labels):
25
- raise HTTPException(status_code=400, detail="Texts and labels length mismatch")
 
 
 
 
 
 
 
 
26
 
27
- X = request.texts
28
- y = pd.DataFrame(request.labels)
 
 
 
29
 
30
- vectorizer = TfidfVectorizer()
31
- X_tfidf = vectorizer.fit_transform(X)
 
 
 
 
 
 
32
 
33
- classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
34
- classifier.fit(X_tfidf, y)
35
 
36
- with open(MODEL_PATH, "wb") as f:
37
- pickle.dump((vectorizer, classifier), f)
38
 
39
- return {"message": "Model trained and saved successfully."}
 
 
40
 
41
  @app.post("/predict")
42
  def predict(request: PredictRequest):
43
- if not os.path.exists(MODEL_PATH):
44
- raise HTTPException(status_code=404, detail="Model not found. Train the model first.")
45
-
46
- with open(MODEL_PATH, "rb") as f:
47
- vectorizer, classifier = pickle.load(f)
48
-
49
- X_tfidf = vectorizer.transform(request.texts)
50
- predictions = classifier.predict(X_tfidf)
51
-
52
- return {"predictions": predictions.tolist()}
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import requests
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import BertTokenizer, BertModel
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
 
9
+ # Constants
10
+ LABEL_COLUMNS = [
11
+ 'Red_Flag_Reason', 'Maker_Action', 'Escalation_Level',
12
+ 'Risk_Category', 'Risk_Drivers', 'Investigation_Outcome'
13
+ ]
14
+ PRETRAINED_MODEL_NAME = 'bert-base-uncased'
15
+ MAX_LEN = 128
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ MODEL_PATH = "/tmp/bert_model.pth"
18
+ FILE_ID = "1qqmBxbxM0CmxPGC4sqO6vLJAe-Kikiv4"
19
 
20
+ # Google Drive download logic
21
+ def download_from_google_drive(file_id, dest_path):
22
+ URL = "https://docs.google.com/uc?export=download"
23
+ session = requests.Session()
24
+ response = session.get(URL, params={'id': file_id}, stream=True)
25
+ def get_confirm_token(response):
26
+ for key, value in response.cookies.items():
27
+ if key.startswith('download_warning'):
28
+ return value
29
+ return None
30
+ token = get_confirm_token(response)
31
+ if token:
32
+ params = {'id': file_id, 'confirm': token}
33
+ response = session.get(URL, params=params, stream=True)
34
+ with open(dest_path, "wb") as f:
35
+ for chunk in response.iter_content(32768):
36
+ if chunk:
37
+ f.write(chunk)
38
 
39
+ if not os.path.exists(MODEL_PATH):
40
+ print("Downloading model from Google Drive...")
41
+ download_from_google_drive(FILE_ID, MODEL_PATH)
 
 
42
 
43
+ # Model Definition
44
+ class BertMultiOutput(nn.Module):
45
+ def __init__(self, num_labels_per_output):
46
+ super().__init__()
47
+ self.bert = BertModel.from_pretrained(PRETRAINED_MODEL_NAME)
48
+ self.dropout = nn.Dropout(0.3)
49
+ self.classifiers = nn.ModuleList([
50
+ nn.Linear(self.bert.config.hidden_size, n_labels)
51
+ for n_labels in num_labels_per_output
52
+ ])
53
 
54
+ def forward(self, input_ids, attention_mask):
55
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
56
+ pooled_output = self.dropout(outputs.pooler_output)
57
+ logits = [classifier(pooled_output) for classifier in self.classifiers]
58
+ return logits
59
 
60
+ # Load model and tokenizer
61
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
62
+ label_encoders = checkpoint['label_encoders']
63
+ num_labels_list = [len(le.classes_) for le in label_encoders.values()]
64
+ model = BertMultiOutput(num_labels_list).to(DEVICE)
65
+ model.load_state_dict(checkpoint['model_state_dict'])
66
+ model.eval()
67
+ tokenizer = BertTokenizer.from_pretrained("bert_tokenizer/")
68
 
69
+ # FastAPI app
70
+ app = FastAPI()
71
 
72
+ class PredictRequest(BaseModel):
73
+ text: str
74
 
75
+ @app.get("/")
76
+ def root():
77
+ return {"message": "Multi-output BERT is ready!"}
78
 
79
  @app.post("/predict")
80
  def predict(request: PredictRequest):
81
+ inputs = tokenizer(
82
+ request.text,
83
+ truncation=True,
84
+ padding='max_length',
85
+ max_length=MAX_LEN,
86
+ return_tensors="pt"
87
+ ).to(DEVICE)
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+ preds = [torch.argmax(output, dim=1).item() for output in outputs]
91
+ decoded = {
92
+ label: label_encoders[label].inverse_transform([pred])[0]
93
+ for label, pred in zip(LABEL_COLUMNS, preds)
94
+ }
95
+ return {"predictions": decoded}
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
  fastapi
2
  uvicorn
3
- scikit-learn
4
- pandas
5
- python-multipart
 
 
 
1
  fastapi
2
  uvicorn
3
+ transformers
4
+ torch
5
+ pydantic
6
+ requests
7
+ scikit-learn
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cls_token": "[CLS]", "sep_token": "[SEP]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "model_max_length": 512}
vocab.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ a
7
+ b
8
+ c