# src/misc_utils.py import json import torch import shutil import os import datetime from transformers import TrainerCallback def load_config(config_path: str = "config.json") -> dict: with open(config_path, "r") as f: return json.load(f) def get_training_config(config: dict) -> dict: return config["training"] def get_model_config(config: dict) -> dict: return {k: v for k, v in config.items() if k not in ["training", "generation", "model_type", "architectures"]} def get_generation_config(config: dict) -> dict: return config.get("generation", {}) def clear_cache(): print("Clearing PyTorch and CUDA caches...") if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() print("CUDA cache cleared") torch.backends.cudnn.benchmark = True print("PyTorch cache cleared") def clear_datasets_cache(): from datasets import get_cache_directory try: cache_dir = get_cache_directory() print(f"Clearing datasets cache at: {cache_dir}") if os.path.exists(cache_dir): shutil.rmtree(cache_dir) print("Datasets cache cleared") except: print("Could not clear datasets cache (may not exist)") class LossLoggerCallback(TrainerCallback): def __init__(self, log_file="training_losses.txt", with_timestamp=False): self.log_file = log_file self.with_timestamp = with_timestamp with open(self.log_file, "w") as f: if self.with_timestamp: f.write("time\tstep\tloss\teval_loss\n") else: f.write("step\tloss\teval_loss\n") def on_log(self, args, state, control, logs=None, **kwargs): if logs is None: return step = state.global_step loss = logs.get("loss") eval_loss = logs.get("eval_loss") with open(self.log_file, "a") as f: if self.with_timestamp: ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") f.write(f"{ts}\t{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n") else: f.write(f"{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n") class CheckpointEvery10PercentCallback(TrainerCallback): def __init__(self, save_dir, total_steps): self.save_dir = save_dir self.total_steps = total_steps self.checkpoint_intervals = [] for i in range(1, 11): checkpoint_step = int(total_steps * i * 0.1) self.checkpoint_intervals.append(checkpoint_step) self.saved_checkpoints = set() print(f"Checkpoint intervals: {self.checkpoint_intervals}") def on_step_end(self, args, state, control, **kwargs): current_step = state.global_step for checkpoint_step in self.checkpoint_intervals: if current_step == checkpoint_step and checkpoint_step not in self.saved_checkpoints: checkpoint_dir = f"{self.save_dir}/checkpoint_10percent_{current_step}" print(f"Saving 10% progress checkpoint at step {current_step} to {checkpoint_dir}") model = kwargs.get('model') tokenizer = kwargs.get('processing_class') if model is not None: model.save_pretrained(checkpoint_dir) if tokenizer is not None: tokenizer.save_pretrained(checkpoint_dir) if hasattr(kwargs.get('trainer'), 'save_state'): kwargs['trainer'].save_state() self.saved_checkpoints.add(checkpoint_step) print(f"Checkpoint saved at step {current_step} ({current_step/self.total_steps*100:.1f}% completion)") break