temp_ss / src /progressive_loader.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Utilities to reconstruct models from progressive pruning cycles."""
import json
import os
from typing import Optional
import torch
try:
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
except Exception as exc: # pragma: no cover - fail early with clear error
raise SystemExit("transformers is required: pip install transformers") from exc
from fuse_layers_model import (
decrement_config,
drop_layer,
find_layer_container,
get_dtype,
normalize_config,
)
def load_progressive_metadata(output_dir: str) -> dict:
path = os.path.join(output_dir, "progressive_metadata.json")
if not os.path.exists(path):
raise FileNotFoundError(f"Missing progressive metadata at {path}")
with open(path, "r", encoding="utf-8") as handle:
return json.load(handle)
def load_normalized_config(model_path: str, trust_remote_code: bool):
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(
model_path,
trust_remote_code=trust_remote_code,
)
num_hidden_layers = config_dict.get("num_hidden_layers")
layer_types = config_dict.get("layer_types")
if (
isinstance(num_hidden_layers, int)
and num_hidden_layers >= 0
and isinstance(layer_types, list)
and len(layer_types) != num_hidden_layers
):
config_dict["layer_types"] = list(layer_types[:num_hidden_layers])
model_type = config_dict["model_type"]
config_class = CONFIG_MAPPING[model_type]
config = config_class.from_dict(config_dict, **unused_kwargs)
normalize_config(config)
return config
def load_causal_lm(
model_path_or_id: str,
*,
torch_dtype,
trust_remote_code: bool,
**kwargs,
) -> torch.nn.Module:
config = None
config_path = os.path.join(model_path_or_id, "config.json")
if os.path.isdir(model_path_or_id) and os.path.isfile(config_path):
config = load_normalized_config(model_path_or_id, trust_remote_code)
return AutoModelForCausalLM.from_pretrained(
model_path_or_id,
config=config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
**kwargs,
)
def load_progressive_model(
base_model_id: str,
output_dir: str,
cycle: Optional[int] = None,
device: Optional[str] = None,
dtype: str = "auto",
trust_remote_code: bool = False,
layer_path: Optional[str] = None,
) -> torch.nn.Module:
meta = load_progressive_metadata(output_dir)
num_cycles = int(meta.get("num_progressive", 0))
if cycle is None:
cycle = num_cycles
if cycle < 0 or cycle > num_cycles:
raise ValueError(f"Cycle {cycle} is outside [0, {num_cycles}]")
if cycle > 0:
full_model_dir = os.path.join(output_dir, f"cycle_{cycle}", "full_model")
if os.path.isdir(full_model_dir):
model = load_causal_lm(
full_model_dir,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
)
if device:
model.to(device)
return model
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
)
active_layer_path = layer_path or meta.get("layer_path")
parent, name, container = find_layer_container(model, active_layer_path)
for idx in range(1, cycle + 1):
cycle_dir = os.path.join(output_dir, f"cycle_{idx}")
cycle_meta_path = os.path.join(cycle_dir, "cycle_metadata.json")
if not os.path.exists(cycle_meta_path):
raise FileNotFoundError(f"Missing cycle metadata at {cycle_meta_path}")
with open(cycle_meta_path, "r", encoding="utf-8") as handle:
cycle_meta = json.load(handle)
layer_idx = int(cycle_meta["layer_merged"])
fused_state = cycle_meta.get("fused_layer_state", "fused_layer.pt")
fused_state_path = os.path.join(cycle_dir, fused_state)
if not os.path.exists(fused_state_path):
raise FileNotFoundError(f"Missing fused layer at {fused_state_path}")
layers = list(container)
if layer_idx < 0 or layer_idx >= len(layers):
raise ValueError(
f"Cycle {idx} layer index {layer_idx} out of range for {len(layers)} layers"
)
state = torch.load(fused_state_path, map_location="cpu")
layers[layer_idx].load_state_dict(state)
new_container = drop_layer(container, layer_idx + 1)
setattr(parent, name, new_container)
decrement_config(model.config)
container = new_container
if device:
model.to(device)
return model