| """ |
| metrics.py |
| |
| Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various |
| endpoints (e.g., JSONL local logs, Weights & Biases). |
| """ |
|
|
| from typing import Tuple |
| import re |
| import json |
| import numpy as np |
| import torch |
|
|
| from accelerate.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| |
| |
|
|
| |
|
|
|
|
| def normalize_dotlist_args(args): |
| """ |
| Convert ['--x.y', 'val'] and ['--flag'] → ['x.y=val', 'flag=true'] |
| """ |
| normalized = [] |
| skip = False |
| for i in range(len(args)): |
| if skip: |
| skip = False |
| continue |
|
|
| arg = args[i] |
| if arg.startswith("--"): |
| key = arg.lstrip("-") |
| if "=" in key: |
| normalized.append(key) |
| elif i + 1 < len(args) and not args[i + 1].startswith("--"): |
| normalized.append(f"{key}={args[i + 1]}") |
| skip = True |
| else: |
| normalized.append(f"{key}=true") |
| else: |
| pass |
| return normalized |
|
|
|
|
| def build_param_lr_groups(model, cfg): |
| """ |
| build multiple param groups based on cfg.trainer.learning_rate. |
| support specifying different learning rates for different modules, the rest use base. |
| |
| Args: |
| vla: nn.Module model object |
| cfg: config object, requires cfg.trainer.learning_rate dictionary |
| |
| Returns: |
| List[Dict]: param_groups that can be used to build optimizer with torch.optim |
| """ |
|
|
| lr_cfg = cfg.trainer.learning_rate |
| base_lr = lr_cfg.get("base", 1e-4) |
|
|
| freeze_modules = cfg.trainer.get("freeze_modules", "") |
| if not isinstance(freeze_modules, str): |
| freeze_modules = "" |
| freeze_patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()] |
|
|
| used_params = set() |
| frozen_params = set() |
| param_groups = [] |
|
|
| for freeze_path in freeze_patterns: |
| module = model |
| try: |
| for attr in freeze_path.split("."): |
| module = getattr(module, attr) |
| frozen_params.update(id(p) for p in module.parameters()) |
| except AttributeError: |
| print(f"⚠️ freeze module path does not exist: {freeze_path}") |
| continue |
|
|
| for module_name, lr in lr_cfg.items(): |
| if module_name == "base": |
| continue |
| |
| module = model |
| try: |
| for attr in module_name.split("."): |
| module = getattr(module, attr) |
| |
| params = [p for p in module.parameters() if id(p) not in frozen_params] |
| if params: |
| param_groups.append({"params": params, "lr": lr, "name": module_name}) |
| used_params.update(id(p) for p in params) |
| except AttributeError: |
| ReferenceError(f"⚠️ module path `{module_name}` not found in vla") |
|
|
| |
| other_params = [p for p in model.parameters() if id(p) not in used_params and id(p) not in frozen_params] |
| if other_params: |
| param_groups.append({"params": other_params, "lr": base_lr, "name": "base"}) |
|
|
| return param_groups |
|
|
|
|
| import torch.distributed as dist |
|
|
|
|
| def _is_main_process_dist() -> bool: |
| return (not dist.is_initialized()) or dist.get_rank() == 0 |
|
|
|
|
| def only_main_process(func): |
| """ |
| decorator: only run in main process (rank=0) |
| """ |
|
|
| def wrapper(*args, **kwargs): |
| if dist.is_initialized() and dist.get_rank() != 0: |
| return None |
| return func(*args, **kwargs) |
|
|
| return wrapper |
|
|
|
|
| from torchvision.ops import box_iou |
| from PIL import Image |
|
|
|
|
| def resize_images(images, target_size=(224, 224)): |
| """ |
| recursively resize all images in the nested list. |
| |
| :param images: nested list of images or single image. |
| :param target_size: target size (width, height) after resizing. |
| :return: resized images list, keeping the original nested structure. |
| """ |
| if isinstance(images, Image.Image): |
| return images.resize(target_size) |
| elif isinstance(images, list): |
| return [resize_images(img, target_size) for img in images] |
| else: |
| raise ValueError("Unsupported image type or structure.") |
|
|
|
|
| class TrainerUtils: |
| @staticmethod |
| def freeze_backbones(model, freeze_modules=""): |
| """ |
| directly freeze the specified submodules based on the relative module path list (patterns), no longer recursively find all submodule names: |
| - patterns: read from config.trainer.freeze_modules, separated by commas to get the "relative path" list |
| for example "qwen_vl_interface, action_model.net", |
| it means to freeze model.qwen_vl_interface and model.action_model.net. |
| |
| Args: |
| model: nn.Module model object |
| freeze_modules: relative module path list (patterns) |
| |
| Returns: |
| model: nn.Module model object |
| return: |
| - model: |
| """ |
| frozen = [] |
| print("#"*30) |
| print(freeze_modules) |
| if freeze_modules and type(freeze_modules) == str: |
| |
| patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()] if freeze_modules else [] |
|
|
| for path in patterns: |
| |
| attrs = path.split(".") |
| module = model |
| try: |
| for attr in attrs: |
| module = getattr(module, attr) |
| |
| for param in module.parameters(): |
| param.requires_grad = False |
| frozen.append(path) |
| except AttributeError: |
| |
| print(f"⚠️ module path does not exist, cannot freeze: {path}") |
| continue |
|
|
| |
| if _is_main_process_dist(): |
| print(f"🔒 Frozen modules with re pattern: {frozen}") |
| return model |
|
|
| @staticmethod |
| def print_trainable_parameters(model): |
| """ |
| print the total number of parameters and trainable parameters of the model |
| :param model: PyTorch model instance |
| """ |
| if not _is_main_process_dist(): |
| return |
| print("📊 model parameter statistics:") |
| num_params = sum(p.numel() for p in model.parameters()) |
| num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print( |
| f"# Parameters (in millions): {num_params / 10**6:.3f} Total, {num_trainable_params / 10**6:.3f} Trainable" |
| ) |
| return num_params, num_trainable_params |
|
|
| @staticmethod |
| def load_pretrained_backbones(model, checkpoint_path=None, reload_modules=None): |
| """ |
| load checkpoint: |
| - if reload_modules is set, load by path part |
| - otherwise → load the entire model parameters (overwrite model) |
| |
| return: |
| replace, loaded_modules: list of module paths that successfully loaded parameters; if global load, then ["<full_model>"] |
| """ |
| if not checkpoint_path: |
| return [] |
| if _is_main_process_dist(): |
| print(f"📦 loading checkpoint: {checkpoint_path}") |
| try: |
| if _is_safetensors_path(checkpoint_path): |
| from safetensors.torch import load_file |
|
|
| checkpoint = load_file(checkpoint_path) |
| else: |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| except Exception as e: |
| raise RuntimeError(f"❌ loading checkpoint failed: {e}") |
|
|
| loaded_modules = [] |
|
|
| if reload_modules: |
| module_paths = [p.strip() for p in reload_modules.split(",") if p.strip()] |
| for path in module_paths: |
| reload_modules = path.split(".") |
| module = model |
| try: |
| for module_name in reload_modules: |
| module = getattr(module, module_name) |
| prefix = path + "." |
| sub_state_dict = {k[len(prefix) :]: v for k, v in checkpoint.items() if k.startswith(prefix)} |
| if sub_state_dict: |
| module.load_state_dict(sub_state_dict, strict=True) |
| if _is_main_process_dist(): |
| print(f"✅ parameters loaded to module '{path}'") |
| loaded_modules.append(path) |
| else: |
| print(f"⚠️ parameters not found in checkpoint '{path}'") |
| except AttributeError: |
| print(f"❌ cannot find module path: {path}") |
| else: |
| try: |
| model.load_state_dict(checkpoint, strict=False) |
| if _is_main_process_dist(): |
| print("✅ loaded <full_model> model parameters") |
| loaded_modules = ["<full_model>"] |
| except Exception as e: |
| raise RuntimeError(f"❌ loading full model failed: {e}") |
| return model |
|
|
| @staticmethod |
| def print_freeze_status(model): |
| """ |
| print the freezing status of each parameter in the model |
| :param model: PyTorch model instance |
| """ |
| for name, param in model.named_parameters(): |
| status = "Frozen" if not param.requires_grad else "Trainable" |
| print(f"{name:60s} | {status}") |
|
|
| @staticmethod |
| def setup_distributed_training(accelerator, *components): |
| """ |
| use Accelerator to prepare distributed training components |
| :param accelerator: Accelerate instance |
| :param components: any number of components (such as model, optimizer, dataloader, etc.) |
| :return: prepared distributed components (in the same order as input) |
| """ |
|
|
| |
| prepared_components = accelerator.prepare(*components) |
| return prepared_components |
|
|
| def save_full_checkpoint(self, completed_steps, checkpoint_dir, output_dir): |
| """Save full training state (prepared components + RNG) for resume, |
| plus a standalone model weights file for deployment. |
| |
| The standalone file format is controlled by ``self.config.trainer.save_format`` |
| (``"pt"`` or ``"safetensors"``). Defaults to ``"pt"`` when unset. |
| |
| Must be called after accelerator.prepare(). |
| |
| Args: |
| completed_steps: Current training step count. |
| checkpoint_dir: Directory to save checkpoints (e.g. results/<run_id>/checkpoints). |
| output_dir: Top-level run directory for summary.jsonl and config. |
| """ |
| from pathlib import Path |
|
|
| save_format = getattr(self.config.trainer, "save_format", "pt") |
|
|
| |
| state_dir = os.path.join(checkpoint_dir, f"steps_{completed_steps}") |
| use_safe = save_format == "safetensors" |
| self.accelerator.save_state(state_dir, safe_serialization=use_safe) |
|
|
| |
| if self.accelerator.is_main_process: |
| import json as _json |
|
|
| |
| state_dict = self.accelerator.get_state_dict(self.model) |
| if state_dict is not None: |
| if save_format == "safetensors": |
| from safetensors.torch import save_file |
|
|
| weights_path = os.path.join( |
| checkpoint_dir, f"steps_{completed_steps}_model.safetensors" |
| ) |
| save_file(state_dict, weights_path) |
| else: |
| weights_path = os.path.join( |
| checkpoint_dir, f"steps_{completed_steps}_pytorch_model.pt" |
| ) |
| torch.save(state_dict, weights_path) |
|
|
| |
| summary_data = {"steps": completed_steps} |
| with open(os.path.join(output_dir, "summary.jsonl"), "a") as f: |
| f.write(_json.dumps(summary_data) + "\n") |
|
|
| self.accelerator.print(f"✅ Checkpoint saved at {state_dir}") |
|
|
| |
| from starVLA.training.trainer_utils.config_tracker import AccessTrackedConfig |
|
|
| if isinstance(self.config, AccessTrackedConfig): |
| self.config.save_accessed_config( |
| Path(output_dir) / "config.yaml", |
| use_original_values=False, |
| ) |
|
|
| self.accelerator.wait_for_everyone() |
|
|
| def resume_from_full_checkpoint(self, checkpoint_dir): |
| """Load full training state from an accelerator state directory. |
| |
| Must be called **after** accelerator.prepare() (DeepSpeed requirement). |
| |
| Args: |
| checkpoint_dir: Path to a steps_N/ directory containing full state. |
| |
| Returns: |
| int: The completed_steps parsed from directory name (steps_N), or 0. |
| """ |
| self.accelerator.load_state(checkpoint_dir) |
| self.accelerator.print(f"Resumed full training state from: {checkpoint_dir}") |
|
|
| |
| dir_name = os.path.basename(checkpoint_dir) |
| match = re.match(r"^steps_(\d+)$", dir_name) |
| return int(match.group(1)) if match else 0 |
|
|
| @staticmethod |
| def euclidean_distance(predicted: np.ndarray, ground_truth: np.ndarray) -> float: |
| return np.linalg.norm(predicted - ground_truth) |
|
|
| @staticmethod |
| def _reset_dataloader(dataloader, epoch_counter): |
| """safe reset dataloader iterator""" |
| |
| epoch_counter += 1 |
|
|
| |
| if hasattr(dataloader, "sampler") and callable(getattr(dataloader.sampler, "set_epoch", None)): |
| dataloader.sampler.set_epoch(epoch_counter) |
|
|
| |
| return iter(dataloader), epoch_counter |
|
|
| @staticmethod |
| def compute_grad_angle_with_stats(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> Tuple[float, float]: |
| """ |
| compute the cosine angle between two groups of gradient vectors (degrees), and calculate the average angle and variance. |
| grads_a, grads_v: gradient Tensor list corresponding to the same parameter list interface_params |
| return: |
| mean_angle_deg: average angle (degrees) |
| angle_variance: angle variance |
| """ |
| angle_degs = [] |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| grads_action = grads_a[0] |
| grads_action = grads_action[ |
| :32, :7 |
| ] |
| grads_vl = grads_v[0] |
| grads_vl = grads_vl[ |
| :32, :7 |
| ] |
| for g_a, g_v in zip(grads_action, grads_vl): |
| dot = torch.sum(g_a * g_v) |
| norm_a_sq = torch.sum(g_a * g_a) |
| norm_v_sq = torch.sum(g_v * g_v) |
|
|
| |
| norm_a = torch.sqrt(norm_a_sq + 1e-16) |
| norm_v = torch.sqrt(norm_v_sq + 1e-16) |
|
|
| cos_sim = (dot / (norm_a * norm_v)).clamp(-1.0, 1.0) |
| angle_rad = torch.acos(cos_sim) |
| angle_deg = angle_rad * (180.0 / torch.pi) |
|
|
| angle_degs.append(angle_deg.item()) |
|
|
| |
| angle_degs_tensor = torch.tensor(angle_degs) |
| mean_angle_deg = torch.mean(angle_degs_tensor).item() |
| angle_variance = torch.sqrt(torch.var(angle_degs_tensor)).item() |
| |
| return mean_angle_deg, angle_variance |
|
|
| @staticmethod |
| def pcgrad_project(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> list[torch.Tensor]: |
| """ |
| apply PCGrad projection to the second group of gradients grads_v, suppress negative transfer between grads_a and grads_v |
| if the dot product of two groups of gradients < 0, then: |
| grads_v <- grads_v - (dot / ||grads_a||^2) * grads_a |
| return the new grads_v list |
| """ |
| |
| dot, norm_a_sq = 0.0, 0.0 |
| for g_a, g_v in zip(grads_a, grads_v): |
| dot += torch.sum(g_a * g_v) |
| norm_a_sq += torch.sum(g_a * g_a) |
|
|
| if dot < 0: |
| coeff = dot / (norm_a_sq + 1e-6) |
| |
| grads_v = [g_v - coeff * g_a for g_a, g_v in zip(grads_a, grads_v)] |
|
|
| return grads_v |
|
|
| @staticmethod |
| def l1_distance(predicted: np.ndarray, ground_truth: np.ndarray) -> float: |
| """Mean Absolute Error - 更直观的误差度量""" |
| return np.sum(np.abs(predicted - ground_truth)) |
|
|
| @staticmethod |
| def eval_qwenpi(qwenpi, dataloader, num_batches=20): |
| """ |
| evaluate QwenQFormerDiT model, compute IoU and action distance. |
| |
| Args: |
| qwenpi: QwenQFormerDiT model instance. |
| dataloader: data loader. |
| num_batches: number of batches to evaluate. |
| |
| Returns: |
| dict: contains IoU and action distance evaluation results. |
| """ |
| iou_scores = [] |
| action_distances = [] |
| count = 0 |
|
|
| dataset_iter = iter(dataloader) |
| while count < num_batches: |
| try: |
| batch_samples = next(dataset_iter) |
| count += 1 |
| except StopIteration: |
| break |
|
|
| |
| images = [example["image"] for example in batch_samples] |
| instructions = [example["lang"] for example in batch_samples] |
| actions = [example["action"] for example in batch_samples] |
| solutions = [example["solution"] for example in batch_samples] |
|
|
| |
| predicted_solutions, normalized_actions = qwenpi.predict_action_withCoT( |
| images=images, instructions=instructions, use_ddim=False, num_ddim_steps=20 |
| ) |
|
|
| |
| parsed_solutions = [] |
| for solution in predicted_solutions: |
| parsed_solution = TrainerUtils.extract_json_from_string(solution) |
| parsed_solutions.append(parsed_solution) |
|
|
| |
| for pred_dict, gt_dict in zip(parsed_solutions, solutions): |
| pred_pick_bbox = torch.tensor(pred_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) |
| gt_pick_bbox = torch.tensor(gt_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) |
| pred_place_bbox = torch.tensor(pred_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) |
| gt_place_bbox = torch.tensor(gt_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0) |
|
|
| pick_iou = box_iou(pred_pick_bbox, gt_pick_bbox).item() |
| place_iou = box_iou(pred_place_bbox, gt_place_bbox).item() |
|
|
| iou_scores.append({"pick_iou": pick_iou, "place_iou": place_iou}) |
|
|
| |
| actions = np.array(actions) |
| num_pots = np.prod(actions.shape) |
| action_distance = TrainerUtils.euclidean_distance(normalized_actions, actions) |
| average_action_distance = action_distance / num_pots |
| action_distances.append(average_action_distance) |
|
|
| |
| avg_action_distance = np.mean(action_distances) |
| return {"iou_scores": iou_scores, "average_action_distance": avg_action_distance} |
|
|
| @staticmethod |
| def extract_json_from_string(input_string): |
| """ |
| extract valid JSON part from string and convert to dictionary. |
| |
| Args: |
| input_string (str): string containing extra characters. |
| |
| Returns: |
| dict: dictionary extracted and parsed. |
| """ |
| json_match = re.search(r"{.*}", input_string, re.DOTALL) |
| if json_match: |
| json_str = json_match.group(0) |
| try: |
| return json.loads(json_str) |
| except json.JSONDecodeError as e: |
| print(f"JSON decode failed: {e}") |
| return None |
| else: |
| print("No valid JSON part found") |
| return None |
|
|
| def _get_latest_checkpoint(self, checkpoint_dir): |
| """Find the latest checkpoint in the directory based on step number. |
| |
| Supports both new directory format (steps_N/) and legacy file format |
| (steps_N_pytorch_model.pt). Prefers new directory format when both exist |
| at the same step. |
| """ |
| if not os.path.exists(checkpoint_dir): |
| self.accelerator.print(f"No checkpoint directory found at {checkpoint_dir}") |
| return None, 0 |
|
|
| checkpoints_with_steps = [] |
|
|
| for entry in os.listdir(checkpoint_dir): |
| full_path = os.path.join(checkpoint_dir, entry) |
|
|
| |
| dir_match = re.match(r"^steps_(\d+)$", entry) |
| if dir_match and os.path.isdir(full_path): |
| step = int(dir_match.group(1)) |
| |
| checkpoints_with_steps.append((full_path, step, "dir")) |
| continue |
|
|
| |
| file_match = re.match(r"^steps_(\d+)_(?:pytorch_model\.pt|model\.safetensors)$", entry) |
| if file_match and os.path.isfile(full_path): |
| step = int(file_match.group(1)) |
| checkpoints_with_steps.append((full_path, step, "file")) |
|
|
| if not checkpoints_with_steps: |
| self.accelerator.print(f"No checkpoints found in {checkpoint_dir}") |
| return None, 0 |
|
|
| |
| type_priority = {"file": 0, "dir": 1} |
| checkpoints_with_steps.sort(key=lambda x: (x[1], type_priority[x[2]])) |
| latest_path, completed_steps, fmt = checkpoints_with_steps[-1] |
|
|
| self.accelerator.print(f"Latest checkpoint found: {latest_path} (format={fmt})") |
| return latest_path, completed_steps |
|
|
| import os |
|
|
|
|
| def is_main_process(): |
| rank = int(os.environ.get("RANK", 0)) |
| return rank == 0 |
|
|
|
|
| def _is_safetensors_path(path): |
| """Check if a path refers to a safetensors file.""" |
| return str(path).endswith(".safetensors") |
|
|