SFT_Dataset / sft-mix-main-14b.py
Parsagh1383's picture
Upload folder using huggingface_hub
e6fad38 verified
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
from datetime import datetime
# Model Configuration
# Choose ONE of the following options:
# Option 1: 3B model from HuggingFace (RECOMMENDED - no compatibility issues)
# MODEL_NAME = 'unsloth/Qwen2.5-3B-Instruct'
# Option 2: 7B model from HuggingFace (larger, slower, needs more VRAM)
# MODEL_NAME = 'unsloth/Qwen2.5-7B-Instruct'
# Option 3: 14B model from HuggingFace (larger, slower, needs more VRAM)
# model_name = "Qwen/Qwen2.5-14B-Instruct"
MODEL_NAME = '/home/moein_salimi/PLLMS/unsloth-Qwen2.5-14B-Instruct-bnb-4bit'
LOAD_IN_4BIT = True
LOAD_IN_8BIT = False
USE_VLLM = False
LORA_RANK = 64
LORA_ALPHA = 64
LORA_DROPOUT = 0.05
GPU_MEMORY_UTILIZATION = 1.0
MAX_SEQ_LENGTH = 4096
MAX_PROMPT_LENGTH = 2048
MAX_COMPLETION_LENGTH = MAX_SEQ_LENGTH - MAX_PROMPT_LENGTH
RESUME_FROM_CHECKPOINT = False
PREVIOUS_RUN_DIR = 'dt11.15.23:13_e20_unsloth_Qwen2.5_3B_Instruct_unsloth_bnb_4bit_bnb_4bit_lr1e-05_t0.7_ε0.2_r64_b16'
RUN_DESC = "SFT_Implementation" # MODIFIED: Added default description for SFT
CUDA_VISIBLE_DEVICES = "0"
# Training Configuration
LEARNING_RATE = 5e-6
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.99
WEIGHT_DECAY = 0.3
WARMUP_STEPS = 10
LR_SCHEDULER_TYPE = "cosine"
OPTIM = "adamw_torch"
# Validation Configuration
EVAL_STEPS = 512 # Evaluate on validation set every N steps (it's useless now. we're doing it at the end of each epoch)
SAVE_STEPS = 512 #TODO: Adjust this (eval too)
LOG_VALIDATION = True # Whether to log validation metrics
LOG_TRAIN_EVERY = 1 # Save training log every N completions (not every step)
# Training Loop Settings
PER_DEVICE_TRAIN_BATCH_SIZE = 4
PER_DEVICE_EVAL_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 4 # try 2, 4 and 8
# [MODIFIED: NUM_GENERATIONS is specific to GRPO/RL. SFT does not generate samples during training.]
# NUM_GENERATIONS = 8
MAX_GRAD_NORM = 0.1
TEMPERATURE = 0.0
NUM_TRAIN_EPOCHS = 6
# DATASET CONFIGURATION
# Set to True to use both UniADILR and COPA.
# Set to False to train/validate on UniADILR ONLY.
MIXED_DATA = True
# Data Configuration
ERROR_LOG_PATH = "error_log.log"
TRAINING_LOG_PATH = "training_log.json"
VALIDATION_LOG_PATH = "validation_log.json"
VALIDATION_METRICS_PATH = "val_metrics.json"
# System Prompt for Abductive Reasoning
SYSTEM_PROMPT_UniADILR = """
You are an expert in logical reasoning and abductive inference. Your task is to identify which sentences from a given context provide the necessary evidence to support or explain a hypothesis.
You will be provided with:
1. A Context containing multiple numbered sentences (sent1, sent2, sent3, etc.)
2. A Hypothesis that needs to be supported or explained
Your goal is to identify which sentence(s) from the context, when combined, provide the logical foundation for the hypothesis through abductive reasoning.
## Instructions:
1. Carefully read all sentences in the context
2. Analyze the hypothesis
3. Identify which sentences, when combined, best explain or support the hypothesis
4. Consider both direct evidence and logical connections
## Output Format:
You MUST provide your answer in the following format:
<think>
[Explain your thought process: why you selected these particular sentences and how they support the hypothesis]
</think>
<answer>
[Sentence numbers only, comma-separated. For example: 5, 13 or 2, 7, 9]
</answer>
CRITICAL: The answer section must contain ONLY the sentence numbers separated by commas. Do not include the word "sent" or any other text.
""".strip()
SYSTEM_PROMPT_balanced_copa_cause_only = """
You are an expert in logical reasoning and abductive inference. Your task is to determine which of two given choices represents the most plausible cause for a given premise.
You will be provided with:
1. A Premise describing a situation or event
2. Two Choices (Choice 1 and Choice 2)
Your goal is to select the choice that best explains WHY the premise happened - identifying the root cause that led to the described situation.
## Instructions:
1. Carefully read the premise
2. Evaluate both choices as potential causes
3. Consider common sense, real-world knowledge, and typical causal relationships when making your decision
4. Select the choice that represents the most plausible and direct cause
## Output Format:
You MUST provide your answer in the following format:
<think>
[Explain your thought process: why we should select one choice over the other or analyzing the cause or their relationships]
</think>
<answer>
[Either "1" or "2" - just the number, nothing else]
</answer>
CRITICAL: The answer section must contain ONLY the number 1 or 2. Do not include any other text, explanation, or punctuation.
""".strip()
# Random State Configuration
RANDOM_STATE = 3407
TORCH_SEED = 42
NUMPY_SEED = 42
# Environment Configuration
WANDB_DISABLED = "true"
# In[ ]:
import os
import sys
# -----------------------------------------------------------------------------
# 1. PLATFORM DETECTION & PATH SETUP
# -----------------------------------------------------------------------------
# # 1. Check if running on Kaggle
# IS_KAGGLE = os.path.exists('/kaggle/input/abduction-dataset')
# # 2. Check if running on Colab
# IS_COLAB = False
# if not IS_KAGGLE:
# try:
# from google.colab import drive
# IS_COLAB = True
# except ImportError:
# IS_COLAB = False
# print(f"🌍 Environment Detected: {'Kaggle' if IS_KAGGLE else 'Colab' if IS_COLAB else 'Local/Other'}")
# if IS_KAGGLE:
# # --- KAGGLE SETUP ---
# BASE_DATA_DIR = "/kaggle/input/abduction-dataset"
# BASE_OUTPUT_DIR = "/kaggle/working"
# # # Unsloth installation for Kaggle
# !pip install unsloth
# !pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth
# !pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth_zoo
# # --- KAGGLE SETUP ---
# # # Kaggle input directory
# # BASE_DATA_DIR = "/kaggle/input/abduction-dataset"
# # # Kaggle output working directory
# # BASE_OUTPUT_DIR = "/kaggle/working"
# # # Unsloth installation for Kaggle
# # !pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"
# # !pip install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
# elif IS_COLAB:
# # --- COLAB SETUP ---
# drive.mount('/content/drive')
# # Dataset is uploaded to the root or specific folder
# BASE_DATA_DIR = "./dataset"
# # Output to Google Drive
# BASE_OUTPUT_DIR = "/content/drive/My Drive/Abductive_SFT_Results"
# # Unsloth installation for Colab
# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install --no-deps xformers trl peft accelerate bitsandbytes
# else:
# --- LOCAL SETUP ---
print("💻 Running Locally")
BASE_DATA_DIR = "./dataset"
BASE_OUTPUT_DIR = "./results_sft_14b"
# Verify installation
import torch
print(f"\n🔥 PyTorch version: {torch.__version__}")
print(f"🎮 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"🎮 CUDA version: {torch.version.cuda}")
os.makedirs(BASE_OUTPUT_DIR, exist_ok=True)
print(f"\n📂 Data Directory: {BASE_DATA_DIR}")
print(f"💾 Output Directory: {BASE_OUTPUT_DIR}")
#=======================================================================
# Output Configuration
def get_run_name():
"""Generate run name based on configuration"""
model_name = MODEL_NAME.split("/")[-1].replace("-", "_")
if LOAD_IN_8BIT:
model_name += "_8bit"
elif LOAD_IN_4BIT:
model_name += "_bnb_4bit"
now = datetime.now()
# MODIFIED: Removed EPSILON from name and added SFT tag
name = f"SFT_dt{now.strftime('%m.%d.%H:%M')}_e{NUM_TRAIN_EPOCHS}_{model_name}_lr{LEARNING_RATE}_t{TEMPERATURE}_r{LORA_RANK}_b{PER_DEVICE_TRAIN_BATCH_SIZE}"
if RUN_DESC:
name += f"_{RUN_DESC}"
return name
def get_results_dir(run_name=None):
"""Get results directory path based on environment"""
if run_name is None:
run_name = get_run_name()
if RESUME_FROM_CHECKPOINT:
run_name = PREVIOUS_RUN_DIR
# MODIFIED: Use the dynamic base path we set up earlier
return os.path.join(BASE_OUTPUT_DIR, run_name)
# In[3]:
# Environment setup and configuration
import os
import sys
import warnings
warnings.filterwarnings('ignore')
import random
import numpy as np
import torch
# Add current directory to path for imports
sys.path.append('.')
# Set random seeds for reproducibility
random.seed(RANDOM_STATE)
np.random.seed(NUMPY_SEED)
torch.manual_seed(TORCH_SEED)
torch.cuda.manual_seed_all(TORCH_SEED)
print(f"🎲 Random seeds set:")
print(f" Python: {RANDOM_STATE}")
print(f" NumPy: {NUMPY_SEED}")
print(f" PyTorch: {TORCH_SEED}")
# Set environment variables
os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
os.environ["WANDB_DISABLED"] = WANDB_DISABLED
# MODIFIED: Changed title to reflect SFT pipeline
print("\n🔧 Abductive Reasoning SFT Training Pipeline")
print("=" * 50)
print(f"Configuration loaded:")
print(f" 📦 Model: {MODEL_NAME}")
print(f" 🎯 Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}")
print(f" 🏃 Epochs: {NUM_TRAIN_EPOCHS}")
print(f" 📈 Learning rate: {LEARNING_RATE}")
# Note: Temperature is primarily used here for validation/evaluation generation,
# as SFT training typically uses standard Cross Entropy loss.
print(f" 🌡️ Temperature: {TEMPERATURE}")
print(f" 🎮 GPU: {CUDA_VISIBLE_DEVICES}")
# In[4]:
# 1. Check the environment variable (what the OS tells the process)
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
# 2. Check what PyTorch actually sees
print(f"Number of visible GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
# In[5]:
# Diagnostic: inspect torch internals
import torch, os, sys
print("torch.__version__:", torch.__version__)
print("torch.cuda.is_available():", torch.cuda.is_available())
print("torch.version.cuda:", torch.version.cuda)
try:
print("has attr 'UnsupportedMutationAliasingException':",
hasattr(torch._subclasses.fake_tensor, "UnsupportedMutationAliasingException"))
except Exception as e:
print("Checking attribute failed:", type(e).__name__, e)
# In[6]:
# get_ipython().run_cell_magic('capture', '', 'if USE_VLLM:\n !pip install vllm\n import vllm\n')
# In[7]:
# Import required libraries
import torch
import json
import re
import time
from datasets import Dataset
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainerCallback, TrainingArguments
import matplotlib.pyplot as plt
print("🔍 System Check:")
print("=" * 30)
# Check GPU setup
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
print(f"Number of visible GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"✅ GPU Available")
print(f" Current device: {torch.cuda.current_device()}")
print(f" GPU name: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("❌ No GPU available!")
print(f"✅ PyTorch version: {torch.__version__}")
# In[8]:
import json
from datasets import Dataset
import os # Make sure os is imported
print("\n📂 Loading Pre-Split Data and Transforming")
print("=" * 40)
LIMITED_VAL = False
LIMITED_TRAIN = False
NUMBER_OF_TRAIN_SAMPLES = 100
NUMBER_OF_VAL_SAMPLES = 10
# MODIFIED: Construct paths dynamically
train_path = os.path.join(BASE_DATA_DIR, 'train_split.json')
val_path = os.path.join(BASE_DATA_DIR, 'val_split.json')
# Load the raw splits from JSON files
print(f"Loading train split from: {train_path}")
with open(train_path, 'r', encoding='utf-8') as f:
raw_train_data = json.load(f)
# FILTERING LOGIC
if MIXED_DATA:
train_data = raw_train_data
print(f" ✅ MIXED_DATA=True: Using all {len(train_data)} samples (UniADILR + COPA)")
else:
# Filter for UniADILR only (entries containing 'context')
train_data = [item for item in raw_train_data if 'context' in item]
print(f" ⚠️ MIXED_DATA=False: Filtered for UniADILR only.")
print(f" 📉 Training samples reduced from {len(raw_train_data)} to {len(train_data)}")
if LIMITED_TRAIN:
train_data = train_data[:NUMBER_OF_TRAIN_SAMPLES]
print(f" ⚠️ Limited training set to {len(train_data)} samples")
print(f"Loading validation split from: {val_path}")
with open(val_path, 'r', encoding='utf-8') as f:
raw_val_data = json.load(f)
# FILTERING LOGIC
if MIXED_DATA:
val_data = raw_val_data
print(f" ✅ MIXED_DATA=True: Using all {len(val_data)} validation samples")
else:
val_data = [item for item in raw_val_data if 'context' in item]
print(f" 📉 Validation samples reduced from {len(raw_val_data)} to {len(val_data)}")
if LIMITED_VAL:
val_data = val_data[:NUMBER_OF_VAL_SAMPLES]
print(f" ⚠️ Limited validation set to {len(val_data)} samples")
def transform_to_prompt_format(example, record_id):
"""
Transform the original JSONL format to the required prompt format.
Handles both UniADILR and balanced_copa_cause_only datasets.
"""
dataset_name = example.get('datasetName', '')
# MODIFIED: Define variables to hold the formatted parts for SFT
system_prompt_content = ""
user_content = ""
assistant_content = "" # For SFT, we need the target response
if dataset_name == 'UniADILR':
# Build the context string for UniADILR
context_lines = []
for key, value in example['context'].items():
context_lines.append(f"{key}: {value}")
context_str = "\n".join(context_lines)
# Create the user prompt for UniADILR
user_content = f"""Context:
{context_str}
Hypothesis:
{example['hypothesis']}
Based on the context and hypothesis above, identify which sentence(s) provide the necessary evidence for the hypothesis."""
system_prompt_content = SYSTEM_PROMPT_UniADILR
# Parse ground truth for UniADILR to format the answer
# The ground truth in example['proof'] is like "sent1 & sent2 -> hypothesis"
proof_str = example['proof']
if '->' in proof_str:
proof_str = proof_str.split('->')[0]
numbers = re.findall(r'sent(\d+)', proof_str)
# Format: "1, 2"
formatted_answer = ", ".join(sorted(numbers, key=int))
# MODIFIED: Construct the target assistant response
# Since we don't have 'thought' traces in the dataset, we provide a generic or empty thought block
# to teach the model the structure, or just the answer if that's preferred.
# Here we mimic the requested format.
assistant_content = f"""<think>
The hypothesis requires evidence from the provided context. I will identify the sentences that support this claim.
</think>
<answer>
{formatted_answer}
</answer>"""
ground_truth = json.dumps(example['proof'])
elif dataset_name == 'balanced_copa_cause_only':
# Create the user prompt for COPA
user_content = f"""Premise: {example['premise']}
Question: {example['question']}
Choice 1: {example['choice1']}
Choice 2: {example['choice2']}
Which choice is the most plausible cause for the premise?"""
system_prompt_content = SYSTEM_PROMPT_balanced_copa_cause_only
# For COPA, the ground truth is the label (0 or 1 -> 1 or 2)
correct_choice = str(example['label'] + 1)
ground_truth = correct_choice
# MODIFIED: Construct the target assistant response for SFT
assistant_content = f"""<think>
I need to identify the most plausible cause for the premise among the two choices.
</think>
<answer>
{correct_choice}
</answer>"""
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")
# Create the prompt structure
# MODIFIED: For SFT, the "prompt" field usually contains the full conversation history including the assistant's turn
# However, trl's SFTTrainer often handles formatting. We will stick to the standard list of dicts format.
prompt = [
{
"role": "system",
"content": system_prompt_content
},
{
"role": "user",
"content": user_content
},
# MODIFIED: Added the assistant turn for SFT training
{
"role": "assistant",
"content": assistant_content
}
]
# Return the transformed example
return {
"prompt": prompt, # Contains System, User, AND Assistant
"record_id": record_id,
"ground_truth": ground_truth,
"reasoning_type": example.get('reasoning_type', 'abduction'),
"dataset_name": dataset_name
}
# Transform each split
print("\nTransforming train data to prompt format...")
train_transformed = []
for idx, example in enumerate(train_data):
train_transformed.append(transform_to_prompt_format(example, record_id=idx))
print("Transforming validation data to prompt format...")
val_transformed = []
for idx, example in enumerate(val_data):
val_transformed.append(transform_to_prompt_format(example, record_id=idx))
# print("Transforming test data to prompt format...")
# test_transformed = []
# for idx, example in enumerate(test_data):
# test_transformed.append(transform_to_prompt_format(example, record_id=idx))
print(f"✅ Transformed all splits")
# Convert to HuggingFace datasets
print("\nConverting to HuggingFace datasets...")
train_ds = Dataset.from_list(train_transformed)
val_ds = Dataset.from_list(val_transformed)
# test_ds = Dataset.from_list(test_transformed)
# Display the first training example to verify format
print("\n" + "="*80)
print("🔍 FIRST TRAINING EXAMPLE (to verify system prompt)")
print("="*80)
first_example = train_ds[0]
print(f"\n📋 Example keys: {list(first_example.keys())}")
print(f"\n🆔 Record ID: {first_example.get('record_id', 'N/A')}")
print("\n💬 PROMPT STRUCTURE:")
print("-" * 80)
for i, msg in enumerate(first_example['prompt']):
role = msg.get('role', 'unknown')
content = msg.get('content', '')
print(f"\n[Message {i+1}] Role: {role.upper()}")
print("-" * 40)
# Show first 500 characters of content to avoid overwhelming output
if len(content) > 500:
print(f"{content[:500]}...")
print(f"\n... (Content truncated - total length: {len(content)} characters)")
else:
print(content)
print("-" * 40)
# Log the prompt structure to a file
log_file = './prompt_structure_log.txt'
with open(log_file, 'w', encoding='utf-8') as f:
for i, msg in enumerate(first_example['prompt']):
role = msg.get('role', 'unknown')
content = msg.get('content', '')
f.write(f"\n[Message {i+1}] Role: {role.upper()}\n")
f.write("-" * 40 + "\n")
f.write(content + "\n")
f.write("-" * 40 + "\n")
print(f"✅ Prompt structure logged to: {log_file}")
print("\n" + "="*80)
# MODIFIED: Commented out statistics for test_ds since it is not loaded
# total = len(train_ds) + len(val_ds) + len(test_ds)
print(f"\n✅ Datasets loaded, transformed, and ready!")
print(f"\n📈 Dataset Statistics:")
# print(f" Total samples: {total:,}")
print(f" Training samples: {len(train_ds):,}")
print(f" Validation samples: {len(val_ds):,}")
# print(f" Test samples: {len(test_ds):,} ({len(test_ds)/total*100:.0f}%)")
# In[9]:
# Verify loaded datasets
print("\n🛠️ Verifying Loaded Datasets")
print("=" * 35)
# Calculate prompt statistics from loaded datasets
prompt_lengths = []
full_lengths = [] # MODIFIED: New list to track total length (System + User + Assistant) for SFT
for ds in [train_ds, val_ds]:
for example in ds:
# Extract user prompt length from the prompt field
current_full_length = 0
user_found = False
for msg in example['prompt']:
content = msg.get('content', '')
current_full_length += len(content)
if isinstance(msg, dict) and msg.get('role') == 'user':
prompt_lengths.append(len(content))
user_found = True
if user_found:
full_lengths.append(current_full_length)
print(f"✅ Datasets ready for training!")
print(f" Total prompts: {len(prompt_lengths):,}")
print(f" Max user prompt length: {max(prompt_lengths)} characters")
print(f" Average user prompt length: {sum(prompt_lengths)/len(prompt_lengths):.0f} characters")
# MODIFIED: Added stats for full length
print(f" Max full length (est): {max(full_lengths)} characters")
print(f" Average full length (est): {sum(full_lengths)/len(full_lengths):.0f} characters")
print(f"\n Sample keys in training data: {list(train_ds[0].keys())}")
# Show example of answer field
print(f"\n📋 Example answer from first training sample:")
print(f" answer: {train_ds[0]['ground_truth']}")
print(f" answer type: {type(train_ds[0]['ground_truth'])}")
# Show a snippet of the user prompt for context
print(f"\n📝 Example user prompt (first 200 chars):")
for msg in train_ds[0]['prompt']:
if isinstance(msg, dict) and msg.get('role') == 'user':
user_content = msg.get('content', '')
print(f" {user_content[:200]}...")
break
# In[10]:
from unsloth import FastLanguageModel, is_bfloat16_supported
from huggingface_hub import HfApi
import os
from tqdm.auto import tqdm
import time
start_time = time.time()
def format_bytes(bytes_value):
"""Convert bytes to human-readable format"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_value < 1024.0:
return f"{bytes_value:.2f} {unit}"
bytes_value /= 1024.0
return f"{bytes_value:.2f} PB"
def get_model_size(model_name):
"""Try to get model size from HuggingFace Hub"""
try:
api = HfApi()
model_info = api.model_info(model_name)
# Sum up all file sizes
total_size = sum(file.size for file in model_info.siblings if file.size)
return total_size
except:
return None
# Configure download settings
print("🔧 Configuring Hugging Face Hub download settings...")
print("=" * 60)
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "240"
print("✓ Download timeout: 240 seconds per chunk")
print("✓ Using default retry settings")
print()
# Get model size info
print("📊 Fetching model information...")
model_size = get_model_size(MODEL_NAME)
if model_size:
print(f"✓ Model size: {format_bytes(model_size)}")
print(f"✓ Estimated download time: ~{model_size / (10 * 1024 * 1024):.0f} seconds (at 10 MB/s)")
else:
print("⚠ Could not determine model size")
print()
# Load model with progress tracking
print("🤖 Model Setup")
print("=" * 60)
print(f"📦 Model: {MODEL_NAME}")
print(f"🔢 Max sequence length: {MAX_SEQ_LENGTH}")
print(f"⚙️ Quantization: {'4-bit' if LOAD_IN_4BIT else '8-bit' if LOAD_IN_8BIT else 'None'}")
print(f"🚀 Fast inference (vLLM): {USE_VLLM}")
print(f"💾 GPU memory utilization: {GPU_MEMORY_UTILIZATION}")
print()
print("⏳ Downloading and loading model...")
print(" (This may take several minutes depending on your connection)")
print()
download_start = time.time()
# Create a simple progress indicator
class ProgressCallback:
def __init__(self):
self.last_print = time.time()
self.dots = 0
def update(self):
current = time.time()
if current - self.last_print > 2: # Print every 2 seconds
self.dots = (self.dots + 1) % 4
elapsed = current - download_start
print(f"\r Downloading{'.' * (self.dots + 1)}{' ' * (3 - self.dots)} " +
f"[{elapsed:.0f}s elapsed]", end='', flush=True)
self.last_print = current
progress = ProgressCallback()
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
max_length=MAX_SEQ_LENGTH,
load_in_4bit=LOAD_IN_4BIT,
load_in_8bit=LOAD_IN_8BIT,
fast_inference=USE_VLLM,
max_lora_rank=LORA_RANK,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
)
print("\r" + " " * 80 + "\r", end='') # Clear progress line
download_time = time.time() - download_start
print(f"✅ Model downloaded and loaded successfully!")
print(f"⏱️ Total time: {download_time:.1f}s ({download_time/60:.1f} minutes)")
if model_size:
avg_speed = model_size / download_time
print(f"📈 Average speed: {format_bytes(avg_speed)}/s")
print()
except Exception as e:
print(f"\n❌ Error loading model: {e}")
raise
# Configure tokenizer
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("✓ Configured pad token")
print()
# Apply LoRA
print("🔧 Applying LoRA configuration...")
print("-" * 60)
lora_start = time.time()
model = FastLanguageModel.get_peft_model(
model,
r=LORA_RANK,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=LORA_ALPHA,
# MODIFIED: Added dropout for regularization
lora_dropout=LORA_DROPOUT, # Set to 0.05 or 0.1 to prevent overfitting
use_gradient_checkpointing="unsloth",
random_state=RANDOM_STATE
)
lora_time = time.time() - lora_start
print(f"✅ LoRA configured successfully! ({lora_time:.1f}s)")
print()
# Model statistics
print("📊 Model Statistics")
print("=" * 60)
print(f"🎯 LoRA Configuration:")
print(f" • Rank (r): {LORA_RANK}")
print(f" • Alpha: {LORA_ALPHA}")
print(f" • Target modules: 7 (q, k, v, o, gate, up, down projections)")
print()
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
frozen_params = total_params - trainable_params
print(f"🔢 Parameters:")
print(f" • Total: {total_params:,}")
print(f" • Trainable: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
print(f" • Frozen: {frozen_params:,} ({100*frozen_params/total_params:.2f}%)")
print()
total_setup_time = time.time() - start_time
print(f"⏱️ Total Setup Time: {total_setup_time:.1f}s ({total_setup_time/60:.1f} minutes)")
print(f" • Model download/load: {download_time:.1f}s")
print(f" • LoRA configuration: {lora_time:.1f}s")
print("=" * 60)
print("✨ Ready to train!")
# In[ ]:
import logging
# MODIFIED: Changed title to reflect SFT context (Metrics instead of Reward)
# Setup evaluation metrics and output directories
print("\n🎯 Evaluation Metrics & Output Setup")
print("=" * 30)
# Create run name and get deterministic results directory (no renaming later)
run_name = get_run_name()
results_dir = get_results_dir(run_name)
# Create the directory structure first
os.makedirs(results_dir, exist_ok=True)
os.makedirs(os.path.join(results_dir, "checkpoint"), exist_ok=True)
logging.basicConfig(
filename=os.path.join(results_dir, ERROR_LOG_PATH),
level=logging.WARNING,
format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s() - %(message)s'
)
print(f"📁 Results directory: {results_dir}")
print(f"🏷️ Run name: {run_name}")
# Define deterministic metric function (formerly reward function)
def extract_sentence_numbers(text, datasetName):
"""Extract sentence numbers from model output if UniADILR, if not no need to do anything.
Looks for content within <answer> tags and extracts comma-separated numbers.
Returns a set of integers.
"""
# Try to find answer tags
answer_match = re.search(r'<answer>\s*([^<]+?)\s*</answer>', text, re.IGNORECASE | re.DOTALL)
if answer_match:
answer_content = answer_match.group(1)
else:
# If no tags found, use the entire text
answer_content = "" # if it's empty the reward/score will be 0
if datasetName == 'UniADILR':
# Extract all numbers from the answer content
numbers = re.findall(r'\b(\d+)\b', answer_content)
return set(int(n) for n in numbers)
elif datasetName == 'balanced_copa_cause_only':
if answer_content == '':
return -123
return int(answer_content)
else:
try:
return int(answer_content)
except:
return -1
def parse_proof(proof_str, datasetName):
"""
If datasetName is UniADILR,
Parse ground truth proof string to extract sentence numbers.
Example: 'sent5 & sent13 -> hypothesis' returns {5, 13}
if datasetName is 'copa' just return the answer number
Example: 2 returns 2
"""
if datasetName == 'UniADILR':
# Extract sentence numbers from proof (before '->')
if '->' in proof_str:
proof_str = proof_str.split('->')[0]
numbers = re.findall(r'sent(\d+)', proof_str)
return set(int(n) for n in numbers)
elif datasetName == 'balanced_copa_cause_only':
return int(proof_str)
else:
return int(proof_str)
class AbductiveRewardFunction:
"""
MODIFIED: In SFT, this class acts as a Metrics Calculator for Validation.
It calculates 'Exact Match' accuracy which we log as 'reward' to maintain compatibility
with your existing visualization tools.
"""
def __init__(self, dataset, tokenizer, output_path, log_every=50):
self.dataset = dataset # Keep for validation only
self.tokenizer = tokenizer
self.output_path = output_path
self.current_epoch = 1
self.training_log = []
self.step_losses = []
self.log_every = log_every
print("🛠️ Building prompt-to-[ground_truth, datasetName] lookup table for evaluation...")
self.lookup_table = {}
missing_ground_truths = 0
flag = False
for record in self.dataset:
# We must apply the chat template exactly as the trainer will.
# In SFT formatting (Part 2), prompt[1] is the user message.
if not flag:
print(f"prompt before apply chat template 1: {record['prompt'][1]['content']}")
prompt_text = record['prompt'][1]['content']
datasetName = record['dataset_name']
if not flag:
print(f"prompt_text: {prompt_text}")
flag = True
ground_truth = record.get('ground_truth', '')
if ground_truth:
# If multiple records have the exact same prompt, this will overwrite.
# This is usually fine if the ground_truth is also the same.
self.lookup_table[prompt_text] = [ground_truth, datasetName]
else:
missing_ground_truths += 1
print(f"✅ Lookup table built. Contains {len(self.lookup_table)} entries.")
if missing_ground_truths > 0:
print(f" ⚠️ Warning: {missing_ground_truths} records in the dataset were missing a 'ground_truths' field.")
def set_epoch(self, epoch):
self.current_epoch = epoch
def record_loss(self, step, loss):
self.step_losses.append({"step": step, "loss": loss})
def __call__(self, completions, prompts, **kwargs):
"""
Calculate accuracy (reward) using the pre-computed lookup table.
Used during validation generation steps.
Args:
completions: List of generated text strings for each prompt in the batch.
prompts: List of the formatted input text strings (or list of dicts).
"""
rewards = []
# The `prompts` and `completions` are flattened lists
for i, (prompt_text, completion_text) in enumerate(zip(prompts, completions)):
try:
# prompt_text[0]['content'] ==> system prompt content
# prompt_text[1]['content'] ==> user prompt content
# MODIFIED: SFT prompts include assistant response, but lookup uses user content
user_content = prompt_text[1]['content']
content_of_look_up_table = self.lookup_table.get(user_content)
if content_of_look_up_table is None:
# Fallback or error logging
logging.warning(f"Prompt not found in lookup table. User content snippet: {user_content[:50]}...")
rewards.append(0.0)
continue
ground_truth_proof = content_of_look_up_table[0]
datasetName = content_of_look_up_table[1]
ground_truth = parse_proof(ground_truth_proof, datasetName)
# Extract predicted sentence numbers from the model's completion
# Note: completion_text might be a list or string depending on vllm vs standard generation
# If it's the standard generation list from Part 5 callback:
actual_text = completion_text
if isinstance(completion_text, list):
actual_text = completion_text[0]['content'] if isinstance(completion_text[0], dict) else completion_text[0]
predicted = extract_sentence_numbers(actual_text, datasetName)
# Calculate reward (1.0 if exact match, 0.0 otherwise)
reward = 1.0 if predicted == ground_truth else 0.0
rewards.append(reward)
if datasetName == 'UniADILR':
log_entry = {
'epoch': self.current_epoch,
'batch_idx': i,
'dataset_name': datasetName,
'input': prompt_text,
'ground_truth': sorted(list(ground_truth)),
'predicted': sorted(list(predicted)),
'reward': reward,
'completion': completion_text,
}
elif datasetName == 'balanced_copa_cause_only':
log_entry = {
'epoch': self.current_epoch,
'batch_idx': i,
'dataset_name': datasetName,
'input': prompt_text,
'ground_truth': ground_truth,
'predicted': predicted,
'reward': reward,
'completion': completion_text,
}
else:
log_entry = {
'epoch': self.current_epoch,
'batch_idx': i,
'dataset_name': datasetName,
'input': prompt_text,
'ground_truth': ground_truth,
'predicted': predicted,
'reward': reward,
'completion': completion_text,
}
self.training_log.append(log_entry)
except Exception as e:
logging.exception(f"Error calculating metric for item {i}: {e}")
rewards.append(0.0)
# Save log periodically
if len(self.training_log) > 0 and len(self.training_log) % self.log_every == 0:
try:
with open(self.output_path, 'w', encoding='utf-8') as f:
json.dump(self.training_log, f, ensure_ascii=False, indent=2)
recent_rewards = [r['reward'] for r in self.training_log[-self.log_every:]]
avg_reward = sum(recent_rewards) / len(recent_rewards) if recent_rewards else 0.0
print(f" 💾 Saved {len(self.training_log)} completions log | Recent avg accuracy: {avg_reward:.3f}")
except Exception as e:
logging.warning(f"Failed to save validation log: {e}")
return rewards
def evaluate_batch(self, completions, record_ids, validation_dataset=None):
"""Evaluate a batch of completions against ground truth.
Args:
completions: List of model outputs
record_ids: List of indices into the dataset
validation_dataset: Optional validation dataset
Returns:
List of dicts with reward, predicted, ground_truth, etc.
"""
results = []
# 1. Determine which dataset to use for evaluation.
dataset_to_use = validation_dataset if validation_dataset is not None else self.dataset
# 2. Fetch the specific records from the dataset using the provided record_ids.
try:
records = [dataset_to_use[i] for i in record_ids]
except (IndexError, TypeError) as e:
logging.error(f"Failed to fetch records for evaluation using record_ids. Error: {e}")
return []
for idx, (completion, record) in enumerate(zip(completions, records)):
try:
ground_truth_numbers = record.get('ground_truth', '')
datasetName = record.get('dataset_name', '')
ground_truth = parse_proof(ground_truth_numbers, datasetName)
# Extract predicted sentence numbers
predicted = extract_sentence_numbers(completion, datasetName)
# Calculate reward (Accuracy for SFT)
reward = 1.0 if predicted == ground_truth else 0.0
# Extract input for logging
input_prompt = record.get('prompt', [])
user_content = ""
for msg in input_prompt:
if isinstance(msg, dict) and msg.get('role') == 'user':
user_content = msg.get('content', '')
break
# Construct result dict
result_entry = {
'reward': reward,
'predicted': sorted(list(predicted)) if isinstance(predicted, set) else predicted,
'ground_truth': sorted(list(ground_truth)) if isinstance(ground_truth, set) else ground_truth,
'completion': completion,
'input': user_content,
'dataset_name': datasetName,
}
results.append(result_entry)
# Log entry (saved separately by validation callback)
log_entry = {
'epoch': self.current_epoch,
'record_id': record.get('record_id', idx),
'dataset_name': datasetName,
'input': user_content,
'ground_truth': result_entry['ground_truth'],
'predicted': result_entry['predicted'],
'reward': reward,
'completion': completion,
}
self.training_log.append(log_entry)
except Exception as e:
logging.exception(f"Error evaluating completion {idx}: {e}")
results.append({
'reward': 0.0,
'predicted': [],
'ground_truth': [],
'completion': completion,
'input': '',
'dataset_name': datasetName,
})
return results
# Create evaluation metrics object (formerly reward function)
# MODIFIED: We reuse the class for SFT validation logging
reward_fn = AbductiveRewardFunction(
dataset=train_ds,
tokenizer=tokenizer,
output_path=os.path.join(results_dir, TRAINING_LOG_PATH),
log_every=LOG_TRAIN_EVERY
)
reward_fn.__name__ = "AbductiveValidationScorer"
print(f"✅ Validation Metrics configured")
print(f" Type: Exact match (order-independent)")
print(f" Output file: {TRAINING_LOG_PATH}")
print(f" Log frequency: Every {LOG_TRAIN_EVERY} completions")
# In[ ]:
from unsloth import is_bfloat16_supported
from collections import defaultdict # <--- FIXED: Added this missing import
# Training configuration
print("\n⚙️ Training Configuration")
print("=" * 30)
# MODIFIED: Switched from GRPOConfig to TrainingArguments for SFT
training_args = TrainingArguments(
learning_rate=LEARNING_RATE,
adam_beta1=ADAM_BETA1,
adam_beta2=ADAM_BETA2,
weight_decay=WEIGHT_DECAY,
warmup_steps=WARMUP_STEPS,
lr_scheduler_type=LR_SCHEDULER_TYPE,
dataloader_num_workers = 0,
optim=OPTIM,
logging_steps=1,
# MODIFIED: Save Strategy Configuration
save_strategy="epoch", # Save a checkpoint at the end of every epoch
save_total_limit=None, # Keep all checkpoints (set to a number like 5 to limit disk usage)
load_best_model_at_end=False, # We handle "Best Model" manually in the callback based on Accuracy
per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
# [MODIFIED: The following parameters are GRPO-specific and removed for SFT]
# num_generations=NUM_GENERATIONS,
# max_prompt_length=MAX_PROMPT_LENGTH,
# max_completion_length=MAX_COMPLETION_LENGTH,
# temperature=TEMPERATURE,
# epsilon=EPSILON,
# beta=BETA,
num_train_epochs=NUM_TRAIN_EPOCHS,
# save_steps=SAVE_STEPS,
max_grad_norm=MAX_GRAD_NORM,
report_to=["tensorboard"],
run_name=None,
output_dir=os.path.join(results_dir, "checkpoint"),
# MODIFIED: Added precision settings explicitly for Unsloth SFT
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
)
print(f"Training Parameters:")
print(f" Learning rate: {LEARNING_RATE}")
print(f" Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}")
print(f" Epochs: {NUM_TRAIN_EPOCHS:,}")
print(f" Save every: {SAVE_STEPS} steps")
print(f" Max grad norm: {MAX_GRAD_NORM}")
# print(f" Temperature: {TEMPERATURE}") # Not used in SFT training
print(f" Warmup steps: {WARMUP_STEPS}")
print(f" Weight decay: {WEIGHT_DECAY}")
# In[13]:
from transformers import DataCollatorWithPadding
from tqdm.auto import tqdm # Added tqdm for progress tracking
from collections import defaultdict # <--- FIXED: Added this missing import
# MODIFIED: Handle VLLM import gracefully. If not installed, define a simple config class.
try:
from vllm import SamplingParams
except ImportError:
class SamplingParams:
def __init__(self, temperature, top_p, max_tokens):
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
print("\n🔄 Setting up Training Callbacks with Validation")
print("=" * 45)
# MODIFIED: Set temperature to 0.0 for deterministic greedy decoding during validation
sampling_params = SamplingParams(
temperature=0.0,
top_p=1.0, # Set top_p to 1.0 when using greedy search
max_tokens=MAX_COMPLETION_LENGTH,
)
class EnhancedEpochCallback(TrainerCallback):
"""
Custom callback to log epoch progress, manage metrics, and handle validation.
- Logs start and end of each epoch.
- Records step losses for the metric function.
- Triggers validation at the end of each epoch and after every EVAL_STEPS steps.
"""
def __init__(self, reward_fn, val_dataset, results_dir, use_vllm=False, eval_interval=EVAL_STEPS):
self.reward_fn = reward_fn
self.val_dataset = val_dataset
self.step_count = 0
self.start_time = None
self.validation_metrics = {}
self.results_dir = results_dir
self.trainer = None
self.formatted_inputs = None
self.use_vllm = use_vllm
self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
self.eval_interval = eval_interval
# MODIFIED: Track best model
self.best_accuracy = -1.0
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
print(f"🚀 Training started at {time.strftime('%Y-%m-%d %H:%M:%S')}")
# MODIFIED: For SFT Validation, we must PREPARE inputs by removing the Assistant's answer.
# The dataset 'prompt' field contains [System, User, Assistant].
# We only want to feed [System, User] to the model for generation.
print(" Preparing validation prompts (stripping assistant answers)...")
val_prompts_input_only = []
for conversation in self.val_dataset['prompt']:
# Keep only system and user messages
input_msgs = [msg for msg in conversation if msg['role'] != 'assistant']
val_prompts_input_only.append(input_msgs)
self.formatted_inputs = self.trainer.processing_class.apply_chat_template(
val_prompts_input_only, # Use the input-only prompts
tokenize=False,
add_generation_prompt=True
)
def on_epoch_begin(self, args, state, control, **kwargs):
epoch_idx = int(state.epoch) + 1 # Convert to 1-indexed
self.reward_fn.set_epoch(epoch_idx)
print(f"\n📍 Starting epoch {epoch_idx}")
def on_step_end(self, args, state, control, **kwargs):
current_loss = 'N/A'
if state.log_history:
current_loss = state.log_history[-1].get("loss", 'N/A')
if current_loss != 'N/A':
self.reward_fn.record_loss(state.log_history[-1]['step'], current_loss)
self.step_count += 1
if self.step_count % 50 == 0:
elapsed = time.time() - self.start_time
steps_per_sec = self.step_count / elapsed
print(f" Step {self.step_count} | Loss: {current_loss} | Speed: {steps_per_sec:.2f} steps/s")
if (
LOG_VALIDATION
and self.eval_interval
and (self.step_count % self.eval_interval == 0)
and self.trainer
):
self.evaluate_validation(
self.trainer.model,
self.trainer.processing_class,
state.global_step,
)
def evaluate_validation(self, model, tokenizer, step, epoch_override=None):
print(f"\n🔍 Validation at step {step}:")
try:
val_rewards = []
validation_log = []
# NEW: Dictionary to track accuracy per dataset
dataset_stats = defaultdict(list)
batch_size = PER_DEVICE_EVAL_BATCH_SIZE
# Added tqdm for progress tracking
total_samples = len(self.val_dataset)
batch_iterator = range(0, total_samples, batch_size)
# Ensure prompts are prepared if called before training
if self.formatted_inputs is None:
val_prompts_input_only = []
for conversation in self.val_dataset['prompt']:
input_msgs = [msg for msg in conversation if msg['role'] != 'assistant']
val_prompts_input_only.append(input_msgs)
self.formatted_inputs = tokenizer.apply_chat_template(
val_prompts_input_only, tokenize=False, add_generation_prompt=True
)
with torch.no_grad():
for batch_num in tqdm(batch_iterator, desc=" Generating & Evaluating", unit="batch", leave=False):
FastLanguageModel.for_inference(model)
batch = self.formatted_inputs[batch_num:batch_num + batch_size]
if self.use_vllm:
outputs = model.fast_generate(
batch,
lora_request=None,
sampling_params=sampling_params,
)
completions = [o.outputs[0].text.strip() for o in outputs]
else:
batch_encodings = tokenizer(batch, return_tensors="pt", padding=True).to(model.device)
# --- MODIFIED: Handle Greedy Decoding Correctly for HuggingFace ---
gen_kwargs = {
"max_new_tokens": sampling_params.max_tokens,
}
# If temperature is 0 (or very small), force greedy decoding by setting do_sample=False
if sampling_params.temperature < 1e-5:
gen_kwargs["do_sample"] = False
else:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = sampling_params.temperature
gen_kwargs["top_p"] = sampling_params.top_p
outputs = model.generate(**batch_encodings, **gen_kwargs)
prompt_lengths = batch_encodings["input_ids"].shape[1]
generated_tokens = outputs[:, prompt_lengths:]
completions = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
batch_indices = list(range(batch_num, batch_num + len(completions)))
results = self.reward_fn.evaluate_batch(completions, batch_indices, validation_dataset=self.val_dataset)
for batch_idx, result in enumerate(results):
reward = result["reward"]
ds_name = result["dataset_name"]
val_rewards.append(reward)
dataset_stats[ds_name].append(reward) # Track per dataset
validation_log.append({
"record_id": self.val_dataset['record_id'][batch_num + batch_idx],
"dataset_name": ds_name,
"input": result.get("input", ""),
"ground_truth": result["ground_truth"],
"predicted": result["predicted"],
"reward": reward, # In SFT this is Accuracy (1.0 or 0.0)
"completion": result["completion"],
})
FastLanguageModel.for_training(model)
if val_rewards:
avg_val_reward = sum(val_rewards) / len(val_rewards)
# NEW: Print detailed breakdown
print(f" 📊 Overall Accuracy: {avg_val_reward:.4f}")
print(" 📈 Breakdown by Dataset:")
for name, scores in dataset_stats.items():
avg_score = sum(scores) / len(scores)
print(f" • {name}: {avg_score:.4f} (n={len(scores)})")
# MODIFIED: Save Best Model Logic
if avg_val_reward > self.best_accuracy:
self.best_accuracy = avg_val_reward
print(f" 🌟 New Best Accuracy: {self.best_accuracy:.4f}")
if self.trainer:
best_model_path = os.path.join(self.results_dir, "best_model")
print(f" 💾 Saving best model to: {best_model_path}")
self.trainer.save_model(best_model_path)
# Also save tokenizer for convenience
if tokenizer:
tokenizer.save_pretrained(best_model_path)
# Determine epoch key (handle pre-training case where trainer.state might not exist)
if epoch_override is not None:
epoch_key = str(epoch_override)
elif self.trainer:
epoch_key = str(int(self.trainer.state.epoch))
else:
epoch_key = "0"
self.validation_metrics[epoch_key] = {
'avg_reward': avg_val_reward,
'num_samples': len(val_rewards),
'breakdown': {k: sum(v)/len(v) for k, v in dataset_stats.items()}
}
# Save logs (append if existing)
val_log_path = os.path.join(self.results_dir, VALIDATION_LOG_PATH)
existing_data = {}
if os.path.exists(val_log_path):
with open(val_log_path, "r", encoding="utf-8") as f:
existing_data = json.load(f)
existing_data[epoch_key] = validation_log
with open(val_log_path, "w", encoding="utf-8") as f:
json.dump(existing_data, f, ensure_ascii=False, indent=2)
# Save metrics
val_metrics_path = os.path.join(self.results_dir, VALIDATION_METRICS_PATH)
all_metrics = {}
if os.path.exists(val_metrics_path):
with open(val_metrics_path, "r", encoding="utf-8") as f:
all_metrics = json.load(f)
all_metrics[epoch_key] = self.validation_metrics[epoch_key]
with open(val_metrics_path, "w", encoding="utf-8") as f:
json.dump(all_metrics, f, ensure_ascii=False, indent=2)
# Save training log from reward fn
try:
with open(self.reward_fn.output_path, 'w', encoding='utf-8') as f:
json.dump(self.reward_fn.training_log, f, ensure_ascii=False, indent=2)
except Exception as e:
logging.warning(f"Failed to save training log: {e}")
else:
logging.warning(f"⚠️ No validation rewards computed. Step: {step}")
except Exception as e:
logging.exception(f"❌ Validation error: {e}")
def on_epoch_end(self, args, state, control, **kwargs):
completed_epoch_idx = int(state.epoch)
print(f"✅ Completed epoch {completed_epoch_idx}")
# Trigger validation at the end of the epoch
if LOG_VALIDATION:
if self.trainer:
# We use state.global_step to be consistent with Hugging Face's tracking
self.evaluate_validation(self.trainer.model, self.trainer.processing_class, state.global_step)
else:
logging.warning("⚠️ No trainer assigned; cannot evaluate validation.")
def on_save(self, args, state, control, **kwargs):
print(f"💾 Checkpoint saved at step {state.global_step}")
# Initialize callback
enhanced_callback = EnhancedEpochCallback(
reward_fn=reward_fn,
val_dataset=val_ds,
results_dir=results_dir,
use_vllm=USE_VLLM,
# eval_interval=EVAL_STEPS, #if we want to evaluate every EVAL_STEPS steps
)
print("✅ Enhanced callbacks configured:")
print(" - Epoch management")
print(" - Progress tracking with loss")
print(" - Validation evaluation (Assistant answers stripped for generation)")
print(" - Validation JSON logging")
print(" - Checkpoint notifications")
print(f" - Validation every {EVAL_STEPS} steps")
# In[14]:
# Create trainer with enhanced validation
print("\n🏗️ Creating Trainer with Validation")
print("=" * 35)
# Formatting function for Unsloth's SFTTrainer
# MUST return a list of strings (even for single example)
def formatting_prompts_func(example):
"""
Format prompts for SFT training.
Unsloth calls this with a SINGLE example during validation check,
and with batched examples during training.
Must ALWAYS return a list of strings.
"""
convos = example["prompt"]
texts = []
# Safety check for empty data
if not convos:
return []
# Detect if this is a single example or batch
# Single example: example["prompt"] = [{"role": "system", ...}, {"role": "user", ...}, ...]
# Batch: example["prompt"] = [[{...}, {...}], [{...}, {...}], ...]
if isinstance(convos[0], dict):
# Single conversation - convos IS the conversation
text = tokenizer.apply_chat_template(convos, tokenize=False, add_generation_prompt=False)
texts.append(text)
else:
# Batch of conversations - convos is a list of conversations
for convo in convos:
text = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
texts.append(text)
return texts # Always return list of strings
try:
# SFTTrainer setup - same for both platforms now
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_ds,
eval_dataset=val_ds,
formatting_func=formatting_prompts_func,
args=training_args,
packing=False,
max_seq_length=MAX_SEQ_LENGTH,
)
enhanced_callback.trainer = trainer
trainer.add_callback(enhanced_callback)
print("✅ Trainer created successfully!")
print(f" Platform: Local")
print(f" Model: {type(model).__name__}")
print(f" Training samples: {len(train_ds):,}")
print(f" Validation samples: {len(val_ds):,}")
print(f" Callbacks: {len(trainer.callback_handler.callbacks)}")
except Exception as e:
logging.exception(f"❌ Failed to create trainer: {e}")
raise
print(f"\n📋 Training Summary:")
print(f" Total training epochs: {NUM_TRAIN_EPOCHS}")
print(f" Effective batch size: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f" Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f" Output directory: {results_dir}")
# In[ ]:
import sys
from datetime import datetime
import signal
# Set up proper logging at the start of your notebook
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training_log.log'),
logging.StreamHandler(sys.stdout)
]
)
# Add a custom callback for better progress tracking
from transformers import TrainerCallback
import math
class DetailedProgressCallback(TrainerCallback):
def __init__(self):
self.start_time = time.time()
self.step_times = []
self.last_log_time = time.time()
def on_step_begin(self, args, state, control, **kwargs):
"""Called at the beginning of each training step"""
current_time = time.time()
# Log every 10 steps or every 30 seconds, whichever comes first
if state.global_step % 10 == 0 or (current_time - self.last_log_time) > 30:
elapsed = current_time - self.start_time
steps_per_sec = state.global_step / elapsed if elapsed > 0 else 0
# Calculate ETA
remaining_steps = state.max_steps - state.global_step
eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
eta_str = time.strftime('%H:%M:%S', time.gmtime(eta_seconds))
progress_pct = (state.global_step / state.max_steps) * 100
print(f"\r⏳ Step {state.global_step}/{state.max_steps} ({progress_pct:.1f}%) | "
f"Speed: {steps_per_sec:.2f} steps/s | ETA: {eta_str} | "
f"Epoch: {state.epoch:.1f}", end='', flush=True)
self.last_log_time = current_time
def on_log(self, args, state, control, logs=None, **kwargs):
"""Called when logging occurs"""
if logs:
print() # New line after progress bar
log_items = []
for k, v in logs.items():
if k == 'epoch': continue
if isinstance(v, float):
if 'learning_rate' in k or 'lr' in k:
# MODIFIED: Use scientific notation for Learning Rate (e.g. 1.00e-05)
val_str = f"{v:.2e}"
else:
# Standard 4 decimals for loss/grad_norm
val_str = f"{v:.4f}"
else:
val_str = f"{v}"
log_items.append(f"{k}: {val_str}")
log_str = " | ".join(log_items)
print(f"📊 {log_str}")
logging.info(log_str)
def on_epoch_end(self, args, state, control, **kwargs):
"""Called at the end of each epoch"""
print() # New line
elapsed = time.time() - self.start_time
print(f"\n✅ Epoch {int(state.epoch)} completed | "
f"Total time: {elapsed/60:.1f}m | "
f"Steps: {state.global_step}/{state.max_steps}")
logging.info(f"Epoch {int(state.epoch)} completed")
def on_train_begin(self, args, state, control, **kwargs):
"""Called at the start of training"""
print(f"\n🎯 SFT Training will run for {state.max_steps} steps")
print(f"📝 Logging every {args.logging_steps} steps")
print(f"💾 Saving checkpoints every {args.save_steps} steps")
print("-" * 70)
logging.info("SFT Training started")
# Add progress callback to trainer
progress_callback = DetailedProgressCallback()
trainer.add_callback(progress_callback)
# Handle keyboard interrupts gracefully
def signal_handler(sig, frame):
print("\n⚠️ Interrupt signal received. Saving progress...")
logging.warning("Training interrupted by user")
trainer.save_model(os.path.join(results_dir, "checkpoint", "interrupted"))
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
# Start training with enhanced logging
print("\n🚀 Starting SFT Training")
print("=" * 70)
print(f"⏰ Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"🏷️ Run name: {run_name}")
print(f"📁 Output directory: {results_dir}")
print(f"🔍 Logs will be saved to: training_log.log")
print("-" * 70)
# Verify logging is working
logging.info(f"Starting SFT training run: {run_name}")
logging.info(f"Output directory: {results_dir}")
logging.info(f"Training config: epochs={NUM_TRAIN_EPOCHS}, batch_size={PER_DEVICE_TRAIN_BATCH_SIZE}")
training_start_time = time.time()
last_checkpoint_time = training_start_time
try:
# Verify trainer is set up correctly
print("🔍 Verifying trainer configuration...")
print(f" • Total training steps: {trainer.args.max_steps}")
print(f" • Steps per epoch: {len(trainer.get_train_dataloader())}")
print(f" • Logging interval: {trainer.args.logging_steps} steps")
print(f" • Save interval: {trainer.args.save_steps} steps")
print()
# NEW: Run Baseline Evaluation
print("\n🔍 Running Baseline Evaluation (Pre-training)...")
print(" This measures zero-shot performance before any updates.")
print("=" * 60, end='')
enhanced_callback.evaluate_validation(model, tokenizer, step=0, epoch_override=0)
print("=" * 60)
# Force immediate logging
sys.stdout.flush()
logging.info("Calling trainer.train()...")
# Start the training process
print("🎬 Initiating training loop...\n")
trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)
training_end_time = time.time()
training_duration = training_end_time - training_start_time
print("\n" + "="*70)
print("🎉 SFT TRAINING COMPLETED SUCCESSFULLY!")
print("="*70)
print(f"⏱️ Duration: {training_duration/3600:.2f} hours ({training_duration/60:.1f} minutes)")
print(f"📈 Average time per epoch: {training_duration/NUM_TRAIN_EPOCHS/60:.2f} minutes")
print(f"🏁 Completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f"SFT Training completed successfully in {training_duration/3600:.2f} hours")
except KeyboardInterrupt:
print("\n\n⚠️ Training interrupted by user")
logging.warning("Training interrupted by user (KeyboardInterrupt)")
print("💾 Saving current progress...")
except Exception as e:
print(f"\n\n❌ Training failed with error!")
print(f"Error type: {type(e).__name__}")
print(f"Error message: {str(e)}")
print("\n📋 Full traceback:")
logging.exception(f"Training failed with error: {e}")
import traceback
traceback.print_exc()
raise
finally:
training_end_time = time.time()
actual_duration = training_end_time - training_start_time
print("\n" + "="*70)
print("🔄 Cleanup and saving...")
print("="*70)
# Always try to save the current state
try:
# Save final training log (contains Validation Metrics from Part 4)
if reward_fn and hasattr(reward_fn, 'training_log') and reward_fn.training_log:
try:
log_path = os.path.join(results_dir, "training_rewards.json")
with open(log_path, 'w', encoding='utf-8') as f:
json.dump(reward_fn.training_log, f, ensure_ascii=False, indent=2)
print(f"✅ Metric/Validation log saved: {len(reward_fn.training_log)} entries")
logging.info(f"Saved metric log with {len(reward_fn.training_log)} entries")
except Exception as e:
print(f"⚠️ Failed to save metric log: {e}")
logging.warning(f"Failed to save metric log: {e}")
# Save final model
final_model_path = os.path.join(results_dir, "checkpoint", "final_model")
os.makedirs(final_model_path, exist_ok=True)
trainer.save_model(final_model_path)
print(f"✅ Model saved to: {final_model_path}")
logging.info(f"Final model saved to: {final_model_path}")
print(f"\n⏱️ Total elapsed time: {actual_duration/60:.1f} minutes")
print("="*70)
except Exception as e:
print(f"⚠️ Error during cleanup: {e}")
logging.exception("Error during cleanup")
# In[16]:
# Optional: Visualize training progress
print("\n📊 Training Visualization")
print("=" * 30)
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
# 1. Plot Validation Accuracy (Metrics)
val_metrics_path = os.path.join(results_dir, VALIDATION_METRICS_PATH)
if os.path.exists(val_metrics_path):
with open(val_metrics_path, 'r') as f:
val_metrics = json.load(f)
epochs = [float(k) for k in val_metrics.keys()]
# MODIFIED: Label as Accuracy for SFT
accuracies = [v['avg_reward'] for v in val_metrics.values()]
plt.figure(figsize=(10, 6))
plt.plot(epochs, accuracies, marker='o', linewidth=2, markersize=8, color='#2ca02c') # Green for accuracy
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Validation Accuracy (Exact Match)', fontsize=12)
plt.title('SFT Validation Performance', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plot_path = os.path.join(results_dir, 'validation_accuracy.png')
plt.savefig(plot_path, dpi=300)
print(f"✅ Validation accuracy plot saved to: {plot_path}")
plt.show()
else:
print("⚠️ No validation metrics found")
# 2. MODIFIED: Plot Training Loss (Crucial for SFT)
# We extract this directly from the trainer's state
if hasattr(trainer, 'state') and trainer.state.log_history:
log_history = trainer.state.log_history
# Extract loss values
steps = []
losses = []
for entry in log_history:
if 'loss' in entry and 'step' in entry:
steps.append(entry['step'])
losses.append(entry['loss'])
if steps:
plt.figure(figsize=(10, 6))
plt.plot(steps, losses, linewidth=2, color='#1f77b4') # Blue for loss
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Cross Entropy Loss', fontsize=12)
plt.title('SFT Training Loss', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
loss_plot_path = os.path.join(results_dir, 'training_loss.png')
plt.savefig(loss_plot_path, dpi=300)
print(f"✅ Training loss plot saved to: {loss_plot_path}")
plt.show()
else:
print("⚠️ No loss history found in trainer state")
# 3. MODIFIED: Plot Accuracy by Dataset Type (If detailed logs exist)
# This helps compare performance on UniADILR vs COPA
training_log_path = os.path.join(results_dir, TRAINING_LOG_PATH)
if os.path.exists(training_log_path):
try:
with open(training_log_path, 'r') as f:
logs = json.load(f)
# Convert to DataFrame for easier analysis
df = pd.DataFrame(logs)
if 'dataset_name' in df.columns and 'reward' in df.columns:
# Calculate average accuracy per dataset type
accuracy_by_dataset = df.groupby('dataset_name')['reward'].mean().reset_index()
accuracy_by_dataset.columns = ['Dataset', 'Accuracy']
plt.figure(figsize=(10, 6))
sns.barplot(data=accuracy_by_dataset, x='Dataset', y='Accuracy', palette='viridis')
plt.ylim(0, 1.0)
plt.title('Overall Accuracy by Dataset Type', fontsize=14)
plt.ylabel('Average Accuracy', fontsize=12)
plt.tight_layout()
breakdown_path = os.path.join(results_dir, 'accuracy_by_dataset.png')
plt.savefig(breakdown_path, dpi=300)
print(f"✅ Dataset breakdown plot saved to: {breakdown_path}")
plt.show()
# Print text summary
print("\n📊 Performance Breakdown:")
for index, row in accuracy_by_dataset.iterrows():
print(f" • {row['Dataset']}: {row['Accuracy']*100:.2f}%")
except Exception as e:
print(f"⚠️ Could not create dataset breakdown: {e}")
# In[17]:
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
print("\n📊 Advanced Visualization (Adapted for SFT)")
print("=" * 60)
plots_dir = os.path.join(results_dir, "plots")
os.makedirs(plots_dir, exist_ok=True)
# Load Data
val_metrics_path = os.path.join(results_dir, VALIDATION_METRICS_PATH)
training_log_path = os.path.join(results_dir, TRAINING_LOG_PATH)
val_log_path = os.path.join(results_dir, VALIDATION_LOG_PATH)
# 1. Detailed Accuracy per Dataset over Epochs
if os.path.exists(val_metrics_path):
with open(val_metrics_path, 'r') as f:
metrics = json.load(f)
epochs = sorted([int(k) for k in metrics.keys()])
# Prepare data for plotting
plot_data = []
for ep in epochs:
ep_key = str(ep)
if 'breakdown' in metrics[ep_key]:
for ds_name, score in metrics[ep_key]['breakdown'].items():
plot_data.append({'Epoch': ep, 'Accuracy': score, 'Dataset': ds_name})
# Add overall
plot_data.append({'Epoch': ep, 'Accuracy': metrics[ep_key]['avg_reward'], 'Dataset': 'Overall'})
if plot_data:
df_acc = pd.DataFrame(plot_data)
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_acc, x='Epoch', y='Accuracy', hue='Dataset', marker='o', linewidth=2)
plt.title('Validation Accuracy per Dataset over Epochs')
plt.ylabel('Accuracy (Exact Match)')
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.05)
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, "accuracy_breakdown.png"), dpi=300)
print("✅ Saved per-dataset accuracy plot")
plt.close()
# 2. Reward (Accuracy) Transitions (Adapted from GRPO code)
# Tracks which validation examples were "solved" or "unsolved" over time
def plot_sft_transitions(val_log_path, output_dir):
if not os.path.exists(val_log_path): return
with open(val_log_path, 'r') as f:
val_data = json.load(f)
# Organize data: record_id -> {epoch -> correct(1)/incorrect(0)}
history = defaultdict(dict)
all_epochs = sorted([int(k) for k in val_data.keys()])
for ep_str, records in val_data.items():
epoch = int(ep_str)
for rec in records:
rid = rec.get('record_id')
# Assuming reward is 1.0 or 0.0
history[rid][epoch] = 1 if rec['reward'] > 0.5 else 0
# Calculate Gained vs Lost
# Gained: Wrong in Epoch 0 -> Right in Last Epoch
# Lost: Right in Epoch 0 -> Wrong in Last Epoch
first_ep = all_epochs[0]
last_ep = all_epochs[-1]
gained = 0
lost = 0
stable_correct = 0
stable_wrong = 0
for rid, eps in history.items():
if first_ep in eps and last_ep in eps:
start = eps[first_ep]
end = eps[last_ep]
if start == 0 and end == 1: gained += 1
elif start == 1 and end == 0: lost += 1
elif start == 1 and end == 1: stable_correct += 1
elif start == 0 and end == 0: stable_wrong += 1
# Plot
categories = ['Gained (Learned)', 'Lost (Forgetting)', 'Stable Correct', 'Stable Wrong']
values = [gained, lost, stable_correct, stable_wrong]
colors = ['#2ca02c', '#d62728', '#1f77b4', 'gray']
plt.figure(figsize=(8, 6))
bars = plt.bar(categories, values, color=colors)
plt.bar_label(bars)
plt.title(f'Learning Dynamics (Epoch {first_ep} vs {last_ep})')
plt.ylabel('Number of Validation Samples')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "learning_transitions.png"), dpi=300)
print("✅ Saved learning transitions plot")
plt.close()
plot_sft_transitions(val_log_path, plots_dir)
# 3. Training Loss (Smoothed)
if hasattr(trainer, 'state') and trainer.state.log_history:
log_history = trainer.state.log_history
steps = []
losses = []
for x in log_history:
if 'loss' in x:
steps.append(x['step'])
losses.append(x['loss'])
if steps:
plt.figure(figsize=(10, 5))
# Raw loss
plt.plot(steps, losses, alpha=0.3, color='blue', label='Raw Loss')
# Smoothed loss
if len(losses) > 10:
avg_loss = pd.Series(losses).rolling(window=10).mean()
plt.plot(steps, avg_loss, color='blue', linewidth=2, label='Smoothed (MA-10)')
plt.title('Training Loss Curve')
plt.xlabel('Steps')
plt.ylabel('Cross Entropy Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(plots_dir, "training_loss_detailed.png"), dpi=300)
print("✅ Saved detailed training loss plot")
plt.close()
print(f"📊 All plots saved to: {plots_dir}")
# In[18]:
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import defaultdict
print("\n📊 Generating Advanced Plots (Per Epoch Comparison & Transitions)")
print("=" * 60)
# Ensure plots directory exists
plots_dir = os.path.join(results_dir, "plots")
os.makedirs(plots_dir, exist_ok=True)
# Load the validation log
val_log_path = os.path.join(results_dir, VALIDATION_LOG_PATH)
if os.path.exists(val_log_path):
with open(val_log_path, 'r', encoding='utf-8') as f:
val_data = json.load(f)
# 1. ACCURACY COMPARISON PER EPOCH (Breakdown)
# =========================================================
metrics_data = []
# Sort epochs numerically
epochs = sorted([int(k) for k in val_data.keys()])
for ep in epochs:
ep_str = str(ep)
records = val_data[ep_str]
# Calculate per-dataset accuracy for this epoch
ds_scores = defaultdict(list)
all_scores = []
for rec in records:
score = 1.0 if rec['reward'] > 0.5 else 0.0 # Binarize reward if needed
ds_name = rec.get('dataset_name', 'Unknown')
ds_scores[ds_name].append(score)
all_scores.append(score)
# Add to plot data
metrics_data.append({
'Epoch': ep,
'Dataset': 'Overall',
'Accuracy': sum(all_scores)/len(all_scores) if all_scores else 0
})
for name, scores in ds_scores.items():
metrics_data.append({
'Epoch': ep,
'Dataset': name,
'Accuracy': sum(scores)/len(scores) if scores else 0
})
# Plot Accuracy
if metrics_data:
df_metrics = pd.DataFrame(metrics_data)
plt.figure(figsize=(10, 6))
# Use seaborn lineplot to show trends per dataset
sns.lineplot(data=df_metrics, x='Epoch', y='Accuracy', hue='Dataset', style='Dataset', markers=True, dashes=False, linewidth=2.5)
plt.title('Validation Accuracy Comparison per Epoch', fontsize=14)
plt.ylabel('Accuracy (Exact Match)', fontsize=12)
plt.xlabel('Epoch', fontsize=12)
plt.grid(True, alpha=0.3)
plt.ylim(-0.05, 1.05)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, "accuracy_comparison_per_epoch.png"), dpi=300)
print("✅ Saved accuracy comparison plot")
plt.show()
# 2. TRANSITIONS PLOT (Learning Dynamics)
# =========================================================
# Measures: Gained (Wrong->Right), Lost (Right->Wrong), Stable Correct, Stable Wrong
# Compares First Epoch vs Last Epoch
if len(epochs) >= 2:
first_ep = str(epochs[0]) # Usually 0 or 1
last_ep = str(epochs[-1])
# Build map: record_id -> {epoch -> correct?}
history = defaultdict(dict)
for ep in [first_ep, last_ep]:
for rec in val_data[ep]:
rid = rec.get('record_id')
# Check if correct (reward >= 1.0)
is_correct = 1 if rec['reward'] > 0.5 else 0
history[rid][ep] = is_correct
# Categorize
transitions = {
'Gained (Learned)': 0,
'Lost (Forgot)': 0,
'Stable Correct': 0,
'Stable Wrong': 0
}
for rid, eps in history.items():
if first_ep in eps and last_ep in eps:
start = eps[first_ep]
end = eps[last_ep]
if start == 0 and end == 1: transitions['Gained (Learned)'] += 1
elif start == 1 and end == 0: transitions['Lost (Forgot)'] += 1
elif start == 1 and end == 1: transitions['Stable Correct'] += 1
elif start == 0 and end == 0: transitions['Stable Wrong'] += 1
# Plot Transitions
plt.figure(figsize=(9, 6))
# Colors: Green, Red, Blue, Gray
colors = ['#2ca02c', '#d62728', '#1f77b4', '#7f7f7f']
bars = plt.bar(transitions.keys(), transitions.values(), color=colors, alpha=0.8)
# Add count labels
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{int(height)}',
ha='center', va='bottom')
plt.title(f'Learning Transitions (Epoch {first_ep}{last_ep})', fontsize=14)
plt.ylabel('Number of Samples', fontsize=12)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, "learning_transitions.png"), dpi=300)
print("✅ Saved learning transitions plot")
plt.show()
else:
print("⚠️ Need at least 2 epochs to plot transitions.")
else:
print(f"❌ Validation log not found at: {val_log_path}")
# In[19]:
# Optional: Compare Training Loss vs Validation Loss
# Note: This requires evaluation to be set to run during training (e.g. eval_steps)
if hasattr(trainer, 'state') and trainer.state.log_history:
log_hist = trainer.state.log_history
# Extract training loss
train_steps = [x['step'] for x in log_hist if 'loss' in x]
train_loss = [x['loss'] for x in log_hist if 'loss' in x]
# Extract validation loss (if available in logs)
val_steps = [x['step'] for x in log_hist if 'eval_loss' in x]
val_loss = [x['eval_loss'] for x in log_hist if 'eval_loss' in x]
if train_steps and val_steps:
plt.figure(figsize=(10, 6))
plt.plot(train_steps, train_loss, label='Training Loss', color='blue', alpha=0.6)
plt.plot(val_steps, val_loss, label='Validation Loss', color='red', marker='o')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(results_dir, "loss_comparison.png"))
print("✅ Saved loss comparison plot")
plt.show()