ChemQ3MTP-base / ChemQ3MTP /misc_utils.py
gbyuvd's picture
Upload base files
379e2d8 verified
# src/misc_utils.py
import json
import torch
import shutil
import os
import datetime
from transformers import TrainerCallback
def load_config(config_path: str = "config.json") -> dict:
with open(config_path, "r") as f:
return json.load(f)
def get_training_config(config: dict) -> dict:
return config["training"]
def get_model_config(config: dict) -> dict:
return {k: v for k, v in config.items()
if k not in ["training", "generation", "model_type", "architectures"]}
def get_generation_config(config: dict) -> dict:
return config.get("generation", {})
def clear_cache():
print("Clearing PyTorch and CUDA caches...")
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("CUDA cache cleared")
torch.backends.cudnn.benchmark = True
print("PyTorch cache cleared")
def clear_datasets_cache():
from datasets import get_cache_directory
try:
cache_dir = get_cache_directory()
print(f"Clearing datasets cache at: {cache_dir}")
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
print("Datasets cache cleared")
except:
print("Could not clear datasets cache (may not exist)")
class LossLoggerCallback(TrainerCallback):
def __init__(self, log_file="training_losses.txt", with_timestamp=False):
self.log_file = log_file
self.with_timestamp = with_timestamp
with open(self.log_file, "w") as f:
if self.with_timestamp:
f.write("time\tstep\tloss\teval_loss\n")
else:
f.write("step\tloss\teval_loss\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
step = state.global_step
loss = logs.get("loss")
eval_loss = logs.get("eval_loss")
with open(self.log_file, "a") as f:
if self.with_timestamp:
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
f.write(f"{ts}\t{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
else:
f.write(f"{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
class CheckpointEvery10PercentCallback(TrainerCallback):
def __init__(self, save_dir, total_steps):
self.save_dir = save_dir
self.total_steps = total_steps
self.checkpoint_intervals = []
for i in range(1, 11):
checkpoint_step = int(total_steps * i * 0.1)
self.checkpoint_intervals.append(checkpoint_step)
self.saved_checkpoints = set()
print(f"Checkpoint intervals: {self.checkpoint_intervals}")
def on_step_end(self, args, state, control, **kwargs):
current_step = state.global_step
for checkpoint_step in self.checkpoint_intervals:
if current_step == checkpoint_step and checkpoint_step not in self.saved_checkpoints:
checkpoint_dir = f"{self.save_dir}/checkpoint_10percent_{current_step}"
print(f"Saving 10% progress checkpoint at step {current_step} to {checkpoint_dir}")
model = kwargs.get('model')
tokenizer = kwargs.get('processing_class')
if model is not None:
model.save_pretrained(checkpoint_dir)
if tokenizer is not None:
tokenizer.save_pretrained(checkpoint_dir)
if hasattr(kwargs.get('trainer'), 'save_state'):
kwargs['trainer'].save_state()
self.saved_checkpoints.add(checkpoint_step)
print(f"Checkpoint saved at step {current_step} ({current_step/self.total_steps*100:.1f}% completion)")
break