#!/usr/bin/env python3 """Utilities to reconstruct models from progressive pruning cycles.""" import json import os from typing import Optional import torch try: from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.models.auto.configuration_auto import CONFIG_MAPPING except Exception as exc: # pragma: no cover - fail early with clear error raise SystemExit("transformers is required: pip install transformers") from exc from fuse_layers_model import ( decrement_config, drop_layer, find_layer_container, get_dtype, normalize_config, ) def load_progressive_metadata(output_dir: str) -> dict: path = os.path.join(output_dir, "progressive_metadata.json") if not os.path.exists(path): raise FileNotFoundError(f"Missing progressive metadata at {path}") with open(path, "r", encoding="utf-8") as handle: return json.load(handle) def load_normalized_config(model_path: str, trust_remote_code: bool): config_dict, unused_kwargs = PretrainedConfig.get_config_dict( model_path, trust_remote_code=trust_remote_code, ) num_hidden_layers = config_dict.get("num_hidden_layers") layer_types = config_dict.get("layer_types") if ( isinstance(num_hidden_layers, int) and num_hidden_layers >= 0 and isinstance(layer_types, list) and len(layer_types) != num_hidden_layers ): config_dict["layer_types"] = list(layer_types[:num_hidden_layers]) model_type = config_dict["model_type"] config_class = CONFIG_MAPPING[model_type] config = config_class.from_dict(config_dict, **unused_kwargs) normalize_config(config) return config def load_causal_lm( model_path_or_id: str, *, torch_dtype, trust_remote_code: bool, **kwargs, ) -> torch.nn.Module: config = None config_path = os.path.join(model_path_or_id, "config.json") if os.path.isdir(model_path_or_id) and os.path.isfile(config_path): config = load_normalized_config(model_path_or_id, trust_remote_code) return AutoModelForCausalLM.from_pretrained( model_path_or_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, **kwargs, ) def load_progressive_model( base_model_id: str, output_dir: str, cycle: Optional[int] = None, device: Optional[str] = None, dtype: str = "auto", trust_remote_code: bool = False, layer_path: Optional[str] = None, ) -> torch.nn.Module: meta = load_progressive_metadata(output_dir) num_cycles = int(meta.get("num_progressive", 0)) if cycle is None: cycle = num_cycles if cycle < 0 or cycle > num_cycles: raise ValueError(f"Cycle {cycle} is outside [0, {num_cycles}]") if cycle > 0: full_model_dir = os.path.join(output_dir, f"cycle_{cycle}", "full_model") if os.path.isdir(full_model_dir): model = load_causal_lm( full_model_dir, torch_dtype=get_dtype(dtype), trust_remote_code=trust_remote_code, ) if device: model.to(device) return model model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=get_dtype(dtype), trust_remote_code=trust_remote_code, ) active_layer_path = layer_path or meta.get("layer_path") parent, name, container = find_layer_container(model, active_layer_path) for idx in range(1, cycle + 1): cycle_dir = os.path.join(output_dir, f"cycle_{idx}") cycle_meta_path = os.path.join(cycle_dir, "cycle_metadata.json") if not os.path.exists(cycle_meta_path): raise FileNotFoundError(f"Missing cycle metadata at {cycle_meta_path}") with open(cycle_meta_path, "r", encoding="utf-8") as handle: cycle_meta = json.load(handle) layer_idx = int(cycle_meta["layer_merged"]) fused_state = cycle_meta.get("fused_layer_state", "fused_layer.pt") fused_state_path = os.path.join(cycle_dir, fused_state) if not os.path.exists(fused_state_path): raise FileNotFoundError(f"Missing fused layer at {fused_state_path}") layers = list(container) if layer_idx < 0 or layer_idx >= len(layers): raise ValueError( f"Cycle {idx} layer index {layer_idx} out of range for {len(layers)} layers" ) state = torch.load(fused_state_path, map_location="cpu") layers[layer_idx].load_state_dict(state) new_container = drop_layer(container, layer_idx + 1) setattr(parent, name, new_container) decrement_config(model.config) container = new_container if device: model.to(device) return model