namanpenguin commited on
Commit
22539ef
·
verified ·
1 Parent(s): 1f1b7a9

Upload 9 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ software-properties-common \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy requirements file
16
+ COPY requirements.txt .
17
+
18
+ # Install Python dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Create necessary directories with proper permissions
22
+ RUN mkdir -p /app/uploads \
23
+ /app/saved_models \
24
+ /app/predictions \
25
+ /app/tokenizer \
26
+ /app/cache \
27
+ && chmod -R 777 /app/uploads \
28
+ /app/saved_models \
29
+ /app/predictions \
30
+ /app/tokenizer \
31
+ /app/cache
32
+
33
+ # Copy the application code and utilities
34
+ COPY . /app/
35
+ COPY ../dataset_utils.py /app/
36
+ COPY ../train_utils.py /app/
37
+ COPY ../config.py /app/
38
+ COPY ../label_encoders.pkl /app/
39
+
40
+ # Set environment variables
41
+ ENV PYTHONPATH=/app
42
+ ENV PYTHONUNBUFFERED=1
43
+ ENV PORT=7860
44
+
45
+ # Expose the port the app runs on
46
+ EXPOSE 7860
47
+
48
+ # Command to run the application
49
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import logging
7
+ import os
8
+ import pandas as pd
9
+ from datetime import datetime
10
+ import shutil
11
+ from pathlib import Path
12
+ import numpy as np
13
+ import sys
14
+ import json
15
+ import joblib
16
+
17
+ # Import existing utilities
18
+ from dataset_utils import (
19
+ load_and_preprocess_data,
20
+ save_label_encoders,
21
+ load_label_encoders
22
+ )
23
+ from config import (
24
+ TEXT_COLUMN,
25
+ LABEL_COLUMNS,
26
+ BATCH_SIZE,
27
+ MODEL_SAVE_DIR
28
+ )
29
+ from tfidf_based_models.tfidf_lgbm import TfidfLightGBM
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ app = FastAPI(title="LGBM Compliance Predictor API")
36
+
37
+ UPLOAD_DIR = Path("uploads")
38
+ MODEL_SAVE_DIR = Path("saved_models")
39
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
40
+ MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
41
+
42
+ training_status = {
43
+ "is_training": False,
44
+ "current_epoch": 0,
45
+ "total_epochs": 0,
46
+ "current_loss": 0.0,
47
+ "start_time": None,
48
+ "end_time": None,
49
+ "status": "idle",
50
+ "metrics": None
51
+ }
52
+
53
+ class TrainingConfig(BaseModel):
54
+ batch_size: int = 32
55
+ num_epochs: int = 1 # Not used for LGBM, but kept for API compatibility
56
+ random_state: int = 42
57
+
58
+ class TrainingResponse(BaseModel):
59
+ message: str
60
+ training_id: str
61
+ status: str
62
+ download_url: Optional[str] = None
63
+
64
+ class ValidationResponse(BaseModel):
65
+ message: str
66
+ metrics: Dict[str, Any]
67
+ predictions: List[Dict[str, Any]]
68
+
69
+ class TransactionData(BaseModel):
70
+ Transaction_Id: str
71
+ Message: str
72
+ # ... (other fields as needed) ...
73
+
74
+ class PredictionRequest(BaseModel):
75
+ transaction_data: TransactionData
76
+ model_name: str = "lgbm_models" # Default to tfidf_lgbm if not specified
77
+
78
+ class BatchPredictionResponse(BaseModel):
79
+ message: str
80
+ predictions: List[Dict[str, Any]]
81
+ metrics: Optional[Dict[str, Any]] = None
82
+
83
+ @app.get("/")
84
+ async def root():
85
+ return {"message": "LGBM Compliance Predictor API"}
86
+
87
+ @app.get("/v1/lgbm/health")
88
+ async def health_check():
89
+ return {"status": "healthy"}
90
+
91
+ @app.get("/v1/lgbm/training-status")
92
+ async def get_training_status():
93
+ return training_status
94
+
95
+ @app.post("/v1/lgbm/train", response_model=TrainingResponse)
96
+ async def start_training(
97
+ config: str = Form(...),
98
+ background_tasks: BackgroundTasks = None,
99
+ file: UploadFile = File(...)
100
+ ):
101
+ if training_status["is_training"]:
102
+ raise HTTPException(status_code=400, detail="Training is already in progress")
103
+ if not file.filename.endswith('.csv'):
104
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
105
+ try:
106
+ config_dict = json.loads(config)
107
+ training_config = TrainingConfig(**config_dict)
108
+ except Exception as e:
109
+ raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
110
+ file_path = UPLOAD_DIR / file.filename
111
+ with file_path.open("wb") as buffer:
112
+ shutil.copyfileobj(file.file, buffer)
113
+ training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
114
+ training_status.update({
115
+ "is_training": True,
116
+ "current_epoch": 0,
117
+ "total_epochs": 1,
118
+ "start_time": datetime.now().isoformat(),
119
+ "status": "starting"
120
+ })
121
+ background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
122
+ download_url = f"/v1/lgbm/download-model/{training_id}"
123
+ return TrainingResponse(
124
+ message="Training started successfully",
125
+ training_id=training_id,
126
+ status="started",
127
+ download_url=download_url
128
+ )
129
+
130
+ @app.post("/v1/lgbm/validate")
131
+ async def validate_model(
132
+ file: UploadFile = File(...),
133
+ model_name: str = "lgbm_models"
134
+ ):
135
+ if not file.filename.endswith('.csv'):
136
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
137
+ try:
138
+ file_path = UPLOAD_DIR / file.filename
139
+ with file_path.open("wb") as buffer:
140
+ shutil.copyfileobj(file.file, buffer)
141
+ data_df, label_encoders = load_and_preprocess_data(str(file_path))
142
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pkl"
143
+ if not model_path.exists():
144
+ raise HTTPException(status_code=404, detail="LGBM model file not found")
145
+ model = TfidfLightGBM(label_encoders)
146
+ model.load_model(model_name)
147
+ X = data_df[TEXT_COLUMN]
148
+ y = data_df[LABEL_COLUMNS]
149
+ reports, y_true_list, y_pred_list = model.evaluate(X, y)
150
+ all_probs = model.predict_proba(X)
151
+ predictions = []
152
+ for i, col in enumerate(LABEL_COLUMNS):
153
+ label_encoder = label_encoders[col]
154
+ true_labels_orig = label_encoder.inverse_transform(y_true_list[i])
155
+ pred_labels_orig = label_encoder.inverse_transform(y_pred_list[i])
156
+ for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
157
+ class_probs = {label: float(prob) for label, prob in zip(label_encoder.classes_, probs)}
158
+ predictions.append({
159
+ "field": col,
160
+ "true_label": true,
161
+ "predicted_label": pred,
162
+ "probabilities": class_probs
163
+ })
164
+ return ValidationResponse(
165
+ message="Validation completed successfully",
166
+ metrics=reports,
167
+ predictions=predictions
168
+ )
169
+ except Exception as e:
170
+ logger.error(f"Validation failed: {str(e)}")
171
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
172
+ finally:
173
+ if os.path.exists(file_path):
174
+ os.remove(file_path)
175
+
176
+ @app.post("/v1/lgbm/predict")
177
+ async def predict(
178
+ request: Optional[PredictionRequest] = None,
179
+ file: UploadFile = File(None),
180
+ model_name: str = "lgbm_models"
181
+ ):
182
+ try:
183
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pkl"
184
+ if not model_path.exists():
185
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
186
+ label_encoders = load_label_encoders()
187
+ model = TfidfLightGBM(label_encoders)
188
+ model.load_model(model_name)
189
+ # Batch prediction
190
+ if file and file.filename:
191
+ if not file.filename.endswith('.csv'):
192
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
193
+ file_path = UPLOAD_DIR / file.filename
194
+ with file_path.open("wb") as buffer:
195
+ shutil.copyfileobj(file.file, buffer)
196
+ try:
197
+ data_df, _ = load_and_preprocess_data(str(file_path))
198
+ X = data_df[TEXT_COLUMN]
199
+ all_probabilities = model.predict_proba(X)
200
+ predictions = []
201
+ for i, row in data_df.iterrows():
202
+ transaction_pred = {}
203
+ for j, col in enumerate(LABEL_COLUMNS):
204
+ probs = all_probabilities[j][i]
205
+ pred = np.argmax(probs)
206
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
207
+ class_probs = {label: float(probs[j]) for j, label in enumerate(label_encoders[col].classes_)}
208
+ transaction_pred[col] = {
209
+ "prediction": decoded_pred,
210
+ "probabilities": class_probs
211
+ }
212
+ predictions.append({
213
+ "transaction_id": row.get('Transaction_Id', f"transaction_{i}"),
214
+ "predictions": transaction_pred
215
+ })
216
+ return BatchPredictionResponse(
217
+ message="Batch prediction completed successfully",
218
+ predictions=predictions
219
+ )
220
+ finally:
221
+ if os.path.exists(file_path):
222
+ os.remove(file_path)
223
+ # Single prediction
224
+ elif request and request.transaction_data:
225
+ input_data = pd.DataFrame([request.transaction_data.dict()])
226
+ X = input_data[TEXT_COLUMN]
227
+ all_probabilities = model.predict_proba(X)
228
+ response = {}
229
+ for i, col in enumerate(LABEL_COLUMNS):
230
+ probs = all_probabilities[i][0]
231
+ pred = np.argmax(probs)
232
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
233
+ class_probs = {label: float(probs[j]) for j, label in enumerate(label_encoders[col].classes_)}
234
+ response[col] = {
235
+ "prediction": decoded_pred,
236
+ "probabilities": class_probs
237
+ }
238
+ return response
239
+ else:
240
+ raise HTTPException(
241
+ status_code=400,
242
+ detail="Either provide a transaction in the request body or upload a CSV file"
243
+ )
244
+ except Exception as e:
245
+ raise HTTPException(status_code=500, detail=str(e))
246
+
247
+ @app.get("/v1/lgbm/download-model/{model_id}")
248
+ async def download_model(model_id: str):
249
+ model_path = MODEL_SAVE_DIR / f"{model_id}.pkl"
250
+ if not model_path.exists():
251
+ raise HTTPException(status_code=404, detail="Model not found")
252
+ return FileResponse(
253
+ path=model_path,
254
+ filename=f"lgbm_model_{model_id}.pkl",
255
+ media_type="application/octet-stream"
256
+ )
257
+
258
+ async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
259
+ try:
260
+ data_df_original, label_encoders = load_and_preprocess_data(file_path)
261
+ save_label_encoders(label_encoders)
262
+ X = data_df_original[TEXT_COLUMN]
263
+ y = data_df_original[LABEL_COLUMNS]
264
+ model = TfidfLightGBM(label_encoders)
265
+ model.train(X, y)
266
+ model.save_model(training_id)
267
+ training_status.update({
268
+ "is_training": False,
269
+ "end_time": datetime.now().isoformat(),
270
+ "status": "completed"
271
+ })
272
+ except Exception as e:
273
+ logger.error(f"Training failed: {str(e)}")
274
+ training_status.update({
275
+ "is_training": False,
276
+ "end_time": datetime.now().isoformat(),
277
+ "status": "failed",
278
+ "error": str(e)
279
+ })
280
+
281
+ if __name__ == "__main__":
282
+ port = int(os.environ.get("PORT", 7860))
283
+ uvicorn.run(app, host="0.0.0.0", port=port)
config.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ import torch
4
+ import os
5
+
6
+ # --- Paths ---
7
+ # Adjust DATA_PATH to your actual data location
8
+ DATA_PATH = './data/synthetic_transactions_samples_5000.csv'
9
+ TOKENIZER_PATH = './tokenizer/'
10
+ LABEL_ENCODERS_PATH = './label_encoders.pkl'
11
+ MODEL_SAVE_DIR = './saved_models/'
12
+ PREDICTIONS_SAVE_DIR = './predictions/' # To save predictions for voting ensemble
13
+
14
+ # --- Data Columns ---
15
+ TEXT_COLUMN = "Sanction_Context"
16
+ # Define all your target label columns
17
+ LABEL_COLUMNS = [
18
+ "Red_Flag_Reason",
19
+ "Maker_Action",
20
+ "Escalation_Level",
21
+ "Risk_Category",
22
+ "Risk_Drivers",
23
+ "Investigation_Outcome"
24
+ ]
25
+ # Example metadata columns. Add actual numerical/categorical metadata if available in your CSV.
26
+ # For now, it's an empty list. If you add metadata, ensure these columns exist and are numeric or can be encoded.
27
+ METADATA_COLUMNS = [] # e.g., ["Risk_Score", "Transaction_Amount"]
28
+
29
+ # --- Model Hyperparameters ---
30
+ MAX_LEN = 128 # Maximum sequence length for transformer tokenizers
31
+ BATCH_SIZE = 16 # Batch size for training and evaluation
32
+ LEARNING_RATE = 2e-5 # Learning rate for AdamW optimizer
33
+ NUM_EPOCHS = 3 # Number of training epochs. Adjust based on convergence.
34
+ DROPOUT_RATE = 0.3 # Dropout rate for regularization
35
+
36
+ # --- Device Configuration ---
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # --- Specific Model Configurations ---
40
+ BERT_MODEL_NAME = 'bert-base-uncased'
41
+ ROBERTA_MODEL_NAME = 'roberta-base'
42
+ DEBERTA_MODEL_NAME = 'microsoft/deberta-base'
43
+
44
+ # TF-IDF
45
+ TFIDF_MAX_FEATURES = 5000 # Max features for TF-IDF vectorizer
46
+
47
+ # --- Field-Specific Strategy (Conceptual) ---
48
+ # This dictionary provides conceptual strategies for enhancing specific fields.
49
+ # Actual implementation requires adapting the models (e.g., custom loss functions, metadata integration).
50
+ FIELD_STRATEGIES = {
51
+ "Maker_Action": {
52
+ "loss": "focal_loss", # Requires custom Focal Loss implementation
53
+ "enhancements": ["action_templates", "context_prompt_tuning"] # Advanced NLP concepts
54
+ },
55
+ "Risk_Category": {
56
+ "enhancements": ["numerical_metadata", "transaction_patterns"] # Integrate METADATA_COLUMNS
57
+ },
58
+ "Escalation_Level": {
59
+ "enhancements": ["class_balancing", "policy_keyword_patterns"] # Handled by class weights/metadata
60
+ },
61
+ "Investigation_Outcome": {
62
+ "type": "classification_or_generation" # If generation, T5/BART would be needed.
63
+ }
64
+ }
65
+
66
+ # Ensure model save and predictions directories exist
67
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
68
+ os.makedirs(PREDICTIONS_SAVE_DIR, exist_ok=True)
69
+ os.makedirs(TOKENIZER_PATH, exist_ok=True)
dataset_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_utils.py
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from sklearn.preprocessing import LabelEncoder
7
+ from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
8
+ import pickle
9
+ import os
10
+
11
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
12
+
13
+ class ComplianceDataset(Dataset):
14
+ """
15
+ Custom Dataset class for handling text and multi-output labels for PyTorch models.
16
+ """
17
+ def __init__(self, texts, labels, tokenizer, max_len):
18
+ self.texts = texts
19
+ self.labels = labels
20
+ self.tokenizer = tokenizer
21
+ self.max_len = max_len
22
+
23
+ def __len__(self):
24
+ """Returns the total number of samples in the dataset."""
25
+ return len(self.texts)
26
+
27
+ def __getitem__(self, idx):
28
+ """
29
+ Retrieves a sample from the dataset at the given index.
30
+ Tokenizes the text and converts labels to a PyTorch tensor.
31
+ """
32
+ text = str(self.texts[idx])
33
+ # Tokenize the text, padding to max_length and truncating if longer.
34
+ # return_tensors="pt" ensures PyTorch tensors are returned.
35
+ inputs = self.tokenizer(
36
+ text,
37
+ padding='max_length',
38
+ truncation=True,
39
+ max_length=self.max_len,
40
+ return_tensors="pt"
41
+ )
42
+ # Squeeze removes the batch dimension (which is 1 here because we process one sample at a time)
43
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
44
+ # Convert labels to a PyTorch long tensor
45
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
46
+ return inputs, labels
47
+
48
+ class ComplianceDatasetWithMetadata(Dataset):
49
+ """
50
+ Custom Dataset class for handling text, additional numerical metadata, and multi-output labels.
51
+ Used for hybrid models combining text and tabular features.
52
+ """
53
+ def __init__(self, texts, metadata, labels, tokenizer, max_len):
54
+ self.texts = texts
55
+ self.metadata = metadata # Expects metadata as a NumPy array or list of lists
56
+ self.labels = labels
57
+ self.tokenizer = tokenizer
58
+ self.max_len = max_len
59
+
60
+ def __len__(self):
61
+ """Returns the total number of samples in the dataset."""
62
+ return len(self.texts)
63
+
64
+ def __getitem__(self, idx):
65
+ """
66
+ Retrieves a sample, its metadata, and labels from the dataset at the given index.
67
+ Tokenizes text, converts metadata and labels to PyTorch tensors.
68
+ """
69
+ text = str(self.texts[idx])
70
+ inputs = self.tokenizer(
71
+ text,
72
+ padding='max_length',
73
+ truncation=True,
74
+ max_length=self.max_len,
75
+ return_tensors="pt"
76
+ )
77
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
78
+ # Convert metadata for the current sample to a float tensor
79
+ metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
80
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
81
+ return inputs, metadata, labels
82
+
83
+ def load_and_preprocess_data(data_path):
84
+ """
85
+ Loads data from a CSV, fills missing values, and encodes categorical labels.
86
+ Also handles converting specified METADATA_COLUMNS to numeric.
87
+
88
+ Args:
89
+ data_path (str): Path to the CSV data file.
90
+
91
+ Returns:
92
+ tuple: A tuple containing:
93
+ - data (pd.DataFrame): The preprocessed DataFrame.
94
+ - label_encoders (dict): A dictionary of LabelEncoder objects for each label column.
95
+ """
96
+ data = pd.read_csv(data_path)
97
+ data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown"
98
+
99
+ # Convert metadata columns to numeric, coercing errors and filling NaNs with 0
100
+ # This ensures metadata is suitable for neural networks.
101
+ for col in METADATA_COLUMNS:
102
+ if col in data.columns:
103
+ data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value
104
+
105
+ label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
106
+ for col in LABEL_COLUMNS:
107
+ # Fit and transform each label column using its respective LabelEncoder
108
+ data[col] = label_encoders[col].fit_transform(data[col])
109
+ return data, label_encoders
110
+
111
+ def get_tokenizer(model_name):
112
+ """
113
+ Returns the appropriate Hugging Face tokenizer based on the model name.
114
+
115
+ Args:
116
+ model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased').
117
+
118
+ Returns:
119
+ transformers.PreTrainedTokenizer: The initialized tokenizer.
120
+ """
121
+ if "bert" in model_name.lower():
122
+ return BertTokenizer.from_pretrained(model_name)
123
+ elif "roberta" in model_name.lower():
124
+ return RobertaTokenizer.from_pretrained(model_name)
125
+ elif "deberta" in model_name.lower():
126
+ return DebertaTokenizer.from_pretrained(model_name)
127
+ else:
128
+ raise ValueError(f"Unsupported tokenizer for model: {model_name}")
129
+
130
+ def save_label_encoders(label_encoders):
131
+ """
132
+ Saves a dictionary of label encoders to a pickle file.
133
+ This is crucial for decoding predictions back to original labels.
134
+
135
+ Args:
136
+ label_encoders (dict): Dictionary of LabelEncoder objects.
137
+ """
138
+ with open(LABEL_ENCODERS_PATH, "wb") as f:
139
+ pickle.dump(label_encoders, f)
140
+ print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")
141
+
142
+ def load_label_encoders():
143
+ """
144
+ Loads a dictionary of label encoders from a pickle file.
145
+
146
+ Returns:
147
+ dict: Loaded dictionary of LabelEncoder objects.
148
+ """
149
+ with open(LABEL_ENCODERS_PATH, "rb") as f:
150
+ return pickle.load(f)
151
+ print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}")
152
+
153
+
154
+ def get_num_labels(label_encoders):
155
+ """
156
+ Returns a list containing the number of unique classes for each label column.
157
+ This list is used to define the output dimensions of the model's classification heads.
158
+
159
+ Args:
160
+ label_encoders (dict): Dictionary of LabelEncoder objects.
161
+
162
+ Returns:
163
+ list: A list of integers, where each integer is the number of classes for a label.
164
+ """
165
+ return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
label_encoders.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c336fd07858af76d40c7200de1a769099abeec25d4f48b999351318680d4e4d6
3
+ size 2047
models/tfidf_lgbm.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tfidf_based_models/tfidf_lgbm.py
2
+
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
+ import lightgbm as lgb
5
+ from sklearn.pipeline import Pipeline
6
+ from sklearn.metrics import classification_report
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from sklearn.utils.class_weight import compute_class_weight
9
+ import numpy as np
10
+ import pandas as pd
11
+ import joblib
12
+ import os
13
+
14
+ from config import TEXT_COLUMN, LABEL_COLUMNS, TFIDF_MAX_FEATURES, MODEL_SAVE_DIR
15
+
16
+ class TfidfLightGBM:
17
+ """
18
+ TF-IDF based LightGBM model for multi-output classification.
19
+ It trains a separate LightGBM classifier for each target label
20
+ after converting text data into TF-IDF features.
21
+ """
22
+ def __init__(self, label_encoders):
23
+ """
24
+ Initializes the TfidfLightGBM model.
25
+
26
+ Args:
27
+ label_encoders (dict): A dictionary of LabelEncoder objects.
28
+ """
29
+ self.label_encoders = label_encoders
30
+ self.models = {} # Stores the trained Pipeline for each label
31
+
32
+ def train(self, X_train_text, y_train_df):
33
+ """
34
+ Trains a TF-IDF + LightGBM pipeline for each label.
35
+
36
+ Args:
37
+ X_train_text (pd.Series): Training text data.
38
+ y_train_df (pd.DataFrame): DataFrame of training labels (encoded).
39
+ """
40
+ print("Training TF-IDF + LightGBM models...")
41
+ for i, col in enumerate(LABEL_COLUMNS):
42
+ print(f" Training for {col}...")
43
+ num_classes = len(self.label_encoders[col].classes_)
44
+ # Determine LightGBM objective based on number of classes
45
+ objective = 'multiclass' if num_classes > 2 else 'binary'
46
+ # `num_class` parameter is required for 'multiclass' objective
47
+ num_class_param = {'num_class': num_classes} if num_classes > 2 else {}
48
+
49
+ pipeline = Pipeline([
50
+ ('tfidf', TfidfVectorizer(max_features=TFIDF_MAX_FEATURES)),
51
+ ('lgbm', lgb.LGBMClassifier(
52
+ objective=objective,
53
+ **num_class_param, # Unpack num_class_param if it's not empty
54
+ random_state=42,
55
+ n_estimators=100
56
+ ))
57
+ ])
58
+ # Fit the pipeline on the training data.
59
+ # LightGBM handles class imbalance with `is_unbalance=True` or `scale_pos_weight`
60
+ # for binary classification, or implicitly for multiclass with default settings.
61
+ pipeline.fit(X_train_text, y_train_df[col])
62
+ self.models[col] = pipeline
63
+ print("TF-IDF + LightGBM training complete.")
64
+
65
+ def predict(self, X_test_text):
66
+ """
67
+ Makes class predictions for all labels.
68
+
69
+ Args:
70
+ X_test_text (pd.Series): Test text data.
71
+
72
+ Returns:
73
+ dict: A dictionary where keys are label names and values are NumPy arrays
74
+ of predicted class indices.
75
+ """
76
+ predictions = {}
77
+ for col, model_pipeline in self.models.items():
78
+ predictions[col] = model_pipeline.predict(X_test_text)
79
+ return predictions
80
+
81
+ def predict_proba(self, X_test_text):
82
+ """
83
+ Returns prediction probabilities for each class for all labels.
84
+
85
+ Args:
86
+ X_test_text (pd.Series): Test text data.
87
+
88
+ Returns:
89
+ list: A list of NumPy arrays. Each array corresponds to a label column
90
+ and contains the probability distribution over classes for each sample.
91
+ """
92
+ probabilities = []
93
+ for col in LABEL_COLUMNS:
94
+ if col in self.models:
95
+ probabilities.append(self.models[col].predict_proba(X_test_text))
96
+ else:
97
+ print(f"Warning: Model for {col} not found, cannot predict probabilities.")
98
+ probabilities.append(np.array([]))
99
+ return probabilities
100
+
101
+ def evaluate(self, X_test_text, y_test_df):
102
+ """
103
+ Evaluates the models and returns classification reports.
104
+
105
+ Args:
106
+ X_test_text (pd.Series): Test text data.
107
+ y_test_df (pd.DataFrame): DataFrame of true test labels (encoded).
108
+
109
+ Returns:
110
+ tuple: A tuple containing:
111
+ - reports (dict): Classification reports for each label column.
112
+ - truths (list): List of true label arrays.
113
+ - preds (list): List of predicted label arrays.
114
+ """
115
+ reports = {}
116
+ truths = [[] for _ in range(len(LABEL_COLUMNS))]
117
+ preds = [[] for _ in range(len(LABEL_COLUMNS))]
118
+
119
+ for i, col in enumerate(LABEL_COLUMNS):
120
+ if col in self.models:
121
+ y_pred = self.models[col].predict(X_test_text)
122
+ y_true = y_test_df[col].values
123
+ try:
124
+ report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
125
+ reports[col] = report
126
+ except ValueError:
127
+ print(f"Warning: Could not generate classification report for {col}. Skipping.")
128
+ reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
129
+ truths[i].extend(y_true)
130
+ preds[i].extend(y_pred)
131
+ else:
132
+ print(f"Warning: Model for {col} not found for evaluation.")
133
+
134
+ return reports, truths, preds
135
+
136
+ def save_model(self, model_name="tfidf_lgbm", save_format='pickle'):
137
+ """
138
+ Saves the trained TF-IDF LightGBM models.
139
+
140
+ Args:
141
+ model_name (str): The base name for the saved model file.
142
+ save_format (str): Format to save the model in (default: 'pickle').
143
+ """
144
+ if save_format != 'pickle':
145
+ raise ValueError("TF-IDF models only support 'pickle' format")
146
+ save_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
147
+ joblib.dump(self.models, save_path)
148
+ print(f"TF-IDF LightGBM models saved to {save_path}")
149
+
150
+ def load_model(self, model_name="tfidf_lgbm"):
151
+ """
152
+ Loads trained TF-IDF LightGBM models from a file.
153
+
154
+ Args:
155
+ model_name (str): The base name of the model file to load.
156
+ """
157
+ load_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
158
+ if os.path.exists(load_path):
159
+ self.models = joblib.load(load_path)
160
+ print(f"TF-IDF LightGBM models loaded from {load_path}")
161
+ else:
162
+ print(f"Error: Model file not found at {load_path}. Initialize models as empty.")
163
+ self.models = {}
saved_models/lgbm_models.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42a64b8b27153f5e28c27a9d5012170bb71a6ce4f19ce10e00fddf59c947ee4d
3
+ size 16550415
saved_models/tfidf_vectorizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a25009cce24ff90ce9329296fb1655a4d837ff6cfd5c17d73fd256b88a58d399
3
+ size 3724098
train_utils.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_utils.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.optim import AdamW
6
+ from sklearn.metrics import classification_report
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import pandas as pd
11
+ import os
12
+ import joblib
13
+
14
+ from config import DEVICE, LABEL_COLUMNS, NUM_EPOCHS, LEARNING_RATE, MODEL_SAVE_DIR
15
+
16
+ def get_class_weights(data_df, field, label_encoder):
17
+ """
18
+ Computes balanced class weights for a given target field.
19
+ These weights can be used in the loss function to mitigate class imbalance.
20
+
21
+ Args:
22
+ data_df (pd.DataFrame): The DataFrame containing the original (unencoded) label data.
23
+ field (str): The name of the label column for which to compute weights.
24
+ label_encoder (sklearn.preprocessing.LabelEncoder): The label encoder fitted for this field.
25
+
26
+ Returns:
27
+ torch.Tensor: A tensor of class weights for the specified field.
28
+ """
29
+ # Get the original labels for the specified field
30
+ y = data_df[field].values
31
+ # Use label_encoder.transform directly - it will handle unseen labels
32
+ try:
33
+ y_encoded = label_encoder.transform(y)
34
+ except ValueError as e:
35
+ print(f"Warning: {e}")
36
+ print(f"Using only seen labels for class weights calculation")
37
+ # Filter out unseen labels
38
+ seen_labels = set(label_encoder.classes_)
39
+ y_filtered = [label for label in y if label in seen_labels]
40
+ y_encoded = label_encoder.transform(y_filtered)
41
+
42
+ # Ensure y_encoded is integer type
43
+ y_encoded = y_encoded.astype(int)
44
+
45
+ # Initialize counts for all possible classes
46
+ n_classes = len(label_encoder.classes_)
47
+ class_counts = np.zeros(n_classes, dtype=int)
48
+
49
+ # Count occurrences of each class
50
+ for i in range(n_classes):
51
+ class_counts[i] = np.sum(y_encoded == i)
52
+
53
+ # Calculate weights for all classes
54
+ total_samples = len(y_encoded)
55
+ class_weights = np.ones(n_classes) # Default weight of 1 for unseen classes
56
+ seen_classes = class_counts > 0
57
+ if np.any(seen_classes):
58
+ class_weights[seen_classes] = total_samples / (np.sum(seen_classes) * class_counts[seen_classes])
59
+
60
+ return torch.tensor(class_weights, dtype=torch.float)
61
+
62
+ def initialize_criterions(data_df, label_encoders):
63
+ """
64
+ Initializes CrossEntropyLoss criteria for each label column, applying class weights.
65
+
66
+ Args:
67
+ data_df (pd.DataFrame): The original (unencoded) DataFrame. Used to compute class weights.
68
+ label_encoders (dict): Dictionary of LabelEncoder objects.
69
+
70
+ Returns:
71
+ dict: A dictionary where keys are label column names and values are
72
+ initialized `torch.nn.CrossEntropyLoss` objects.
73
+ """
74
+ field_criterions = {}
75
+ for field in LABEL_COLUMNS:
76
+ # Get class weights for the current field
77
+ weights = get_class_weights(data_df, field, label_encoders[field])
78
+ # Initialize CrossEntropyLoss with the computed weights and move to the device
79
+ field_criterions[field] = torch.nn.CrossEntropyLoss(weight=weights.to(DEVICE))
80
+ return field_criterions
81
+
82
+ def train_model(model, loader, optimizer, field_criterions, epoch):
83
+ """
84
+ Trains the given PyTorch model for one epoch.
85
+
86
+ Args:
87
+ model (torch.nn.Module): The model to train.
88
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
89
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
90
+ field_criterions (dict): Dictionary of loss functions for each label.
91
+ epoch (int): Current epoch number (for progress bar description).
92
+
93
+ Returns:
94
+ float: Average training loss for the epoch.
95
+ """
96
+ model.train() # Set the model to training mode
97
+ total_loss = 0
98
+ # Use tqdm for a progress bar during training
99
+ tqdm_loader = tqdm(loader, desc=f"Epoch {epoch + 1} Training")
100
+
101
+ for batch in tqdm_loader:
102
+ # Unpack batch based on whether it contains metadata
103
+ if len(batch) == 2: # Text-only models (inputs, labels)
104
+ inputs, labels = batch
105
+ input_ids = inputs['input_ids'].to(DEVICE)
106
+ attention_mask = inputs['attention_mask'].to(DEVICE)
107
+ labels = labels.to(DEVICE)
108
+ # Forward pass through the model
109
+ outputs = model(input_ids, attention_mask)
110
+ elif len(batch) == 3: # Text + Metadata models (inputs, metadata, labels)
111
+ inputs, metadata, labels = batch
112
+ input_ids = inputs['input_ids'].to(DEVICE)
113
+ attention_mask = inputs['attention_mask'].to(DEVICE)
114
+ metadata = metadata.to(DEVICE)
115
+ labels = labels.to(DEVICE)
116
+ # Forward pass through the hybrid model
117
+ outputs = model(input_ids, attention_mask, metadata)
118
+ else:
119
+ raise ValueError("Unsupported batch format. Expected 2 or 3 items in batch.")
120
+
121
+ loss = 0
122
+ # Calculate total loss by summing loss for each label column
123
+ # `outputs` is a list of logits, one for each label column
124
+ for i, output_logits in enumerate(outputs):
125
+ # `labels[:, i]` gets the true labels for the i-th label column
126
+ # `field_criterions[LABEL_COLUMNS[i]]` selects the appropriate loss function
127
+ loss += field_criterions[LABEL_COLUMNS[i]](output_logits, labels[:, i])
128
+
129
+ optimizer.zero_grad() # Clear previous gradients
130
+ loss.backward() # Backpropagation
131
+ optimizer.step() # Update model parameters
132
+ total_loss += loss.item() # Accumulate loss
133
+ tqdm_loader.set_postfix(loss=loss.item()) # Update progress bar with current batch loss
134
+
135
+ return total_loss / len(loader) # Return average loss for the epoch
136
+
137
+ def evaluate_model(model, loader):
138
+ """
139
+ Evaluates the given PyTorch model on a validation/test set.
140
+
141
+ Args:
142
+ model (torch.nn.Module): The model to evaluate.
143
+ loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
144
+
145
+ Returns:
146
+ tuple: A tuple containing:
147
+ - reports (dict): Classification reports (dict format) for each label column.
148
+ - truths (list): List of true label arrays for each label column.
149
+ - predictions (list): List of predicted label arrays for each label column.
150
+ """
151
+ model.eval() # Set the model to evaluation mode (disables dropout, batch norm updates, etc.)
152
+ # Initialize lists to store predictions and true labels for each output head
153
+ predictions = [[] for _ in range(len(LABEL_COLUMNS))]
154
+ truths = [[] for _ in range(len(LABEL_COLUMNS))]
155
+
156
+ with torch.no_grad(): # Disable gradient calculations during evaluation for efficiency
157
+ for batch in tqdm(loader, desc="Evaluation"):
158
+ if len(batch) == 2:
159
+ inputs, labels = batch
160
+ input_ids = inputs['input_ids'].to(DEVICE)
161
+ attention_mask = inputs['attention_mask'].to(DEVICE)
162
+ labels = labels.to(DEVICE)
163
+ outputs = model(input_ids, attention_mask)
164
+ elif len(batch) == 3:
165
+ inputs, metadata, labels = batch
166
+ input_ids = inputs['input_ids'].to(DEVICE)
167
+ attention_mask = inputs['attention_mask'].to(DEVICE)
168
+ metadata = metadata.to(DEVICE)
169
+ labels = labels.to(DEVICE)
170
+ outputs = model(input_ids, attention_mask, metadata)
171
+ else:
172
+ raise ValueError("Unsupported batch format.")
173
+
174
+ for i, output_logits in enumerate(outputs):
175
+ # Get the predicted class by taking the argmax of the logits
176
+ preds = torch.argmax(output_logits, dim=1).cpu().numpy()
177
+ predictions[i].extend(preds)
178
+ # Get the true labels for the current output head
179
+ truths[i].extend(labels[:, i].cpu().numpy())
180
+
181
+ reports = {}
182
+ # Generate classification report for each label column
183
+ for i, col in enumerate(LABEL_COLUMNS):
184
+ try:
185
+ # `zero_division=0` handles cases where a class might have no true or predicted samples
186
+ reports[col] = classification_report(truths[i], predictions[i], output_dict=True, zero_division=0)
187
+ except ValueError:
188
+ # Handle cases where a label might not appear in the validation set,
189
+ # which could cause classification_report to fail.
190
+ print(f"Warning: Could not generate classification report for {col}. Skipping.")
191
+ reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
192
+ return reports, truths, predictions
193
+
194
+ def summarize_metrics(metrics):
195
+ """
196
+ Summarizes classification reports into a readable Pandas DataFrame.
197
+
198
+ Args:
199
+ metrics (dict): Dictionary of classification reports, as returned by `evaluate_model`.
200
+
201
+ Returns:
202
+ pd.DataFrame: A DataFrame summarizing precision, recall, f1-score, accuracy, and support for each field.
203
+ """
204
+ summary = []
205
+ for field, report in metrics.items():
206
+ # Safely get metrics, defaulting to 0 if not present (e.g., for empty reports)
207
+ precision = report['weighted avg']['precision'] if 'weighted avg' in report else 0
208
+ recall = report['weighted avg']['recall'] if 'weighted avg' in report else 0
209
+ f1 = report['weighted avg']['f1-score'] if 'weighted avg' in report else 0
210
+ support = report['weighted avg']['support'] if 'weighted avg' in report else 0
211
+ accuracy = report['accuracy'] if 'accuracy' in report else 0 # Accuracy is usually top-level
212
+ summary.append({
213
+ "Field": field,
214
+ "Precision": precision,
215
+ "Recall": recall,
216
+ "F1-Score": f1,
217
+ "Accuracy": accuracy,
218
+ "Support": support
219
+ })
220
+ return pd.DataFrame(summary)
221
+
222
+ def save_model(model, model_name, save_format='pth'):
223
+ """
224
+ Saves the state dictionary of a PyTorch model.
225
+
226
+ Args:
227
+ model (torch.nn.Module): The trained PyTorch model.
228
+ model_name (str): A descriptive name for the model (used for filename).
229
+ save_format (str): Format to save the model in ('pth' for PyTorch models, 'pickle' for traditional ML models).
230
+ """
231
+ # Construct the save path dynamically relative to the project root
232
+ if save_format == 'pth':
233
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
234
+ torch.save(model.state_dict(), model_path)
235
+ elif save_format == 'pickle':
236
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
237
+ joblib.dump(model, model_path)
238
+ else:
239
+ raise ValueError(f"Unsupported save format: {save_format}")
240
+
241
+ print(f"Model saved to {model_path}")
242
+
243
+ def load_model_state(model, model_name, model_class, num_labels, metadata_dim=0):
244
+ """
245
+ Loads the state dictionary into a PyTorch model.
246
+
247
+ Args:
248
+ model (torch.nn.Module): An initialized model instance (architecture).
249
+ model_name (str): The name of the model to load.
250
+ model_class (class): The class of the model (e.g., BertMultiOutputModel).
251
+ num_labels (list): List of number of classes for each label.
252
+ metadata_dim (int): Dimensionality of metadata features, if applicable (default 0 for text-only).
253
+
254
+ Returns:
255
+ torch.nn.Module: The model with loaded state_dict, moved to the correct device, and set to eval mode.
256
+ """
257
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
258
+ if not os.path.exists(model_path):
259
+ print(f"Warning: Model file not found at {model_path}. Returning a newly initialized model instance.")
260
+ # Re-initialize the model if not found, to ensure it has the correct architecture
261
+ if metadata_dim > 0:
262
+ return model_class(num_labels, metadata_dim=metadata_dim).to(DEVICE)
263
+ else:
264
+ return model_class(num_labels).to(DEVICE)
265
+
266
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
267
+ model.to(DEVICE)
268
+ model.eval() # Set to evaluation mode after loading
269
+ print(f"Model loaded from {model_path}")
270
+ return model
271
+
272
+ def predict_probabilities(model, loader):
273
+ """
274
+ Generates prediction probabilities for each label for a given model.
275
+ This is used for confidence scoring and feeding into a voting ensemble.
276
+
277
+ Args:
278
+ model (torch.nn.Module): The trained PyTorch model.
279
+ loader (torch.utils.data.DataLoader): DataLoader for the data to predict on.
280
+
281
+ Returns:
282
+ list: A list of lists of numpy arrays. Each inner list corresponds to a label column,
283
+ containing the softmax probabilities for each sample for that label.
284
+ """
285
+ model.eval() # Set to evaluation mode
286
+ # List to store probabilities for each output head
287
+ all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
288
+
289
+ with torch.no_grad():
290
+ for batch in tqdm(loader, desc="Predicting Probabilities"):
291
+ # Unpack batch, ignoring labels as we only need inputs
292
+ if len(batch) == 2:
293
+ inputs, _ = batch
294
+ input_ids = inputs['input_ids'].to(DEVICE)
295
+ attention_mask = inputs['attention_mask'].to(DEVICE)
296
+ outputs = model(input_ids, attention_mask)
297
+ elif len(batch) == 3:
298
+ inputs, metadata, _ = batch
299
+ input_ids = inputs['input_ids'].to(DEVICE)
300
+ attention_mask = inputs['attention_mask'].to(DEVICE)
301
+ metadata = metadata.to(DEVICE)
302
+ outputs = model(input_ids, attention_mask, metadata)
303
+ else:
304
+ raise ValueError("Unsupported batch format.")
305
+
306
+ for i, out_logits in enumerate(outputs):
307
+ # Apply softmax to logits to get probabilities
308
+ probs = torch.softmax(out_logits, dim=1).cpu().numpy()
309
+ all_probabilities[i].extend(probs)
310
+ return all_probabilities