|
|
import argparse
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
import numpy as np
|
|
|
from accelerate import Accelerator
|
|
|
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
|
|
|
from torch.optim import AdamW
|
|
|
from tqdm import tqdm
|
|
|
import gc
|
|
|
import traceback
|
|
|
import matplotlib.pyplot as plt
|
|
|
from anticipation.vocab import ANTICIPATE, AUTOREGRESS
|
|
|
|
|
|
|
|
|
def print_gpu_memory_stats():
|
|
|
if torch.cuda.is_available():
|
|
|
for i in range(torch.cuda.device_count()):
|
|
|
print(f"GPU {i} memory allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
|
|
print(f"GPU {i} memory reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
|
|
|
print(f"GPU {i} max memory allocated: {torch.cuda.max_memory_allocated(i) / 1024**2:.2f} MB")
|
|
|
|
|
|
|
|
|
def check_model_for_nans(model):
|
|
|
for name, param in model.named_parameters():
|
|
|
if torch.isnan(param).any():
|
|
|
print(f"NaN detected in parameter {name}")
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.device("cuda")
|
|
|
device_count = torch.cuda.device_count()
|
|
|
print(f"✓ CUDA is available with {device_count} device(s)")
|
|
|
for i in range(device_count):
|
|
|
device_name = torch.cuda.get_device_name(i)
|
|
|
print(f" Device {i}: {device_name}")
|
|
|
props = torch.cuda.get_device_properties(i)
|
|
|
print(f" - Total memory: {props.total_memory / 1024**3:.2f} GB")
|
|
|
print(f" - CUDA capability: {props.major}.{props.minor}")
|
|
|
else:
|
|
|
device = torch.device("cpu")
|
|
|
print("✗ CUDA is not available! Training will be much slower on CPU.")
|
|
|
|
|
|
|
|
|
print(f"Using device: {device}")
|
|
|
print(f"PyTorch version: {torch.__version__}")
|
|
|
print(f"CUDA version: {torch.version.cuda}")
|
|
|
|
|
|
class SequencePackedDataset(Dataset):
|
|
|
def __init__(self, file_path, context_length=1024, max_packed_sequences=4):
|
|
|
"""Load data from tokenized file and implement sequence packing
|
|
|
|
|
|
Args:
|
|
|
file_path: Path to the tokenized data file
|
|
|
context_length: Maximum context length (default 1024)
|
|
|
max_packed_sequences: Maximum number of sequences to pack together (default 4)
|
|
|
"""
|
|
|
from anticipation.vocab import SEPARATOR, AUTOREGRESS, ANTICIPATE
|
|
|
|
|
|
|
|
|
individual_sequences = []
|
|
|
with open(file_path, 'r') as f:
|
|
|
for line in f:
|
|
|
tokens = list(map(int, line.strip().split()))
|
|
|
individual_sequences.append(tokens)
|
|
|
|
|
|
print(f"Loaded {len(individual_sequences)} individual sequences")
|
|
|
|
|
|
|
|
|
self.packed_sequences = []
|
|
|
self.attention_masks = []
|
|
|
|
|
|
|
|
|
self.total_packed = 0
|
|
|
self.avg_sequences_per_pack = 0
|
|
|
sequences_per_pack = []
|
|
|
|
|
|
|
|
|
import random
|
|
|
random.shuffle(individual_sequences)
|
|
|
|
|
|
|
|
|
current_packed = []
|
|
|
current_positions = []
|
|
|
|
|
|
for sequence in individual_sequences:
|
|
|
|
|
|
control_flag = sequence[0]
|
|
|
assert control_flag in [AUTOREGRESS, ANTICIPATE], f"Invalid control flag: {control_flag}"
|
|
|
|
|
|
|
|
|
sequence_content = sequence[1:]
|
|
|
|
|
|
|
|
|
|
|
|
if len(current_packed) > 0 and (len(current_packed) + 3 + len(sequence_content) > context_length or
|
|
|
len(sequences_per_pack) >= max_packed_sequences):
|
|
|
|
|
|
if len(current_packed) > 0:
|
|
|
|
|
|
attention_mask = torch.zeros(context_length, dtype=torch.long)
|
|
|
for start, end in current_positions:
|
|
|
attention_mask[start:end] = 1
|
|
|
|
|
|
|
|
|
if len(current_packed) < context_length:
|
|
|
padding_length = context_length - len(current_packed)
|
|
|
current_packed.extend([SEPARATOR] * padding_length)
|
|
|
|
|
|
|
|
|
self.packed_sequences.append(torch.tensor(current_packed[:context_length], dtype=torch.long))
|
|
|
self.attention_masks.append(attention_mask)
|
|
|
sequences_per_pack.append(len(current_positions))
|
|
|
self.total_packed += 1
|
|
|
|
|
|
|
|
|
current_packed = []
|
|
|
current_positions = []
|
|
|
|
|
|
|
|
|
start_pos = len(current_packed)
|
|
|
if len(current_packed) > 0:
|
|
|
|
|
|
current_packed.extend([SEPARATOR, SEPARATOR, SEPARATOR])
|
|
|
start_pos += 3
|
|
|
|
|
|
|
|
|
current_packed.append(control_flag)
|
|
|
current_packed.extend(sequence_content)
|
|
|
end_pos = len(current_packed)
|
|
|
|
|
|
|
|
|
current_positions.append((start_pos, end_pos))
|
|
|
|
|
|
|
|
|
if len(current_packed) > 0:
|
|
|
attention_mask = torch.zeros(context_length, dtype=torch.long)
|
|
|
for start, end in current_positions:
|
|
|
attention_mask[start:end] = 1
|
|
|
|
|
|
|
|
|
if len(current_packed) < context_length:
|
|
|
padding_length = context_length - len(current_packed)
|
|
|
current_packed.extend([SEPARATOR] * padding_length)
|
|
|
|
|
|
|
|
|
self.packed_sequences.append(torch.tensor(current_packed[:context_length], dtype=torch.long))
|
|
|
self.attention_masks.append(attention_mask)
|
|
|
sequences_per_pack.append(len(current_positions))
|
|
|
self.total_packed += 1
|
|
|
|
|
|
|
|
|
if sequences_per_pack:
|
|
|
self.avg_sequences_per_pack = sum(sequences_per_pack) / len(sequences_per_pack)
|
|
|
|
|
|
print(f"Created {len(self.packed_sequences)} packed sequences")
|
|
|
print(f"Average sequences per pack: {self.avg_sequences_per_pack:.2f}")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.packed_sequences)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return {
|
|
|
"input_ids": self.packed_sequences[idx],
|
|
|
"attention_mask": self.attention_masks[idx],
|
|
|
"labels": self.packed_sequences[idx],
|
|
|
}
|
|
|
|
|
|
def collate_packed_sequences(batch):
|
|
|
"""Collate function for packed sequences that includes attention masks"""
|
|
|
input_ids = torch.stack([item["input_ids"] for item in batch])
|
|
|
attention_masks = torch.stack([item["attention_mask"] for item in batch])
|
|
|
labels = torch.stack([item["labels"] for item in batch])
|
|
|
return {
|
|
|
"input_ids": input_ids,
|
|
|
"attention_mask": attention_masks,
|
|
|
"labels": labels
|
|
|
}
|
|
|
|
|
|
def evaluate_model(model, dataloader, accelerator):
|
|
|
"""Calculate validation loss on a dataset"""
|
|
|
model.eval()
|
|
|
total_loss = 0
|
|
|
total_samples = 0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch in tqdm(dataloader, desc="Evaluating", leave=False):
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
|
|
|
batch_size = batch["input_ids"].size(0)
|
|
|
|
|
|
|
|
|
total_loss += loss.item() * batch_size
|
|
|
total_samples += batch_size
|
|
|
|
|
|
|
|
|
return total_loss / total_samples
|
|
|
|
|
|
def plot_losses(train_losses, val_losses, validation_steps, output_dir):
|
|
|
"""
|
|
|
Plot training and validation losses and save the figure
|
|
|
|
|
|
Args:
|
|
|
train_losses (list): Training loss history
|
|
|
val_losses (list): Validation loss history
|
|
|
validation_steps (list): Steps at which validation was performed
|
|
|
output_dir (Path): Directory to save the plot
|
|
|
"""
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
|
|
|
|
steps = list(range(1, len(train_losses) + 1))
|
|
|
plt.plot(steps, train_losses, label='Training Loss', alpha=0.7, color='blue')
|
|
|
|
|
|
|
|
|
plt.plot(validation_steps, val_losses, label='Validation Loss',
|
|
|
linestyle='--', marker='o', markersize=5, color='red')
|
|
|
|
|
|
plt.xlabel('Steps (x10)')
|
|
|
plt.ylabel('Loss')
|
|
|
plt.title('Training and Validation Loss')
|
|
|
plt.legend()
|
|
|
plt.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
plot_path = output_dir / "loss_plot.png"
|
|
|
plt.savefig(plot_path)
|
|
|
plt.close()
|
|
|
|
|
|
print(f"Loss plot saved to {plot_path}")
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('--data_file', type=Path, default=Path('./data/train.txt'))
|
|
|
parser.add_argument('--val_file', type=Path, default=Path('./data/test.txt'))
|
|
|
parser.add_argument('--model_name', type=str, default='stanford-crfm/music-small-800k')
|
|
|
parser.add_argument('--output_dir', type=Path, default=Path('./fine_tuned'))
|
|
|
parser.add_argument('--batch_size', type=int, default=8)
|
|
|
parser.add_argument('--val_batch_size', type=int, default=16)
|
|
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=32)
|
|
|
parser.add_argument('--learning_rate', type=float, default=3e-5)
|
|
|
parser.add_argument('--max_steps', type=int, default=3500)
|
|
|
parser.add_argument('--save_steps', type=int, default=500)
|
|
|
parser.add_argument('--eval_steps', type=int, default=100)
|
|
|
parser.add_argument('--warmup_steps', type=int, default=500)
|
|
|
parser.add_argument('--force_cpu', action='store_true', help='Force CPU usage even if GPU is available')
|
|
|
parser.add_argument('--reduce_memory', action='store_true', help='Use memory-saving techniques')
|
|
|
parser.add_argument('--context_length', type=int, default=1024, help='Maximum context length')
|
|
|
parser.add_argument('--max_packed_sequences', type=int, default=4,
|
|
|
help='Maximum number of sequences to pack together (set to 1 to disable packing)')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
global device
|
|
|
if args.force_cpu:
|
|
|
device = torch.device("cpu")
|
|
|
print("Forcing CPU usage as requested")
|
|
|
|
|
|
print(f"Effective batch size: {args.batch_size * args.gradient_accumulation_steps}")
|
|
|
print(f"Final device confirmation: {device}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
mixed_precision = 'bf16' if torch.cuda.is_available() and not args.force_cpu else 'no'
|
|
|
print(f"Mixed precision mode: {mixed_precision}")
|
|
|
|
|
|
accelerator = Accelerator(
|
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
|
cpu=args.force_cpu,
|
|
|
mixed_precision=mixed_precision,
|
|
|
)
|
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
print("Initial GPU memory stats:")
|
|
|
print_gpu_memory_stats()
|
|
|
|
|
|
|
|
|
print(f"Loading training dataset from {args.data_file}...")
|
|
|
if args.max_packed_sequences > 1:
|
|
|
print(f"Using sequence packing with max {args.max_packed_sequences} sequences per pack")
|
|
|
train_dataset = SequencePackedDataset(
|
|
|
args.data_file,
|
|
|
context_length=args.context_length,
|
|
|
max_packed_sequences=args.max_packed_sequences
|
|
|
)
|
|
|
collate_fn_train = collate_packed_sequences
|
|
|
else:
|
|
|
print("Sequence packing disabled - using single sequences")
|
|
|
|
|
|
from anticipation.vocab import SEPARATOR
|
|
|
individual_sequences = []
|
|
|
with open(args.data_file, 'r') as f:
|
|
|
for line in f:
|
|
|
tokens = list(map(int, line.strip().split()))
|
|
|
individual_sequences.append(torch.tensor(tokens, dtype=torch.long))
|
|
|
|
|
|
class TokenizedDataset(Dataset):
|
|
|
def __init__(self, sequences):
|
|
|
self.sequences = sequences
|
|
|
self.sequence_length = len(self.sequences[0]) if self.sequences else 0
|
|
|
print(f"Loaded {len(self.sequences)} sequences with length {self.sequence_length}")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.sequences)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
tokens = self.sequences[idx]
|
|
|
return {"input_ids": tokens, "labels": tokens}
|
|
|
|
|
|
train_dataset = TokenizedDataset(individual_sequences)
|
|
|
|
|
|
def collate_fn_train(batch):
|
|
|
input_ids = torch.stack([item["input_ids"] for item in batch])
|
|
|
labels = torch.stack([item["labels"] for item in batch])
|
|
|
return {"input_ids": input_ids, "labels": labels}
|
|
|
|
|
|
train_dataloader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=args.batch_size,
|
|
|
shuffle=True,
|
|
|
collate_fn=collate_fn_train,
|
|
|
pin_memory=torch.cuda.is_available() and not args.force_cpu,
|
|
|
num_workers=0,
|
|
|
)
|
|
|
|
|
|
|
|
|
print(f"Loading validation dataset from {args.val_file}...")
|
|
|
if args.max_packed_sequences > 1:
|
|
|
val_dataset = SequencePackedDataset(
|
|
|
args.val_file,
|
|
|
context_length=args.context_length,
|
|
|
max_packed_sequences=args.max_packed_sequences
|
|
|
)
|
|
|
collate_fn_val = collate_packed_sequences
|
|
|
else:
|
|
|
|
|
|
val_sequences = []
|
|
|
with open(args.val_file, 'r') as f:
|
|
|
for line in f:
|
|
|
tokens = list(map(int, line.strip().split()))
|
|
|
val_sequences.append(torch.tensor(tokens, dtype=torch.long))
|
|
|
|
|
|
val_dataset = TokenizedDataset(val_sequences)
|
|
|
collate_fn_val = collate_fn_train
|
|
|
|
|
|
val_dataloader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=args.val_batch_size,
|
|
|
shuffle=False,
|
|
|
collate_fn=collate_fn_val,
|
|
|
pin_memory=torch.cuda.is_available() and not args.force_cpu,
|
|
|
num_workers=0,
|
|
|
)
|
|
|
|
|
|
|
|
|
print(f"Loading model {args.model_name}...")
|
|
|
model_kwargs = {
|
|
|
"trust_remote_code": True,
|
|
|
"use_cache": False,
|
|
|
}
|
|
|
|
|
|
if args.reduce_memory and torch.cuda.is_available():
|
|
|
print("Using memory reduction techniques...")
|
|
|
|
|
|
model_kwargs.update({
|
|
|
"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
|
|
"low_cpu_mem_usage": True,
|
|
|
})
|
|
|
|
|
|
try:
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
args.model_name,
|
|
|
**model_kwargs
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(f"Error loading model with advanced options: {e}")
|
|
|
print("Trying with basic options...")
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
args.model_name,
|
|
|
trust_remote_code=True,
|
|
|
use_cache=False
|
|
|
)
|
|
|
|
|
|
|
|
|
print("GPU memory after loading model:")
|
|
|
print_gpu_memory_stats()
|
|
|
|
|
|
|
|
|
model = model.to(device)
|
|
|
print(f"Model moved to: {next(model.parameters()).device}")
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = AdamW(
|
|
|
model.parameters(),
|
|
|
lr=args.learning_rate,
|
|
|
eps=1e-6,
|
|
|
weight_decay=0.01,
|
|
|
betas=(0.9, 0.999),
|
|
|
)
|
|
|
|
|
|
|
|
|
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
|
|
|
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
|
|
print(f"After accelerator preparation, model device: {next(model.parameters()).device}")
|
|
|
|
|
|
|
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
|
optimizer=optimizer,
|
|
|
num_warmup_steps=args.warmup_steps,
|
|
|
num_training_steps=args.max_steps,
|
|
|
)
|
|
|
|
|
|
|
|
|
print("GPU memory before training:")
|
|
|
print_gpu_memory_stats()
|
|
|
|
|
|
|
|
|
torch.autograd.set_detect_anomaly(False)
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.deterministic = False
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
print("Clearing CUDA cache before training")
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.set_device(0)
|
|
|
|
|
|
|
|
|
print("Starting training...")
|
|
|
model.train()
|
|
|
completed_steps = 0
|
|
|
step = 0
|
|
|
|
|
|
|
|
|
train_losses = []
|
|
|
val_losses = []
|
|
|
validation_steps = []
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(total=args.max_steps, desc="Training", disable=False)
|
|
|
|
|
|
try:
|
|
|
while completed_steps < args.max_steps:
|
|
|
for batch in train_dataloader:
|
|
|
try:
|
|
|
with accelerator.accumulate(model):
|
|
|
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
|
|
|
if torch.isnan(loss).any() or torch.isinf(loss).any():
|
|
|
print(f"WARNING: NaN or Inf loss detected: {loss.item()}")
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
continue
|
|
|
|
|
|
|
|
|
accelerator.backward(loss)
|
|
|
|
|
|
|
|
|
if accelerator.sync_gradients:
|
|
|
|
|
|
accelerator.clip_grad_norm_(model.parameters(), max_norm=0.5)
|
|
|
|
|
|
|
|
|
has_nan_grads = False
|
|
|
for name, param in model.named_parameters():
|
|
|
if param.grad is not None and torch.isnan(param.grad).any():
|
|
|
print(f"NaN gradient detected in {name}")
|
|
|
has_nan_grads = True
|
|
|
break
|
|
|
|
|
|
if has_nan_grads:
|
|
|
print("Skipping update due to NaN gradients")
|
|
|
optimizer.zero_grad()
|
|
|
continue
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
scheduler.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
completed_steps += 1
|
|
|
progress_bar.update(1)
|
|
|
|
|
|
|
|
|
if completed_steps % 10 == 0:
|
|
|
|
|
|
train_losses.append(loss.item())
|
|
|
|
|
|
|
|
|
print(f"Step: {completed_steps}/{args.max_steps}, Loss: {loss.item():.4f}, "
|
|
|
f"LR: {scheduler.get_last_lr()[0]:.8e}")
|
|
|
|
|
|
|
|
|
if check_model_for_nans(model):
|
|
|
print("NaN parameters detected in model! Training may be unstable.")
|
|
|
|
|
|
|
|
|
if completed_steps % 100 == 0:
|
|
|
print_gpu_memory_stats()
|
|
|
|
|
|
|
|
|
if completed_steps % args.eval_steps == 0:
|
|
|
print(f"\nRunning validation at step {completed_steps}...")
|
|
|
val_loss = evaluate_model(model, val_dataloader, accelerator)
|
|
|
validation_steps.append(completed_steps // 10)
|
|
|
val_losses.append(val_loss)
|
|
|
print(f"Validation Loss: {val_loss:.4f}")
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
if completed_steps % args.save_steps == 0:
|
|
|
checkpoint_dir = args.output_dir / f"checkpoint-{completed_steps}"
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
unwrapped_model.save_pretrained(
|
|
|
checkpoint_dir,
|
|
|
is_main_process=accelerator.is_main_process,
|
|
|
save_function=accelerator.save,
|
|
|
)
|
|
|
print(f"Saved checkpoint to {checkpoint_dir}")
|
|
|
|
|
|
|
|
|
np.savez(
|
|
|
checkpoint_dir / "losses.npz",
|
|
|
train_losses=np.array(train_losses),
|
|
|
val_losses=np.array(val_losses),
|
|
|
validation_steps=np.array(validation_steps)
|
|
|
)
|
|
|
|
|
|
|
|
|
plot_losses(train_losses, val_losses, validation_steps, checkpoint_dir)
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
if not accelerator.sync_gradients:
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
if completed_steps >= args.max_steps:
|
|
|
break
|
|
|
|
|
|
except RuntimeError as e:
|
|
|
if "CUDA out of memory" in str(e):
|
|
|
print(f"CUDA OOM error! Current batch size: {args.batch_size}")
|
|
|
print(f"Current memory usage:")
|
|
|
print_gpu_memory_stats()
|
|
|
print("Consider reducing batch size or model size.")
|
|
|
print(f"Error details: {str(e)}")
|
|
|
raise
|
|
|
elif "nan" in str(e).lower() or "inf" in str(e).lower():
|
|
|
print(f"NaN/Inf error: {str(e)}")
|
|
|
print("Trying to recover by skipping this batch...")
|
|
|
optimizer.zero_grad()
|
|
|
continue
|
|
|
else:
|
|
|
print(f"Runtime error: {str(e)}")
|
|
|
print(traceback.format_exc())
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error during training: {e}")
|
|
|
print(traceback.format_exc())
|
|
|
raise
|
|
|
finally:
|
|
|
|
|
|
progress_bar.close()
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
print("\nRunning final validation...")
|
|
|
final_val_loss = evaluate_model(model, val_dataloader, accelerator)
|
|
|
validation_steps.append(completed_steps // 10)
|
|
|
val_losses.append(final_val_loss)
|
|
|
print(f"Final validation Loss: {final_val_loss:.4f}")
|
|
|
|
|
|
|
|
|
final_dir = args.output_dir / "final"
|
|
|
os.makedirs(final_dir, exist_ok=True)
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
unwrapped_model.save_pretrained(
|
|
|
final_dir,
|
|
|
is_main_process=accelerator.is_main_process,
|
|
|
save_function=accelerator.save,
|
|
|
)
|
|
|
print(f"Saved final model to {final_dir}")
|
|
|
|
|
|
|
|
|
np.savez(
|
|
|
final_dir / "losses.npz",
|
|
|
train_losses=np.array(train_losses),
|
|
|
val_losses=np.array(val_losses),
|
|
|
validation_steps=np.array(validation_steps)
|
|
|
)
|
|
|
|
|
|
|
|
|
plot_losses(train_losses, val_losses, validation_steps, final_dir)
|
|
|
|
|
|
except Exception as save_error:
|
|
|
print(f"Error saving final model or generating plot: {save_error}")
|
|
|
|
|
|
except Exception as setup_error:
|
|
|
print(f"Error in setup: {setup_error}")
|
|
|
print(traceback.format_exc())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |