#!/usr/bin/env python3 """Estimate Fisher-Barycentric Merge Cost (FBMC) for adjacent layers.""" import argparse import csv import json import os from typing import Dict, List, Optional, Tuple import torch try: from datasets import load_dataset except Exception: # pragma: no cover - optional dependency load_dataset = None try: from transformers import AutoModelForCausalLM, AutoTokenizer except Exception as exc: # pragma: no cover - fail early with clear error raise SystemExit("transformers is required: pip install transformers") from exc def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Compute FBMC for adjacent layers of a Hugging Face causal LM." ) parser.add_argument("--model", required=True, help="HF model id or local path") parser.add_argument( "--dataset", action="append", default=[], help=( "HF dataset name (repeatable). Optional if using --text or --text_file." ), ) parser.add_argument( "--dataset_config", action="append", default=[], help="Optional dataset config (repeatable or single shared config).", ) parser.add_argument( "--dataset_split", default="train", help="Dataset split to use (default: train)", ) parser.add_argument( "--dataset_text_field", default=None, help="Text field in dataset (default: auto-detect, applies to all datasets)", ) parser.add_argument( "--text", action="append", default=[], help="Inline text samples (can pass multiple)", ) parser.add_argument( "--text_file", default=None, help="Path to a text file for calibration data", ) parser.add_argument( "--num_samples", type=int, default=128, help="Number of token sequences to use", ) parser.add_argument( "--seq_len", type=int, default=256, help="Sequence length" ) parser.add_argument( "--batch_size", type=int, default=2, help="Batch size" ) parser.add_argument( "--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device for model + compute", ) parser.add_argument( "--dtype", default="auto", choices=["auto", "float32", "float16", "bfloat16"], help="Model dtype", ) parser.add_argument( "--layer_path", default=None, help="Override layer attribute path (e.g., model.layers)", ) parser.add_argument( "--fisher_mode", default="tensor", choices=["tensor", "param"], help="Fisher approximation granularity", ) parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon") parser.add_argument( "--output", default=None, help="Optional JSON output path", ) parser.add_argument( "--output_csv", default=None, help="Optional CSV output path", ) parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument( "--trust_remote_code", action="store_true", help="Allow custom model code from hub", ) return parser.parse_args() def resolve_attr(root: object, path: str) -> Optional[object]: cur = root for part in path.split("."): if not hasattr(cur, part): return None cur = getattr(cur, part) return cur def find_layers(model, layer_path: Optional[str]) -> List[torch.nn.Module]: if layer_path: layers = resolve_attr(model, layer_path) if layers is None: raise ValueError(f"layer_path '{layer_path}' not found on model") return list(layers) # Common decoder-only layer containers. Add more if needed. candidate_paths = [ "model.layers", # LLaMA, Mistral, Qwen2, Gemma "model.decoder.layers", # OPT "transformer.h", # GPT-2, GPT-J, Bloom, Falcon "transformer.blocks", # MPT "gpt_neox.layers", # GPT-NeoX "layers", # fallback ] for path in candidate_paths: layers = resolve_attr(model, path) if layers is not None: try: return list(layers) except TypeError: continue raise ValueError( "Could not locate transformer layers. Pass --layer_path explicitly." ) def guess_text_field(dataset) -> str: if hasattr(dataset, "column_names") and dataset.column_names: if "text" in dataset.column_names: return "text" return dataset.column_names[0] if hasattr(dataset, "features"): names = list(dataset.features.keys()) if "text" in names: return "text" if names: return names[0] return "text" def _normalize_config(config: Optional[str]) -> Optional[str]: if config is None: return None if config.strip().lower() in {"none", "null", "-"}: return None return config def _expand_dataset_configs( datasets: List[str], configs: List[str] ) -> List[Optional[str]]: if not configs: return [None] * len(datasets) if len(configs) == 1 and len(datasets) > 1: return [_normalize_config(configs[0])] * len(datasets) if len(configs) != len(datasets): raise SystemExit( "Provide zero, one, or matching-count --dataset_config values." ) return [_normalize_config(cfg) for cfg in configs] def _sample_dataset_rows( dataset, target: int, seed: int ) -> List[Dict[str, object]]: if target <= 0: return [] try: dataset = dataset.shuffle(seed=seed) except Exception: pass if hasattr(dataset, "__len__"): limit = min(target, len(dataset)) dataset = dataset.select(range(limit)) return [row for row in dataset] # IterableDataset fallback. rows = [] for row in dataset: rows.append(row) if len(rows) >= target: break return rows def load_texts(args: argparse.Namespace) -> List[str]: texts: List[str] = [] if args.text_file: with open(args.text_file, "r", encoding="utf-8") as handle: texts.extend([line.strip() for line in handle if line.strip()]) if args.text: texts.extend([t for t in args.text if t]) if args.dataset: if load_dataset is None: raise SystemExit("datasets is required for --dataset") datasets = list(args.dataset) configs = _expand_dataset_configs(datasets, list(args.dataset_config)) num_datasets = len(datasets) base = args.num_samples // num_datasets remainder = args.num_samples % num_datasets for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): target = base + (1 if idx < remainder else 0) dataset = load_dataset( dataset_name, config, split=args.dataset_split, trust_remote_code=True, ) rows = _sample_dataset_rows(dataset, target, args.seed + idx) text_field = args.dataset_text_field or guess_text_field(dataset) for row in rows: value = row.get(text_field, None) if isinstance(row, dict) else None if isinstance(value, str) and value.strip(): texts.append(value) return texts def build_token_chunks( texts: List[str], tokenizer, seq_len: int, num_samples: int ) -> List[torch.Tensor]: chunks: List[torch.Tensor] = [] buffer: List[int] = [] for text in texts: ids = tokenizer.encode(text, add_special_tokens=False) if not ids: continue buffer.extend(ids) while len(buffer) >= seq_len and len(chunks) < num_samples: chunk = buffer[:seq_len] buffer = buffer[seq_len:] chunks.append(torch.tensor(chunk, dtype=torch.long)) if len(chunks) >= num_samples: break return chunks def get_dtype(dtype: str): if dtype == "auto": return None if dtype == "float16": return torch.float16 if dtype == "bfloat16": return torch.bfloat16 return torch.float32 def compute_fisher( model, layers: List[torch.nn.Module], dataloader, fisher_mode: str, device: str, ) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]: # Only compute grads for layer params. for param in model.parameters(): param.requires_grad_(False) for layer in layers: for param in layer.parameters(): param.requires_grad_(True) fisher_sums: List[Dict[str, object]] = [] param_numels: List[Dict[str, int]] = [] for layer in layers: layer_sums: Dict[str, object] = {} layer_numels: Dict[str, int] = {} for name, param in layer.named_parameters(): if not param.requires_grad: continue if fisher_mode == "param": layer_sums[name] = torch.zeros_like( param, dtype=torch.float32, device="cpu" ) else: layer_sums[name] = 0.0 layer_numels[name] = param.numel() fisher_sums.append(layer_sums) param_numels.append(layer_numels) num_batches = 0 model.eval() for batch in dataloader: input_ids = batch[0].to(device) outputs = model(input_ids=input_ids, labels=input_ids) loss = outputs.loss loss.backward() for layer_idx, layer in enumerate(layers): layer_sums = fisher_sums[layer_idx] for name, param in layer.named_parameters(): if not param.requires_grad: continue if param.grad is None: continue grad_sq = param.grad.detach().float().pow(2) if fisher_mode == "param": layer_sums[name] += grad_sq.cpu() else: layer_sums[name] += float(grad_sq.sum().item()) model.zero_grad(set_to_none=True) num_batches += 1 if num_batches == 0: raise RuntimeError("No batches processed; check dataset or text inputs.") return fisher_sums, num_batches, param_numels def compute_fbmc_costs( layers: List[torch.nn.Module], fisher_sums: List[Dict[str, object]], num_batches: int, param_numels: List[Dict[str, int]], fisher_mode: str, eps: float, ) -> List[Dict[str, object]]: layer_params: List[Dict[str, torch.nn.Parameter]] = [] for layer in layers: layer_params.append({name: param for name, param in layer.named_parameters()}) results: List[Dict[str, object]] = [] for idx in range(len(layers) - 1): cost = 0.0 matched = 0 skipped = 0 params_i = layer_params[idx] params_j = layer_params[idx + 1] for name, param_i in params_i.items(): param_j = params_j.get(name) if param_j is None or param_j.shape != param_i.shape: skipped += 1 continue matched += 1 if fisher_mode == "param": fisher_i = fisher_sums[idx][name] / num_batches fisher_j = fisher_sums[idx + 1][name] / num_batches diff = (param_i.detach().float().cpu() - param_j.detach().float().cpu()) denom = fisher_i + fisher_j + eps term = (fisher_i * fisher_j / denom) * diff * diff cost += 0.5 * float(term.sum().item()) else: fisher_i = fisher_sums[idx][name] / ( num_batches * param_numels[idx][name] ) fisher_j = fisher_sums[idx + 1][name] / ( num_batches * param_numels[idx + 1][name] ) denom = fisher_i + fisher_j + eps if denom == 0: continue diff_sq = ( param_i.detach().float() - param_j.detach().float() ).pow(2) cost += 0.5 * (fisher_i * fisher_j / denom) * float( diff_sq.sum().item() ) results.append( { "layer_i": idx, "layer_j": idx + 1, "fbmc": cost, "matched_params": matched, "skipped_params": skipped, } ) return results def main() -> None: args = parse_args() torch.manual_seed(args.seed) dtype = get_dtype(args.dtype) model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, trust_remote_code=args.trust_remote_code, ) tokenizer = AutoTokenizer.from_pretrained( args.model, trust_remote_code=args.trust_remote_code ) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token layers = find_layers(model, args.layer_path) if len(layers) < 2: raise SystemExit("Model has fewer than 2 layers; cannot compute FBMC.") texts = load_texts(args) if not texts: raise SystemExit( "No calibration text found. Provide --dataset, --text, or --text_file." ) chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples) if not chunks: raise SystemExit("Not enough text to build token sequences.") dataset = torch.utils.data.TensorDataset(torch.stack(chunks)) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False ) model.to(args.device) fisher_sums, num_batches, param_numels = compute_fisher( model, layers, dataloader, fisher_mode=args.fisher_mode, device=args.device, ) costs = compute_fbmc_costs( layers, fisher_sums, num_batches, param_numels, fisher_mode=args.fisher_mode, eps=args.eps, ) costs_sorted = sorted(costs, key=lambda x: x["fbmc"]) best = costs_sorted[0] print("FBMC results (layer order):") for item in costs: print( f"layers {item['layer_i']} & {item['layer_j']} -> " f"fbmc={item['fbmc']:.6e} " f"(matched={item['matched_params']}, skipped={item['skipped_params']})" ) print("\nFBMC results (lowest cost first):") for item in costs_sorted: print( f"layers {item['layer_i']} & {item['layer_j']} -> " f"fbmc={item['fbmc']:.6e} " f"(matched={item['matched_params']}, skipped={item['skipped_params']})" ) print( f"\nBest pair: layers {best['layer_i']} & {best['layer_j']} " f"(fbmc={best['fbmc']:.6e})" ) if args.output: payload = { "model": args.model, "num_layers": len(layers), "fisher_mode": args.fisher_mode, "num_batches": num_batches, "num_sequences": len(chunks), "seq_len": args.seq_len, "best_pair": best, "pairs": costs_sorted, } os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) with open(args.output, "w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2) print(f"\nWrote results to {args.output}") if args.output_csv: os.makedirs(os.path.dirname(args.output_csv) or ".", exist_ok=True) with open(args.output_csv, "w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter( handle, fieldnames=[ "layer_i", "layer_j", "fbmc", "matched_params", "skipped_params", ], ) writer.writeheader() for item in costs_sorted: writer.writerow( { "layer_i": item["layer_i"], "layer_j": item["layer_j"], "fbmc": item["fbmc"], "matched_params": item["matched_params"], "skipped_params": item["skipped_params"], } ) print(f"Wrote CSV results to {args.output_csv}") if __name__ == "__main__": main()