#!/usr/bin/env python3 """ Wrapper around gr00t/experiment/launch_finetune.py that: 1. Overrides config.model.model_name to a local backbone path (avoids downloading gated nvidia/Cosmos-Reason2-2B) 2. Prints the checkpoint output directory clearly at start and end """ import json import os import sys from pathlib import Path import tyro # Make sure gr00t is importable sys.path.insert(0, str(Path(__file__).parent.parent)) from gr00t.configs.base_config import get_default_config from gr00t.configs.finetune_config import FinetuneConfig from gr00t.experiment.experiment import run LOCAL_BACKBONE = ( "/lustre/fsw/portfolios/nvr/users/anchiehc/cache/huggingface/hub" "/models--Qwen--Qwen3-VL-2B-Instruct/snapshots/89644892e4d85e24eaac8bacfd4f463576704203" ) def load_modality_config(modality_config_path: str): import importlib path = Path(modality_config_path) if path.exists() and path.suffix == ".py": sys.path.insert(0, str(path.parent)) importlib.import_module(path.stem) print(f"Loaded modality config: {path}") else: raise FileNotFoundError(f"Modality config not found: {modality_config_path}") if __name__ == "__main__": if "LOGURU_LEVEL" not in os.environ: os.environ["LOGURU_LEVEL"] = "INFO" ft_config = tyro.cli(FinetuneConfig) from gr00t.data.embodiment_tags import EmbodimentTag ft_config.embodiment_tag = EmbodimentTag.resolve(ft_config.embodiment_tag) embodiment_tag = ft_config.embodiment_tag.value if ft_config.modality_config_path is not None: load_modality_config(ft_config.modality_config_path) config = get_default_config().load_dict( { "data": { "download_cache": False, "datasets": [ { "dataset_paths": [ft_config.dataset_path], "mix_ratio": 1.0, "embodiment_tag": embodiment_tag, } ], } } ) config.load_config_path = None config.model.tune_llm = ft_config.tune_llm config.model.tune_visual = ft_config.tune_visual config.model.tune_projector = ft_config.tune_projector config.model.tune_diffusion_model = ft_config.tune_diffusion_model config.model.state_dropout_prob = ft_config.state_dropout_prob config.model.random_rotation_angle = ft_config.random_rotation_angle config.model.color_jitter_params = ft_config.color_jitter_params if ft_config.extra_augmentation_config: config.model.extra_augmentation_config = json.loads(ft_config.extra_augmentation_config) else: config.model.extra_augmentation_config = None config.model.load_bf16 = False config.model.reproject_vision = False # Use local Qwen3-VL-2B-Instruct as backbone (same architecture as Cosmos-Reason2-2B) # Actual backbone weights come from the GR00T-N1.7-3B checkpoint safetensors. config.model.model_name = LOCAL_BACKBONE config.model.backbone_trainable_params_fp32 = True config.model.use_relative_action = True config.training.experiment_name = ft_config.experiment_name config.training.start_from_checkpoint = ft_config.base_model_path config.training.optim = "adamw_torch" config.training.global_batch_size = ft_config.global_batch_size config.training.dataloader_num_workers = ft_config.dataloader_num_workers config.training.learning_rate = ft_config.learning_rate config.training.gradient_accumulation_steps = ft_config.gradient_accumulation_steps config.training.output_dir = ft_config.output_dir config.training.save_steps = ft_config.save_steps config.training.save_total_limit = ft_config.save_total_limit config.training.num_gpus = ft_config.num_gpus config.training.use_wandb = ft_config.use_wandb config.training.max_steps = ft_config.max_steps config.training.weight_decay = ft_config.weight_decay config.training.warmup_ratio = ft_config.warmup_ratio config.training.wandb_project = ft_config.wandb_project config.training.save_only_model = ft_config.save_only_model config.training.skip_weight_loading = ft_config.skip_weight_loading config.data.shard_size = ft_config.shard_size config.data.episode_sampling_rate = ft_config.episode_sampling_rate config.data.num_shards_per_epoch = ft_config.num_shards_per_epoch # Use pyav backend: fast, pure-Python, no FFmpeg 7.0 soname requirement config.data.video_backend = "pyav" output_dir = Path(ft_config.output_dir).resolve() print("=" * 64) print(f" CHECKPOINT OUTPUT DIRECTORY:") print(f" {output_dir}") print(f" Checkpoints saved every {ft_config.save_steps} steps") print(f" Max checkpoints kept: {ft_config.save_total_limit}") print(f" Final checkpoint will be at:") print(f" {output_dir}/checkpoint-/") print("=" * 64) run(config) print("=" * 64) print(f" Training complete.") print(f" Checkpoint saved at: {output_dir}") import os ckpts = sorted([d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")]) for ckpt in ckpts: print(f" {ckpt}") print("=" * 64)