| |
| |
|
|
| |
|
|
|
|
| from datetime import datetime |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| MODEL_NAME = '/home/moein_salimi/PLLMS/unsloth-Qwen2.5-14B-Instruct-bnb-4bit' |
|
|
|
|
| LOAD_IN_4BIT = True |
| LOAD_IN_8BIT = False |
| USE_VLLM = False |
| LORA_RANK = 64 |
| LORA_ALPHA = 64 |
| LORA_DROPOUT = 0.05 |
| GPU_MEMORY_UTILIZATION = 1.0 |
| MAX_SEQ_LENGTH = 4096 |
| MAX_PROMPT_LENGTH = 2048 |
| MAX_COMPLETION_LENGTH = MAX_SEQ_LENGTH - MAX_PROMPT_LENGTH |
|
|
| RESUME_FROM_CHECKPOINT = False |
| PREVIOUS_RUN_DIR = 'dt11.15.23:13_e20_unsloth_Qwen2.5_3B_Instruct_unsloth_bnb_4bit_bnb_4bit_lr1e-05_t0.7_ε0.2_r64_b16' |
|
|
| RUN_DESC = "SFT_Implementation" |
| CUDA_VISIBLE_DEVICES = "0" |
|
|
| |
| LEARNING_RATE = 5e-6 |
| ADAM_BETA1 = 0.9 |
| ADAM_BETA2 = 0.99 |
| WEIGHT_DECAY = 0.3 |
| WARMUP_STEPS = 10 |
| LR_SCHEDULER_TYPE = "cosine" |
| OPTIM = "adamw_torch" |
|
|
|
|
| |
| EVAL_STEPS = 512 |
| SAVE_STEPS = 512 |
| LOG_VALIDATION = True |
| LOG_TRAIN_EVERY = 1 |
|
|
| |
| PER_DEVICE_TRAIN_BATCH_SIZE = 4 |
| PER_DEVICE_EVAL_BATCH_SIZE = 1 |
| GRADIENT_ACCUMULATION_STEPS = 4 |
|
|
| |
| |
|
|
| MAX_GRAD_NORM = 0.1 |
| TEMPERATURE = 0.0 |
| NUM_TRAIN_EPOCHS = 6 |
|
|
|
|
| |
| |
| |
| MIXED_DATA = True |
|
|
| |
| ERROR_LOG_PATH = "error_log.log" |
| TRAINING_LOG_PATH = "training_log.json" |
| VALIDATION_LOG_PATH = "validation_log.json" |
| VALIDATION_METRICS_PATH = "val_metrics.json" |
|
|
| |
| SYSTEM_PROMPT_UniADILR = """ |
| |
| You are an expert in logical reasoning and abductive inference. Your task is to identify which sentences from a given context provide the necessary evidence to support or explain a hypothesis. |
| |
| You will be provided with: |
| 1. A Context containing multiple numbered sentences (sent1, sent2, sent3, etc.) |
| 2. A Hypothesis that needs to be supported or explained |
| |
| Your goal is to identify which sentence(s) from the context, when combined, provide the logical foundation for the hypothesis through abductive reasoning. |
| |
| ## Instructions: |
| 1. Carefully read all sentences in the context |
| 2. Analyze the hypothesis |
| 3. Identify which sentences, when combined, best explain or support the hypothesis |
| 4. Consider both direct evidence and logical connections |
| |
| ## Output Format: |
| You MUST provide your answer in the following format: |
| |
| <think> |
| [Explain your thought process: why you selected these particular sentences and how they support the hypothesis] |
| </think> |
| |
| <answer> |
| [Sentence numbers only, comma-separated. For example: 5, 13 or 2, 7, 9] |
| </answer> |
| |
| CRITICAL: The answer section must contain ONLY the sentence numbers separated by commas. Do not include the word "sent" or any other text. |
| """.strip() |
|
|
|
|
| SYSTEM_PROMPT_balanced_copa_cause_only = """ |
| |
| You are an expert in logical reasoning and abductive inference. Your task is to determine which of two given choices represents the most plausible cause for a given premise. |
| |
| You will be provided with: |
| 1. A Premise describing a situation or event |
| 2. Two Choices (Choice 1 and Choice 2) |
| |
| Your goal is to select the choice that best explains WHY the premise happened - identifying the root cause that led to the described situation. |
| |
| ## Instructions: |
| 1. Carefully read the premise |
| 2. Evaluate both choices as potential causes |
| 3. Consider common sense, real-world knowledge, and typical causal relationships when making your decision |
| 4. Select the choice that represents the most plausible and direct cause |
| |
| ## Output Format: |
| You MUST provide your answer in the following format: |
| |
| <think> |
| [Explain your thought process: why we should select one choice over the other or analyzing the cause or their relationships] |
| </think> |
| |
| <answer> |
| [Either "1" or "2" - just the number, nothing else] |
| </answer> |
| |
| CRITICAL: The answer section must contain ONLY the number 1 or 2. Do not include any other text, explanation, or punctuation. |
| """.strip() |
|
|
| |
| RANDOM_STATE = 3407 |
| TORCH_SEED = 42 |
| NUMPY_SEED = 42 |
|
|
| |
| WANDB_DISABLED = "true" |
|
|
|
|
| |
|
|
|
|
| import os |
| import sys |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| print("💻 Running Locally") |
| BASE_DATA_DIR = "./dataset" |
| BASE_OUTPUT_DIR = "./results_sft_14b" |
|
|
| |
| import torch |
| print(f"\n🔥 PyTorch version: {torch.__version__}") |
| print(f"🎮 CUDA available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| print(f"🎮 CUDA version: {torch.version.cuda}") |
|
|
| os.makedirs(BASE_OUTPUT_DIR, exist_ok=True) |
| print(f"\n📂 Data Directory: {BASE_DATA_DIR}") |
| print(f"💾 Output Directory: {BASE_OUTPUT_DIR}") |
|
|
|
|
| |
|
|
| |
| def get_run_name(): |
| """Generate run name based on configuration""" |
| model_name = MODEL_NAME.split("/")[-1].replace("-", "_") |
| if LOAD_IN_8BIT: |
| model_name += "_8bit" |
| elif LOAD_IN_4BIT: |
| model_name += "_bnb_4bit" |
| now = datetime.now() |
|
|
| |
| name = f"SFT_dt{now.strftime('%m.%d.%H:%M')}_e{NUM_TRAIN_EPOCHS}_{model_name}_lr{LEARNING_RATE}_t{TEMPERATURE}_r{LORA_RANK}_b{PER_DEVICE_TRAIN_BATCH_SIZE}" |
|
|
| if RUN_DESC: |
| name += f"_{RUN_DESC}" |
| return name |
|
|
|
|
| def get_results_dir(run_name=None): |
| """Get results directory path based on environment""" |
| if run_name is None: |
| run_name = get_run_name() |
| if RESUME_FROM_CHECKPOINT: |
| run_name = PREVIOUS_RUN_DIR |
|
|
| |
| return os.path.join(BASE_OUTPUT_DIR, run_name) |
|
|
|
|
| |
|
|
|
|
| |
| import os |
| import sys |
| import warnings |
| warnings.filterwarnings('ignore') |
| import random |
| import numpy as np |
| import torch |
|
|
| |
| sys.path.append('.') |
|
|
| |
| random.seed(RANDOM_STATE) |
| np.random.seed(NUMPY_SEED) |
| torch.manual_seed(TORCH_SEED) |
| torch.cuda.manual_seed_all(TORCH_SEED) |
|
|
| print(f"🎲 Random seeds set:") |
| print(f" Python: {RANDOM_STATE}") |
| print(f" NumPy: {NUMPY_SEED}") |
| print(f" PyTorch: {TORCH_SEED}") |
|
|
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES |
| os.environ["WANDB_DISABLED"] = WANDB_DISABLED |
|
|
| |
| print("\n🔧 Abductive Reasoning SFT Training Pipeline") |
| print("=" * 50) |
| print(f"Configuration loaded:") |
| print(f" 📦 Model: {MODEL_NAME}") |
| print(f" 🎯 Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}") |
| print(f" 🏃 Epochs: {NUM_TRAIN_EPOCHS}") |
| print(f" 📈 Learning rate: {LEARNING_RATE}") |
| |
| |
| print(f" 🌡️ Temperature: {TEMPERATURE}") |
| print(f" 🎮 GPU: {CUDA_VISIBLE_DEVICES}") |
|
|
|
|
| |
|
|
|
|
| |
| print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}") |
|
|
| |
| print(f"Number of visible GPUs: {torch.cuda.device_count()}") |
|
|
| if torch.cuda.is_available(): |
| for i in range(torch.cuda.device_count()): |
| print(f"GPU {i}: {torch.cuda.get_device_name(i)}") |
| print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB") |
|
|
|
|
| |
|
|
|
|
| |
| import torch, os, sys |
| print("torch.__version__:", torch.__version__) |
| print("torch.cuda.is_available():", torch.cuda.is_available()) |
| print("torch.version.cuda:", torch.version.cuda) |
| try: |
| print("has attr 'UnsupportedMutationAliasingException':", |
| hasattr(torch._subclasses.fake_tensor, "UnsupportedMutationAliasingException")) |
| except Exception as e: |
| print("Checking attribute failed:", type(e).__name__, e) |
|
|
|
|
| |
|
|
|
|
| |
|
|
|
|
| |
|
|
|
|
| |
| import torch |
| import json |
| import re |
| import time |
| from datasets import Dataset |
| from unsloth import FastLanguageModel |
|
|
| from trl import SFTTrainer |
| from transformers import TrainerCallback, TrainingArguments |
| import matplotlib.pyplot as plt |
|
|
| print("🔍 System Check:") |
| print("=" * 30) |
|
|
| |
| print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}") |
| print(f"Number of visible GPUs: {torch.cuda.device_count()}") |
|
|
| if torch.cuda.is_available(): |
| print(f"✅ GPU Available") |
| print(f" Current device: {torch.cuda.current_device()}") |
| print(f" GPU name: {torch.cuda.get_device_name(0)}") |
| print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| else: |
| print("❌ No GPU available!") |
|
|
| print(f"✅ PyTorch version: {torch.__version__}") |
|
|
|
|
| |
|
|
|
|
| import json |
| from datasets import Dataset |
| import os |
|
|
| print("\n📂 Loading Pre-Split Data and Transforming") |
| print("=" * 40) |
|
|
| LIMITED_VAL = False |
| LIMITED_TRAIN = False |
| NUMBER_OF_TRAIN_SAMPLES = 100 |
| NUMBER_OF_VAL_SAMPLES = 10 |
|
|
| |
| train_path = os.path.join(BASE_DATA_DIR, 'train_split.json') |
| val_path = os.path.join(BASE_DATA_DIR, 'val_split.json') |
|
|
| |
| print(f"Loading train split from: {train_path}") |
| with open(train_path, 'r', encoding='utf-8') as f: |
| raw_train_data = json.load(f) |
|
|
| |
| if MIXED_DATA: |
| train_data = raw_train_data |
| print(f" ✅ MIXED_DATA=True: Using all {len(train_data)} samples (UniADILR + COPA)") |
| else: |
| |
| train_data = [item for item in raw_train_data if 'context' in item] |
| print(f" ⚠️ MIXED_DATA=False: Filtered for UniADILR only.") |
| print(f" 📉 Training samples reduced from {len(raw_train_data)} to {len(train_data)}") |
|
|
| if LIMITED_TRAIN: |
| train_data = train_data[:NUMBER_OF_TRAIN_SAMPLES] |
| print(f" ⚠️ Limited training set to {len(train_data)} samples") |
|
|
| print(f"Loading validation split from: {val_path}") |
| with open(val_path, 'r', encoding='utf-8') as f: |
| raw_val_data = json.load(f) |
|
|
| |
| if MIXED_DATA: |
| val_data = raw_val_data |
| print(f" ✅ MIXED_DATA=True: Using all {len(val_data)} validation samples") |
| else: |
| val_data = [item for item in raw_val_data if 'context' in item] |
| print(f" 📉 Validation samples reduced from {len(raw_val_data)} to {len(val_data)}") |
|
|
| if LIMITED_VAL: |
| val_data = val_data[:NUMBER_OF_VAL_SAMPLES] |
| print(f" ⚠️ Limited validation set to {len(val_data)} samples") |
|
|
|
|
| def transform_to_prompt_format(example, record_id): |
| """ |
| Transform the original JSONL format to the required prompt format. |
| Handles both UniADILR and balanced_copa_cause_only datasets. |
| """ |
| dataset_name = example.get('datasetName', '') |
|
|
| |
| system_prompt_content = "" |
| user_content = "" |
| assistant_content = "" |
|
|
| if dataset_name == 'UniADILR': |
| |
| context_lines = [] |
| for key, value in example['context'].items(): |
| context_lines.append(f"{key}: {value}") |
| context_str = "\n".join(context_lines) |
|
|
| |
| user_content = f"""Context: |
| {context_str} |
| |
| Hypothesis: |
| {example['hypothesis']} |
| |
| Based on the context and hypothesis above, identify which sentence(s) provide the necessary evidence for the hypothesis.""" |
|
|
| system_prompt_content = SYSTEM_PROMPT_UniADILR |
|
|
| |
| |
| proof_str = example['proof'] |
| if '->' in proof_str: |
| proof_str = proof_str.split('->')[0] |
| numbers = re.findall(r'sent(\d+)', proof_str) |
| |
| formatted_answer = ", ".join(sorted(numbers, key=int)) |
|
|
| |
| |
| |
| |
| assistant_content = f"""<think> |
| The hypothesis requires evidence from the provided context. I will identify the sentences that support this claim. |
| </think> |
| |
| <answer> |
| {formatted_answer} |
| </answer>""" |
|
|
| ground_truth = json.dumps(example['proof']) |
|
|
| elif dataset_name == 'balanced_copa_cause_only': |
| |
| user_content = f"""Premise: {example['premise']} |
| |
| Question: {example['question']} |
| |
| Choice 1: {example['choice1']} |
| Choice 2: {example['choice2']} |
| |
| Which choice is the most plausible cause for the premise?""" |
|
|
| system_prompt_content = SYSTEM_PROMPT_balanced_copa_cause_only |
|
|
| |
| correct_choice = str(example['label'] + 1) |
| ground_truth = correct_choice |
|
|
| |
| assistant_content = f"""<think> |
| I need to identify the most plausible cause for the premise among the two choices. |
| </think> |
| |
| <answer> |
| {correct_choice} |
| </answer>""" |
|
|
| else: |
| raise ValueError(f"Unknown dataset name: {dataset_name}") |
|
|
| |
| |
| |
| prompt = [ |
| { |
| "role": "system", |
| "content": system_prompt_content |
| }, |
| { |
| "role": "user", |
| "content": user_content |
| }, |
| |
| { |
| "role": "assistant", |
| "content": assistant_content |
| } |
| ] |
|
|
| |
| return { |
| "prompt": prompt, |
| "record_id": record_id, |
| "ground_truth": ground_truth, |
| "reasoning_type": example.get('reasoning_type', 'abduction'), |
| "dataset_name": dataset_name |
| } |
|
|
| |
| print("\nTransforming train data to prompt format...") |
| train_transformed = [] |
| for idx, example in enumerate(train_data): |
| train_transformed.append(transform_to_prompt_format(example, record_id=idx)) |
|
|
| print("Transforming validation data to prompt format...") |
| val_transformed = [] |
| for idx, example in enumerate(val_data): |
| val_transformed.append(transform_to_prompt_format(example, record_id=idx)) |
|
|
| |
| |
| |
| |
|
|
| print(f"✅ Transformed all splits") |
|
|
| |
| print("\nConverting to HuggingFace datasets...") |
| train_ds = Dataset.from_list(train_transformed) |
| val_ds = Dataset.from_list(val_transformed) |
| |
|
|
| |
| print("\n" + "="*80) |
| print("🔍 FIRST TRAINING EXAMPLE (to verify system prompt)") |
| print("="*80) |
| first_example = train_ds[0] |
| print(f"\n📋 Example keys: {list(first_example.keys())}") |
| print(f"\n🆔 Record ID: {first_example.get('record_id', 'N/A')}") |
| print("\n💬 PROMPT STRUCTURE:") |
| print("-" * 80) |
| for i, msg in enumerate(first_example['prompt']): |
| role = msg.get('role', 'unknown') |
| content = msg.get('content', '') |
| print(f"\n[Message {i+1}] Role: {role.upper()}") |
| print("-" * 40) |
| |
| if len(content) > 500: |
| print(f"{content[:500]}...") |
| print(f"\n... (Content truncated - total length: {len(content)} characters)") |
| else: |
| print(content) |
| print("-" * 40) |
|
|
| |
| log_file = './prompt_structure_log.txt' |
| with open(log_file, 'w', encoding='utf-8') as f: |
| for i, msg in enumerate(first_example['prompt']): |
| role = msg.get('role', 'unknown') |
| content = msg.get('content', '') |
| f.write(f"\n[Message {i+1}] Role: {role.upper()}\n") |
| f.write("-" * 40 + "\n") |
| f.write(content + "\n") |
| f.write("-" * 40 + "\n") |
|
|
| print(f"✅ Prompt structure logged to: {log_file}") |
|
|
|
|
| print("\n" + "="*80) |
|
|
|
|
| |
| |
| print(f"\n✅ Datasets loaded, transformed, and ready!") |
| print(f"\n📈 Dataset Statistics:") |
| |
| print(f" Training samples: {len(train_ds):,}") |
| print(f" Validation samples: {len(val_ds):,}") |
| |
|
|
|
|
| |
|
|
|
|
| |
| print("\n🛠️ Verifying Loaded Datasets") |
| print("=" * 35) |
|
|
| |
| prompt_lengths = [] |
| full_lengths = [] |
| for ds in [train_ds, val_ds]: |
| for example in ds: |
| |
| current_full_length = 0 |
| user_found = False |
|
|
| for msg in example['prompt']: |
| content = msg.get('content', '') |
| current_full_length += len(content) |
|
|
| if isinstance(msg, dict) and msg.get('role') == 'user': |
| prompt_lengths.append(len(content)) |
| user_found = True |
|
|
| if user_found: |
| full_lengths.append(current_full_length) |
|
|
| print(f"✅ Datasets ready for training!") |
| print(f" Total prompts: {len(prompt_lengths):,}") |
| print(f" Max user prompt length: {max(prompt_lengths)} characters") |
| print(f" Average user prompt length: {sum(prompt_lengths)/len(prompt_lengths):.0f} characters") |
| |
| print(f" Max full length (est): {max(full_lengths)} characters") |
| print(f" Average full length (est): {sum(full_lengths)/len(full_lengths):.0f} characters") |
|
|
| print(f"\n Sample keys in training data: {list(train_ds[0].keys())}") |
|
|
| |
| print(f"\n📋 Example answer from first training sample:") |
| print(f" answer: {train_ds[0]['ground_truth']}") |
| print(f" answer type: {type(train_ds[0]['ground_truth'])}") |
|
|
| |
| print(f"\n📝 Example user prompt (first 200 chars):") |
| for msg in train_ds[0]['prompt']: |
| if isinstance(msg, dict) and msg.get('role') == 'user': |
| user_content = msg.get('content', '') |
| print(f" {user_content[:200]}...") |
| break |
|
|
|
|
| |
|
|
|
|
| from unsloth import FastLanguageModel, is_bfloat16_supported |
| from huggingface_hub import HfApi |
| import os |
| from tqdm.auto import tqdm |
| import time |
|
|
| start_time = time.time() |
|
|
| def format_bytes(bytes_value): |
| """Convert bytes to human-readable format""" |
| for unit in ['B', 'KB', 'MB', 'GB', 'TB']: |
| if bytes_value < 1024.0: |
| return f"{bytes_value:.2f} {unit}" |
| bytes_value /= 1024.0 |
| return f"{bytes_value:.2f} PB" |
|
|
| def get_model_size(model_name): |
| """Try to get model size from HuggingFace Hub""" |
| try: |
| api = HfApi() |
| model_info = api.model_info(model_name) |
| |
| total_size = sum(file.size for file in model_info.siblings if file.size) |
| return total_size |
| except: |
| return None |
|
|
| |
| print("🔧 Configuring Hugging Face Hub download settings...") |
| print("=" * 60) |
| os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "240" |
| print("✓ Download timeout: 240 seconds per chunk") |
| print("✓ Using default retry settings") |
| print() |
|
|
| |
| print("📊 Fetching model information...") |
| model_size = get_model_size(MODEL_NAME) |
| if model_size: |
| print(f"✓ Model size: {format_bytes(model_size)}") |
| print(f"✓ Estimated download time: ~{model_size / (10 * 1024 * 1024):.0f} seconds (at 10 MB/s)") |
| else: |
| print("⚠ Could not determine model size") |
| print() |
|
|
| |
| print("🤖 Model Setup") |
| print("=" * 60) |
| print(f"📦 Model: {MODEL_NAME}") |
| print(f"🔢 Max sequence length: {MAX_SEQ_LENGTH}") |
| print(f"⚙️ Quantization: {'4-bit' if LOAD_IN_4BIT else '8-bit' if LOAD_IN_8BIT else 'None'}") |
| print(f"🚀 Fast inference (vLLM): {USE_VLLM}") |
| print(f"💾 GPU memory utilization: {GPU_MEMORY_UTILIZATION}") |
| print() |
|
|
| print("⏳ Downloading and loading model...") |
| print(" (This may take several minutes depending on your connection)") |
| print() |
|
|
| download_start = time.time() |
|
|
| |
| class ProgressCallback: |
| def __init__(self): |
| self.last_print = time.time() |
| self.dots = 0 |
|
|
| def update(self): |
| current = time.time() |
| if current - self.last_print > 2: |
| self.dots = (self.dots + 1) % 4 |
| elapsed = current - download_start |
| print(f"\r Downloading{'.' * (self.dots + 1)}{' ' * (3 - self.dots)} " + |
| f"[{elapsed:.0f}s elapsed]", end='', flush=True) |
| self.last_print = current |
|
|
| progress = ProgressCallback() |
|
|
| try: |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=MODEL_NAME, |
| max_seq_length=MAX_SEQ_LENGTH, |
| max_length=MAX_SEQ_LENGTH, |
| load_in_4bit=LOAD_IN_4BIT, |
| load_in_8bit=LOAD_IN_8BIT, |
| fast_inference=USE_VLLM, |
| max_lora_rank=LORA_RANK, |
| gpu_memory_utilization=GPU_MEMORY_UTILIZATION, |
| ) |
| print("\r" + " " * 80 + "\r", end='') |
|
|
| download_time = time.time() - download_start |
| print(f"✅ Model downloaded and loaded successfully!") |
| print(f"⏱️ Total time: {download_time:.1f}s ({download_time/60:.1f} minutes)") |
|
|
| if model_size: |
| avg_speed = model_size / download_time |
| print(f"📈 Average speed: {format_bytes(avg_speed)}/s") |
| print() |
|
|
| except Exception as e: |
| print(f"\n❌ Error loading model: {e}") |
| raise |
|
|
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| print("✓ Configured pad token") |
| print() |
|
|
| |
| print("🔧 Applying LoRA configuration...") |
| print("-" * 60) |
|
|
| lora_start = time.time() |
|
|
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=LORA_RANK, |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| lora_alpha=LORA_ALPHA, |
|
|
| |
| lora_dropout=LORA_DROPOUT, |
|
|
| use_gradient_checkpointing="unsloth", |
| random_state=RANDOM_STATE |
| ) |
|
|
| lora_time = time.time() - lora_start |
|
|
| print(f"✅ LoRA configured successfully! ({lora_time:.1f}s)") |
| print() |
|
|
| |
| print("📊 Model Statistics") |
| print("=" * 60) |
| print(f"🎯 LoRA Configuration:") |
| print(f" • Rank (r): {LORA_RANK}") |
| print(f" • Alpha: {LORA_ALPHA}") |
| print(f" • Target modules: 7 (q, k, v, o, gate, up, down projections)") |
| print() |
|
|
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in model.parameters()) |
| frozen_params = total_params - trainable_params |
|
|
| print(f"🔢 Parameters:") |
| print(f" • Total: {total_params:,}") |
| print(f" • Trainable: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)") |
| print(f" • Frozen: {frozen_params:,} ({100*frozen_params/total_params:.2f}%)") |
| print() |
|
|
| total_setup_time = time.time() - start_time |
| print(f"⏱️ Total Setup Time: {total_setup_time:.1f}s ({total_setup_time/60:.1f} minutes)") |
| print(f" • Model download/load: {download_time:.1f}s") |
| print(f" • LoRA configuration: {lora_time:.1f}s") |
| print("=" * 60) |
| print("✨ Ready to train!") |
|
|
|
|
| |
|
|
|
|
| import logging |
|
|
| |
| |
| print("\n🎯 Evaluation Metrics & Output Setup") |
| print("=" * 30) |
|
|
| |
| run_name = get_run_name() |
| results_dir = get_results_dir(run_name) |
|
|
| |
| os.makedirs(results_dir, exist_ok=True) |
| os.makedirs(os.path.join(results_dir, "checkpoint"), exist_ok=True) |
|
|
| logging.basicConfig( |
| filename=os.path.join(results_dir, ERROR_LOG_PATH), |
| level=logging.WARNING, |
| format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s() - %(message)s' |
| ) |
|
|
| print(f"📁 Results directory: {results_dir}") |
| print(f"🏷️ Run name: {run_name}") |
|
|
| |
| def extract_sentence_numbers(text, datasetName): |
| """Extract sentence numbers from model output if UniADILR, if not no need to do anything. |
| |
| Looks for content within <answer> tags and extracts comma-separated numbers. |
| Returns a set of integers. |
| """ |
| |
| answer_match = re.search(r'<answer>\s*([^<]+?)\s*</answer>', text, re.IGNORECASE | re.DOTALL) |
|
|
| if answer_match: |
| answer_content = answer_match.group(1) |
| else: |
| |
| answer_content = "" |
|
|
| if datasetName == 'UniADILR': |
| |
| numbers = re.findall(r'\b(\d+)\b', answer_content) |
|
|
| return set(int(n) for n in numbers) |
| elif datasetName == 'balanced_copa_cause_only': |
| if answer_content == '': |
| return -123 |
| return int(answer_content) |
| else: |
| try: |
| return int(answer_content) |
| except: |
| return -1 |
|
|
| def parse_proof(proof_str, datasetName): |
| """ |
| If datasetName is UniADILR, |
| Parse ground truth proof string to extract sentence numbers. |
| |
| Example: 'sent5 & sent13 -> hypothesis' returns {5, 13} |
| |
| if datasetName is 'copa' just return the answer number |
| |
| Example: 2 returns 2 |
| """ |
| if datasetName == 'UniADILR': |
| |
| if '->' in proof_str: |
| proof_str = proof_str.split('->')[0] |
|
|
| numbers = re.findall(r'sent(\d+)', proof_str) |
| return set(int(n) for n in numbers) |
| elif datasetName == 'balanced_copa_cause_only': |
| return int(proof_str) |
| else: |
| return int(proof_str) |
|
|
| class AbductiveRewardFunction: |
| """ |
| MODIFIED: In SFT, this class acts as a Metrics Calculator for Validation. |
| It calculates 'Exact Match' accuracy which we log as 'reward' to maintain compatibility |
| with your existing visualization tools. |
| """ |
|
|
| def __init__(self, dataset, tokenizer, output_path, log_every=50): |
| self.dataset = dataset |
| self.tokenizer = tokenizer |
| self.output_path = output_path |
| self.current_epoch = 1 |
| self.training_log = [] |
| self.step_losses = [] |
| self.log_every = log_every |
|
|
| print("🛠️ Building prompt-to-[ground_truth, datasetName] lookup table for evaluation...") |
| self.lookup_table = {} |
| missing_ground_truths = 0 |
| flag = False |
| for record in self.dataset: |
| |
| |
| if not flag: |
| print(f"prompt before apply chat template 1: {record['prompt'][1]['content']}") |
|
|
| prompt_text = record['prompt'][1]['content'] |
| datasetName = record['dataset_name'] |
| if not flag: |
| print(f"prompt_text: {prompt_text}") |
| flag = True |
| ground_truth = record.get('ground_truth', '') |
| if ground_truth: |
| |
| |
| self.lookup_table[prompt_text] = [ground_truth, datasetName] |
| else: |
| missing_ground_truths += 1 |
|
|
| print(f"✅ Lookup table built. Contains {len(self.lookup_table)} entries.") |
| if missing_ground_truths > 0: |
| print(f" ⚠️ Warning: {missing_ground_truths} records in the dataset were missing a 'ground_truths' field.") |
|
|
| def set_epoch(self, epoch): |
| self.current_epoch = epoch |
|
|
| def record_loss(self, step, loss): |
| self.step_losses.append({"step": step, "loss": loss}) |
|
|
| def __call__(self, completions, prompts, **kwargs): |
| """ |
| Calculate accuracy (reward) using the pre-computed lookup table. |
| Used during validation generation steps. |
| |
| Args: |
| completions: List of generated text strings for each prompt in the batch. |
| prompts: List of the formatted input text strings (or list of dicts). |
| """ |
| rewards = [] |
|
|
| |
| for i, (prompt_text, completion_text) in enumerate(zip(prompts, completions)): |
| try: |
| |
| |
| |
| user_content = prompt_text[1]['content'] |
|
|
| content_of_look_up_table = self.lookup_table.get(user_content) |
|
|
| if content_of_look_up_table is None: |
| |
| logging.warning(f"Prompt not found in lookup table. User content snippet: {user_content[:50]}...") |
| rewards.append(0.0) |
| continue |
|
|
| ground_truth_proof = content_of_look_up_table[0] |
| datasetName = content_of_look_up_table[1] |
|
|
| ground_truth = parse_proof(ground_truth_proof, datasetName) |
|
|
| |
| |
| |
| actual_text = completion_text |
| if isinstance(completion_text, list): |
| actual_text = completion_text[0]['content'] if isinstance(completion_text[0], dict) else completion_text[0] |
|
|
| predicted = extract_sentence_numbers(actual_text, datasetName) |
|
|
| |
| reward = 1.0 if predicted == ground_truth else 0.0 |
| rewards.append(reward) |
|
|
| if datasetName == 'UniADILR': |
| log_entry = { |
| 'epoch': self.current_epoch, |
| 'batch_idx': i, |
| 'dataset_name': datasetName, |
| 'input': prompt_text, |
| 'ground_truth': sorted(list(ground_truth)), |
| 'predicted': sorted(list(predicted)), |
| 'reward': reward, |
| 'completion': completion_text, |
| } |
| elif datasetName == 'balanced_copa_cause_only': |
| log_entry = { |
| 'epoch': self.current_epoch, |
| 'batch_idx': i, |
| 'dataset_name': datasetName, |
| 'input': prompt_text, |
| 'ground_truth': ground_truth, |
| 'predicted': predicted, |
| 'reward': reward, |
| 'completion': completion_text, |
| } |
| else: |
| log_entry = { |
| 'epoch': self.current_epoch, |
| 'batch_idx': i, |
| 'dataset_name': datasetName, |
| 'input': prompt_text, |
| 'ground_truth': ground_truth, |
| 'predicted': predicted, |
| 'reward': reward, |
| 'completion': completion_text, |
| } |
| self.training_log.append(log_entry) |
|
|
| except Exception as e: |
| logging.exception(f"Error calculating metric for item {i}: {e}") |
| rewards.append(0.0) |
|
|
| |
| if len(self.training_log) > 0 and len(self.training_log) % self.log_every == 0: |
| try: |
| with open(self.output_path, 'w', encoding='utf-8') as f: |
| json.dump(self.training_log, f, ensure_ascii=False, indent=2) |
|
|
| recent_rewards = [r['reward'] for r in self.training_log[-self.log_every:]] |
| avg_reward = sum(recent_rewards) / len(recent_rewards) if recent_rewards else 0.0 |
| print(f" 💾 Saved {len(self.training_log)} completions log | Recent avg accuracy: {avg_reward:.3f}") |
| except Exception as e: |
| logging.warning(f"Failed to save validation log: {e}") |
|
|
| return rewards |
|
|
| def evaluate_batch(self, completions, record_ids, validation_dataset=None): |
| """Evaluate a batch of completions against ground truth. |
| |
| Args: |
| completions: List of model outputs |
| record_ids: List of indices into the dataset |
| validation_dataset: Optional validation dataset |
| |
| Returns: |
| List of dicts with reward, predicted, ground_truth, etc. |
| """ |
| results = [] |
|
|
| |
| dataset_to_use = validation_dataset if validation_dataset is not None else self.dataset |
|
|
| |
| try: |
| records = [dataset_to_use[i] for i in record_ids] |
| except (IndexError, TypeError) as e: |
| logging.error(f"Failed to fetch records for evaluation using record_ids. Error: {e}") |
| return [] |
|
|
| for idx, (completion, record) in enumerate(zip(completions, records)): |
| try: |
| ground_truth_numbers = record.get('ground_truth', '') |
| datasetName = record.get('dataset_name', '') |
| ground_truth = parse_proof(ground_truth_numbers, datasetName) |
|
|
| |
| predicted = extract_sentence_numbers(completion, datasetName) |
|
|
| |
| reward = 1.0 if predicted == ground_truth else 0.0 |
|
|
| |
| input_prompt = record.get('prompt', []) |
| user_content = "" |
| for msg in input_prompt: |
| if isinstance(msg, dict) and msg.get('role') == 'user': |
| user_content = msg.get('content', '') |
| break |
|
|
| |
| result_entry = { |
| 'reward': reward, |
| 'predicted': sorted(list(predicted)) if isinstance(predicted, set) else predicted, |
| 'ground_truth': sorted(list(ground_truth)) if isinstance(ground_truth, set) else ground_truth, |
| 'completion': completion, |
| 'input': user_content, |
| 'dataset_name': datasetName, |
| } |
| results.append(result_entry) |
|
|
| |
| log_entry = { |
| 'epoch': self.current_epoch, |
| 'record_id': record.get('record_id', idx), |
| 'dataset_name': datasetName, |
| 'input': user_content, |
| 'ground_truth': result_entry['ground_truth'], |
| 'predicted': result_entry['predicted'], |
| 'reward': reward, |
| 'completion': completion, |
| } |
|
|
| self.training_log.append(log_entry) |
|
|
| except Exception as e: |
| logging.exception(f"Error evaluating completion {idx}: {e}") |
| results.append({ |
| 'reward': 0.0, |
| 'predicted': [], |
| 'ground_truth': [], |
| 'completion': completion, |
| 'input': '', |
| 'dataset_name': datasetName, |
| }) |
|
|
| return results |
|
|
| |
| |
| reward_fn = AbductiveRewardFunction( |
| dataset=train_ds, |
| tokenizer=tokenizer, |
| output_path=os.path.join(results_dir, TRAINING_LOG_PATH), |
| log_every=LOG_TRAIN_EVERY |
| ) |
| reward_fn.__name__ = "AbductiveValidationScorer" |
|
|
| print(f"✅ Validation Metrics configured") |
| print(f" Type: Exact match (order-independent)") |
| print(f" Output file: {TRAINING_LOG_PATH}") |
| print(f" Log frequency: Every {LOG_TRAIN_EVERY} completions") |
|
|
|
|
| |
|
|
|
|
| from unsloth import is_bfloat16_supported |
| from collections import defaultdict |
|
|
| |
| print("\n⚙️ Training Configuration") |
| print("=" * 30) |
|
|
| |
| training_args = TrainingArguments( |
| learning_rate=LEARNING_RATE, |
| adam_beta1=ADAM_BETA1, |
| adam_beta2=ADAM_BETA2, |
| weight_decay=WEIGHT_DECAY, |
| warmup_steps=WARMUP_STEPS, |
| lr_scheduler_type=LR_SCHEDULER_TYPE, |
| dataloader_num_workers = 0, |
| optim=OPTIM, |
| logging_steps=1, |
|
|
| |
| save_strategy="epoch", |
| save_total_limit=None, |
| load_best_model_at_end=False, |
| |
| per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE, |
| gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| num_train_epochs=NUM_TRAIN_EPOCHS, |
| |
| max_grad_norm=MAX_GRAD_NORM, |
| report_to=["tensorboard"], |
| run_name=None, |
| output_dir=os.path.join(results_dir, "checkpoint"), |
|
|
| |
| fp16 = not is_bfloat16_supported(), |
| bf16 = is_bfloat16_supported(), |
| ) |
|
|
| print(f"Training Parameters:") |
| print(f" Learning rate: {LEARNING_RATE}") |
| print(f" Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}") |
| print(f" Epochs: {NUM_TRAIN_EPOCHS:,}") |
| print(f" Save every: {SAVE_STEPS} steps") |
| print(f" Max grad norm: {MAX_GRAD_NORM}") |
| |
| print(f" Warmup steps: {WARMUP_STEPS}") |
| print(f" Weight decay: {WEIGHT_DECAY}") |
|
|
|
|
| |
|
|
|
|
| from transformers import DataCollatorWithPadding |
| from tqdm.auto import tqdm |
| from collections import defaultdict |
|
|
| |
| try: |
| from vllm import SamplingParams |
| except ImportError: |
| class SamplingParams: |
| def __init__(self, temperature, top_p, max_tokens): |
| self.temperature = temperature |
| self.top_p = top_p |
| self.max_tokens = max_tokens |
|
|
| print("\n🔄 Setting up Training Callbacks with Validation") |
| print("=" * 45) |
|
|
| |
| sampling_params = SamplingParams( |
| temperature=0.0, |
| top_p=1.0, |
| max_tokens=MAX_COMPLETION_LENGTH, |
| ) |
|
|
| class EnhancedEpochCallback(TrainerCallback): |
| """ |
| Custom callback to log epoch progress, manage metrics, and handle validation. |
| - Logs start and end of each epoch. |
| - Records step losses for the metric function. |
| - Triggers validation at the end of each epoch and after every EVAL_STEPS steps. |
| """ |
| def __init__(self, reward_fn, val_dataset, results_dir, use_vllm=False, eval_interval=EVAL_STEPS): |
| self.reward_fn = reward_fn |
| self.val_dataset = val_dataset |
| self.step_count = 0 |
| self.start_time = None |
| self.validation_metrics = {} |
| self.results_dir = results_dir |
| self.trainer = None |
| self.formatted_inputs = None |
| self.use_vllm = use_vllm |
| self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
| self.eval_interval = eval_interval |
| |
| |
| self.best_accuracy = -1.0 |
|
|
| def on_train_begin(self, args, state, control, **kwargs): |
| self.start_time = time.time() |
| print(f"🚀 Training started at {time.strftime('%Y-%m-%d %H:%M:%S')}") |
| |
| |
| |
| |
| print(" Preparing validation prompts (stripping assistant answers)...") |
| val_prompts_input_only = [] |
| for conversation in self.val_dataset['prompt']: |
| |
| input_msgs = [msg for msg in conversation if msg['role'] != 'assistant'] |
| val_prompts_input_only.append(input_msgs) |
|
|
| self.formatted_inputs = self.trainer.processing_class.apply_chat_template( |
| val_prompts_input_only, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| def on_epoch_begin(self, args, state, control, **kwargs): |
| epoch_idx = int(state.epoch) + 1 |
| self.reward_fn.set_epoch(epoch_idx) |
| print(f"\n📍 Starting epoch {epoch_idx}") |
|
|
| def on_step_end(self, args, state, control, **kwargs): |
| current_loss = 'N/A' |
| if state.log_history: |
| current_loss = state.log_history[-1].get("loss", 'N/A') |
| if current_loss != 'N/A': |
| self.reward_fn.record_loss(state.log_history[-1]['step'], current_loss) |
| |
| self.step_count += 1 |
| if self.step_count % 50 == 0: |
| elapsed = time.time() - self.start_time |
| steps_per_sec = self.step_count / elapsed |
| print(f" Step {self.step_count} | Loss: {current_loss} | Speed: {steps_per_sec:.2f} steps/s") |
|
|
| if ( |
| LOG_VALIDATION |
| and self.eval_interval |
| and (self.step_count % self.eval_interval == 0) |
| and self.trainer |
| ): |
| self.evaluate_validation( |
| self.trainer.model, |
| self.trainer.processing_class, |
| state.global_step, |
| ) |
|
|
| def evaluate_validation(self, model, tokenizer, step, epoch_override=None): |
| print(f"\n🔍 Validation at step {step}:") |
|
|
| try: |
| val_rewards = [] |
| validation_log = [] |
| |
| |
| dataset_stats = defaultdict(list) |
| |
| batch_size = PER_DEVICE_EVAL_BATCH_SIZE |
| |
| |
| total_samples = len(self.val_dataset) |
| batch_iterator = range(0, total_samples, batch_size) |
|
|
| |
| if self.formatted_inputs is None: |
| val_prompts_input_only = [] |
| for conversation in self.val_dataset['prompt']: |
| input_msgs = [msg for msg in conversation if msg['role'] != 'assistant'] |
| val_prompts_input_only.append(input_msgs) |
| self.formatted_inputs = tokenizer.apply_chat_template( |
| val_prompts_input_only, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| with torch.no_grad(): |
| for batch_num in tqdm(batch_iterator, desc=" Generating & Evaluating", unit="batch", leave=False): |
| FastLanguageModel.for_inference(model) |
| batch = self.formatted_inputs[batch_num:batch_num + batch_size] |
|
|
| if self.use_vllm: |
| outputs = model.fast_generate( |
| batch, |
| lora_request=None, |
| sampling_params=sampling_params, |
| ) |
| completions = [o.outputs[0].text.strip() for o in outputs] |
| else: |
| batch_encodings = tokenizer(batch, return_tensors="pt", padding=True).to(model.device) |
| |
| |
| gen_kwargs = { |
| "max_new_tokens": sampling_params.max_tokens, |
| } |
| |
| |
| if sampling_params.temperature < 1e-5: |
| gen_kwargs["do_sample"] = False |
| else: |
| gen_kwargs["do_sample"] = True |
| gen_kwargs["temperature"] = sampling_params.temperature |
| gen_kwargs["top_p"] = sampling_params.top_p |
|
|
| outputs = model.generate(**batch_encodings, **gen_kwargs) |
| |
| prompt_lengths = batch_encodings["input_ids"].shape[1] |
| generated_tokens = outputs[:, prompt_lengths:] |
| completions = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
| batch_indices = list(range(batch_num, batch_num + len(completions))) |
| results = self.reward_fn.evaluate_batch(completions, batch_indices, validation_dataset=self.val_dataset) |
|
|
| for batch_idx, result in enumerate(results): |
| reward = result["reward"] |
| ds_name = result["dataset_name"] |
| |
| val_rewards.append(reward) |
| dataset_stats[ds_name].append(reward) |
| |
| validation_log.append({ |
| "record_id": self.val_dataset['record_id'][batch_num + batch_idx], |
| "dataset_name": ds_name, |
| "input": result.get("input", ""), |
| "ground_truth": result["ground_truth"], |
| "predicted": result["predicted"], |
| "reward": reward, |
| "completion": result["completion"], |
| }) |
|
|
| FastLanguageModel.for_training(model) |
|
|
| if val_rewards: |
| avg_val_reward = sum(val_rewards) / len(val_rewards) |
|
|
| |
| print(f" 📊 Overall Accuracy: {avg_val_reward:.4f}") |
| print(" 📈 Breakdown by Dataset:") |
| for name, scores in dataset_stats.items(): |
| avg_score = sum(scores) / len(scores) |
| print(f" • {name}: {avg_score:.4f} (n={len(scores)})") |
|
|
| |
| if avg_val_reward > self.best_accuracy: |
| self.best_accuracy = avg_val_reward |
| print(f" 🌟 New Best Accuracy: {self.best_accuracy:.4f}") |
| |
| if self.trainer: |
| best_model_path = os.path.join(self.results_dir, "best_model") |
| print(f" 💾 Saving best model to: {best_model_path}") |
| self.trainer.save_model(best_model_path) |
| |
| if tokenizer: |
| tokenizer.save_pretrained(best_model_path) |
|
|
| |
| if epoch_override is not None: |
| epoch_key = str(epoch_override) |
| elif self.trainer: |
| epoch_key = str(int(self.trainer.state.epoch)) |
| else: |
| epoch_key = "0" |
|
|
| self.validation_metrics[epoch_key] = { |
| 'avg_reward': avg_val_reward, |
| 'num_samples': len(val_rewards), |
| 'breakdown': {k: sum(v)/len(v) for k, v in dataset_stats.items()} |
| } |
|
|
| |
| val_log_path = os.path.join(self.results_dir, VALIDATION_LOG_PATH) |
| existing_data = {} |
| if os.path.exists(val_log_path): |
| with open(val_log_path, "r", encoding="utf-8") as f: |
| existing_data = json.load(f) |
|
|
| existing_data[epoch_key] = validation_log |
| with open(val_log_path, "w", encoding="utf-8") as f: |
| json.dump(existing_data, f, ensure_ascii=False, indent=2) |
|
|
| |
| val_metrics_path = os.path.join(self.results_dir, VALIDATION_METRICS_PATH) |
| all_metrics = {} |
| if os.path.exists(val_metrics_path): |
| with open(val_metrics_path, "r", encoding="utf-8") as f: |
| all_metrics = json.load(f) |
|
|
| all_metrics[epoch_key] = self.validation_metrics[epoch_key] |
| with open(val_metrics_path, "w", encoding="utf-8") as f: |
| json.dump(all_metrics, f, ensure_ascii=False, indent=2) |
|
|
| |
| try: |
| with open(self.reward_fn.output_path, 'w', encoding='utf-8') as f: |
| json.dump(self.reward_fn.training_log, f, ensure_ascii=False, indent=2) |
| except Exception as e: |
| logging.warning(f"Failed to save training log: {e}") |
| else: |
| logging.warning(f"⚠️ No validation rewards computed. Step: {step}") |
|
|
| except Exception as e: |
| logging.exception(f"❌ Validation error: {e}") |
|
|
| def on_epoch_end(self, args, state, control, **kwargs): |
| completed_epoch_idx = int(state.epoch) |
| print(f"✅ Completed epoch {completed_epoch_idx}") |
|
|
| |
| if LOG_VALIDATION: |
| if self.trainer: |
| |
| self.evaluate_validation(self.trainer.model, self.trainer.processing_class, state.global_step) |
| else: |
| logging.warning("⚠️ No trainer assigned; cannot evaluate validation.") |
|
|
| def on_save(self, args, state, control, **kwargs): |
| print(f"💾 Checkpoint saved at step {state.global_step}") |
|
|
| |
| enhanced_callback = EnhancedEpochCallback( |
| reward_fn=reward_fn, |
| val_dataset=val_ds, |
| results_dir=results_dir, |
| use_vllm=USE_VLLM, |
| |
| ) |
|
|
| print("✅ Enhanced callbacks configured:") |
| print(" - Epoch management") |
| print(" - Progress tracking with loss") |
| print(" - Validation evaluation (Assistant answers stripped for generation)") |
| print(" - Validation JSON logging") |
| print(" - Checkpoint notifications") |
| print(f" - Validation every {EVAL_STEPS} steps") |
|
|
|
|
| |
|
|
|
|
| |
| print("\n🏗️ Creating Trainer with Validation") |
| print("=" * 35) |
|
|
| |
| |
| def formatting_prompts_func(example): |
| """ |
| Format prompts for SFT training. |
| |
| Unsloth calls this with a SINGLE example during validation check, |
| and with batched examples during training. |
| Must ALWAYS return a list of strings. |
| """ |
| convos = example["prompt"] |
| texts = [] |
|
|
| |
| if not convos: |
| return [] |
|
|
| |
| |
| |
| |
| if isinstance(convos[0], dict): |
| |
| text = tokenizer.apply_chat_template(convos, tokenize=False, add_generation_prompt=False) |
| texts.append(text) |
| else: |
| |
| for convo in convos: |
| text = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) |
| texts.append(text) |
|
|
| return texts |
|
|
| try: |
| |
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=train_ds, |
| eval_dataset=val_ds, |
| formatting_func=formatting_prompts_func, |
| args=training_args, |
| packing=False, |
| max_seq_length=MAX_SEQ_LENGTH, |
| ) |
|
|
| enhanced_callback.trainer = trainer |
| trainer.add_callback(enhanced_callback) |
|
|
| print("✅ Trainer created successfully!") |
| print(f" Platform: Local") |
| print(f" Model: {type(model).__name__}") |
| print(f" Training samples: {len(train_ds):,}") |
| print(f" Validation samples: {len(val_ds):,}") |
| print(f" Callbacks: {len(trainer.callback_handler.callbacks)}") |
|
|
| except Exception as e: |
| logging.exception(f"❌ Failed to create trainer: {e}") |
| raise |
|
|
| print(f"\n📋 Training Summary:") |
| print(f" Total training epochs: {NUM_TRAIN_EPOCHS}") |
| print(f" Effective batch size: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}") |
| print(f" Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}") |
| print(f" Output directory: {results_dir}") |
|
|
|
|
| |
|
|
|
|
| import sys |
| from datetime import datetime |
| import signal |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler('training_log.log'), |
| logging.StreamHandler(sys.stdout) |
| ] |
| ) |
|
|
| |
| from transformers import TrainerCallback |
| import math |
|
|
| class DetailedProgressCallback(TrainerCallback): |
| def __init__(self): |
| self.start_time = time.time() |
| self.step_times = [] |
| self.last_log_time = time.time() |
|
|
| def on_step_begin(self, args, state, control, **kwargs): |
| """Called at the beginning of each training step""" |
| current_time = time.time() |
| |
| if state.global_step % 10 == 0 or (current_time - self.last_log_time) > 30: |
| elapsed = current_time - self.start_time |
| steps_per_sec = state.global_step / elapsed if elapsed > 0 else 0 |
|
|
| |
| remaining_steps = state.max_steps - state.global_step |
| eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0 |
| eta_str = time.strftime('%H:%M:%S', time.gmtime(eta_seconds)) |
|
|
| progress_pct = (state.global_step / state.max_steps) * 100 |
|
|
| print(f"\r⏳ Step {state.global_step}/{state.max_steps} ({progress_pct:.1f}%) | " |
| f"Speed: {steps_per_sec:.2f} steps/s | ETA: {eta_str} | " |
| f"Epoch: {state.epoch:.1f}", end='', flush=True) |
|
|
| self.last_log_time = current_time |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| """Called when logging occurs""" |
| if logs: |
| print() |
| log_items = [] |
| for k, v in logs.items(): |
| if k == 'epoch': continue |
| if isinstance(v, float): |
| if 'learning_rate' in k or 'lr' in k: |
| |
| val_str = f"{v:.2e}" |
| else: |
| |
| val_str = f"{v:.4f}" |
| else: |
| val_str = f"{v}" |
| log_items.append(f"{k}: {val_str}") |
|
|
| log_str = " | ".join(log_items) |
| print(f"📊 {log_str}") |
| logging.info(log_str) |
|
|
| def on_epoch_end(self, args, state, control, **kwargs): |
| """Called at the end of each epoch""" |
| print() |
| elapsed = time.time() - self.start_time |
| print(f"\n✅ Epoch {int(state.epoch)} completed | " |
| f"Total time: {elapsed/60:.1f}m | " |
| f"Steps: {state.global_step}/{state.max_steps}") |
| logging.info(f"Epoch {int(state.epoch)} completed") |
|
|
| def on_train_begin(self, args, state, control, **kwargs): |
| """Called at the start of training""" |
| print(f"\n🎯 SFT Training will run for {state.max_steps} steps") |
| print(f"📝 Logging every {args.logging_steps} steps") |
| print(f"💾 Saving checkpoints every {args.save_steps} steps") |
| print("-" * 70) |
| logging.info("SFT Training started") |
|
|
| |
| progress_callback = DetailedProgressCallback() |
| trainer.add_callback(progress_callback) |
|
|
| |
| def signal_handler(sig, frame): |
| print("\n⚠️ Interrupt signal received. Saving progress...") |
| logging.warning("Training interrupted by user") |
| trainer.save_model(os.path.join(results_dir, "checkpoint", "interrupted")) |
| sys.exit(0) |
|
|
| signal.signal(signal.SIGINT, signal_handler) |
|
|
| |
| print("\n🚀 Starting SFT Training") |
| print("=" * 70) |
| print(f"⏰ Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
| print(f"🏷️ Run name: {run_name}") |
| print(f"📁 Output directory: {results_dir}") |
| print(f"🔍 Logs will be saved to: training_log.log") |
| print("-" * 70) |
|
|
| |
| logging.info(f"Starting SFT training run: {run_name}") |
| logging.info(f"Output directory: {results_dir}") |
| logging.info(f"Training config: epochs={NUM_TRAIN_EPOCHS}, batch_size={PER_DEVICE_TRAIN_BATCH_SIZE}") |
|
|
| training_start_time = time.time() |
| last_checkpoint_time = training_start_time |
|
|
| try: |
| |
| print("🔍 Verifying trainer configuration...") |
| print(f" • Total training steps: {trainer.args.max_steps}") |
| print(f" • Steps per epoch: {len(trainer.get_train_dataloader())}") |
| print(f" • Logging interval: {trainer.args.logging_steps} steps") |
| print(f" • Save interval: {trainer.args.save_steps} steps") |
| print() |
|
|
| |
| print("\n🔍 Running Baseline Evaluation (Pre-training)...") |
| print(" This measures zero-shot performance before any updates.") |
| print("=" * 60, end='') |
| enhanced_callback.evaluate_validation(model, tokenizer, step=0, epoch_override=0) |
| print("=" * 60) |
|
|
| |
| sys.stdout.flush() |
| logging.info("Calling trainer.train()...") |
|
|
| |
| print("🎬 Initiating training loop...\n") |
| trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT) |
|
|
| training_end_time = time.time() |
| training_duration = training_end_time - training_start_time |
|
|
| print("\n" + "="*70) |
| print("🎉 SFT TRAINING COMPLETED SUCCESSFULLY!") |
| print("="*70) |
| print(f"⏱️ Duration: {training_duration/3600:.2f} hours ({training_duration/60:.1f} minutes)") |
| print(f"📈 Average time per epoch: {training_duration/NUM_TRAIN_EPOCHS/60:.2f} minutes") |
| print(f"🏁 Completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
| logging.info(f"SFT Training completed successfully in {training_duration/3600:.2f} hours") |
|
|
| except KeyboardInterrupt: |
| print("\n\n⚠️ Training interrupted by user") |
| logging.warning("Training interrupted by user (KeyboardInterrupt)") |
| print("💾 Saving current progress...") |
|
|
| except Exception as e: |
| print(f"\n\n❌ Training failed with error!") |
| print(f"Error type: {type(e).__name__}") |
| print(f"Error message: {str(e)}") |
| print("\n📋 Full traceback:") |
| logging.exception(f"Training failed with error: {e}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|
| finally: |
| training_end_time = time.time() |
| actual_duration = training_end_time - training_start_time |
|
|
| print("\n" + "="*70) |
| print("🔄 Cleanup and saving...") |
| print("="*70) |
|
|
| |
| try: |
| |
| if reward_fn and hasattr(reward_fn, 'training_log') and reward_fn.training_log: |
| try: |
| log_path = os.path.join(results_dir, "training_rewards.json") |
| with open(log_path, 'w', encoding='utf-8') as f: |
| json.dump(reward_fn.training_log, f, ensure_ascii=False, indent=2) |
| print(f"✅ Metric/Validation log saved: {len(reward_fn.training_log)} entries") |
| logging.info(f"Saved metric log with {len(reward_fn.training_log)} entries") |
| except Exception as e: |
| print(f"⚠️ Failed to save metric log: {e}") |
| logging.warning(f"Failed to save metric log: {e}") |
|
|
| |
| final_model_path = os.path.join(results_dir, "checkpoint", "final_model") |
| os.makedirs(final_model_path, exist_ok=True) |
| trainer.save_model(final_model_path) |
| print(f"✅ Model saved to: {final_model_path}") |
| logging.info(f"Final model saved to: {final_model_path}") |
|
|
| print(f"\n⏱️ Total elapsed time: {actual_duration/60:.1f} minutes") |
| print("="*70) |
|
|
| except Exception as e: |
| print(f"⚠️ Error during cleanup: {e}") |
| logging.exception("Error during cleanup") |
|
|
|
|
| |
|
|
|
|
| |
| print("\n📊 Training Visualization") |
| print("=" * 30) |
|
|
| import matplotlib.pyplot as plt |
| import pandas as pd |
| import seaborn as sns |
|
|
| |
| val_metrics_path = os.path.join(results_dir, VALIDATION_METRICS_PATH) |
| if os.path.exists(val_metrics_path): |
| with open(val_metrics_path, 'r') as f: |
| val_metrics = json.load(f) |
|
|
| epochs = [float(k) for k in val_metrics.keys()] |
| |
| accuracies = [v['avg_reward'] for v in val_metrics.values()] |
|
|
| plt.figure(figsize=(10, 6)) |
| plt.plot(epochs, accuracies, marker='o', linewidth=2, markersize=8, color='#2ca02c') |
| plt.xlabel('Epoch', fontsize=12) |
| plt.ylabel('Validation Accuracy (Exact Match)', fontsize=12) |
| plt.title('SFT Validation Performance', fontsize=14) |
| plt.grid(True, alpha=0.3) |
| plt.tight_layout() |
|
|
| plot_path = os.path.join(results_dir, 'validation_accuracy.png') |
| plt.savefig(plot_path, dpi=300) |
| print(f"✅ Validation accuracy plot saved to: {plot_path}") |
| plt.show() |
| else: |
| print("⚠️ No validation metrics found") |
|
|
| |
| |
| if hasattr(trainer, 'state') and trainer.state.log_history: |
| log_history = trainer.state.log_history |
|
|
| |
| steps = [] |
| losses = [] |
| for entry in log_history: |
| if 'loss' in entry and 'step' in entry: |
| steps.append(entry['step']) |
| losses.append(entry['loss']) |
|
|
| if steps: |
| plt.figure(figsize=(10, 6)) |
| plt.plot(steps, losses, linewidth=2, color='#1f77b4') |
| plt.xlabel('Training Steps', fontsize=12) |
| plt.ylabel('Cross Entropy Loss', fontsize=12) |
| plt.title('SFT Training Loss', fontsize=14) |
| plt.grid(True, alpha=0.3) |
| plt.tight_layout() |
|
|
| loss_plot_path = os.path.join(results_dir, 'training_loss.png') |
| plt.savefig(loss_plot_path, dpi=300) |
| print(f"✅ Training loss plot saved to: {loss_plot_path}") |
| plt.show() |
| else: |
| print("⚠️ No loss history found in trainer state") |
|
|
| |
| |
| training_log_path = os.path.join(results_dir, TRAINING_LOG_PATH) |
| if os.path.exists(training_log_path): |
| try: |
| with open(training_log_path, 'r') as f: |
| logs = json.load(f) |
|
|
| |
| df = pd.DataFrame(logs) |
|
|
| if 'dataset_name' in df.columns and 'reward' in df.columns: |
| |
| accuracy_by_dataset = df.groupby('dataset_name')['reward'].mean().reset_index() |
| accuracy_by_dataset.columns = ['Dataset', 'Accuracy'] |
|
|
| plt.figure(figsize=(10, 6)) |
| sns.barplot(data=accuracy_by_dataset, x='Dataset', y='Accuracy', palette='viridis') |
| plt.ylim(0, 1.0) |
| plt.title('Overall Accuracy by Dataset Type', fontsize=14) |
| plt.ylabel('Average Accuracy', fontsize=12) |
| plt.tight_layout() |
|
|
| breakdown_path = os.path.join(results_dir, 'accuracy_by_dataset.png') |
| plt.savefig(breakdown_path, dpi=300) |
| print(f"✅ Dataset breakdown plot saved to: {breakdown_path}") |
| plt.show() |
|
|
| |
| print("\n📊 Performance Breakdown:") |
| for index, row in accuracy_by_dataset.iterrows(): |
| print(f" • {row['Dataset']}: {row['Accuracy']*100:.2f}%") |
|
|
| except Exception as e: |
| print(f"⚠️ Could not create dataset breakdown: {e}") |
|
|
|
|
| |
|
|
|
|
| import os |
| import json |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import pandas as pd |
| import numpy as np |
|
|
| print("\n📊 Advanced Visualization (Adapted for SFT)") |
| print("=" * 60) |
|
|
| plots_dir = os.path.join(results_dir, "plots") |
| os.makedirs(plots_dir, exist_ok=True) |
|
|
| |
| val_metrics_path = os.path.join(results_dir, VALIDATION_METRICS_PATH) |
| training_log_path = os.path.join(results_dir, TRAINING_LOG_PATH) |
| val_log_path = os.path.join(results_dir, VALIDATION_LOG_PATH) |
|
|
| |
| if os.path.exists(val_metrics_path): |
| with open(val_metrics_path, 'r') as f: |
| metrics = json.load(f) |
|
|
| epochs = sorted([int(k) for k in metrics.keys()]) |
|
|
| |
| plot_data = [] |
| for ep in epochs: |
| ep_key = str(ep) |
| if 'breakdown' in metrics[ep_key]: |
| for ds_name, score in metrics[ep_key]['breakdown'].items(): |
| plot_data.append({'Epoch': ep, 'Accuracy': score, 'Dataset': ds_name}) |
| |
| plot_data.append({'Epoch': ep, 'Accuracy': metrics[ep_key]['avg_reward'], 'Dataset': 'Overall'}) |
|
|
| if plot_data: |
| df_acc = pd.DataFrame(plot_data) |
| plt.figure(figsize=(10, 6)) |
| sns.lineplot(data=df_acc, x='Epoch', y='Accuracy', hue='Dataset', marker='o', linewidth=2) |
| plt.title('Validation Accuracy per Dataset over Epochs') |
| plt.ylabel('Accuracy (Exact Match)') |
| plt.grid(True, alpha=0.3) |
| plt.ylim(0, 1.05) |
| plt.tight_layout() |
| plt.savefig(os.path.join(plots_dir, "accuracy_breakdown.png"), dpi=300) |
| print("✅ Saved per-dataset accuracy plot") |
| plt.close() |
|
|
| |
| |
| def plot_sft_transitions(val_log_path, output_dir): |
| if not os.path.exists(val_log_path): return |
|
|
| with open(val_log_path, 'r') as f: |
| val_data = json.load(f) |
|
|
| |
| history = defaultdict(dict) |
| all_epochs = sorted([int(k) for k in val_data.keys()]) |
|
|
| for ep_str, records in val_data.items(): |
| epoch = int(ep_str) |
| for rec in records: |
| rid = rec.get('record_id') |
| |
| history[rid][epoch] = 1 if rec['reward'] > 0.5 else 0 |
|
|
| |
| |
| |
| first_ep = all_epochs[0] |
| last_ep = all_epochs[-1] |
|
|
| gained = 0 |
| lost = 0 |
| stable_correct = 0 |
| stable_wrong = 0 |
|
|
| for rid, eps in history.items(): |
| if first_ep in eps and last_ep in eps: |
| start = eps[first_ep] |
| end = eps[last_ep] |
|
|
| if start == 0 and end == 1: gained += 1 |
| elif start == 1 and end == 0: lost += 1 |
| elif start == 1 and end == 1: stable_correct += 1 |
| elif start == 0 and end == 0: stable_wrong += 1 |
|
|
| |
| categories = ['Gained (Learned)', 'Lost (Forgetting)', 'Stable Correct', 'Stable Wrong'] |
| values = [gained, lost, stable_correct, stable_wrong] |
| colors = ['#2ca02c', '#d62728', '#1f77b4', 'gray'] |
|
|
| plt.figure(figsize=(8, 6)) |
| bars = plt.bar(categories, values, color=colors) |
| plt.bar_label(bars) |
| plt.title(f'Learning Dynamics (Epoch {first_ep} vs {last_ep})') |
| plt.ylabel('Number of Validation Samples') |
| plt.grid(axis='y', alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, "learning_transitions.png"), dpi=300) |
| print("✅ Saved learning transitions plot") |
| plt.close() |
|
|
| plot_sft_transitions(val_log_path, plots_dir) |
|
|
| |
| if hasattr(trainer, 'state') and trainer.state.log_history: |
| log_history = trainer.state.log_history |
| steps = [] |
| losses = [] |
| for x in log_history: |
| if 'loss' in x: |
| steps.append(x['step']) |
| losses.append(x['loss']) |
|
|
| if steps: |
| plt.figure(figsize=(10, 5)) |
| |
| plt.plot(steps, losses, alpha=0.3, color='blue', label='Raw Loss') |
| |
| if len(losses) > 10: |
| avg_loss = pd.Series(losses).rolling(window=10).mean() |
| plt.plot(steps, avg_loss, color='blue', linewidth=2, label='Smoothed (MA-10)') |
|
|
| plt.title('Training Loss Curve') |
| plt.xlabel('Steps') |
| plt.ylabel('Cross Entropy Loss') |
| plt.legend() |
| plt.grid(True, alpha=0.3) |
| plt.savefig(os.path.join(plots_dir, "training_loss_detailed.png"), dpi=300) |
| print("✅ Saved detailed training loss plot") |
| plt.close() |
|
|
| print(f"📊 All plots saved to: {plots_dir}") |
|
|
|
|
| |
|
|
|
|
| import os |
| import json |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import pandas as pd |
| from collections import defaultdict |
|
|
| print("\n📊 Generating Advanced Plots (Per Epoch Comparison & Transitions)") |
| print("=" * 60) |
|
|
| |
| plots_dir = os.path.join(results_dir, "plots") |
| os.makedirs(plots_dir, exist_ok=True) |
|
|
| |
| val_log_path = os.path.join(results_dir, VALIDATION_LOG_PATH) |
|
|
| if os.path.exists(val_log_path): |
| with open(val_log_path, 'r', encoding='utf-8') as f: |
| val_data = json.load(f) |
|
|
| |
| |
| metrics_data = [] |
| |
| |
| epochs = sorted([int(k) for k in val_data.keys()]) |
| |
| for ep in epochs: |
| ep_str = str(ep) |
| records = val_data[ep_str] |
| |
| |
| ds_scores = defaultdict(list) |
| all_scores = [] |
| |
| for rec in records: |
| score = 1.0 if rec['reward'] > 0.5 else 0.0 |
| ds_name = rec.get('dataset_name', 'Unknown') |
| ds_scores[ds_name].append(score) |
| all_scores.append(score) |
| |
| |
| metrics_data.append({ |
| 'Epoch': ep, |
| 'Dataset': 'Overall', |
| 'Accuracy': sum(all_scores)/len(all_scores) if all_scores else 0 |
| }) |
| |
| for name, scores in ds_scores.items(): |
| metrics_data.append({ |
| 'Epoch': ep, |
| 'Dataset': name, |
| 'Accuracy': sum(scores)/len(scores) if scores else 0 |
| }) |
|
|
| |
| if metrics_data: |
| df_metrics = pd.DataFrame(metrics_data) |
| plt.figure(figsize=(10, 6)) |
| |
| sns.lineplot(data=df_metrics, x='Epoch', y='Accuracy', hue='Dataset', style='Dataset', markers=True, dashes=False, linewidth=2.5) |
| plt.title('Validation Accuracy Comparison per Epoch', fontsize=14) |
| plt.ylabel('Accuracy (Exact Match)', fontsize=12) |
| plt.xlabel('Epoch', fontsize=12) |
| plt.grid(True, alpha=0.3) |
| plt.ylim(-0.05, 1.05) |
| plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') |
| plt.tight_layout() |
| plt.savefig(os.path.join(plots_dir, "accuracy_comparison_per_epoch.png"), dpi=300) |
| print("✅ Saved accuracy comparison plot") |
| plt.show() |
|
|
| |
| |
| |
| |
| |
| if len(epochs) >= 2: |
| first_ep = str(epochs[0]) |
| last_ep = str(epochs[-1]) |
| |
| |
| history = defaultdict(dict) |
| for ep in [first_ep, last_ep]: |
| for rec in val_data[ep]: |
| rid = rec.get('record_id') |
| |
| is_correct = 1 if rec['reward'] > 0.5 else 0 |
| history[rid][ep] = is_correct |
| |
| |
| transitions = { |
| 'Gained (Learned)': 0, |
| 'Lost (Forgot)': 0, |
| 'Stable Correct': 0, |
| 'Stable Wrong': 0 |
| } |
| |
| for rid, eps in history.items(): |
| if first_ep in eps and last_ep in eps: |
| start = eps[first_ep] |
| end = eps[last_ep] |
| |
| if start == 0 and end == 1: transitions['Gained (Learned)'] += 1 |
| elif start == 1 and end == 0: transitions['Lost (Forgot)'] += 1 |
| elif start == 1 and end == 1: transitions['Stable Correct'] += 1 |
| elif start == 0 and end == 0: transitions['Stable Wrong'] += 1 |
|
|
| |
| plt.figure(figsize=(9, 6)) |
| |
| colors = ['#2ca02c', '#d62728', '#1f77b4', '#7f7f7f'] |
| bars = plt.bar(transitions.keys(), transitions.values(), color=colors, alpha=0.8) |
| |
| |
| for bar in bars: |
| height = bar.get_height() |
| plt.text(bar.get_x() + bar.get_width()/2., height, |
| f'{int(height)}', |
| ha='center', va='bottom') |
| |
| plt.title(f'Learning Transitions (Epoch {first_ep} → {last_ep})', fontsize=14) |
| plt.ylabel('Number of Samples', fontsize=12) |
| plt.grid(axis='y', alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(os.path.join(plots_dir, "learning_transitions.png"), dpi=300) |
| print("✅ Saved learning transitions plot") |
| plt.show() |
| else: |
| print("⚠️ Need at least 2 epochs to plot transitions.") |
|
|
| else: |
| print(f"❌ Validation log not found at: {val_log_path}") |
|
|
|
|
| |
|
|
|
|
| |
| |
| if hasattr(trainer, 'state') and trainer.state.log_history: |
| log_hist = trainer.state.log_history |
| |
| |
| train_steps = [x['step'] for x in log_hist if 'loss' in x] |
| train_loss = [x['loss'] for x in log_hist if 'loss' in x] |
| |
| |
| val_steps = [x['step'] for x in log_hist if 'eval_loss' in x] |
| val_loss = [x['eval_loss'] for x in log_hist if 'eval_loss' in x] |
|
|
| if train_steps and val_steps: |
| plt.figure(figsize=(10, 6)) |
| plt.plot(train_steps, train_loss, label='Training Loss', color='blue', alpha=0.6) |
| plt.plot(val_steps, val_loss, label='Validation Loss', color='red', marker='o') |
| plt.xlabel('Steps') |
| plt.ylabel('Loss') |
| plt.title('Training vs Validation Loss') |
| plt.legend() |
| plt.grid(True, alpha=0.3) |
| plt.savefig(os.path.join(results_dir, "loss_comparison.png")) |
| print("✅ Saved loss comparison plot") |
| plt.show() |
|
|
|
|