| """Display utilities for training configuration. |
| |
| This module provides formatted console output for LtxTrainerConfig. |
| """ |
|
|
| from rich import box |
| from rich.console import Console |
| from rich.table import Table |
|
|
| from ltx_trainer.config import LtxTrainerConfig |
|
|
|
|
| def print_config(config: LtxTrainerConfig) -> None: |
| """Print configuration as a nicely formatted table with sections.""" |
|
|
| def fmt(v: object, max_len: int = 55) -> str: |
| """Format any value for display.""" |
| if v is None: |
| return "[dim]β[/]" |
| if isinstance(v, bool): |
| return "[green]β[/]" if v else "[dim]β[/]" |
| if isinstance(v, (list, tuple)): |
| if not v: |
| return "[dim]β[/]" |
| return ", ".join(str(x) for x in v) |
| s = str(v) |
| return s[: max_len - 3] + "..." if len(s) > max_len else s |
|
|
| cfg = config |
| opt = cfg.optimization |
| val = cfg.validation |
| accel = cfg.acceleration |
|
|
| |
| sections: list[tuple[str, list[tuple[str, str]]]] = [ |
| ( |
| "π¬ Model", |
| [ |
| ("Base", fmt(cfg.model.model_path)), |
| ("Text Encoder", fmt(cfg.model.text_encoder_path) or "[dim]Built-in[/]"), |
| ("Training Mode", f"[bold green]{cfg.model.training_mode.upper()}[/]"), |
| ("Load Checkpoint", fmt(cfg.model.load_checkpoint) if cfg.model.load_checkpoint else "[dim]β[/]"), |
| ], |
| ), |
| ] |
|
|
| if cfg.lora: |
| sections.append( |
| ( |
| "π LoRA", |
| [ |
| ("Rank / Alpha", f"{cfg.lora.rank} / {cfg.lora.alpha}"), |
| ("Dropout", str(cfg.lora.dropout)), |
| ("Target Modules", fmt(cfg.lora.target_modules)), |
| ], |
| ) |
| ) |
|
|
| |
| strategy_items: list[tuple[str, str]] = [("Name", cfg.training_strategy.name)] |
| if hasattr(cfg.training_strategy, "with_audio"): |
| strategy_items.append(("Audio", fmt(cfg.training_strategy.with_audio))) |
| if hasattr(cfg.training_strategy, "first_frame_conditioning_p"): |
| strategy_items.append(("First Frame Cond P", str(cfg.training_strategy.first_frame_conditioning_p))) |
|
|
| sections.append(("π― Strategy", strategy_items)) |
|
|
| sections.extend( |
| [ |
| ( |
| "β‘ Optimization", |
| [ |
| ("Steps", f"[bold]{opt.steps:,}[/]"), |
| ("Learning Rate", f"{opt.learning_rate:.2e}"), |
| ("Batch Size", str(opt.batch_size)), |
| ("Grad Accumulation", str(opt.gradient_accumulation_steps)), |
| ("Optimizer", opt.optimizer_type), |
| ("Scheduler", opt.scheduler_type), |
| ("Max Grad Norm", str(opt.max_grad_norm)), |
| ("Grad Checkpointing", fmt(opt.enable_gradient_checkpointing)), |
| ], |
| ), |
| ( |
| "π Acceleration", |
| [ |
| ("Mixed Precision", accel.mixed_precision_mode or "[dim]β[/]"), |
| ("Quantization", str(accel.quantization) if accel.quantization else "[dim]β[/]"), |
| ("Text Encoder 8bit", fmt(accel.load_text_encoder_in_8bit)), |
| ], |
| ), |
| ( |
| "π₯ Validation", |
| [ |
| ("Prompts", f"{len(val.prompts)} prompt(s)" if val.prompts else "[dim]β[/]"), |
| ("Interval", f"Every {val.interval} steps" if val.interval else "[dim]Disabled[/]"), |
| ("Video Dims", f"{val.video_dims[0]}x{val.video_dims[1]}, {val.video_dims[2]} frames"), |
| ("Frame Rate", f"{val.frame_rate} fps"), |
| ("Inference Steps", str(val.inference_steps)), |
| ("CFG Scale", str(val.guidance_scale)), |
| ( |
| "STG", |
| f"scale={val.stg_scale}; blocks={fmt(val.stg_blocks)}; mode={val.stg_mode}" |
| if val.stg_scale > 0 |
| else "[dim]Disabled[/]", |
| ), |
| ("Seed", str(val.seed)), |
| ], |
| ), |
| ( |
| "π Data & Output", |
| [ |
| ("Dataset", fmt(cfg.data.preprocessed_data_root)), |
| ("Dataloader Workers", str(cfg.data.num_dataloader_workers)), |
| ("Output Dir", fmt(cfg.output_dir)), |
| ("Seed", str(cfg.seed)), |
| ], |
| ), |
| ( |
| "π Integrations", |
| [ |
| ( |
| "Checkpoints", |
| f"Every {cfg.checkpoints.interval} steps (keep {cfg.checkpoints.keep_last_n})" |
| if cfg.checkpoints.interval |
| else "[dim]Disabled[/]", |
| ), |
| ("W&B", f"{cfg.wandb.project}" if cfg.wandb.enabled else "[dim]Disabled[/]"), |
| ("HF Hub", cfg.hub.hub_model_id if cfg.hub.push_to_hub else "[dim]Disabled[/]"), |
| ], |
| ), |
| ] |
| ) |
|
|
| |
| table = Table( |
| title="[bold]βοΈ Training Configuration[/]", |
| show_header=False, |
| box=box.ROUNDED, |
| border_style="bright_blue", |
| padding=(0, 1), |
| title_style="bold bright_blue", |
| ) |
| table.add_column("Key", style="white", width=20) |
| table.add_column("Value", style="cyan") |
|
|
| for i, (section_title, items) in enumerate(sections): |
| if i > 0: |
| table.add_row("", "") |
| table.add_row(f"[bold yellow]{section_title}[/]", "") |
| for key, value in items: |
| table.add_row(f" {key}", value) |
|
|
| console = Console() |
| console.print() |
| console.print(table) |
| console.print() |
|
|