subbunanepalli commited on
Commit
6993506
·
verified ·
1 Parent(s): 53eac32

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +83 -0
  2. config.py +61 -0
  3. dataset_utils.py +92 -0
  4. label_encoders.pkl +3 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import RobertaTokenizer
5
+ from models.roberta_model import RobertaMultiOutputModel
6
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE
7
+ from dataset_utils import load_label_encoders
8
+ import numpy as np
9
+ import os
10
+
11
+ app = FastAPI()
12
+
13
+ # Load the model and tokenizer
14
+ model_path = "saved_models/ROBERTA_model.pth" # Adjust if different
15
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
16
+
17
+ # Load label encoders
18
+ label_encoders = load_label_encoders()
19
+ num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
20
+
21
+ # Initialize model and load weights
22
+ model = RobertaMultiOutputModel(num_classes).to(DEVICE)
23
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
24
+ model.eval()
25
+
26
+ # Request format
27
+ class PredictionRequest(BaseModel):
28
+ sanction_context: str
29
+
30
+ # Root health check
31
+ @app.get("/")
32
+ async def root():
33
+ return {"status": "healthy", "message": "RoBERTa API is running"}
34
+
35
+ @app.get("/health")
36
+ async def health_check():
37
+ return {"status": "healthy"}
38
+
39
+ # Prediction endpoint
40
+ @app.post("/predict")
41
+ async def predict(request: PredictionRequest):
42
+ try:
43
+ # Tokenize the input text
44
+ inputs = tokenizer(
45
+ request.sanction_context,
46
+ padding='max_length',
47
+ truncation=True,
48
+ max_length=MAX_LEN,
49
+ return_tensors="pt"
50
+ )
51
+
52
+ # Move inputs to device
53
+ input_ids = inputs['input_ids'].to(DEVICE)
54
+ attention_mask = inputs['attention_mask'].to(DEVICE)
55
+
56
+ # Get model predictions
57
+ with torch.no_grad():
58
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
59
+ probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs]
60
+ predictions = [np.argmax(prob, axis=1) for prob in probabilities]
61
+
62
+ # Format the response
63
+ response = {}
64
+ for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)):
65
+ decoded_pred = label_encoders[col].inverse_transform(pred)[0]
66
+ response[col] = {
67
+ "prediction": decoded_pred,
68
+ "probabilities": {
69
+ label: float(prob[0][j])
70
+ for j, label in enumerate(label_encoders[col].classes_)
71
+ }
72
+ }
73
+
74
+ return response
75
+
76
+ except Exception as e:
77
+ raise HTTPException(status_code=500, detail=str(e))
78
+
79
+ # For local or Spaces deployment
80
+ if __name__ == "__main__":
81
+ import uvicorn
82
+ port = int(os.environ.get("PORT", 7860))
83
+ uvicorn.run(app, host="0.0.0.0", port=port)
config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ # --- Paths ---
5
+ DATA_PATH = '/kaggle/input/synthesisss/synthetic_transactions_samples_5000.csv'
6
+ TOKENIZER_PATH = './tokenizer/'
7
+ LABEL_ENCODERS_PATH = './label_encoders.pkl'
8
+ MODEL_SAVE_DIR = './saved_models/'
9
+ PREDICTIONS_SAVE_DIR = './predictions/'
10
+
11
+ # --- Data Columns ---
12
+ TEXT_COLUMN = "Sanction_Context"
13
+ LABEL_COLUMNS = [
14
+ "Red_Flag_Reason",
15
+ "Maker_Action",
16
+ "Escalation_Level",
17
+ "Risk_Category",
18
+ "Risk_Drivers",
19
+ "Investigation_Outcome"
20
+ ]
21
+ METADATA_COLUMNS = []
22
+
23
+ # --- Model Hyperparameters ---
24
+ MAX_LEN = 128
25
+ BATCH_SIZE = 16
26
+ LEARNING_RATE = 2e-5
27
+ NUM_EPOCHS = 3
28
+ DROPOUT_RATE = 0.3
29
+
30
+ # --- Device Configuration ---
31
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ # --- Model Names ---
34
+ BERT_MODEL_NAME = 'bert-base-uncased'
35
+ ROBERTA_MODEL_NAME = 'roberta-base'
36
+ DEBERTA_MODEL_NAME = 'microsoft/deberta-base'
37
+
38
+ # --- TF-IDF ---
39
+ TFIDF_MAX_FEATURES = 5000
40
+
41
+ # --- Optional Strategy Definitions ---
42
+ FIELD_STRATEGIES = {
43
+ "Maker_Action": {
44
+ "loss": "focal_loss",
45
+ "enhancements": ["action_templates", "context_prompt_tuning"]
46
+ },
47
+ "Risk_Category": {
48
+ "enhancements": ["numerical_metadata", "transaction_patterns"]
49
+ },
50
+ "Escalation_Level": {
51
+ "enhancements": ["class_balancing", "policy_keyword_patterns"]
52
+ },
53
+ "Investigation_Outcome": {
54
+ "type": "classification_or_generation"
55
+ }
56
+ }
57
+
58
+ # Ensure save directories exist
59
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
60
+ os.makedirs(PREDICTIONS_SAVE_DIR, exist_ok=True)
61
+ os.makedirs(TOKENIZER_PATH, exist_ok=True)
dataset_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from sklearn.preprocessing import LabelEncoder
5
+ from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
6
+ import pickle
7
+ import os
8
+
9
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
10
+
11
+ class ComplianceDataset(Dataset):
12
+ def __init__(self, texts, labels, tokenizer, max_len):
13
+ self.texts = texts
14
+ self.labels = labels
15
+ self.tokenizer = tokenizer
16
+ self.max_len = max_len
17
+
18
+ def __len__(self):
19
+ return len(self.texts)
20
+
21
+ def __getitem__(self, idx):
22
+ text = str(self.texts[idx])
23
+ inputs = self.tokenizer(
24
+ text,
25
+ padding='max_length',
26
+ truncation=True,
27
+ max_length=self.max_len,
28
+ return_tensors="pt"
29
+ )
30
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
31
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
32
+ return inputs, labels
33
+
34
+ class ComplianceDatasetWithMetadata(Dataset):
35
+ def __init__(self, texts, metadata, labels, tokenizer, max_len):
36
+ self.texts = texts
37
+ self.metadata = metadata
38
+ self.labels = labels
39
+ self.tokenizer = tokenizer
40
+ self.max_len = max_len
41
+
42
+ def __len__(self):
43
+ return len(self.texts)
44
+
45
+ def __getitem__(self, idx):
46
+ text = str(self.texts[idx])
47
+ inputs = self.tokenizer(
48
+ text,
49
+ padding='max_length',
50
+ truncation=True,
51
+ max_length=self.max_len,
52
+ return_tensors="pt"
53
+ )
54
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
55
+ metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
56
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
57
+ return inputs, metadata, labels
58
+
59
+ def load_and_preprocess_data(data_path):
60
+ data = pd.read_csv(data_path)
61
+ data.fillna("Unknown", inplace=True)
62
+
63
+ for col in METADATA_COLUMNS:
64
+ if col in data.columns:
65
+ data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0)
66
+
67
+ label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
68
+ for col in LABEL_COLUMNS:
69
+ data[col] = label_encoders[col].fit_transform(data[col])
70
+ return data, label_encoders
71
+
72
+ def get_tokenizer(model_name):
73
+ if "bert" in model_name.lower():
74
+ return BertTokenizer.from_pretrained(model_name)
75
+ elif "roberta" in model_name.lower():
76
+ return RobertaTokenizer.from_pretrained(model_name)
77
+ elif "deberta" in model_name.lower():
78
+ return DebertaTokenizer.from_pretrained(model_name)
79
+ else:
80
+ raise ValueError(f"Unsupported tokenizer for model: {model_name}")
81
+
82
+ def save_label_encoders(label_encoders):
83
+ with open(LABEL_ENCODERS_PATH, "wb") as f:
84
+ pickle.dump(label_encoders, f)
85
+ print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")
86
+
87
+ def load_label_encoders():
88
+ with open(LABEL_ENCODERS_PATH, "rb") as f:
89
+ return pickle.load(f)
90
+
91
+ def get_num_labels(label_encoders):
92
+ return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
label_encoders.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be834abbaaa80f915d0a0015f541a17ae6fda5c75d9485cb23c6a7b7bb7b7c97
3
+ size 2047