| """ |
| Fine-tuning Script for LFM2-2.6B with Complete Dialogue History |
| Following KokoroChat methodology - uses entire conversation context |
| Filename: finetune_lfm_complete_history.py |
| """ |
|
|
| import torch |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling, |
| BitsAndBytesConfig, |
| TrainerCallback |
| ) |
| from peft import ( |
| LoraConfig, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| TaskType, |
| PeftModel, |
| PeftConfig |
| ) |
| from datasets import load_dataset, Dataset |
| import os |
| from typing import Dict, List, Optional |
| import numpy as np |
| from tqdm import tqdm |
| import json |
| import gc |
| import warnings |
| import wandb |
| from datetime import datetime |
|
|
| warnings.filterwarnings('ignore') |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| class LFMKokoroChatFineTuner: |
| def __init__( |
| self, |
| model_name: str = "LiquidAI/LFM2-2.6B", |
| use_4bit: bool = False, |
| max_seq_length: int = 2048 |
| ): |
| """ |
| Initialize the fine-tuner for LFM models with complete dialogue history support |
| |
| Args: |
| model_name: Name of the base model |
| use_4bit: Whether to use 4-bit quantization |
| max_seq_length: Maximum sequence length for complete dialogues |
| """ |
| self.model_name = model_name |
| self.use_4bit = use_4bit |
| self.max_seq_length = max_seq_length |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| print("="*80) |
| print("🚀 LFM Fine-tuning with Complete Dialogue History (KokoroChat Method)") |
| print("="*80) |
| print(f"Model: {model_name}") |
| print(f"Device: {self.device}") |
| print(f"Max sequence length: {max_seq_length}") |
| |
| |
| if torch.cuda.is_available(): |
| num_gpus = torch.cuda.device_count() |
| print(f"Number of GPUs: {num_gpus}") |
| for i in range(num_gpus): |
| gpu_name = torch.cuda.get_device_name(i) |
| gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1e9 |
| print(f" GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)") |
| |
| |
| self.init_wandb() |
| |
| def init_wandb(self): |
| """Initialize WandB for experiment tracking""" |
| try: |
| run_name = f"lfm-kokoro-complete-{datetime.now().strftime('%Y%m%d-%H%M%S')}" |
| |
| wandb.init( |
| project="lfm-kokoro-complete-history", |
| name=run_name, |
| config={ |
| "model_name": self.model_name, |
| "use_4bit_quantization": self.use_4bit, |
| "max_seq_length": self.max_seq_length, |
| "device": str(self.device), |
| "num_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0, |
| "methodology": "Complete dialogue history (KokoroChat)", |
| "framework": "transformers + peft", |
| "task": "japanese_counseling" |
| }, |
| tags=["counseling", "japanese", "lfm", "complete-history", "kokoro"] |
| ) |
| |
| print(f"✅ WandB initialized: {wandb.run.name}") |
| print(f"📊 View run at: {wandb.run.get_url()}") |
| self.wandb_enabled = True |
| |
| except Exception as e: |
| print(f"⚠️ WandB initialization failed: {e}") |
| self.wandb_enabled = False |
| os.environ["WANDB_DISABLED"] = "true" |
| |
| def setup_model_and_tokenizer(self): |
| """Setup model with quantization and LoRA""" |
| |
| print("\n📚 Setting up model and tokenizer...") |
| |
| |
| print("Loading tokenizer...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.model_name, |
| trust_remote_code=True |
| ) |
| except: |
| print("Using fallback tokenizer...") |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| if self.tokenizer.eos_token is None: |
| self.tokenizer.eos_token = "</s>" |
| self.tokenizer.pad_token = "</s>" |
| |
| self.tokenizer.padding_side = "left" |
| |
| |
| if self.use_4bit: |
| print("Setting up 4-bit quantization...") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True |
| ) |
| else: |
| bnb_config = None |
| |
| |
| print(f"Loading model: {self.model_name}...") |
| model_kwargs = { |
| "trust_remote_code": True, |
| "torch_dtype": torch.bfloat16, |
| "device_map": "auto", |
| } |
| |
| if bnb_config: |
| model_kwargs["quantization_config"] = bnb_config |
| |
| try: |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_name, |
| **model_kwargs |
| ) |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| print("Attempting without device_map...") |
| model_kwargs.pop("device_map", None) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_name, |
| **model_kwargs |
| ) |
| self.model = self.model.to(self.device) |
| |
| |
| if hasattr(self.model, 'gradient_checkpointing_enable'): |
| self.model.gradient_checkpointing_enable() |
| |
| |
| if self.use_4bit: |
| print("Preparing model for 4-bit training...") |
| self.model = prepare_model_for_kbit_training(self.model) |
| |
| |
| print("Applying LoRA configuration...") |
| |
| |
| target_modules = self.find_target_modules() |
| |
| |
| lora_config = LoraConfig( |
| r=64, |
| lora_alpha=128, |
| target_modules=target_modules, |
| lora_dropout=0.05, |
| bias="none", |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False |
| ) |
| |
| |
| self.model = get_peft_model(self.model, lora_config) |
| |
| |
| trainable_params = 0 |
| all_params = 0 |
| for _, param in self.model.named_parameters(): |
| all_params += param.numel() |
| if param.requires_grad: |
| trainable_params += param.numel() |
| |
| trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0 |
| |
| print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)") |
| |
| |
| if self.wandb_enabled: |
| wandb.config.update({ |
| "lora_r": lora_config.r, |
| "lora_alpha": lora_config.lora_alpha, |
| "lora_dropout": lora_config.lora_dropout, |
| "lora_target_modules": target_modules, |
| "total_parameters": all_params, |
| "trainable_parameters": trainable_params, |
| "trainable_percentage": trainable_percentage |
| }) |
| |
| self.model.print_trainable_parameters() |
| |
| def find_target_modules(self): |
| """Find linear modules to apply LoRA to""" |
| target_modules = [] |
| for name, module in self.model.named_modules(): |
| if isinstance(module, torch.nn.Linear): |
| names = name.split('.') |
| if len(names) > 0: |
| target_modules.append(names[-1]) |
| |
| |
| target_modules = list(set(target_modules)) |
| |
| |
| common_targets = ["q_proj", "v_proj", "k_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| "fc1", "fc2", "query", "key", "value", "dense"] |
| |
| |
| final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)] |
| |
| if not final_targets: |
| |
| final_targets = ["q_proj", "v_proj", "k_proj", "o_proj"] |
| |
| print(f"LoRA target modules: {final_targets}") |
| return final_targets |
| |
| def load_and_process_datasets(self, data_path: str): |
| """ |
| Load and process datasets with complete dialogue history |
| Handles the new data format with full conversation context |
| """ |
| |
| print(f"\n📚 Loading datasets from {data_path}...") |
| |
| |
| stats_file = os.path.join(data_path, 'dataset_stats.json') |
| if os.path.exists(stats_file): |
| with open(stats_file, 'r') as f: |
| stats = json.load(f) |
| print("Dataset statistics:") |
| print(f" Average dialogue history: {stats['dialogue_history_stats']['mean_length']:.1f} turns") |
| print(f" Max dialogue history: {stats['dialogue_history_stats']['max_length']} turns") |
| print(f" Median dialogue history: {stats['dialogue_history_stats']['median_length']:.1f} turns") |
| |
| |
| train_data = [] |
| val_data = [] |
| |
| |
| train_file = os.path.join(data_path, 'train.jsonl') |
| with open(train_file, 'r', encoding='utf-8') as f: |
| for line in tqdm(f, desc="Loading training data"): |
| item = json.loads(line) |
| train_data.append({ |
| 'text': item['text'], |
| 'history_length': item.get('history_length', 0), |
| 'score': item.get('score', 100), |
| 'topic': item.get('topic', 'general') |
| }) |
| |
| |
| val_file = os.path.join(data_path, 'val.jsonl') |
| with open(val_file, 'r', encoding='utf-8') as f: |
| for line in tqdm(f, desc="Loading validation data"): |
| item = json.loads(line) |
| val_data.append({ |
| 'text': item['text'], |
| 'history_length': item.get('history_length', 0), |
| 'score': item.get('score', 100), |
| 'topic': item.get('topic', 'general') |
| }) |
| |
| print(f"Loaded {len(train_data)} training examples") |
| print(f"Loaded {len(val_data)} validation examples") |
| |
| |
| train_history_lengths = [d['history_length'] for d in train_data] |
| val_history_lengths = [d['history_length'] for d in val_data] |
| |
| print(f"\nDialogue history length distribution:") |
| print(f" Training - Mean: {np.mean(train_history_lengths):.1f}, Max: {max(train_history_lengths)}") |
| print(f" Validation - Mean: {np.mean(val_history_lengths):.1f}, Max: {max(val_history_lengths)}") |
| |
| |
| if self.wandb_enabled: |
| wandb.config.update({ |
| "train_examples": len(train_data), |
| "val_examples": len(val_data), |
| "avg_train_history_length": float(np.mean(train_history_lengths)), |
| "max_train_history_length": int(max(train_history_lengths)), |
| "avg_val_history_length": float(np.mean(val_history_lengths)), |
| "max_val_history_length": int(max(val_history_lengths)) |
| }) |
| |
| |
| wandb.log({ |
| "train_history_distribution": wandb.Histogram(train_history_lengths), |
| "val_history_distribution": wandb.Histogram(val_history_lengths) |
| }) |
| |
| |
| print("\nTokenizing datasets with complete dialogue history...") |
| print(f"Using max sequence length: {self.max_seq_length}") |
| |
| |
| train_texts = [d['text'] for d in train_data] |
| val_texts = [d['text'] for d in val_data] |
| |
| |
| train_encodings = self.tokenize_texts(train_texts, desc="Tokenizing training data") |
| val_encodings = self.tokenize_texts(val_texts, desc="Tokenizing validation data") |
| |
| |
| self.train_dataset = Dataset.from_dict(train_encodings) |
| self.val_dataset = Dataset.from_dict(val_encodings) |
| |
| |
| self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
| self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
| |
| |
| del train_texts, val_texts, train_encodings, val_encodings, train_data, val_data |
| gc.collect() |
| |
| print("✅ Datasets loaded and tokenized") |
| |
| def tokenize_texts(self, texts: List[str], batch_size: int = 50, desc: str = "Tokenizing"): |
| """ |
| Tokenize texts in batches with support for longer sequences |
| """ |
| all_input_ids = [] |
| all_attention_masks = [] |
| |
| |
| for i in tqdm(range(0, len(texts), batch_size), desc=desc): |
| batch_texts = texts[i:i + batch_size] |
| |
| |
| encodings = self.tokenizer( |
| batch_texts, |
| truncation=True, |
| padding='max_length', |
| max_length=self.max_seq_length, |
| return_tensors='pt' |
| ) |
| |
| |
| all_input_ids.extend(encodings['input_ids'].tolist()) |
| all_attention_masks.extend(encodings['attention_mask'].tolist()) |
| |
| |
| labels = all_input_ids.copy() |
| |
| return { |
| 'input_ids': all_input_ids, |
| 'attention_mask': all_attention_masks, |
| 'labels': labels |
| } |
| |
| def setup_training_args(self, output_dir: str = "./lfm_kokoro_complete"): |
| """Setup training arguments optimized for complete dialogue history""" |
| |
| print("\n⚙️ Setting up training arguments...") |
| |
| |
| if torch.cuda.is_available(): |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| num_gpus = torch.cuda.device_count() |
| |
| |
| if self.max_seq_length >= 2048: |
| if gpu_memory >= 80: |
| batch_size = 4 |
| gradient_accumulation = 4 |
| elif gpu_memory >= 40: |
| batch_size = 2 |
| gradient_accumulation = 8 |
| else: |
| batch_size = 1 |
| gradient_accumulation = 16 |
| else: |
| batch_size = 8 |
| gradient_accumulation = 2 |
| |
| |
| if num_gpus > 1: |
| batch_size = batch_size * num_gpus |
| gradient_accumulation = max(1, gradient_accumulation // num_gpus) |
| else: |
| batch_size = 1 |
| gradient_accumulation = 32 |
| |
| print(f"Batch configuration:") |
| print(f" Per device batch size: {batch_size}") |
| print(f" Gradient accumulation steps: {gradient_accumulation}") |
| print(f" Effective batch size: {batch_size * gradient_accumulation}") |
| |
| |
| if self.wandb_enabled: |
| wandb.config.update({ |
| "batch_size": batch_size, |
| "gradient_accumulation_steps": gradient_accumulation, |
| "effective_batch_size": batch_size * gradient_accumulation, |
| "num_epochs": 3, |
| "learning_rate": 2e-4, |
| "warmup_ratio": 0.1, |
| "weight_decay": 0.01, |
| "max_grad_norm": 1.0, |
| "lr_scheduler": "cosine", |
| "optimizer": "adamw_torch" |
| }) |
| |
| self.training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=3, |
| per_device_train_batch_size=batch_size, |
| per_device_eval_batch_size=batch_size, |
| gradient_accumulation_steps=gradient_accumulation, |
| gradient_checkpointing=True, |
| warmup_ratio=0.1, |
| learning_rate=2e-4, |
| bf16=True, |
| tf32=True, |
| logging_steps=10, |
| logging_first_step=True, |
| eval_strategy="steps", |
| eval_steps=100, |
| save_strategy="steps", |
| save_steps=200, |
| save_total_limit=3, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| report_to="wandb" if self.wandb_enabled else "none", |
| run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run", |
| optim="adamw_torch", |
| lr_scheduler_type="cosine", |
| weight_decay=0.01, |
| max_grad_norm=1.0, |
| remove_unused_columns=False, |
| label_names=["labels"], |
| dataloader_num_workers=4, |
| dataloader_pin_memory=True, |
| ddp_find_unused_parameters=False if torch.cuda.device_count() > 1 else None, |
| ) |
| |
| def train(self): |
| """Execute training with complete dialogue history""" |
| |
| print("\n🎯 Starting training with complete dialogue history...") |
| |
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=False, |
| pad_to_multiple_of=8 |
| ) |
| |
| |
| class MetricsCallback(TrainerCallback): |
| def __init__(self, wandb_enabled): |
| self.wandb_enabled = wandb_enabled |
| |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if logs and self.wandb_enabled: |
| |
| if "loss" in logs: |
| logs["perplexity"] = np.exp(logs["loss"]) |
| if "eval_loss" in logs: |
| logs["eval_perplexity"] = np.exp(logs["eval_loss"]) |
| |
| |
| wandb.log(logs, step=state.global_step) |
| |
| return control |
| |
| |
| trainer = Trainer( |
| model=self.model, |
| args=self.training_args, |
| train_dataset=self.train_dataset, |
| eval_dataset=self.val_dataset, |
| data_collator=data_collator, |
| tokenizer=self.tokenizer, |
| callbacks=[MetricsCallback(self.wandb_enabled)] if self.wandb_enabled else [], |
| ) |
| |
| |
| total_steps = len(self.train_dataset) // ( |
| self.training_args.per_device_train_batch_size * |
| self.training_args.gradient_accumulation_steps |
| ) * self.training_args.num_train_epochs |
| |
| print("="*60) |
| print("Training Information:") |
| print(f" Total training samples: {len(self.train_dataset)}") |
| print(f" Total validation samples: {len(self.val_dataset)}") |
| print(f" Total training steps: {total_steps}") |
| print(f" Max sequence length: {self.max_seq_length}") |
| print("="*60) |
| |
| |
| if self.wandb_enabled: |
| wandb.log({ |
| "training_status": "started", |
| "total_steps": total_steps, |
| "max_seq_length": self.max_seq_length |
| }) |
| |
| try: |
| |
| print("\n🚀 Training started...") |
| train_result = trainer.train() |
| |
| |
| print("\n💾 Saving fine-tuned model...") |
| final_model_path = os.path.join(self.training_args.output_dir, "final_model") |
| trainer.save_model(final_model_path) |
| self.tokenizer.save_pretrained(final_model_path) |
| |
| |
| with open(os.path.join(self.training_args.output_dir, "training_metrics.json"), 'w') as f: |
| json.dump(train_result.metrics, f, indent=2) |
| |
| |
| print("\n📊 Running final evaluation...") |
| eval_results = trainer.evaluate() |
| |
| |
| with open(os.path.join(self.training_args.output_dir, "eval_metrics.json"), 'w') as f: |
| json.dump(eval_results, f, indent=2) |
| |
| |
| if self.wandb_enabled: |
| wandb.run.summary.update({ |
| "final_train_loss": train_result.metrics.get("train_loss", 0), |
| "final_eval_loss": eval_results.get("eval_loss", 0), |
| "final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)), |
| "total_training_time": train_result.metrics.get("train_runtime", 0), |
| "training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0), |
| "training_status": "completed" |
| }) |
| |
| |
| artifact = wandb.Artifact( |
| name=f"kokoro-model-complete-{wandb.run.id}", |
| type="model", |
| description="LFM model fine-tuned with complete dialogue history", |
| metadata={ |
| "base_model": self.model_name, |
| "final_loss": float(eval_results.get("eval_loss", 0)), |
| "final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))), |
| "max_seq_length": self.max_seq_length, |
| "methodology": "Complete dialogue history (KokoroChat)" |
| } |
| ) |
| artifact.add_dir(final_model_path) |
| wandb.log_artifact(artifact) |
| |
| print("\n" + "="*60) |
| print("✅ Training completed successfully!") |
| print(f"📁 Model saved to: {final_model_path}") |
| print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}") |
| print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}") |
| if self.wandb_enabled and wandb.run: |
| print(f"🔗 View results at: {wandb.run.get_url()}") |
| print("="*60) |
| |
| return trainer |
| |
| except Exception as e: |
| print(f"❌ Error during training: {e}") |
| |
| if self.wandb_enabled: |
| wandb.run.summary["training_status"] = "failed" |
| wandb.run.summary["error"] = str(e) |
| |
| |
| try: |
| emergency_path = os.path.join(self.training_args.output_dir, "emergency_checkpoint") |
| self.model.save_pretrained(emergency_path) |
| self.tokenizer.save_pretrained(emergency_path) |
| print(f"💾 Emergency checkpoint saved to: {emergency_path}") |
| except: |
| print("❌ Could not save emergency checkpoint") |
| |
| raise e |
| |
| finally: |
| if self.wandb_enabled: |
| wandb.finish() |
|
|
| def test_model_with_complete_history(model_path: str): |
| """Test the fine-tuned model with complete dialogue history examples""" |
| |
| print("\n" + "="*60) |
| print("🧪 Testing model with complete dialogue history") |
| print("="*60) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) |
| |
| |
| adapter_config_path = os.path.join(model_path, "adapter_config.json") |
| if os.path.exists(adapter_config_path): |
| print("Loading as PEFT model...") |
| config = PeftConfig.from_pretrained(model_path) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| config.base_model_name_or_path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| model = PeftModel.from_pretrained(base_model, model_path) |
| else: |
| print("Loading as regular model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| local_files_only=True, |
| trust_remote_code=True |
| ) |
| |
| model.eval() |
| |
| |
| test_cases = [ |
| { |
| "history": "クライアント: こんにちは。最近ストレスを感じています。\nカウンセラー: こんにちは。ストレスを感じていらっしゃるのですね。どのような状況でストレスを感じることが多いですか?\n", |
| "current": "クライアント: 仕事が忙しくて、休む時間がありません。" |
| }, |
| { |
| "history": "", |
| "current": "クライアント: 人間関係で悩んでいます。" |
| } |
| ] |
| |
| print("Testing with complete dialogue history:\n") |
| |
| for i, test_case in enumerate(test_cases, 1): |
| print(f"Test Case {i}:") |
| print("-" * 40) |
| |
| |
| if test_case["history"]: |
| prompt = f"""### Instruction: |
| あなたは専門的な訓練を受けた心理カウンセラーです。 |
| 以下の完全な対話履歴を踏まえて、カウンセラーとして適切な応答を生成してください。 |
| |
| ### Dialogue History: |
| {test_case["history"]}{test_case["current"]} |
| |
| ### Response: |
| """ |
| else: |
| prompt = f"""### Instruction: |
| あなたは専門的な訓練を受けた心理カウンセラーです。 |
| |
| ### Dialogue History: |
| {test_case["current"]} |
| |
| ### Response: |
| """ |
| |
| |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) |
| inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=150, |
| temperature=0, |
| do_sample=True, |
| top_p=0.9, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| response = response.split("### Response:")[-1].strip() if "### Response:" in response else response |
| |
| |
| print("History Length: {} turns".format(len(test_case['history'].split('\\n')) if test_case['history'] else 0)) |
|
|
| print(f"Current Input: {test_case['current']}") |
| print(f"Generated Response: {response[:300]}...") |
| print() |
| |
| print("="*60) |
|
|
| |
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Fine-tune LFM model with complete dialogue history') |
| parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B', |
| help='Base model name') |
| parser.add_argument('--data_path', type=str, default='./kokoro_processed_data', |
| help='Path to processed data with complete dialogue history') |
| parser.add_argument('--output_dir', type=str, default='./lfm_kokoro_complete', |
| help='Output directory for fine-tuned model') |
| parser.add_argument('--max_seq_length', type=int, default=2048, |
| help='Maximum sequence length for complete dialogues') |
| parser.add_argument('--use_4bit', action='store_true', |
| help='Use 4-bit quantization') |
| parser.add_argument('--test_only', action='store_true', |
| help='Only test existing model') |
| |
| args = parser.parse_args() |
| |
| if args.test_only: |
| |
| test_model_with_complete_history( |
| os.path.join(args.output_dir, "final_model") |
| ) |
| else: |
| |
| if not torch.cuda.is_available(): |
| print("⚠️ Warning: CUDA is not available. Training will be slow.") |
| response = input("Continue? (y/n): ") |
| if response.lower() != 'y': |
| exit() |
| |
| try: |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| print(f"🚀 Initializing fine-tuner for complete dialogue history") |
| finetuner = LFMKokoroChatFineTuner( |
| model_name=args.model_name, |
| use_4bit=args.use_4bit, |
| max_seq_length=args.max_seq_length |
| ) |
| |
| |
| finetuner.setup_model_and_tokenizer() |
| |
| |
| finetuner.load_and_process_datasets(args.data_path) |
| |
| |
| finetuner.setup_training_args(args.output_dir) |
| |
| |
| trainer = finetuner.train() |
| |
| |
| print("\n🧪 Testing the fine-tuned model...") |
| test_model_with_complete_history( |
| os.path.join(args.output_dir, "final_model") |
| ) |
| |
| print("\n✅ Fine-tuning with complete dialogue history completed!") |
| print(f"📁 Model saved to: {args.output_dir}/final_model") |
| print("\n📋 Next steps:") |
| print(f"1. Test more: python {__file__} --test_only --output_dir {args.output_dir}") |
| print("2. Run benchmarking with complete history support") |
| print("3. Deploy for production use") |
| |
| except KeyboardInterrupt: |
| print("\n\n⚠️ Training interrupted by user.") |
| if wandb.run: |
| wandb.finish() |
| except Exception as e: |
| print(f"\n❌ Error: {e}") |
| import traceback |
| traceback.print_exc() |
| if wandb.run: |
| wandb.finish() |
|
|