namanpenguin commited on
Commit
b44c3a2
·
verified ·
1 Parent(s): b7c3869

Upload 8 files

Browse files
Files changed (8) hide show
  1. Dockerfile +43 -0
  2. app.py +149 -0
  3. config.py +69 -0
  4. dataset_utils.py +165 -0
  5. label_encoders.pkl +3 -0
  6. requirements.txt +11 -0
  7. train_utils.py +310 -0
  8. voting.py +152 -0
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ software-properties-common \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy requirements file
16
+ COPY requirements.txt .
17
+
18
+ # Install Python dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Create necessary directories with proper permissions
22
+ RUN mkdir -p /app/uploads \
23
+ /app/predictions \
24
+ && chmod -R 777 /app/uploads \
25
+ /app/predictions
26
+
27
+ # Copy the application code and utilities
28
+ COPY . /app/
29
+ COPY ../voting.py /app/
30
+ COPY ../config.py /app/
31
+ COPY ../dataset_utils.py /app/
32
+ COPY ../label_encoders.pkl /app/
33
+
34
+ # Set environment variables
35
+ ENV PYTHONPATH=/app
36
+ ENV PYTHONUNBUFFERED=1
37
+ ENV PORT=7861
38
+
39
+ # Expose the port the app runs on
40
+ EXPOSE 7861
41
+
42
+ # Command to run the application
43
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import torch
7
+ import logging
8
+ import os
9
+ import asyncio
10
+ import pandas as pd
11
+ from datetime import datetime
12
+ import shutil
13
+ from pathlib import Path
14
+ import numpy as np
15
+ import sys
16
+
17
+ # Add parent directory to Python path
18
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
19
+
20
+ from voting import perform_voting_ensemble, save_predictions
21
+ from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR
22
+ from dataset_utils import load_label_encoders
23
+
24
+ # Configure logging
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+ app = FastAPI(title="Ensemble Voting API")
29
+
30
+ # Create necessary directories
31
+ UPLOAD_DIR = Path("uploads")
32
+ PREDICTIONS_DIR = Path(PREDICTIONS_SAVE_DIR)
33
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
34
+ PREDICTIONS_DIR.mkdir(parents=True, exist_ok=True)
35
+
36
+ class EnsembleConfig(BaseModel):
37
+ model_names: List[str]
38
+ weights: Optional[Dict[str, float]] = None
39
+
40
+ class EnsembleResponse(BaseModel):
41
+ message: str
42
+ metrics: Dict[str, Any]
43
+ predictions: List[Dict[str, Any]]
44
+
45
+ class PredictionData(BaseModel):
46
+ model_name: str
47
+ probabilities: List[List[float]]
48
+ true_labels: Optional[List[int]] = None
49
+
50
+ @app.get("/")
51
+ async def root():
52
+ return {"message": "Ensemble Voting API"}
53
+
54
+ @app.get("/health")
55
+ async def health_check():
56
+ return {"status": "healthy"}
57
+
58
+ @app.post("/ensemble/vote")
59
+ async def perform_ensemble(
60
+ config: EnsembleConfig
61
+ ):
62
+ """Perform ensemble voting using specified models"""
63
+ try:
64
+ # Perform ensemble voting
65
+ ensemble_reports, true_labels, ensemble_predictions = perform_voting_ensemble(config.model_names)
66
+
67
+ # Load label encoders for decoding predictions
68
+ label_encoders = load_label_encoders()
69
+
70
+ # Format predictions with original labels
71
+ formatted_predictions = []
72
+ for i, (col, preds) in enumerate(zip(LABEL_COLUMNS, ensemble_predictions)):
73
+ if true_labels[i] is not None:
74
+ label_encoder = label_encoders[col]
75
+ true_labels_orig = label_encoder.inverse_transform(true_labels[i])
76
+ pred_labels_orig = label_encoder.inverse_transform(preds)
77
+
78
+ for true, pred in zip(true_labels_orig, pred_labels_orig):
79
+ formatted_predictions.append({
80
+ "field": col,
81
+ "true_label": true,
82
+ "predicted_label": pred
83
+ })
84
+
85
+ return EnsembleResponse(
86
+ message="Ensemble voting completed successfully",
87
+ metrics=ensemble_reports,
88
+ predictions=formatted_predictions
89
+ )
90
+
91
+ except Exception as e:
92
+ logger.error(f"Ensemble voting failed: {str(e)}")
93
+ raise HTTPException(status_code=500, detail=f"Ensemble voting failed: {str(e)}")
94
+
95
+ @app.post("/ensemble/save-predictions")
96
+ async def save_model_predictions(
97
+ prediction_data: PredictionData
98
+ ):
99
+ """Save predictions from a model for later ensemble voting"""
100
+ try:
101
+ # Convert probabilities to numpy arrays
102
+ all_probs = [np.array(probs) for probs in prediction_data.probabilities]
103
+ true_labels = [np.array(prediction_data.true_labels) if prediction_data.true_labels else None]
104
+
105
+ # Save predictions
106
+ save_predictions(
107
+ prediction_data.model_name,
108
+ all_probs,
109
+ true_labels
110
+ )
111
+
112
+ return {
113
+ "message": f"Predictions saved successfully for model {prediction_data.model_name}",
114
+ "model_name": prediction_data.model_name
115
+ }
116
+
117
+ except Exception as e:
118
+ logger.error(f"Failed to save predictions: {str(e)}")
119
+ raise HTTPException(status_code=500, detail=f"Failed to save predictions: {str(e)}")
120
+
121
+ @app.get("/ensemble/available-models")
122
+ async def get_available_models():
123
+ """Get list of models with saved predictions"""
124
+ try:
125
+ model_dirs = [d for d in os.listdir(PREDICTIONS_DIR)
126
+ if os.path.isdir(os.path.join(PREDICTIONS_DIR, d))]
127
+
128
+ available_models = []
129
+ for model_name in model_dirs:
130
+ model_dir = os.path.join(PREDICTIONS_DIR, model_name)
131
+ has_all_files = all(
132
+ os.path.exists(os.path.join(model_dir, f"{col}_probs.pkl"))
133
+ for col in LABEL_COLUMNS
134
+ )
135
+ if has_all_files:
136
+ available_models.append(model_name)
137
+
138
+ return {
139
+ "available_models": available_models,
140
+ "total_models": len(available_models)
141
+ }
142
+
143
+ except Exception as e:
144
+ logger.error(f"Failed to get available models: {str(e)}")
145
+ raise HTTPException(status_code=500, detail=f"Failed to get available models: {str(e)}")
146
+
147
+ if __name__ == "__main__":
148
+ port = int(os.environ.get("PORT", 7861))
149
+ uvicorn.run(app, host="0.0.0.0", port=port)
config.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ import torch
4
+ import os
5
+
6
+ # --- Paths ---
7
+ # Adjust DATA_PATH to your actual data location
8
+ DATA_PATH = './data/synthetic_transactions_samples_5000.csv'
9
+ TOKENIZER_PATH = './tokenizer/'
10
+ LABEL_ENCODERS_PATH = './label_encoders.pkl'
11
+ MODEL_SAVE_DIR = './saved_models/'
12
+ PREDICTIONS_SAVE_DIR = './predictions/' # To save predictions for voting ensemble
13
+
14
+ # --- Data Columns ---
15
+ TEXT_COLUMN = "Sanction_Context"
16
+ # Define all your target label columns
17
+ LABEL_COLUMNS = [
18
+ "Red_Flag_Reason",
19
+ "Maker_Action",
20
+ "Escalation_Level",
21
+ "Risk_Category",
22
+ "Risk_Drivers",
23
+ "Investigation_Outcome"
24
+ ]
25
+ # Example metadata columns. Add actual numerical/categorical metadata if available in your CSV.
26
+ # For now, it's an empty list. If you add metadata, ensure these columns exist and are numeric or can be encoded.
27
+ METADATA_COLUMNS = [] # e.g., ["Risk_Score", "Transaction_Amount"]
28
+
29
+ # --- Model Hyperparameters ---
30
+ MAX_LEN = 128 # Maximum sequence length for transformer tokenizers
31
+ BATCH_SIZE = 16 # Batch size for training and evaluation
32
+ LEARNING_RATE = 2e-5 # Learning rate for AdamW optimizer
33
+ NUM_EPOCHS = 3 # Number of training epochs. Adjust based on convergence.
34
+ DROPOUT_RATE = 0.3 # Dropout rate for regularization
35
+
36
+ # --- Device Configuration ---
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # --- Specific Model Configurations ---
40
+ BERT_MODEL_NAME = 'bert-base-uncased'
41
+ ROBERTA_MODEL_NAME = 'roberta-base'
42
+ DEBERTA_MODEL_NAME = 'microsoft/deberta-base'
43
+
44
+ # TF-IDF
45
+ TFIDF_MAX_FEATURES = 5000 # Max features for TF-IDF vectorizer
46
+
47
+ # --- Field-Specific Strategy (Conceptual) ---
48
+ # This dictionary provides conceptual strategies for enhancing specific fields.
49
+ # Actual implementation requires adapting the models (e.g., custom loss functions, metadata integration).
50
+ FIELD_STRATEGIES = {
51
+ "Maker_Action": {
52
+ "loss": "focal_loss", # Requires custom Focal Loss implementation
53
+ "enhancements": ["action_templates", "context_prompt_tuning"] # Advanced NLP concepts
54
+ },
55
+ "Risk_Category": {
56
+ "enhancements": ["numerical_metadata", "transaction_patterns"] # Integrate METADATA_COLUMNS
57
+ },
58
+ "Escalation_Level": {
59
+ "enhancements": ["class_balancing", "policy_keyword_patterns"] # Handled by class weights/metadata
60
+ },
61
+ "Investigation_Outcome": {
62
+ "type": "classification_or_generation" # If generation, T5/BART would be needed.
63
+ }
64
+ }
65
+
66
+ # Ensure model save and predictions directories exist
67
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
68
+ os.makedirs(PREDICTIONS_SAVE_DIR, exist_ok=True)
69
+ os.makedirs(TOKENIZER_PATH, exist_ok=True)
dataset_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_utils.py
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from sklearn.preprocessing import LabelEncoder
7
+ from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
8
+ import pickle
9
+ import os
10
+
11
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
12
+
13
+ class ComplianceDataset(Dataset):
14
+ """
15
+ Custom Dataset class for handling text and multi-output labels for PyTorch models.
16
+ """
17
+ def __init__(self, texts, labels, tokenizer, max_len):
18
+ self.texts = texts
19
+ self.labels = labels
20
+ self.tokenizer = tokenizer
21
+ self.max_len = max_len
22
+
23
+ def __len__(self):
24
+ """Returns the total number of samples in the dataset."""
25
+ return len(self.texts)
26
+
27
+ def __getitem__(self, idx):
28
+ """
29
+ Retrieves a sample from the dataset at the given index.
30
+ Tokenizes the text and converts labels to a PyTorch tensor.
31
+ """
32
+ text = str(self.texts[idx])
33
+ # Tokenize the text, padding to max_length and truncating if longer.
34
+ # return_tensors="pt" ensures PyTorch tensors are returned.
35
+ inputs = self.tokenizer(
36
+ text,
37
+ padding='max_length',
38
+ truncation=True,
39
+ max_length=self.max_len,
40
+ return_tensors="pt"
41
+ )
42
+ # Squeeze removes the batch dimension (which is 1 here because we process one sample at a time)
43
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
44
+ # Convert labels to a PyTorch long tensor
45
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
46
+ return inputs, labels
47
+
48
+ class ComplianceDatasetWithMetadata(Dataset):
49
+ """
50
+ Custom Dataset class for handling text, additional numerical metadata, and multi-output labels.
51
+ Used for hybrid models combining text and tabular features.
52
+ """
53
+ def __init__(self, texts, metadata, labels, tokenizer, max_len):
54
+ self.texts = texts
55
+ self.metadata = metadata # Expects metadata as a NumPy array or list of lists
56
+ self.labels = labels
57
+ self.tokenizer = tokenizer
58
+ self.max_len = max_len
59
+
60
+ def __len__(self):
61
+ """Returns the total number of samples in the dataset."""
62
+ return len(self.texts)
63
+
64
+ def __getitem__(self, idx):
65
+ """
66
+ Retrieves a sample, its metadata, and labels from the dataset at the given index.
67
+ Tokenizes text, converts metadata and labels to PyTorch tensors.
68
+ """
69
+ text = str(self.texts[idx])
70
+ inputs = self.tokenizer(
71
+ text,
72
+ padding='max_length',
73
+ truncation=True,
74
+ max_length=self.max_len,
75
+ return_tensors="pt"
76
+ )
77
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
78
+ # Convert metadata for the current sample to a float tensor
79
+ metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
80
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
81
+ return inputs, metadata, labels
82
+
83
+ def load_and_preprocess_data(data_path):
84
+ """
85
+ Loads data from a CSV, fills missing values, and encodes categorical labels.
86
+ Also handles converting specified METADATA_COLUMNS to numeric.
87
+
88
+ Args:
89
+ data_path (str): Path to the CSV data file.
90
+
91
+ Returns:
92
+ tuple: A tuple containing:
93
+ - data (pd.DataFrame): The preprocessed DataFrame.
94
+ - label_encoders (dict): A dictionary of LabelEncoder objects for each label column.
95
+ """
96
+ data = pd.read_csv(data_path)
97
+ data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown"
98
+
99
+ # Convert metadata columns to numeric, coercing errors and filling NaNs with 0
100
+ # This ensures metadata is suitable for neural networks.
101
+ for col in METADATA_COLUMNS:
102
+ if col in data.columns:
103
+ data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value
104
+
105
+ label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
106
+ for col in LABEL_COLUMNS:
107
+ # Fit and transform each label column using its respective LabelEncoder
108
+ data[col] = label_encoders[col].fit_transform(data[col])
109
+ return data, label_encoders
110
+
111
+ def get_tokenizer(model_name):
112
+ """
113
+ Returns the appropriate Hugging Face tokenizer based on the model name.
114
+
115
+ Args:
116
+ model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased').
117
+
118
+ Returns:
119
+ transformers.PreTrainedTokenizer: The initialized tokenizer.
120
+ """
121
+ if "bert" in model_name.lower():
122
+ return BertTokenizer.from_pretrained(model_name)
123
+ elif "roberta" in model_name.lower():
124
+ return RobertaTokenizer.from_pretrained(model_name)
125
+ elif "deberta" in model_name.lower():
126
+ return DebertaTokenizer.from_pretrained(model_name)
127
+ else:
128
+ raise ValueError(f"Unsupported tokenizer for model: {model_name}")
129
+
130
+ def save_label_encoders(label_encoders):
131
+ """
132
+ Saves a dictionary of label encoders to a pickle file.
133
+ This is crucial for decoding predictions back to original labels.
134
+
135
+ Args:
136
+ label_encoders (dict): Dictionary of LabelEncoder objects.
137
+ """
138
+ with open(LABEL_ENCODERS_PATH, "wb") as f:
139
+ pickle.dump(label_encoders, f)
140
+ print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")
141
+
142
+ def load_label_encoders():
143
+ """
144
+ Loads a dictionary of label encoders from a pickle file.
145
+
146
+ Returns:
147
+ dict: Loaded dictionary of LabelEncoder objects.
148
+ """
149
+ with open(LABEL_ENCODERS_PATH, "rb") as f:
150
+ return pickle.load(f)
151
+ print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}")
152
+
153
+
154
+ def get_num_labels(label_encoders):
155
+ """
156
+ Returns a list containing the number of unique classes for each label column.
157
+ This list is used to define the output dimensions of the model's classification heads.
158
+
159
+ Args:
160
+ label_encoders (dict): Dictionary of LabelEncoder objects.
161
+
162
+ Returns:
163
+ list: A list of integers, where each integer is the number of classes for a label.
164
+ """
165
+ return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
label_encoders.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c336fd07858af76d40c7200de1a769099abeec25d4f48b999351318680d4e4d6
3
+ size 2047
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn==0.24.0
3
+ pydantic==2.4.2
4
+ numpy==1.24.3
5
+ pandas==2.1.2
6
+ scikit-learn==1.3.2
7
+ python-multipart==0.0.6
8
+ python-jose==3.3.0
9
+ passlib==1.7.4
10
+ bcrypt==4.0.1
11
+ python-dotenv==1.0.0
train_utils.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_utils.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.optim import AdamW
6
+ from sklearn.metrics import classification_report
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import pandas as pd
11
+ import os
12
+ import joblib
13
+
14
+ from config import DEVICE, LABEL_COLUMNS, NUM_EPOCHS, LEARNING_RATE, MODEL_SAVE_DIR
15
+
16
+ def get_class_weights(data_df, field, label_encoder):
17
+ """
18
+ Computes balanced class weights for a given target field.
19
+ These weights can be used in the loss function to mitigate class imbalance.
20
+
21
+ Args:
22
+ data_df (pd.DataFrame): The DataFrame containing the original (unencoded) label data.
23
+ field (str): The name of the label column for which to compute weights.
24
+ label_encoder (sklearn.preprocessing.LabelEncoder): The label encoder fitted for this field.
25
+
26
+ Returns:
27
+ torch.Tensor: A tensor of class weights for the specified field.
28
+ """
29
+ # Get the original labels for the specified field
30
+ y = data_df[field].values
31
+ # Use label_encoder.transform directly - it will handle unseen labels
32
+ try:
33
+ y_encoded = label_encoder.transform(y)
34
+ except ValueError as e:
35
+ print(f"Warning: {e}")
36
+ print(f"Using only seen labels for class weights calculation")
37
+ # Filter out unseen labels
38
+ seen_labels = set(label_encoder.classes_)
39
+ y_filtered = [label for label in y if label in seen_labels]
40
+ y_encoded = label_encoder.transform(y_filtered)
41
+
42
+ # Ensure y_encoded is integer type
43
+ y_encoded = y_encoded.astype(int)
44
+
45
+ # Initialize counts for all possible classes
46
+ n_classes = len(label_encoder.classes_)
47
+ class_counts = np.zeros(n_classes, dtype=int)
48
+
49
+ # Count occurrences of each class
50
+ for i in range(n_classes):
51
+ class_counts[i] = np.sum(y_encoded == i)
52
+
53
+ # Calculate weights for all classes
54
+ total_samples = len(y_encoded)
55
+ class_weights = np.ones(n_classes) # Default weight of 1 for unseen classes
56
+ seen_classes = class_counts > 0
57
+ if np.any(seen_classes):
58
+ class_weights[seen_classes] = total_samples / (np.sum(seen_classes) * class_counts[seen_classes])
59
+
60
+ return torch.tensor(class_weights, dtype=torch.float)
61
+
62
+ def initialize_criterions(data_df, label_encoders):
63
+ """
64
+ Initializes CrossEntropyLoss criteria for each label column, applying class weights.
65
+
66
+ Args:
67
+ data_df (pd.DataFrame): The original (unencoded) DataFrame. Used to compute class weights.
68
+ label_encoders (dict): Dictionary of LabelEncoder objects.
69
+
70
+ Returns:
71
+ dict: A dictionary where keys are label column names and values are
72
+ initialized `torch.nn.CrossEntropyLoss` objects.
73
+ """
74
+ field_criterions = {}
75
+ for field in LABEL_COLUMNS:
76
+ # Get class weights for the current field
77
+ weights = get_class_weights(data_df, field, label_encoders[field])
78
+ # Initialize CrossEntropyLoss with the computed weights and move to the device
79
+ field_criterions[field] = torch.nn.CrossEntropyLoss(weight=weights.to(DEVICE))
80
+ return field_criterions
81
+
82
+ def train_model(model, loader, optimizer, field_criterions, epoch):
83
+ """
84
+ Trains the given PyTorch model for one epoch.
85
+
86
+ Args:
87
+ model (torch.nn.Module): The model to train.
88
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
89
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
90
+ field_criterions (dict): Dictionary of loss functions for each label.
91
+ epoch (int): Current epoch number (for progress bar description).
92
+
93
+ Returns:
94
+ float: Average training loss for the epoch.
95
+ """
96
+ model.train() # Set the model to training mode
97
+ total_loss = 0
98
+ # Use tqdm for a progress bar during training
99
+ tqdm_loader = tqdm(loader, desc=f"Epoch {epoch + 1} Training")
100
+
101
+ for batch in tqdm_loader:
102
+ # Unpack batch based on whether it contains metadata
103
+ if len(batch) == 2: # Text-only models (inputs, labels)
104
+ inputs, labels = batch
105
+ input_ids = inputs['input_ids'].to(DEVICE)
106
+ attention_mask = inputs['attention_mask'].to(DEVICE)
107
+ labels = labels.to(DEVICE)
108
+ # Forward pass through the model
109
+ outputs = model(input_ids, attention_mask)
110
+ elif len(batch) == 3: # Text + Metadata models (inputs, metadata, labels)
111
+ inputs, metadata, labels = batch
112
+ input_ids = inputs['input_ids'].to(DEVICE)
113
+ attention_mask = inputs['attention_mask'].to(DEVICE)
114
+ metadata = metadata.to(DEVICE)
115
+ labels = labels.to(DEVICE)
116
+ # Forward pass through the hybrid model
117
+ outputs = model(input_ids, attention_mask, metadata)
118
+ else:
119
+ raise ValueError("Unsupported batch format. Expected 2 or 3 items in batch.")
120
+
121
+ loss = 0
122
+ # Calculate total loss by summing loss for each label column
123
+ # `outputs` is a list of logits, one for each label column
124
+ for i, output_logits in enumerate(outputs):
125
+ # `labels[:, i]` gets the true labels for the i-th label column
126
+ # `field_criterions[LABEL_COLUMNS[i]]` selects the appropriate loss function
127
+ loss += field_criterions[LABEL_COLUMNS[i]](output_logits, labels[:, i])
128
+
129
+ optimizer.zero_grad() # Clear previous gradients
130
+ loss.backward() # Backpropagation
131
+ optimizer.step() # Update model parameters
132
+ total_loss += loss.item() # Accumulate loss
133
+ tqdm_loader.set_postfix(loss=loss.item()) # Update progress bar with current batch loss
134
+
135
+ return total_loss / len(loader) # Return average loss for the epoch
136
+
137
+ def evaluate_model(model, loader):
138
+ """
139
+ Evaluates the given PyTorch model on a validation/test set.
140
+
141
+ Args:
142
+ model (torch.nn.Module): The model to evaluate.
143
+ loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
144
+
145
+ Returns:
146
+ tuple: A tuple containing:
147
+ - reports (dict): Classification reports (dict format) for each label column.
148
+ - truths (list): List of true label arrays for each label column.
149
+ - predictions (list): List of predicted label arrays for each label column.
150
+ """
151
+ model.eval() # Set the model to evaluation mode (disables dropout, batch norm updates, etc.)
152
+ # Initialize lists to store predictions and true labels for each output head
153
+ predictions = [[] for _ in range(len(LABEL_COLUMNS))]
154
+ truths = [[] for _ in range(len(LABEL_COLUMNS))]
155
+
156
+ with torch.no_grad(): # Disable gradient calculations during evaluation for efficiency
157
+ for batch in tqdm(loader, desc="Evaluation"):
158
+ if len(batch) == 2:
159
+ inputs, labels = batch
160
+ input_ids = inputs['input_ids'].to(DEVICE)
161
+ attention_mask = inputs['attention_mask'].to(DEVICE)
162
+ labels = labels.to(DEVICE)
163
+ outputs = model(input_ids, attention_mask)
164
+ elif len(batch) == 3:
165
+ inputs, metadata, labels = batch
166
+ input_ids = inputs['input_ids'].to(DEVICE)
167
+ attention_mask = inputs['attention_mask'].to(DEVICE)
168
+ metadata = metadata.to(DEVICE)
169
+ labels = labels.to(DEVICE)
170
+ outputs = model(input_ids, attention_mask, metadata)
171
+ else:
172
+ raise ValueError("Unsupported batch format.")
173
+
174
+ for i, output_logits in enumerate(outputs):
175
+ # Get the predicted class by taking the argmax of the logits
176
+ preds = torch.argmax(output_logits, dim=1).cpu().numpy()
177
+ predictions[i].extend(preds)
178
+ # Get the true labels for the current output head
179
+ truths[i].extend(labels[:, i].cpu().numpy())
180
+
181
+ reports = {}
182
+ # Generate classification report for each label column
183
+ for i, col in enumerate(LABEL_COLUMNS):
184
+ try:
185
+ # `zero_division=0` handles cases where a class might have no true or predicted samples
186
+ reports[col] = classification_report(truths[i], predictions[i], output_dict=True, zero_division=0)
187
+ except ValueError:
188
+ # Handle cases where a label might not appear in the validation set,
189
+ # which could cause classification_report to fail.
190
+ print(f"Warning: Could not generate classification report for {col}. Skipping.")
191
+ reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
192
+ return reports, truths, predictions
193
+
194
+ def summarize_metrics(metrics):
195
+ """
196
+ Summarizes classification reports into a readable Pandas DataFrame.
197
+
198
+ Args:
199
+ metrics (dict): Dictionary of classification reports, as returned by `evaluate_model`.
200
+
201
+ Returns:
202
+ pd.DataFrame: A DataFrame summarizing precision, recall, f1-score, accuracy, and support for each field.
203
+ """
204
+ summary = []
205
+ for field, report in metrics.items():
206
+ # Safely get metrics, defaulting to 0 if not present (e.g., for empty reports)
207
+ precision = report['weighted avg']['precision'] if 'weighted avg' in report else 0
208
+ recall = report['weighted avg']['recall'] if 'weighted avg' in report else 0
209
+ f1 = report['weighted avg']['f1-score'] if 'weighted avg' in report else 0
210
+ support = report['weighted avg']['support'] if 'weighted avg' in report else 0
211
+ accuracy = report['accuracy'] if 'accuracy' in report else 0 # Accuracy is usually top-level
212
+ summary.append({
213
+ "Field": field,
214
+ "Precision": precision,
215
+ "Recall": recall,
216
+ "F1-Score": f1,
217
+ "Accuracy": accuracy,
218
+ "Support": support
219
+ })
220
+ return pd.DataFrame(summary)
221
+
222
+ def save_model(model, model_name, save_format='pth'):
223
+ """
224
+ Saves the state dictionary of a PyTorch model.
225
+
226
+ Args:
227
+ model (torch.nn.Module): The trained PyTorch model.
228
+ model_name (str): A descriptive name for the model (used for filename).
229
+ save_format (str): Format to save the model in ('pth' for PyTorch models, 'pickle' for traditional ML models).
230
+ """
231
+ # Construct the save path dynamically relative to the project root
232
+ if save_format == 'pth':
233
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
234
+ torch.save(model.state_dict(), model_path)
235
+ elif save_format == 'pickle':
236
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
237
+ joblib.dump(model, model_path)
238
+ else:
239
+ raise ValueError(f"Unsupported save format: {save_format}")
240
+
241
+ print(f"Model saved to {model_path}")
242
+
243
+ def load_model_state(model, model_name, model_class, num_labels, metadata_dim=0):
244
+ """
245
+ Loads the state dictionary into a PyTorch model.
246
+
247
+ Args:
248
+ model (torch.nn.Module): An initialized model instance (architecture).
249
+ model_name (str): The name of the model to load.
250
+ model_class (class): The class of the model (e.g., BertMultiOutputModel).
251
+ num_labels (list): List of number of classes for each label.
252
+ metadata_dim (int): Dimensionality of metadata features, if applicable (default 0 for text-only).
253
+
254
+ Returns:
255
+ torch.nn.Module: The model with loaded state_dict, moved to the correct device, and set to eval mode.
256
+ """
257
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
258
+ if not os.path.exists(model_path):
259
+ print(f"Warning: Model file not found at {model_path}. Returning a newly initialized model instance.")
260
+ # Re-initialize the model if not found, to ensure it has the correct architecture
261
+ if metadata_dim > 0:
262
+ return model_class(num_labels, metadata_dim=metadata_dim).to(DEVICE)
263
+ else:
264
+ return model_class(num_labels).to(DEVICE)
265
+
266
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
267
+ model.to(DEVICE)
268
+ model.eval() # Set to evaluation mode after loading
269
+ print(f"Model loaded from {model_path}")
270
+ return model
271
+
272
+ def predict_probabilities(model, loader):
273
+ """
274
+ Generates prediction probabilities for each label for a given model.
275
+ This is used for confidence scoring and feeding into a voting ensemble.
276
+
277
+ Args:
278
+ model (torch.nn.Module): The trained PyTorch model.
279
+ loader (torch.utils.data.DataLoader): DataLoader for the data to predict on.
280
+
281
+ Returns:
282
+ list: A list of lists of numpy arrays. Each inner list corresponds to a label column,
283
+ containing the softmax probabilities for each sample for that label.
284
+ """
285
+ model.eval() # Set to evaluation mode
286
+ # List to store probabilities for each output head
287
+ all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
288
+
289
+ with torch.no_grad():
290
+ for batch in tqdm(loader, desc="Predicting Probabilities"):
291
+ # Unpack batch, ignoring labels as we only need inputs
292
+ if len(batch) == 2:
293
+ inputs, _ = batch
294
+ input_ids = inputs['input_ids'].to(DEVICE)
295
+ attention_mask = inputs['attention_mask'].to(DEVICE)
296
+ outputs = model(input_ids, attention_mask)
297
+ elif len(batch) == 3:
298
+ inputs, metadata, _ = batch
299
+ input_ids = inputs['input_ids'].to(DEVICE)
300
+ attention_mask = inputs['attention_mask'].to(DEVICE)
301
+ metadata = metadata.to(DEVICE)
302
+ outputs = model(input_ids, attention_mask, metadata)
303
+ else:
304
+ raise ValueError("Unsupported batch format.")
305
+
306
+ for i, out_logits in enumerate(outputs):
307
+ # Apply softmax to logits to get probabilities
308
+ probs = torch.softmax(out_logits, dim=1).cpu().numpy()
309
+ all_probabilities[i].extend(probs)
310
+ return all_probabilities
voting.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # voting.py
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ from sklearn.metrics import classification_report
7
+ import os
8
+ import pickle
9
+
10
+ from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR
11
+
12
+ def save_predictions(model_name, all_probs, true_labels):
13
+ """
14
+ Saves the prediction probabilities and true labels for each target field
15
+ from a specific model. This data is then used by the voting ensemble.
16
+
17
+ Args:
18
+ model_name (str): Unique identifier for the model (e.g., "BERT", "TF_IDF_LR").
19
+ all_probs (list): A list where each element is a NumPy array of probabilities
20
+ for a corresponding label column (shape: num_samples, num_classes).
21
+ true_labels (list): A list where each element is a NumPy array of true labels
22
+ for a corresponding label column (shape: num_samples,).
23
+ """
24
+ model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name)
25
+ os.makedirs(model_preds_dir, exist_ok=True) # Ensure the model-specific directory exists
26
+
27
+ for i, col in enumerate(LABEL_COLUMNS):
28
+ # Define file paths for probabilities and true labels for the current field
29
+ prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl")
30
+ true_file = os.path.join(model_preds_dir, f"{col}_true.pkl")
31
+
32
+ # Save probabilities (list of arrays) and true labels (list of arrays)
33
+ with open(prob_file, 'wb') as f:
34
+ pickle.dump(all_probs[i], f)
35
+ with open(true_file, 'wb') as f:
36
+ pickle.dump(true_labels[i], f)
37
+ print(f"Predictions for {model_name} saved to {model_preds_dir}")
38
+
39
+ def load_predictions(model_name):
40
+ """
41
+ Loads saved prediction probabilities and true labels for a given model.
42
+
43
+ Args:
44
+ model_name (str): Unique identifier for the model.
45
+
46
+ Returns:
47
+ tuple: A tuple containing:
48
+ - all_probs (list): List of NumPy arrays of probabilities for each label column.
49
+ - true_labels (list): List of NumPy arrays of true labels for each label column.
50
+ Returns (None, None) if files are not found.
51
+ """
52
+ model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name)
53
+ all_probs = [[] for _ in range(len(LABEL_COLUMNS))]
54
+ true_labels = [[] for _ in range(len(LABEL_COLUMNS))]
55
+
56
+ found_all_files = True
57
+ for i, col in enumerate(LABEL_COLUMNS):
58
+ prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl")
59
+ true_file = os.path.join(model_preds_dir, f"{col}_true.pkl")
60
+ if os.path.exists(prob_file) and os.path.exists(true_file):
61
+ with open(prob_file, 'rb') as f:
62
+ all_probs[i] = pickle.load(f)
63
+ with open(true_file, 'rb') as f:
64
+ true_labels[i] = pickle.load(f)
65
+ else:
66
+ print(f"Warning: Prediction files not found for {model_name} - {col}. This model might be excluded for this label in ensemble.")
67
+ found_all_files = False # Mark that not all files were found
68
+
69
+ if not found_all_files:
70
+ return None, None # Indicate that this model's predictions couldn't be fully loaded
71
+
72
+ # Convert list of lists to list of numpy arrays if they were loaded as lists
73
+ # This ensures consistency for stacking later.
74
+ all_probs = [np.array(p) for p in all_probs]
75
+ true_labels = [np.array(t) for t in true_labels]
76
+
77
+ return all_probs, true_labels
78
+
79
+ def perform_voting_ensemble(model_names_to_ensemble):
80
+ """
81
+ Performs a soft voting ensemble (averaging probabilities) for each label
82
+ across a list of specified models.
83
+
84
+ Args:
85
+ model_names_to_ensemble (list): A list of string names of the models
86
+ whose predictions should be ensembled.
87
+ These names should match the directory names
88
+ under `PREDICTIONS_SAVE_DIR`.
89
+
90
+ Returns:
91
+ tuple: A tuple containing:
92
+ - ensemble_reports (dict): Classification reports for the ensemble predictions.
93
+ - all_true_labels_for_ensemble (list): List of true labels used for evaluation.
94
+ - ensemble_predictions (list): List of predicted class indices from the ensemble.
95
+ """
96
+ print("\n--- Performing Voting Ensemble ---")
97
+ # defaultdict stores a list for each key, helpful when appending to potentially new keys
98
+ all_models_probs = defaultdict(list) # Stores list of probability arrays per label for all models
99
+ # Initialize with empty lists; true labels for evaluation (should be consistent across models)
100
+ all_true_labels_for_ensemble = [None for _ in range(len(LABEL_COLUMNS))]
101
+
102
+ # Load probabilities from all specified models
103
+ for model_name in model_names_to_ensemble:
104
+ print(f"Loading predictions for {model_name}...")
105
+ probs_per_label, true_labels_per_label = load_predictions(model_name)
106
+
107
+ if probs_per_label is None: # Skip this model if loading failed
108
+ continue
109
+
110
+ for i, col in enumerate(LABEL_COLUMNS):
111
+ if len(probs_per_label[i]) > 0: # Ensure probabilities were actually loaded for this label
112
+ all_models_probs[col].append(probs_per_label[i])
113
+ if all_true_labels_for_ensemble[i] is None: # Store true labels only once (they should be identical)
114
+ all_true_labels_for_ensemble[i] = true_labels_per_label[i]
115
+
116
+ ensemble_predictions = [[] for _ in range(len(LABEL_COLUMNS))]
117
+ ensemble_reports = {}
118
+
119
+ for i, col in enumerate(LABEL_COLUMNS):
120
+ if not all_models_probs[col]: # If no models provided predictions for this label
121
+ print(f"No valid predictions available for {col} to ensemble. Skipping.")
122
+ ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
123
+ continue
124
+
125
+ # Stack probabilities for the current label from all models that had them.
126
+ # `stacked_probs` will have shape: (num_contributing_models, num_samples, num_classes)
127
+ stacked_probs = np.stack(all_models_probs[col], axis=0)
128
+
129
+ # Perform soft voting by summing probabilities across models.
130
+ # `summed_probs` will have shape: (num_samples, num_classes)
131
+ summed_probs = np.sum(stacked_probs, axis=0)
132
+
133
+ # Get the final predicted class by taking the argmax of the summed probabilities.
134
+ final_preds = np.argmax(summed_probs, axis=1) # (num_samples,)
135
+
136
+ ensemble_predictions[i] = final_preds.tolist()
137
+
138
+ # Evaluate ensemble predictions
139
+ y_true_ensemble = all_true_labels_for_ensemble[i]
140
+ if y_true_ensemble is not None: # Ensure true labels are available
141
+ try:
142
+ report = classification_report(y_true_ensemble, final_preds, output_dict=True, zero_division=0)
143
+ ensemble_reports[col] = report
144
+ except ValueError:
145
+ print(f"Warning: Could not generate ensemble classification report for {col}. Skipping.")
146
+ ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
147
+ else:
148
+ print(f"Warning: True labels not found for {col}, cannot evaluate ensemble.")
149
+ ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
150
+
151
+
152
+ return ensemble_reports, all_true_labels_for_ensemble, ensemble_predictions