lucvantien1211's picture
Upload src folder, which contains modules and scripts
b20701a verified
from pathlib import Path
import logging
from datetime import datetime
import time
import pandas as pd
import numpy as np
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
import torch
import torch.nn as nn
from torch.utils.data import WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from tqdm.auto import tqdm
from safetensors.torch import save_file
from src.plot_utils import plot_confusion_matrix
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def split_train_val_paths(train_root, metadata_path, random_state=None):
train_root = Path(train_root)
metadata_df = pd.read_csv(metadata_path)
X = metadata_df[["label", "video_name"]]
y = metadata_df["label"]
X_train, X_val, _, _ = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=random_state
)
train_paths = (train_root / X_train["label"] / X_train["video_name"]).to_list()
val_paths = (train_root / X_val["label"] / X_val["video_name"]).to_list()
return train_paths, val_paths
def create_balanced_sampler(dataset):
'''Create balanced sampler for imbalanced dataset'''
all_labels = dataset.labels
class_counts = np.bincount(all_labels)
class_weights = 1.0 / class_counts
sample_weights = [class_weights[label] for label in all_labels]
sample_weights = torch.FloatTensor(sample_weights)
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
return sampler
def validate(model, dataloader, criterion, device):
model.eval()
total_loss, preds, labels_all = 0, [], []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Validation"):
frames, labels = batch["frames"].to(device), batch["labels"].to(device)
outputs = model(frames)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = outputs.max(1)
preds.extend(predicted.cpu().numpy())
labels_all.extend(labels.cpu().numpy())
precision, recall, f1, _ = precision_recall_fscore_support(
labels_all, preds, average="macro", zero_division=0
)
return (
total_loss / len(dataloader),
{"precision": precision*100, "recall": recall*100, "f1": f1*100},
preds,
labels_all
)
def train_epoch(model, dataloader, criterion, optimizer, device):
total_loss = 0
progress = tqdm(dataloader, desc="Training")
model.train()
for batch in progress:
frames, labels = batch["frames"].to(device), batch["labels"].to(device)
optimizer.zero_grad()
outputs = model(frames)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
lr = optimizer.param_groups[0]["lr"]
progress.set_postfix({"loss": f"{total_loss / (len(progress)+1e-9):.4f}"})
return total_loss / len(dataloader), lr
def train_model(
model, train_loader, val_loader, logger,
num_epochs=10, lr=5e-4, device="cuda",
early_stopping_patience=3,
save_path="best_model.safetensors",
validation_cm_path="validation_cm.png"
):
model = model.to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(
params=filter(lambda p: p.requires_grad, model.parameters()),
lr=lr,
weight_decay=1e-4
)
scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=10,
T_mult=2,
eta_min=1e-6
)
train_losses = []
val_losses = []
precision_scores = []
recall_scores = []
f1_scores = []
learning_rates = []
best_f1 = 0.0
best_f1_epoch = 1
early_stopping_cnt = 0
start_time = time.time()
for epoch in range(num_epochs):
epoch_start = time.time()
logger.info(f"===== Epoch {epoch+1}/{num_epochs} =====")
train_loss, lr = train_epoch(
model, train_loader, criterion, optimizer, device
)
val_loss, val_metrics, preds, labels_all = validate(
model, val_loader, criterion, device
)
scheduler.step()
epoch_time = time.time() - epoch_start
train_losses.append(train_loss)
val_losses.append(val_loss)
precision_scores.append(val_metrics["precision"])
recall_scores.append(val_metrics["recall"])
f1_scores.append(val_metrics["f1"])
learning_rates.append(lr)
logger.info(
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Val Precision: {val_metrics['precision']:.2f}% | "
f"Val Recall: {val_metrics['recall']:.2f}% | "
f"Val F1: {val_metrics['f1']:.2f}% | "
f"LR: {lr:.6f} | "
f"Time: {epoch_time:.2f}s"
)
if val_metrics["f1"] > best_f1:
label_mapping = train_loader.dataset.label2id
best_f1 = val_metrics["f1"]
best_f1_epoch = epoch + 1
early_stopping_cnt = 0
save_file(model.state_dict(), save_path)
plot_confusion_matrix(
labels_all, preds,
labels=[v for k, v in sorted(label_mapping.items(), key=lambda x: x[1])],
display_labels=[k for k, v in sorted(label_mapping.items(), key=lambda x: x[1])],
top_k=10,
figsize=(20, 24),
normalize="true",
save_path=validation_cm_path
)
logger.info(f"✓ Best model saved with F1: {best_f1:.2f}%")
logger.info(f"✓ Best validation results saved at: {validation_cm_path}")
else:
early_stopping_cnt += 1
if early_stopping_cnt == early_stopping_patience:
logger.info(
f"Early stopping triggered. Best macro F1: {best_f1:.2f}, "
f"achieved on epoch {best_f1_epoch}"
)
break
total_time = time.time() - start_time
logger.info("========== TRAINING END ==========")
logger.info(f"Best F1: {best_f1:.2f}%")
logger.info(f"Total Time: {total_time/60:.2f} minutes")
return (
train_losses, val_losses, precision_scores,
recall_scores, f1_scores, learning_rates
)
def setup_logger(log_dir="logs"):
Path(log_dir).mkdir(exist_ok=True)
log_file = Path(log_dir) / f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logger = logging.getLogger("train_logger")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s | %(levelname)s | %(message)s",
"%Y-%m-%d %H:%M:%S"
)
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger, log_file
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)