| |
| """Utilities to reconstruct models from progressive pruning cycles.""" |
|
|
| import json |
| import os |
| from typing import Optional |
|
|
| import torch |
|
|
| try: |
| from transformers import AutoModelForCausalLM, PretrainedConfig |
| from transformers.models.auto.configuration_auto import CONFIG_MAPPING |
| except Exception as exc: |
| raise SystemExit("transformers is required: pip install transformers") from exc |
|
|
| from fuse_layers_model import ( |
| decrement_config, |
| drop_layer, |
| find_layer_container, |
| get_dtype, |
| normalize_config, |
| ) |
|
|
|
|
| def load_progressive_metadata(output_dir: str) -> dict: |
| path = os.path.join(output_dir, "progressive_metadata.json") |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Missing progressive metadata at {path}") |
| with open(path, "r", encoding="utf-8") as handle: |
| return json.load(handle) |
|
|
|
|
| def load_normalized_config(model_path: str, trust_remote_code: bool): |
| config_dict, unused_kwargs = PretrainedConfig.get_config_dict( |
| model_path, |
| trust_remote_code=trust_remote_code, |
| ) |
| num_hidden_layers = config_dict.get("num_hidden_layers") |
| layer_types = config_dict.get("layer_types") |
| if ( |
| isinstance(num_hidden_layers, int) |
| and num_hidden_layers >= 0 |
| and isinstance(layer_types, list) |
| and len(layer_types) != num_hidden_layers |
| ): |
| config_dict["layer_types"] = list(layer_types[:num_hidden_layers]) |
| model_type = config_dict["model_type"] |
| config_class = CONFIG_MAPPING[model_type] |
| config = config_class.from_dict(config_dict, **unused_kwargs) |
| normalize_config(config) |
| return config |
|
|
|
|
| def load_causal_lm( |
| model_path_or_id: str, |
| *, |
| torch_dtype, |
| trust_remote_code: bool, |
| **kwargs, |
| ) -> torch.nn.Module: |
| config = None |
| config_path = os.path.join(model_path_or_id, "config.json") |
| if os.path.isdir(model_path_or_id) and os.path.isfile(config_path): |
| config = load_normalized_config(model_path_or_id, trust_remote_code) |
| return AutoModelForCausalLM.from_pretrained( |
| model_path_or_id, |
| config=config, |
| torch_dtype=torch_dtype, |
| trust_remote_code=trust_remote_code, |
| **kwargs, |
| ) |
|
|
|
|
| def load_progressive_model( |
| base_model_id: str, |
| output_dir: str, |
| cycle: Optional[int] = None, |
| device: Optional[str] = None, |
| dtype: str = "auto", |
| trust_remote_code: bool = False, |
| layer_path: Optional[str] = None, |
| ) -> torch.nn.Module: |
| meta = load_progressive_metadata(output_dir) |
| num_cycles = int(meta.get("num_progressive", 0)) |
| if cycle is None: |
| cycle = num_cycles |
| if cycle < 0 or cycle > num_cycles: |
| raise ValueError(f"Cycle {cycle} is outside [0, {num_cycles}]") |
|
|
| if cycle > 0: |
| full_model_dir = os.path.join(output_dir, f"cycle_{cycle}", "full_model") |
| if os.path.isdir(full_model_dir): |
| model = load_causal_lm( |
| full_model_dir, |
| torch_dtype=get_dtype(dtype), |
| trust_remote_code=trust_remote_code, |
| ) |
| if device: |
| model.to(device) |
| return model |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| torch_dtype=get_dtype(dtype), |
| trust_remote_code=trust_remote_code, |
| ) |
| active_layer_path = layer_path or meta.get("layer_path") |
| parent, name, container = find_layer_container(model, active_layer_path) |
|
|
| for idx in range(1, cycle + 1): |
| cycle_dir = os.path.join(output_dir, f"cycle_{idx}") |
| cycle_meta_path = os.path.join(cycle_dir, "cycle_metadata.json") |
| if not os.path.exists(cycle_meta_path): |
| raise FileNotFoundError(f"Missing cycle metadata at {cycle_meta_path}") |
| with open(cycle_meta_path, "r", encoding="utf-8") as handle: |
| cycle_meta = json.load(handle) |
|
|
| layer_idx = int(cycle_meta["layer_merged"]) |
| fused_state = cycle_meta.get("fused_layer_state", "fused_layer.pt") |
| fused_state_path = os.path.join(cycle_dir, fused_state) |
| if not os.path.exists(fused_state_path): |
| raise FileNotFoundError(f"Missing fused layer at {fused_state_path}") |
|
|
| layers = list(container) |
| if layer_idx < 0 or layer_idx >= len(layers): |
| raise ValueError( |
| f"Cycle {idx} layer index {layer_idx} out of range for {len(layers)} layers" |
| ) |
|
|
| state = torch.load(fused_state_path, map_location="cpu") |
| layers[layer_idx].load_state_dict(state) |
|
|
| new_container = drop_layer(container, layer_idx + 1) |
| setattr(parent, name, new_container) |
| decrement_config(model.config) |
|
|
| container = new_container |
|
|
| if device: |
| model.to(device) |
|
|
| return model |
|
|