""" metrics.py Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various endpoints (e.g., JSONL local logs, Weights & Biases). """ from typing import Tuple import re import json import numpy as np import torch from accelerate.logging import get_logger logger = get_logger(__name__) # === Define Tracker Interface === # # utils/cli_parser.py def normalize_dotlist_args(args): """ Convert ['--x.y', 'val'] and ['--flag'] → ['x.y=val', 'flag=true'] """ normalized = [] skip = False for i in range(len(args)): if skip: skip = False continue arg = args[i] if arg.startswith("--"): key = arg.lstrip("-") if "=" in key: normalized.append(key) elif i + 1 < len(args) and not args[i + 1].startswith("--"): normalized.append(f"{key}={args[i + 1]}") skip = True else: normalized.append(f"{key}=true") else: pass # skip orphaned values return normalized def build_param_lr_groups(model, cfg): """ build multiple param groups based on cfg.trainer.learning_rate. support specifying different learning rates for different modules, the rest use base. Args: vla: nn.Module model object cfg: config object, requires cfg.trainer.learning_rate dictionary Returns: List[Dict]: param_groups that can be used to build optimizer with torch.optim """ lr_cfg = cfg.trainer.learning_rate base_lr = lr_cfg.get("base", 1e-4) # default base learning rate freeze_modules = cfg.trainer.get("freeze_modules", "") if not isinstance(freeze_modules, str): freeze_modules = "" freeze_patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()] used_params = set() frozen_params = set() param_groups = [] for freeze_path in freeze_patterns: module = model try: for attr in freeze_path.split("."): module = getattr(module, attr) frozen_params.update(id(p) for p in module.parameters()) except AttributeError: print(f"⚠️ freeze module path does not exist: {freeze_path}") continue for module_name, lr in lr_cfg.items(): if module_name == "base": continue # try to find the module under vla by module_name (support nested paths) module = model try: for attr in module_name.split("."): module = getattr(module, attr) # filter out frozen parameters params = [p for p in module.parameters() if id(p) not in frozen_params] if params: # only add param group if there are trainable parameters param_groups.append({"params": params, "lr": lr, "name": module_name}) used_params.update(id(p) for p in params) except AttributeError: ReferenceError(f"⚠️ module path `{module_name}` not found in vla") # assign base learning rate to the remaining unused parameters (exclude frozen ones) other_params = [p for p in model.parameters() if id(p) not in used_params and id(p) not in frozen_params] if other_params: param_groups.append({"params": other_params, "lr": base_lr, "name": "base"}) return param_groups import torch.distributed as dist def _is_main_process_dist() -> bool: return (not dist.is_initialized()) or dist.get_rank() == 0 def only_main_process(func): """ decorator: only run in main process (rank=0) """ def wrapper(*args, **kwargs): if dist.is_initialized() and dist.get_rank() != 0: return None # non-main process does not execute return func(*args, **kwargs) return wrapper from torchvision.ops import box_iou from PIL import Image def resize_images(images, target_size=(224, 224)): """ recursively resize all images in the nested list. :param images: nested list of images or single image. :param target_size: target size (width, height) after resizing. :return: resized images list, keeping the original nested structure. """ if isinstance(images, Image.Image): # if it is a single PIL image return images.resize(target_size) elif isinstance(images, list): # if it is a list, recursively process each element return [resize_images(img, target_size) for img in images] else: raise ValueError("Unsupported image type or structure.") class TrainerUtils: @staticmethod def freeze_backbones(model, freeze_modules=""): """ directly freeze the specified submodules based on the relative module path list (patterns), no longer recursively find all submodule names: - patterns: read from config.trainer.freeze_modules, separated by commas to get the "relative path" list for example "qwen_vl_interface, action_model.net", it means to freeze model.qwen_vl_interface and model.action_model.net. Args: model: nn.Module model object freeze_modules: relative module path list (patterns) Returns: model: nn.Module model object return: - model: """ frozen = [] print("#"*30) print(freeze_modules) if freeze_modules and type(freeze_modules) == str: # split and remove whitespace patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()] if freeze_modules else [] for path in patterns: # split the "relative path" by dots, for example "action_model.net" → ["action_model", "net"] attrs = path.split(".") module = model try: for attr in attrs: module = getattr(module, attr) # if the module is successfully get, freeze it and its all submodule parameters for param in module.parameters(): param.requires_grad = False frozen.append(path) except AttributeError: # if the attribute does not exist, skip and print warning print(f"⚠️ module path does not exist, cannot freeze: {path}") continue # accelerator.wait_for_everyone() # synchronize when distributed training if _is_main_process_dist(): print(f"🔒 Frozen modules with re pattern: {frozen}") return model @staticmethod def print_trainable_parameters(model): """ print the total number of parameters and trainable parameters of the model :param model: PyTorch model instance """ if not _is_main_process_dist(): return print("📊 model parameter statistics:") num_params = sum(p.numel() for p in model.parameters()) num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print( f"# Parameters (in millions): {num_params / 10**6:.3f} Total, {num_trainable_params / 10**6:.3f} Trainable" ) return num_params, num_trainable_params @staticmethod def load_pretrained_backbones(model, checkpoint_path=None, reload_modules=None): """ load checkpoint: - if reload_modules is set, load by path part - otherwise → load the entire model parameters (overwrite model) return: replace, loaded_modules: list of module paths that successfully loaded parameters; if global load, then [""] """ if not checkpoint_path: return [] if _is_main_process_dist(): print(f"📦 loading checkpoint: {checkpoint_path}") try: if _is_safetensors_path(checkpoint_path): from safetensors.torch import load_file checkpoint = load_file(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location="cpu") except Exception as e: raise RuntimeError(f"❌ loading checkpoint failed: {e}") loaded_modules = [] if reload_modules: # partial load module_paths = [p.strip() for p in reload_modules.split(",") if p.strip()] for path in module_paths: reload_modules = path.split(".") module = model try: for module_name in reload_modules: # find the module to modify level by level module = getattr(module, module_name) prefix = path + "." sub_state_dict = {k[len(prefix) :]: v for k, v in checkpoint.items() if k.startswith(prefix)} if sub_state_dict: module.load_state_dict(sub_state_dict, strict=True) if _is_main_process_dist(): print(f"✅ parameters loaded to module '{path}'") loaded_modules.append(path) else: print(f"⚠️ parameters not found in checkpoint '{path}'") except AttributeError: print(f"❌ cannot find module path: {path}") else: # full load try: model.load_state_dict(checkpoint, strict=False) if _is_main_process_dist(): print("✅ loaded model parameters") loaded_modules = [""] except Exception as e: raise RuntimeError(f"❌ loading full model failed: {e}") return model @staticmethod def print_freeze_status(model): """ print the freezing status of each parameter in the model :param model: PyTorch model instance """ for name, param in model.named_parameters(): status = "Frozen" if not param.requires_grad else "Trainable" print(f"{name:60s} | {status}") @staticmethod def setup_distributed_training(accelerator, *components): """ use Accelerator to prepare distributed training components :param accelerator: Accelerate instance :param components: any number of components (such as model, optimizer, dataloader, etc.) :return: prepared distributed components (in the same order as input) """ # use accelerator.prepare method to wrap components prepared_components = accelerator.prepare(*components) return prepared_components def save_full_checkpoint(self, completed_steps, checkpoint_dir, output_dir): """Save full training state (prepared components + RNG) for resume, plus a standalone model weights file for deployment. The standalone file format is controlled by ``self.config.trainer.save_format`` (``"pt"`` or ``"safetensors"``). Defaults to ``"pt"`` when unset. Must be called after accelerator.prepare(). Args: completed_steps: Current training step count. checkpoint_dir: Directory to save checkpoints (e.g. results//checkpoints). output_dir: Top-level run directory for summary.jsonl and config. """ from pathlib import Path save_format = getattr(self.config.trainer, "save_format", "pt") # Save full accelerator state for all prepared components. state_dir = os.path.join(checkpoint_dir, f"steps_{completed_steps}") use_safe = save_format == "safetensors" self.accelerator.save_state(state_dir, safe_serialization=use_safe) # Save standalone weights & metadata (main process only) if self.accelerator.is_main_process: import json as _json # Save standalone model weights for deployment state_dict = self.accelerator.get_state_dict(self.model) if state_dict is not None: if save_format == "safetensors": from safetensors.torch import save_file weights_path = os.path.join( checkpoint_dir, f"steps_{completed_steps}_model.safetensors" ) save_file(state_dict, weights_path) else: weights_path = os.path.join( checkpoint_dir, f"steps_{completed_steps}_pytorch_model.pt" ) torch.save(state_dict, weights_path) # Append to summary log summary_data = {"steps": completed_steps} with open(os.path.join(output_dir, "summary.jsonl"), "a") as f: f.write(_json.dumps(summary_data) + "\n") self.accelerator.print(f"✅ Checkpoint saved at {state_dir}") # Save accessed config if available from starVLA.training.trainer_utils.config_tracker import AccessTrackedConfig if isinstance(self.config, AccessTrackedConfig): self.config.save_accessed_config( Path(output_dir) / "config.yaml", use_original_values=False, ) self.accelerator.wait_for_everyone() def resume_from_full_checkpoint(self, checkpoint_dir): """Load full training state from an accelerator state directory. Must be called **after** accelerator.prepare() (DeepSpeed requirement). Args: checkpoint_dir: Path to a steps_N/ directory containing full state. Returns: int: The completed_steps parsed from directory name (steps_N), or 0. """ self.accelerator.load_state(checkpoint_dir) self.accelerator.print(f"Resumed full training state from: {checkpoint_dir}") # Parse completed_steps from directory name (e.g. "steps_5000") dir_name = os.path.basename(checkpoint_dir) match = re.match(r"^steps_(\d+)$", dir_name) return int(match.group(1)) if match else 0 @staticmethod def euclidean_distance(predicted: np.ndarray, ground_truth: np.ndarray) -> float: return np.linalg.norm(predicted - ground_truth) @staticmethod def _reset_dataloader(dataloader, epoch_counter): """safe reset dataloader iterator""" # 1. update epoch counter epoch_counter += 1 # 2. set new epoch (distributed core) if hasattr(dataloader, "sampler") and callable(getattr(dataloader.sampler, "set_epoch", None)): dataloader.sampler.set_epoch(epoch_counter) # 3. create new iterator return iter(dataloader), epoch_counter @staticmethod def compute_grad_angle_with_stats(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> Tuple[float, float]: """ compute the cosine angle between two groups of gradient vectors (degrees), and calculate the average angle and variance. grads_a, grads_v: gradient Tensor list corresponding to the same parameter list interface_params return: mean_angle_deg: average angle (degrees) angle_variance: angle variance """ angle_degs = [] # compute the cosine angle between each gradient block grads_a[0].shape = 1280, 3, 14, 14 # grads_1 = grads_a[0][0] # [3, 14, 14] # grads_2 = grads_v[0][0] # grads_a = grads_1.view(-1, 3) # reshape to [196, 3] # grads_v = grads_2.view(-1, 3) # lang linear # reshape to 14*14, 3 # layer grads_action = grads_a[0] # [2048, 11008] grads_action = grads_action[ :32, :7 ] # only take the first 7 elements, avoid cosim failure in high-dimensional space grads_vl = grads_v[0] # [2048, 11008] grads_vl = grads_vl[ :32, :7 ] # only take the first 32 elements, 7 dimensions, avoid cosim failure in high-dimensional space for g_a, g_v in zip(grads_action, grads_vl): dot = torch.sum(g_a * g_v) norm_a_sq = torch.sum(g_a * g_a) norm_v_sq = torch.sum(g_v * g_v) # avoid division by zero norm_a = torch.sqrt(norm_a_sq + 1e-16) norm_v = torch.sqrt(norm_v_sq + 1e-16) cos_sim = (dot / (norm_a * norm_v)).clamp(-1.0, 1.0) angle_rad = torch.acos(cos_sim) angle_deg = angle_rad * (180.0 / torch.pi) angle_degs.append(angle_deg.item()) # compute the average angle and variance angle_degs_tensor = torch.tensor(angle_degs) mean_angle_deg = torch.mean(angle_degs_tensor).item() angle_variance = torch.sqrt(torch.var(angle_degs_tensor)).item() # accelerator.wait_for_everyone() return mean_angle_deg, angle_variance @staticmethod def pcgrad_project(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> list[torch.Tensor]: """ apply PCGrad projection to the second group of gradients grads_v, suppress negative transfer between grads_a and grads_v if the dot product of two groups of gradients < 0, then: grads_v <- grads_v - (dot / ||grads_a||^2) * grads_a return the new grads_v list """ # first compute dot and ||grads_a||^2 dot, norm_a_sq = 0.0, 0.0 for g_a, g_v in zip(grads_a, grads_v): dot += torch.sum(g_a * g_v) norm_a_sq += torch.sum(g_a * g_a) if dot < 0: coeff = dot / (norm_a_sq + 1e-6) # projection grads_v = [g_v - coeff * g_a for g_a, g_v in zip(grads_a, grads_v)] return grads_v @staticmethod def l1_distance(predicted: np.ndarray, ground_truth: np.ndarray) -> float: """Mean Absolute Error - 更直观的误差度量""" return np.sum(np.abs(predicted - ground_truth)) @staticmethod def eval_qwenpi(qwenpi, dataloader, num_batches=20): """ evaluate QwenQFormerDiT model, compute IoU and action distance. Args: qwenpi: QwenQFormerDiT model instance. dataloader: data loader. num_batches: number of batches to evaluate. Returns: dict: contains IoU and action distance evaluation results. """ iou_scores = [] action_distances = [] count = 0 dataset_iter = iter(dataloader) while count < num_batches: try: batch_samples = next(dataset_iter) count += 1 except StopIteration: break # extract data images = [example["image"] for example in batch_samples] instructions = [example["lang"] for example in batch_samples] actions = [example["action"] for example in batch_samples] solutions = [example["solution"] for example in batch_samples] # model prediction predicted_solutions, normalized_actions = qwenpi.predict_action_withCoT( images=images, instructions=instructions, use_ddim=False, num_ddim_steps=20 ) # extract and convert predicted results parsed_solutions = [] for solution in predicted_solutions: parsed_solution = TrainerUtils.extract_json_from_string(solution) parsed_solutions.append(parsed_solution) # compute IoU for pred_dict, gt_dict in zip(parsed_solutions, solutions): pred_pick_bbox = torch.tensor(pred_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) gt_pick_bbox = torch.tensor(gt_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) pred_place_bbox = torch.tensor(pred_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) gt_place_bbox = torch.tensor(gt_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) pick_iou = box_iou(pred_pick_bbox, gt_pick_bbox).item() place_iou = box_iou(pred_place_bbox, gt_place_bbox).item() iou_scores.append({"pick_iou": pick_iou, "place_iou": place_iou}) # compute action distance actions = np.array(actions) # convert to numpy array num_pots = np.prod(actions.shape) # B*len*dim action_distance = TrainerUtils.euclidean_distance(normalized_actions, actions) average_action_distance = action_distance / num_pots action_distances.append(average_action_distance) # summarize results avg_action_distance = np.mean(action_distances) return {"iou_scores": iou_scores, "average_action_distance": avg_action_distance} @staticmethod def extract_json_from_string(input_string): """ extract valid JSON part from string and convert to dictionary. Args: input_string (str): string containing extra characters. Returns: dict: dictionary extracted and parsed. """ json_match = re.search(r"{.*}", input_string, re.DOTALL) if json_match: json_str = json_match.group(0) try: return json.loads(json_str) except json.JSONDecodeError as e: print(f"JSON decode failed: {e}") return None else: print("No valid JSON part found") return None def _get_latest_checkpoint(self, checkpoint_dir): """Find the latest checkpoint in the directory based on step number. Supports both new directory format (steps_N/) and legacy file format (steps_N_pytorch_model.pt). Prefers new directory format when both exist at the same step. """ if not os.path.exists(checkpoint_dir): self.accelerator.print(f"No checkpoint directory found at {checkpoint_dir}") return None, 0 checkpoints_with_steps = [] for entry in os.listdir(checkpoint_dir): full_path = os.path.join(checkpoint_dir, entry) # New format: steps_N/ directories (with training_state.json inside) dir_match = re.match(r"^steps_(\d+)$", entry) if dir_match and os.path.isdir(full_path): step = int(dir_match.group(1)) # Directory checkpoints contain full accelerator state for resume. checkpoints_with_steps.append((full_path, step, "dir")) continue # Weight-only files: steps_N_pytorch_model.pt or steps_N_model.safetensors file_match = re.match(r"^steps_(\d+)_(?:pytorch_model\.pt|model\.safetensors)$", entry) if file_match and os.path.isfile(full_path): step = int(file_match.group(1)) checkpoints_with_steps.append((full_path, step, "file")) if not checkpoints_with_steps: self.accelerator.print(f"No checkpoints found in {checkpoint_dir}") return None, 0 # Sort by step number, then by type priority (dir > file) so directory wins ties. type_priority = {"file": 0, "dir": 1} checkpoints_with_steps.sort(key=lambda x: (x[1], type_priority[x[2]])) latest_path, completed_steps, fmt = checkpoints_with_steps[-1] self.accelerator.print(f"Latest checkpoint found: {latest_path} (format={fmt})") return latest_path, completed_steps import os def is_main_process(): rank = int(os.environ.get("RANK", 0)) # if RANK is not set, default to 0 return rank == 0 def _is_safetensors_path(path): """Check if a path refers to a safetensors file.""" return str(path).endswith(".safetensors")