| from pathlib import Path | |
| import logging | |
| from datetime import datetime | |
| import time | |
| import pandas as pd | |
| import numpy as np | |
| import random | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import precision_recall_fscore_support | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import WeightedRandomSampler | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts | |
| from tqdm.auto import tqdm | |
| from safetensors.torch import save_file | |
| from src.plot_utils import plot_confusion_matrix | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def split_train_val_paths(train_root, metadata_path, random_state=None): | |
| train_root = Path(train_root) | |
| metadata_df = pd.read_csv(metadata_path) | |
| X = metadata_df[["label", "video_name"]] | |
| y = metadata_df["label"] | |
| X_train, X_val, _, _ = train_test_split( | |
| X, y, test_size=0.2, stratify=y, random_state=random_state | |
| ) | |
| train_paths = (train_root / X_train["label"] / X_train["video_name"]).to_list() | |
| val_paths = (train_root / X_val["label"] / X_val["video_name"]).to_list() | |
| return train_paths, val_paths | |
| def create_balanced_sampler(dataset): | |
| '''Create balanced sampler for imbalanced dataset''' | |
| all_labels = dataset.labels | |
| class_counts = np.bincount(all_labels) | |
| class_weights = 1.0 / class_counts | |
| sample_weights = [class_weights[label] for label in all_labels] | |
| sample_weights = torch.FloatTensor(sample_weights) | |
| sampler = WeightedRandomSampler( | |
| weights=sample_weights, | |
| num_samples=len(sample_weights), | |
| replacement=True | |
| ) | |
| return sampler | |
| def validate(model, dataloader, criterion, device): | |
| model.eval() | |
| total_loss, preds, labels_all = 0, [], [] | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, desc="Validation"): | |
| frames, labels = batch["frames"].to(device), batch["labels"].to(device) | |
| outputs = model(frames) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| preds.extend(predicted.cpu().numpy()) | |
| labels_all.extend(labels.cpu().numpy()) | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| labels_all, preds, average="macro", zero_division=0 | |
| ) | |
| return ( | |
| total_loss / len(dataloader), | |
| {"precision": precision*100, "recall": recall*100, "f1": f1*100}, | |
| preds, | |
| labels_all | |
| ) | |
| def train_epoch(model, dataloader, criterion, optimizer, device): | |
| total_loss = 0 | |
| progress = tqdm(dataloader, desc="Training") | |
| model.train() | |
| for batch in progress: | |
| frames, labels = batch["frames"].to(device), batch["labels"].to(device) | |
| optimizer.zero_grad() | |
| outputs = model(frames) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| lr = optimizer.param_groups[0]["lr"] | |
| progress.set_postfix({"loss": f"{total_loss / (len(progress)+1e-9):.4f}"}) | |
| return total_loss / len(dataloader), lr | |
| def train_model( | |
| model, train_loader, val_loader, logger, | |
| num_epochs=10, lr=5e-4, device="cuda", | |
| early_stopping_patience=3, | |
| save_path="best_model.safetensors", | |
| validation_cm_path="validation_cm.png" | |
| ): | |
| model = model.to(device) | |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| optimizer = AdamW( | |
| params=filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=lr, | |
| weight_decay=1e-4 | |
| ) | |
| scheduler = CosineAnnealingWarmRestarts( | |
| optimizer, | |
| T_0=10, | |
| T_mult=2, | |
| eta_min=1e-6 | |
| ) | |
| train_losses = [] | |
| val_losses = [] | |
| precision_scores = [] | |
| recall_scores = [] | |
| f1_scores = [] | |
| learning_rates = [] | |
| best_f1 = 0.0 | |
| best_f1_epoch = 1 | |
| early_stopping_cnt = 0 | |
| start_time = time.time() | |
| for epoch in range(num_epochs): | |
| epoch_start = time.time() | |
| logger.info(f"===== Epoch {epoch+1}/{num_epochs} =====") | |
| train_loss, lr = train_epoch( | |
| model, train_loader, criterion, optimizer, device | |
| ) | |
| val_loss, val_metrics, preds, labels_all = validate( | |
| model, val_loader, criterion, device | |
| ) | |
| scheduler.step() | |
| epoch_time = time.time() - epoch_start | |
| train_losses.append(train_loss) | |
| val_losses.append(val_loss) | |
| precision_scores.append(val_metrics["precision"]) | |
| recall_scores.append(val_metrics["recall"]) | |
| f1_scores.append(val_metrics["f1"]) | |
| learning_rates.append(lr) | |
| logger.info( | |
| f"Train Loss: {train_loss:.4f} | " | |
| f"Val Loss: {val_loss:.4f} | " | |
| f"Val Precision: {val_metrics['precision']:.2f}% | " | |
| f"Val Recall: {val_metrics['recall']:.2f}% | " | |
| f"Val F1: {val_metrics['f1']:.2f}% | " | |
| f"LR: {lr:.6f} | " | |
| f"Time: {epoch_time:.2f}s" | |
| ) | |
| if val_metrics["f1"] > best_f1: | |
| label_mapping = train_loader.dataset.label2id | |
| best_f1 = val_metrics["f1"] | |
| best_f1_epoch = epoch + 1 | |
| early_stopping_cnt = 0 | |
| save_file(model.state_dict(), save_path) | |
| plot_confusion_matrix( | |
| labels_all, preds, | |
| labels=[v for k, v in sorted(label_mapping.items(), key=lambda x: x[1])], | |
| display_labels=[k for k, v in sorted(label_mapping.items(), key=lambda x: x[1])], | |
| top_k=10, | |
| figsize=(20, 24), | |
| normalize="true", | |
| save_path=validation_cm_path | |
| ) | |
| logger.info(f"✓ Best model saved with F1: {best_f1:.2f}%") | |
| logger.info(f"✓ Best validation results saved at: {validation_cm_path}") | |
| else: | |
| early_stopping_cnt += 1 | |
| if early_stopping_cnt == early_stopping_patience: | |
| logger.info( | |
| f"Early stopping triggered. Best macro F1: {best_f1:.2f}, " | |
| f"achieved on epoch {best_f1_epoch}" | |
| ) | |
| break | |
| total_time = time.time() - start_time | |
| logger.info("========== TRAINING END ==========") | |
| logger.info(f"Best F1: {best_f1:.2f}%") | |
| logger.info(f"Total Time: {total_time/60:.2f} minutes") | |
| return ( | |
| train_losses, val_losses, precision_scores, | |
| recall_scores, f1_scores, learning_rates | |
| ) | |
| def setup_logger(log_dir="logs"): | |
| Path(log_dir).mkdir(exist_ok=True) | |
| log_file = Path(log_dir) / f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" | |
| logger = logging.getLogger("train_logger") | |
| logger.setLevel(logging.INFO) | |
| formatter = logging.Formatter( | |
| "%(asctime)s | %(levelname)s | %(message)s", | |
| "%Y-%m-%d %H:%M:%S" | |
| ) | |
| file_handler = logging.FileHandler(log_file) | |
| file_handler.setFormatter(formatter) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| logger.addHandler(console_handler) | |
| return logger, log_file | |
| def seed_worker(worker_id): | |
| worker_seed = torch.initial_seed() % 2**32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) |