|
|
from InquirerPy import prompt |
|
|
from rich.console import Console |
|
|
from rich.table import Table |
|
|
from utils.device import get_system_info |
|
|
from utils.validators import detect_datasets |
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
import datetime |
|
|
|
|
|
console = Console() |
|
|
|
|
|
def display_system_summary(): |
|
|
info = get_system_info() |
|
|
|
|
|
table = Table(title="π₯οΈ System Detection Summary", show_lines=True) |
|
|
table.add_column("Property", style="cyan", no_wrap=True) |
|
|
table.add_column("Value", style="green") |
|
|
|
|
|
for key, val in info.items(): |
|
|
if key == "GPUs": |
|
|
for i, gpu in enumerate(val): |
|
|
table.add_row(f"GPU {i} Name", gpu['name']) |
|
|
table.add_row(f"GPU {i} Memory", gpu['memory']) |
|
|
else: |
|
|
table.add_row(key, str(val)) |
|
|
|
|
|
console.print("\n") |
|
|
console.print(table) |
|
|
|
|
|
def get_available_models(): |
|
|
"""Get available models for LoRA training with auto-detection.""" |
|
|
|
|
|
hf_cache = os.path.expanduser("~/.cache/huggingface/hub/models--") |
|
|
model_choices = [] |
|
|
|
|
|
if os.path.exists(hf_cache): |
|
|
for root, dirs, files in os.walk(hf_cache): |
|
|
for d in dirs: |
|
|
if d.startswith("snapshots"): |
|
|
model_dir = os.path.basename(os.path.dirname(root)) |
|
|
model_choices.append(model_dir.replace("models--", "").replace("--", "/")) |
|
|
|
|
|
|
|
|
model_choices += [ |
|
|
"meta-llama/Meta-Llama-3-8B-Instruct", |
|
|
"meta-llama/Meta-Llama-3-70B-Instruct", |
|
|
"mistralai/Mistral-7B-Instruct-v0.1", |
|
|
"mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
|
"microsoft/Phi-2", |
|
|
"microsoft/Phi-3-mini-4k-instruct", |
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
|
"Qwen/Qwen1.5-0.5B", |
|
|
"Qwen/Qwen1.5-1.8B", |
|
|
"Qwen/Qwen1.5-7B", |
|
|
"google/gemma-2-2b-it", |
|
|
"google/gemma-2-9b-it", |
|
|
"manual-entry (custom path/repo)" |
|
|
] |
|
|
|
|
|
|
|
|
return sorted(list(set(model_choices))) |
|
|
|
|
|
def get_available_datasets(): |
|
|
"""Get available datasets for LoRA training.""" |
|
|
|
|
|
local_datasets = detect_datasets() |
|
|
|
|
|
|
|
|
hf_datasets = [ |
|
|
("wikitext-2-raw-v1", "Hugging Face - WikiText-2 (Raw)"), |
|
|
("wikitext-103-raw-v1", "Hugging Face - WikiText-103 (Raw)"), |
|
|
("openwebtext", "Hugging Face - OpenWebText"), |
|
|
("c4", "Hugging Face - C4 (Common Crawl)"), |
|
|
("bookcorpus", "Hugging Face - BookCorpus"), |
|
|
] |
|
|
|
|
|
|
|
|
all_datasets = [] |
|
|
|
|
|
|
|
|
for name, path in local_datasets: |
|
|
all_datasets.append((f"Local - {name}", f"local:{path}")) |
|
|
|
|
|
|
|
|
for dataset_id, display_name in hf_datasets: |
|
|
all_datasets.append((display_name, f"hf:{dataset_id}")) |
|
|
|
|
|
return all_datasets |
|
|
|
|
|
def generate_output_directory(model_name, dataset_name): |
|
|
"""Generate a meaningful output directory name.""" |
|
|
|
|
|
model_clean = model_name.replace("/", "_").replace(":", "_") |
|
|
dataset_clean = dataset_name.replace("/", "_").replace(":", "_") |
|
|
|
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
return f"out_lora_{model_clean}_{dataset_clean}_{timestamp}" |
|
|
|
|
|
def get_lora_presets(): |
|
|
"""Get LoRA configuration presets.""" |
|
|
return [ |
|
|
{ |
|
|
"name": "Efficient (r=8, Ξ±=16)", |
|
|
"description": "Fast training, lower memory usage", |
|
|
"r": 8, |
|
|
"alpha": 16, |
|
|
"dropout": 0.05 |
|
|
}, |
|
|
{ |
|
|
"name": "Balanced (r=16, Ξ±=32)", |
|
|
"description": "Good balance of performance and speed", |
|
|
"r": 16, |
|
|
"alpha": 32, |
|
|
"dropout": 0.05 |
|
|
}, |
|
|
{ |
|
|
"name": "High Quality (r=32, Ξ±=64)", |
|
|
"description": "Better performance, more parameters", |
|
|
"r": 32, |
|
|
"alpha": 64, |
|
|
"dropout": 0.1 |
|
|
}, |
|
|
{ |
|
|
"name": "Custom Configuration", |
|
|
"description": "Set your own LoRA parameters", |
|
|
"r": 16, |
|
|
"alpha": 32, |
|
|
"dropout": 0.05 |
|
|
} |
|
|
] |
|
|
|
|
|
def run(): |
|
|
"""Run the LoRA training wizard.""" |
|
|
console.print("\n[bold magenta]π§ Single-GPU LoRA Training Setup[/bold magenta]") |
|
|
|
|
|
|
|
|
questions = [ |
|
|
{ |
|
|
"type": "list", |
|
|
"name": "setup_mode", |
|
|
"message": "Choose Setup Mode:", |
|
|
"choices": [ |
|
|
"Quick Start β Recommended settings for most users", |
|
|
"Custom Setup β Full control over all parameters" |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
answers = prompt(questions) |
|
|
setup_mode = answers.get("setup_mode").split(" ")[0].lower() |
|
|
|
|
|
console.print(f"\n[green]β
You selected:[/green] [yellow]{answers.get('setup_mode')}[/yellow]") |
|
|
|
|
|
|
|
|
display_system_summary() |
|
|
|
|
|
|
|
|
console.print("\n[bold blue]π§ Model Selection[/bold blue]") |
|
|
available_models = get_available_models() |
|
|
|
|
|
model_question = [ |
|
|
{ |
|
|
"type": "list", |
|
|
"name": "base_model", |
|
|
"message": "Choose Base Model for LoRA Training:", |
|
|
"choices": available_models |
|
|
} |
|
|
] |
|
|
|
|
|
model_answer = prompt(model_question) |
|
|
selected_model = model_answer.get("base_model") |
|
|
|
|
|
|
|
|
if selected_model == "manual-entry (custom path/repo)": |
|
|
manual_input = prompt([ |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "custom_model", |
|
|
"message": "Enter Hugging Face repo or local model path:", |
|
|
"validate": lambda x: len(x.strip()) > 0 |
|
|
} |
|
|
]) |
|
|
selected_model = manual_input.get("custom_model") |
|
|
|
|
|
console.print(f"\n[green]β
Selected model:[/green] [yellow]{selected_model}[/yellow]") |
|
|
|
|
|
|
|
|
console.print("\n[bold blue]π Dataset Selection[/bold blue]") |
|
|
available_datasets = get_available_datasets() |
|
|
|
|
|
if not available_datasets: |
|
|
console.print("[bold red]β οΈ No datasets found! Please ensure you have datasets available.[/bold red]") |
|
|
return None |
|
|
|
|
|
dataset_question = [ |
|
|
{ |
|
|
"type": "list", |
|
|
"name": "dataset", |
|
|
"message": "Choose Dataset for Training:", |
|
|
"choices": [name for name, _ in available_datasets] |
|
|
} |
|
|
] |
|
|
|
|
|
dataset_answer = prompt(dataset_question) |
|
|
selected_dataset_display = dataset_answer.get("dataset") |
|
|
|
|
|
|
|
|
selected_dataset = None |
|
|
for name, path in available_datasets: |
|
|
if name == selected_dataset_display: |
|
|
selected_dataset = path |
|
|
break |
|
|
|
|
|
console.print(f"\n[green]β
Selected dataset:[/green] [yellow]{selected_dataset_display}[/yellow]") |
|
|
|
|
|
|
|
|
output_dir = generate_output_directory(selected_model, selected_dataset_display) |
|
|
console.print(f"\n[green]π Output directory:[/green] [yellow]{output_dir}[/yellow]") |
|
|
|
|
|
|
|
|
console.print("\n[bold blue]βοΈ LoRA Configuration[/bold blue]") |
|
|
lora_presets = get_lora_presets() |
|
|
|
|
|
lora_question = [ |
|
|
{ |
|
|
"type": "list", |
|
|
"name": "lora_preset", |
|
|
"message": "Choose LoRA Configuration:", |
|
|
"choices": [f"{preset['name']} - {preset['description']}" for preset in lora_presets] |
|
|
} |
|
|
] |
|
|
|
|
|
lora_answer = prompt(lora_question) |
|
|
selected_preset = lora_answer.get("lora_preset").split(" - ")[0] |
|
|
|
|
|
|
|
|
selected_lora_config = None |
|
|
for preset in lora_presets: |
|
|
if preset['name'] == selected_preset: |
|
|
selected_lora_config = preset |
|
|
break |
|
|
|
|
|
console.print(f"\n[green]β
LoRA config:[/green] [yellow]{selected_preset}[/yellow]") |
|
|
|
|
|
|
|
|
if setup_mode == "custom": |
|
|
console.print("\n[bold blue]π― Training Parameters[/bold blue]") |
|
|
|
|
|
param_questions = [ |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "max_steps", |
|
|
"message": "Maximum training steps:", |
|
|
"default": "1000", |
|
|
"validate": lambda x: x.isdigit() and int(x) > 0 |
|
|
}, |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "batch_size", |
|
|
"message": "Per-device batch size:", |
|
|
"default": "4", |
|
|
"validate": lambda x: x.isdigit() and int(x) > 0 |
|
|
}, |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "grad_accum", |
|
|
"message": "Gradient accumulation steps:", |
|
|
"default": "4", |
|
|
"validate": lambda x: x.isdigit() and int(x) > 0 |
|
|
}, |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "learning_rate", |
|
|
"message": "Learning rate:", |
|
|
"default": "2e-4", |
|
|
"validate": lambda x: float(x) > 0 |
|
|
}, |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "block_size", |
|
|
"message": "Block size for text grouping:", |
|
|
"default": "512", |
|
|
"validate": lambda x: x.isdigit() and int(x) > 0 |
|
|
} |
|
|
] |
|
|
|
|
|
param_answers = prompt(param_questions) |
|
|
else: |
|
|
|
|
|
param_answers = { |
|
|
"max_steps": "1000", |
|
|
"batch_size": "4", |
|
|
"grad_accum": "4", |
|
|
"learning_rate": "2e-4", |
|
|
"block_size": "512" |
|
|
} |
|
|
|
|
|
|
|
|
if selected_preset == "Custom Configuration": |
|
|
console.print("\n[bold blue]π§ Custom LoRA Parameters[/bold blue]") |
|
|
|
|
|
custom_lora_questions = [ |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "lora_r", |
|
|
"message": "LoRA rank (r):", |
|
|
"default": "16", |
|
|
"validate": lambda x: x.isdigit() and int(x) > 0 |
|
|
}, |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "lora_alpha", |
|
|
"message": "LoRA alpha:", |
|
|
"default": "32", |
|
|
"validate": lambda x: x.isdigit() and int(x) > 0 |
|
|
}, |
|
|
{ |
|
|
"type": "input", |
|
|
"name": "lora_dropout", |
|
|
"message": "LoRA dropout:", |
|
|
"default": "0.05", |
|
|
"validate": lambda x: 0 <= float(x) <= 1 |
|
|
} |
|
|
] |
|
|
|
|
|
custom_lora_answers = prompt(custom_lora_questions) |
|
|
selected_lora_config.update({ |
|
|
"r": int(custom_lora_answers["lora_r"]), |
|
|
"alpha": int(custom_lora_answers["lora_alpha"]), |
|
|
"dropout": float(custom_lora_answers["lora_dropout"]) |
|
|
}) |
|
|
|
|
|
|
|
|
if selected_dataset.startswith("local:"): |
|
|
dataset_name = "jsonl" |
|
|
dataset_config = selected_dataset[6:] |
|
|
elif selected_dataset.startswith("hf:"): |
|
|
dataset_name = "wikitext" |
|
|
dataset_config = selected_dataset[3:] |
|
|
else: |
|
|
dataset_name = "wikitext" |
|
|
dataset_config = selected_dataset |
|
|
|
|
|
|
|
|
final_config = { |
|
|
"setup_mode": setup_mode, |
|
|
"base_model": selected_model, |
|
|
"dataset_name": dataset_name, |
|
|
"dataset_config": dataset_config, |
|
|
"dataset_display": selected_dataset_display, |
|
|
"output_dir": output_dir, |
|
|
"lora_config": selected_lora_config, |
|
|
"max_steps": int(param_answers["max_steps"]), |
|
|
"batch_size": int(param_answers["batch_size"]), |
|
|
"grad_accum": int(param_answers["grad_accum"]), |
|
|
"learning_rate": float(param_answers["learning_rate"]), |
|
|
"block_size": int(param_answers["block_size"]), |
|
|
"timestamp": datetime.datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
console.print("\n[bold cyan]π Configuration Summary[/bold cyan]") |
|
|
summary_table = Table(show_header=True, header_style="bold magenta") |
|
|
summary_table.add_column("Parameter", style="cyan") |
|
|
summary_table.add_column("Value", style="green") |
|
|
|
|
|
summary_table.add_row("Model", selected_model) |
|
|
summary_table.add_row("Dataset", selected_dataset_display) |
|
|
summary_table.add_row("Output Directory", output_dir) |
|
|
summary_table.add_row("LoRA Rank (r)", str(selected_lora_config["r"])) |
|
|
summary_table.add_row("LoRA Alpha", str(selected_lora_config["alpha"])) |
|
|
summary_table.add_row("LoRA Dropout", str(selected_lora_config["dropout"])) |
|
|
summary_table.add_row("Max Steps", str(final_config["max_steps"])) |
|
|
summary_table.add_row("Batch Size", str(final_config["batch_size"])) |
|
|
summary_table.add_row("Grad Accumulation", str(final_config["grad_accum"])) |
|
|
summary_table.add_row("Learning Rate", str(final_config["learning_rate"])) |
|
|
summary_table.add_row("Block Size", str(final_config["block_size"])) |
|
|
|
|
|
console.print(summary_table) |
|
|
|
|
|
|
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
config_path = Path(output_dir) / "lora_config.json" |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(final_config, f, indent=2) |
|
|
|
|
|
console.print(f"\n[bold green]β
Configuration saved to:[/bold green] [cyan]{config_path}[/cyan]") |
|
|
|
|
|
|
|
|
reproduce_script = f"""#!/bin/bash |
|
|
# Re-run this exact LoRA training config |
|
|
cd {Path.cwd()} |
|
|
python3 cli/train_lora_single.py \\ |
|
|
--model {selected_model} \\ |
|
|
--output-dir {output_dir} \\ |
|
|
--max-steps {final_config["max_steps"]} \\ |
|
|
--batch-size {final_config["batch_size"]} \\ |
|
|
--grad-accum {final_config["grad_accum"]} \\ |
|
|
--learning-rate {final_config["learning_rate"]} \\ |
|
|
--block-size {final_config["block_size"]} \\ |
|
|
--lora-r {selected_lora_config["r"]} \\ |
|
|
--lora-alpha {selected_lora_config["alpha"]} \\ |
|
|
--lora-dropout {selected_lora_config["dropout"]} \\ |
|
|
--dataset {dataset_name} \\ |
|
|
--dataset-config {dataset_config} |
|
|
""" |
|
|
|
|
|
reproduce_path = Path(output_dir) / "reproduce.sh" |
|
|
with open(reproduce_path, "w") as f: |
|
|
f.write(reproduce_script) |
|
|
reproduce_path.chmod(0o755) |
|
|
|
|
|
console.print(f"[bold green]β
Reproduction script saved to:[/bold green] [cyan]{reproduce_path}[/cyan]") |
|
|
|
|
|
|
|
|
final_prompt = prompt([ |
|
|
{ |
|
|
"type": "confirm", |
|
|
"name": "confirm_training", |
|
|
"message": "π Start LoRA training now?", |
|
|
"default": True |
|
|
} |
|
|
]) |
|
|
|
|
|
if not final_prompt["confirm_training"]: |
|
|
console.print("[bold yellow]β Training cancelled.[/bold yellow]") |
|
|
return None |
|
|
else: |
|
|
console.print("[bold green]π Starting LoRA training...[/bold green]") |
|
|
return final_config |
|
|
|
|
|
if __name__ == "__main__": |
|
|
run() |
|
|
|
|
|
|