|
|
|
|
|
import json
|
|
|
import torch
|
|
|
import shutil
|
|
|
import os
|
|
|
import datetime
|
|
|
from transformers import TrainerCallback
|
|
|
|
|
|
def load_config(config_path: str = "config.json") -> dict:
|
|
|
with open(config_path, "r") as f:
|
|
|
return json.load(f)
|
|
|
|
|
|
def get_training_config(config: dict) -> dict:
|
|
|
return config["training"]
|
|
|
|
|
|
def get_model_config(config: dict) -> dict:
|
|
|
return {k: v for k, v in config.items()
|
|
|
if k not in ["training", "generation", "model_type", "architectures"]}
|
|
|
|
|
|
def get_generation_config(config: dict) -> dict:
|
|
|
return config.get("generation", {})
|
|
|
|
|
|
def clear_cache():
|
|
|
print("Clearing PyTorch and CUDA caches...")
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.synchronize()
|
|
|
print("CUDA cache cleared")
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
print("PyTorch cache cleared")
|
|
|
|
|
|
def clear_datasets_cache():
|
|
|
from datasets import get_cache_directory
|
|
|
try:
|
|
|
cache_dir = get_cache_directory()
|
|
|
print(f"Clearing datasets cache at: {cache_dir}")
|
|
|
if os.path.exists(cache_dir):
|
|
|
shutil.rmtree(cache_dir)
|
|
|
print("Datasets cache cleared")
|
|
|
except:
|
|
|
print("Could not clear datasets cache (may not exist)")
|
|
|
|
|
|
class LossLoggerCallback(TrainerCallback):
|
|
|
def __init__(self, log_file="training_losses.txt", with_timestamp=False):
|
|
|
self.log_file = log_file
|
|
|
self.with_timestamp = with_timestamp
|
|
|
with open(self.log_file, "w") as f:
|
|
|
if self.with_timestamp:
|
|
|
f.write("time\tstep\tloss\teval_loss\n")
|
|
|
else:
|
|
|
f.write("step\tloss\teval_loss\n")
|
|
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
|
if logs is None:
|
|
|
return
|
|
|
step = state.global_step
|
|
|
loss = logs.get("loss")
|
|
|
eval_loss = logs.get("eval_loss")
|
|
|
|
|
|
with open(self.log_file, "a") as f:
|
|
|
if self.with_timestamp:
|
|
|
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
f.write(f"{ts}\t{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
|
|
|
else:
|
|
|
f.write(f"{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
|
|
|
|
|
|
class CheckpointEvery10PercentCallback(TrainerCallback):
|
|
|
def __init__(self, save_dir, total_steps):
|
|
|
self.save_dir = save_dir
|
|
|
self.total_steps = total_steps
|
|
|
self.checkpoint_intervals = []
|
|
|
for i in range(1, 11):
|
|
|
checkpoint_step = int(total_steps * i * 0.1)
|
|
|
self.checkpoint_intervals.append(checkpoint_step)
|
|
|
self.saved_checkpoints = set()
|
|
|
print(f"Checkpoint intervals: {self.checkpoint_intervals}")
|
|
|
|
|
|
def on_step_end(self, args, state, control, **kwargs):
|
|
|
current_step = state.global_step
|
|
|
for checkpoint_step in self.checkpoint_intervals:
|
|
|
if current_step == checkpoint_step and checkpoint_step not in self.saved_checkpoints:
|
|
|
checkpoint_dir = f"{self.save_dir}/checkpoint_10percent_{current_step}"
|
|
|
print(f"Saving 10% progress checkpoint at step {current_step} to {checkpoint_dir}")
|
|
|
|
|
|
model = kwargs.get('model')
|
|
|
tokenizer = kwargs.get('processing_class')
|
|
|
|
|
|
if model is not None:
|
|
|
model.save_pretrained(checkpoint_dir)
|
|
|
if tokenizer is not None:
|
|
|
tokenizer.save_pretrained(checkpoint_dir)
|
|
|
|
|
|
if hasattr(kwargs.get('trainer'), 'save_state'):
|
|
|
kwargs['trainer'].save_state()
|
|
|
|
|
|
self.saved_checkpoints.add(checkpoint_step)
|
|
|
print(f"Checkpoint saved at step {current_step} ({current_step/self.total_steps*100:.1f}% completion)")
|
|
|
break |