File size: 8,600 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | def auto_get_module_keys(module, max_depth=0, prefix_list=None, current_depth=0, current_prefix=""):
"""
get all submodule keys of a module, support setting recursion depth and prefix list.
:param module: the module to traverse.
:param max_depth: the maximum recursion depth, default is 1.
:param prefix_list: only include modules with specified prefix, default is None means no restriction.
:param current_depth: the current recursion depth, internal use.
:param current_prefix: the current prefix, internal use.
:return: the list of module keys.
"""
if current_depth > max_depth:
return []
module_keys = []
for name, sub_module in module.named_children():
full_name = f"{current_prefix}.{name}" if current_prefix else name
if prefix_list is None or any(full_name.startswith(prefix) for prefix in prefix_list):
module_keys.append(full_name)
module_keys.extend(auto_get_module_keys(sub_module, max_depth, prefix_list, current_depth + 1, full_name))
return module_keys
def is_module_trainable(module):
"""
check if a module is trainable: if the module itself has parameters, then all its parameters require_grad must be True;
if the module itself has no parameters, then its trainability depends on its submodules.
"""
params = list(module.parameters(recurse=False))
if params:
return all(p.requires_grad for p in params)
else:
# for container modules with no direct parameters, consider them trainable (the final result depends on their submodules)
return True
def auto_get_trainable_modules(module, prefix="", max_depth=None):
"""
recursively traverse the module, return the list of all trainable module names.
if all submodules of a module are trainable, then only return the name of the parent module, no longer recursively output the names of its submodules.
parameters:
- module: the module to traverse.
- prefix: the name prefix of the current module (internal use).
- max_depth: the maximum recursion depth, None means infinite recursion.
return:
a list of module names.
"""
# get all direct submodules of the current module
children = list(module.named_children())
# if the maximum depth is reached or there are no submodules, return the current module (if trainable and prefix is not empty)
if (max_depth is not None and max_depth <= 0) or not children:
return [prefix] if prefix and is_module_trainable(module) else []
child_keys = []
all_children_trainable = True
for name, child in children:
full_name = f"{prefix}.{name}" if prefix else name
# recursively get the trainable keys of the submodules
keys = auto_get_trainable_modules(child, full_name, None if max_depth is None else max_depth - 1)
if not keys:
# if the submodule does not return any further submodules, check the submodule itself
if is_module_trainable(child):
keys = [full_name]
else:
all_children_trainable = False
else:
# if the submodule returns multiple names, it means that it cannot be merged
if len(keys) > 1:
all_children_trainable = False
child_keys.extend(keys)
# if the current module is trainable and all submodules are trainable, return the name of the current module
if is_module_trainable(module) and all_children_trainable and child_keys:
return [prefix] if prefix else child_keys
else:
return child_keys
def print_freeze_status(self):
"""
for each top-level submodule, if all its parameters are in the same state (all frozen or all trainable), only print the top-level module.
if some top-level submodule has mixed parameter states (some frozen, some trainable), list the state of each parameter under the submodule.
"""
from collections import defaultdict
# collect the state of parameters under each top-level module
status_dict = defaultdict(lambda: {"Frozen": 0, "Trainable": 0, "params": []})
for full_name, param in self.named_parameters():
# full_name is like "qwen_vl_interface.model.layer.weight"
top_module = full_name.split(".", 1)[0] # get the top-level module name
state = "Frozen" if not param.requires_grad else "Trainable"
status_dict[top_module]["params"].append((full_name, state))
status_dict[top_module][state] += 1
print("=== module parameter freezing status ===")
for top_module, info in status_dict.items():
frozen_count = info["Frozen"]
trainable_count = info["Trainable"]
if frozen_count > 0 and trainable_count == 0:
# all frozen
print(f"{top_module:40s} | all Frozen ({frozen_count} parameters)")
elif trainable_count > 0 and frozen_count == 0:
# all trainable
print(f"{top_module:40s} | all Trainable ({trainable_count} parameters)")
else:
# mixed state, first print the module name summary, then list the state of each parameter
print(f"{top_module:40s} | mixed state → Frozen: {frozen_count}, Trainable: {trainable_count}")
for pname, pstate in info["params"]:
print(f" {pname:60s} | {pstate}")
print("=========================\n")
class Registry:
def __init__(self, name: str):
self.name = name
self._registry = {}
def register(self, key: str):
"""Decorator: register a builder function or class"""
def decorator(framework_class):
if key in self._registry:
# print(ImportWarning(f"{key} already registered to {self.name}"))
pass
self._registry[key] = framework_class
return framework_class
return decorator
def __getitem__(self, key):
return self._registry[key]
def list(self):
"""
List currently registered keys; if with_values=True (not used here) return mapping {key: value_obj}.
Using class name as value is also intuitive, e.g., framework.__name__.
"""
return {k: v for k, v in self._registry.items()}
FRAMEWORK_REGISTRY = Registry("frameworks")
from starVLA.training.trainer_utils import initialize_overwatch
import os
import json
from pathlib import Path
from omegaconf import OmegaConf
# Initialize Overwatch =>> Wraps `logging.Logger`
overwatch = initialize_overwatch(__name__)
def read_mode_config(pretrained_checkpoint):
"""
Same as read_model_config (legacy duplicate kept for backward compatibility).
Args:
pretrained_checkpoint: Path to a .pt checkpoint file.
Returns:
tuple:
vla_cfg (dict)
norm_stats (dict)
"""
if os.path.isfile(pretrained_checkpoint):
overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(pretrained_checkpoint))}`")
# [Validate] Checkpoint Path should look like `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt|.safetensors`
assert checkpoint_pt.suffix in (".pt", ".safetensors"), \
f"Unsupported checkpoint suffix `{checkpoint_pt.suffix}`, expected `.pt` or `.safetensors`"
run_dir = checkpoint_pt.parents[1]
# Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
config_yaml, dataset_statistics_json = run_dir / "config.yaml", run_dir / "dataset_statistics.json"
assert config_yaml.exists(), f"Missing `config.yaml` for `{run_dir}`"
assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir}`"
# Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`model_id_or_path`)
# Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
try:
ocfg = OmegaConf.load(str(config_yaml))
global_cfg = OmegaConf.to_container(ocfg, resolve=True)
except Exception as e:
overwatch.error(f"❌ Failed to load YAML config `{config_yaml}`: {e}")
raise
# Load Dataset Statistics for Action Denormalization
with open(dataset_statistics_json, "r") as f:
norm_stats = json.load(f)
else:
overwatch.error(f"❌ Pretrained checkpoint `{pretrained_checkpoint}` does not exist.")
raise FileNotFoundError(f"Pretrained checkpoint `{pretrained_checkpoint}` does not exist.")
return global_cfg, norm_stats
|