temp_ss / src /fbmc_metric.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Estimate Fisher-Barycentric Merge Cost (FBMC) for adjacent layers."""
import argparse
import csv
import json
import os
from typing import Dict, List, Optional, Tuple
import torch
try:
from datasets import load_dataset
except Exception: # pragma: no cover - optional dependency
load_dataset = None
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except Exception as exc: # pragma: no cover - fail early with clear error
raise SystemExit("transformers is required: pip install transformers") from exc
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Compute FBMC for adjacent layers of a Hugging Face causal LM."
)
parser.add_argument("--model", required=True, help="HF model id or local path")
parser.add_argument(
"--dataset",
action="append",
default=[],
help=(
"HF dataset name (repeatable). Optional if using --text or --text_file."
),
)
parser.add_argument(
"--dataset_config",
action="append",
default=[],
help="Optional dataset config (repeatable or single shared config).",
)
parser.add_argument(
"--dataset_split",
default="train",
help="Dataset split to use (default: train)",
)
parser.add_argument(
"--dataset_text_field",
default=None,
help="Text field in dataset (default: auto-detect, applies to all datasets)",
)
parser.add_argument(
"--text",
action="append",
default=[],
help="Inline text samples (can pass multiple)",
)
parser.add_argument(
"--text_file",
default=None,
help="Path to a text file for calibration data",
)
parser.add_argument(
"--num_samples",
type=int,
default=128,
help="Number of token sequences to use",
)
parser.add_argument(
"--seq_len", type=int, default=256, help="Sequence length"
)
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size"
)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for model + compute",
)
parser.add_argument(
"--dtype",
default="auto",
choices=["auto", "float32", "float16", "bfloat16"],
help="Model dtype",
)
parser.add_argument(
"--layer_path",
default=None,
help="Override layer attribute path (e.g., model.layers)",
)
parser.add_argument(
"--fisher_mode",
default="tensor",
choices=["tensor", "param"],
help="Fisher approximation granularity",
)
parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon")
parser.add_argument(
"--output",
default=None,
help="Optional JSON output path",
)
parser.add_argument(
"--output_csv",
default=None,
help="Optional CSV output path",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Allow custom model code from hub",
)
return parser.parse_args()
def resolve_attr(root: object, path: str) -> Optional[object]:
cur = root
for part in path.split("."):
if not hasattr(cur, part):
return None
cur = getattr(cur, part)
return cur
def find_layers(model, layer_path: Optional[str]) -> List[torch.nn.Module]:
if layer_path:
layers = resolve_attr(model, layer_path)
if layers is None:
raise ValueError(f"layer_path '{layer_path}' not found on model")
return list(layers)
# Common decoder-only layer containers. Add more if needed.
candidate_paths = [
"model.layers", # LLaMA, Mistral, Qwen2, Gemma
"model.decoder.layers", # OPT
"transformer.h", # GPT-2, GPT-J, Bloom, Falcon
"transformer.blocks", # MPT
"gpt_neox.layers", # GPT-NeoX
"layers", # fallback
]
for path in candidate_paths:
layers = resolve_attr(model, path)
if layers is not None:
try:
return list(layers)
except TypeError:
continue
raise ValueError(
"Could not locate transformer layers. Pass --layer_path explicitly."
)
def guess_text_field(dataset) -> str:
if hasattr(dataset, "column_names") and dataset.column_names:
if "text" in dataset.column_names:
return "text"
return dataset.column_names[0]
if hasattr(dataset, "features"):
names = list(dataset.features.keys())
if "text" in names:
return "text"
if names:
return names[0]
return "text"
def _normalize_config(config: Optional[str]) -> Optional[str]:
if config is None:
return None
if config.strip().lower() in {"none", "null", "-"}:
return None
return config
def _expand_dataset_configs(
datasets: List[str], configs: List[str]
) -> List[Optional[str]]:
if not configs:
return [None] * len(datasets)
if len(configs) == 1 and len(datasets) > 1:
return [_normalize_config(configs[0])] * len(datasets)
if len(configs) != len(datasets):
raise SystemExit(
"Provide zero, one, or matching-count --dataset_config values."
)
return [_normalize_config(cfg) for cfg in configs]
def _sample_dataset_rows(
dataset, target: int, seed: int
) -> List[Dict[str, object]]:
if target <= 0:
return []
try:
dataset = dataset.shuffle(seed=seed)
except Exception:
pass
if hasattr(dataset, "__len__"):
limit = min(target, len(dataset))
dataset = dataset.select(range(limit))
return [row for row in dataset]
# IterableDataset fallback.
rows = []
for row in dataset:
rows.append(row)
if len(rows) >= target:
break
return rows
def load_texts(args: argparse.Namespace) -> List[str]:
texts: List[str] = []
if args.text_file:
with open(args.text_file, "r", encoding="utf-8") as handle:
texts.extend([line.strip() for line in handle if line.strip()])
if args.text:
texts.extend([t for t in args.text if t])
if args.dataset:
if load_dataset is None:
raise SystemExit("datasets is required for --dataset")
datasets = list(args.dataset)
configs = _expand_dataset_configs(datasets, list(args.dataset_config))
num_datasets = len(datasets)
base = args.num_samples // num_datasets
remainder = args.num_samples % num_datasets
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
target = base + (1 if idx < remainder else 0)
dataset = load_dataset(
dataset_name,
config,
split=args.dataset_split,
trust_remote_code=True,
)
rows = _sample_dataset_rows(dataset, target, args.seed + idx)
text_field = args.dataset_text_field or guess_text_field(dataset)
for row in rows:
value = row.get(text_field, None) if isinstance(row, dict) else None
if isinstance(value, str) and value.strip():
texts.append(value)
return texts
def build_token_chunks(
texts: List[str], tokenizer, seq_len: int, num_samples: int
) -> List[torch.Tensor]:
chunks: List[torch.Tensor] = []
buffer: List[int] = []
for text in texts:
ids = tokenizer.encode(text, add_special_tokens=False)
if not ids:
continue
buffer.extend(ids)
while len(buffer) >= seq_len and len(chunks) < num_samples:
chunk = buffer[:seq_len]
buffer = buffer[seq_len:]
chunks.append(torch.tensor(chunk, dtype=torch.long))
if len(chunks) >= num_samples:
break
return chunks
def get_dtype(dtype: str):
if dtype == "auto":
return None
if dtype == "float16":
return torch.float16
if dtype == "bfloat16":
return torch.bfloat16
return torch.float32
def compute_fisher(
model,
layers: List[torch.nn.Module],
dataloader,
fisher_mode: str,
device: str,
) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]:
# Only compute grads for layer params.
for param in model.parameters():
param.requires_grad_(False)
for layer in layers:
for param in layer.parameters():
param.requires_grad_(True)
fisher_sums: List[Dict[str, object]] = []
param_numels: List[Dict[str, int]] = []
for layer in layers:
layer_sums: Dict[str, object] = {}
layer_numels: Dict[str, int] = {}
for name, param in layer.named_parameters():
if not param.requires_grad:
continue
if fisher_mode == "param":
layer_sums[name] = torch.zeros_like(
param, dtype=torch.float32, device="cpu"
)
else:
layer_sums[name] = 0.0
layer_numels[name] = param.numel()
fisher_sums.append(layer_sums)
param_numels.append(layer_numels)
num_batches = 0
model.eval()
for batch in dataloader:
input_ids = batch[0].to(device)
outputs = model(input_ids=input_ids, labels=input_ids)
loss = outputs.loss
loss.backward()
for layer_idx, layer in enumerate(layers):
layer_sums = fisher_sums[layer_idx]
for name, param in layer.named_parameters():
if not param.requires_grad:
continue
if param.grad is None:
continue
grad_sq = param.grad.detach().float().pow(2)
if fisher_mode == "param":
layer_sums[name] += grad_sq.cpu()
else:
layer_sums[name] += float(grad_sq.sum().item())
model.zero_grad(set_to_none=True)
num_batches += 1
if num_batches == 0:
raise RuntimeError("No batches processed; check dataset or text inputs.")
return fisher_sums, num_batches, param_numels
def compute_fbmc_costs(
layers: List[torch.nn.Module],
fisher_sums: List[Dict[str, object]],
num_batches: int,
param_numels: List[Dict[str, int]],
fisher_mode: str,
eps: float,
) -> List[Dict[str, object]]:
layer_params: List[Dict[str, torch.nn.Parameter]] = []
for layer in layers:
layer_params.append({name: param for name, param in layer.named_parameters()})
results: List[Dict[str, object]] = []
for idx in range(len(layers) - 1):
cost = 0.0
matched = 0
skipped = 0
params_i = layer_params[idx]
params_j = layer_params[idx + 1]
for name, param_i in params_i.items():
param_j = params_j.get(name)
if param_j is None or param_j.shape != param_i.shape:
skipped += 1
continue
matched += 1
if fisher_mode == "param":
fisher_i = fisher_sums[idx][name] / num_batches
fisher_j = fisher_sums[idx + 1][name] / num_batches
diff = (param_i.detach().float().cpu() - param_j.detach().float().cpu())
denom = fisher_i + fisher_j + eps
term = (fisher_i * fisher_j / denom) * diff * diff
cost += 0.5 * float(term.sum().item())
else:
fisher_i = fisher_sums[idx][name] / (
num_batches * param_numels[idx][name]
)
fisher_j = fisher_sums[idx + 1][name] / (
num_batches * param_numels[idx + 1][name]
)
denom = fisher_i + fisher_j + eps
if denom == 0:
continue
diff_sq = (
param_i.detach().float() - param_j.detach().float()
).pow(2)
cost += 0.5 * (fisher_i * fisher_j / denom) * float(
diff_sq.sum().item()
)
results.append(
{
"layer_i": idx,
"layer_j": idx + 1,
"fbmc": cost,
"matched_params": matched,
"skipped_params": skipped,
}
)
return results
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
dtype = get_dtype(args.dtype)
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
trust_remote_code=args.trust_remote_code,
)
tokenizer = AutoTokenizer.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code
)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
layers = find_layers(model, args.layer_path)
if len(layers) < 2:
raise SystemExit("Model has fewer than 2 layers; cannot compute FBMC.")
texts = load_texts(args)
if not texts:
raise SystemExit(
"No calibration text found. Provide --dataset, --text, or --text_file."
)
chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples)
if not chunks:
raise SystemExit("Not enough text to build token sequences.")
dataset = torch.utils.data.TensorDataset(torch.stack(chunks))
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=False
)
model.to(args.device)
fisher_sums, num_batches, param_numels = compute_fisher(
model,
layers,
dataloader,
fisher_mode=args.fisher_mode,
device=args.device,
)
costs = compute_fbmc_costs(
layers,
fisher_sums,
num_batches,
param_numels,
fisher_mode=args.fisher_mode,
eps=args.eps,
)
costs_sorted = sorted(costs, key=lambda x: x["fbmc"])
best = costs_sorted[0]
print("FBMC results (layer order):")
for item in costs:
print(
f"layers {item['layer_i']} & {item['layer_j']} -> "
f"fbmc={item['fbmc']:.6e} "
f"(matched={item['matched_params']}, skipped={item['skipped_params']})"
)
print("\nFBMC results (lowest cost first):")
for item in costs_sorted:
print(
f"layers {item['layer_i']} & {item['layer_j']} -> "
f"fbmc={item['fbmc']:.6e} "
f"(matched={item['matched_params']}, skipped={item['skipped_params']})"
)
print(
f"\nBest pair: layers {best['layer_i']} & {best['layer_j']} "
f"(fbmc={best['fbmc']:.6e})"
)
if args.output:
payload = {
"model": args.model,
"num_layers": len(layers),
"fisher_mode": args.fisher_mode,
"num_batches": num_batches,
"num_sequences": len(chunks),
"seq_len": args.seq_len,
"best_pair": best,
"pairs": costs_sorted,
}
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
print(f"\nWrote results to {args.output}")
if args.output_csv:
os.makedirs(os.path.dirname(args.output_csv) or ".", exist_ok=True)
with open(args.output_csv, "w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(
handle,
fieldnames=[
"layer_i",
"layer_j",
"fbmc",
"matched_params",
"skipped_params",
],
)
writer.writeheader()
for item in costs_sorted:
writer.writerow(
{
"layer_i": item["layer_i"],
"layer_j": item["layer_j"],
"fbmc": item["fbmc"],
"matched_params": item["matched_params"],
"skipped_params": item["skipped_params"],
}
)
print(f"Wrote CSV results to {args.output_csv}")
if __name__ == "__main__":
main()