| |
| """Estimate Fisher-Barycentric Merge Cost (FBMC) for adjacent layers.""" |
|
|
| import argparse |
| import csv |
| import json |
| import os |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
|
|
| try: |
| from datasets import load_dataset |
| except Exception: |
| load_dataset = None |
|
|
| try: |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| except Exception as exc: |
| raise SystemExit("transformers is required: pip install transformers") from exc |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Compute FBMC for adjacent layers of a Hugging Face causal LM." |
| ) |
| parser.add_argument("--model", required=True, help="HF model id or local path") |
| parser.add_argument( |
| "--dataset", |
| action="append", |
| default=[], |
| help=( |
| "HF dataset name (repeatable). Optional if using --text or --text_file." |
| ), |
| ) |
| parser.add_argument( |
| "--dataset_config", |
| action="append", |
| default=[], |
| help="Optional dataset config (repeatable or single shared config).", |
| ) |
| parser.add_argument( |
| "--dataset_split", |
| default="train", |
| help="Dataset split to use (default: train)", |
| ) |
| parser.add_argument( |
| "--dataset_text_field", |
| default=None, |
| help="Text field in dataset (default: auto-detect, applies to all datasets)", |
| ) |
| parser.add_argument( |
| "--text", |
| action="append", |
| default=[], |
| help="Inline text samples (can pass multiple)", |
| ) |
| parser.add_argument( |
| "--text_file", |
| default=None, |
| help="Path to a text file for calibration data", |
| ) |
| parser.add_argument( |
| "--num_samples", |
| type=int, |
| default=128, |
| help="Number of token sequences to use", |
| ) |
| parser.add_argument( |
| "--seq_len", type=int, default=256, help="Sequence length" |
| ) |
| parser.add_argument( |
| "--batch_size", type=int, default=2, help="Batch size" |
| ) |
| parser.add_argument( |
| "--device", |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device for model + compute", |
| ) |
| parser.add_argument( |
| "--dtype", |
| default="auto", |
| choices=["auto", "float32", "float16", "bfloat16"], |
| help="Model dtype", |
| ) |
| parser.add_argument( |
| "--layer_path", |
| default=None, |
| help="Override layer attribute path (e.g., model.layers)", |
| ) |
| parser.add_argument( |
| "--fisher_mode", |
| default="tensor", |
| choices=["tensor", "param"], |
| help="Fisher approximation granularity", |
| ) |
| parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon") |
| parser.add_argument( |
| "--output", |
| default=None, |
| help="Optional JSON output path", |
| ) |
| parser.add_argument( |
| "--output_csv", |
| default=None, |
| help="Optional CSV output path", |
| ) |
| parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| parser.add_argument( |
| "--trust_remote_code", |
| action="store_true", |
| help="Allow custom model code from hub", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def resolve_attr(root: object, path: str) -> Optional[object]: |
| cur = root |
| for part in path.split("."): |
| if not hasattr(cur, part): |
| return None |
| cur = getattr(cur, part) |
| return cur |
|
|
|
|
| def find_layers(model, layer_path: Optional[str]) -> List[torch.nn.Module]: |
| if layer_path: |
| layers = resolve_attr(model, layer_path) |
| if layers is None: |
| raise ValueError(f"layer_path '{layer_path}' not found on model") |
| return list(layers) |
|
|
| |
| candidate_paths = [ |
| "model.layers", |
| "model.decoder.layers", |
| "transformer.h", |
| "transformer.blocks", |
| "gpt_neox.layers", |
| "layers", |
| ] |
| for path in candidate_paths: |
| layers = resolve_attr(model, path) |
| if layers is not None: |
| try: |
| return list(layers) |
| except TypeError: |
| continue |
| raise ValueError( |
| "Could not locate transformer layers. Pass --layer_path explicitly." |
| ) |
|
|
|
|
| def guess_text_field(dataset) -> str: |
| if hasattr(dataset, "column_names") and dataset.column_names: |
| if "text" in dataset.column_names: |
| return "text" |
| return dataset.column_names[0] |
| if hasattr(dataset, "features"): |
| names = list(dataset.features.keys()) |
| if "text" in names: |
| return "text" |
| if names: |
| return names[0] |
| return "text" |
|
|
|
|
| def _normalize_config(config: Optional[str]) -> Optional[str]: |
| if config is None: |
| return None |
| if config.strip().lower() in {"none", "null", "-"}: |
| return None |
| return config |
|
|
|
|
| def _expand_dataset_configs( |
| datasets: List[str], configs: List[str] |
| ) -> List[Optional[str]]: |
| if not configs: |
| return [None] * len(datasets) |
| if len(configs) == 1 and len(datasets) > 1: |
| return [_normalize_config(configs[0])] * len(datasets) |
| if len(configs) != len(datasets): |
| raise SystemExit( |
| "Provide zero, one, or matching-count --dataset_config values." |
| ) |
| return [_normalize_config(cfg) for cfg in configs] |
|
|
|
|
| def _sample_dataset_rows( |
| dataset, target: int, seed: int |
| ) -> List[Dict[str, object]]: |
| if target <= 0: |
| return [] |
| try: |
| dataset = dataset.shuffle(seed=seed) |
| except Exception: |
| pass |
|
|
| if hasattr(dataset, "__len__"): |
| limit = min(target, len(dataset)) |
| dataset = dataset.select(range(limit)) |
| return [row for row in dataset] |
|
|
| |
| rows = [] |
| for row in dataset: |
| rows.append(row) |
| if len(rows) >= target: |
| break |
| return rows |
|
|
|
|
| def load_texts(args: argparse.Namespace) -> List[str]: |
| texts: List[str] = [] |
| if args.text_file: |
| with open(args.text_file, "r", encoding="utf-8") as handle: |
| texts.extend([line.strip() for line in handle if line.strip()]) |
| if args.text: |
| texts.extend([t for t in args.text if t]) |
|
|
| if args.dataset: |
| if load_dataset is None: |
| raise SystemExit("datasets is required for --dataset") |
|
|
| datasets = list(args.dataset) |
| configs = _expand_dataset_configs(datasets, list(args.dataset_config)) |
| num_datasets = len(datasets) |
| base = args.num_samples // num_datasets |
| remainder = args.num_samples % num_datasets |
|
|
| for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): |
| target = base + (1 if idx < remainder else 0) |
| dataset = load_dataset( |
| dataset_name, |
| config, |
| split=args.dataset_split, |
| trust_remote_code=True, |
| ) |
| rows = _sample_dataset_rows(dataset, target, args.seed + idx) |
| text_field = args.dataset_text_field or guess_text_field(dataset) |
| for row in rows: |
| value = row.get(text_field, None) if isinstance(row, dict) else None |
| if isinstance(value, str) and value.strip(): |
| texts.append(value) |
|
|
| return texts |
|
|
|
|
| def build_token_chunks( |
| texts: List[str], tokenizer, seq_len: int, num_samples: int |
| ) -> List[torch.Tensor]: |
| chunks: List[torch.Tensor] = [] |
| buffer: List[int] = [] |
| for text in texts: |
| ids = tokenizer.encode(text, add_special_tokens=False) |
| if not ids: |
| continue |
| buffer.extend(ids) |
| while len(buffer) >= seq_len and len(chunks) < num_samples: |
| chunk = buffer[:seq_len] |
| buffer = buffer[seq_len:] |
| chunks.append(torch.tensor(chunk, dtype=torch.long)) |
| if len(chunks) >= num_samples: |
| break |
| return chunks |
|
|
|
|
| def get_dtype(dtype: str): |
| if dtype == "auto": |
| return None |
| if dtype == "float16": |
| return torch.float16 |
| if dtype == "bfloat16": |
| return torch.bfloat16 |
| return torch.float32 |
|
|
|
|
| def compute_fisher( |
| model, |
| layers: List[torch.nn.Module], |
| dataloader, |
| fisher_mode: str, |
| device: str, |
| ) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]: |
| |
| for param in model.parameters(): |
| param.requires_grad_(False) |
| for layer in layers: |
| for param in layer.parameters(): |
| param.requires_grad_(True) |
|
|
| fisher_sums: List[Dict[str, object]] = [] |
| param_numels: List[Dict[str, int]] = [] |
| for layer in layers: |
| layer_sums: Dict[str, object] = {} |
| layer_numels: Dict[str, int] = {} |
| for name, param in layer.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if fisher_mode == "param": |
| layer_sums[name] = torch.zeros_like( |
| param, dtype=torch.float32, device="cpu" |
| ) |
| else: |
| layer_sums[name] = 0.0 |
| layer_numels[name] = param.numel() |
| fisher_sums.append(layer_sums) |
| param_numels.append(layer_numels) |
|
|
| num_batches = 0 |
| model.eval() |
| for batch in dataloader: |
| input_ids = batch[0].to(device) |
| outputs = model(input_ids=input_ids, labels=input_ids) |
| loss = outputs.loss |
| loss.backward() |
| for layer_idx, layer in enumerate(layers): |
| layer_sums = fisher_sums[layer_idx] |
| for name, param in layer.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if param.grad is None: |
| continue |
| grad_sq = param.grad.detach().float().pow(2) |
| if fisher_mode == "param": |
| layer_sums[name] += grad_sq.cpu() |
| else: |
| layer_sums[name] += float(grad_sq.sum().item()) |
| model.zero_grad(set_to_none=True) |
| num_batches += 1 |
|
|
| if num_batches == 0: |
| raise RuntimeError("No batches processed; check dataset or text inputs.") |
|
|
| return fisher_sums, num_batches, param_numels |
|
|
|
|
| def compute_fbmc_costs( |
| layers: List[torch.nn.Module], |
| fisher_sums: List[Dict[str, object]], |
| num_batches: int, |
| param_numels: List[Dict[str, int]], |
| fisher_mode: str, |
| eps: float, |
| ) -> List[Dict[str, object]]: |
| layer_params: List[Dict[str, torch.nn.Parameter]] = [] |
| for layer in layers: |
| layer_params.append({name: param for name, param in layer.named_parameters()}) |
|
|
| results: List[Dict[str, object]] = [] |
| for idx in range(len(layers) - 1): |
| cost = 0.0 |
| matched = 0 |
| skipped = 0 |
| params_i = layer_params[idx] |
| params_j = layer_params[idx + 1] |
| for name, param_i in params_i.items(): |
| param_j = params_j.get(name) |
| if param_j is None or param_j.shape != param_i.shape: |
| skipped += 1 |
| continue |
| matched += 1 |
| if fisher_mode == "param": |
| fisher_i = fisher_sums[idx][name] / num_batches |
| fisher_j = fisher_sums[idx + 1][name] / num_batches |
| diff = (param_i.detach().float().cpu() - param_j.detach().float().cpu()) |
| denom = fisher_i + fisher_j + eps |
| term = (fisher_i * fisher_j / denom) * diff * diff |
| cost += 0.5 * float(term.sum().item()) |
| else: |
| fisher_i = fisher_sums[idx][name] / ( |
| num_batches * param_numels[idx][name] |
| ) |
| fisher_j = fisher_sums[idx + 1][name] / ( |
| num_batches * param_numels[idx + 1][name] |
| ) |
| denom = fisher_i + fisher_j + eps |
| if denom == 0: |
| continue |
| diff_sq = ( |
| param_i.detach().float() - param_j.detach().float() |
| ).pow(2) |
| cost += 0.5 * (fisher_i * fisher_j / denom) * float( |
| diff_sq.sum().item() |
| ) |
| results.append( |
| { |
| "layer_i": idx, |
| "layer_j": idx + 1, |
| "fbmc": cost, |
| "matched_params": matched, |
| "skipped_params": skipped, |
| } |
| ) |
| return results |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| torch.manual_seed(args.seed) |
|
|
| dtype = get_dtype(args.dtype) |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model, |
| torch_dtype=dtype, |
| trust_remote_code=args.trust_remote_code, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.model, trust_remote_code=args.trust_remote_code |
| ) |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| layers = find_layers(model, args.layer_path) |
| if len(layers) < 2: |
| raise SystemExit("Model has fewer than 2 layers; cannot compute FBMC.") |
|
|
| texts = load_texts(args) |
| if not texts: |
| raise SystemExit( |
| "No calibration text found. Provide --dataset, --text, or --text_file." |
| ) |
|
|
| chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples) |
| if not chunks: |
| raise SystemExit("Not enough text to build token sequences.") |
|
|
| dataset = torch.utils.data.TensorDataset(torch.stack(chunks)) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=False |
| ) |
|
|
| model.to(args.device) |
|
|
| fisher_sums, num_batches, param_numels = compute_fisher( |
| model, |
| layers, |
| dataloader, |
| fisher_mode=args.fisher_mode, |
| device=args.device, |
| ) |
|
|
| costs = compute_fbmc_costs( |
| layers, |
| fisher_sums, |
| num_batches, |
| param_numels, |
| fisher_mode=args.fisher_mode, |
| eps=args.eps, |
| ) |
|
|
| costs_sorted = sorted(costs, key=lambda x: x["fbmc"]) |
| best = costs_sorted[0] |
|
|
| print("FBMC results (layer order):") |
| for item in costs: |
| print( |
| f"layers {item['layer_i']} & {item['layer_j']} -> " |
| f"fbmc={item['fbmc']:.6e} " |
| f"(matched={item['matched_params']}, skipped={item['skipped_params']})" |
| ) |
| print("\nFBMC results (lowest cost first):") |
| for item in costs_sorted: |
| print( |
| f"layers {item['layer_i']} & {item['layer_j']} -> " |
| f"fbmc={item['fbmc']:.6e} " |
| f"(matched={item['matched_params']}, skipped={item['skipped_params']})" |
| ) |
| print( |
| f"\nBest pair: layers {best['layer_i']} & {best['layer_j']} " |
| f"(fbmc={best['fbmc']:.6e})" |
| ) |
|
|
| if args.output: |
| payload = { |
| "model": args.model, |
| "num_layers": len(layers), |
| "fisher_mode": args.fisher_mode, |
| "num_batches": num_batches, |
| "num_sequences": len(chunks), |
| "seq_len": args.seq_len, |
| "best_pair": best, |
| "pairs": costs_sorted, |
| } |
| os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) |
| with open(args.output, "w", encoding="utf-8") as handle: |
| json.dump(payload, handle, indent=2) |
| print(f"\nWrote results to {args.output}") |
|
|
| if args.output_csv: |
| os.makedirs(os.path.dirname(args.output_csv) or ".", exist_ok=True) |
| with open(args.output_csv, "w", encoding="utf-8", newline="") as handle: |
| writer = csv.DictWriter( |
| handle, |
| fieldnames=[ |
| "layer_i", |
| "layer_j", |
| "fbmc", |
| "matched_params", |
| "skipped_params", |
| ], |
| ) |
| writer.writeheader() |
| for item in costs_sorted: |
| writer.writerow( |
| { |
| "layer_i": item["layer_i"], |
| "layer_j": item["layer_j"], |
| "fbmc": item["fbmc"], |
| "matched_params": item["matched_params"], |
| "skipped_params": item["skipped_params"], |
| } |
| ) |
| print(f"Wrote CSV results to {args.output_csv}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|