File size: 4,825 Bytes
2c44909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | #!/usr/bin/env python3
"""Utilities to reconstruct models from progressive pruning cycles."""
import json
import os
from typing import Optional
import torch
try:
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
except Exception as exc: # pragma: no cover - fail early with clear error
raise SystemExit("transformers is required: pip install transformers") from exc
from fuse_layers_model import (
decrement_config,
drop_layer,
find_layer_container,
get_dtype,
normalize_config,
)
def load_progressive_metadata(output_dir: str) -> dict:
path = os.path.join(output_dir, "progressive_metadata.json")
if not os.path.exists(path):
raise FileNotFoundError(f"Missing progressive metadata at {path}")
with open(path, "r", encoding="utf-8") as handle:
return json.load(handle)
def load_normalized_config(model_path: str, trust_remote_code: bool):
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(
model_path,
trust_remote_code=trust_remote_code,
)
num_hidden_layers = config_dict.get("num_hidden_layers")
layer_types = config_dict.get("layer_types")
if (
isinstance(num_hidden_layers, int)
and num_hidden_layers >= 0
and isinstance(layer_types, list)
and len(layer_types) != num_hidden_layers
):
config_dict["layer_types"] = list(layer_types[:num_hidden_layers])
model_type = config_dict["model_type"]
config_class = CONFIG_MAPPING[model_type]
config = config_class.from_dict(config_dict, **unused_kwargs)
normalize_config(config)
return config
def load_causal_lm(
model_path_or_id: str,
*,
torch_dtype,
trust_remote_code: bool,
**kwargs,
) -> torch.nn.Module:
config = None
config_path = os.path.join(model_path_or_id, "config.json")
if os.path.isdir(model_path_or_id) and os.path.isfile(config_path):
config = load_normalized_config(model_path_or_id, trust_remote_code)
return AutoModelForCausalLM.from_pretrained(
model_path_or_id,
config=config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
**kwargs,
)
def load_progressive_model(
base_model_id: str,
output_dir: str,
cycle: Optional[int] = None,
device: Optional[str] = None,
dtype: str = "auto",
trust_remote_code: bool = False,
layer_path: Optional[str] = None,
) -> torch.nn.Module:
meta = load_progressive_metadata(output_dir)
num_cycles = int(meta.get("num_progressive", 0))
if cycle is None:
cycle = num_cycles
if cycle < 0 or cycle > num_cycles:
raise ValueError(f"Cycle {cycle} is outside [0, {num_cycles}]")
if cycle > 0:
full_model_dir = os.path.join(output_dir, f"cycle_{cycle}", "full_model")
if os.path.isdir(full_model_dir):
model = load_causal_lm(
full_model_dir,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
)
if device:
model.to(device)
return model
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
)
active_layer_path = layer_path or meta.get("layer_path")
parent, name, container = find_layer_container(model, active_layer_path)
for idx in range(1, cycle + 1):
cycle_dir = os.path.join(output_dir, f"cycle_{idx}")
cycle_meta_path = os.path.join(cycle_dir, "cycle_metadata.json")
if not os.path.exists(cycle_meta_path):
raise FileNotFoundError(f"Missing cycle metadata at {cycle_meta_path}")
with open(cycle_meta_path, "r", encoding="utf-8") as handle:
cycle_meta = json.load(handle)
layer_idx = int(cycle_meta["layer_merged"])
fused_state = cycle_meta.get("fused_layer_state", "fused_layer.pt")
fused_state_path = os.path.join(cycle_dir, fused_state)
if not os.path.exists(fused_state_path):
raise FileNotFoundError(f"Missing fused layer at {fused_state_path}")
layers = list(container)
if layer_idx < 0 or layer_idx >= len(layers):
raise ValueError(
f"Cycle {idx} layer index {layer_idx} out of range for {len(layers)} layers"
)
state = torch.load(fused_state_path, map_location="cpu")
layers[layer_idx].load_state_dict(state)
new_container = drop_layer(container, layer_idx + 1)
setattr(parent, name, new_container)
decrement_config(model.config)
container = new_container
if device:
model.to(device)
return model
|