""" Weight Loading and Saving Utilities for SAM3 MLX Handles: - Loading converted MLX weights from .npz files - Saving model weights - Weight name mapping between PyTorch and MLX """ import mlx.core as mx import numpy as np from pathlib import Path from typing import Dict, Any, Optional import json def map_pytorch_to_mlx_name(pytorch_name: str) -> str: """ Map PyTorch parameter names to MLX parameter names PyTorch uses different naming conventions: - weight/bias instead of MLX's weight/bias - Different module paths Args: pytorch_name: PyTorch parameter name Returns: MLX parameter name """ # Direct mappings name = pytorch_name # Vision encoder mappings name = name.replace("image_encoder.", "vision_encoder.") name = name.replace("trunk.", "") # Attention mappings name = name.replace("attn.qkv.", "attn.qkv.") # Layer norm mappings (PyTorch uses weight/bias, MLX uses scale/bias) # Actually MLX LayerNorm uses weight/bias too, so no change needed # Prompt encoder mappings name = name.replace("prompt_encoder.point_embeddings", "prompt_encoder.point_embeddings") # Mask decoder mappings name = name.replace("mask_decoder.transformer.", "mask_decoder.transformer.") name = name.replace("mask_decoder.output_upscaling.", "mask_decoder.output_upscaling.") return name def load_weights( model: Any, weights_path: str, strict: bool = False, verbose: bool = True, ) -> Any: """ Load MLX weights from .npz file into model Args: model: SAM3MLX model instance weights_path: Path to .npz weights file strict: If True, all parameters must match exactly verbose: Print loading statistics Returns: Model with loaded weights """ weights_path = Path(weights_path) if not weights_path.exists(): raise FileNotFoundError(f"Weights file not found: {weights_path}") if verbose: print(f"📥 Loading weights from {weights_path.name}") # Load numpy arrays weights_np = np.load(weights_path) # Get model parameter tree model_params = model.parameters() model_param_names = set(_flatten_params(model_params).keys()) # Convert and load weights loaded_count = 0 skipped_count = 0 missing_params = set(model_param_names) for param_name in weights_np.files: # Map PyTorch name to MLX name mlx_name = map_pytorch_to_mlx_name(param_name) # Check if parameter exists in model if mlx_name in model_param_names: # Convert to MLX array param_data = mx.array(weights_np[param_name]) # Set parameter in model _set_param(model, mlx_name, param_data) loaded_count += 1 missing_params.discard(mlx_name) else: skipped_count += 1 if verbose and strict: print(f" ⚠️ Skipped: {param_name} (not found in model)") if verbose: print(f"✅ Loaded {loaded_count} parameters") if skipped_count > 0: print(f" ⏭️ Skipped {skipped_count} parameters") if len(missing_params) > 0: print(f" ❌ Missing {len(missing_params)} parameters in checkpoint") if strict: for param in list(missing_params)[:10]: # Show first 10 print(f" - {param}") if strict and len(missing_params) > 0: raise ValueError( f"Missing {len(missing_params)} parameters in checkpoint. " "Use strict=False to load partial weights." ) return model def save_weights( model: Any, weights_path: str, verbose: bool = True, ) -> None: """ Save model weights to .npz file Args: model: SAM3MLX model instance weights_path: Path to save .npz weights file verbose: Print saving statistics """ weights_path = Path(weights_path) weights_path.parent.mkdir(parents=True, exist_ok=True) if verbose: print(f"💾 Saving weights to {weights_path.name}") # Get model parameters model_params = _flatten_params(model.parameters()) # Convert to numpy weights_np = {} for name, param in model_params.items(): weights_np[name] = np.array(param) # Save np.savez(weights_path, **weights_np) if verbose: file_size_mb = weights_path.stat().st_size / (1024 * 1024) print(f"✅ Saved {len(weights_np)} parameters ({file_size_mb:.2f} MB)") def _flatten_params(params: Dict, prefix: str = "", sep: str = ".") -> Dict[str, mx.array]: """ Flatten nested parameter dictionary Args: params: Nested parameter dict prefix: Current prefix for parameter names sep: Separator for parameter names Returns: Flattened dict of {name: array} """ flat = {} for key, value in params.items(): full_key = f"{prefix}{sep}{key}" if prefix else key if isinstance(value, dict): # Recurse into nested dict flat.update(_flatten_params(value, full_key, sep)) elif isinstance(value, mx.array): # Leaf parameter flat[full_key] = value elif isinstance(value, list): # List of parameters (e.g., from nn.Sequential) for i, item in enumerate(value): if isinstance(item, dict): flat.update(_flatten_params(item, f"{full_key}.{i}", sep)) elif isinstance(item, mx.array): flat[f"{full_key}.{i}"] = item return flat def _set_param(model: Any, param_name: str, value: mx.array) -> None: """ Set a parameter in the model by dotted name Args: model: Model instance param_name: Dotted parameter name (e.g., "vision_encoder.patch_embed.proj.weight") value: Parameter value """ parts = param_name.split(".") obj = model # Navigate to the parent object for part in parts[:-1]: if part.isdigit(): # List index obj = obj[int(part)] elif hasattr(obj, part): obj = getattr(obj, part) else: # Try to access as attribute raise AttributeError(f"Cannot find {part} in {type(obj)}") # Set the final attribute final_attr = parts[-1] if hasattr(obj, final_attr): setattr(obj, final_attr, value) else: raise AttributeError(f"Cannot set {final_attr} in {type(obj)}") def load_config(config_path: str) -> Dict[str, Any]: """ Load model configuration from JSON file Args: config_path: Path to config JSON file Returns: Configuration dictionary """ config_path = Path(config_path) if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") with open(config_path) as f: config = json.load(f) return config def save_config(config: Dict[str, Any], config_path: str) -> None: """ Save model configuration to JSON file Args: config: Configuration dictionary config_path: Path to save config JSON file """ config_path = Path(config_path) config_path.parent.mkdir(parents=True, exist_ok=True) with open(config_path, 'w') as f: json.dump(config, f, indent=2)