Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import json | |
| import logging | |
| import time | |
| import pickle | |
| import copy | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import TensorDataset, DataLoader, Subset | |
| from sklearn.model_selection import StratifiedKFold | |
| from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score | |
| from matplotlib import pyplot as plt | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| if str(_PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_PROJECT_ROOT)) | |
| # We need the Tokenizer from stage 2 to execute texts_to_sequences natively | |
| from src.stage2_preprocessing import KerasStyleTokenizer | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s") | |
| logger = logging.getLogger("lstm_model") | |
| # ββ Architecture ββββββββββββββββββββββββββββββββββββββ | |
| class SpatialDropout1D(nn.Module): | |
| def __init__(self, p=0.3): | |
| super().__init__() | |
| self.p = p | |
| def forward(self, x): | |
| if not self.training or self.p == 0: | |
| return x | |
| # x is (batch, seq_len, embed_dim) | |
| # convert to (batch, embed_dim, seq_len) | |
| x = x.permute(0, 2, 1) | |
| # 1D spatial dropout is equivalent to 2d dropout with height 1 | |
| # nn.Dropout2d drops entire channels (which are our embedding dimensions) | |
| x = x.unsqueeze(3) | |
| x = F.dropout2d(x, p=self.p, training=self.training) | |
| x = x.squeeze(3) | |
| return x.permute(0, 2, 1) | |
| class BiLSTMClassifier(nn.Module): | |
| def __init__(self, vocab_size, embedding_matrix=None): | |
| super().__init__() | |
| # Embedding(vocab_size, 100) | |
| self.embedding = nn.Embedding(vocab_size, 100, padding_idx=0) | |
| if embedding_matrix is not None: | |
| self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix)) | |
| self.embedding.weight.requires_grad = False | |
| self.spatial_drop = SpatialDropout1D(0.3) | |
| # Bi-LSTM(100->128, bidirectional=True) | |
| self.lstm1 = nn.LSTM(100, 128, bidirectional=True, batch_first=True) | |
| # Bi-LSTM(256->64, bidirectional=True) | |
| self.lstm2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True) | |
| # Linear(128, 64) + ReLU | |
| self.fc1 = nn.Linear(128, 64) | |
| self.dropout = nn.Dropout(0.4) | |
| # Linear(64, 1) + Sigmoid (handled via BCEWithLogitsLoss below conceptually, or explicitly applied) | |
| self.fc2 = nn.Linear(64, 1) | |
| def forward(self, x): | |
| h = self.embedding(x) | |
| h = self.spatial_drop(h) | |
| h, _ = self.lstm1(h) | |
| # Taking last states? Typically Keras `return_sequences=False` on the 2nd LSTM | |
| # means it takes the final hidden state of the sequence | |
| _, (h_n, _) = self.lstm2(h) | |
| # h_n shape for Bi-LSTM: (2, batch, hidden_size) | |
| # Concatenate forward and backward final states | |
| h_concat = torch.cat((h_n[-2,:,:], h_n[-1,:,:]), dim=1) # shape: (batch, 128) | |
| out = F.relu(self.fc1(h_concat)) | |
| out = self.dropout(out) | |
| logits = self.fc2(out) | |
| return logits.squeeze(1) | |
| # ββ Utilities ββββββββββββββββββββββββββββββββββββββ | |
| def pad_sequences(sequences, maxlen=512, padding='post'): | |
| padded = np.zeros((len(sequences), maxlen), dtype=np.int64) | |
| for i, seq in enumerate(sequences): | |
| seq = seq[:maxlen] | |
| if padding == 'post': | |
| padded[i, :len(seq)] = seq | |
| else: | |
| padded[i, -len(seq):] = seq | |
| return padded | |
| def load_glove_embeddings(glove_path, word_index, embed_dim=100): | |
| logger.info(f"Loading GloVe embeddings from {glove_path}...") | |
| embeddings_index = {} | |
| with open(glove_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| values = line.split() | |
| word = values[0] | |
| coefs = np.asarray(values[1:], dtype='float32') | |
| embeddings_index[word] = coefs | |
| vocab_size = len(word_index) + 1 # 1 for padding | |
| embedding_matrix = np.zeros((vocab_size, embed_dim), dtype=np.float32) | |
| hits, misses = 0, 0 | |
| for word, i in word_index.items(): | |
| embedding_vector = embeddings_index.get(word) | |
| if embedding_vector is not None: | |
| embedding_matrix[i] = embedding_vector | |
| hits += 1 | |
| else: | |
| misses += 1 | |
| logger.info(f"GloVe mapped: {hits} hits, {misses} misses.") | |
| return embedding_matrix, vocab_size | |
| def plot_and_save_cm(y_true, y_pred, path): | |
| cm = confusion_matrix(y_true, (np.array(y_pred) > 0.5).astype(int)) | |
| fig, ax = plt.subplots(figsize=(5, 5)) | |
| ax.matshow(cm, cmap=plt.cm.Blues, alpha=0.3) | |
| for i in range(cm.shape[0]): | |
| for j in range(cm.shape[1]): | |
| ax.text(x=j, y=i, s=cm[i, j], va='center', ha='center', size='xx-large') | |
| plt.xlabel('Predicted Label') | |
| plt.ylabel('True Label') | |
| plt.title('Bi-LSTM Confusion Matrix') | |
| plt.tight_layout() | |
| plt.savefig(path) | |
| plt.close() | |
| # ββ Training Loop ββββββββββββββββββββββββββββββββββββββ | |
| def train_epoch(model, loader, optimizer, criterion, device): | |
| model.train() | |
| total_loss = 0 | |
| for x_batch, y_batch in loader: | |
| x_batch, y_batch = x_batch.to(device), y_batch.to(device) | |
| optimizer.zero_grad() | |
| logits = model(x_batch) | |
| loss = criterion(logits, y_batch) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * x_batch.size(0) | |
| return total_loss / len(loader.dataset) | |
| def eval_model(model, loader, criterion, device): | |
| model.eval() | |
| total_loss = 0 | |
| all_preds = [] | |
| for x_batch, y_batch in loader: | |
| x_batch, y_batch = x_batch.to(device), y_batch.to(device) | |
| logits = model(x_batch) | |
| loss = criterion(logits, y_batch) | |
| total_loss += loss.item() * x_batch.size(0) | |
| probas = torch.sigmoid(logits).cpu().numpy() | |
| all_preds.extend(probas) | |
| return total_loss / len(loader.dataset), np.array(all_preds) | |
| def train_lstm_logic(cfg, splits_dir, save_dir, glove_path): | |
| os.makedirs(save_dir, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Load tokenized resources | |
| train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv")) | |
| val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv")) | |
| y_train = np.float32(train_df["binary_label"].values) | |
| y_val = np.float32(val_df["binary_label"].values) | |
| with open(os.path.join(_PROJECT_ROOT, cfg["paths"]["models_dir"], "tokenizer.pkl"), "rb") as f: | |
| tokenizer = pickle.load(f) | |
| maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512) | |
| batch_size = cfg.get("training", {}).get("lstm_batch_size", 64) | |
| epochs = cfg.get("training", {}).get("lstm_epochs", 10) | |
| logger.info("Transforming texts to padded sequences...") | |
| X_train_seq = tokenizer.texts_to_sequences(train_df["clean_text"].fillna("")) | |
| X_val_seq = tokenizer.texts_to_sequences(val_df["clean_text"].fillna("")) | |
| X_train_pad = pad_sequences(X_train_seq, maxlen=maxlen, padding='post') | |
| X_val_pad = pad_sequences(X_val_seq, maxlen=maxlen, padding='post') | |
| # Embedding matrix | |
| emb_matrix, vocab_size = load_glove_embeddings(glove_path, tokenizer.word_index) | |
| # Class weights balancing formula: n_samples / (n_classes * np.bincount(y)) | |
| class_counts = np.bincount(y_train.astype(int)) | |
| pos_weight = torch.tensor([class_counts[0] / class_counts[1]], dtype=torch.float32).to(device) | |
| # Datasets | |
| train_tensor = TensorDataset(torch.from_numpy(X_train_pad).long(), torch.from_numpy(y_train)) | |
| val_tensor = TensorDataset(torch.from_numpy(X_val_pad).long(), torch.from_numpy(y_val)) | |
| val_loader = DataLoader(val_tensor, batch_size=batch_size, shuffle=False) | |
| # --- 5-Fold OOF Predictions --- | |
| logger.info("Starting 5-Fold OOF generation...") | |
| skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) | |
| oof_preds = np.zeros_like(y_train, dtype=np.float32) | |
| criterion_kfold = nn.BCEWithLogitsLoss(pos_weight=pos_weight) | |
| for fold, (t_idx, v_idx) in enumerate(skf.split(X_train_pad, y_train)): | |
| logger.info(f"OOF Fold {fold+1}/5") | |
| fold_train_ds = Subset(train_tensor, t_idx) | |
| fold_val_ds = Subset(train_tensor, v_idx) | |
| fold_train_loader = DataLoader(fold_train_ds, batch_size=batch_size, shuffle=True) | |
| fold_val_loader = DataLoader(fold_val_ds, batch_size=batch_size, shuffle=False) | |
| model = BiLSTMClassifier(vocab_size, emb_matrix).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5) | |
| best_val_loss = float('inf') | |
| patience_counter = 0 | |
| best_weights = copy.deepcopy(model.state_dict()) | |
| for ep in range(epochs): # Or hardcode early stop tightly for OOF e.g., 3-4 epochs max to save time | |
| t_loss = train_epoch(model, fold_train_loader, optimizer, criterion_kfold, device) | |
| v_loss, v_preds = eval_model(model, fold_val_loader, criterion_kfold, device) | |
| scheduler.step(v_loss) | |
| if v_loss < best_val_loss: | |
| best_val_loss = v_loss | |
| best_weights = copy.deepcopy(model.state_dict()) | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= 3: | |
| break | |
| # Apply the best model | |
| model.load_state_dict(best_weights) | |
| _, fold_best_preds = eval_model(model, fold_val_loader, criterion_kfold, device) | |
| oof_preds[v_idx] = fold_best_preds | |
| np.save(os.path.join(save_dir, "lstm_oof.npy"), oof_preds) | |
| logger.info("Saved OOF predictions (lstm_oof.npy).") | |
| # --- Final Training on ALL Data --- | |
| logger.info("Starting final model training on full Train split...") | |
| train_loader = DataLoader(train_tensor, batch_size=batch_size, shuffle=True) | |
| model = BiLSTMClassifier(vocab_size, emb_matrix).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5) | |
| best_val_loss = float('inf') | |
| best_weights = copy.deepcopy(model.state_dict()) | |
| patience_counter = 0 | |
| for ep in range(epochs): | |
| t_loss = train_epoch(model, train_loader, optimizer, criterion_kfold, device) | |
| v_loss, v_preds = eval_model(model, val_loader, criterion_kfold, device) | |
| scheduler.step(v_loss) | |
| logger.info(f" Epoch {ep+1}/{epochs} | Train Loss: {t_loss:.4f} | Val Loss: {v_loss:.4f}") | |
| if v_loss < best_val_loss: | |
| best_val_loss = v_loss | |
| best_weights = copy.deepcopy(model.state_dict()) | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= 3: | |
| logger.info(" EarlyStopping triggered.") | |
| break | |
| model.load_state_dict(best_weights) | |
| torch.save(model.state_dict(), os.path.join(save_dir, "model.pt")) | |
| logger.info("Saved final LSTM weights.") | |
| # Evaluate Validation Split | |
| _, val_preds_probas = eval_model(model, val_loader, criterion_kfold, device) | |
| val_preds_binary = (val_preds_probas >= 0.5).astype(int) | |
| logger.info("Validation Classification Report:\n" + classification_report(y_val, val_preds_binary)) | |
| roc_auc = roc_auc_score(y_val, val_preds_probas) | |
| logger.info(f"ROC-AUC: {roc_auc:.4f}") | |
| plot_and_save_cm(y_val, val_preds_probas, os.path.join(save_dir, "cm.png")) | |
| bucket_acc = {} | |
| for b in ["short", "medium", "long"]: | |
| b_mask = (val_df["text_length_bucket"] == b).values | |
| if b_mask.sum() > 0: | |
| acc = (val_preds_binary[b_mask] == y_val[b_mask]).mean() | |
| bucket_acc[b] = acc | |
| metrics = { | |
| "roc_auc": float(roc_auc), | |
| "bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()} | |
| } | |
| with open(os.path.join(save_dir, "metrics.json"), "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| if __name__ == "__main__": | |
| import yaml | |
| cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml") | |
| with open(cfg_path, "r", encoding="utf-8") as file: | |
| config = yaml.safe_load(file) | |
| s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"]) | |
| m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "lstm_model") | |
| g_path = os.path.join(_PROJECT_ROOT, config["paths"]["glove_path"]) | |
| t0 = time.time() | |
| train_lstm_logic(config, s_dir, m_dir, g_path) | |
| print(f"Total time: {time.time() - t0:.2f}s") | |