""" Fine-tuning Script for LFM2-2.6B with Complete Dialogue History Following KokoroChat methodology - uses entire conversation context Filename: finetune_lfm_complete_history.py """ import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig, TrainerCallback ) from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel, PeftConfig ) from datasets import load_dataset, Dataset import os from typing import Dict, List, Optional import numpy as np from tqdm import tqdm import json import gc import warnings import wandb from datetime import datetime warnings.filterwarnings('ignore') # Enable TF32 for H100 optimization torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True class LFMKokoroChatFineTuner: def __init__( self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = False, # H100 has enough memory max_seq_length: int = 2048 # Increased for complete dialogue history ): """ Initialize the fine-tuner for LFM models with complete dialogue history support Args: model_name: Name of the base model use_4bit: Whether to use 4-bit quantization max_seq_length: Maximum sequence length for complete dialogues """ self.model_name = model_name self.use_4bit = use_4bit self.max_seq_length = max_seq_length self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("="*80) print("๐Ÿš€ LFM Fine-tuning with Complete Dialogue History (KokoroChat Method)") print("="*80) print(f"Model: {model_name}") print(f"Device: {self.device}") print(f"Max sequence length: {max_seq_length}") # GPU information if torch.cuda.is_available(): num_gpus = torch.cuda.device_count() print(f"Number of GPUs: {num_gpus}") for i in range(num_gpus): gpu_name = torch.cuda.get_device_name(i) gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1e9 print(f" GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)") # Initialize WandB self.init_wandb() def init_wandb(self): """Initialize WandB for experiment tracking""" try: run_name = f"lfm-kokoro-complete-{datetime.now().strftime('%Y%m%d-%H%M%S')}" wandb.init( project="lfm-kokoro-complete-history", name=run_name, config={ "model_name": self.model_name, "use_4bit_quantization": self.use_4bit, "max_seq_length": self.max_seq_length, "device": str(self.device), "num_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0, "methodology": "Complete dialogue history (KokoroChat)", "framework": "transformers + peft", "task": "japanese_counseling" }, tags=["counseling", "japanese", "lfm", "complete-history", "kokoro"] ) print(f"โœ… WandB initialized: {wandb.run.name}") print(f"๐Ÿ“Š View run at: {wandb.run.get_url()}") self.wandb_enabled = True except Exception as e: print(f"โš ๏ธ WandB initialization failed: {e}") self.wandb_enabled = False os.environ["WANDB_DISABLED"] = "true" def setup_model_and_tokenizer(self): """Setup model with quantization and LoRA""" print("\n๐Ÿ“š Setting up model and tokenizer...") # Load tokenizer print("Loading tokenizer...") try: self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True ) except: print("Using fallback tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained("gpt2") # Set special tokens if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer.eos_token is None: self.tokenizer.eos_token = "" self.tokenizer.pad_token = "" self.tokenizer.padding_side = "left" # Important for batch generation # Quantization config if self.use_4bit: print("Setting up 4-bit quantization...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, # BF16 for H100 bnb_4bit_use_double_quant=True ) else: bnb_config = None # Load model print(f"Loading model: {self.model_name}...") model_kwargs = { "trust_remote_code": True, "torch_dtype": torch.bfloat16, # BF16 for H100 "device_map": "auto", } if bnb_config: model_kwargs["quantization_config"] = bnb_config try: self.model = AutoModelForCausalLM.from_pretrained( self.model_name, **model_kwargs ) except Exception as e: print(f"Error loading model: {e}") print("Attempting without device_map...") model_kwargs.pop("device_map", None) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, **model_kwargs ) self.model = self.model.to(self.device) # Enable gradient checkpointing if hasattr(self.model, 'gradient_checkpointing_enable'): self.model.gradient_checkpointing_enable() # Prepare for k-bit training if using quantization if self.use_4bit: print("Preparing model for 4-bit training...") self.model = prepare_model_for_kbit_training(self.model) # LoRA configuration optimized for dialogue with complete history print("Applying LoRA configuration...") # Find target modules target_modules = self.find_target_modules() # Higher rank for complex dialogue understanding lora_config = LoraConfig( r=64, # Increased for better dialogue understanding lora_alpha=128, target_modules=target_modules, lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM, inference_mode=False ) # Apply LoRA self.model = get_peft_model(self.model, lora_config) # Print trainable parameters trainable_params = 0 all_params = 0 for _, param in self.model.named_parameters(): all_params += param.numel() if param.requires_grad: trainable_params += param.numel() trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0 print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)") # Log to WandB if self.wandb_enabled: wandb.config.update({ "lora_r": lora_config.r, "lora_alpha": lora_config.lora_alpha, "lora_dropout": lora_config.lora_dropout, "lora_target_modules": target_modules, "total_parameters": all_params, "trainable_parameters": trainable_params, "trainable_percentage": trainable_percentage }) self.model.print_trainable_parameters() def find_target_modules(self): """Find linear modules to apply LoRA to""" target_modules = [] for name, module in self.model.named_modules(): if isinstance(module, torch.nn.Linear): names = name.split('.') if len(names) > 0: target_modules.append(names[-1]) # Remove duplicates target_modules = list(set(target_modules)) # Common patterns for transformer models common_targets = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "fc1", "fc2", "query", "key", "value", "dense"] # Filter to common targets final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)] if not final_targets: # Fallback to specific modules for LFM final_targets = ["q_proj", "v_proj", "k_proj", "o_proj"] print(f"LoRA target modules: {final_targets}") return final_targets def load_and_process_datasets(self, data_path: str): """ Load and process datasets with complete dialogue history Handles the new data format with full conversation context """ print(f"\n๐Ÿ“š Loading datasets from {data_path}...") # Check for dataset statistics stats_file = os.path.join(data_path, 'dataset_stats.json') if os.path.exists(stats_file): with open(stats_file, 'r') as f: stats = json.load(f) print("Dataset statistics:") print(f" Average dialogue history: {stats['dialogue_history_stats']['mean_length']:.1f} turns") print(f" Max dialogue history: {stats['dialogue_history_stats']['max_length']} turns") print(f" Median dialogue history: {stats['dialogue_history_stats']['median_length']:.1f} turns") # Load datasets train_data = [] val_data = [] # Load training data train_file = os.path.join(data_path, 'train.jsonl') with open(train_file, 'r', encoding='utf-8') as f: for line in tqdm(f, desc="Loading training data"): item = json.loads(line) train_data.append({ 'text': item['text'], 'history_length': item.get('history_length', 0), 'score': item.get('score', 100), 'topic': item.get('topic', 'general') }) # Load validation data val_file = os.path.join(data_path, 'val.jsonl') with open(val_file, 'r', encoding='utf-8') as f: for line in tqdm(f, desc="Loading validation data"): item = json.loads(line) val_data.append({ 'text': item['text'], 'history_length': item.get('history_length', 0), 'score': item.get('score', 100), 'topic': item.get('topic', 'general') }) print(f"Loaded {len(train_data)} training examples") print(f"Loaded {len(val_data)} validation examples") # Analyze dialogue history lengths train_history_lengths = [d['history_length'] for d in train_data] val_history_lengths = [d['history_length'] for d in val_data] print(f"\nDialogue history length distribution:") print(f" Training - Mean: {np.mean(train_history_lengths):.1f}, Max: {max(train_history_lengths)}") print(f" Validation - Mean: {np.mean(val_history_lengths):.1f}, Max: {max(val_history_lengths)}") # Log to WandB if self.wandb_enabled: wandb.config.update({ "train_examples": len(train_data), "val_examples": len(val_data), "avg_train_history_length": float(np.mean(train_history_lengths)), "max_train_history_length": int(max(train_history_lengths)), "avg_val_history_length": float(np.mean(val_history_lengths)), "max_val_history_length": int(max(val_history_lengths)) }) # Log history length distribution wandb.log({ "train_history_distribution": wandb.Histogram(train_history_lengths), "val_history_distribution": wandb.Histogram(val_history_lengths) }) # Tokenize datasets print("\nTokenizing datasets with complete dialogue history...") print(f"Using max sequence length: {self.max_seq_length}") # Extract texts for tokenization train_texts = [d['text'] for d in train_data] val_texts = [d['text'] for d in val_data] # Tokenize with longer context for complete history train_encodings = self.tokenize_texts(train_texts, desc="Tokenizing training data") val_encodings = self.tokenize_texts(val_texts, desc="Tokenizing validation data") # Create datasets self.train_dataset = Dataset.from_dict(train_encodings) self.val_dataset = Dataset.from_dict(val_encodings) # Set format for PyTorch self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) # Clean up memory del train_texts, val_texts, train_encodings, val_encodings, train_data, val_data gc.collect() print("โœ… Datasets loaded and tokenized") def tokenize_texts(self, texts: List[str], batch_size: int = 50, desc: str = "Tokenizing"): """ Tokenize texts in batches with support for longer sequences """ all_input_ids = [] all_attention_masks = [] # Process in smaller batches for long sequences for i in tqdm(range(0, len(texts), batch_size), desc=desc): batch_texts = texts[i:i + batch_size] # Tokenize batch with longer max length encodings = self.tokenizer( batch_texts, truncation=True, padding='max_length', max_length=self.max_seq_length, return_tensors='pt' ) # Convert to lists all_input_ids.extend(encodings['input_ids'].tolist()) all_attention_masks.extend(encodings['attention_mask'].tolist()) # Create labels (same as input_ids for causal LM) labels = all_input_ids.copy() return { 'input_ids': all_input_ids, 'attention_mask': all_attention_masks, 'labels': labels } def setup_training_args(self, output_dir: str = "./lfm_kokoro_complete"): """Setup training arguments optimized for complete dialogue history""" print("\nโš™๏ธ Setting up training arguments...") # Calculate batch sizes based on sequence length and GPU memory if torch.cuda.is_available(): gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 num_gpus = torch.cuda.device_count() # Adjust batch size based on sequence length and GPU memory if self.max_seq_length >= 2048: if gpu_memory >= 80: # H100 80GB batch_size = 4 gradient_accumulation = 4 elif gpu_memory >= 40: batch_size = 2 gradient_accumulation = 8 else: batch_size = 1 gradient_accumulation = 16 else: batch_size = 8 gradient_accumulation = 2 # Adjust for multiple GPUs if num_gpus > 1: batch_size = batch_size * num_gpus gradient_accumulation = max(1, gradient_accumulation // num_gpus) else: batch_size = 1 gradient_accumulation = 32 print(f"Batch configuration:") print(f" Per device batch size: {batch_size}") print(f" Gradient accumulation steps: {gradient_accumulation}") print(f" Effective batch size: {batch_size * gradient_accumulation}") # Update WandB config if self.wandb_enabled: wandb.config.update({ "batch_size": batch_size, "gradient_accumulation_steps": gradient_accumulation, "effective_batch_size": batch_size * gradient_accumulation, "num_epochs": 3, "learning_rate": 2e-4, "warmup_ratio": 0.1, "weight_decay": 0.01, "max_grad_norm": 1.0, "lr_scheduler": "cosine", "optimizer": "adamw_torch" }) self.training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=3, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation, gradient_checkpointing=True, warmup_ratio=0.1, learning_rate=2e-4, bf16=True, # Use BF16 for H100 tf32=True, # Enable TF32 for H100 logging_steps=10, logging_first_step=True, eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=200, save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, report_to="wandb" if self.wandb_enabled else "none", run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run", optim="adamw_torch", lr_scheduler_type="cosine", weight_decay=0.01, max_grad_norm=1.0, remove_unused_columns=False, label_names=["labels"], dataloader_num_workers=4, dataloader_pin_memory=True, ddp_find_unused_parameters=False if torch.cuda.device_count() > 1 else None, ) def train(self): """Execute training with complete dialogue history""" print("\n๐ŸŽฏ Starting training with complete dialogue history...") # Data collator data_collator = DataCollatorForLanguageModeling( tokenizer=self.tokenizer, mlm=False, pad_to_multiple_of=8 ) # Custom callback for metrics class MetricsCallback(TrainerCallback): def __init__(self, wandb_enabled): self.wandb_enabled = wandb_enabled def on_log(self, args, state, control, logs=None, **kwargs): if logs and self.wandb_enabled: # Add perplexity if "loss" in logs: logs["perplexity"] = np.exp(logs["loss"]) if "eval_loss" in logs: logs["eval_perplexity"] = np.exp(logs["eval_loss"]) # Log to WandB wandb.log(logs, step=state.global_step) return control # Initialize trainer trainer = Trainer( model=self.model, args=self.training_args, train_dataset=self.train_dataset, eval_dataset=self.val_dataset, data_collator=data_collator, tokenizer=self.tokenizer, callbacks=[MetricsCallback(self.wandb_enabled)] if self.wandb_enabled else [], ) # Calculate total steps total_steps = len(self.train_dataset) // ( self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps ) * self.training_args.num_train_epochs print("="*60) print("Training Information:") print(f" Total training samples: {len(self.train_dataset)}") print(f" Total validation samples: {len(self.val_dataset)}") print(f" Total training steps: {total_steps}") print(f" Max sequence length: {self.max_seq_length}") print("="*60) # Log training start if self.wandb_enabled: wandb.log({ "training_status": "started", "total_steps": total_steps, "max_seq_length": self.max_seq_length }) try: # Train print("\n๐Ÿš€ Training started...") train_result = trainer.train() # Save model print("\n๐Ÿ’พ Saving fine-tuned model...") final_model_path = os.path.join(self.training_args.output_dir, "final_model") trainer.save_model(final_model_path) self.tokenizer.save_pretrained(final_model_path) # Save training metrics with open(os.path.join(self.training_args.output_dir, "training_metrics.json"), 'w') as f: json.dump(train_result.metrics, f, indent=2) # Final evaluation print("\n๐Ÿ“Š Running final evaluation...") eval_results = trainer.evaluate() # Save evaluation metrics with open(os.path.join(self.training_args.output_dir, "eval_metrics.json"), 'w') as f: json.dump(eval_results, f, indent=2) # Log final metrics if self.wandb_enabled: wandb.run.summary.update({ "final_train_loss": train_result.metrics.get("train_loss", 0), "final_eval_loss": eval_results.get("eval_loss", 0), "final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)), "total_training_time": train_result.metrics.get("train_runtime", 0), "training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0), "training_status": "completed" }) # Save model artifact artifact = wandb.Artifact( name=f"kokoro-model-complete-{wandb.run.id}", type="model", description="LFM model fine-tuned with complete dialogue history", metadata={ "base_model": self.model_name, "final_loss": float(eval_results.get("eval_loss", 0)), "final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))), "max_seq_length": self.max_seq_length, "methodology": "Complete dialogue history (KokoroChat)" } ) artifact.add_dir(final_model_path) wandb.log_artifact(artifact) print("\n" + "="*60) print("โœ… Training completed successfully!") print(f"๐Ÿ“ Model saved to: {final_model_path}") print(f"๐Ÿ“‰ Final eval loss: {eval_results.get('eval_loss', 0):.4f}") print(f"๐Ÿ“Š Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}") if self.wandb_enabled and wandb.run: print(f"๐Ÿ”— View results at: {wandb.run.get_url()}") print("="*60) return trainer except Exception as e: print(f"โŒ Error during training: {e}") if self.wandb_enabled: wandb.run.summary["training_status"] = "failed" wandb.run.summary["error"] = str(e) # Save emergency checkpoint try: emergency_path = os.path.join(self.training_args.output_dir, "emergency_checkpoint") self.model.save_pretrained(emergency_path) self.tokenizer.save_pretrained(emergency_path) print(f"๐Ÿ’พ Emergency checkpoint saved to: {emergency_path}") except: print("โŒ Could not save emergency checkpoint") raise e finally: if self.wandb_enabled: wandb.finish() def test_model_with_complete_history(model_path: str): """Test the fine-tuned model with complete dialogue history examples""" print("\n" + "="*60) print("๐Ÿงช Testing model with complete dialogue history") print("="*60) # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) # Check if it's a PEFT model adapter_config_path = os.path.join(model_path, "adapter_config.json") if os.path.exists(adapter_config_path): print("Loading as PEFT model...") config = PeftConfig.from_pretrained(model_path) base_model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) model = PeftModel.from_pretrained(base_model, model_path) else: print("Loading as regular model...") model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map="auto", local_files_only=True, trust_remote_code=True ) model.eval() # Test with dialogue history examples test_cases = [ { "history": "ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆ: ใ“ใ‚“ใซใกใฏใ€‚ๆœ€่ฟ‘ใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใฆใ„ใพใ™ใ€‚\nใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผ: ใ“ใ‚“ใซใกใฏใ€‚ใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใฆใ„ใ‚‰ใฃใ—ใ‚ƒใ‚‹ใฎใงใ™ใญใ€‚ใฉใฎใ‚ˆใ†ใช็Šถๆณใงใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใ‚‹ใ“ใจใŒๅคšใ„ใงใ™ใ‹๏ผŸ\n", "current": "ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆ: ไป•ไบ‹ใŒๅฟ™ใ—ใใฆใ€ไผ‘ใ‚€ๆ™‚้–“ใŒใ‚ใ‚Šใพใ›ใ‚“ใ€‚" }, { "history": "", "current": "ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆ: ไบบ้–“้–ขไฟ‚ใงๆ‚ฉใ‚“ใงใ„ใพใ™ใ€‚" } ] print("Testing with complete dialogue history:\n") for i, test_case in enumerate(test_cases, 1): print(f"Test Case {i}:") print("-" * 40) # Format input with complete history if test_case["history"]: prompt = f"""### Instruction: ใ‚ใชใŸใฏๅฐ‚้–€็š„ใช่จ“็ทดใ‚’ๅ—ใ‘ใŸๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚ ไปฅไธ‹ใฎๅฎŒๅ…จใชๅฏพ่ฉฑๅฑฅๆญดใ‚’่ธใพใˆใฆใ€ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใจใ—ใฆ้ฉๅˆ‡ใชๅฟœ็ญ”ใ‚’็”Ÿๆˆใ—ใฆใใ ใ•ใ„ใ€‚ ### Dialogue History: {test_case["history"]}{test_case["current"]} ### Response: """ else: prompt = f"""### Instruction: ใ‚ใชใŸใฏๅฐ‚้–€็š„ใช่จ“็ทดใ‚’ๅ—ใ‘ใŸๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚ ### Dialogue History: {test_case["current"]} ### Response: """ # Generate response inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, temperature=0, do_sample=True, top_p=0.9, pad_token_id=tokenizer.pad_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("### Response:")[-1].strip() if "### Response:" in response else response # print(f"History Length: {len(test_case['history'].split('\\n')) if test_case['history'] else 0} turns") print("History Length: {} turns".format(len(test_case['history'].split('\\n')) if test_case['history'] else 0)) print(f"Current Input: {test_case['current']}") print(f"Generated Response: {response[:300]}...") print() print("="*60) # Main execution if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Fine-tune LFM model with complete dialogue history') parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B', help='Base model name') parser.add_argument('--data_path', type=str, default='./kokoro_processed_data', help='Path to processed data with complete dialogue history') parser.add_argument('--output_dir', type=str, default='./lfm_kokoro_complete', help='Output directory for fine-tuned model') parser.add_argument('--max_seq_length', type=int, default=2048, help='Maximum sequence length for complete dialogues') parser.add_argument('--use_4bit', action='store_true', help='Use 4-bit quantization') parser.add_argument('--test_only', action='store_true', help='Only test existing model') args = parser.parse_args() if args.test_only: # Test existing model test_model_with_complete_history( os.path.join(args.output_dir, "final_model") ) else: # Check CUDA availability if not torch.cuda.is_available(): print("โš ๏ธ Warning: CUDA is not available. Training will be slow.") response = input("Continue? (y/n): ") if response.lower() != 'y': exit() try: # Clear GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() # Initialize fine-tuner print(f"๐Ÿš€ Initializing fine-tuner for complete dialogue history") finetuner = LFMKokoroChatFineTuner( model_name=args.model_name, use_4bit=args.use_4bit, max_seq_length=args.max_seq_length ) # Setup model finetuner.setup_model_and_tokenizer() # Load datasets finetuner.load_and_process_datasets(args.data_path) # Setup training arguments finetuner.setup_training_args(args.output_dir) # Train trainer = finetuner.train() # Test the model print("\n๐Ÿงช Testing the fine-tuned model...") test_model_with_complete_history( os.path.join(args.output_dir, "final_model") ) print("\nโœ… Fine-tuning with complete dialogue history completed!") print(f"๐Ÿ“ Model saved to: {args.output_dir}/final_model") print("\n๐Ÿ“‹ Next steps:") print(f"1. Test more: python {__file__} --test_only --output_dir {args.output_dir}") print("2. Run benchmarking with complete history support") print("3. Deploy for production use") except KeyboardInterrupt: print("\n\nโš ๏ธ Training interrupted by user.") if wandb.run: wandb.finish() except Exception as e: print(f"\nโŒ Error: {e}") import traceback traceback.print_exc() if wandb.run: wandb.finish()