| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from tensorboardX import SummaryWriter |
| | from safetensors.torch import save_file, load_file |
| | from pathlib import Path |
| | import time |
| |
|
| | def count_parameters_layerwise(model): |
| | |
| | total_params = 0 |
| | layer_params = {} |
| | |
| | for name, parameter in model.named_parameters(): |
| | if not parameter.requires_grad: |
| | continue |
| | |
| | param_count = parameter.numel() |
| | layer_params[name] = param_count |
| | total_params += param_count |
| | |
| | print(f"\nModel Parameter Summary:") |
| | print("-" * 60) |
| | for name, count in layer_params.items(): |
| | print(f"{name}: {count:,} parameters") |
| | print("-" * 60) |
| | print(f"Total Trainable Parameters: {total_params:,}\n") |
| | |
| | return total_params |
| |
|
| | def save_checkpoint(model, filename="checkpoint.safetensors"): |
| | if hasattr(model, '_orig_mod'): |
| | model = model._orig_mod |
| | |
| | torch.save(model.state_dict(), filename.replace('.safetensors', '.pt')) |
| |
|
| | def load_checkpoint(model, filename="checkpoint.safetensors"): |
| | if hasattr(model, '_orig_mod'): |
| | model = model._orig_mod |
| | |
| | try: |
| | model_state = load_file(filename) |
| | model.load_state_dict(model_state) |
| | except Exception as e: |
| | model_state = torch.load(filename.replace('.safetensors', '.pt'), weights_only=True) |
| | model.load_state_dict(model_state) |
| | |
| | class TBLogger: |
| | def __init__(self, log_dir='logs/current_run', flush_secs=10, enable_grad_logging=True): |
| | Path(log_dir).mkdir(parents=True, exist_ok=True) |
| | self.writer = SummaryWriter(log_dir, flush_secs=flush_secs) |
| | self.enable_grad_logging = enable_grad_logging |
| | self.start_time = time.time() |
| | |
| | def log(self, metrics, step=None, model=None, prefix='', grad_checking=False): |
| | for name, value in metrics.items(): |
| | full_name = f"{prefix}{name}" if prefix else name |
| | |
| | if isinstance(value, (int, float)): |
| | self.writer.add_scalar(full_name, value, step) |
| | elif isinstance(value, torch.Tensor): |
| | self.writer.add_scalar(full_name, value.item(), step) |
| | elif isinstance(value, (list, tuple)) and len(value) > 0: |
| | if all(isinstance(x, (int, float)) for x in value): |
| | self.writer.add_histogram(full_name, torch.tensor(value), step) |
| | |
| | if self.enable_grad_logging and model is not None: |
| | self._log_gradients(model, step, grad_checking) |
| | |
| | def _log_gradients(self, model, step, grad_checking): |
| | total_norm = 0.0 |
| | for name, param in model.named_parameters(): |
| | if grad_checking and param.grad is not None: |
| | |
| | if torch.isnan(param.grad).any(): |
| | print(f"Warning: Found nan in gradients for layer: {name}") |
| | continue |
| | if torch.isinf(param.grad).any(): |
| | print(f"Warning: Found inf in gradients for layer: {name}") |
| | continue |
| | |
| | param_norm = param.grad.detach().data.norm(2) |
| | self.writer.add_scalar(f"gradients/{name}_norm", param_norm, step) |
| | total_norm += param_norm.item() ** 2 |
| | |
| | |
| | if total_norm > 0: |
| | total_norm = total_norm ** 0.5 |
| | self.writer.add_scalar("gradients/total_norm", total_norm, step) |
| | |
| | def close(self): |
| | self.writer.close() |