| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| |
|
|
| import torch |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling, |
| BitsAndBytesConfig, |
| TrainerCallback |
| ) |
| from peft import ( |
| LoraConfig, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| TaskType |
| ) |
| from datasets import load_dataset, Dataset |
| import os |
| from typing import Dict, List, Optional |
| import numpy as np |
| from tqdm import tqdm |
| import json |
| import gc |
| import warnings |
| import wandb |
| from datetime import datetime |
|
|
| warnings.filterwarnings('ignore') |
|
|
| class LFMCounselorFineTuner: |
| def __init__(self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = True): |
| """ |
| Initialize the fine-tuner for LFM models |
| |
| Args: |
| model_name: Name of the base model |
| use_4bit: Whether to use 4-bit quantization for memory efficiency |
| """ |
| self.model_name = model_name |
| self.use_4bit = use_4bit |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| print(f"Using device: {self.device}") |
| gpu_memory = 0 |
| if torch.cuda.is_available(): |
| gpu_name = torch.cuda.get_device_name(0) |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| print(f"GPU: {gpu_name}") |
| print(f"GPU Memory: {gpu_memory:.2f} GB") |
| |
| |
| try: |
| |
| run_name = f"lfm-counselor-{datetime.now().strftime('%Y%m%d-%H%M%S')}" |
| |
| |
| wandb.init( |
| project="liquid-counselor-hackathon", |
| name=run_name, |
| config={ |
| "model_name": model_name, |
| "use_4bit_quantization": use_4bit, |
| "device": str(self.device), |
| "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU", |
| "gpu_memory_gb": gpu_memory, |
| "framework": "transformers", |
| "peft_method": "LoRA", |
| "task": "japanese_counseling", |
| "dataset": "KokoroChat" |
| }, |
| tags=["counseling", "japanese", "lfm", "finetune", "hackathon"] |
| ) |
| print(f"✅ WandB initialized: {wandb.run.name}") |
| print(f"📊 View run at: {wandb.run.get_url()}") |
| self.wandb_enabled = True |
| except Exception as e: |
| print(f"⚠️ WandB initialization failed: {e}") |
| print("Continuing without WandB logging...") |
| self.wandb_enabled = False |
| os.environ["WANDB_DISABLED"] = "true" |
| |
| def setup_model_and_tokenizer(self): |
| """Setup model with quantization and LoRA""" |
| |
| print("Loading tokenizer...") |
| |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| except: |
| |
| print("Using fallback tokenizer...") |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| if self.tokenizer.eos_token is None: |
| self.tokenizer.eos_token = "</s>" |
| self.tokenizer.pad_token = "</s>" |
| |
| self.tokenizer.padding_side = "right" |
| |
| |
| if self.use_4bit: |
| print("Setting up 4-bit quantization...") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_use_double_quant=True |
| ) |
| else: |
| bnb_config = None |
| |
| |
| print(f"Loading model: {self.model_name}...") |
| try: |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_name, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype=torch.float16 |
| ) |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| print("Attempting to load without quantization...") |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_name, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True |
| ) |
| |
| |
| if hasattr(self.model, 'gradient_checkpointing_enable'): |
| self.model.gradient_checkpointing_enable() |
| |
| |
| if self.use_4bit: |
| print("Preparing model for 4-bit training...") |
| self.model = prepare_model_for_kbit_training(self.model) |
| |
| |
| print("Applying LoRA configuration...") |
| |
| |
| target_modules = self.find_target_modules() |
| |
| lora_config = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| target_modules=target_modules, |
| lora_dropout=0.05, |
| bias="none", |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False |
| ) |
| |
| |
| self.model = get_peft_model(self.model, lora_config) |
| |
| |
| trainable_params = 0 |
| all_params = 0 |
| for _, param in self.model.named_parameters(): |
| all_params += param.numel() |
| if param.requires_grad: |
| trainable_params += param.numel() |
| |
| trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0 |
| |
| print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)") |
| |
| |
| if self.wandb_enabled: |
| wandb.config.update({ |
| "lora_r": lora_config.r, |
| "lora_alpha": lora_config.lora_alpha, |
| "lora_dropout": lora_config.lora_dropout, |
| "lora_target_modules": target_modules, |
| "total_parameters": all_params, |
| "trainable_parameters": trainable_params, |
| "trainable_percentage": trainable_percentage |
| }) |
| |
| self.model.print_trainable_parameters() |
| |
| def find_target_modules(self): |
| """Find linear modules to apply LoRA to""" |
| target_modules = [] |
| for name, module in self.model.named_modules(): |
| if isinstance(module, torch.nn.Linear): |
| |
| names = name.split('.') |
| if len(names) > 0: |
| target_modules.append(names[-1]) |
| |
| |
| target_modules = list(set(target_modules)) |
| |
| |
| common_targets = ["q_proj", "v_proj", "k_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| "fc1", "fc2", "query", "key", "value", "dense"] |
| |
| |
| final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)] |
| |
| |
| if not final_targets: |
| final_targets = target_modules[:6] |
| |
| print(f"LoRA target modules: {final_targets}") |
| return final_targets if final_targets else ["q_proj", "v_proj"] |
| |
| def load_and_process_datasets(self, data_path: str): |
| """Load and process datasets without multiprocessing issues""" |
| |
| print(f"Loading datasets from {data_path}...") |
| |
| |
| train_texts = [] |
| train_scores = [] |
| train_topics = [] |
| |
| with open(f'{data_path}/train.jsonl', 'r', encoding='utf-8') as f: |
| for line in tqdm(f, desc="Loading training data"): |
| data = json.loads(line) |
| train_texts.append(data['text']) |
| train_scores.append(data.get('score', 0)) |
| train_topics.append(data.get('topic', 'Unknown')) |
| |
| |
| val_texts = [] |
| val_scores = [] |
| val_topics = [] |
| |
| with open(f'{data_path}/validation.jsonl', 'r', encoding='utf-8') as f: |
| for line in tqdm(f, desc="Loading validation data"): |
| data = json.loads(line) |
| val_texts.append(data['text']) |
| val_scores.append(data.get('score', 0)) |
| val_topics.append(data.get('topic', 'Unknown')) |
| |
| print(f"Loaded {len(train_texts)} training examples") |
| print(f"Loaded {len(val_texts)} validation examples") |
| |
| |
| if self.wandb_enabled: |
| |
| train_score_stats = { |
| "train_examples": len(train_texts), |
| "train_avg_score": float(np.mean(train_scores)), |
| "train_min_score": float(np.min(train_scores)), |
| "train_max_score": float(np.max(train_scores)), |
| "train_std_score": float(np.std(train_scores)) |
| } |
| |
| val_score_stats = { |
| "val_examples": len(val_texts), |
| "val_avg_score": float(np.mean(val_scores)), |
| "val_min_score": float(np.min(val_scores)), |
| "val_max_score": float(np.max(val_scores)), |
| "val_std_score": float(np.std(val_scores)) |
| } |
| |
| wandb.config.update(train_score_stats) |
| wandb.config.update(val_score_stats) |
| |
| |
| wandb.log({ |
| "train_score_distribution": wandb.Histogram(train_scores), |
| "val_score_distribution": wandb.Histogram(val_scores) |
| }) |
| |
| |
| train_topic_counts = {} |
| for topic in train_topics: |
| train_topic_counts[topic] = train_topic_counts.get(topic, 0) + 1 |
| |
| |
| if len(train_topic_counts) > 0: |
| top_topics = sorted(train_topic_counts.items(), key=lambda x: x[1], reverse=True)[:20] |
| wandb.log({ |
| "topic_distribution": wandb.plot.bar( |
| wandb.Table(data=[[k, v] for k, v in top_topics], |
| columns=["Topic", "Count"]), |
| "Topic", "Count", title="Training Topic Distribution (Top 20)" |
| ) |
| }) |
| |
| |
| print("Tokenizing training dataset...") |
| train_encodings = self.tokenize_texts(train_texts) |
| |
| print("Tokenizing validation dataset...") |
| val_encodings = self.tokenize_texts(val_texts) |
| |
| |
| self.train_dataset = Dataset.from_dict(train_encodings) |
| self.val_dataset = Dataset.from_dict(val_encodings) |
| |
| |
| self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
| self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
| |
| |
| del train_texts, val_texts, train_encodings, val_encodings |
| gc.collect() |
| |
| def tokenize_texts(self, texts: List[str], batch_size: int = 100): |
| """Tokenize texts in batches to avoid memory issues""" |
| all_input_ids = [] |
| all_attention_masks = [] |
| |
| for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"): |
| batch_texts = texts[i:i + batch_size] |
| |
| |
| encodings = self.tokenizer( |
| batch_texts, |
| truncation=True, |
| padding='max_length', |
| max_length=512, |
| return_tensors='pt' |
| ) |
| |
| |
| all_input_ids.extend(encodings['input_ids'].tolist()) |
| all_attention_masks.extend(encodings['attention_mask'].tolist()) |
| |
| |
| labels = all_input_ids.copy() |
| |
| return { |
| 'input_ids': all_input_ids, |
| 'attention_mask': all_attention_masks, |
| 'labels': labels |
| } |
| |
| def setup_training_args(self, output_dir: str = "./counselor_model_2b"): |
| """Setup training arguments optimized for counseling task""" |
| |
| print("Setting up training arguments...") |
| |
| |
| if torch.cuda.is_available(): |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| if gpu_memory < 16: |
| batch_size = 1 |
| gradient_accumulation = 16 |
| elif gpu_memory < 24: |
| batch_size = 2 |
| gradient_accumulation = 8 |
| else: |
| batch_size = 4 |
| gradient_accumulation = 4 |
| else: |
| batch_size = 1 |
| gradient_accumulation = 16 |
| |
| print(f"Using batch_size={batch_size}, gradient_accumulation={gradient_accumulation}") |
| |
| |
| if self.wandb_enabled: |
| wandb.config.update({ |
| "batch_size": batch_size, |
| "gradient_accumulation_steps": gradient_accumulation, |
| "effective_batch_size": batch_size * gradient_accumulation, |
| "num_epochs": 3, |
| "learning_rate": 5e-5, |
| "warmup_steps": 100, |
| "weight_decay": 0.01, |
| "max_grad_norm": 1.0, |
| "lr_scheduler": "linear", |
| "optimizer": "adamw_torch", |
| "fp16": True, |
| "max_length": 512 |
| }) |
| |
| |
| report_to = "wandb" if self.wandb_enabled else "none" |
| |
| self.training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=3, |
| per_device_train_batch_size=batch_size, |
| per_device_eval_batch_size=batch_size, |
| gradient_accumulation_steps=gradient_accumulation, |
| gradient_checkpointing=True, |
| warmup_steps=100, |
| learning_rate=5e-5, |
| fp16=True, |
| logging_steps=50, |
| logging_first_step=True, |
| eval_strategy="steps", |
| eval_steps=200, |
| save_strategy="steps", |
| save_steps=400, |
| save_total_limit=2, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| report_to=report_to, |
| run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run", |
| push_to_hub=False, |
| optim="adamw_torch", |
| lr_scheduler_type="linear", |
| weight_decay=0.01, |
| max_grad_norm=1.0, |
| remove_unused_columns=False, |
| label_names=["labels"], |
| dataloader_num_workers=0, |
| dataloader_pin_memory=False, |
| ) |
| |
| def train(self): |
| """Execute training""" |
| |
| print("Initializing trainer...") |
| |
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=False, |
| pad_to_multiple_of=8 |
| ) |
| |
| |
| class CustomMetricsCallback(TrainerCallback): |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if logs and self.wandb_enabled: |
| |
| if "loss" in logs: |
| logs["perplexity"] = np.exp(logs["loss"]) |
| if "eval_loss" in logs: |
| logs["eval_perplexity"] = np.exp(logs["eval_loss"]) |
| return control |
| |
| |
| custom_callback = CustomMetricsCallback() |
| custom_callback.wandb_enabled = self.wandb_enabled |
| |
| |
| try: |
| |
| trainer = Trainer( |
| model=self.model, |
| args=self.training_args, |
| train_dataset=self.train_dataset, |
| eval_dataset=self.val_dataset, |
| data_collator=data_collator, |
| tokenizer=self.tokenizer, |
| callbacks=[custom_callback] if self.wandb_enabled else [], |
| ) |
| |
| |
| total_steps = len(self.train_dataset) // (self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps) * self.training_args.num_train_epochs |
| |
| |
| print("="*50) |
| print("Starting fine-tuning...") |
| print(f"Total training samples: {len(self.train_dataset)}") |
| print(f"Total validation samples: {len(self.val_dataset)}") |
| print(f"Total training steps: {total_steps}") |
| print("="*50) |
| |
| |
| if self.wandb_enabled: |
| wandb.log({"training_status": "started", "total_steps": total_steps}) |
| |
| |
| train_result = trainer.train() |
| |
| |
| print("\nSaving fine-tuned model...") |
| trainer.save_model(f"{self.training_args.output_dir}/final_model_2b") |
| self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/final_model_2b") |
| |
| |
| with open(f"{self.training_args.output_dir}/training_metrics.json", 'w') as f: |
| json.dump(train_result.metrics, f, indent=2) |
| |
| |
| print("\nRunning final evaluation...") |
| eval_results = trainer.evaluate() |
| |
| |
| with open(f"{self.training_args.output_dir}/eval_metrics.json", 'w') as f: |
| json.dump(eval_results, f, indent=2) |
| |
| |
| if self.wandb_enabled: |
| |
| wandb.run.summary.update({ |
| "final_train_loss": train_result.metrics.get("train_loss", 0), |
| "final_eval_loss": eval_results.get("eval_loss", 0), |
| "final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)), |
| "total_training_time": train_result.metrics.get("train_runtime", 0), |
| "training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0), |
| "training_status": "completed" |
| }) |
| |
| |
| summary_table = wandb.Table( |
| columns=["Metric", "Value"], |
| data=[ |
| ["Final Training Loss", f"{train_result.metrics.get('train_loss', 0):.4f}"], |
| ["Final Eval Loss", f"{eval_results.get('eval_loss', 0):.4f}"], |
| ["Final Perplexity", f"{np.exp(eval_results.get('eval_loss', 0)):.2f}"], |
| ["Training Time (seconds)", f"{train_result.metrics.get('train_runtime', 0):.0f}"], |
| ["Training Samples/Second", f"{train_result.metrics.get('train_samples_per_second', 0):.2f}"] |
| ] |
| ) |
| wandb.log({"training_summary": summary_table}) |
| |
| |
| try: |
| artifact = wandb.Artifact( |
| name=f"counselor-model-{wandb.run.id}", |
| type="model", |
| description="Fine-tuned Japanese counseling model", |
| metadata={ |
| "base_model": self.model_name, |
| "final_loss": float(eval_results.get("eval_loss", 0)), |
| "final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))), |
| "dataset": "KokoroChat" |
| } |
| ) |
| artifact.add_dir(f"{self.training_args.output_dir}/final_model_2b") |
| wandb.log_artifact(artifact) |
| except Exception as e: |
| print(f"Warning: Could not save model artifact: {e}") |
| |
| print("\n" + "="*50) |
| print("✅ Training completed successfully!") |
| print(f"📁 Model saved to: {self.training_args.output_dir}/final_model_2b") |
| print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}") |
| print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}") |
| if self.wandb_enabled and wandb.run: |
| print(f"🔗 View results at: {wandb.run.get_url()}") |
| print("="*50) |
| |
| return trainer |
| |
| except Exception as e: |
| print(f"❌ Error during training: {e}") |
| |
| |
| if self.wandb_enabled: |
| wandb.run.summary["training_status"] = "failed" |
| wandb.run.summary["error"] = str(e) |
| |
| print("Attempting to save checkpoint...") |
| |
| |
| try: |
| self.model.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency") |
| self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency") |
| print(f"💾 Emergency checkpoint saved to: {self.training_args.output_dir}/checkpoint_emergency") |
| except: |
| print("❌ Could not save emergency checkpoint") |
| |
| raise e |
| finally: |
| |
| if self.wandb_enabled: |
| wandb.finish() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def test_model(model_path: str, tokenizer_path: str): |
| """Test the fine-tuned model with sample inputs""" |
| |
| print("\n" + "="*50) |
| print("Testing fine-tuned model...") |
| print("="*50) |
| |
| |
| from peft import PeftModel, PeftConfig |
| import os |
| |
| |
| try: |
| |
| if os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): |
| print(f"Loading tokenizer from {tokenizer_path}") |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True) |
| else: |
| print(f"Tokenizer not found at {tokenizer_path}, using base model tokenizer") |
| |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| except Exception as e: |
| print(f"Error loading tokenizer: {e}") |
| print("Using fallback GPT-2 tokenizer") |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| try: |
| |
| adapter_config_path = os.path.join(model_path, "adapter_config.json") |
| if os.path.exists(adapter_config_path): |
| print("Loading as PEFT model...") |
| config = PeftConfig.from_pretrained(model_path) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| config.base_model_name_or_path, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| model = PeftModel.from_pretrained(base_model, model_path) |
| else: |
| |
| print("Loading as regular model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| local_files_only=True, |
| trust_remote_code=True |
| ) |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| raise |
| |
| model.eval() |
| |
| |
| test_cases = [ |
| "こんにちは。最近ストレスを感じています。", |
| "仕事がうまくいかなくて悩んでいます。", |
| "人間関係で困っています。どうすればいいでしょうか。" |
| ] |
| |
| print("Sample conversations:") |
| print("-" * 50) |
| |
| for test_input in test_cases: |
| |
| inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512) |
| inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=150, |
| temperature=0.1, |
| do_sample=True, |
| top_p=0.9, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| response = response[len(test_input):].strip() |
| |
| print(f"Client: {test_input}") |
| print(f"Counselor: {response[:200]}...") |
| print("-" * 50) |
| |
| print("="*50) |
|
|
| for test_input in test_cases: |
| |
| inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512) |
| inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=150, |
| temperature=0.1, |
| do_sample=True, |
| top_p=0.9, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| response = response[len(test_input):].strip() |
| |
| print(f"Client: {test_input}") |
| print(f"Counselor: {response[:200]}...") |
| print("-" * 50) |
| |
| print("="*50) |
|
|
| |
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Fine-tune LFM model for counseling') |
| parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B', |
| help='Base model name') |
| parser.add_argument('--data_path', type=str, default='./processed_data_score80', |
| help='Path to processed data') |
| parser.add_argument('--output_dir', type=str, default='./counselor_model_2b', |
| help='Output directory for fine-tuned model') |
| parser.add_argument('--use_4bit', action='store_true', default=False, |
| help='Use 4-bit quantization') |
| parser.add_argument('--wandb_api_key', type=str, default=None, |
| help='WandB API key (optional, can use wandb login instead)') |
| parser.add_argument('--test_only', action='store_true', |
| help='Only test existing model') |
| |
| args = parser.parse_args() |
| |
| |
| if args.wandb_api_key: |
| os.environ["WANDB_API_KEY"] = args.wandb_api_key |
| |
| if args.test_only: |
| |
| test_model( |
| f"{args.output_dir}/final_model_2b", |
| f"{args.output_dir}/final_model_2b" |
| ) |
| else: |
| |
| if not torch.cuda.is_available(): |
| print("⚠️ Warning: CUDA is not available. Training will be very slow on CPU.") |
| print("It's highly recommended to use a GPU for training.") |
| response = input("Do you want to continue anyway? (y/n): ") |
| if response.lower() != 'y': |
| exit() |
| |
| try: |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| print(f"🚀 Initializing fine-tuner with model: {args.model_name}") |
| finetuner = LFMCounselorFineTuner( |
| model_name=args.model_name, |
| use_4bit=args.use_4bit |
| ) |
| |
| |
| print("\n🔧 Setting up model and tokenizer...") |
| finetuner.setup_model_and_tokenizer() |
| |
| |
| print("\n📚 Loading and processing datasets...") |
| finetuner.load_and_process_datasets(args.data_path) |
| |
| |
| print("\n⚙️ Setting up training arguments...") |
| finetuner.setup_training_args(args.output_dir) |
| |
| |
| trainer = finetuner.train() |
| |
| |
| print("\n🧪 Testing the fine-tuned model...") |
| test_model( |
| f"{args.output_dir}/final_model_2b_v2", |
| f"{args.output_dir}/final_model_2b_v2" |
| ) |
| |
| print("\n✅ Fine-tuning completed successfully!") |
| print(f"📁 Model saved to: {args.output_dir}/final_model_2b_v2") |
| print("\n📋 Next steps:") |
| print("1. Test more: python finetune_lfm.py --test_only") |
| print("2. Run benchmarking: python benchmark_model.py") |
| print("3. Optimize for mobile: python optimize_for_mobile.py") |
| |
| except KeyboardInterrupt: |
| print("\n\n⚠️ Training interrupted by user.") |
| print("Partial model may be saved in checkpoints.") |
| if wandb.run: |
| wandb.finish() |
| except Exception as e: |
| print(f"\n❌ Error during fine-tuning: {e}") |
| import traceback |
| traceback.print_exc() |
| if wandb.run: |
| wandb.finish() |
|
|