File size: 4,825 Bytes
2c44909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
"""Utilities to reconstruct models from progressive pruning cycles."""

import json
import os
from typing import Optional

import torch

try:
    from transformers import AutoModelForCausalLM, PretrainedConfig
    from transformers.models.auto.configuration_auto import CONFIG_MAPPING
except Exception as exc:  # pragma: no cover - fail early with clear error
    raise SystemExit("transformers is required: pip install transformers") from exc

from fuse_layers_model import (
    decrement_config,
    drop_layer,
    find_layer_container,
    get_dtype,
    normalize_config,
)


def load_progressive_metadata(output_dir: str) -> dict:
    path = os.path.join(output_dir, "progressive_metadata.json")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing progressive metadata at {path}")
    with open(path, "r", encoding="utf-8") as handle:
        return json.load(handle)


def load_normalized_config(model_path: str, trust_remote_code: bool):
    config_dict, unused_kwargs = PretrainedConfig.get_config_dict(
        model_path,
        trust_remote_code=trust_remote_code,
    )
    num_hidden_layers = config_dict.get("num_hidden_layers")
    layer_types = config_dict.get("layer_types")
    if (
        isinstance(num_hidden_layers, int)
        and num_hidden_layers >= 0
        and isinstance(layer_types, list)
        and len(layer_types) != num_hidden_layers
    ):
        config_dict["layer_types"] = list(layer_types[:num_hidden_layers])
    model_type = config_dict["model_type"]
    config_class = CONFIG_MAPPING[model_type]
    config = config_class.from_dict(config_dict, **unused_kwargs)
    normalize_config(config)
    return config


def load_causal_lm(
    model_path_or_id: str,
    *,
    torch_dtype,
    trust_remote_code: bool,
    **kwargs,
) -> torch.nn.Module:
    config = None
    config_path = os.path.join(model_path_or_id, "config.json")
    if os.path.isdir(model_path_or_id) and os.path.isfile(config_path):
        config = load_normalized_config(model_path_or_id, trust_remote_code)
    return AutoModelForCausalLM.from_pretrained(
        model_path_or_id,
        config=config,
        torch_dtype=torch_dtype,
        trust_remote_code=trust_remote_code,
        **kwargs,
    )


def load_progressive_model(
    base_model_id: str,
    output_dir: str,
    cycle: Optional[int] = None,
    device: Optional[str] = None,
    dtype: str = "auto",
    trust_remote_code: bool = False,
    layer_path: Optional[str] = None,
) -> torch.nn.Module:
    meta = load_progressive_metadata(output_dir)
    num_cycles = int(meta.get("num_progressive", 0))
    if cycle is None:
        cycle = num_cycles
    if cycle < 0 or cycle > num_cycles:
        raise ValueError(f"Cycle {cycle} is outside [0, {num_cycles}]")

    if cycle > 0:
        full_model_dir = os.path.join(output_dir, f"cycle_{cycle}", "full_model")
        if os.path.isdir(full_model_dir):
            model = load_causal_lm(
                full_model_dir,
                torch_dtype=get_dtype(dtype),
                trust_remote_code=trust_remote_code,
            )
            if device:
                model.to(device)
            return model

    model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        torch_dtype=get_dtype(dtype),
        trust_remote_code=trust_remote_code,
    )
    active_layer_path = layer_path or meta.get("layer_path")
    parent, name, container = find_layer_container(model, active_layer_path)

    for idx in range(1, cycle + 1):
        cycle_dir = os.path.join(output_dir, f"cycle_{idx}")
        cycle_meta_path = os.path.join(cycle_dir, "cycle_metadata.json")
        if not os.path.exists(cycle_meta_path):
            raise FileNotFoundError(f"Missing cycle metadata at {cycle_meta_path}")
        with open(cycle_meta_path, "r", encoding="utf-8") as handle:
            cycle_meta = json.load(handle)

        layer_idx = int(cycle_meta["layer_merged"])
        fused_state = cycle_meta.get("fused_layer_state", "fused_layer.pt")
        fused_state_path = os.path.join(cycle_dir, fused_state)
        if not os.path.exists(fused_state_path):
            raise FileNotFoundError(f"Missing fused layer at {fused_state_path}")

        layers = list(container)
        if layer_idx < 0 or layer_idx >= len(layers):
            raise ValueError(
                f"Cycle {idx} layer index {layer_idx} out of range for {len(layers)} layers"
            )

        state = torch.load(fused_state_path, map_location="cpu")
        layers[layer_idx].load_state_dict(state)

        new_container = drop_layer(container, layer_idx + 1)
        setattr(parent, name, new_container)
        decrement_config(model.config)

        container = new_container

    if device:
        model.to(device)

    return model