LJYAI commited on
Commit
2c44909
·
verified ·
1 Parent(s): 98b9392

upload src

Browse files
src/common_lm_data.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Shared LM dataset helpers for fair cross-method comparisons."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Dict, Iterable, Iterator, List, Optional, Tuple
8
+
9
+ import torch
10
+
11
+ try:
12
+ from datasets import load_dataset
13
+ from datasets import Dataset as HFDataset
14
+ except Exception: # pragma: no cover - optional dependency
15
+ load_dataset = None
16
+ HFDataset = None
17
+
18
+
19
+ def _normalize_config(config: Optional[str]) -> Optional[str]:
20
+ if config is None:
21
+ return None
22
+ if config.strip().lower() in {"none", "null", "-"}:
23
+ return None
24
+ return config
25
+
26
+
27
+ def guess_text_field(dataset) -> str:
28
+ if hasattr(dataset, "column_names") and dataset.column_names:
29
+ if "text" in dataset.column_names:
30
+ return "text"
31
+ return dataset.column_names[0]
32
+ if hasattr(dataset, "features"):
33
+ names = list(dataset.features.keys())
34
+ if "text" in names:
35
+ return "text"
36
+ if names:
37
+ return names[0]
38
+ return "text"
39
+
40
+
41
+ def normalize_dataset_name(name: str) -> str:
42
+ normalized = name.strip().lower()
43
+ aliases = {
44
+ "bookcorpus": "bookcorpus",
45
+ "boockcorpus": "bookcorpus",
46
+ "slimpajama": "slimpajama",
47
+ "dkyoon/slimpajama-6b": "slimpajama",
48
+ }
49
+ if normalized not in aliases:
50
+ raise ValueError(f"Unsupported dataset: {name}")
51
+ return aliases[normalized]
52
+
53
+
54
+ def resolve_dataset_spec(
55
+ name: str,
56
+ config: Optional[str] = None,
57
+ split: str = "train",
58
+ ) -> Tuple[str, Optional[str], str]:
59
+ normalized = normalize_dataset_name(name)
60
+ if normalized == "bookcorpus":
61
+ return "bookcorpus", _normalize_config(config), split
62
+ if normalized == "slimpajama":
63
+ return "DKYoon/SlimPajama-6B", _normalize_config(config), split
64
+ raise ValueError(f"Unsupported dataset: {name}")
65
+
66
+
67
+ def _sample_dataset_rows(dataset, target: int, seed: int) -> List[Dict[str, object]]:
68
+ if target <= 0:
69
+ return []
70
+ try:
71
+ dataset = dataset.shuffle(seed=seed)
72
+ except Exception:
73
+ pass
74
+
75
+ if hasattr(dataset, "__len__"):
76
+ limit = min(target, len(dataset))
77
+ dataset = dataset.select(range(limit))
78
+ return [row for row in dataset]
79
+
80
+ rows = []
81
+ for row in dataset:
82
+ rows.append(row)
83
+ if len(rows) >= target:
84
+ break
85
+ return rows
86
+
87
+
88
+ def _iter_dataset_rows(dataset, seed: int) -> Iterator[Dict[str, object]]:
89
+ try:
90
+ dataset = dataset.shuffle(seed=seed)
91
+ except Exception:
92
+ pass
93
+ for row in dataset:
94
+ yield row
95
+
96
+
97
+ def load_named_texts(
98
+ dataset_name: str,
99
+ *,
100
+ config: Optional[str] = None,
101
+ split: str = "train",
102
+ text_field: Optional[str] = None,
103
+ num_samples: int = 0,
104
+ seed: int = 0,
105
+ ) -> List[str]:
106
+ if load_dataset is None:
107
+ raise SystemExit("datasets is required for shared LM dataloaders")
108
+
109
+ hf_name, hf_config, hf_split = resolve_dataset_spec(dataset_name, config, split)
110
+ dataset = load_dataset(
111
+ hf_name,
112
+ hf_config,
113
+ split=hf_split,
114
+ trust_remote_code=True,
115
+ )
116
+ rows = dataset if num_samples <= 0 else _sample_dataset_rows(dataset, num_samples, seed)
117
+ field = text_field or guess_text_field(dataset)
118
+
119
+ texts: List[str] = []
120
+ for row in rows:
121
+ value = row.get(field, None) if isinstance(row, dict) else None
122
+ if isinstance(value, str) and value.strip():
123
+ texts.append(value)
124
+ return texts
125
+
126
+
127
+ def build_token_chunks_from_rows(
128
+ rows: Iterable[Dict[str, object]],
129
+ *,
130
+ text_field: str,
131
+ tokenizer,
132
+ seq_len: int,
133
+ num_sequences: int = 0,
134
+ add_bos: bool = False,
135
+ max_rows: int = 0,
136
+ ) -> List[torch.Tensor]:
137
+ chunks: List[torch.Tensor] = []
138
+ buffer: List[int] = []
139
+ limit = None if num_sequences <= 0 else num_sequences
140
+ rows_seen = 0
141
+
142
+ for row in rows:
143
+ if max_rows > 0 and rows_seen >= max_rows:
144
+ break
145
+ rows_seen += 1
146
+
147
+ value = row.get(text_field, None) if isinstance(row, dict) else None
148
+ if not isinstance(value, str) or not value.strip():
149
+ continue
150
+
151
+ ids = tokenizer.encode(value, add_special_tokens=False)
152
+ if add_bos and tokenizer.bos_token_id is not None:
153
+ ids = [tokenizer.bos_token_id] + ids
154
+ if not ids:
155
+ continue
156
+
157
+ buffer.extend(ids)
158
+ while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
159
+ chunk = buffer[:seq_len]
160
+ buffer = buffer[seq_len:]
161
+ chunks.append(torch.tensor(chunk, dtype=torch.long))
162
+ if limit is not None and len(chunks) >= limit:
163
+ break
164
+
165
+ return chunks
166
+
167
+
168
+ def collect_texts_from_rows(
169
+ rows: Iterable[Dict[str, object]],
170
+ *,
171
+ text_field: str,
172
+ tokenizer,
173
+ target_tokens: int = 0,
174
+ add_bos: bool = False,
175
+ max_rows: int = 0,
176
+ ) -> List[str]:
177
+ texts: List[str] = []
178
+ token_count = 0
179
+ rows_seen = 0
180
+
181
+ for row in rows:
182
+ if max_rows > 0 and rows_seen >= max_rows:
183
+ break
184
+ rows_seen += 1
185
+
186
+ value = row.get(text_field, None) if isinstance(row, dict) else None
187
+ if not isinstance(value, str) or not value.strip():
188
+ continue
189
+
190
+ texts.append(value)
191
+ if target_tokens > 0:
192
+ ids = tokenizer.encode(value, add_special_tokens=False)
193
+ if add_bos and tokenizer.bos_token_id is not None:
194
+ ids = [tokenizer.bos_token_id] + ids
195
+ token_count += len(ids)
196
+ if token_count >= target_tokens:
197
+ break
198
+
199
+ return texts
200
+
201
+
202
+ def build_token_chunks(
203
+ texts: Iterable[str],
204
+ tokenizer,
205
+ seq_len: int,
206
+ num_sequences: int = 0,
207
+ add_bos: bool = False,
208
+ ) -> List[torch.Tensor]:
209
+ chunks: List[torch.Tensor] = []
210
+ buffer: List[int] = []
211
+ limit = None if num_sequences <= 0 else num_sequences
212
+
213
+ for text in texts:
214
+ ids = tokenizer.encode(text, add_special_tokens=False)
215
+ if add_bos and tokenizer.bos_token_id is not None:
216
+ ids = [tokenizer.bos_token_id] + ids
217
+ if not ids:
218
+ continue
219
+
220
+ buffer.extend(ids)
221
+ while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
222
+ chunk = buffer[:seq_len]
223
+ buffer = buffer[seq_len:]
224
+ chunks.append(torch.tensor(chunk, dtype=torch.long))
225
+ if limit is not None and len(chunks) >= limit:
226
+ break
227
+
228
+ return chunks
229
+
230
+
231
+ class TokenChunkDataset(torch.utils.data.Dataset):
232
+ def __init__(self, chunks: List[torch.Tensor]) -> None:
233
+ self.chunks = chunks
234
+
235
+ def __len__(self) -> int:
236
+ return len(self.chunks)
237
+
238
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
239
+ input_ids = self.chunks[idx]
240
+ attention_mask = torch.ones_like(input_ids)
241
+ return {
242
+ "input_ids": input_ids,
243
+ "attention_mask": attention_mask,
244
+ "labels": input_ids.clone(),
245
+ }
246
+
247
+
248
+ class TokenOnlyDataset(torch.utils.data.Dataset):
249
+ def __init__(self, chunks: List[torch.Tensor]) -> None:
250
+ self.chunks = chunks
251
+
252
+ def __len__(self) -> int:
253
+ return len(self.chunks)
254
+
255
+ def __getitem__(self, idx: int) -> torch.Tensor:
256
+ return self.chunks[idx]
257
+
258
+
259
+ class TokenInputMaskDataset(torch.utils.data.Dataset):
260
+ def __init__(self, chunks: List[torch.Tensor]) -> None:
261
+ self.chunks = chunks
262
+
263
+ def __len__(self) -> int:
264
+ return len(self.chunks)
265
+
266
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
267
+ input_ids = self.chunks[idx]
268
+ return {
269
+ "input_ids": input_ids,
270
+ "attention_mask": torch.ones_like(input_ids),
271
+ }
272
+
273
+
274
+ @dataclass
275
+ class SharedLMDataSpec:
276
+ dataset: str
277
+ config: Optional[str] = None
278
+ split: str = "train"
279
+ text_field: Optional[str] = None
280
+ num_samples: int = 0
281
+ seq_len: int = 2048
282
+ num_sequences: int = 0
283
+ target_tokens: int = 0
284
+ batch_size: int = 1
285
+ shuffle: bool = False
286
+ num_workers: int = 0
287
+ seed: int = 0
288
+ add_bos: bool = False
289
+
290
+
291
+ def build_chunks(spec: SharedLMDataSpec, tokenizer) -> List[torch.Tensor]:
292
+ if load_dataset is None:
293
+ raise SystemExit("datasets is required for shared LM dataloaders")
294
+
295
+ hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split)
296
+ dataset = load_dataset(
297
+ hf_name,
298
+ hf_config,
299
+ split=hf_split,
300
+ trust_remote_code=True,
301
+ )
302
+
303
+ target_sequences = spec.num_sequences
304
+ if spec.target_tokens > 0:
305
+ token_sequences = (spec.target_tokens + spec.seq_len - 1) // spec.seq_len
306
+ target_sequences = max(target_sequences, token_sequences)
307
+ row_limit = spec.num_samples if target_sequences <= 0 else 0
308
+
309
+ rows = _iter_dataset_rows(dataset, spec.seed)
310
+ text_field = spec.text_field or guess_text_field(dataset)
311
+ chunks = build_token_chunks_from_rows(
312
+ rows,
313
+ text_field=text_field,
314
+ tokenizer=tokenizer,
315
+ seq_len=spec.seq_len,
316
+ num_sequences=target_sequences,
317
+ add_bos=spec.add_bos,
318
+ max_rows=row_limit,
319
+ )
320
+ return chunks
321
+
322
+
323
+ def build_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader:
324
+ chunks = build_chunks(spec, tokenizer)
325
+ dataset = TokenChunkDataset(chunks)
326
+ return torch.utils.data.DataLoader(
327
+ dataset,
328
+ batch_size=spec.batch_size,
329
+ shuffle=spec.shuffle,
330
+ num_workers=spec.num_workers,
331
+ )
332
+
333
+
334
+ def build_text_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader:
335
+ if load_dataset is None:
336
+ raise SystemExit("datasets is required for shared LM dataloaders")
337
+
338
+ hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split)
339
+ dataset = load_dataset(
340
+ hf_name,
341
+ hf_config,
342
+ split=hf_split,
343
+ trust_remote_code=True,
344
+ )
345
+ rows = _iter_dataset_rows(dataset, spec.seed)
346
+ text_field = spec.text_field or guess_text_field(dataset)
347
+ row_limit = spec.num_samples
348
+ texts = collect_texts_from_rows(
349
+ rows,
350
+ text_field=text_field,
351
+ tokenizer=tokenizer,
352
+ target_tokens=spec.target_tokens,
353
+ add_bos=spec.add_bos,
354
+ max_rows=row_limit,
355
+ )
356
+ return torch.utils.data.DataLoader(
357
+ texts,
358
+ batch_size=spec.batch_size,
359
+ shuffle=spec.shuffle,
360
+ num_workers=spec.num_workers,
361
+ drop_last=True,
362
+ )
363
+
364
+
365
+ def build_uidl_post_train_dataloader(
366
+ spec: SharedLMDataSpec,
367
+ tokenizer,
368
+ ) -> torch.utils.data.DataLoader:
369
+ dataset = TokenChunkDataset(build_chunks(spec, tokenizer))
370
+ return torch.utils.data.DataLoader(
371
+ dataset,
372
+ batch_size=spec.batch_size,
373
+ shuffle=spec.shuffle,
374
+ num_workers=spec.num_workers,
375
+ )
376
+
377
+
378
+ def build_uidl_similarity_dataloader(
379
+ spec: SharedLMDataSpec,
380
+ tokenizer,
381
+ ) -> torch.utils.data.DataLoader:
382
+ dataset = TokenInputMaskDataset(build_chunks(spec, tokenizer))
383
+ return torch.utils.data.DataLoader(
384
+ dataset,
385
+ batch_size=spec.batch_size,
386
+ shuffle=spec.shuffle,
387
+ num_workers=spec.num_workers,
388
+ )
389
+
390
+
391
+ def build_shortened_llm_dataloader(
392
+ spec: SharedLMDataSpec,
393
+ tokenizer,
394
+ ) -> torch.utils.data.DataLoader:
395
+ dataset = TokenOnlyDataset(build_chunks(spec, tokenizer))
396
+ return torch.utils.data.DataLoader(
397
+ dataset,
398
+ batch_size=spec.batch_size,
399
+ shuffle=spec.shuffle,
400
+ num_workers=spec.num_workers,
401
+ )
402
+
403
+
404
+ def build_shortened_llm_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor:
405
+ chunks = build_chunks(spec, tokenizer)
406
+ if not chunks:
407
+ return torch.empty((0, spec.seq_len), dtype=torch.long)
408
+ return torch.stack(chunks, dim=0)
409
+
410
+
411
+ def build_llmpruner_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor:
412
+ chunks = build_chunks(spec, tokenizer)
413
+ if not chunks:
414
+ return torch.empty((0, spec.seq_len), dtype=torch.long)
415
+ return torch.stack(chunks, dim=0)
416
+
417
+
418
+ def build_replaceme_dataloader(
419
+ spec: SharedLMDataSpec,
420
+ tokenizer,
421
+ ) -> torch.utils.data.DataLoader:
422
+ return build_text_dataloader(spec, tokenizer)
423
+
424
+
425
+ def build_hf_causal_dataset(spec: SharedLMDataSpec, tokenizer):
426
+ if HFDataset is None:
427
+ raise SystemExit("datasets is required for shared LM dataloaders")
428
+
429
+ chunks = build_chunks(spec, tokenizer)
430
+ payload = {
431
+ "input_ids": [chunk.tolist() for chunk in chunks],
432
+ "attention_mask": [torch.ones_like(chunk).tolist() for chunk in chunks],
433
+ "labels": [chunk.tolist() for chunk in chunks],
434
+ }
435
+ return HFDataset.from_dict(payload)
src/convert_llmpruner_checkpoint.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+
8
+
9
+ def ensure_llmpruner_on_path() -> None:
10
+ repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
+ llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner")
12
+ if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path:
13
+ sys.path.insert(0, llmpruner_root)
14
+
15
+
16
+ def load_llmpruner_checkpoint(path: str):
17
+ ensure_llmpruner_on_path()
18
+ checkpoint = torch.load(path, map_location="cpu", weights_only=False)
19
+ if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
20
+ raise SystemExit(
21
+ "Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries."
22
+ )
23
+ return checkpoint["model"], checkpoint["tokenizer"]
24
+
25
+
26
+ def main() -> None:
27
+ parser = argparse.ArgumentParser(
28
+ description="Convert an LLM-Pruner .bin checkpoint to a Hugging Face save_pretrained directory."
29
+ )
30
+ parser.add_argument("--input", required=True, help="Path to LLM-Pruner pytorch_model.bin")
31
+ parser.add_argument("--output_dir", required=True, help="Directory to write HF model artifacts")
32
+ args = parser.parse_args()
33
+
34
+ model, tokenizer = load_llmpruner_checkpoint(args.input)
35
+ os.makedirs(args.output_dir, exist_ok=True)
36
+ model.save_pretrained(args.output_dir)
37
+ tokenizer.save_pretrained(args.output_dir)
38
+ print(args.output_dir)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
src/eval_ppl.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import csv
4
+ import json
5
+ import os
6
+ import sys
7
+ from typing import Iterable
8
+
9
+ import numpy as np
10
+ import torch
11
+ from datasets import load_dataset
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from tqdm import tqdm
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+
16
+
17
+ class IndexDataset(Dataset):
18
+ def __init__(self, tensors: torch.Tensor):
19
+ self.tensors = tensors
20
+
21
+ def __getitem__(self, index: int) -> torch.Tensor:
22
+ return self.tensors[index]
23
+
24
+ def __len__(self) -> int:
25
+ return len(self.tensors)
26
+
27
+
28
+ def get_dataset(name: str):
29
+ if name == "wikitext2":
30
+ train_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
31
+ test_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
32
+ return train_data, test_data, "text"
33
+ if name == "ptb":
34
+ train_data = load_dataset("ptb_text_only", "penn_treebank", split="train")
35
+ test_data = load_dataset("ptb_text_only", "penn_treebank", split="validation")
36
+ return train_data, test_data, "sentence"
37
+ raise ValueError(f"Unsupported dataset: {name}")
38
+
39
+
40
+ def process_data(samples, tokenizer, seq_len: int, field_name: str, add_bos_to_every: bool) -> IndexDataset:
41
+ test_ids = tokenizer(
42
+ "\n\n".join(samples[field_name]),
43
+ return_tensors="pt",
44
+ add_special_tokens=False,
45
+ ).input_ids[0]
46
+
47
+ if not add_bos_to_every and tokenizer.bos_token_id is not None:
48
+ test_ids = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), test_ids), dim=0)
49
+
50
+ batches = []
51
+ num_samples = test_ids.numel() // seq_len
52
+ for index in range(num_samples):
53
+ batch = test_ids[(index * seq_len) : ((index + 1) * seq_len)]
54
+ if add_bos_to_every and tokenizer.bos_token_id is not None:
55
+ batch = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), batch), dim=0)
56
+ batches.append(batch)
57
+
58
+ return IndexDataset(tensors=torch.stack(batches))
59
+
60
+
61
+ def get_loader(name: str, tokenizer, seq_len: int, batch_size: int, add_bos_to_every: bool):
62
+ _, test_data, field_name = get_dataset(name)
63
+ dataset = process_data(test_data, tokenizer, seq_len, field_name, add_bos_to_every)
64
+ return DataLoader(dataset, batch_size=batch_size, shuffle=False)
65
+
66
+
67
+ @torch.no_grad()
68
+ def evaluate_ppl(model, test_loader, device: str) -> float:
69
+ nlls = []
70
+ for batch in tqdm(test_loader, desc="Running PPL", dynamic_ncols=True):
71
+ batch = batch.to(device)
72
+ outputs = model(batch)
73
+ shift_logits = outputs.logits[:, :-1, :].contiguous()
74
+ shift_labels = batch[:, 1:].contiguous()
75
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
76
+ loss = loss_fct(
77
+ shift_logits.reshape(-1, shift_logits.size(-1)),
78
+ shift_labels.view(-1),
79
+ )
80
+ nlls.append(loss.cpu())
81
+
82
+ return float(np.exp(torch.cat(nlls, dim=-1).mean().item()))
83
+
84
+
85
+ def resolve_dtype(args) -> torch.dtype:
86
+ if args.use_bfloat:
87
+ return torch.bfloat16
88
+
89
+ dtype_name = args.dtype if args.dtype is not None else args.torch_dtype
90
+ if dtype_name is None:
91
+ dtype_name = "float16"
92
+
93
+ dtype_map = {
94
+ "float16": torch.float16,
95
+ "fp16": torch.float16,
96
+ "bfloat16": torch.bfloat16,
97
+ "bf16": torch.bfloat16,
98
+ "float32": torch.float32,
99
+ "fp32": torch.float32,
100
+ }
101
+ if dtype_name not in dtype_map:
102
+ raise ValueError(f"Unsupported dtype: {dtype_name}")
103
+ return dtype_map[dtype_name]
104
+
105
+
106
+ def normalize_datasets(datasets: Iterable[str]) -> list[str]:
107
+ normalized = []
108
+ for dataset in datasets:
109
+ normalized.append("wikitext2" if dataset == "wikitext" else dataset)
110
+ return normalized
111
+
112
+
113
+ def build_arg_parser() -> argparse.ArgumentParser:
114
+ parser = argparse.ArgumentParser(description="Shared perplexity evaluation for abprune.")
115
+ parser.add_argument("--base_model", "--model-path", dest="model_path", required=True)
116
+ parser.add_argument("--output_dir", type=str, default=None)
117
+ parser.add_argument("--dataset", nargs="+", default=["wikitext2", "ptb"])
118
+ parser.add_argument("--max_seq_len", "--seq-len", dest="seq_len", type=int, default=1024)
119
+ parser.add_argument("--batch_size", type=int, default=4)
120
+ parser.add_argument("--device", default="cuda")
121
+ parser.add_argument(
122
+ "--dtype",
123
+ default=None,
124
+ choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"],
125
+ )
126
+ parser.add_argument(
127
+ "--torch_dtype",
128
+ default=None,
129
+ choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"],
130
+ )
131
+ parser.add_argument("--use_bfloat", action="store_true")
132
+ parser.add_argument("--add_bos_to_every", action="store_true")
133
+ parser.add_argument("--fix_decapoda_config", action="store_true")
134
+ parser.add_argument("--local_files_only", action="store_true")
135
+ return parser
136
+
137
+
138
+ def maybe_fix_decapoda_config(tokenizer, enabled: bool) -> None:
139
+ if not enabled:
140
+ return
141
+ if tokenizer.bos_token_id is None and tokenizer.eos_token_id is not None:
142
+ tokenizer.bos_token = tokenizer.eos_token
143
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
144
+ tokenizer.pad_token = tokenizer.eos_token
145
+
146
+
147
+ def ensure_llmpruner_on_path() -> None:
148
+ repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
149
+ llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner")
150
+ if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path:
151
+ sys.path.insert(0, llmpruner_root)
152
+
153
+
154
+ def load_model_and_tokenizer(model_path: str, *, torch_dtype: torch.dtype, local_files_only: bool):
155
+ if os.path.isfile(model_path) and model_path.endswith(".bin"):
156
+ ensure_llmpruner_on_path()
157
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
158
+ if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
159
+ raise ValueError(
160
+ "Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries."
161
+ )
162
+ model = checkpoint["model"]
163
+ tokenizer = checkpoint["tokenizer"]
164
+ if torch_dtype is not None:
165
+ model = model.to(dtype=torch_dtype)
166
+ return model, tokenizer
167
+
168
+ tokenizer = AutoTokenizer.from_pretrained(
169
+ model_path,
170
+ local_files_only=local_files_only,
171
+ )
172
+ model = AutoModelForCausalLM.from_pretrained(
173
+ model_path,
174
+ torch_dtype=torch_dtype,
175
+ local_files_only=local_files_only,
176
+ )
177
+ return model, tokenizer
178
+
179
+
180
+ def main() -> None:
181
+ parser = build_arg_parser()
182
+ args = parser.parse_args()
183
+
184
+ datasets = normalize_datasets(args.dataset)
185
+ torch_dtype = resolve_dtype(args)
186
+
187
+ model, tokenizer = load_model_and_tokenizer(
188
+ args.model_path,
189
+ torch_dtype=torch_dtype,
190
+ local_files_only=args.local_files_only,
191
+ )
192
+ maybe_fix_decapoda_config(tokenizer, args.fix_decapoda_config)
193
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
194
+ tokenizer.pad_token = tokenizer.eos_token
195
+
196
+ model.eval()
197
+ model.to(args.device)
198
+
199
+ metrics = {}
200
+ for dataset in datasets:
201
+ test_loader = get_loader(
202
+ dataset,
203
+ tokenizer,
204
+ seq_len=args.seq_len,
205
+ batch_size=args.batch_size,
206
+ add_bos_to_every=args.add_bos_to_every,
207
+ )
208
+ metrics[dataset] = evaluate_ppl(model, test_loader, args.device)
209
+ print(f"PPL-{dataset}: {metrics[dataset]} | add_bos_to_every: {args.add_bos_to_every} | seq_len: {args.seq_len}")
210
+
211
+ mem = None
212
+ if torch.cuda.is_available() and args.device.startswith("cuda"):
213
+ mem = torch.cuda.memory_allocated(args.device) / 1024 / 1024
214
+
215
+ result = {
216
+ "model_path": os.path.abspath(args.model_path),
217
+ "datasets": datasets,
218
+ "seq_len": args.seq_len,
219
+ "batch_size": args.batch_size,
220
+ "device": args.device,
221
+ "dtype": str(torch_dtype).replace("torch.", ""),
222
+ "add_bos_to_every": args.add_bos_to_every,
223
+ "metrics": metrics,
224
+ "params": int(sum(parameter.numel() for parameter in model.parameters())),
225
+ "mem_mib": mem,
226
+ }
227
+
228
+ if args.output_dir is not None:
229
+ os.makedirs(args.output_dir, exist_ok=True)
230
+ filename = "ppl_bos.csv" if args.add_bos_to_every else "ppl.csv"
231
+ csv_path = os.path.join(args.output_dir, filename)
232
+ with open(csv_path, "w", newline="", encoding="utf-8") as handle:
233
+ writer = csv.writer(handle)
234
+ writer.writerow([*(f"ppl_{dataset}" for dataset in datasets), "params", "mem"])
235
+ writer.writerow([*(metrics[dataset] for dataset in datasets), result["params"], mem])
236
+
237
+ print(json.dumps(result, ensure_ascii=True))
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
src/fbmc_metric.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Estimate Fisher-Barycentric Merge Cost (FBMC) for adjacent layers."""
3
+
4
+ import argparse
5
+ import csv
6
+ import json
7
+ import os
8
+ from typing import Dict, List, Optional, Tuple
9
+
10
+ import torch
11
+
12
+ try:
13
+ from datasets import load_dataset
14
+ except Exception: # pragma: no cover - optional dependency
15
+ load_dataset = None
16
+
17
+ try:
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+ except Exception as exc: # pragma: no cover - fail early with clear error
20
+ raise SystemExit("transformers is required: pip install transformers") from exc
21
+
22
+
23
+ def parse_args() -> argparse.Namespace:
24
+ parser = argparse.ArgumentParser(
25
+ description="Compute FBMC for adjacent layers of a Hugging Face causal LM."
26
+ )
27
+ parser.add_argument("--model", required=True, help="HF model id or local path")
28
+ parser.add_argument(
29
+ "--dataset",
30
+ action="append",
31
+ default=[],
32
+ help=(
33
+ "HF dataset name (repeatable). Optional if using --text or --text_file."
34
+ ),
35
+ )
36
+ parser.add_argument(
37
+ "--dataset_config",
38
+ action="append",
39
+ default=[],
40
+ help="Optional dataset config (repeatable or single shared config).",
41
+ )
42
+ parser.add_argument(
43
+ "--dataset_split",
44
+ default="train",
45
+ help="Dataset split to use (default: train)",
46
+ )
47
+ parser.add_argument(
48
+ "--dataset_text_field",
49
+ default=None,
50
+ help="Text field in dataset (default: auto-detect, applies to all datasets)",
51
+ )
52
+ parser.add_argument(
53
+ "--text",
54
+ action="append",
55
+ default=[],
56
+ help="Inline text samples (can pass multiple)",
57
+ )
58
+ parser.add_argument(
59
+ "--text_file",
60
+ default=None,
61
+ help="Path to a text file for calibration data",
62
+ )
63
+ parser.add_argument(
64
+ "--num_samples",
65
+ type=int,
66
+ default=128,
67
+ help="Number of token sequences to use",
68
+ )
69
+ parser.add_argument(
70
+ "--seq_len", type=int, default=256, help="Sequence length"
71
+ )
72
+ parser.add_argument(
73
+ "--batch_size", type=int, default=2, help="Batch size"
74
+ )
75
+ parser.add_argument(
76
+ "--device",
77
+ default="cuda" if torch.cuda.is_available() else "cpu",
78
+ help="Device for model + compute",
79
+ )
80
+ parser.add_argument(
81
+ "--dtype",
82
+ default="auto",
83
+ choices=["auto", "float32", "float16", "bfloat16"],
84
+ help="Model dtype",
85
+ )
86
+ parser.add_argument(
87
+ "--layer_path",
88
+ default=None,
89
+ help="Override layer attribute path (e.g., model.layers)",
90
+ )
91
+ parser.add_argument(
92
+ "--fisher_mode",
93
+ default="tensor",
94
+ choices=["tensor", "param"],
95
+ help="Fisher approximation granularity",
96
+ )
97
+ parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon")
98
+ parser.add_argument(
99
+ "--output",
100
+ default=None,
101
+ help="Optional JSON output path",
102
+ )
103
+ parser.add_argument(
104
+ "--output_csv",
105
+ default=None,
106
+ help="Optional CSV output path",
107
+ )
108
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
109
+ parser.add_argument(
110
+ "--trust_remote_code",
111
+ action="store_true",
112
+ help="Allow custom model code from hub",
113
+ )
114
+ return parser.parse_args()
115
+
116
+
117
+ def resolve_attr(root: object, path: str) -> Optional[object]:
118
+ cur = root
119
+ for part in path.split("."):
120
+ if not hasattr(cur, part):
121
+ return None
122
+ cur = getattr(cur, part)
123
+ return cur
124
+
125
+
126
+ def find_layers(model, layer_path: Optional[str]) -> List[torch.nn.Module]:
127
+ if layer_path:
128
+ layers = resolve_attr(model, layer_path)
129
+ if layers is None:
130
+ raise ValueError(f"layer_path '{layer_path}' not found on model")
131
+ return list(layers)
132
+
133
+ # Common decoder-only layer containers. Add more if needed.
134
+ candidate_paths = [
135
+ "model.layers", # LLaMA, Mistral, Qwen2, Gemma
136
+ "model.decoder.layers", # OPT
137
+ "transformer.h", # GPT-2, GPT-J, Bloom, Falcon
138
+ "transformer.blocks", # MPT
139
+ "gpt_neox.layers", # GPT-NeoX
140
+ "layers", # fallback
141
+ ]
142
+ for path in candidate_paths:
143
+ layers = resolve_attr(model, path)
144
+ if layers is not None:
145
+ try:
146
+ return list(layers)
147
+ except TypeError:
148
+ continue
149
+ raise ValueError(
150
+ "Could not locate transformer layers. Pass --layer_path explicitly."
151
+ )
152
+
153
+
154
+ def guess_text_field(dataset) -> str:
155
+ if hasattr(dataset, "column_names") and dataset.column_names:
156
+ if "text" in dataset.column_names:
157
+ return "text"
158
+ return dataset.column_names[0]
159
+ if hasattr(dataset, "features"):
160
+ names = list(dataset.features.keys())
161
+ if "text" in names:
162
+ return "text"
163
+ if names:
164
+ return names[0]
165
+ return "text"
166
+
167
+
168
+ def _normalize_config(config: Optional[str]) -> Optional[str]:
169
+ if config is None:
170
+ return None
171
+ if config.strip().lower() in {"none", "null", "-"}:
172
+ return None
173
+ return config
174
+
175
+
176
+ def _expand_dataset_configs(
177
+ datasets: List[str], configs: List[str]
178
+ ) -> List[Optional[str]]:
179
+ if not configs:
180
+ return [None] * len(datasets)
181
+ if len(configs) == 1 and len(datasets) > 1:
182
+ return [_normalize_config(configs[0])] * len(datasets)
183
+ if len(configs) != len(datasets):
184
+ raise SystemExit(
185
+ "Provide zero, one, or matching-count --dataset_config values."
186
+ )
187
+ return [_normalize_config(cfg) for cfg in configs]
188
+
189
+
190
+ def _sample_dataset_rows(
191
+ dataset, target: int, seed: int
192
+ ) -> List[Dict[str, object]]:
193
+ if target <= 0:
194
+ return []
195
+ try:
196
+ dataset = dataset.shuffle(seed=seed)
197
+ except Exception:
198
+ pass
199
+
200
+ if hasattr(dataset, "__len__"):
201
+ limit = min(target, len(dataset))
202
+ dataset = dataset.select(range(limit))
203
+ return [row for row in dataset]
204
+
205
+ # IterableDataset fallback.
206
+ rows = []
207
+ for row in dataset:
208
+ rows.append(row)
209
+ if len(rows) >= target:
210
+ break
211
+ return rows
212
+
213
+
214
+ def load_texts(args: argparse.Namespace) -> List[str]:
215
+ texts: List[str] = []
216
+ if args.text_file:
217
+ with open(args.text_file, "r", encoding="utf-8") as handle:
218
+ texts.extend([line.strip() for line in handle if line.strip()])
219
+ if args.text:
220
+ texts.extend([t for t in args.text if t])
221
+
222
+ if args.dataset:
223
+ if load_dataset is None:
224
+ raise SystemExit("datasets is required for --dataset")
225
+
226
+ datasets = list(args.dataset)
227
+ configs = _expand_dataset_configs(datasets, list(args.dataset_config))
228
+ num_datasets = len(datasets)
229
+ base = args.num_samples // num_datasets
230
+ remainder = args.num_samples % num_datasets
231
+
232
+ for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
233
+ target = base + (1 if idx < remainder else 0)
234
+ dataset = load_dataset(
235
+ dataset_name,
236
+ config,
237
+ split=args.dataset_split,
238
+ trust_remote_code=True,
239
+ )
240
+ rows = _sample_dataset_rows(dataset, target, args.seed + idx)
241
+ text_field = args.dataset_text_field or guess_text_field(dataset)
242
+ for row in rows:
243
+ value = row.get(text_field, None) if isinstance(row, dict) else None
244
+ if isinstance(value, str) and value.strip():
245
+ texts.append(value)
246
+
247
+ return texts
248
+
249
+
250
+ def build_token_chunks(
251
+ texts: List[str], tokenizer, seq_len: int, num_samples: int
252
+ ) -> List[torch.Tensor]:
253
+ chunks: List[torch.Tensor] = []
254
+ buffer: List[int] = []
255
+ for text in texts:
256
+ ids = tokenizer.encode(text, add_special_tokens=False)
257
+ if not ids:
258
+ continue
259
+ buffer.extend(ids)
260
+ while len(buffer) >= seq_len and len(chunks) < num_samples:
261
+ chunk = buffer[:seq_len]
262
+ buffer = buffer[seq_len:]
263
+ chunks.append(torch.tensor(chunk, dtype=torch.long))
264
+ if len(chunks) >= num_samples:
265
+ break
266
+ return chunks
267
+
268
+
269
+ def get_dtype(dtype: str):
270
+ if dtype == "auto":
271
+ return None
272
+ if dtype == "float16":
273
+ return torch.float16
274
+ if dtype == "bfloat16":
275
+ return torch.bfloat16
276
+ return torch.float32
277
+
278
+
279
+ def compute_fisher(
280
+ model,
281
+ layers: List[torch.nn.Module],
282
+ dataloader,
283
+ fisher_mode: str,
284
+ device: str,
285
+ ) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]:
286
+ # Only compute grads for layer params.
287
+ for param in model.parameters():
288
+ param.requires_grad_(False)
289
+ for layer in layers:
290
+ for param in layer.parameters():
291
+ param.requires_grad_(True)
292
+
293
+ fisher_sums: List[Dict[str, object]] = []
294
+ param_numels: List[Dict[str, int]] = []
295
+ for layer in layers:
296
+ layer_sums: Dict[str, object] = {}
297
+ layer_numels: Dict[str, int] = {}
298
+ for name, param in layer.named_parameters():
299
+ if not param.requires_grad:
300
+ continue
301
+ if fisher_mode == "param":
302
+ layer_sums[name] = torch.zeros_like(
303
+ param, dtype=torch.float32, device="cpu"
304
+ )
305
+ else:
306
+ layer_sums[name] = 0.0
307
+ layer_numels[name] = param.numel()
308
+ fisher_sums.append(layer_sums)
309
+ param_numels.append(layer_numels)
310
+
311
+ num_batches = 0
312
+ model.eval()
313
+ for batch in dataloader:
314
+ input_ids = batch[0].to(device)
315
+ outputs = model(input_ids=input_ids, labels=input_ids)
316
+ loss = outputs.loss
317
+ loss.backward()
318
+ for layer_idx, layer in enumerate(layers):
319
+ layer_sums = fisher_sums[layer_idx]
320
+ for name, param in layer.named_parameters():
321
+ if not param.requires_grad:
322
+ continue
323
+ if param.grad is None:
324
+ continue
325
+ grad_sq = param.grad.detach().float().pow(2)
326
+ if fisher_mode == "param":
327
+ layer_sums[name] += grad_sq.cpu()
328
+ else:
329
+ layer_sums[name] += float(grad_sq.sum().item())
330
+ model.zero_grad(set_to_none=True)
331
+ num_batches += 1
332
+
333
+ if num_batches == 0:
334
+ raise RuntimeError("No batches processed; check dataset or text inputs.")
335
+
336
+ return fisher_sums, num_batches, param_numels
337
+
338
+
339
+ def compute_fbmc_costs(
340
+ layers: List[torch.nn.Module],
341
+ fisher_sums: List[Dict[str, object]],
342
+ num_batches: int,
343
+ param_numels: List[Dict[str, int]],
344
+ fisher_mode: str,
345
+ eps: float,
346
+ ) -> List[Dict[str, object]]:
347
+ layer_params: List[Dict[str, torch.nn.Parameter]] = []
348
+ for layer in layers:
349
+ layer_params.append({name: param for name, param in layer.named_parameters()})
350
+
351
+ results: List[Dict[str, object]] = []
352
+ for idx in range(len(layers) - 1):
353
+ cost = 0.0
354
+ matched = 0
355
+ skipped = 0
356
+ params_i = layer_params[idx]
357
+ params_j = layer_params[idx + 1]
358
+ for name, param_i in params_i.items():
359
+ param_j = params_j.get(name)
360
+ if param_j is None or param_j.shape != param_i.shape:
361
+ skipped += 1
362
+ continue
363
+ matched += 1
364
+ if fisher_mode == "param":
365
+ fisher_i = fisher_sums[idx][name] / num_batches
366
+ fisher_j = fisher_sums[idx + 1][name] / num_batches
367
+ diff = (param_i.detach().float().cpu() - param_j.detach().float().cpu())
368
+ denom = fisher_i + fisher_j + eps
369
+ term = (fisher_i * fisher_j / denom) * diff * diff
370
+ cost += 0.5 * float(term.sum().item())
371
+ else:
372
+ fisher_i = fisher_sums[idx][name] / (
373
+ num_batches * param_numels[idx][name]
374
+ )
375
+ fisher_j = fisher_sums[idx + 1][name] / (
376
+ num_batches * param_numels[idx + 1][name]
377
+ )
378
+ denom = fisher_i + fisher_j + eps
379
+ if denom == 0:
380
+ continue
381
+ diff_sq = (
382
+ param_i.detach().float() - param_j.detach().float()
383
+ ).pow(2)
384
+ cost += 0.5 * (fisher_i * fisher_j / denom) * float(
385
+ diff_sq.sum().item()
386
+ )
387
+ results.append(
388
+ {
389
+ "layer_i": idx,
390
+ "layer_j": idx + 1,
391
+ "fbmc": cost,
392
+ "matched_params": matched,
393
+ "skipped_params": skipped,
394
+ }
395
+ )
396
+ return results
397
+
398
+
399
+ def main() -> None:
400
+ args = parse_args()
401
+ torch.manual_seed(args.seed)
402
+
403
+ dtype = get_dtype(args.dtype)
404
+ model = AutoModelForCausalLM.from_pretrained(
405
+ args.model,
406
+ torch_dtype=dtype,
407
+ trust_remote_code=args.trust_remote_code,
408
+ )
409
+ tokenizer = AutoTokenizer.from_pretrained(
410
+ args.model, trust_remote_code=args.trust_remote_code
411
+ )
412
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
413
+ tokenizer.pad_token = tokenizer.eos_token
414
+
415
+ layers = find_layers(model, args.layer_path)
416
+ if len(layers) < 2:
417
+ raise SystemExit("Model has fewer than 2 layers; cannot compute FBMC.")
418
+
419
+ texts = load_texts(args)
420
+ if not texts:
421
+ raise SystemExit(
422
+ "No calibration text found. Provide --dataset, --text, or --text_file."
423
+ )
424
+
425
+ chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples)
426
+ if not chunks:
427
+ raise SystemExit("Not enough text to build token sequences.")
428
+
429
+ dataset = torch.utils.data.TensorDataset(torch.stack(chunks))
430
+ dataloader = torch.utils.data.DataLoader(
431
+ dataset, batch_size=args.batch_size, shuffle=False
432
+ )
433
+
434
+ model.to(args.device)
435
+
436
+ fisher_sums, num_batches, param_numels = compute_fisher(
437
+ model,
438
+ layers,
439
+ dataloader,
440
+ fisher_mode=args.fisher_mode,
441
+ device=args.device,
442
+ )
443
+
444
+ costs = compute_fbmc_costs(
445
+ layers,
446
+ fisher_sums,
447
+ num_batches,
448
+ param_numels,
449
+ fisher_mode=args.fisher_mode,
450
+ eps=args.eps,
451
+ )
452
+
453
+ costs_sorted = sorted(costs, key=lambda x: x["fbmc"])
454
+ best = costs_sorted[0]
455
+
456
+ print("FBMC results (layer order):")
457
+ for item in costs:
458
+ print(
459
+ f"layers {item['layer_i']} & {item['layer_j']} -> "
460
+ f"fbmc={item['fbmc']:.6e} "
461
+ f"(matched={item['matched_params']}, skipped={item['skipped_params']})"
462
+ )
463
+ print("\nFBMC results (lowest cost first):")
464
+ for item in costs_sorted:
465
+ print(
466
+ f"layers {item['layer_i']} & {item['layer_j']} -> "
467
+ f"fbmc={item['fbmc']:.6e} "
468
+ f"(matched={item['matched_params']}, skipped={item['skipped_params']})"
469
+ )
470
+ print(
471
+ f"\nBest pair: layers {best['layer_i']} & {best['layer_j']} "
472
+ f"(fbmc={best['fbmc']:.6e})"
473
+ )
474
+
475
+ if args.output:
476
+ payload = {
477
+ "model": args.model,
478
+ "num_layers": len(layers),
479
+ "fisher_mode": args.fisher_mode,
480
+ "num_batches": num_batches,
481
+ "num_sequences": len(chunks),
482
+ "seq_len": args.seq_len,
483
+ "best_pair": best,
484
+ "pairs": costs_sorted,
485
+ }
486
+ os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
487
+ with open(args.output, "w", encoding="utf-8") as handle:
488
+ json.dump(payload, handle, indent=2)
489
+ print(f"\nWrote results to {args.output}")
490
+
491
+ if args.output_csv:
492
+ os.makedirs(os.path.dirname(args.output_csv) or ".", exist_ok=True)
493
+ with open(args.output_csv, "w", encoding="utf-8", newline="") as handle:
494
+ writer = csv.DictWriter(
495
+ handle,
496
+ fieldnames=[
497
+ "layer_i",
498
+ "layer_j",
499
+ "fbmc",
500
+ "matched_params",
501
+ "skipped_params",
502
+ ],
503
+ )
504
+ writer.writeheader()
505
+ for item in costs_sorted:
506
+ writer.writerow(
507
+ {
508
+ "layer_i": item["layer_i"],
509
+ "layer_j": item["layer_j"],
510
+ "fbmc": item["fbmc"],
511
+ "matched_params": item["matched_params"],
512
+ "skipped_params": item["skipped_params"],
513
+ }
514
+ )
515
+ print(f"Wrote CSV results to {args.output_csv}")
516
+
517
+
518
+ if __name__ == "__main__":
519
+ main()
src/fuse_layers.py ADDED
@@ -0,0 +1,2416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Fuse adjacent layers via attention head alignment + Fisher-barycentric merge."""
3
+
4
+ import argparse
5
+ import copy
6
+ import gc
7
+ import json
8
+ import os
9
+ import random
10
+ from dataclasses import dataclass
11
+ from typing import Dict, List, Optional, Set, Tuple
12
+
13
+ import torch
14
+
15
+ try:
16
+ import numpy as np
17
+ except Exception: # pragma: no cover - optional dependency
18
+ np = None
19
+
20
+ try:
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+ except Exception as exc: # pragma: no cover - fail early with clear error
23
+ raise SystemExit("transformers is required: pip install transformers") from exc
24
+
25
+ try:
26
+ import ppl_eval
27
+ except Exception as exc: # pragma: no cover - optional dependency
28
+ raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
29
+
30
+ from fuse_layers_data import (
31
+ FixedSeqDataset,
32
+ build_token_chunks,
33
+ expand_dataset_configs,
34
+ load_instruction_records,
35
+ load_texts,
36
+ load_texts_from_datasets,
37
+ )
38
+ from common_lm_data import SharedLMDataSpec, build_chunks, build_dataloader
39
+ from fuse_layers_distill import (
40
+ commutator_precondition,
41
+ compute_fisher_gate_priors,
42
+ distill_reparam_merge,
43
+ lora_ce_finetune,
44
+ )
45
+ from fuse_layers_model import (
46
+ apply_norm_policy,
47
+ build_head_permutation,
48
+ clone_state_dict,
49
+ compute_fisher,
50
+ compute_head_means,
51
+ decrement_config,
52
+ drop_layer,
53
+ find_attention_module,
54
+ find_colon_modules,
55
+ find_layer_container,
56
+ get_dtype,
57
+ get_norm_pair,
58
+ merge_layers,
59
+ permute_attention_heads,
60
+ )
61
+ from fuse_layers_select import select_layer_auto
62
+ from progressive_loader import load_causal_lm, load_progressive_model
63
+
64
+
65
+ def parse_args() -> argparse.Namespace:
66
+ parser = argparse.ArgumentParser(
67
+ description="Fuse layer i and i+1 using head alignment + Fisher barycenter."
68
+ )
69
+ parser.add_argument("--model", required=True, help="HF model id or local path")
70
+ parser.add_argument(
71
+ "--model_cache_dir",
72
+ default=None,
73
+ help="Optional cache dir for model/tokenizer downloads",
74
+ )
75
+ parser.add_argument(
76
+ "--layer",
77
+ type=str,
78
+ default="auto",
79
+ help="Layer index i (int) or 'auto' to select via auto metric",
80
+ )
81
+ parser.add_argument(
82
+ "--selection_method",
83
+ choices=["dwce", "sequential"],
84
+ default="dwce",
85
+ help=(
86
+ "Pair selection policy for progressive pruning. "
87
+ "'dwce' uses downstream-weighted composition error; "
88
+ "'sequential' always takes the next available pair."
89
+ ),
90
+ )
91
+ parser.add_argument(
92
+ "--exclude_pairs",
93
+ "--exclude_layers",
94
+ nargs="*",
95
+ default=None,
96
+ dest="exclude_pairs",
97
+ help=(
98
+ "Exclude pair indices from consideration for any fusion. Indices refer to "
99
+ "pair start positions in [0..N-2]. Negative indices count from the end "
100
+ "(-1 = last pair, -2 = second last). Accepts space- or comma-separated ints. "
101
+ "Alias: --exclude_layers (deprecated)."
102
+ ),
103
+ )
104
+ parser.add_argument(
105
+ "--output_dir", required=True, help="Directory to write fused model"
106
+ )
107
+ parser.add_argument(
108
+ "--dataset",
109
+ action="append",
110
+ default=[],
111
+ help=(
112
+ "HF dataset name (repeatable). Optional if using --text or --text_file."
113
+ ),
114
+ )
115
+ parser.add_argument(
116
+ "--dataset_config",
117
+ action="append",
118
+ default=[],
119
+ help="Optional dataset config (repeatable or single shared config).",
120
+ )
121
+ parser.add_argument(
122
+ "--dataset_split",
123
+ default="train",
124
+ help="Dataset split to use (default: train)",
125
+ )
126
+ parser.add_argument(
127
+ "--dataset_text_field",
128
+ default=None,
129
+ help="Text field in dataset (default: auto-detect, applies to all datasets)",
130
+ )
131
+ parser.add_argument(
132
+ "--text",
133
+ action="append",
134
+ default=[],
135
+ help="Inline text samples (can pass multiple)",
136
+ )
137
+ parser.add_argument(
138
+ "--text_file",
139
+ default=None,
140
+ help="Path to a text file for calibration data",
141
+ )
142
+ parser.add_argument(
143
+ "--num_samples",
144
+ type=int,
145
+ default=128,
146
+ help="Number of token sequences to use",
147
+ )
148
+ parser.add_argument(
149
+ "--target_tokens",
150
+ type=int,
151
+ default=0,
152
+ help="Target token budget for common_lm_data-backed calibration/distillation (0 = disabled)",
153
+ )
154
+ parser.add_argument("--seq_len", type=int, default=256, help="Sequence length")
155
+ parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
156
+ parser.add_argument(
157
+ "--device",
158
+ default="cuda" if torch.cuda.is_available() else "cpu",
159
+ help="Device for model + compute",
160
+ )
161
+ parser.add_argument(
162
+ "--dtype",
163
+ default="auto",
164
+ choices=["auto", "float32", "float16", "bfloat16"],
165
+ help="Model dtype",
166
+ )
167
+ parser.add_argument(
168
+ "--layer_path",
169
+ default=None,
170
+ help="Override layer attribute path (e.g., model.layers)",
171
+ )
172
+ parser.add_argument(
173
+ "--fisher_mode",
174
+ default="tensor",
175
+ choices=["tensor", "param"],
176
+ help="Fisher approximation granularity",
177
+ )
178
+ parser.add_argument(
179
+ "--no_head_permute",
180
+ action="store_true",
181
+ help=(
182
+ "Deprecated alias for --no_head_permute_merge. "
183
+ "Disables merge-stage head permutation only."
184
+ ),
185
+ )
186
+ parser.add_argument(
187
+ "--no_head_permute_merge",
188
+ action="store_true",
189
+ help="Disable attention head permutation alignment before merge",
190
+ )
191
+ parser.add_argument(
192
+ "--no_head_permute_select",
193
+ action="store_true",
194
+ help="Disable attention head permutation alignment during auto selection",
195
+ )
196
+ parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon")
197
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
198
+ parser.add_argument(
199
+ "--trust_remote_code",
200
+ action="store_true",
201
+ help="Allow custom model code from hub",
202
+ )
203
+ parser.add_argument(
204
+ "--save_metadata",
205
+ action="store_true",
206
+ help="Backward-compatible no-op; metadata is always written.",
207
+ )
208
+ parser.add_argument(
209
+ "--skip_eval",
210
+ action="store_true",
211
+ help="Skip pre/post perplexity evaluation",
212
+ )
213
+ parser.add_argument(
214
+ "--eval_dataset",
215
+ action="append",
216
+ default=[],
217
+ help="Evaluation dataset name (repeatable). Defaults to wikitext.",
218
+ )
219
+ parser.add_argument(
220
+ "--eval_dataset_config",
221
+ action="append",
222
+ default=[],
223
+ help="Evaluation dataset config (repeatable or single shared config).",
224
+ )
225
+ parser.add_argument(
226
+ "--eval_split",
227
+ default="test",
228
+ help="Evaluation dataset split (default: test)",
229
+ )
230
+ parser.add_argument(
231
+ "--eval_text_field",
232
+ default=None,
233
+ help="Evaluation text field override (default: auto-detect)",
234
+ )
235
+ parser.add_argument(
236
+ "--eval_model_family",
237
+ type=str,
238
+ choices=["auto", "llama", "qwen"],
239
+ default="auto",
240
+ help="Model family for BOS handling during eval",
241
+ )
242
+ parser.add_argument(
243
+ "--eval_add_bos",
244
+ type=str,
245
+ choices=["auto", "always", "never"],
246
+ default="auto",
247
+ help="Whether to prepend BOS to each eval sample",
248
+ )
249
+ parser.add_argument(
250
+ "--eval_num_samples",
251
+ type=int,
252
+ default=0,
253
+ help="Number of token sequences per eval dataset (0 = all)",
254
+ )
255
+ parser.add_argument(
256
+ "--eval_seq_len",
257
+ type=int,
258
+ default=2048,
259
+ help="Sequence length for eval",
260
+ )
261
+ parser.add_argument(
262
+ "--eval_batch_size",
263
+ type=int,
264
+ default=None,
265
+ help="Batch size for eval (defaults to --batch_size)",
266
+ )
267
+ parser.add_argument(
268
+ "--eval_max_batches",
269
+ type=int,
270
+ default=None,
271
+ help="Optional max number of eval batches per dataset",
272
+ )
273
+ parser.add_argument(
274
+ "--eval_cache_dir",
275
+ default=None,
276
+ help="Optional datasets cache dir for eval",
277
+ )
278
+ parser.add_argument(
279
+ "--eval_num_workers",
280
+ type=int,
281
+ default=0,
282
+ help="Eval DataLoader workers",
283
+ )
284
+ parser.add_argument(
285
+ "--eval_device",
286
+ default=None,
287
+ help="Device for eval (defaults to --device)",
288
+ )
289
+ parser.add_argument(
290
+ "--skip_distill",
291
+ action="store_true",
292
+ help="Skip reparameterized distillation after head alignment/Fisher setup",
293
+ )
294
+ parser.add_argument(
295
+ "--distill_calib_samples",
296
+ type=int,
297
+ default=256,
298
+ help="Number of distillation sequences from calibration datasets",
299
+ )
300
+ parser.add_argument(
301
+ "--distill_inst_samples",
302
+ type=int,
303
+ default=0,
304
+ help="Number of distillation sequences from instruction dataset (0 = all)",
305
+ )
306
+ parser.add_argument(
307
+ "--distill_seq_len",
308
+ type=int,
309
+ default=512,
310
+ help="Sequence length for distillation",
311
+ )
312
+ parser.add_argument(
313
+ "--distill_batch_size",
314
+ type=int,
315
+ default=2,
316
+ help="Batch size for distillation",
317
+ )
318
+ parser.add_argument(
319
+ "--distill_epochs",
320
+ type=float,
321
+ default=1.0,
322
+ help="Number of distillation epochs (float allowed, e.g. 0.5)",
323
+ )
324
+ parser.add_argument(
325
+ "--distill_lr",
326
+ type=float,
327
+ default=1e-4,
328
+ help="Learning rate for distillation",
329
+ )
330
+ parser.add_argument(
331
+ "--distill_method",
332
+ choices=["reparam"],
333
+ default="reparam",
334
+ help="Distillation strategy (reparam only).",
335
+ )
336
+ parser.add_argument(
337
+ "--distill_kl_weight",
338
+ type=float,
339
+ default=1e-2,
340
+ help="Weight for KL loss on logits",
341
+ )
342
+ parser.add_argument(
343
+ "--distill_kl_temp",
344
+ type=float,
345
+ default=4.0,
346
+ help="Temperature for KL distillation on logits",
347
+ )
348
+ parser.add_argument(
349
+ "--distill_hidden_mse_weight",
350
+ type=float,
351
+ default=1.0,
352
+ help="Weight for hidden-state MSE in reparam distillation (0 disables it)",
353
+ )
354
+ parser.add_argument(
355
+ "--distill_attn_mse_weight",
356
+ type=float,
357
+ default=0.0,
358
+ help="Weight for auxiliary attention-output MSE in reparam distillation",
359
+ )
360
+ parser.add_argument(
361
+ "--distill_mlp_mse_weight",
362
+ type=float,
363
+ default=0.0,
364
+ help="Weight for auxiliary MLP-output MSE in reparam distillation",
365
+ )
366
+ parser.add_argument(
367
+ "--reparam_eta",
368
+ type=float,
369
+ default=1e-2,
370
+ help="Eta: ||lambda - lambda_gate||^2 regularizer weight for --distill_method reparam",
371
+ )
372
+ parser.add_argument(
373
+ "--reparam_gamma",
374
+ type=float,
375
+ default=1e-4,
376
+ help="Gamma: ||U - U0||^2 regularizer weight for --distill_method reparam",
377
+ )
378
+ parser.add_argument(
379
+ "--reparam_attn_reg_scale",
380
+ type=float,
381
+ default=1.0,
382
+ help="Relative scale applied to attention-parameter reparam regularizers",
383
+ )
384
+ parser.add_argument(
385
+ "--reparam_mlp_reg_scale",
386
+ type=float,
387
+ default=1.0,
388
+ help="Relative scale applied to MLP-parameter reparam regularizers",
389
+ )
390
+ parser.add_argument(
391
+ "--reparam_param_subset",
392
+ type=str,
393
+ choices=["all", "mlp", "attn"],
394
+ default="all",
395
+ help="Restrict reparam merge/recovery capacity to only this parameter family",
396
+ )
397
+ parser.add_argument(
398
+ "--norm_policy",
399
+ type=str,
400
+ choices=["hybrid", "merge_all", "copy_n1", "copy_n1_n2"],
401
+ default="hybrid",
402
+ help="Norm merge policy (default: hybrid)",
403
+ )
404
+ parser.add_argument(
405
+ "--distill_weight_decay",
406
+ type=float,
407
+ default=0.0,
408
+ help="Weight decay for distillation",
409
+ )
410
+ parser.add_argument(
411
+ "--distill_max_grad_norm",
412
+ type=float,
413
+ default=1.0,
414
+ help="Max grad norm for distillation",
415
+ )
416
+ parser.add_argument(
417
+ "--distill_grad_accum_steps",
418
+ type=int,
419
+ default=1,
420
+ help="Gradient accumulation steps for distillation",
421
+ )
422
+ parser.add_argument(
423
+ "--distill_log_steps",
424
+ type=int,
425
+ default=100,
426
+ help="Log distillation loss every N steps",
427
+ )
428
+ parser.add_argument(
429
+ "--distill_eval_every",
430
+ type=int,
431
+ default=0,
432
+ help="Evaluate PPL every N distill steps (0 = disable)",
433
+ )
434
+ parser.add_argument(
435
+ "--distill_eval_max_batches",
436
+ type=int,
437
+ default=None,
438
+ help="Max eval batches per dataset during distill (default: all)",
439
+ )
440
+ parser.add_argument(
441
+ "--distill_teacher_device",
442
+ default=None,
443
+ help="Device for teacher model during distillation (defaults to --device)",
444
+ )
445
+ parser.add_argument(
446
+ "--comm_enabled",
447
+ action="store_true",
448
+ help=(
449
+ "Enable commutator-style preconditioning before each progressive "
450
+ "cycle's fusion."
451
+ ),
452
+ )
453
+ parser.add_argument(
454
+ "--comm_include_cycle1",
455
+ action="store_true",
456
+ help="Run commutator preconditioning for cycle 1 as well (default: skip cycle 1).",
457
+ )
458
+ parser.add_argument(
459
+ "--comm_topk",
460
+ type=int,
461
+ default=1,
462
+ help="Top-K lowest-score pairs used as the commutator candidate set",
463
+ )
464
+ parser.add_argument(
465
+ "--comm_sample_eta",
466
+ type=float,
467
+ default=0.5,
468
+ help="Mixture weight between uniform and score-biased candidate sampling",
469
+ )
470
+ parser.add_argument(
471
+ "--comm_sample_dwce_scale",
472
+ type=float,
473
+ default=1.0,
474
+ help="Scale c in softmax(-c * score(i)) for commutator pair sampling",
475
+ )
476
+ parser.add_argument(
477
+ "--comm_temp",
478
+ type=float,
479
+ default=2.0,
480
+ help="Temperature for teacher-anchor KL in commutator preconditioning",
481
+ )
482
+ parser.add_argument(
483
+ "--comm_steps_ratio",
484
+ type=float,
485
+ default=0.1,
486
+ help="Run this fraction of distillation optimizer steps for commutator phase",
487
+ )
488
+ parser.add_argument(
489
+ "--comm_lr_scale",
490
+ type=float,
491
+ default=0.1,
492
+ help="Commutator LR = --distill_lr * this scale",
493
+ )
494
+ parser.add_argument(
495
+ "--comm_train_mode",
496
+ choices=["lora", "full"],
497
+ default="lora",
498
+ help=(
499
+ "Commutator trainable parameter mode: "
500
+ "'lora' updates LoRA adapters on sampled receiver layers; "
501
+ "'full' updates full receiver-layer weights."
502
+ ),
503
+ )
504
+ parser.add_argument(
505
+ "--comm_interaction_mode",
506
+ choices=["mse", "relative"],
507
+ default="relative",
508
+ help="Interaction loss form: plain MSE or relative MSE",
509
+ )
510
+ parser.add_argument(
511
+ "--comm_interaction_eps",
512
+ type=float,
513
+ default=1e-8,
514
+ help="Epsilon for relative commutator interaction normalization",
515
+ )
516
+ parser.add_argument(
517
+ "--comm_mu",
518
+ type=float,
519
+ default=None,
520
+ help=(
521
+ "Weight for interaction loss. Defaults to 0.1 for --comm_interaction_mode=mse "
522
+ "and 0.5 for --comm_interaction_mode=relative."
523
+ ),
524
+ )
525
+ parser.add_argument(
526
+ "--comm_mu_auto",
527
+ action="store_true",
528
+ help="Enable automatic mu scaling via gradient-norm balancing",
529
+ )
530
+ parser.add_argument(
531
+ "--comm_mu_auto_rho",
532
+ type=float,
533
+ default=0.1,
534
+ help="Target anchor-to-interaction gradient ratio constant for auto-mu",
535
+ )
536
+ parser.add_argument(
537
+ "--comm_mu_auto_eps",
538
+ type=float,
539
+ default=1e-8,
540
+ help="Numerical epsilon in auto-mu denominator",
541
+ )
542
+ parser.add_argument(
543
+ "--comm_log_steps",
544
+ type=int,
545
+ default=50,
546
+ help="Log commutator preconditioning loss every N optimizer steps",
547
+ )
548
+ parser.add_argument(
549
+ "--comm_skip_post_reselect",
550
+ action="store_true",
551
+ help=(
552
+ "Keep the pre-comm selected fusion pair and skip recomputing "
553
+ "selection after commutator preconditioning."
554
+ ),
555
+ )
556
+ parser.add_argument(
557
+ "--redistrib_teacher_source",
558
+ type=str,
559
+ choices=["base_model", "previous_cycle"],
560
+ default="base_model",
561
+ help=(
562
+ "Teacher source for commutator preconditioning teacher loading. "
563
+ "'base_model' uses --model for all cycles; "
564
+ "'previous_cycle' uses cycle-1 checkpoint (cycle 1 falls back to base_model)."
565
+ ),
566
+ )
567
+ parser.add_argument(
568
+ "--lora_epochs",
569
+ type=float,
570
+ default=1.0,
571
+ help="LoRA CE finetuning epochs after distill (0 = disable)",
572
+ )
573
+ parser.add_argument(
574
+ "--lora_rank",
575
+ type=int,
576
+ default=8,
577
+ help="LoRA rank (r)",
578
+ )
579
+ parser.add_argument(
580
+ "--lora_alpha",
581
+ type=float,
582
+ default=16.0,
583
+ help="LoRA alpha",
584
+ )
585
+ parser.add_argument(
586
+ "--lora_dropout",
587
+ type=float,
588
+ default=0.0,
589
+ help="LoRA dropout",
590
+ )
591
+ parser.add_argument(
592
+ "--lora_kl_enabled",
593
+ action="store_true",
594
+ help="Add KL regularization between pre/post LoRA logits",
595
+ )
596
+ parser.add_argument(
597
+ "--lora_kl_weight",
598
+ type=float,
599
+ default=1e-1,
600
+ help="KL weight for LoRA regularization",
601
+ )
602
+ parser.add_argument(
603
+ "--lora_kl_temp",
604
+ type=float,
605
+ default=4.0,
606
+ help="Temperature for LoRA KL regularization",
607
+ )
608
+ parser.add_argument(
609
+ "--lora_target_modules",
610
+ nargs="*",
611
+ default=[
612
+ "q_proj",
613
+ "k_proj",
614
+ "v_proj",
615
+ "o_proj",
616
+ "gate_proj",
617
+ "down_proj",
618
+ "up_proj",
619
+ ],
620
+ help="Module name suffixes to LoRA-wrap",
621
+ )
622
+ parser.add_argument(
623
+ "--lora_respect_exclude_pairs",
624
+ action="store_true",
625
+ help=(
626
+ "When attaching LoRA adapters, skip linear modules under layers touched by "
627
+ "--exclude_pairs (i and i+1 for each excluded pair)."
628
+ ),
629
+ )
630
+ parser.add_argument(
631
+ "--lora_lr",
632
+ type=float,
633
+ default=1e-4,
634
+ help="Learning rate for LoRA finetuning",
635
+ )
636
+ parser.add_argument(
637
+ "--lora_weight_decay",
638
+ type=float,
639
+ default=0.0,
640
+ help="Weight decay for LoRA finetuning",
641
+ )
642
+ parser.add_argument(
643
+ "--lora_max_grad_norm",
644
+ type=float,
645
+ default=1.0,
646
+ help="Max grad norm for LoRA finetuning",
647
+ )
648
+ parser.add_argument(
649
+ "--lora_grad_accum_steps",
650
+ type=int,
651
+ default=1,
652
+ help="Gradient accumulation steps for LoRA finetuning",
653
+ )
654
+ parser.add_argument(
655
+ "--lora_log_steps",
656
+ type=int,
657
+ default=100,
658
+ help="Log LoRA loss every N steps",
659
+ )
660
+ parser.add_argument(
661
+ "--lora_eval_every",
662
+ type=int,
663
+ default=0,
664
+ help="Evaluate PPL every N LoRA steps (0 = disable)",
665
+ )
666
+ parser.add_argument(
667
+ "--lora_eval_max_batches",
668
+ type=int,
669
+ default=None,
670
+ help="Max eval batches per dataset during LoRA (default: all)",
671
+ )
672
+ parser.add_argument(
673
+ "--instruction_dataset",
674
+ default=None,
675
+ help="HF dataset name for alpaca-style instruction data",
676
+ )
677
+ parser.add_argument(
678
+ "--instruction_config",
679
+ default=None,
680
+ help="Optional instruction dataset config",
681
+ )
682
+ parser.add_argument(
683
+ "--instruction_split",
684
+ default="train",
685
+ help="Instruction dataset split",
686
+ )
687
+ parser.add_argument(
688
+ "--instruction_field_instruction",
689
+ default="instruction",
690
+ help="Instruction field name",
691
+ )
692
+ parser.add_argument(
693
+ "--instruction_field_input",
694
+ default="input",
695
+ help="Optional input field name",
696
+ )
697
+ parser.add_argument(
698
+ "--instruction_field_output",
699
+ default="output",
700
+ help="Response/output field name",
701
+ )
702
+ parser.add_argument(
703
+ "--auto_max_batches",
704
+ type=int,
705
+ default=0,
706
+ help="Max calibration batches for auto selection scoring (0 = all)",
707
+ )
708
+ parser.add_argument(
709
+ "--auto_metric",
710
+ type=str,
711
+ choices=[
712
+ "dwce",
713
+ "cosine",
714
+ "hybrid",
715
+ "hybrid_cosine",
716
+ "hybrid_global_rel",
717
+ ],
718
+ default="dwce",
719
+ help=(
720
+ "Auto pair scoring metric. 'dwce' uses downstream-weighted composition error; "
721
+ "'cosine' uses average token-level cosine distance between adjacent layer outputs; "
722
+ "'hybrid'/'hybrid_cosine' use DWCE to shortlist then adjacent cosine for final scoring; "
723
+ "'hybrid_global_rel' uses DWCE to shortlist then reranks by the change in "
724
+ "pair-to-final-layer cosine relation after surrogate fusion."
725
+ ),
726
+ )
727
+ parser.add_argument(
728
+ "--auto_cosine_topk",
729
+ type=int,
730
+ default=3,
731
+ help="Top-K DWCE candidates to rescore with cosine in --auto_metric=hybrid",
732
+ )
733
+ parser.add_argument(
734
+ "--auto_norm",
735
+ type=str,
736
+ choices=["relative", "none"],
737
+ default="relative",
738
+ help="Normalization mode for DWCE scoring (ignored for cosine)",
739
+ )
740
+ parser.add_argument(
741
+ "--auto_dwce_mode",
742
+ type=str,
743
+ choices=["separate", "shared"],
744
+ default="separate",
745
+ help=(
746
+ "DWCE implementation for auto scoring. "
747
+ "'separate' runs distinct Fisher and DWCE backward passes; "
748
+ "'shared' reuses one backward pass and replays DWCE with cached gradients."
749
+ ),
750
+ )
751
+ parser.add_argument(
752
+ "--num_progressive",
753
+ type=int,
754
+ default=0,
755
+ help="Number of progressive fusions (>0 required)",
756
+ )
757
+ parser.add_argument(
758
+ "--resume_from_cycle",
759
+ type=int,
760
+ default=0,
761
+ help=(
762
+ "Resume from this completed cycle index. When > 0, --model should point "
763
+ "to the saved full model directory for that cycle."
764
+ ),
765
+ )
766
+ parser.add_argument(
767
+ "--save_full_model_cycles",
768
+ nargs="*",
769
+ default=[],
770
+ help=(
771
+ "Cycle indices whose full models should be saved. Requesting cycle c "
772
+ "also saves cycle c-1 automatically (c=1 saves only cycle 1)."
773
+ ),
774
+ )
775
+ return parser.parse_args()
776
+
777
+
778
+ def parse_exclude_pairs(exclude_raw: Optional[List[str]], num_pairs: int) -> List[int]:
779
+ """Parse --exclude_pairs into normalized pair indices for the current model.
780
+
781
+ Indices refer to the start of an adjacent pair (i, i+1) and must be in [0..N-2].
782
+ Negative indices count from the end (-1 = last pair).
783
+ """
784
+ if not exclude_raw:
785
+ return []
786
+ exclude: List[int] = []
787
+ for item in exclude_raw:
788
+ if item is None:
789
+ continue
790
+ for part in str(item).split(","):
791
+ part = part.strip()
792
+ if not part:
793
+ continue
794
+ try:
795
+ idx = int(part)
796
+ except ValueError as exc:
797
+ raise SystemExit("--exclude_pairs must contain integers.") from exc
798
+ if idx < 0:
799
+ idx = num_pairs + idx
800
+ if 0 <= idx < num_pairs:
801
+ exclude.append(idx)
802
+ return sorted(set(exclude))
803
+
804
+
805
+ def parse_cycle_list(raw_values: Optional[List[str]]) -> List[int]:
806
+ if not raw_values:
807
+ return []
808
+ cycles: List[int] = []
809
+ for item in raw_values:
810
+ if item is None:
811
+ continue
812
+ for part in str(item).split(","):
813
+ part = part.strip()
814
+ if not part:
815
+ continue
816
+ try:
817
+ cycles.append(int(part))
818
+ except ValueError as exc:
819
+ raise SystemExit(
820
+ "--save_full_model_cycles must contain integers."
821
+ ) from exc
822
+ return cycles
823
+
824
+
825
+ def resolve_full_model_save_cycles(
826
+ requested_cycles: List[int], num_progressive: int
827
+ ) -> Set[int]:
828
+ resolved: Set[int] = set()
829
+ for cycle in requested_cycles:
830
+ if cycle <= 0 or cycle > num_progressive:
831
+ raise SystemExit(
832
+ "--save_full_model_cycles entries must be within [1, --num_progressive]."
833
+ )
834
+ resolved.add(cycle)
835
+ if cycle > 1:
836
+ resolved.add(cycle - 1)
837
+ return resolved
838
+
839
+
840
+ def load_resume_metadata(model_path: str) -> Optional[Dict[str, object]]:
841
+ resume_meta_path = os.path.join(model_path, "resume_info.json")
842
+ if not os.path.exists(resume_meta_path):
843
+ return None
844
+ with open(resume_meta_path, "r", encoding="utf-8") as handle:
845
+ loaded = json.load(handle)
846
+ return loaded if isinstance(loaded, dict) else None
847
+
848
+
849
+ def build_generator(seed: int) -> torch.Generator:
850
+ generator = torch.Generator(device="cpu")
851
+ generator.manual_seed(int(seed))
852
+ return generator
853
+
854
+
855
+ def capture_rng_state() -> Dict[str, object]:
856
+ state: Dict[str, object] = {
857
+ "python_random_state": random.getstate(),
858
+ "torch_cpu_rng_state": torch.get_rng_state(),
859
+ }
860
+ if np is not None:
861
+ state["numpy_random_state"] = np.random.get_state()
862
+ if torch.cuda.is_available():
863
+ state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
864
+ return state
865
+
866
+
867
+ def restore_rng_state(state: Dict[str, object]) -> None:
868
+ python_state = state.get("python_random_state")
869
+ if python_state is not None:
870
+ random.setstate(python_state)
871
+
872
+ numpy_state = state.get("numpy_random_state")
873
+ if numpy_state is not None and np is not None:
874
+ np.random.set_state(numpy_state)
875
+
876
+ torch_cpu_state = state.get("torch_cpu_rng_state")
877
+ if torch_cpu_state is not None:
878
+ torch.set_rng_state(torch_cpu_state)
879
+
880
+ torch_cuda_state = state.get("torch_cuda_rng_state_all")
881
+ if torch_cuda_state is not None and torch.cuda.is_available():
882
+ torch.cuda.set_rng_state_all(torch_cuda_state)
883
+
884
+
885
+ def save_rng_state(path: str) -> None:
886
+ torch.save(capture_rng_state(), path)
887
+
888
+
889
+ def load_rng_state(path: str) -> Optional[Dict[str, object]]:
890
+ if not os.path.exists(path):
891
+ return None
892
+ loaded = torch.load(path, map_location="cpu", weights_only=False)
893
+ return loaded if isinstance(loaded, dict) else None
894
+
895
+
896
+ def configure_reproducibility(seed: int) -> None:
897
+ random.seed(seed)
898
+ if np is not None:
899
+ np.random.seed(seed)
900
+ torch.manual_seed(seed)
901
+ if torch.cuda.is_available():
902
+ torch.cuda.manual_seed_all(seed)
903
+ if hasattr(torch.backends, "cudnn"):
904
+ torch.backends.cudnn.deterministic = True
905
+ torch.backends.cudnn.benchmark = False
906
+ if hasattr(torch, "use_deterministic_algorithms"):
907
+ torch.use_deterministic_algorithms(True, warn_only=True)
908
+
909
+
910
+ def save_loader_generator_state(
911
+ base_dir: str,
912
+ *,
913
+ distill_generator: Optional[torch.Generator] = None,
914
+ lora_generator: Optional[torch.Generator] = None,
915
+ ) -> None:
916
+ state: Dict[str, object] = {}
917
+ if distill_generator is not None:
918
+ state["distill_generator_state"] = distill_generator.get_state()
919
+ if lora_generator is not None:
920
+ state["lora_generator_state"] = lora_generator.get_state()
921
+ if state:
922
+ torch.save(state, os.path.join(base_dir, "loader_generators.pt"))
923
+
924
+
925
+ def load_loader_generator_state(base_dir: str) -> Optional[Dict[str, object]]:
926
+ path = os.path.join(base_dir, "loader_generators.pt")
927
+ if not os.path.exists(path):
928
+ return None
929
+ loaded = torch.load(path, map_location="cpu")
930
+ return loaded if isinstance(loaded, dict) else None
931
+
932
+
933
+ def resolve_layer_idx(
934
+ args: argparse.Namespace,
935
+ model,
936
+ layers: List[torch.nn.Module],
937
+ dataloader,
938
+ previous_scores,
939
+ start_index: int,
940
+ exclude_pairs: Set[int],
941
+ ):
942
+ layer_arg = str(getattr(args, "layer", "auto")).strip().lower()
943
+ selection_method = str(getattr(args, "selection_method", "dwce")).strip().lower()
944
+
945
+ if layer_arg != "auto":
946
+ try:
947
+ layer_idx = int(layer_arg)
948
+ except ValueError as exc:
949
+ raise SystemExit("--layer must be 'auto' or an integer index") from exc
950
+ num_pairs = max(len(layers) - 1, 0)
951
+ if layer_idx < 0:
952
+ layer_idx += num_pairs
953
+ if layer_idx in exclude_pairs:
954
+ raise SystemExit(f"--layer resolved to excluded pair index {layer_idx}")
955
+ return layer_idx, previous_scores, {"method": "manual", "exclude_pairs": sorted(exclude_pairs)}
956
+
957
+ if selection_method == "sequential":
958
+ num_pairs = len(layers) - 1
959
+ for layer_idx in range(max(start_index, 0), num_pairs):
960
+ if layer_idx not in exclude_pairs:
961
+ return layer_idx, previous_scores, {
962
+ "method": "sequential",
963
+ "start_index": max(start_index, 0),
964
+ "exclude_pairs": sorted(exclude_pairs),
965
+ }
966
+ raise SystemExit("No eligible layer pairs remain after exclusions")
967
+
968
+ layer_idx, dwce_scores, dwce_meta = select_layer_auto(
969
+ model,
970
+ layers,
971
+ dataloader,
972
+ args,
973
+ previous_scores=previous_scores,
974
+ start_index=start_index,
975
+ exclude_pairs=exclude_pairs,
976
+ )
977
+ return layer_idx, dwce_scores, dwce_meta
978
+
979
+
980
+ @dataclass
981
+ class PreparedData:
982
+ calib_loader: torch.utils.data.DataLoader
983
+ calib_num_sequences: int
984
+ distill_loader: Optional[torch.utils.data.DataLoader]
985
+ distill_generator: Optional[torch.Generator]
986
+ distill_meta: Dict[str, object]
987
+ lora_loader: Optional[torch.utils.data.DataLoader]
988
+ lora_generator: Optional[torch.Generator]
989
+ lora_meta: Dict[str, object]
990
+ eval_datasets: List[str]
991
+ eval_configs: List[Optional[str]]
992
+ eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]]
993
+
994
+
995
+ def resolve_eval_datasets(args: argparse.Namespace) -> Tuple[List[str], List[Optional[str]]]:
996
+ eval_datasets = args.eval_dataset or ["wikitext"]
997
+ eval_configs = args.eval_dataset_config or ["wikitext-2-raw-v1"]
998
+ eval_configs = ppl_eval._expand_dataset_configs(eval_datasets, eval_configs)
999
+ return eval_datasets, eval_configs
1000
+
1001
+
1002
+ def run_ppl_eval(
1003
+ model_id_or_path: str,
1004
+ eval_datasets: List[str],
1005
+ eval_configs: List[Optional[str]],
1006
+ args: argparse.Namespace,
1007
+ prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
1008
+ ) -> Dict[str, float]:
1009
+ eval_device = args.eval_device or args.device
1010
+ dtype = get_dtype(args.dtype)
1011
+
1012
+ eval_model = load_causal_lm(
1013
+ model_id_or_path,
1014
+ torch_dtype=dtype,
1015
+ trust_remote_code=args.trust_remote_code,
1016
+ )
1017
+
1018
+ eval_model.to(eval_device)
1019
+ if prepared_eval_dataloaders is not None:
1020
+ results = ppl_eval.evaluate_ppl_dataloaders(
1021
+ eval_model,
1022
+ prepared_eval_dataloaders,
1023
+ eval_device,
1024
+ max_batches=args.eval_max_batches,
1025
+ )
1026
+ else:
1027
+ eval_batch_size = args.eval_batch_size or args.batch_size
1028
+ eval_tokenizer = AutoTokenizer.from_pretrained(
1029
+ model_id_or_path, trust_remote_code=args.trust_remote_code
1030
+ )
1031
+ if eval_tokenizer.pad_token is None and eval_tokenizer.eos_token is not None:
1032
+ eval_tokenizer.pad_token = eval_tokenizer.eos_token
1033
+
1034
+ results = ppl_eval.evaluate_ppl_datasets(
1035
+ eval_model,
1036
+ eval_tokenizer,
1037
+ datasets=eval_datasets,
1038
+ configs=eval_configs,
1039
+ split=args.eval_split,
1040
+ text_field=args.eval_text_field,
1041
+ num_samples=args.eval_num_samples,
1042
+ seq_len=args.eval_seq_len,
1043
+ batch_size=eval_batch_size,
1044
+ device=eval_device,
1045
+ seed=args.seed,
1046
+ shuffle=False,
1047
+ model_family=args.eval_model_family,
1048
+ add_bos=args.eval_add_bos,
1049
+ max_batches=args.eval_max_batches,
1050
+ cache_dir=args.eval_cache_dir,
1051
+ num_workers=args.eval_num_workers,
1052
+ )
1053
+
1054
+ del eval_model
1055
+ if torch.cuda.is_available():
1056
+ torch.cuda.empty_cache()
1057
+
1058
+ return results
1059
+
1060
+
1061
+ def build_calibration_dataloader(
1062
+ args: argparse.Namespace, tokenizer
1063
+ ) -> Tuple[List[str], List[torch.Tensor], torch.utils.data.DataLoader]:
1064
+ if args.dataset:
1065
+ datasets = list(args.dataset)
1066
+ configs = expand_dataset_configs(datasets, list(args.dataset_config))
1067
+ chunks: List[torch.Tensor] = []
1068
+ for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
1069
+ spec = SharedLMDataSpec(
1070
+ dataset=dataset_name,
1071
+ config=config,
1072
+ split=args.dataset_split,
1073
+ text_field=args.dataset_text_field,
1074
+ seq_len=args.seq_len,
1075
+ num_sequences=args.num_samples,
1076
+ seed=args.seed + idx,
1077
+ )
1078
+ chunks.extend(build_chunks(spec, tokenizer))
1079
+ if not chunks:
1080
+ raise SystemExit("Not enough text to build token sequences.")
1081
+ input_ids = torch.stack(chunks)
1082
+ attention_mask = torch.ones_like(input_ids)
1083
+ dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
1084
+ dataloader = torch.utils.data.DataLoader(
1085
+ dataset, batch_size=args.batch_size, shuffle=False
1086
+ )
1087
+ return [], chunks, dataloader
1088
+
1089
+ texts = load_texts(args)
1090
+ if not texts:
1091
+ raise SystemExit(
1092
+ "No calibration text found. Provide --dataset, --text, or --text_file."
1093
+ )
1094
+
1095
+ chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples)
1096
+ if not chunks:
1097
+ raise SystemExit("Not enough text to build token sequences.")
1098
+
1099
+ input_ids = torch.stack(chunks)
1100
+ attention_mask = torch.ones_like(input_ids)
1101
+ dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
1102
+ dataloader = torch.utils.data.DataLoader(
1103
+ dataset, batch_size=args.batch_size, shuffle=False
1104
+ )
1105
+ return texts, chunks, dataloader
1106
+
1107
+
1108
+ def prepare_distillation_data(
1109
+ args: argparse.Namespace, tokenizer, include_instruction: bool = True
1110
+ ) -> Tuple[Optional[torch.utils.data.DataLoader], Optional[torch.Generator], Dict[str, object]]:
1111
+ if (
1112
+ include_instruction
1113
+ and args.distill_inst_samples != 0
1114
+ and not args.instruction_dataset
1115
+ ):
1116
+ print(
1117
+ "Warning: --distill_inst_samples > 0 but no --instruction_dataset "
1118
+ "provided; instruction distillation will be skipped."
1119
+ )
1120
+
1121
+ calib_texts: List[str] = []
1122
+ calib_dataset = None
1123
+ if args.target_tokens > 0 and args.dataset:
1124
+ datasets = list(args.dataset)
1125
+ configs = expand_dataset_configs(datasets, list(args.dataset_config))
1126
+ per_dataset = args.target_tokens // len(datasets)
1127
+ remainder = args.target_tokens % len(datasets)
1128
+ calib_chunks: List[torch.Tensor] = []
1129
+ for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
1130
+ dataset_tokens = per_dataset + (remainder if idx == 0 else 0)
1131
+ spec = SharedLMDataSpec(
1132
+ dataset=dataset_name,
1133
+ config=config,
1134
+ split=args.dataset_split,
1135
+ text_field=args.dataset_text_field,
1136
+ seq_len=args.distill_seq_len,
1137
+ target_tokens=dataset_tokens,
1138
+ seed=args.seed + 17 + idx,
1139
+ )
1140
+ calib_chunks.extend(build_chunks(spec, tokenizer))
1141
+ if calib_chunks:
1142
+ input_ids = torch.stack(calib_chunks)
1143
+ attention_mask = torch.ones_like(input_ids)
1144
+ calib_dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
1145
+ else:
1146
+ calib_texts = load_texts_from_datasets(
1147
+ datasets=list(args.dataset),
1148
+ configs=expand_dataset_configs(list(args.dataset), list(args.dataset_config)),
1149
+ split=args.dataset_split,
1150
+ text_field=args.dataset_text_field,
1151
+ num_samples=args.distill_calib_samples,
1152
+ seed=args.seed + 17,
1153
+ )
1154
+ inst_records = []
1155
+ if include_instruction:
1156
+ inst_records = load_instruction_records(args, args.distill_inst_samples)
1157
+
1158
+ distill_datasets = []
1159
+ if calib_dataset is not None:
1160
+ distill_datasets.append(calib_dataset)
1161
+ elif calib_texts:
1162
+ calib_records = [{"text": text} for text in calib_texts]
1163
+ distill_datasets.append(
1164
+ FixedSeqDataset(calib_records, tokenizer, args.distill_seq_len)
1165
+ )
1166
+ if inst_records:
1167
+ distill_datasets.append(
1168
+ FixedSeqDataset(inst_records, tokenizer, args.distill_seq_len)
1169
+ )
1170
+
1171
+ distill_meta: Dict[str, object] = {
1172
+ "calib_texts": len(calib_texts),
1173
+ "calib_sequences": len(calib_dataset) if calib_dataset is not None else len(calib_texts),
1174
+ "inst_sequences": len(inst_records),
1175
+ "total_sequences": 0,
1176
+ }
1177
+
1178
+ if not distill_datasets:
1179
+ return None, None, distill_meta
1180
+ if len(distill_datasets) == 1:
1181
+ distill_dataset = distill_datasets[0]
1182
+ else:
1183
+ distill_dataset = torch.utils.data.ConcatDataset(distill_datasets)
1184
+
1185
+ distill_meta["total_sequences"] = len(distill_dataset)
1186
+ distill_generator = build_generator(
1187
+ args.seed + 1000 + (1000000 if include_instruction else 0)
1188
+ )
1189
+ distill_loader = torch.utils.data.DataLoader(
1190
+ distill_dataset,
1191
+ batch_size=args.distill_batch_size,
1192
+ shuffle=True,
1193
+ generator=distill_generator,
1194
+ )
1195
+ return distill_loader, distill_generator, distill_meta
1196
+
1197
+
1198
+ def prepare_eval_dataloaders(
1199
+ args: argparse.Namespace,
1200
+ tokenizer,
1201
+ model: torch.nn.Module,
1202
+ eval_datasets: List[str],
1203
+ eval_configs: List[Optional[str]],
1204
+ ) -> Optional[Dict[str, torch.utils.data.DataLoader]]:
1205
+ needs_eval = (not args.skip_eval) or (
1206
+ (not args.skip_distill and args.distill_eval_every)
1207
+ or (args.lora_epochs > 0 and args.lora_eval_every)
1208
+ )
1209
+ if not needs_eval:
1210
+ return None
1211
+
1212
+ eval_batch_size = args.eval_batch_size or args.batch_size
1213
+ resolved_family = args.eval_model_family
1214
+ if resolved_family == "auto":
1215
+ resolved_family = ppl_eval._infer_model_family(model)
1216
+
1217
+ return ppl_eval.prepare_ppl_dataloaders(
1218
+ tokenizer=tokenizer,
1219
+ datasets=eval_datasets,
1220
+ configs=eval_configs,
1221
+ split=args.eval_split,
1222
+ text_field=args.eval_text_field,
1223
+ num_samples=args.eval_num_samples,
1224
+ seq_len=args.eval_seq_len,
1225
+ batch_size=eval_batch_size,
1226
+ seed=args.seed,
1227
+ shuffle=False,
1228
+ model_family=resolved_family,
1229
+ add_bos=args.eval_add_bos,
1230
+ cache_dir=args.eval_cache_dir,
1231
+ num_workers=args.eval_num_workers,
1232
+ model=model,
1233
+ )
1234
+
1235
+
1236
+ def prepare_all_data(
1237
+ args: argparse.Namespace,
1238
+ tokenizer,
1239
+ model: torch.nn.Module,
1240
+ eval_datasets: List[str],
1241
+ eval_configs: List[Optional[str]],
1242
+ loader_generator_state: Optional[Dict[str, object]] = None,
1243
+ ) -> PreparedData:
1244
+ texts, chunks, calib_loader = build_calibration_dataloader(args, tokenizer)
1245
+ calib_num_sequences = len(chunks)
1246
+ del texts
1247
+ del chunks
1248
+
1249
+ distill_loader = None
1250
+ distill_generator = None
1251
+ distill_meta: Dict[str, object] = {
1252
+ "calib_texts": 0,
1253
+ "calib_sequences": 0,
1254
+ "inst_sequences": 0,
1255
+ "total_sequences": 0,
1256
+ }
1257
+ lora_loader = None
1258
+ lora_generator = None
1259
+ lora_meta: Dict[str, object] = {
1260
+ "calib_texts": 0,
1261
+ "calib_sequences": 0,
1262
+ "inst_sequences": 0,
1263
+ "total_sequences": 0,
1264
+ }
1265
+ if (not args.skip_distill) or bool(getattr(args, "comm_enabled", False)):
1266
+ distill_loader, distill_generator, distill_meta = prepare_distillation_data(
1267
+ args, tokenizer, include_instruction=False
1268
+ )
1269
+ if (
1270
+ distill_generator is not None
1271
+ and loader_generator_state is not None
1272
+ and loader_generator_state.get("distill_generator_state") is not None
1273
+ ):
1274
+ distill_generator.set_state(loader_generator_state["distill_generator_state"])
1275
+ if args.lora_epochs > 0:
1276
+ lora_loader, lora_generator, lora_meta = prepare_distillation_data(
1277
+ args, tokenizer, include_instruction=True
1278
+ )
1279
+ if (
1280
+ lora_generator is not None
1281
+ and loader_generator_state is not None
1282
+ and loader_generator_state.get("lora_generator_state") is not None
1283
+ ):
1284
+ lora_generator.set_state(loader_generator_state["lora_generator_state"])
1285
+
1286
+ eval_dataloaders = prepare_eval_dataloaders(
1287
+ args, tokenizer, model, eval_datasets, eval_configs
1288
+ )
1289
+
1290
+ return PreparedData(
1291
+ calib_loader=calib_loader,
1292
+ calib_num_sequences=calib_num_sequences,
1293
+ distill_loader=distill_loader,
1294
+ distill_generator=distill_generator,
1295
+ distill_meta=distill_meta,
1296
+ lora_loader=lora_loader,
1297
+ lora_generator=lora_generator,
1298
+ lora_meta=lora_meta,
1299
+ eval_datasets=eval_datasets,
1300
+ eval_configs=eval_configs,
1301
+ eval_dataloaders=eval_dataloaders,
1302
+ )
1303
+
1304
+
1305
+ def evaluate_ppl_model(
1306
+ model: torch.nn.Module,
1307
+ tokenizer,
1308
+ eval_datasets: List[str],
1309
+ eval_configs: List[Optional[str]],
1310
+ args: argparse.Namespace,
1311
+ max_batches: Optional[int] = None,
1312
+ prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
1313
+ ) -> Dict[str, float]:
1314
+ eval_device = args.eval_device or args.device
1315
+ prev_mode = model.training
1316
+ try:
1317
+ prev_device = next(model.parameters()).device
1318
+ except StopIteration:
1319
+ prev_device = torch.device(eval_device)
1320
+
1321
+ model.eval()
1322
+ if str(prev_device) != eval_device:
1323
+ model.to(eval_device)
1324
+
1325
+ if prepared_eval_dataloaders is not None:
1326
+ results = ppl_eval.evaluate_ppl_dataloaders(
1327
+ model,
1328
+ prepared_eval_dataloaders,
1329
+ eval_device,
1330
+ max_batches=max_batches if max_batches is not None else args.eval_max_batches,
1331
+ )
1332
+ else:
1333
+ eval_batch_size = args.eval_batch_size or args.batch_size
1334
+ results = ppl_eval.evaluate_ppl_datasets(
1335
+ model,
1336
+ tokenizer,
1337
+ datasets=eval_datasets,
1338
+ configs=eval_configs,
1339
+ split=args.eval_split,
1340
+ text_field=args.eval_text_field,
1341
+ num_samples=args.eval_num_samples,
1342
+ seq_len=args.eval_seq_len,
1343
+ batch_size=eval_batch_size,
1344
+ device=eval_device,
1345
+ seed=args.seed,
1346
+ shuffle=False,
1347
+ model_family=args.eval_model_family,
1348
+ add_bos=args.eval_add_bos,
1349
+ max_batches=max_batches if max_batches is not None else args.eval_max_batches,
1350
+ cache_dir=args.eval_cache_dir,
1351
+ num_workers=args.eval_num_workers,
1352
+ )
1353
+
1354
+ if prev_mode:
1355
+ model.train()
1356
+ if str(prev_device) != eval_device:
1357
+ model.to(prev_device)
1358
+ if torch.cuda.is_available():
1359
+ torch.cuda.empty_cache()
1360
+
1361
+ return results
1362
+
1363
+
1364
+ def has_post_fusion_data(
1365
+ distill_loader: Optional[torch.utils.data.DataLoader],
1366
+ distill_meta: Optional[Dict[str, object]],
1367
+ ) -> bool:
1368
+ if distill_loader is None or distill_meta is None:
1369
+ return False
1370
+ return distill_meta.get("total_sequences", 0) > 0
1371
+
1372
+
1373
+ def summarize_gate_lambdas(gates: Dict[str, torch.Tensor]) -> Dict[str, object]:
1374
+ if not gates:
1375
+ return {"num_tensors": 0, "num_elements": 0}
1376
+
1377
+ total_sum = 0.0
1378
+ total_elems = 0
1379
+ per_tensor_mean: Dict[str, Optional[float]] = {}
1380
+ for name, gate in gates.items():
1381
+ g = gate.detach().float()
1382
+ if g.numel() == 0:
1383
+ per_tensor_mean[name] = None
1384
+ continue
1385
+ per_tensor_mean[name] = float(g.mean().item())
1386
+ total_sum += float(g.sum().item())
1387
+ total_elems += int(g.numel())
1388
+
1389
+ global_mean = None if total_elems == 0 else total_sum / float(total_elems)
1390
+ return {
1391
+ "num_tensors": len(gates),
1392
+ "num_elements": total_elems,
1393
+ "global_mean": global_mean,
1394
+ "per_tensor_mean": per_tensor_mean,
1395
+ }
1396
+
1397
+
1398
+ def compute_path_bytes(path: str) -> int:
1399
+ if os.path.isfile(path):
1400
+ return os.path.getsize(path)
1401
+
1402
+ total = 0
1403
+ for root, _, files in os.walk(path):
1404
+ for name in files:
1405
+ file_path = os.path.join(root, name)
1406
+ if os.path.islink(file_path):
1407
+ continue
1408
+ try:
1409
+ total += os.path.getsize(file_path)
1410
+ except OSError:
1411
+ continue
1412
+ return total
1413
+
1414
+
1415
+ def save_stage_checkpoint(
1416
+ model: torch.nn.Module,
1417
+ tokenizer,
1418
+ stage_dir: str,
1419
+ stage_name: str,
1420
+ ppl_results: Optional[Dict[str, float]],
1421
+ ) -> Dict[str, object]:
1422
+ os.makedirs(stage_dir, exist_ok=True)
1423
+ colon_modules = find_colon_modules(model)
1424
+ if colon_modules:
1425
+ raise RuntimeError(
1426
+ "Unexpected module names with ':' detected before save: "
1427
+ f"{', '.join(colon_modules)}."
1428
+ )
1429
+
1430
+ model.save_pretrained(stage_dir)
1431
+ tokenizer.save_pretrained(stage_dir)
1432
+
1433
+ stage_meta = {
1434
+ "stage": stage_name,
1435
+ "path": stage_dir,
1436
+ "weight_bytes": compute_path_bytes(stage_dir),
1437
+ "post_ppl": ppl_results,
1438
+ }
1439
+ with open(
1440
+ os.path.join(stage_dir, "stage_metrics.json"),
1441
+ "w",
1442
+ encoding="utf-8",
1443
+ ) as handle:
1444
+ json.dump(stage_meta, handle, indent=2)
1445
+ return stage_meta
1446
+
1447
+
1448
+ def save_cycle_full_model(
1449
+ model: torch.nn.Module,
1450
+ tokenizer,
1451
+ cycle_dir: str,
1452
+ cycle_idx: int,
1453
+ args: argparse.Namespace,
1454
+ ppl_results: Optional[Dict[str, float]],
1455
+ ) -> Dict[str, object]:
1456
+ full_model_dir = os.path.join(cycle_dir, "full_model")
1457
+ stage_meta = save_stage_checkpoint(
1458
+ model=model,
1459
+ tokenizer=tokenizer,
1460
+ stage_dir=full_model_dir,
1461
+ stage_name=f"cycle_{cycle_idx}_full_model",
1462
+ ppl_results=ppl_results,
1463
+ )
1464
+ resume_meta = {
1465
+ "base_model": getattr(args, "base_model_id", args.model),
1466
+ "cycle": cycle_idx,
1467
+ "output_dir": args.output_dir,
1468
+ "layer_path": args.layer_path,
1469
+ "rng_state": "rng_state.pt",
1470
+ "loader_generators": "loader_generators.pt",
1471
+ }
1472
+ with open(
1473
+ os.path.join(full_model_dir, "resume_info.json"),
1474
+ "w",
1475
+ encoding="utf-8",
1476
+ ) as handle:
1477
+ json.dump(resume_meta, handle, indent=2)
1478
+ stage_meta["resume_info"] = "resume_info.json"
1479
+ return stage_meta
1480
+
1481
+
1482
+ def run_lora_phase(
1483
+ model: torch.nn.Module,
1484
+ tokenizer,
1485
+ eval_datasets: List[str],
1486
+ eval_configs: List[Optional[str]],
1487
+ args: argparse.Namespace,
1488
+ lora_loader: Optional[torch.utils.data.DataLoader] = None,
1489
+ lora_meta: Optional[Dict[str, object]] = None,
1490
+ eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
1491
+ cycle_idx: Optional[int] = None,
1492
+ num_cycles: Optional[int] = None,
1493
+ ) -> List[Dict[str, object]]:
1494
+ lora_eval_history: List[Dict[str, object]] = []
1495
+ if args.lora_epochs <= 0:
1496
+ return lora_eval_history
1497
+ if not has_post_fusion_data(lora_loader, lora_meta):
1498
+ print("No post-fusion sequences built; skipping LoRA finetuning.")
1499
+ return lora_eval_history
1500
+
1501
+ lora_ce_finetune(
1502
+ model=model,
1503
+ dataloader=lora_loader,
1504
+ eval_tokenizer=tokenizer,
1505
+ eval_datasets=eval_datasets,
1506
+ eval_configs=eval_configs,
1507
+ eval_history=lora_eval_history,
1508
+ args=args,
1509
+ eval_dataloaders=eval_dataloaders,
1510
+ progressive_cycle=cycle_idx,
1511
+ progressive_total=num_cycles,
1512
+ )
1513
+ return lora_eval_history
1514
+
1515
+
1516
+ def run_progressive(
1517
+ args: argparse.Namespace,
1518
+ model: torch.nn.Module,
1519
+ tokenizer,
1520
+ prepared: PreparedData,
1521
+ ) -> None:
1522
+ eval_datasets = prepared.eval_datasets
1523
+ eval_configs = prepared.eval_configs
1524
+
1525
+ dataloader = prepared.calib_loader
1526
+ num_sequences = prepared.calib_num_sequences
1527
+
1528
+ model.to(args.device)
1529
+
1530
+ os.makedirs(args.output_dir, exist_ok=True)
1531
+ progressive_meta_path = os.path.join(args.output_dir, "progressive_metadata.json")
1532
+ existing_meta: Dict[str, object] = {}
1533
+ if args.resume_from_cycle > 0 and os.path.exists(progressive_meta_path):
1534
+ with open(progressive_meta_path, "r", encoding="utf-8") as handle:
1535
+ loaded_meta = json.load(handle)
1536
+ if isinstance(loaded_meta, dict):
1537
+ existing_meta = loaded_meta
1538
+
1539
+ bootstrap_meta = {
1540
+ "base_model": getattr(args, "base_model_id", args.model),
1541
+ "num_progressive": args.num_progressive,
1542
+ "layer_path": args.layer_path,
1543
+ "resume_from_cycle": args.resume_from_cycle,
1544
+ "save_full_model_cycles": sorted(args.full_model_save_cycles),
1545
+ "cycles": (
1546
+ existing_meta.get("cycles", [])
1547
+ if isinstance(existing_meta.get("cycles"), list)
1548
+ else []
1549
+ ),
1550
+ }
1551
+ with open(
1552
+ progressive_meta_path,
1553
+ "w",
1554
+ encoding="utf-8",
1555
+ ) as handle:
1556
+ json.dump(bootstrap_meta, handle, indent=2)
1557
+
1558
+ pre_eval = None
1559
+ if not args.skip_eval:
1560
+ pre_eval = evaluate_ppl_model(
1561
+ model,
1562
+ tokenizer,
1563
+ eval_datasets,
1564
+ eval_configs,
1565
+ args,
1566
+ prepared_eval_dataloaders=prepared.eval_dataloaders,
1567
+ )
1568
+ print("Pre-pruning perplexity:")
1569
+ for dataset_name, ppl in pre_eval.items():
1570
+ print(f"{dataset_name}: {ppl:.4f}")
1571
+
1572
+ parent, name, container = find_layer_container(model, args.layer_path)
1573
+ layers = list(container)
1574
+ if args.num_progressive > (len(layers) - 1 + args.resume_from_cycle):
1575
+ raise SystemExit(
1576
+ f"--num_progressive ({args.num_progressive}) exceeds available pairs "
1577
+ f"after resume offset ({len(layers) - 1 + args.resume_from_cycle})"
1578
+ )
1579
+
1580
+ dwce_scores = None
1581
+ dwce_meta = None
1582
+ last_fused_idx = 0
1583
+ cycle_summaries: List[Dict[str, object]] = []
1584
+ existing_cycles = existing_meta.get("cycles", [])
1585
+ if isinstance(existing_cycles, list):
1586
+ for entry in existing_cycles:
1587
+ if not isinstance(entry, dict):
1588
+ continue
1589
+ cycle_value = entry.get("cycle")
1590
+ if isinstance(cycle_value, int) and cycle_value <= args.resume_from_cycle:
1591
+ cycle_summaries.append(entry)
1592
+
1593
+ comm_enabled = bool(getattr(args, "comm_enabled", False))
1594
+ comm_teacher_model = None
1595
+ comm_teacher_cycle: Optional[int] = None
1596
+ teacher_device = args.distill_teacher_device or args.device
1597
+ previous_cycle_teacher_model = None
1598
+ previous_cycle_teacher_cycle: Optional[int] = None
1599
+
1600
+ def _release_comm_teacher() -> None:
1601
+ nonlocal comm_teacher_model, comm_teacher_cycle
1602
+ if comm_teacher_model is not None:
1603
+ del comm_teacher_model
1604
+ comm_teacher_model = None
1605
+ comm_teacher_cycle = None
1606
+ if torch.cuda.is_available():
1607
+ torch.cuda.empty_cache()
1608
+
1609
+ def _release_previous_cycle_teacher() -> None:
1610
+ nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle
1611
+ if previous_cycle_teacher_model is not None:
1612
+ del previous_cycle_teacher_model
1613
+ previous_cycle_teacher_model = None
1614
+ previous_cycle_teacher_cycle = None
1615
+ if torch.cuda.is_available():
1616
+ torch.cuda.empty_cache()
1617
+
1618
+ def _snapshot_previous_cycle_teacher(cycle_idx: int) -> None:
1619
+ nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle
1620
+ _release_previous_cycle_teacher()
1621
+ previous_cycle_teacher_model = copy.deepcopy(model)
1622
+ previous_cycle_teacher_model.to(teacher_device)
1623
+ previous_cycle_teacher_model.eval()
1624
+ previous_cycle_teacher_cycle = cycle_idx
1625
+
1626
+ def _get_previous_cycle_teacher(
1627
+ cycle_idx: int,
1628
+ ) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]:
1629
+ prev_cycle = cycle_idx - 1
1630
+ if prev_cycle <= 0:
1631
+ return None, "base_model", 0
1632
+ if (
1633
+ previous_cycle_teacher_model is not None
1634
+ and previous_cycle_teacher_cycle == prev_cycle
1635
+ ):
1636
+ return previous_cycle_teacher_model, "previous_cycle_memory", prev_cycle
1637
+ teacher_model = load_progressive_model(
1638
+ getattr(args, "base_model_id", args.model),
1639
+ args.output_dir,
1640
+ cycle=prev_cycle,
1641
+ device=teacher_device,
1642
+ dtype=args.dtype,
1643
+ trust_remote_code=args.trust_remote_code,
1644
+ layer_path=args.layer_path,
1645
+ )
1646
+ teacher_model.eval()
1647
+ return teacher_model, "previous_cycle_disk", prev_cycle
1648
+
1649
+ def _get_comm_teacher(cycle_idx: int) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]:
1650
+ nonlocal comm_teacher_model, comm_teacher_cycle
1651
+ if not comm_enabled:
1652
+ return None, "disabled", None
1653
+
1654
+ source = str(getattr(args, "redistrib_teacher_source", "base_model"))
1655
+ if source == "base_model":
1656
+ if comm_teacher_model is None:
1657
+ print(
1658
+ "[comm] Loading fixed base teacher for anchor loss "
1659
+ f"(device={teacher_device})."
1660
+ )
1661
+ comm_teacher_model = AutoModelForCausalLM.from_pretrained(
1662
+ getattr(args, "base_model_id", args.model),
1663
+ torch_dtype=get_dtype(args.dtype),
1664
+ trust_remote_code=args.trust_remote_code,
1665
+ )
1666
+ comm_teacher_model.to(teacher_device)
1667
+ comm_teacher_model.eval()
1668
+ comm_teacher_cycle = 0
1669
+ return comm_teacher_model, "base_model", 0
1670
+
1671
+ prev_cycle = cycle_idx - 1
1672
+ if prev_cycle <= 0:
1673
+ if comm_teacher_model is None or comm_teacher_cycle != 0:
1674
+ _release_comm_teacher()
1675
+ print(
1676
+ "[comm] --redistrib_teacher_source=previous_cycle but cycle 1 "
1677
+ "has no prior checkpoint; using base teacher."
1678
+ )
1679
+ comm_teacher_model = AutoModelForCausalLM.from_pretrained(
1680
+ getattr(args, "base_model_id", args.model),
1681
+ torch_dtype=get_dtype(args.dtype),
1682
+ trust_remote_code=args.trust_remote_code,
1683
+ )
1684
+ comm_teacher_model.to(teacher_device)
1685
+ comm_teacher_model.eval()
1686
+ comm_teacher_cycle = 0
1687
+ return comm_teacher_model, "base_model", 0
1688
+
1689
+ if (
1690
+ previous_cycle_teacher_model is not None
1691
+ and previous_cycle_teacher_cycle == prev_cycle
1692
+ ):
1693
+ if comm_teacher_model is not previous_cycle_teacher_model:
1694
+ _release_comm_teacher()
1695
+ comm_teacher_model = previous_cycle_teacher_model
1696
+ comm_teacher_cycle = prev_cycle
1697
+ return comm_teacher_model, "previous_cycle_memory", prev_cycle
1698
+
1699
+ if comm_teacher_model is None or comm_teacher_cycle != prev_cycle:
1700
+ _release_comm_teacher()
1701
+ print(
1702
+ "[comm] Loading teacher from previous cycle "
1703
+ f"{prev_cycle} (device={teacher_device})."
1704
+ )
1705
+ comm_teacher_model = load_progressive_model(
1706
+ getattr(args, "base_model_id", args.model),
1707
+ args.output_dir,
1708
+ cycle=prev_cycle,
1709
+ device=teacher_device,
1710
+ dtype=args.dtype,
1711
+ trust_remote_code=args.trust_remote_code,
1712
+ layer_path=args.layer_path,
1713
+ )
1714
+ comm_teacher_model.eval()
1715
+ comm_teacher_cycle = prev_cycle
1716
+ return comm_teacher_model, "previous_cycle_disk", prev_cycle
1717
+
1718
+ if args.resume_from_cycle > 0:
1719
+ _snapshot_previous_cycle_teacher(args.resume_from_cycle)
1720
+
1721
+ start_cycle = args.resume_from_cycle + 1
1722
+ for cycle_idx in range(start_cycle, args.num_progressive + 1):
1723
+ print(f"[progressive] Cycle {cycle_idx}/{args.num_progressive}")
1724
+ run_comm = comm_enabled and (
1725
+ cycle_idx > 1 or bool(getattr(args, "comm_include_cycle1", False))
1726
+ )
1727
+ comm_stats: Dict[str, object] = {"enabled": False}
1728
+ comm_post_eval = None
1729
+ if run_comm:
1730
+ # Preconditioning updates model weights, so DWCE reuse is unreliable.
1731
+ start_index = 0
1732
+ reuse_scores = None
1733
+ else:
1734
+ start_index = last_fused_idx if cycle_idx > 1 else 0
1735
+ reuse_scores = dwce_scores
1736
+ exclude_pairs = set(parse_exclude_pairs(args.exclude_pairs, max(len(layers) - 1, 0)))
1737
+ layer_idx, dwce_scores, dwce_meta = resolve_layer_idx(
1738
+ args,
1739
+ model,
1740
+ layers,
1741
+ dataloader,
1742
+ reuse_scores,
1743
+ start_index,
1744
+ exclude_pairs,
1745
+ )
1746
+
1747
+ if run_comm:
1748
+ dwce_scores_pre_comm = dwce_scores
1749
+ if prepared.calib_loader is None:
1750
+ print(
1751
+ "[comm] Enabled but no calibration sequences were built; skipping."
1752
+ )
1753
+ else:
1754
+ (
1755
+ comm_teacher_model_loaded,
1756
+ comm_teacher_source,
1757
+ comm_teacher_cycle_idx,
1758
+ ) = _get_comm_teacher(cycle_idx)
1759
+ if comm_teacher_model_loaded is None:
1760
+ raise RuntimeError("comm_enabled but teacher model was not loaded.")
1761
+ comm_stats = commutator_precondition(
1762
+ student_model=model,
1763
+ student_layers=layers,
1764
+ teacher_model=comm_teacher_model_loaded,
1765
+ dataloader=prepared.calib_loader,
1766
+ dwce_scores=dwce_scores_pre_comm,
1767
+ exclude_pairs=exclude_pairs,
1768
+ args=args,
1769
+ progressive_cycle=cycle_idx,
1770
+ progressive_total=args.num_progressive,
1771
+ )
1772
+ if comm_stats.get("enabled"):
1773
+ comm_stats["teacher_source"] = comm_teacher_source
1774
+ comm_stats["teacher_cycle"] = comm_teacher_cycle_idx
1775
+ comm_stats["dwce_scores_pre"] = dwce_scores_pre_comm
1776
+ print(
1777
+ "[comm] Done:"
1778
+ f" opt_steps={comm_stats.get('opt_steps')}"
1779
+ f" lr={comm_stats.get('lr')}"
1780
+ )
1781
+ if not args.skip_eval:
1782
+ comm_post_eval = evaluate_ppl_model(
1783
+ model,
1784
+ tokenizer,
1785
+ eval_datasets,
1786
+ eval_configs,
1787
+ args,
1788
+ prepared_eval_dataloaders=prepared.eval_dataloaders,
1789
+ )
1790
+ comm_stats["post_ppl"] = comm_post_eval
1791
+ print(f"[progressive] Cycle {cycle_idx} post-comm perplexity:")
1792
+ for dataset_name, ppl in comm_post_eval.items():
1793
+ print(f"{dataset_name}: {ppl:.4f}")
1794
+
1795
+ if bool(getattr(args, "comm_skip_post_reselect", False)):
1796
+ comm_stats["post_selection_recomputed"] = False
1797
+ comm_stats["selected_layer_post"] = int(layer_idx)
1798
+ print(
1799
+ "[comm] Keeping pre-comm DWCE pair selection for fusion."
1800
+ )
1801
+ else:
1802
+ print(
1803
+ "[comm] Recomputing DWCE after preconditioning for fusion selection."
1804
+ )
1805
+ layer_idx, dwce_scores, dwce_meta = resolve_layer_idx(
1806
+ args,
1807
+ model,
1808
+ layers,
1809
+ dataloader,
1810
+ None,
1811
+ 0,
1812
+ exclude_pairs,
1813
+ )
1814
+ comm_stats["post_selection_recomputed"] = True
1815
+ comm_stats["selected_layer_post"] = int(layer_idx)
1816
+
1817
+ if layer_idx < 0 or layer_idx >= len(layers) - 1:
1818
+ raise SystemExit("--layer must be in [0, num_layers-2]")
1819
+
1820
+ num_layers_before = len(layers)
1821
+ layer_a = layers[layer_idx]
1822
+ layer_b = layers[layer_idx + 1]
1823
+
1824
+ norm1_state = None
1825
+ norm2_state = None
1826
+ norm1, norm2, norm_names = get_norm_pair(layer_a)
1827
+ if norm1 is not None:
1828
+ norm1_state = clone_state_dict(norm1)
1829
+ if norm2 is not None:
1830
+ norm2_state = clone_state_dict(norm2)
1831
+
1832
+ attn_a = find_attention_module(layer_a)
1833
+ attn_b = find_attention_module(layer_b)
1834
+ hidden_size = getattr(model.config, "hidden_size", None)
1835
+ if hidden_size is None:
1836
+ hidden_size = getattr(model.config, "n_embd", None)
1837
+ if hidden_size is None:
1838
+ raise SystemExit("Model config missing hidden_size/n_embd")
1839
+
1840
+ no_head_permute_merge = bool(
1841
+ getattr(args, "no_head_permute_merge", False)
1842
+ or getattr(args, "no_head_permute", False)
1843
+ )
1844
+ if no_head_permute_merge:
1845
+ print("[fuse] Head permutation disabled; merging with original head order.")
1846
+ else:
1847
+ mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
1848
+ model,
1849
+ attn_a,
1850
+ attn_b,
1851
+ dataloader,
1852
+ args.device,
1853
+ hidden_size,
1854
+ )
1855
+
1856
+ perm = build_head_permutation(
1857
+ mean_a,
1858
+ mean_b,
1859
+ num_heads=num_heads,
1860
+ num_kv_heads=num_kv_heads,
1861
+ eps=args.eps,
1862
+ )
1863
+ permute_attention_heads(
1864
+ attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim
1865
+ )
1866
+
1867
+ fisher_sums, num_batches, param_numels = compute_fisher(
1868
+ model,
1869
+ layer_a,
1870
+ layer_b,
1871
+ dataloader,
1872
+ fisher_mode=args.fisher_mode,
1873
+ device=args.device,
1874
+ )
1875
+
1876
+ distill_ready = has_post_fusion_data(
1877
+ prepared.distill_loader, prepared.distill_meta
1878
+ )
1879
+ teacher_cycle = cycle_idx - 1
1880
+ teacher_source = "previous_cycle" if teacher_cycle > 0 else "base_model"
1881
+
1882
+ merge_method = "fisher"
1883
+ distill_method = str(getattr(args, "distill_method", "reparam"))
1884
+ reparam_stats: Optional[Dict[str, object]] = None
1885
+ reparam_gate_summary: Optional[Dict[str, object]] = None
1886
+
1887
+ needs_teacher_for_reparam = (
1888
+ (not args.skip_distill)
1889
+ and distill_ready
1890
+ and float(args.distill_epochs) > 0.0
1891
+ )
1892
+ teacher_model = None
1893
+ teacher_parent = None
1894
+ teacher_layer_attr = None
1895
+ teacher_layers: Optional[List[torch.nn.Module]] = None
1896
+ teacher_from_cache = False
1897
+ if needs_teacher_for_reparam:
1898
+ teacher_model, teacher_source, teacher_cycle = _get_previous_cycle_teacher(
1899
+ cycle_idx
1900
+ )
1901
+ teacher_from_cache = (
1902
+ teacher_source == "previous_cycle_memory"
1903
+ and teacher_model is previous_cycle_teacher_model
1904
+ )
1905
+ if teacher_model is None:
1906
+ teacher_model = load_causal_lm(
1907
+ getattr(args, "base_model_id", args.model),
1908
+ torch_dtype=get_dtype(args.dtype),
1909
+ trust_remote_code=args.trust_remote_code,
1910
+ cache_dir=args.model_cache_dir,
1911
+ )
1912
+ teacher_model.to(teacher_device)
1913
+ teacher_model.eval()
1914
+ teacher_source = "base_model"
1915
+ teacher_cycle = 0
1916
+ teacher_parent, teacher_layer_attr, teacher_container = find_layer_container(
1917
+ teacher_model, args.layer_path
1918
+ )
1919
+ teacher_layers = list(teacher_container)
1920
+
1921
+ do_reparam = (
1922
+ (not args.skip_distill)
1923
+ and distill_ready
1924
+ and prepared.distill_loader is not None
1925
+ )
1926
+ if (not args.skip_distill) and not do_reparam:
1927
+ print("[reparam] No distillation sequences built; skipping reparam distill.")
1928
+ distill_post_eval = None
1929
+ if do_reparam:
1930
+ lambda_source = "fisher_prior"
1931
+ reparam_gate_targets: Dict[str, object] = compute_fisher_gate_priors(
1932
+ layer_a=layer_a,
1933
+ layer_b=layer_b,
1934
+ fisher_a=fisher_sums[0],
1935
+ fisher_b=fisher_sums[1],
1936
+ num_batches=num_batches,
1937
+ numels_a=param_numels[0],
1938
+ numels_b=param_numels[1],
1939
+ fisher_mode=args.fisher_mode,
1940
+ eps=float(args.eps),
1941
+ )
1942
+ if not reparam_gate_targets:
1943
+ raise SystemExit("[reparam] No mergeable parameters found; cannot continue.")
1944
+ if float(args.distill_epochs) > 0.0 and (
1945
+ teacher_model is None or teacher_layers is None
1946
+ ):
1947
+ raise SystemExit("--distill_method reparam requires a teacher model.")
1948
+ print(
1949
+ f"[reparam] Cycle {cycle_idx}: training U + gates for pair "
1950
+ f"{layer_idx}-{layer_idx + 1} (epochs={args.distill_epochs}, "
1951
+ f"hidden_mse_w={args.distill_hidden_mse_weight}, "
1952
+ f"attn_mse_w={args.distill_attn_mse_weight}, "
1953
+ f"mlp_mse_w={args.distill_mlp_mse_weight}, "
1954
+ f"eta={args.reparam_eta}, gamma={args.reparam_gamma}, "
1955
+ f"attn_reg_scale={args.reparam_attn_reg_scale}, "
1956
+ f"mlp_reg_scale={args.reparam_mlp_reg_scale}, "
1957
+ f"param_subset={args.reparam_param_subset}, "
1958
+ f"lambda_init={lambda_source})."
1959
+ )
1960
+ merged, final_gates, reparam_stats = distill_reparam_merge(
1961
+ student_model=model,
1962
+ student_parent=parent,
1963
+ student_layer_attr=name,
1964
+ student_layers=layers,
1965
+ teacher_model=teacher_model,
1966
+ teacher_parent=teacher_parent,
1967
+ teacher_layer_attr=teacher_layer_attr,
1968
+ teacher_layers=teacher_layers,
1969
+ layer_idx=layer_idx,
1970
+ gate_lambdas=reparam_gate_targets,
1971
+ dataloader=prepared.distill_loader,
1972
+ args=args,
1973
+ progressive_cycle=cycle_idx,
1974
+ progressive_total=args.num_progressive,
1975
+ )
1976
+ reparam_gate_summary = summarize_gate_lambdas(final_gates)
1977
+ merge_method = "reparam"
1978
+ if reparam_stats is not None:
1979
+ reparam_stats["lambda_init"] = lambda_source
1980
+ else:
1981
+ merged = merge_layers(
1982
+ layer_a,
1983
+ layer_b,
1984
+ fisher_sums[0],
1985
+ fisher_sums[1],
1986
+ num_batches,
1987
+ param_numels[0],
1988
+ param_numels[1],
1989
+ fisher_mode=args.fisher_mode,
1990
+ eps=args.eps,
1991
+ )
1992
+
1993
+ apply_norm_policy(
1994
+ layer_a,
1995
+ args.norm_policy,
1996
+ norm1_state,
1997
+ norm2_state,
1998
+ norm_names,
1999
+ )
2000
+ if teacher_model is not None and not teacher_from_cache:
2001
+ del teacher_model
2002
+ teacher_model = None
2003
+ teacher_parent = None
2004
+ teacher_layer_attr = None
2005
+ teacher_layers = None
2006
+ if torch.cuda.is_available():
2007
+ torch.cuda.empty_cache()
2008
+
2009
+ new_container = drop_layer(container, layer_idx + 1)
2010
+ setattr(parent, name, new_container)
2011
+ decrement_config(model.config)
2012
+ layers = list(new_container)
2013
+
2014
+ lora_post_eval = None
2015
+ if (not args.skip_eval) and (not args.skip_distill) and do_reparam:
2016
+ distill_post_eval = evaluate_ppl_model(
2017
+ model,
2018
+ tokenizer,
2019
+ eval_datasets,
2020
+ eval_configs,
2021
+ args,
2022
+ prepared_eval_dataloaders=prepared.eval_dataloaders,
2023
+ )
2024
+ print(f"[progressive] Cycle {cycle_idx} post-distill perplexity:")
2025
+ for dataset_name, ppl in distill_post_eval.items():
2026
+ print(f"{dataset_name}: {ppl:.4f}")
2027
+
2028
+ post_eval = None
2029
+ if not args.skip_eval:
2030
+ if distill_post_eval is not None:
2031
+ post_eval = distill_post_eval
2032
+ else:
2033
+ post_eval = evaluate_ppl_model(
2034
+ model,
2035
+ tokenizer,
2036
+ eval_datasets,
2037
+ eval_configs,
2038
+ args,
2039
+ prepared_eval_dataloaders=prepared.eval_dataloaders,
2040
+ )
2041
+ print(f"[progressive] Cycle {cycle_idx} perplexity:")
2042
+ for dataset_name, ppl in post_eval.items():
2043
+ print(f"{dataset_name}: {ppl:.4f}")
2044
+
2045
+ cycle_dir = os.path.join(args.output_dir, f"cycle_{cycle_idx}")
2046
+ os.makedirs(cycle_dir, exist_ok=True)
2047
+ fused_layer_file = "fused_layer.pt"
2048
+ fused_layer_path = os.path.join(cycle_dir, fused_layer_file)
2049
+ torch.save(layers[layer_idx].state_dict(), fused_layer_path)
2050
+
2051
+ cycle_meta: Dict[str, object] = {
2052
+ "cycle": cycle_idx,
2053
+ "layer_merged": layer_idx,
2054
+ "num_layers_before": num_layers_before,
2055
+ "num_layers_after": num_layers_before - 1,
2056
+ "fused_layer_state": fused_layer_file,
2057
+ "dwce_score": dwce_scores[layer_idx] if dwce_scores else None,
2058
+ "dwce_scores": dwce_scores,
2059
+ "dwce_meta": dwce_meta,
2060
+ "fisher_num_batches": num_batches,
2061
+ "merge_method": merge_method,
2062
+ "merged_params": merged,
2063
+ "num_sequences": num_sequences,
2064
+ "teacher_source": teacher_source,
2065
+ "teacher_cycle": teacher_cycle,
2066
+ "eval": {
2067
+ "datasets": eval_datasets,
2068
+ "configs": eval_configs,
2069
+ "split": args.eval_split,
2070
+ "num_samples": args.eval_num_samples,
2071
+ "seq_len": args.eval_seq_len,
2072
+ "post_ppl": post_eval,
2073
+ },
2074
+ "comm": comm_stats,
2075
+ "distill": {
2076
+ "enabled": not args.skip_distill,
2077
+ "method": distill_method,
2078
+ "calib_samples": args.distill_calib_samples,
2079
+ "inst_samples": args.distill_inst_samples,
2080
+ "seq_len": args.distill_seq_len,
2081
+ "batch_size": args.distill_batch_size,
2082
+ "epochs": args.distill_epochs,
2083
+ "lr": args.distill_lr,
2084
+ "kl_weight": args.distill_kl_weight,
2085
+ "kl_temp": args.distill_kl_temp,
2086
+ "hidden_mse_weight": args.distill_hidden_mse_weight,
2087
+ "attn_mse_weight": args.distill_attn_mse_weight,
2088
+ "mlp_mse_weight": args.distill_mlp_mse_weight,
2089
+ "reparam_eta": args.reparam_eta,
2090
+ "reparam_gamma": args.reparam_gamma,
2091
+ "reparam_attn_reg_scale": args.reparam_attn_reg_scale,
2092
+ "reparam_mlp_reg_scale": args.reparam_mlp_reg_scale,
2093
+ "reparam_param_subset": args.reparam_param_subset,
2094
+ "reparam_stats": reparam_stats,
2095
+ "reparam_gate_summary": reparam_gate_summary,
2096
+ "post_ppl": distill_post_eval,
2097
+ "weight_decay": args.distill_weight_decay,
2098
+ "max_grad_norm": args.distill_max_grad_norm,
2099
+ "grad_accum_steps": args.distill_grad_accum_steps,
2100
+ "instruction_dataset": args.instruction_dataset,
2101
+ "instruction_config": args.instruction_config,
2102
+ "instruction_split": args.instruction_split,
2103
+ },
2104
+ "lora": {
2105
+ "enabled": args.lora_epochs > 0,
2106
+ "seq_len": args.distill_seq_len,
2107
+ "batch_size": args.distill_batch_size,
2108
+ "epochs": args.lora_epochs,
2109
+ "rank": args.lora_rank,
2110
+ "alpha": args.lora_alpha,
2111
+ "dropout": args.lora_dropout,
2112
+ "target_modules": args.lora_target_modules,
2113
+ "respect_exclude_pairs": args.lora_respect_exclude_pairs,
2114
+ "kl_enabled": args.lora_kl_enabled,
2115
+ "kl_weight": args.lora_kl_weight,
2116
+ "kl_temp": args.lora_kl_temp,
2117
+ "post_ppl": lora_post_eval,
2118
+ "lr": args.lora_lr,
2119
+ "weight_decay": args.lora_weight_decay,
2120
+ "max_grad_norm": args.lora_max_grad_norm,
2121
+ "grad_accum_steps": args.lora_grad_accum_steps,
2122
+ "log_steps": args.lora_log_steps,
2123
+ "eval_every": args.lora_eval_every,
2124
+ "eval_max_batches": args.lora_eval_max_batches,
2125
+ },
2126
+ "norm_policy": args.norm_policy,
2127
+ }
2128
+
2129
+ saved_full_model_dir = None
2130
+ if cycle_idx in args.full_model_save_cycles:
2131
+ cycle_meta["full_model_saved"] = True
2132
+ cycle_meta["full_model"] = save_cycle_full_model(
2133
+ model=model,
2134
+ tokenizer=tokenizer,
2135
+ cycle_dir=cycle_dir,
2136
+ cycle_idx=cycle_idx,
2137
+ args=args,
2138
+ ppl_results=post_eval,
2139
+ )
2140
+ saved_full_model_dir = os.path.join(cycle_dir, "full_model")
2141
+ else:
2142
+ cycle_meta["full_model_saved"] = False
2143
+
2144
+ with open(
2145
+ os.path.join(cycle_dir, "cycle_metadata.json"),
2146
+ "w",
2147
+ encoding="utf-8",
2148
+ ) as handle:
2149
+ json.dump(cycle_meta, handle, indent=2)
2150
+
2151
+ cycle_summaries.append(
2152
+ {
2153
+ "cycle": cycle_idx,
2154
+ "layer_merged": layer_idx,
2155
+ "dwce_score": dwce_scores[layer_idx] if dwce_scores else None,
2156
+ "comm_post_ppl": comm_post_eval,
2157
+ "distill_post_ppl": distill_post_eval,
2158
+ "lora_post_ppl": lora_post_eval,
2159
+ "post_ppl": post_eval,
2160
+ "cycle_dir": f"cycle_{cycle_idx}",
2161
+ }
2162
+ )
2163
+
2164
+ last_fused_idx = layer_idx
2165
+ _snapshot_previous_cycle_teacher(cycle_idx)
2166
+
2167
+ parent, name, container = find_layer_container(model, args.layer_path)
2168
+ layers = list(container)
2169
+ if dwce_scores:
2170
+ dwce_scores = dwce_scores[: max(len(layers) - 1, 0)]
2171
+
2172
+ # Encourage allocator to release cached blocks between cycles.
2173
+ if torch.cuda.is_available():
2174
+ torch.cuda.empty_cache()
2175
+ torch.cuda.ipc_collect()
2176
+ gc.collect()
2177
+
2178
+ if saved_full_model_dir is not None:
2179
+ save_rng_state(os.path.join(saved_full_model_dir, "rng_state.pt"))
2180
+ save_loader_generator_state(
2181
+ saved_full_model_dir,
2182
+ distill_generator=prepared.distill_generator,
2183
+ lora_generator=prepared.lora_generator,
2184
+ )
2185
+
2186
+ _release_comm_teacher()
2187
+ _release_previous_cycle_teacher()
2188
+
2189
+ final_pre_lora_eval = cycle_summaries[-1]["post_ppl"] if cycle_summaries else None
2190
+ final_pre_lora_dir = f"{os.path.abspath(args.output_dir.rstrip(os.sep))}_final_pre_lora_hf"
2191
+ final_pre_lora_meta = save_stage_checkpoint(
2192
+ model=model,
2193
+ tokenizer=tokenizer,
2194
+ stage_dir=final_pre_lora_dir,
2195
+ stage_name="final_pre_lora",
2196
+ ppl_results=final_pre_lora_eval,
2197
+ )
2198
+
2199
+ # Optional final LoRA finetune after all pruning cycles.
2200
+ lora_eval_history: List[Dict[str, object]] = []
2201
+ lora_post_eval = None
2202
+ lora_ready = has_post_fusion_data(prepared.lora_loader, prepared.lora_meta)
2203
+ if args.lora_epochs > 0:
2204
+ if not lora_ready:
2205
+ print("No post-fusion sequences built; skipping LoRA finetuning.")
2206
+ else:
2207
+ print(
2208
+ f"[progressive] Running final LoRA finetuning (epochs={args.lora_epochs})."
2209
+ )
2210
+ lora_eval_history = run_lora_phase(
2211
+ model=model,
2212
+ tokenizer=tokenizer,
2213
+ eval_datasets=eval_datasets,
2214
+ eval_configs=eval_configs,
2215
+ args=args,
2216
+ lora_loader=prepared.lora_loader,
2217
+ lora_meta=prepared.lora_meta,
2218
+ eval_dataloaders=prepared.eval_dataloaders,
2219
+ cycle_idx=args.num_progressive,
2220
+ num_cycles=args.num_progressive,
2221
+ )
2222
+
2223
+ if not args.skip_eval:
2224
+ lora_post_eval = evaluate_ppl_model(
2225
+ model,
2226
+ tokenizer,
2227
+ eval_datasets,
2228
+ eval_configs,
2229
+ args,
2230
+ prepared_eval_dataloaders=prepared.eval_dataloaders,
2231
+ )
2232
+ print("[progressive] Post-LoRA perplexity:")
2233
+ for dataset_name, ppl in lora_post_eval.items():
2234
+ print(f"{dataset_name}: {ppl:.4f}")
2235
+
2236
+ # Update final cycle metadata and summary with the post-LoRA PPL.
2237
+ if cycle_summaries:
2238
+ cycle_summaries[-1]["lora_post_ppl"] = lora_post_eval
2239
+ if lora_post_eval is not None:
2240
+ cycle_summaries[-1]["post_ppl"] = lora_post_eval
2241
+
2242
+ final_cycle_dir = os.path.join(
2243
+ args.output_dir, f"cycle_{args.num_progressive}"
2244
+ )
2245
+ final_cycle_meta_path = os.path.join(final_cycle_dir, "cycle_metadata.json")
2246
+ if os.path.exists(final_cycle_meta_path):
2247
+ with open(final_cycle_meta_path, "r", encoding="utf-8") as handle:
2248
+ final_cycle_meta = json.load(handle)
2249
+
2250
+ lora_meta_entry = final_cycle_meta.get("lora")
2251
+ if not isinstance(lora_meta_entry, dict):
2252
+ lora_meta_entry = {}
2253
+ final_cycle_meta["lora"] = lora_meta_entry
2254
+ lora_meta_entry["ran"] = True
2255
+ lora_meta_entry["post_ppl"] = lora_post_eval
2256
+ if lora_post_eval is not None and isinstance(
2257
+ final_cycle_meta.get("eval"), dict
2258
+ ):
2259
+ final_cycle_meta["eval"]["post_ppl"] = lora_post_eval
2260
+
2261
+ if lora_eval_history:
2262
+ lora_path = os.path.join(final_cycle_dir, "ppl_over_lora.json")
2263
+ with open(lora_path, "w", encoding="utf-8") as handle:
2264
+ json.dump(lora_eval_history, handle, indent=2)
2265
+ lora_meta_entry["ppl_over_lora"] = "ppl_over_lora.json"
2266
+
2267
+ with open(final_cycle_meta_path, "w", encoding="utf-8") as handle:
2268
+ json.dump(final_cycle_meta, handle, indent=2)
2269
+
2270
+ os.makedirs(args.output_dir, exist_ok=True)
2271
+ final_post_lora_meta = save_stage_checkpoint(
2272
+ model=model,
2273
+ tokenizer=tokenizer,
2274
+ stage_dir=args.output_dir,
2275
+ stage_name="final_post_lora" if lora_post_eval is not None else "final_model",
2276
+ ppl_results=lora_post_eval,
2277
+ )
2278
+
2279
+ progressive_meta = {
2280
+ "base_model": getattr(args, "base_model_id", args.model),
2281
+ "num_progressive": args.num_progressive,
2282
+ "layer_path": args.layer_path,
2283
+ "resume_from_cycle": args.resume_from_cycle,
2284
+ "save_full_model_cycles": sorted(args.full_model_save_cycles),
2285
+ "num_sequences": num_sequences,
2286
+ "seq_len": args.seq_len,
2287
+ "lora": {
2288
+ "enabled": args.lora_epochs > 0,
2289
+ "ran": args.lora_epochs > 0 and lora_ready,
2290
+ "seq_len": args.distill_seq_len,
2291
+ "batch_size": args.distill_batch_size,
2292
+ "epochs": args.lora_epochs,
2293
+ "rank": args.lora_rank,
2294
+ "alpha": args.lora_alpha,
2295
+ "dropout": args.lora_dropout,
2296
+ "target_modules": args.lora_target_modules,
2297
+ "respect_exclude_pairs": args.lora_respect_exclude_pairs,
2298
+ "kl_enabled": args.lora_kl_enabled,
2299
+ "kl_weight": args.lora_kl_weight,
2300
+ "kl_temp": args.lora_kl_temp,
2301
+ "post_ppl": lora_post_eval,
2302
+ "ppl_over_lora": (
2303
+ f"cycle_{args.num_progressive}/ppl_over_lora.json"
2304
+ if lora_eval_history
2305
+ else None
2306
+ ),
2307
+ "lr": args.lora_lr,
2308
+ "weight_decay": args.lora_weight_decay,
2309
+ "max_grad_norm": args.lora_max_grad_norm,
2310
+ "grad_accum_steps": args.lora_grad_accum_steps,
2311
+ "log_steps": args.lora_log_steps,
2312
+ "eval_every": args.lora_eval_every,
2313
+ "eval_max_batches": args.lora_eval_max_batches,
2314
+ },
2315
+ "artifacts": {
2316
+ "final_pre_lora": final_pre_lora_meta,
2317
+ "final_post_lora": final_post_lora_meta,
2318
+ },
2319
+ "eval": {
2320
+ "datasets": eval_datasets,
2321
+ "configs": eval_configs,
2322
+ "split": args.eval_split,
2323
+ "num_samples": args.eval_num_samples,
2324
+ "seq_len": args.eval_seq_len,
2325
+ "pre_ppl": pre_eval,
2326
+ "post_ppl": cycle_summaries[-1]["post_ppl"] if cycle_summaries else None,
2327
+ },
2328
+ "cycles": cycle_summaries,
2329
+ "final_num_layers": len(layers),
2330
+ }
2331
+
2332
+ with open(
2333
+ os.path.join(args.output_dir, "progressive_metadata.json"),
2334
+ "w",
2335
+ encoding="utf-8",
2336
+ ) as handle:
2337
+ json.dump(progressive_meta, handle, indent=2)
2338
+
2339
+ print(
2340
+ f"[progressive] Completed {args.num_progressive} cycles. "
2341
+ f"Final model saved to {args.output_dir}."
2342
+ )
2343
+
2344
+
2345
+ def main() -> None:
2346
+ args = parse_args()
2347
+ if args.num_progressive <= 0:
2348
+ raise SystemExit(
2349
+ "Single-cycle mode has been removed. Pass --num_progressive > 0."
2350
+ )
2351
+ if args.resume_from_cycle < 0:
2352
+ raise SystemExit("--resume_from_cycle must be >= 0.")
2353
+ if args.resume_from_cycle >= args.num_progressive:
2354
+ raise SystemExit("--resume_from_cycle must be smaller than --num_progressive.")
2355
+ args.full_model_save_cycles = resolve_full_model_save_cycles(
2356
+ parse_cycle_list(args.save_full_model_cycles),
2357
+ args.num_progressive,
2358
+ )
2359
+ args.base_model_id = args.model
2360
+ if args.resume_from_cycle > 0:
2361
+ resume_meta = load_resume_metadata(args.model)
2362
+ if resume_meta is None:
2363
+ raise SystemExit(
2364
+ "--resume_from_cycle requires --model to point to a saved cycle full model "
2365
+ "directory containing resume_info.json."
2366
+ )
2367
+ resume_cycle = resume_meta.get("cycle")
2368
+ if resume_cycle is not None and int(resume_cycle) != args.resume_from_cycle:
2369
+ raise SystemExit(
2370
+ "resume_info.json cycle does not match --resume_from_cycle."
2371
+ )
2372
+ base_model = resume_meta.get("base_model")
2373
+ if isinstance(base_model, str) and base_model:
2374
+ args.base_model_id = base_model
2375
+ configure_reproducibility(args.seed)
2376
+
2377
+ eval_datasets, eval_configs = resolve_eval_datasets(args)
2378
+ dtype = get_dtype(args.dtype)
2379
+ model = load_causal_lm(
2380
+ args.model,
2381
+ torch_dtype=dtype,
2382
+ trust_remote_code=args.trust_remote_code,
2383
+ cache_dir=args.model_cache_dir,
2384
+ )
2385
+ loader_generator_state = None
2386
+ if args.resume_from_cycle > 0:
2387
+ rng_state_path = os.path.join(args.model, "rng_state.pt")
2388
+ rng_state = load_rng_state(rng_state_path)
2389
+ if rng_state is not None:
2390
+ restore_rng_state(rng_state)
2391
+ loader_generator_state = load_loader_generator_state(args.model)
2392
+ tokenizer = AutoTokenizer.from_pretrained(
2393
+ args.model,
2394
+ trust_remote_code=args.trust_remote_code,
2395
+ cache_dir=args.model_cache_dir,
2396
+ )
2397
+ print(model)
2398
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
2399
+ tokenizer.pad_token = tokenizer.eos_token
2400
+
2401
+ # for llama?
2402
+ model.config.use_cache = False
2403
+
2404
+ prepared = prepare_all_data(
2405
+ args,
2406
+ tokenizer,
2407
+ model,
2408
+ eval_datasets,
2409
+ eval_configs,
2410
+ loader_generator_state=loader_generator_state,
2411
+ )
2412
+ run_progressive(args, model, tokenizer, prepared)
2413
+
2414
+
2415
+ if __name__ == "__main__":
2416
+ main()
src/fuse_layers_data.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Dataset and text helpers for fuse_layers."""
3
+
4
+ import argparse
5
+ from typing import Dict, List, Optional
6
+
7
+ import torch
8
+
9
+ try:
10
+ from datasets import load_dataset
11
+ except Exception: # pragma: no cover - optional dependency
12
+ load_dataset = None
13
+
14
+
15
+ def guess_text_field(dataset) -> str:
16
+ if hasattr(dataset, "column_names") and dataset.column_names:
17
+ if "text" in dataset.column_names:
18
+ return "text"
19
+ return dataset.column_names[0]
20
+ if hasattr(dataset, "features"):
21
+ names = list(dataset.features.keys())
22
+ if "text" in names:
23
+ return "text"
24
+ if names:
25
+ return names[0]
26
+ return "text"
27
+
28
+
29
+ def _normalize_config(config: Optional[str]) -> Optional[str]:
30
+ if config is None:
31
+ return None
32
+ if config.strip().lower() in {"none", "null", "-"}:
33
+ return None
34
+ return config
35
+
36
+
37
+ def expand_dataset_configs(
38
+ datasets: List[str], configs: List[str]
39
+ ) -> List[Optional[str]]:
40
+ if not configs:
41
+ return [None] * len(datasets)
42
+ if len(configs) == 1 and len(datasets) > 1:
43
+ return [_normalize_config(configs[0])] * len(datasets)
44
+ if len(configs) != len(datasets):
45
+ raise SystemExit(
46
+ "Provide zero, one, or matching-count --dataset_config values."
47
+ )
48
+ return [_normalize_config(cfg) for cfg in configs]
49
+
50
+
51
+ def _sample_dataset_rows(
52
+ dataset, target: int, seed: int
53
+ ) -> List[Dict[str, object]]:
54
+ if target <= 0:
55
+ return []
56
+ try:
57
+ dataset = dataset.shuffle(seed=seed)
58
+ except Exception:
59
+ pass
60
+
61
+ if hasattr(dataset, "__len__"):
62
+ limit = min(target, len(dataset))
63
+ dataset = dataset.select(range(limit))
64
+ return [row for row in dataset]
65
+
66
+ rows = []
67
+ for row in dataset:
68
+ rows.append(row)
69
+ if len(rows) >= target:
70
+ break
71
+ return rows
72
+
73
+
74
+ def load_texts(args: argparse.Namespace) -> List[str]:
75
+ texts: List[str] = []
76
+ if args.text_file:
77
+ with open(args.text_file, "r", encoding="utf-8") as handle:
78
+ texts.extend([line.strip() for line in handle if line.strip()])
79
+ if args.text:
80
+ texts.extend([t for t in args.text if t])
81
+
82
+ if args.dataset:
83
+ if load_dataset is None:
84
+ raise SystemExit("datasets is required for --dataset")
85
+
86
+ datasets = list(args.dataset)
87
+ configs = expand_dataset_configs(datasets, list(args.dataset_config))
88
+ num_datasets = len(datasets)
89
+ base = args.num_samples // num_datasets
90
+ remainder = args.num_samples % num_datasets
91
+
92
+ for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
93
+ target = base + (1 if idx < remainder else 0)
94
+ dataset = load_dataset(
95
+ dataset_name,
96
+ config,
97
+ split=args.dataset_split,
98
+ trust_remote_code=True,
99
+ )
100
+ rows = _sample_dataset_rows(dataset, target, args.seed + idx)
101
+ text_field = args.dataset_text_field or guess_text_field(dataset)
102
+ for row in rows:
103
+ value = row.get(text_field, None) if isinstance(row, dict) else None
104
+ if isinstance(value, str) and value.strip():
105
+ texts.append(value)
106
+
107
+ return texts
108
+
109
+
110
+ def load_texts_from_datasets(
111
+ datasets: List[str],
112
+ configs: List[Optional[str]],
113
+ split: str,
114
+ text_field: Optional[str],
115
+ num_samples: int,
116
+ seed: int,
117
+ ) -> List[str]:
118
+ if not datasets:
119
+ return []
120
+ if load_dataset is None:
121
+ raise SystemExit("datasets is required for --dataset")
122
+
123
+ texts: List[str] = []
124
+ num_datasets = len(datasets)
125
+ base = num_samples // num_datasets
126
+ remainder = num_samples % num_datasets
127
+
128
+ for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
129
+ target = base + (1 if idx < remainder else 0)
130
+ dataset = load_dataset(
131
+ dataset_name,
132
+ config,
133
+ split=split,
134
+ trust_remote_code=True,
135
+ )
136
+ rows = _sample_dataset_rows(dataset, target, seed + idx)
137
+ field = text_field or guess_text_field(dataset)
138
+ for row in rows:
139
+ value = row.get(field, None) if isinstance(row, dict) else None
140
+ if isinstance(value, str) and value.strip():
141
+ texts.append(value)
142
+ return texts
143
+
144
+
145
+ def format_alpaca_example(instruction: str, inp: str, output: str) -> str:
146
+ if inp:
147
+ return (
148
+ "### Instruction:\n"
149
+ f"{instruction}\n\n"
150
+ "### Input:\n"
151
+ f"{inp}\n\n"
152
+ "### Response:\n"
153
+ f"{output}"
154
+ )
155
+ return (
156
+ "### Instruction:\n"
157
+ f"{instruction}\n\n"
158
+ "### Response:\n"
159
+ f"{output}"
160
+ )
161
+
162
+
163
+ def build_alpaca_messages(
164
+ instruction: str, inp: str, output: str
165
+ ) -> List[Dict[str, str]]:
166
+ if inp:
167
+ user_content = f"{instruction}\n\nInput:\n{inp}"
168
+ else:
169
+ user_content = instruction
170
+ return [
171
+ {"role": "user", "content": user_content},
172
+ {"role": "assistant", "content": output},
173
+ ]
174
+
175
+
176
+ class FixedSeqDataset(torch.utils.data.Dataset):
177
+ def __init__(self, records: List[Dict[str, object]], tokenizer, seq_len: int) -> None:
178
+ self.records = records
179
+ self.tokenizer = tokenizer
180
+ self.seq_len = seq_len
181
+ self.pad_id = tokenizer.pad_token_id
182
+ if self.pad_id is None:
183
+ self.pad_id = tokenizer.eos_token_id or 0
184
+
185
+ def __len__(self) -> int:
186
+ return len(self.records)
187
+
188
+ def __getitem__(self, idx: int):
189
+ record = self.records[idx]
190
+ chat_template = getattr(self.tokenizer, "chat_template", None)
191
+ if (
192
+ "messages" in record
193
+ and hasattr(self.tokenizer, "apply_chat_template")
194
+ and chat_template
195
+ ):
196
+ ids = self.tokenizer.apply_chat_template(
197
+ record["messages"],
198
+ tokenize=True,
199
+ add_generation_prompt=False,
200
+ )
201
+ else:
202
+ text = record.get("text", "")
203
+ ids = self.tokenizer.encode(text, add_special_tokens=False)
204
+
205
+ # Transformers may return a BatchEncoding here instead of a plain list.
206
+ if hasattr(ids, "input_ids"):
207
+ ids = ids.input_ids
208
+ if isinstance(ids, torch.Tensor):
209
+ ids = ids.tolist()
210
+ elif not isinstance(ids, list):
211
+ ids = list(ids)
212
+
213
+ if len(ids) > self.seq_len:
214
+ ids = ids[: self.seq_len]
215
+ attn = [1] * len(ids)
216
+ if len(ids) < self.seq_len:
217
+ pad_len = self.seq_len - len(ids)
218
+ ids = ids + [self.pad_id] * pad_len
219
+ attn = attn + [0] * pad_len
220
+
221
+ return (
222
+ torch.tensor(ids, dtype=torch.long),
223
+ torch.tensor(attn, dtype=torch.long),
224
+ )
225
+
226
+
227
+ def load_instruction_records(
228
+ args: argparse.Namespace, num_samples: int
229
+ ) -> List[Dict[str, object]]:
230
+ if not args.instruction_dataset:
231
+ return []
232
+ if load_dataset is None:
233
+ raise SystemExit("datasets is required for instruction dataset")
234
+
235
+ dataset = load_dataset(
236
+ args.instruction_dataset,
237
+ _normalize_config(args.instruction_config),
238
+ split=args.instruction_split,
239
+ trust_remote_code=True,
240
+ )
241
+ if num_samples > 0:
242
+ rows = _sample_dataset_rows(dataset, num_samples, args.seed)
243
+ else:
244
+ rows = dataset
245
+ records: List[Dict[str, object]] = []
246
+ for row in rows:
247
+ if not isinstance(row, dict):
248
+ continue
249
+ instruction = str(row.get(args.instruction_field_instruction, "")).strip()
250
+ inp = str(row.get(args.instruction_field_input, "")).strip()
251
+ output = str(row.get(args.instruction_field_output, "")).strip()
252
+ if not instruction or not output:
253
+ continue
254
+ records.append(
255
+ {
256
+ "messages": build_alpaca_messages(instruction, inp, output),
257
+ "text": format_alpaca_example(instruction, inp, output),
258
+ }
259
+ )
260
+ return records
261
+
262
+
263
+ def build_token_chunks(
264
+ texts: List[str], tokenizer, seq_len: int, num_samples: int
265
+ ) -> List[torch.Tensor]:
266
+ chunks: List[torch.Tensor] = []
267
+ buffer: List[int] = []
268
+ limit = None if num_samples <= 0 else num_samples
269
+ for text in texts:
270
+ ids = tokenizer.encode(text, add_special_tokens=False)
271
+ if not ids:
272
+ continue
273
+ buffer.extend(ids)
274
+ while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
275
+ chunk = buffer[:seq_len]
276
+ buffer = buffer[seq_len:]
277
+ chunks.append(torch.tensor(chunk, dtype=torch.long))
278
+ if limit is not None and len(chunks) >= limit:
279
+ break
280
+ return chunks
src/fuse_layers_distill.py ADDED
@@ -0,0 +1,2018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Distillation helpers for fuse_layers."""
3
+
4
+ import argparse
5
+ import itertools
6
+ import math
7
+ import os
8
+ from contextlib import contextmanager, nullcontext
9
+ from typing import Dict, List, Optional, Set, Tuple
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ import ppl_eval
16
+ except Exception as exc: # pragma: no cover - optional dependency
17
+ raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
18
+ try:
19
+ from tqdm import tqdm
20
+ except Exception: # pragma: no cover - optional dependency
21
+ tqdm = None
22
+
23
+ try:
24
+ from torch.func import functional_call as _functional_call
25
+ except Exception: # pragma: no cover - depends on torch version
26
+ try:
27
+ from torch.nn.utils.stateless import functional_call as _functional_call
28
+ except Exception: # pragma: no cover - depends on torch version
29
+ _functional_call = None
30
+
31
+ from fuse_layers_model import find_attention_module, find_mlp_module
32
+
33
+
34
+ def _tqdm_enabled() -> bool:
35
+ value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
36
+ return value.strip().lower() not in {"1", "true", "yes", "on"}
37
+
38
+
39
+ @contextmanager
40
+ def temporary_layers(parent: object, name: str, new_layers: torch.nn.Module):
41
+ original = getattr(parent, name)
42
+ setattr(parent, name, new_layers)
43
+ try:
44
+ yield
45
+ finally:
46
+ setattr(parent, name, original)
47
+
48
+
49
+ @contextmanager
50
+ def temporary_norm(parent: object):
51
+ if hasattr(parent, "norm"):
52
+ original = getattr(parent, "norm")
53
+ setattr(parent, "norm", torch.nn.Identity())
54
+ try:
55
+ yield
56
+ finally:
57
+ setattr(parent, "norm", original)
58
+ else:
59
+ yield
60
+
61
+
62
+ def forward_truncated(
63
+ parent: torch.nn.Module,
64
+ layer_attr: str,
65
+ layers: List[torch.nn.Module],
66
+ upto: int,
67
+ input_ids: torch.Tensor,
68
+ attention_mask: Optional[torch.Tensor] = None,
69
+ ) -> torch.Tensor:
70
+ truncated = torch.nn.ModuleList(layers[:upto])
71
+ with temporary_layers(parent, layer_attr, truncated), temporary_norm(parent):
72
+ outputs = parent(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ use_cache=False,
76
+ )
77
+ if hasattr(outputs, "last_hidden_state"):
78
+ return outputs.last_hidden_state
79
+ return outputs[0]
80
+
81
+
82
+ def _masked_hidden_mse(diff: torch.Tensor, attention_mask: torch.Tensor) -> Optional[torch.Tensor]:
83
+ diff_f = diff.float()
84
+ mask = attention_mask.to(device=diff.device, dtype=torch.float32)
85
+ denom = mask.sum() * diff_f.size(-1)
86
+ if denom.item() == 0:
87
+ return None
88
+ return (diff_f.pow(2) * mask.unsqueeze(-1)).sum() / denom
89
+
90
+
91
+ def _extract_hidden_like(output) -> Optional[torch.Tensor]:
92
+ if torch.is_tensor(output):
93
+ return output
94
+ if isinstance(output, (tuple, list)) and output:
95
+ first = output[0]
96
+ if torch.is_tensor(first):
97
+ return first
98
+ if hasattr(output, "last_hidden_state"):
99
+ hidden = getattr(output, "last_hidden_state")
100
+ if torch.is_tensor(hidden):
101
+ return hidden
102
+ return None
103
+
104
+
105
+ @contextmanager
106
+ def capture_module_output(module: torch.nn.Module):
107
+ cache: Dict[str, Optional[torch.Tensor]] = {"output": None}
108
+
109
+ def hook(_module, _inputs, output):
110
+ cache["output"] = _extract_hidden_like(output)
111
+
112
+ handle = module.register_forward_hook(hook)
113
+ try:
114
+ yield cache
115
+ finally:
116
+ handle.remove()
117
+
118
+
119
+ _ATTN_NAME_FRAGMENTS = (
120
+ "self_attn.",
121
+ "attn.",
122
+ "attention.",
123
+ "q_proj",
124
+ "k_proj",
125
+ "v_proj",
126
+ "o_proj",
127
+ "q_norm",
128
+ "k_norm",
129
+ )
130
+ _MLP_NAME_FRAGMENTS = (
131
+ "mlp.",
132
+ "ffn.",
133
+ "feed_forward",
134
+ "feedforward",
135
+ "gate_proj",
136
+ "up_proj",
137
+ "down_proj",
138
+ "fc1",
139
+ "fc2",
140
+ "dense_h_to_4h",
141
+ "dense_4h_to_h",
142
+ "w1",
143
+ "w2",
144
+ "w3",
145
+ )
146
+
147
+
148
+ def _classify_param_family(name: str) -> str:
149
+ lowered = name.lower()
150
+ if any(fragment in lowered for fragment in _MLP_NAME_FRAGMENTS):
151
+ return "mlp"
152
+ if any(fragment in lowered for fragment in _ATTN_NAME_FRAGMENTS):
153
+ return "attn"
154
+ return "other"
155
+
156
+
157
+ def _family_reg_scale(family: str, attn_scale: float, mlp_scale: float) -> float:
158
+ if family == "attn":
159
+ return attn_scale
160
+ if family == "mlp":
161
+ return mlp_scale
162
+ return 1.0
163
+
164
+
165
+ def _subset_allows_param(name: str, subset: str) -> bool:
166
+ if subset == "all":
167
+ return True
168
+ return _classify_param_family(name) == subset
169
+
170
+
171
+ def _gate_logit_from_prior(prior: torch.Tensor) -> torch.Tensor:
172
+ # Stable logit: log(p) - log(1 - p).
173
+ return torch.log(prior) - torch.log1p(-prior)
174
+
175
+
176
+ def _build_gate_priors(
177
+ layer_a: torch.nn.Module,
178
+ layer_b: torch.nn.Module,
179
+ fisher_a: Dict[str, object],
180
+ fisher_b: Dict[str, object],
181
+ num_batches: int,
182
+ numels_a: Dict[str, int],
183
+ numels_b: Dict[str, int],
184
+ fisher_mode: str,
185
+ eps: float,
186
+ clamp_eps: float,
187
+ ) -> Dict[str, torch.Tensor]:
188
+ """Return lambda priors for parameters that can be merged."""
189
+ priors: Dict[str, torch.Tensor] = {}
190
+ params_b = {name: param for name, param in layer_b.named_parameters()}
191
+ for name, param_a in layer_a.named_parameters():
192
+ param_b = params_b.get(name)
193
+ if param_b is None or param_b.shape != param_a.shape:
194
+ continue
195
+ if fisher_mode == "param":
196
+ fa = fisher_a[name] / max(num_batches, 1)
197
+ fb = fisher_b[name] / max(num_batches, 1)
198
+ denom = fa + fb
199
+ if not isinstance(denom, torch.Tensor):
200
+ denom = torch.tensor(float(denom))
201
+ # If Fisher is uninformative, default to symmetric init.
202
+ prior = torch.where(
203
+ denom > eps,
204
+ fa / (denom + eps),
205
+ torch.full_like(denom, 0.5),
206
+ )
207
+ prior = prior.clamp(clamp_eps, 1.0 - clamp_eps)
208
+ priors[name] = prior
209
+ else:
210
+ fa = fisher_a[name] / (max(num_batches, 1) * numels_a[name])
211
+ fb = fisher_b[name] / (max(num_batches, 1) * numels_b[name])
212
+ denom = fa + fb
213
+ if denom <= eps:
214
+ prior_val = 0.5
215
+ else:
216
+ prior_val = float(fa / (denom + eps))
217
+ prior_val = min(max(prior_val, clamp_eps), 1.0 - clamp_eps)
218
+ priors[name] = torch.tensor(prior_val, dtype=torch.float32)
219
+ return priors
220
+
221
+
222
+ def compute_fisher_gate_priors(
223
+ layer_a: torch.nn.Module,
224
+ layer_b: torch.nn.Module,
225
+ fisher_a: Dict[str, object],
226
+ fisher_b: Dict[str, object],
227
+ num_batches: int,
228
+ numels_a: Dict[str, int],
229
+ numels_b: Dict[str, int],
230
+ fisher_mode: str,
231
+ eps: float,
232
+ clamp_eps: float = 1e-4,
233
+ ) -> Dict[str, torch.Tensor]:
234
+ """Compute Fisher prior gate lambdas (lambda_prior) for mergeable parameters."""
235
+ return _build_gate_priors(
236
+ layer_a=layer_a,
237
+ layer_b=layer_b,
238
+ fisher_a=fisher_a,
239
+ fisher_b=fisher_b,
240
+ num_batches=num_batches,
241
+ numels_a=numels_a,
242
+ numels_b=numels_b,
243
+ fisher_mode=fisher_mode,
244
+ eps=eps,
245
+ clamp_eps=clamp_eps,
246
+ )
247
+
248
+
249
+ class ReparamMergedLayer(torch.nn.Module):
250
+ """Virtual layer that merges parameters via W0/U reparameterization.
251
+
252
+ Parameters of layer_a/layer_b are treated as frozen (detached). We train:
253
+ - gate logits s (lambda = sigmoid(s))
254
+ - U (initialized as U0 = (W_a - W_b) / 2)
255
+
256
+ Forward uses:
257
+ W_merge = W0 + (2 * lambda - 1) * U
258
+ where W0 = (W_a + W_b) / 2
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ layer_a: torch.nn.Module,
264
+ layer_b: torch.nn.Module,
265
+ gate_targets: Dict[str, object],
266
+ param_subset: str = "all",
267
+ clamp_eps: float = 1e-4,
268
+ ) -> None:
269
+ super().__init__()
270
+ self.layer_a = layer_a
271
+ self.layer_b = layer_b
272
+ self.param_subset = param_subset
273
+ self._name_map: Dict[str, str] = {}
274
+
275
+ self.gates = torch.nn.ParameterDict()
276
+ self.u = torch.nn.ParameterDict()
277
+
278
+ params_b = {name: param for name, param in layer_b.named_parameters()}
279
+ try:
280
+ device = next(layer_a.parameters()).device
281
+ except StopIteration:
282
+ device = torch.device("cpu")
283
+
284
+ for name, param_a in layer_a.named_parameters():
285
+ param_b = params_b.get(name)
286
+ if param_b is None or param_b.shape != param_a.shape:
287
+ continue
288
+ if not _subset_allows_param(name, self.param_subset):
289
+ continue
290
+
291
+ target = gate_targets.get(name)
292
+ if target is None:
293
+ target_t = torch.tensor(0.5, device=device, dtype=torch.float32)
294
+ elif isinstance(target, torch.Tensor):
295
+ target_t = target.detach().to(device=device, dtype=torch.float32)
296
+ else:
297
+ target_t = torch.tensor(float(target), device=device, dtype=torch.float32)
298
+
299
+ target_t = target_t.clamp(clamp_eps, 1.0 - clamp_eps)
300
+ s0 = _gate_logit_from_prior(target_t)
301
+ u0 = 0.5 * (param_a.detach().float() - param_b.detach().float())
302
+
303
+ safe = name.replace(".", "__")
304
+ if safe in self.gates:
305
+ safe = f"{safe}_{len(self.gates)}"
306
+ self._name_map[name] = safe
307
+ self.gates[safe] = torch.nn.Parameter(s0)
308
+ self.u[safe] = torch.nn.Parameter(u0)
309
+
310
+ def __getattr__(self, name: str):
311
+ # Delegate model-specific attributes (e.g. Qwen's `attention_type`) to
312
+ # the underlying layer so the parent forward doesn't break.
313
+ try:
314
+ return super().__getattr__(name)
315
+ except AttributeError as exc:
316
+ try:
317
+ layer_a = super().__getattr__("layer_a")
318
+ if hasattr(layer_a, name):
319
+ return getattr(layer_a, name)
320
+ except AttributeError:
321
+ pass
322
+ try:
323
+ layer_b = super().__getattr__("layer_b")
324
+ if hasattr(layer_b, name):
325
+ return getattr(layer_b, name)
326
+ except AttributeError:
327
+ pass
328
+ raise exc
329
+
330
+ def _safe_for(self, orig: str) -> Optional[str]:
331
+ return self._name_map.get(orig)
332
+
333
+ def gate_lambdas(self) -> Dict[str, torch.Tensor]:
334
+ out: Dict[str, torch.Tensor] = {}
335
+ for orig, safe in self._name_map.items():
336
+ out[orig] = torch.sigmoid(self.gates[safe]).detach()
337
+ return out
338
+
339
+ def _merged_params(self) -> Dict[str, torch.Tensor]:
340
+ params_a = {name: p for name, p in self.layer_a.named_parameters()}
341
+ params_b = {name: p for name, p in self.layer_b.named_parameters()}
342
+ merged_params: Dict[str, torch.Tensor] = {}
343
+
344
+ for name, param_a in params_a.items():
345
+ param_b = params_b.get(name)
346
+ safe = self._safe_for(name)
347
+ if safe is None or param_b is None or param_b.shape != param_a.shape:
348
+ merged_params[name] = param_a.detach()
349
+ continue
350
+
351
+ lam = torch.sigmoid(self.gates[safe]).to(dtype=torch.float32)
352
+ u = self.u[safe].to(dtype=torch.float32)
353
+ w0 = 0.5 * (param_a.detach().float() + param_b.detach().float())
354
+ merged = w0 + (2.0 * lam - 1.0) * u
355
+ merged_params[name] = merged.to(dtype=param_a.dtype)
356
+ return merged_params
357
+
358
+ def forward(self, *args, **kwargs):
359
+ if _functional_call is None:
360
+ raise RuntimeError(
361
+ "Reparam distillation requires torch.func.functional_call"
362
+ )
363
+
364
+ merged_params = self._merged_params()
365
+ return _functional_call(self.layer_a, merged_params, args, kwargs)
366
+
367
+ def materialize_into_layer_a(self) -> int:
368
+ merged = 0
369
+ params_a = {name: p for name, p in self.layer_a.named_parameters()}
370
+ params_b = {name: p for name, p in self.layer_b.named_parameters()}
371
+ with torch.no_grad():
372
+ for orig, safe in self._name_map.items():
373
+ param_a = params_a.get(orig)
374
+ param_b = params_b.get(orig)
375
+ if param_a is None or param_b is None or param_b.shape != param_a.shape:
376
+ continue
377
+ lam = torch.sigmoid(self.gates[safe]).to(device=param_a.device, dtype=torch.float32)
378
+ u = self.u[safe].to(device=param_a.device, dtype=torch.float32)
379
+ w0 = 0.5 * (param_a.detach().float() + param_b.detach().float())
380
+ merged_param = w0 + (2.0 * lam - 1.0) * u
381
+ param_a.copy_(merged_param.to(dtype=param_a.dtype))
382
+ merged += 1
383
+ return merged
384
+
385
+
386
+ def distill_reparam_merge(
387
+ student_model: torch.nn.Module,
388
+ student_parent: object,
389
+ student_layer_attr: str,
390
+ student_layers: List[torch.nn.Module],
391
+ teacher_model: torch.nn.Module,
392
+ teacher_parent: object,
393
+ teacher_layer_attr: str,
394
+ teacher_layers: List[torch.nn.Module],
395
+ layer_idx: int,
396
+ gate_lambdas: Dict[str, object],
397
+ dataloader,
398
+ args: argparse.Namespace,
399
+ progressive_cycle: Optional[int] = None,
400
+ progressive_total: Optional[int] = None,
401
+ ) -> Tuple[int, Dict[str, torch.Tensor], Dict[str, object]]:
402
+ """Reparameterized distillation that materializes a fused layer into layer_a.
403
+
404
+ Trains U and gate logits s (lambda = sigmoid(s)) using:
405
+ - composition MSE + distill-KL
406
+ - eta * ||lambda - lambda_gate||^2 + gamma * ||U - U0||^2
407
+ """
408
+ total_epochs = float(args.distill_epochs)
409
+
410
+ hidden_mse_weight = float(getattr(args, "distill_hidden_mse_weight", 1.0))
411
+ if hidden_mse_weight < 0.0:
412
+ raise SystemExit("--distill_hidden_mse_weight must be >= 0")
413
+ attn_mse_weight = float(getattr(args, "distill_attn_mse_weight", 0.0))
414
+ if attn_mse_weight < 0.0:
415
+ raise SystemExit("--distill_attn_mse_weight must be >= 0")
416
+ mlp_mse_weight = float(getattr(args, "distill_mlp_mse_weight", 0.0))
417
+ if mlp_mse_weight < 0.0:
418
+ raise SystemExit("--distill_mlp_mse_weight must be >= 0")
419
+ param_subset = str(getattr(args, "reparam_param_subset", "all"))
420
+ if param_subset not in {"all", "mlp", "attn"}:
421
+ raise SystemExit("--reparam_param_subset must be one of: all, mlp, attn")
422
+
423
+ kl_weight = float(args.distill_kl_weight)
424
+ kl_temp = float(args.distill_kl_temp)
425
+ if kl_weight < 0.0:
426
+ raise SystemExit("--distill_kl_weight must be >= 0")
427
+ if kl_temp <= 0.0:
428
+ raise SystemExit("--distill_kl_temp must be > 0")
429
+
430
+ eta = float(getattr(args, "reparam_eta", 0.0))
431
+ gamma = float(getattr(args, "reparam_gamma", 0.0))
432
+ if eta < 0.0:
433
+ raise SystemExit("--reparam_eta must be >= 0")
434
+ if gamma < 0.0:
435
+ raise SystemExit("--reparam_gamma must be >= 0")
436
+ attn_reg_scale = float(getattr(args, "reparam_attn_reg_scale", 1.0))
437
+ mlp_reg_scale = float(getattr(args, "reparam_mlp_reg_scale", 1.0))
438
+ if attn_reg_scale < 0.0:
439
+ raise SystemExit("--reparam_attn_reg_scale must be >= 0")
440
+ if mlp_reg_scale < 0.0:
441
+ raise SystemExit("--reparam_mlp_reg_scale must be >= 0")
442
+ if (
443
+ total_epochs > 0.0
444
+ and hidden_mse_weight == 0.0
445
+ and attn_mse_weight == 0.0
446
+ and mlp_mse_weight == 0.0
447
+ and kl_weight == 0.0
448
+ and eta == 0.0
449
+ and gamma == 0.0
450
+ ):
451
+ raise SystemExit(
452
+ "Reparam distillation has no active loss terms. "
453
+ "Enable hidden/attention/MLP MSE, KL, or at least one reparam regularizer."
454
+ )
455
+
456
+ if not gate_lambdas:
457
+ raise SystemExit("Reparam distillation requires non-empty gate lambdas.")
458
+
459
+ layer_a = student_layers[layer_idx]
460
+ layer_b = student_layers[layer_idx + 1]
461
+
462
+ reparam_layer = ReparamMergedLayer(
463
+ layer_a,
464
+ layer_b,
465
+ gate_lambdas,
466
+ param_subset=param_subset,
467
+ clamp_eps=1e-4,
468
+ )
469
+ if not reparam_layer._name_map:
470
+ raise RuntimeError(
471
+ "No mergeable parameters found for reparam distillation under "
472
+ f"--reparam_param_subset={param_subset!r}."
473
+ )
474
+
475
+ teacher_attn = None
476
+ student_attn = None
477
+ if attn_mse_weight > 0.0:
478
+ try:
479
+ teacher_attn = find_attention_module(teacher_layers[layer_idx + 1])
480
+ student_attn = find_attention_module(reparam_layer.layer_a)
481
+ except ValueError as exc:
482
+ raise SystemExit(
483
+ "Attention-output preservation was requested but an attention module "
484
+ f"could not be resolved: {exc}"
485
+ ) from exc
486
+
487
+ teacher_mlp = None
488
+ student_mlp = None
489
+ if mlp_mse_weight > 0.0:
490
+ try:
491
+ teacher_mlp = find_mlp_module(teacher_layers[layer_idx + 1])
492
+ student_mlp = find_mlp_module(reparam_layer.layer_a)
493
+ except ValueError as exc:
494
+ raise SystemExit(
495
+ "MLP-output preservation was requested but an MLP module could not be "
496
+ f"resolved: {exc}"
497
+ ) from exc
498
+
499
+ # Virtual layer list: replace layer_a with reparam layer and remove layer_b.
500
+ virtual_layers = list(student_layers)
501
+ virtual_layers[layer_idx] = reparam_layer
502
+ del virtual_layers[layer_idx + 1]
503
+
504
+ # Only (U, s) are trainable.
505
+ for param in student_model.parameters():
506
+ param.requires_grad_(False)
507
+ for param in reparam_layer.gates.parameters():
508
+ param.requires_grad_(True)
509
+ for param in reparam_layer.u.parameters():
510
+ param.requires_grad_(True)
511
+
512
+ do_train = total_epochs > 0.0
513
+ if do_train:
514
+ teacher_model.eval()
515
+ student_model.train()
516
+
517
+ # Rough memory heads-up (esp. when --fisher_mode param makes per-element gates).
518
+ total_gate_elems = sum(int(p.numel()) for p in reparam_layer.gates.parameters())
519
+ total_u_elems = sum(int(p.numel()) for p in reparam_layer.u.parameters())
520
+ gate_mib = total_gate_elems * 4.0 / (1024.0 * 1024.0)
521
+ u_mib = total_u_elems * 4.0 / (1024.0 * 1024.0)
522
+ family_counts: Dict[str, int] = {"attn": 0, "mlp": 0, "other": 0}
523
+ for orig in reparam_layer._name_map:
524
+ family_counts[_classify_param_family(orig)] += 1
525
+ print(
526
+ f"[reparam] subset={param_subset} gates={len(reparam_layer.gates)} "
527
+ f"(attn={family_counts['attn']}, mlp={family_counts['mlp']}, other={family_counts['other']}) "
528
+ f"elems={total_gate_elems} (~{gate_mib:.1f} MiB), "
529
+ f"U_elems={total_u_elems} (~{u_mib:.1f} MiB; +optimizer state)"
530
+ )
531
+
532
+ optimizer = None
533
+ if do_train:
534
+ optimizer = torch.optim.AdamW(
535
+ [*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()],
536
+ lr=float(args.distill_lr),
537
+ weight_decay=float(args.distill_weight_decay),
538
+ )
539
+
540
+ device_type = torch.device(args.device).type
541
+ amp_dtype = None
542
+ if args.dtype == "float16":
543
+ amp_dtype = torch.float16
544
+ elif args.dtype == "bfloat16":
545
+ amp_dtype = torch.bfloat16
546
+ use_amp = do_train and amp_dtype is not None and device_type == "cuda"
547
+ use_scaler = use_amp and amp_dtype == torch.float16
548
+ scaler = torch.cuda.amp.GradScaler() if use_scaler else None
549
+
550
+ full_epochs = int(total_epochs) if do_train else 0
551
+ fractional = (total_epochs - full_epochs) if do_train else 0.0
552
+ if fractional < 1e-8:
553
+ fractional = 0.0
554
+
555
+ epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)]
556
+ if fractional > 0:
557
+ try:
558
+ batches_per_epoch = len(dataloader)
559
+ except TypeError as exc:
560
+ raise SystemExit(
561
+ "Fractional distill epochs require a dataloader with finite length."
562
+ ) from exc
563
+ if batches_per_epoch > 0:
564
+ frac_batches = int(round(fractional * batches_per_epoch))
565
+ if frac_batches <= 0:
566
+ frac_batches = 1
567
+ epoch_plan.append((full_epochs, frac_batches))
568
+
569
+ grad_accum = int(getattr(args, "distill_grad_accum_steps", 1))
570
+ if grad_accum <= 0:
571
+ raise SystemExit("--distill_grad_accum_steps must be >= 1")
572
+
573
+ log_steps = int(getattr(args, "distill_log_steps", 100))
574
+ max_grad_norm = getattr(args, "distill_max_grad_norm", 1.0)
575
+
576
+ params_a = {name: p for name, p in layer_a.named_parameters()}
577
+ params_b = {name: p for name, p in layer_b.named_parameters()}
578
+
579
+ step = 0
580
+ for epoch_idx, max_batches in epoch_plan:
581
+ if max_batches is None:
582
+ epoch_iter = dataloader
583
+ else:
584
+ epoch_iter = itertools.islice(dataloader, max_batches)
585
+ iterator = epoch_iter
586
+ if tqdm is not None and _tqdm_enabled():
587
+ if progressive_cycle is not None:
588
+ if progressive_total is not None:
589
+ desc = (
590
+ f"Reparam (cycle {progressive_cycle}/{progressive_total}, "
591
+ f"epoch {epoch_idx+1})"
592
+ )
593
+ else:
594
+ desc = f"Reparam (cycle {progressive_cycle}, epoch {epoch_idx+1})"
595
+ else:
596
+ desc = f"Reparam (epoch {epoch_idx+1})"
597
+ iterator = tqdm(epoch_iter, desc=desc, unit="batch", total=max_batches)
598
+
599
+ for batch in iterator:
600
+ input_ids = batch[0].to(args.device)
601
+ attention_mask = batch[1].to(args.device)
602
+ teacher_ids = input_ids.to(args.distill_teacher_device or args.device)
603
+ teacher_mask = attention_mask.to(args.distill_teacher_device or args.device)
604
+
605
+ teacher_depth = layer_idx + 2
606
+ student_depth = layer_idx + 1
607
+
608
+ autocast_ctx = (
609
+ torch.autocast(device_type=device_type, dtype=amp_dtype)
610
+ if use_amp
611
+ else nullcontext()
612
+ )
613
+ with autocast_ctx:
614
+ teacher_attn_ctx = (
615
+ capture_module_output(teacher_attn)
616
+ if teacher_attn is not None
617
+ else nullcontext({"output": None})
618
+ )
619
+ teacher_mlp_ctx = (
620
+ capture_module_output(teacher_mlp)
621
+ if teacher_mlp is not None
622
+ else nullcontext({"output": None})
623
+ )
624
+ with torch.no_grad():
625
+ with teacher_attn_ctx as teacher_attn_cache, teacher_mlp_ctx as teacher_mlp_cache:
626
+ teacher_hidden = forward_truncated(
627
+ teacher_parent,
628
+ teacher_layer_attr,
629
+ teacher_layers,
630
+ teacher_depth,
631
+ teacher_ids,
632
+ attention_mask=teacher_mask,
633
+ )
634
+
635
+ student_attn_ctx = (
636
+ capture_module_output(student_attn)
637
+ if student_attn is not None
638
+ else nullcontext({"output": None})
639
+ )
640
+ student_mlp_ctx = (
641
+ capture_module_output(student_mlp)
642
+ if student_mlp is not None
643
+ else nullcontext({"output": None})
644
+ )
645
+ with student_attn_ctx as student_attn_cache, student_mlp_ctx as student_mlp_cache:
646
+ student_hidden = forward_truncated(
647
+ student_parent,
648
+ student_layer_attr,
649
+ virtual_layers,
650
+ student_depth,
651
+ input_ids,
652
+ attention_mask=attention_mask,
653
+ )
654
+
655
+ if teacher_hidden.device != student_hidden.device:
656
+ teacher_hidden = teacher_hidden.to(student_hidden.device)
657
+
658
+ mse_loss = None
659
+ if hidden_mse_weight > 0.0:
660
+ diff = student_hidden - teacher_hidden
661
+ mse_loss = _masked_hidden_mse(diff, attention_mask)
662
+ if mse_loss is None:
663
+ continue
664
+
665
+ attn_aux_loss = None
666
+ if attn_mse_weight > 0.0:
667
+ teacher_attn_hidden = teacher_attn_cache.get("output")
668
+ student_attn_hidden = student_attn_cache.get("output")
669
+ if teacher_attn_hidden is None or student_attn_hidden is None:
670
+ raise RuntimeError(
671
+ "Attention-output preservation is enabled, but the forward "
672
+ "hook did not capture attention outputs."
673
+ )
674
+ if teacher_attn_hidden.device != student_attn_hidden.device:
675
+ teacher_attn_hidden = teacher_attn_hidden.to(student_attn_hidden.device)
676
+ attn_aux_loss = _masked_hidden_mse(
677
+ student_attn_hidden - teacher_attn_hidden,
678
+ attention_mask,
679
+ )
680
+ if attn_aux_loss is None:
681
+ continue
682
+
683
+ mlp_aux_loss = None
684
+ if mlp_mse_weight > 0.0:
685
+ teacher_mlp_hidden = teacher_mlp_cache.get("output")
686
+ student_mlp_hidden = student_mlp_cache.get("output")
687
+ if teacher_mlp_hidden is None or student_mlp_hidden is None:
688
+ raise RuntimeError(
689
+ "MLP-output preservation is enabled, but the forward hook "
690
+ "did not capture MLP outputs."
691
+ )
692
+ if teacher_mlp_hidden.device != student_mlp_hidden.device:
693
+ teacher_mlp_hidden = teacher_mlp_hidden.to(student_mlp_hidden.device)
694
+ mlp_aux_loss = _masked_hidden_mse(
695
+ student_mlp_hidden - teacher_mlp_hidden,
696
+ attention_mask,
697
+ )
698
+ if mlp_aux_loss is None:
699
+ continue
700
+
701
+ kl_loss = None
702
+ if kl_weight > 0.0:
703
+ with torch.no_grad():
704
+ teacher_outputs = teacher_model(
705
+ input_ids=teacher_ids,
706
+ attention_mask=teacher_mask,
707
+ use_cache=False,
708
+ )
709
+ teacher_logits = teacher_outputs.logits
710
+
711
+ virtual_container = torch.nn.ModuleList(virtual_layers)
712
+ with temporary_layers(
713
+ student_parent, student_layer_attr, virtual_container
714
+ ):
715
+ student_outputs = student_model(
716
+ input_ids=input_ids,
717
+ attention_mask=attention_mask,
718
+ use_cache=False,
719
+ )
720
+ student_logits = student_outputs.logits
721
+ if teacher_logits.device != student_logits.device:
722
+ teacher_logits = teacher_logits.to(student_logits.device)
723
+
724
+ shift_teacher_logits = teacher_logits[:, :-1, :].contiguous()
725
+ shift_student_logits = student_logits[:, :-1, :].contiguous()
726
+ shift_mask = attention_mask[:, 1:].contiguous()
727
+ log_p_t = F.log_softmax(shift_teacher_logits / kl_temp, dim=-1)
728
+ log_p_s = F.log_softmax(shift_student_logits / kl_temp, dim=-1)
729
+ p_t = log_p_t.exp()
730
+ kl_flat = (p_t * (log_p_t - log_p_s)).sum(dim=-1)
731
+ kl_denom = shift_mask.sum()
732
+ if kl_denom.item() == 0:
733
+ continue
734
+ kl_loss = (
735
+ kl_flat * shift_mask.to(kl_flat.dtype)
736
+ ).sum() / kl_denom
737
+
738
+ lambda_reg = None
739
+ if eta > 0.0:
740
+ reg_sum: Optional[torch.Tensor] = None
741
+ reg_elems = 0
742
+ for orig, safe in reparam_layer._name_map.items():
743
+ lam = torch.sigmoid(reparam_layer.gates[safe]).float()
744
+ target = gate_lambdas.get(orig)
745
+ if target is None:
746
+ target_t = 0.5
747
+ elif isinstance(target, torch.Tensor):
748
+ target_t = target.to(device=lam.device, dtype=lam.dtype)
749
+ else:
750
+ target_t = float(target)
751
+ diff_lam = lam - target_t
752
+ family = _classify_param_family(orig)
753
+ scale = _family_reg_scale(
754
+ family,
755
+ attn_scale=attn_reg_scale,
756
+ mlp_scale=mlp_reg_scale,
757
+ )
758
+ if scale <= 0.0:
759
+ continue
760
+ part = diff_lam.pow(2).sum() * scale
761
+ reg_sum = part if reg_sum is None else reg_sum + part
762
+ reg_elems += int(float(diff_lam.numel()) * scale)
763
+ if reg_elems > 0 and reg_sum is not None:
764
+ lambda_reg = reg_sum / float(reg_elems)
765
+
766
+ u_reg = None
767
+ if gamma > 0.0:
768
+ reg_sum: Optional[torch.Tensor] = None
769
+ reg_elems = 0
770
+ for orig, safe in reparam_layer._name_map.items():
771
+ u = reparam_layer.u[safe].float()
772
+ param_a = params_a.get(orig)
773
+ param_b = params_b.get(orig)
774
+ if param_a is None or param_b is None or param_b.shape != param_a.shape:
775
+ continue
776
+ u0 = 0.5 * (param_a.detach().float() - param_b.detach().float())
777
+ diff_u = u - u0
778
+ family = _classify_param_family(orig)
779
+ scale = _family_reg_scale(
780
+ family,
781
+ attn_scale=attn_reg_scale,
782
+ mlp_scale=mlp_reg_scale,
783
+ )
784
+ if scale <= 0.0:
785
+ continue
786
+ part = diff_u.pow(2).sum() * scale
787
+ reg_sum = part if reg_sum is None else reg_sum + part
788
+ reg_elems += int(float(diff_u.numel()) * scale)
789
+ if reg_elems > 0 and reg_sum is not None:
790
+ u_reg = reg_sum / float(reg_elems)
791
+
792
+ total_loss = None
793
+ if mse_loss is not None:
794
+ total_loss = hidden_mse_weight * mse_loss
795
+ if attn_aux_loss is not None:
796
+ total_loss = attn_mse_weight * attn_aux_loss if total_loss is None else total_loss + (attn_mse_weight * attn_aux_loss)
797
+ if mlp_aux_loss is not None:
798
+ total_loss = mlp_mse_weight * mlp_aux_loss if total_loss is None else total_loss + (mlp_mse_weight * mlp_aux_loss)
799
+ if kl_loss is not None:
800
+ total_loss = kl_weight * (kl_temp ** 2) * kl_loss if total_loss is None else total_loss + (kl_weight * (kl_temp ** 2) * kl_loss)
801
+ if lambda_reg is not None:
802
+ total_loss = eta * lambda_reg if total_loss is None else total_loss + (eta * lambda_reg)
803
+ if u_reg is not None:
804
+ total_loss = gamma * u_reg if total_loss is None else total_loss + (gamma * u_reg)
805
+ if total_loss is None:
806
+ continue
807
+
808
+ if grad_accum > 1:
809
+ total_loss = total_loss / grad_accum
810
+ if use_scaler:
811
+ scaler.scale(total_loss).backward()
812
+ else:
813
+ total_loss.backward()
814
+
815
+ if (step + 1) % grad_accum == 0:
816
+ if max_grad_norm is not None:
817
+ if use_scaler:
818
+ scaler.unscale_(optimizer)
819
+ torch.nn.utils.clip_grad_norm_(
820
+ [*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()],
821
+ float(max_grad_norm),
822
+ )
823
+ if use_scaler:
824
+ scaler.step(optimizer)
825
+ scaler.update()
826
+ else:
827
+ optimizer.step()
828
+ optimizer.zero_grad(set_to_none=True)
829
+
830
+ if log_steps and (step == 0 or (step + 1) % log_steps == 0):
831
+ log_parts = [f"loss={total_loss.item():.6e}"]
832
+ if mse_loss is not None:
833
+ log_parts.append(f"mse={mse_loss.item():.6e}")
834
+ else:
835
+ log_parts.append("mse=disabled")
836
+ if attn_aux_loss is not None:
837
+ log_parts.append(f"attn_mse={attn_aux_loss.item():.6e}")
838
+ elif attn_mse_weight > 0.0:
839
+ log_parts.append("attn_mse=nan")
840
+ if mlp_aux_loss is not None:
841
+ log_parts.append(f"mlp_mse={mlp_aux_loss.item():.6e}")
842
+ elif mlp_mse_weight > 0.0:
843
+ log_parts.append("mlp_mse=nan")
844
+ if kl_loss is not None:
845
+ log_parts.append(f"kl={kl_loss.item():.6e}")
846
+ if lambda_reg is not None:
847
+ log_parts.append(f"lam_reg={lambda_reg.item():.6e}")
848
+ if u_reg is not None:
849
+ log_parts.append(f"u_reg={u_reg.item():.6e}")
850
+ print(
851
+ f"[reparam] epoch={epoch_idx+1} step={step+1} " + " ".join(log_parts)
852
+ )
853
+ step += 1
854
+
855
+ merged = reparam_layer.materialize_into_layer_a()
856
+ final_lambdas = reparam_layer.gate_lambdas()
857
+ stats: Dict[str, object] = {
858
+ "enabled": True,
859
+ "epochs": total_epochs,
860
+ "lr": float(args.distill_lr),
861
+ "hidden_mse_weight": hidden_mse_weight,
862
+ "attn_mse_weight": attn_mse_weight,
863
+ "mlp_mse_weight": mlp_mse_weight,
864
+ "eta": eta,
865
+ "gamma": gamma,
866
+ "attn_reg_scale": attn_reg_scale,
867
+ "mlp_reg_scale": mlp_reg_scale,
868
+ "param_subset": param_subset,
869
+ "num_gates": len(final_lambdas),
870
+ "num_attn_gates": family_counts["attn"],
871
+ "num_mlp_gates": family_counts["mlp"],
872
+ "num_other_gates": family_counts["other"],
873
+ }
874
+ return merged, final_lambdas, stats
875
+
876
+
877
+ class LoRALinear(torch.nn.Module):
878
+ def __init__(
879
+ self,
880
+ base: torch.nn.Linear,
881
+ rank: int,
882
+ alpha: float,
883
+ dropout: float,
884
+ ) -> None:
885
+ super().__init__()
886
+ if rank <= 0:
887
+ raise ValueError("LoRA rank must be positive")
888
+ self.base = base
889
+ self.rank = int(rank)
890
+ self.alpha = float(alpha)
891
+ self.scaling = self.alpha / float(self.rank)
892
+ self.enabled = True
893
+ if dropout > 0:
894
+ self.dropout = torch.nn.Dropout(dropout)
895
+ else:
896
+ self.dropout = torch.nn.Identity()
897
+
898
+ self.lora_A = torch.nn.Linear(base.in_features, self.rank, bias=False)
899
+ self.lora_B = torch.nn.Linear(self.rank, base.out_features, bias=False)
900
+ torch.nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
901
+ torch.nn.init.zeros_(self.lora_B.weight)
902
+
903
+ self.lora_A.to(device=base.weight.device, dtype=base.weight.dtype)
904
+ self.lora_B.to(device=base.weight.device, dtype=base.weight.dtype)
905
+ self.merged = False
906
+
907
+ def lora_parameters(self) -> List[torch.nn.Parameter]:
908
+ return [*self.lora_A.parameters(), *self.lora_B.parameters()]
909
+
910
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
911
+ result = self.base(x)
912
+ if self.merged or not self.enabled:
913
+ return result
914
+ lora_out = self.lora_B(self.lora_A(self.dropout(x)))
915
+ return result + lora_out * self.scaling
916
+
917
+ def merge(self) -> None:
918
+ if self.merged:
919
+ return
920
+ delta = torch.matmul(self.lora_B.weight, self.lora_A.weight)
921
+ delta = delta.to(dtype=self.base.weight.dtype) * self.scaling
922
+ self.base.weight.data.add_(delta)
923
+ self.merged = True
924
+
925
+
926
+ def _get_child_module(parent: torch.nn.Module, part: str) -> torch.nn.Module:
927
+ if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit():
928
+ return parent[int(part)]
929
+ if isinstance(parent, torch.nn.ModuleDict):
930
+ return parent[part]
931
+ return getattr(parent, part)
932
+
933
+
934
+ def _set_child_module(parent: torch.nn.Module, part: str, module: torch.nn.Module) -> None:
935
+ if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit():
936
+ parent[int(part)] = module
937
+ return
938
+ if isinstance(parent, torch.nn.ModuleDict):
939
+ parent[part] = module
940
+ return
941
+ setattr(parent, part, module)
942
+
943
+
944
+ def _resolve_parent_module(
945
+ root: torch.nn.Module, module_name: str
946
+ ) -> Optional[tuple]:
947
+ if not module_name:
948
+ return None
949
+ parts = module_name.split(".")
950
+ parent = root
951
+ for part in parts[:-1]:
952
+ parent = _get_child_module(parent, part)
953
+ return parent, parts[-1]
954
+
955
+
956
+ def _resolve_module_by_path(root: torch.nn.Module, module_path: str) -> Optional[torch.nn.Module]:
957
+ if not module_path:
958
+ return None
959
+ parts = [part for part in module_path.split(".") if part]
960
+ node = root
961
+ for part in parts:
962
+ try:
963
+ node = _get_child_module(node, part)
964
+ except Exception:
965
+ return None
966
+ return node
967
+
968
+
969
+ def _resolve_layer_container_for_lora(
970
+ model: torch.nn.Module, layer_path: Optional[str]
971
+ ) -> Tuple[Optional[str], Optional[object]]:
972
+ """Resolve transformer layer container with optional auto-detection.
973
+
974
+ Mirrors the candidate path strategy used elsewhere, so LoRA filtering can work
975
+ even when --layer_path is not provided.
976
+ """
977
+ if isinstance(layer_path, str) and layer_path and layer_path.lower() != "none":
978
+ container = _resolve_module_by_path(model, layer_path)
979
+ if container is not None:
980
+ try:
981
+ list(container)
982
+ return layer_path, container
983
+ except TypeError:
984
+ pass
985
+
986
+ candidate_paths = [
987
+ "model.layers", # LLaMA, Mistral, Qwen2, Gemma
988
+ "model.decoder.layers", # OPT
989
+ "transformer.h", # GPT-2, GPT-J, Bloom, Falcon
990
+ "transformer.blocks", # MPT
991
+ "gpt_neox.layers", # GPT-NeoX
992
+ "layers", # fallback
993
+ ]
994
+ for path in candidate_paths:
995
+ container = _resolve_module_by_path(model, path)
996
+ if container is None:
997
+ continue
998
+ try:
999
+ list(container)
1000
+ except TypeError:
1001
+ continue
1002
+ return path, container
1003
+
1004
+ return None, None
1005
+
1006
+
1007
+ def _parse_exclude_pairs_local(raw_values, num_pairs: int) -> Set[int]:
1008
+ if not raw_values or num_pairs <= 0:
1009
+ return set()
1010
+ exclude: Set[int] = set()
1011
+ for item in raw_values:
1012
+ if item is None:
1013
+ continue
1014
+ for part in str(item).split(","):
1015
+ part = part.strip()
1016
+ if not part:
1017
+ continue
1018
+ try:
1019
+ idx = int(part)
1020
+ except ValueError as exc:
1021
+ raise SystemExit("--exclude_pairs must contain integers.") from exc
1022
+ if idx < 0:
1023
+ idx = num_pairs + idx
1024
+ if 0 <= idx < num_pairs:
1025
+ exclude.add(idx)
1026
+ return exclude
1027
+
1028
+
1029
+ def _extract_layer_index_from_module_name(
1030
+ module_name: str, layer_path: str
1031
+ ) -> Optional[int]:
1032
+ if not layer_path:
1033
+ return None
1034
+ prefix = f"{layer_path}."
1035
+ if not module_name.startswith(prefix):
1036
+ return None
1037
+ rest = module_name[len(prefix) :]
1038
+ if not rest:
1039
+ return None
1040
+ idx_text = rest.split(".", 1)[0]
1041
+ if not idx_text.isdigit():
1042
+ return None
1043
+ return int(idx_text)
1044
+
1045
+
1046
+ def _select_linear_modules_for_lora_targets(
1047
+ model: torch.nn.Module,
1048
+ args: argparse.Namespace,
1049
+ *,
1050
+ log_tag: str,
1051
+ ) -> Tuple[List[Tuple[str, torch.nn.Linear]], Optional[Set[str]], Set[int], Optional[str]]:
1052
+ raw_targets = getattr(args, "lora_target_modules", None)
1053
+ target_modules: Optional[Set[str]] = None
1054
+ if raw_targets:
1055
+ target_modules = {str(item) for item in raw_targets if str(item)}
1056
+
1057
+ exclude_layer_indices: Set[int] = set()
1058
+ resolved_layer_path: Optional[str] = None
1059
+ if bool(getattr(args, "lora_respect_exclude_pairs", False)):
1060
+ requested_layer_path = getattr(args, "layer_path", None)
1061
+ resolved_layer_path, layer_container = _resolve_layer_container_for_lora(
1062
+ model, requested_layer_path
1063
+ )
1064
+ if isinstance(layer_container, (torch.nn.ModuleList, list, tuple)):
1065
+ num_pairs = max(len(layer_container) - 1, 0)
1066
+ exclude_pairs = _parse_exclude_pairs_local(
1067
+ getattr(args, "exclude_pairs", None), num_pairs
1068
+ )
1069
+ for pair_idx in exclude_pairs:
1070
+ exclude_layer_indices.add(pair_idx)
1071
+ exclude_layer_indices.add(pair_idx + 1)
1072
+ else:
1073
+ print(
1074
+ f"[{log_tag}] Warning: --lora_respect_exclude_pairs enabled, but "
1075
+ f"could not resolve layer path '{requested_layer_path}'."
1076
+ )
1077
+
1078
+ linear_modules = [
1079
+ (name, module)
1080
+ for name, module in model.named_modules()
1081
+ if isinstance(module, torch.nn.Linear)
1082
+ and (target_modules is None or name.split(".")[-1] in target_modules)
1083
+ and (
1084
+ not exclude_layer_indices
1085
+ or _extract_layer_index_from_module_name(name, resolved_layer_path or "")
1086
+ not in exclude_layer_indices
1087
+ )
1088
+ ]
1089
+ return linear_modules, target_modules, exclude_layer_indices, resolved_layer_path
1090
+
1091
+
1092
+ def apply_lora_adapters(
1093
+ model: torch.nn.Module, args: argparse.Namespace
1094
+ ) -> List[LoRALinear]:
1095
+ if args.lora_rank <= 0:
1096
+ raise SystemExit("--lora_rank must be > 0 when --lora_epochs > 0")
1097
+ linear_modules, target_modules, exclude_layer_indices, _ = (
1098
+ _select_linear_modules_for_lora_targets(model, args, log_tag="lora")
1099
+ )
1100
+ if not linear_modules:
1101
+ raise SystemExit(
1102
+ "No Linear modules found for LoRA adapters "
1103
+ "(check --lora_target_modules / --exclude_pairs / --lora_respect_exclude_pairs)."
1104
+ )
1105
+
1106
+ lora_modules: List[LoRALinear] = []
1107
+ for name, module in linear_modules:
1108
+ resolved = _resolve_parent_module(model, name)
1109
+ if resolved is None:
1110
+ continue
1111
+ parent, attr = resolved
1112
+ wrapped = LoRALinear(
1113
+ base=module,
1114
+ rank=args.lora_rank,
1115
+ alpha=args.lora_alpha,
1116
+ dropout=args.lora_dropout,
1117
+ )
1118
+ _set_child_module(parent, attr, wrapped)
1119
+ lora_modules.append(wrapped)
1120
+
1121
+ for param in model.parameters():
1122
+ param.requires_grad_(False)
1123
+ for lora_module in lora_modules:
1124
+ for param in lora_module.lora_parameters():
1125
+ param.requires_grad_(True)
1126
+
1127
+ total_params = sum(p.numel() for p in model.parameters())
1128
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1129
+ percent = 100.0 * trainable_params / max(total_params, 1)
1130
+ target_note = ""
1131
+ if target_modules is not None:
1132
+ target_note = f" target={sorted(target_modules)}"
1133
+ exclude_note = ""
1134
+ if exclude_layer_indices:
1135
+ exclude_note = f" excluded_layers={sorted(exclude_layer_indices)}"
1136
+ print(
1137
+ "[lora] Applied adapters to "
1138
+ f"{len(lora_modules)} linear modules "
1139
+ f"({trainable_params}/{total_params} trainable, {percent:.4f}%)."
1140
+ f"{target_note}{exclude_note}"
1141
+ )
1142
+ return lora_modules
1143
+
1144
+
1145
+ def merge_lora_adapters(model: torch.nn.Module) -> None:
1146
+ lora_entries = [
1147
+ (name, module)
1148
+ for name, module in model.named_modules()
1149
+ if isinstance(module, LoRALinear)
1150
+ ]
1151
+ for name, module in lora_entries:
1152
+ module.merge()
1153
+ resolved = _resolve_parent_module(model, name)
1154
+ if resolved is None:
1155
+ continue
1156
+ parent, attr = resolved
1157
+ _set_child_module(parent, attr, module.base)
1158
+
1159
+
1160
+ def set_lora_enabled(lora_modules: List[LoRALinear], enabled: bool) -> None:
1161
+ for module in lora_modules:
1162
+ module.enabled = enabled
1163
+
1164
+
1165
+ def lora_ce_finetune(
1166
+ model: torch.nn.Module,
1167
+ dataloader,
1168
+ eval_tokenizer,
1169
+ eval_datasets: List[str],
1170
+ eval_configs: List[Optional[str]],
1171
+ eval_history: List[Dict[str, object]],
1172
+ args: argparse.Namespace,
1173
+ eval_dataloaders: Optional[Dict[str, object]] = None,
1174
+ progressive_cycle: Optional[int] = None,
1175
+ progressive_total: Optional[int] = None,
1176
+ ) -> None:
1177
+ total_epochs = float(args.lora_epochs)
1178
+ if total_epochs <= 0:
1179
+ return
1180
+
1181
+ use_kl = bool(getattr(args, "lora_kl_enabled", False))
1182
+ kl_weight = float(getattr(args, "lora_kl_weight", 0.0))
1183
+ kl_temp = float(getattr(args, "lora_kl_temp", 1.0))
1184
+ if use_kl:
1185
+ if kl_weight < 0.0:
1186
+ raise SystemExit("--lora_kl_weight must be >= 0")
1187
+ if kl_temp <= 0.0:
1188
+ raise SystemExit("--lora_kl_temp must be > 0")
1189
+ if kl_weight == 0.0:
1190
+ use_kl = False
1191
+
1192
+ lora_modules = apply_lora_adapters(model, args)
1193
+ if not lora_modules:
1194
+ return
1195
+
1196
+ model.train()
1197
+
1198
+ lora_params = []
1199
+ for module in lora_modules:
1200
+ lora_params.extend(module.lora_parameters())
1201
+
1202
+ optimizer = torch.optim.AdamW(
1203
+ lora_params,
1204
+ lr=args.lora_lr,
1205
+ weight_decay=args.lora_weight_decay,
1206
+ )
1207
+
1208
+ device_type = torch.device(args.device).type
1209
+ amp_dtype = None
1210
+ if args.dtype == "float16":
1211
+ amp_dtype = torch.float16
1212
+ elif args.dtype == "bfloat16":
1213
+ amp_dtype = torch.bfloat16
1214
+ use_amp = amp_dtype is not None and device_type == "cuda"
1215
+ use_scaler = use_amp and amp_dtype == torch.float16
1216
+ scaler = torch.cuda.amp.GradScaler() if use_scaler else None
1217
+
1218
+ full_epochs = int(total_epochs)
1219
+ fractional = total_epochs - full_epochs
1220
+ if fractional < 1e-8:
1221
+ fractional = 0.0
1222
+
1223
+ epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)]
1224
+ if fractional > 0:
1225
+ try:
1226
+ batches_per_epoch = len(dataloader)
1227
+ except TypeError as exc:
1228
+ raise SystemExit(
1229
+ "Fractional lora epochs require a dataloader with finite length."
1230
+ ) from exc
1231
+ if batches_per_epoch > 0:
1232
+ frac_batches = int(round(fractional * batches_per_epoch))
1233
+ if frac_batches <= 0:
1234
+ frac_batches = 1
1235
+ epoch_plan.append((full_epochs, frac_batches))
1236
+
1237
+ step = 0
1238
+ for epoch_idx, max_batches in epoch_plan:
1239
+ if max_batches is None:
1240
+ epoch_iter = dataloader
1241
+ else:
1242
+ epoch_iter = itertools.islice(dataloader, max_batches)
1243
+ iterator = epoch_iter
1244
+ if tqdm is not None and _tqdm_enabled():
1245
+ if progressive_cycle is not None:
1246
+ if progressive_total is not None:
1247
+ desc = (
1248
+ f"LoRA (cycle {progressive_cycle}/{progressive_total}, "
1249
+ f"epoch {epoch_idx+1})"
1250
+ )
1251
+ else:
1252
+ desc = f"LoRA (cycle {progressive_cycle}, epoch {epoch_idx+1})"
1253
+ else:
1254
+ desc = f"LoRA (epoch {epoch_idx+1})"
1255
+ iterator = tqdm(
1256
+ epoch_iter,
1257
+ desc=desc,
1258
+ unit="batch",
1259
+ total=max_batches,
1260
+ )
1261
+ for batch in iterator:
1262
+ input_ids = batch[0].to(args.device)
1263
+ attention_mask = batch[1].to(args.device)
1264
+ autocast_ctx = (
1265
+ torch.autocast(device_type=device_type, dtype=amp_dtype)
1266
+ if use_amp
1267
+ else nullcontext()
1268
+ )
1269
+ with autocast_ctx:
1270
+ outputs = model(
1271
+ input_ids=input_ids,
1272
+ attention_mask=attention_mask,
1273
+ use_cache=False,
1274
+ )
1275
+ logits = outputs.logits
1276
+ shift_logits = logits[:, :-1, :].contiguous()
1277
+ shift_labels = input_ids[:, 1:].contiguous()
1278
+ shift_mask = attention_mask[:, 1:].contiguous()
1279
+ ce_flat = F.cross_entropy(
1280
+ shift_logits.view(-1, shift_logits.size(-1)),
1281
+ shift_labels.view(-1),
1282
+ reduction="none",
1283
+ )
1284
+ ce_denom = shift_mask.sum()
1285
+ if ce_denom.item() == 0:
1286
+ continue
1287
+ ce_loss = (
1288
+ ce_flat * shift_mask.view(-1).to(ce_flat.dtype)
1289
+ ).sum() / ce_denom
1290
+ kl_loss = None
1291
+ if use_kl:
1292
+ set_lora_enabled(lora_modules, False)
1293
+ with torch.no_grad():
1294
+ base_outputs = model(
1295
+ input_ids=input_ids,
1296
+ attention_mask=attention_mask,
1297
+ use_cache=False,
1298
+ )
1299
+ base_logits = base_outputs.logits
1300
+ set_lora_enabled(lora_modules, True)
1301
+ if base_logits.device != shift_logits.device:
1302
+ base_logits = base_logits.to(shift_logits.device)
1303
+ shift_base_logits = base_logits[:, :-1, :].contiguous()
1304
+ log_p_pre = F.log_softmax(shift_base_logits / kl_temp, dim=-1)
1305
+ log_p_post = F.log_softmax(shift_logits / kl_temp, dim=-1)
1306
+ p_pre = log_p_pre.exp()
1307
+ kl_flat = (p_pre * (log_p_pre - log_p_post)).sum(dim=-1)
1308
+ kl_loss = (
1309
+ kl_flat * shift_mask.to(kl_flat.dtype)
1310
+ ).sum() / ce_denom
1311
+
1312
+ total_loss = ce_loss
1313
+ if kl_loss is not None:
1314
+ total_loss = total_loss + (kl_weight * (kl_temp ** 2) * kl_loss)
1315
+
1316
+ if args.lora_grad_accum_steps > 1:
1317
+ total_loss = total_loss / args.lora_grad_accum_steps
1318
+ if use_scaler:
1319
+ scaler.scale(total_loss).backward()
1320
+ else:
1321
+ total_loss.backward()
1322
+
1323
+ if (step + 1) % args.lora_grad_accum_steps == 0:
1324
+ if args.lora_max_grad_norm is not None:
1325
+ if use_scaler:
1326
+ scaler.unscale_(optimizer)
1327
+ torch.nn.utils.clip_grad_norm_(
1328
+ lora_params,
1329
+ args.lora_max_grad_norm,
1330
+ )
1331
+ if use_scaler:
1332
+ scaler.step(optimizer)
1333
+ scaler.update()
1334
+ else:
1335
+ optimizer.step()
1336
+ optimizer.zero_grad(set_to_none=True)
1337
+
1338
+ if args.lora_eval_every and (step + 1) % args.lora_eval_every == 0:
1339
+ prev_mode = model.training
1340
+ model.eval()
1341
+ eval_device = args.eval_device or args.device
1342
+ if eval_dataloaders is not None:
1343
+ results = ppl_eval.evaluate_ppl_dataloaders(
1344
+ model,
1345
+ eval_dataloaders,
1346
+ eval_device,
1347
+ max_batches=args.lora_eval_max_batches,
1348
+ )
1349
+ else:
1350
+ results = ppl_eval.evaluate_ppl_datasets(
1351
+ model,
1352
+ eval_tokenizer,
1353
+ datasets=eval_datasets,
1354
+ configs=eval_configs,
1355
+ split=args.eval_split,
1356
+ text_field=args.eval_text_field,
1357
+ num_samples=args.eval_num_samples,
1358
+ seq_len=args.eval_seq_len,
1359
+ batch_size=args.eval_batch_size or args.batch_size,
1360
+ device=eval_device,
1361
+ seed=args.seed,
1362
+ shuffle=False,
1363
+ model_family=args.eval_model_family,
1364
+ add_bos=args.eval_add_bos,
1365
+ max_batches=args.lora_eval_max_batches,
1366
+ cache_dir=args.eval_cache_dir,
1367
+ num_workers=args.eval_num_workers,
1368
+ )
1369
+ eval_history.append({"step": step + 1, "ppl": results})
1370
+ print(f"[lora] eval step={step+1}: {results}")
1371
+ if prev_mode:
1372
+ model.train()
1373
+
1374
+ if args.lora_log_steps and (
1375
+ step == 0 or (step + 1) % args.lora_log_steps == 0
1376
+ ):
1377
+ log_parts = [f"loss={total_loss.item():.6f}"]
1378
+ if kl_loss is not None:
1379
+ log_parts.append(f"kl={kl_loss.item():.6f}")
1380
+ print(
1381
+ f"[lora] epoch={epoch_idx+1} step={step+1} "
1382
+ + " ".join(log_parts)
1383
+ )
1384
+ step += 1
1385
+
1386
+ merge_lora_adapters(model)
1387
+
1388
+
1389
+ def _masked_kl(
1390
+ logits_p: torch.Tensor,
1391
+ logits_q: torch.Tensor,
1392
+ attention_mask: torch.Tensor,
1393
+ temp: float,
1394
+ detach_p: bool = True,
1395
+ ) -> Optional[torch.Tensor]:
1396
+ shift_mask = attention_mask[:, 1:].contiguous()
1397
+ denom = shift_mask.sum()
1398
+ if denom.item() == 0:
1399
+ return None
1400
+
1401
+ p = logits_p[:, :-1, :].contiguous()
1402
+ q = logits_q[:, :-1, :].contiguous()
1403
+ if p.device != q.device:
1404
+ p = p.to(q.device)
1405
+
1406
+ # Keep dtype to avoid blowing up memory on large vocab models.
1407
+ log_p = F.log_softmax(p / temp, dim=-1)
1408
+ log_q = F.log_softmax(q / temp, dim=-1)
1409
+ if detach_p:
1410
+ log_p = log_p.detach()
1411
+ p_probs = log_p.exp()
1412
+ kl_flat = (p_probs * (log_p - log_q)).sum(dim=-1)
1413
+ return (kl_flat * shift_mask.to(kl_flat.dtype)).sum() / denom
1414
+
1415
+
1416
+ def _extract_hidden_tensor(output: object) -> Optional[torch.Tensor]:
1417
+ if isinstance(output, torch.Tensor):
1418
+ return output
1419
+ if isinstance(output, (tuple, list)) and output:
1420
+ first = output[0]
1421
+ if isinstance(first, torch.Tensor):
1422
+ return first
1423
+ return None
1424
+
1425
+
1426
+ def _grad_l2_norm(grads: List[Optional[torch.Tensor]]) -> float:
1427
+ total = 0.0
1428
+ for grad in grads:
1429
+ if grad is None:
1430
+ continue
1431
+ total += float(grad.detach().float().pow(2).sum().item())
1432
+ if total <= 0.0:
1433
+ return 0.0
1434
+ return float(math.sqrt(total))
1435
+
1436
+
1437
+ def _register_forward_pre_hook_with_optional_kwargs(layer, hook):
1438
+ try:
1439
+ handle = layer.register_forward_pre_hook(hook, with_kwargs=True)
1440
+ return handle
1441
+ except TypeError:
1442
+ def wrapper(module, inputs):
1443
+ return hook(module, inputs, None)
1444
+
1445
+ return layer.register_forward_pre_hook(wrapper)
1446
+
1447
+
1448
+ def commutator_precondition(
1449
+ student_model: torch.nn.Module,
1450
+ student_layers: List[torch.nn.Module],
1451
+ teacher_model: torch.nn.Module,
1452
+ dataloader,
1453
+ dwce_scores: Optional[List[float]],
1454
+ args: argparse.Namespace,
1455
+ exclude_pairs: Optional[Set[int]] = None,
1456
+ progressive_cycle: Optional[int] = None,
1457
+ progressive_total: Optional[int] = None,
1458
+ ) -> Dict[str, object]:
1459
+ """Run commutator-style preconditioning before pair fusion.
1460
+
1461
+ Objective on each sampled pair i:
1462
+ L = T^2 * KL(p_teacher || p_student) + mu * L_interaction(i)
1463
+
1464
+ Interaction loss is computed locally on block (i+1):
1465
+ r1 = B_{i+1}(h_{i+1}) - h_{i+1}
1466
+ r0 = B_{i+1}(h_i) - h_i
1467
+ L_interaction = ||r1-r0||^2 (or relative form).
1468
+ """
1469
+ if not bool(getattr(args, "comm_enabled", False)):
1470
+ return {"enabled": False}
1471
+ if not student_layers or len(student_layers) < 2:
1472
+ return {"enabled": False, "reason": "need_at_least_2_layers"}
1473
+
1474
+ temp = float(getattr(args, "comm_temp", 2.0))
1475
+ steps_ratio = float(getattr(args, "comm_steps_ratio", 0.1))
1476
+ lr_scale = float(getattr(args, "comm_lr_scale", 0.1))
1477
+ sample_eta = float(getattr(args, "comm_sample_eta", 0.5))
1478
+ sample_dwce_scale = float(getattr(args, "comm_sample_dwce_scale", 1.0))
1479
+ top_k = int(getattr(args, "comm_topk", 1))
1480
+ interaction_mode = str(getattr(args, "comm_interaction_mode", "relative")).strip().lower()
1481
+ interaction_eps = float(getattr(args, "comm_interaction_eps", 1e-8))
1482
+ mu_cfg = getattr(args, "comm_mu", None)
1483
+ mu_auto = bool(getattr(args, "comm_mu_auto", False))
1484
+ mu_auto_rho = float(getattr(args, "comm_mu_auto_rho", 0.1))
1485
+ mu_auto_eps = float(getattr(args, "comm_mu_auto_eps", 1e-8))
1486
+ comm_train_mode = str(getattr(args, "comm_train_mode", "lora")).strip().lower()
1487
+ log_steps = int(getattr(args, "comm_log_steps", 50))
1488
+
1489
+ if temp <= 0.0:
1490
+ raise SystemExit("--comm_temp must be > 0")
1491
+ if steps_ratio < 0.0:
1492
+ raise SystemExit("--comm_steps_ratio must be >= 0")
1493
+ if lr_scale <= 0.0:
1494
+ raise SystemExit("--comm_lr_scale must be > 0")
1495
+ if not (0.0 <= sample_eta <= 1.0):
1496
+ raise SystemExit("--comm_sample_eta must be in [0, 1]")
1497
+ if top_k <= 0:
1498
+ raise SystemExit("--comm_topk must be >= 1")
1499
+ if interaction_mode not in {"mse", "relative"}:
1500
+ raise SystemExit("--comm_interaction_mode must be one of: mse, relative")
1501
+ if comm_train_mode not in {"lora", "full"}:
1502
+ raise SystemExit("--comm_train_mode must be one of: lora, full")
1503
+ if interaction_eps <= 0.0:
1504
+ raise SystemExit("--comm_interaction_eps must be > 0")
1505
+ if mu_auto_rho < 0.0:
1506
+ raise SystemExit("--comm_mu_auto_rho must be >= 0")
1507
+ if mu_auto_eps <= 0.0:
1508
+ raise SystemExit("--comm_mu_auto_eps must be > 0")
1509
+
1510
+ if mu_cfg is None:
1511
+ base_mu = 0.5 if interaction_mode == "relative" else 0.1
1512
+ else:
1513
+ base_mu = float(mu_cfg)
1514
+ if base_mu < 0.0:
1515
+ raise SystemExit("--comm_mu must be >= 0")
1516
+
1517
+ distill_epochs = float(getattr(args, "distill_epochs", 1.0))
1518
+ if distill_epochs <= 0.0:
1519
+ distill_epochs = 1.0
1520
+ grad_accum = int(getattr(args, "distill_grad_accum_steps", 1))
1521
+ if grad_accum <= 0:
1522
+ grad_accum = 1
1523
+
1524
+ try:
1525
+ batches_per_epoch = len(dataloader)
1526
+ except TypeError as exc:
1527
+ raise SystemExit(
1528
+ "Commutator preconditioning requires a finite-length distillation dataloader."
1529
+ ) from exc
1530
+ if batches_per_epoch <= 0:
1531
+ return {"enabled": False, "reason": "empty_dataloader"}
1532
+
1533
+ full_epochs = int(distill_epochs)
1534
+ fractional = distill_epochs - full_epochs
1535
+ if fractional < 1e-8:
1536
+ fractional = 0.0
1537
+ total_batches = full_epochs * batches_per_epoch
1538
+ if fractional > 0.0:
1539
+ frac_batches = int(round(fractional * batches_per_epoch))
1540
+ if frac_batches <= 0:
1541
+ frac_batches = 1
1542
+ total_batches += frac_batches
1543
+
1544
+ distill_opt_steps = int(math.ceil(total_batches / float(grad_accum)))
1545
+ target_opt_steps = int(round(steps_ratio * distill_opt_steps))
1546
+ if target_opt_steps <= 0:
1547
+ target_opt_steps = 1
1548
+
1549
+ num_pairs = max(len(student_layers) - 1, 0)
1550
+ exclude_set = {
1551
+ int(idx)
1552
+ for idx in (exclude_pairs or set())
1553
+ if isinstance(idx, int) and 0 <= int(idx) < num_pairs
1554
+ }
1555
+ allowed_pairs = [i for i in range(num_pairs) if i not in exclude_set]
1556
+ if not allowed_pairs:
1557
+ return {"enabled": False, "reason": "all_pairs_excluded"}
1558
+
1559
+ ranked_pairs = list(allowed_pairs)
1560
+ if dwce_scores is not None and len(dwce_scores) >= num_pairs:
1561
+ finite_pairs = []
1562
+ for idx in allowed_pairs:
1563
+ value = float(dwce_scores[idx])
1564
+ if math.isfinite(value):
1565
+ finite_pairs.append(idx)
1566
+ if finite_pairs:
1567
+ ranked_pairs = sorted(finite_pairs, key=lambda i: float(dwce_scores[i]))
1568
+ else:
1569
+ ranked_pairs = list(allowed_pairs)
1570
+ candidate_pairs = ranked_pairs[: min(top_k, len(ranked_pairs))]
1571
+ if not candidate_pairs:
1572
+ return {"enabled": False, "reason": "no_candidate_pairs"}
1573
+
1574
+ layer_trainable_params: List[List[torch.nn.Parameter]] = []
1575
+ trainable_params: List[torch.nn.Parameter] = []
1576
+ if comm_train_mode == "lora":
1577
+ # LoRA comm preconditioning: update LoRA adapters on receiver layer (i+1).
1578
+ lora_modules = apply_lora_adapters(student_model, args)
1579
+ if not lora_modules:
1580
+ return {"enabled": False, "reason": "no_lora_modules"}
1581
+
1582
+ trainable_seen: Set[int] = set()
1583
+ for module in lora_modules:
1584
+ for param in module.lora_parameters():
1585
+ pid = id(param)
1586
+ if pid in trainable_seen:
1587
+ continue
1588
+ trainable_seen.add(pid)
1589
+ trainable_params.append(param)
1590
+
1591
+ for layer in student_layers:
1592
+ seen: Set[int] = set()
1593
+ params: List[torch.nn.Parameter] = []
1594
+ for module in layer.modules():
1595
+ if not isinstance(module, LoRALinear):
1596
+ continue
1597
+ for param in module.lora_parameters():
1598
+ pid = id(param)
1599
+ if pid in seen:
1600
+ continue
1601
+ seen.add(pid)
1602
+ params.append(param)
1603
+ layer_trainable_params.append(params)
1604
+ else:
1605
+ # Full-weight comm preconditioning: update full receiver-layer weights.
1606
+ for layer in student_layers:
1607
+ seen: Set[int] = set()
1608
+ params: List[torch.nn.Parameter] = []
1609
+ for param in layer.parameters():
1610
+ if not isinstance(param, torch.nn.Parameter):
1611
+ continue
1612
+ pid = id(param)
1613
+ if pid in seen:
1614
+ continue
1615
+ seen.add(pid)
1616
+ params.append(param)
1617
+ layer_trainable_params.append(params)
1618
+
1619
+ candidate_pairs = [
1620
+ i
1621
+ for i in candidate_pairs
1622
+ if (i + 1) < len(layer_trainable_params) and layer_trainable_params[i + 1]
1623
+ ]
1624
+ if not candidate_pairs:
1625
+ if comm_train_mode == "lora":
1626
+ merge_lora_adapters(student_model)
1627
+ return {"enabled": False, "reason": "no_trainable_receiver_layers"}
1628
+
1629
+ if comm_train_mode == "full":
1630
+ trainable_seen: Set[int] = set()
1631
+ for pair_idx in candidate_pairs:
1632
+ for param in layer_trainable_params[pair_idx + 1]:
1633
+ pid = id(param)
1634
+ if pid in trainable_seen:
1635
+ continue
1636
+ trainable_seen.add(pid)
1637
+ trainable_params.append(param)
1638
+ if not trainable_params:
1639
+ return {"enabled": False, "reason": "no_trainable_receiver_layers"}
1640
+
1641
+ # Freeze non-comm params to reduce grad memory.
1642
+ for param in student_model.parameters():
1643
+ param.requires_grad_(False)
1644
+ for param in trainable_params:
1645
+ param.requires_grad_(True)
1646
+
1647
+ if not trainable_params:
1648
+ if comm_train_mode == "lora":
1649
+ merge_lora_adapters(student_model)
1650
+ return {"enabled": False, "reason": "no_trainable_params"}
1651
+
1652
+ candidate_probs = torch.full(
1653
+ (len(candidate_pairs),),
1654
+ 1.0 / float(len(candidate_pairs)),
1655
+ dtype=torch.float32,
1656
+ )
1657
+ if dwce_scores is not None and len(dwce_scores) >= num_pairs and sample_eta > 0.0:
1658
+ score_vec = torch.tensor(
1659
+ [float(dwce_scores[i]) for i in candidate_pairs], dtype=torch.float32
1660
+ )
1661
+ score_vec = torch.nan_to_num(score_vec, nan=1e9, posinf=1e9, neginf=-1e9)
1662
+ biased = torch.softmax(-float(sample_dwce_scale) * score_vec, dim=0)
1663
+ candidate_probs = (1.0 - sample_eta) * candidate_probs + sample_eta * biased
1664
+ candidate_probs = candidate_probs / candidate_probs.sum()
1665
+
1666
+ probs_by_pair = [0.0 for _ in range(num_pairs)]
1667
+ for pos, pair_idx in enumerate(candidate_pairs):
1668
+ probs_by_pair[pair_idx] = float(candidate_probs[pos].item())
1669
+
1670
+ lr = float(getattr(args, "distill_lr", 1e-4)) * lr_scale
1671
+ optimizer = torch.optim.AdamW(
1672
+ trainable_params,
1673
+ lr=lr,
1674
+ weight_decay=float(getattr(args, "distill_weight_decay", 0.0)),
1675
+ )
1676
+
1677
+ device_type = torch.device(args.device).type
1678
+ amp_dtype = None
1679
+ if args.dtype == "float16":
1680
+ amp_dtype = torch.float16
1681
+ elif args.dtype == "bfloat16":
1682
+ amp_dtype = torch.bfloat16
1683
+ use_amp = amp_dtype is not None and device_type == "cuda"
1684
+ use_scaler = use_amp and amp_dtype == torch.float16
1685
+ scaler = torch.cuda.amp.GradScaler() if use_scaler else None
1686
+
1687
+ teacher_device = next(teacher_model.parameters()).device
1688
+ teacher_model.eval()
1689
+ student_model.train()
1690
+
1691
+ gen = torch.Generator(device="cpu")
1692
+ seed = int(getattr(args, "seed", 0))
1693
+ if progressive_cycle is not None:
1694
+ seed += int(progressive_cycle) * 100003
1695
+ gen.manual_seed(seed)
1696
+
1697
+ opt_step = 0
1698
+ total_loss_sum = 0.0
1699
+ anchor_sum = 0.0
1700
+ interaction_sum = 0.0
1701
+ mu_sum = 0.0
1702
+ counted = 0
1703
+ pair_counts = [0 for _ in range(num_pairs)]
1704
+
1705
+ desc = "Comm"
1706
+ if progressive_cycle is not None:
1707
+ if progressive_total is not None:
1708
+ desc = f"Comm (cycle {progressive_cycle}/{progressive_total})"
1709
+ else:
1710
+ desc = f"Comm (cycle {progressive_cycle})"
1711
+ iterator = range(target_opt_steps)
1712
+ if tqdm is not None and _tqdm_enabled():
1713
+ iterator = tqdm(iterator, desc=desc, unit="step")
1714
+
1715
+ data_iter = iter(dataloader)
1716
+ autocast_ctx = (
1717
+ torch.autocast(device_type=device_type, dtype=amp_dtype)
1718
+ if use_amp
1719
+ else nullcontext()
1720
+ )
1721
+
1722
+ for _ in iterator:
1723
+ optimizer.zero_grad(set_to_none=True)
1724
+ accum_done = 0
1725
+ while accum_done < grad_accum:
1726
+ try:
1727
+ batch = next(data_iter)
1728
+ except StopIteration:
1729
+ data_iter = iter(dataloader)
1730
+ batch = next(data_iter)
1731
+
1732
+ input_ids = batch[0].to(args.device)
1733
+ attention_mask = batch[1].to(args.device)
1734
+ sampled_pos = int(torch.multinomial(candidate_probs, 1, generator=gen).item())
1735
+ pair_idx = int(candidate_pairs[sampled_pos])
1736
+ pair_counts[pair_idx] += 1
1737
+
1738
+ receiver_params = layer_trainable_params[pair_idx + 1]
1739
+ receiver_param_ids = {id(param) for param in receiver_params}
1740
+
1741
+ teacher_ids = input_ids.to(teacher_device)
1742
+ teacher_mask = attention_mask.to(teacher_device)
1743
+ with torch.no_grad(), autocast_ctx:
1744
+ teacher_outputs = teacher_model(
1745
+ input_ids=teacher_ids,
1746
+ attention_mask=teacher_mask,
1747
+ use_cache=False,
1748
+ )
1749
+ teacher_logits = teacher_outputs.logits
1750
+
1751
+ capture: Dict[str, object] = {
1752
+ "h_l": None,
1753
+ "h_lp1": None,
1754
+ "y1": None,
1755
+ "recv_args": None,
1756
+ "recv_kwargs": None,
1757
+ }
1758
+
1759
+ def _hook_l(_module, inputs, _output):
1760
+ if inputs and isinstance(inputs[0], torch.Tensor):
1761
+ capture["h_l"] = inputs[0]
1762
+
1763
+ def _hook_recv_pre(_module, inputs, kwargs):
1764
+ capture["recv_args"] = inputs
1765
+ capture["recv_kwargs"] = kwargs
1766
+
1767
+ def _hook_recv(_module, inputs, output):
1768
+ if inputs and isinstance(inputs[0], torch.Tensor):
1769
+ capture["h_lp1"] = inputs[0]
1770
+ capture["y1"] = _extract_hidden_tensor(output)
1771
+
1772
+ handles: List[object] = [
1773
+ student_layers[pair_idx].register_forward_hook(_hook_l),
1774
+ _register_forward_pre_hook_with_optional_kwargs(
1775
+ student_layers[pair_idx + 1], _hook_recv_pre
1776
+ ),
1777
+ student_layers[pair_idx + 1].register_forward_hook(_hook_recv),
1778
+ ]
1779
+ try:
1780
+ with autocast_ctx:
1781
+ student_outputs = student_model(
1782
+ input_ids=input_ids,
1783
+ attention_mask=attention_mask,
1784
+ use_cache=False,
1785
+ )
1786
+ student_logits = student_outputs.logits
1787
+ finally:
1788
+ for handle in handles:
1789
+ try:
1790
+ handle.remove()
1791
+ except Exception:
1792
+ pass
1793
+
1794
+ with autocast_ctx:
1795
+ anchor_kl = _masked_kl(
1796
+ teacher_logits,
1797
+ student_logits,
1798
+ attention_mask,
1799
+ temp=temp,
1800
+ detach_p=True,
1801
+ )
1802
+ if anchor_kl is None:
1803
+ continue
1804
+ anchor_loss = (temp ** 2) * anchor_kl
1805
+
1806
+ interaction_loss = None
1807
+ h_l = capture.get("h_l")
1808
+ h_lp1 = capture.get("h_lp1")
1809
+ y1 = capture.get("y1")
1810
+ recv_args = capture.get("recv_args")
1811
+ recv_kwargs = capture.get("recv_kwargs")
1812
+ if (
1813
+ isinstance(h_l, torch.Tensor)
1814
+ and isinstance(h_lp1, torch.Tensor)
1815
+ and isinstance(y1, torch.Tensor)
1816
+ and isinstance(recv_args, tuple)
1817
+ and len(recv_args) > 0
1818
+ and isinstance(recv_args[0], torch.Tensor)
1819
+ ):
1820
+ call_args = list(recv_args)
1821
+ first_hidden = call_args[0]
1822
+ h_l_detached = h_l.detach().to(
1823
+ device=first_hidden.device,
1824
+ dtype=first_hidden.dtype,
1825
+ )
1826
+ call_args[0] = h_l_detached
1827
+ call_kwargs = dict(recv_kwargs) if isinstance(recv_kwargs, dict) else {}
1828
+
1829
+ y0_raw = student_layers[pair_idx + 1](*tuple(call_args), **call_kwargs)
1830
+ y0 = _extract_hidden_tensor(y0_raw)
1831
+ if isinstance(y0, torch.Tensor):
1832
+ if y0.device != y1.device:
1833
+ y0 = y0.to(y1.device)
1834
+ h_lp1_detached = h_lp1.detach().to(device=y1.device, dtype=y1.dtype)
1835
+ h_l_for_res = h_l.detach().to(device=y0.device, dtype=y0.dtype)
1836
+ r1 = y1 - h_lp1_detached
1837
+ r0 = y0 - h_l_for_res
1838
+ mask = attention_mask.to(dtype=r1.dtype)
1839
+ mask_sum = mask.sum()
1840
+ if mask_sum.item() > 0:
1841
+ if interaction_mode == "relative":
1842
+ num = (r1 - r0).float().pow(2).sum(dim=-1)
1843
+ den = r1.float().pow(2).sum(dim=-1) + float(interaction_eps)
1844
+ ratio = (num / den) * mask.to(num.dtype)
1845
+ interaction_loss = ratio.sum() / (mask_sum + 1e-8)
1846
+ else:
1847
+ denom = mask_sum * r1.size(-1)
1848
+ if denom.item() > 0:
1849
+ interaction_loss = (
1850
+ (r1 - r0).pow(2) * mask.unsqueeze(-1)
1851
+ ).sum() / denom
1852
+
1853
+ mu_effective = float(base_mu)
1854
+ if (
1855
+ mu_auto
1856
+ and interaction_loss is not None
1857
+ and receiver_params
1858
+ and mu_auto_rho > 0.0
1859
+ ):
1860
+ anchor_grads = torch.autograd.grad(
1861
+ anchor_loss,
1862
+ receiver_params,
1863
+ retain_graph=True,
1864
+ allow_unused=True,
1865
+ )
1866
+ interaction_grads = torch.autograd.grad(
1867
+ interaction_loss,
1868
+ receiver_params,
1869
+ retain_graph=True,
1870
+ allow_unused=True,
1871
+ )
1872
+ anchor_norm = _grad_l2_norm(list(anchor_grads))
1873
+ interaction_norm = _grad_l2_norm(list(interaction_grads))
1874
+ if interaction_norm > 0.0:
1875
+ mu_effective = float(
1876
+ mu_auto_rho
1877
+ * (anchor_norm / (interaction_norm + float(mu_auto_eps)))
1878
+ )
1879
+ else:
1880
+ mu_effective = float(base_mu)
1881
+ if not math.isfinite(mu_effective):
1882
+ mu_effective = float(base_mu)
1883
+
1884
+ total_loss = anchor_loss
1885
+ if interaction_loss is not None:
1886
+ total_loss = total_loss + (float(mu_effective) * interaction_loss)
1887
+
1888
+ if grad_accum > 1:
1889
+ total_loss = total_loss / float(grad_accum)
1890
+
1891
+ if use_scaler:
1892
+ scaler.scale(total_loss).backward()
1893
+ else:
1894
+ total_loss.backward()
1895
+
1896
+ # Only the sampled receiver layer updates on this micro-batch.
1897
+ for param in trainable_params:
1898
+ if id(param) in receiver_param_ids:
1899
+ continue
1900
+ if param.grad is not None:
1901
+ if comm_train_mode == "lora":
1902
+ param.grad.zero_()
1903
+ else:
1904
+ param.grad = None
1905
+
1906
+ total_loss_sum += float(total_loss.detach().float().item())
1907
+ anchor_sum += float(anchor_loss.detach().float().item())
1908
+ if interaction_loss is not None:
1909
+ interaction_sum += float(interaction_loss.detach().float().item())
1910
+ mu_sum += float(mu_effective)
1911
+ counted += 1
1912
+ accum_done += 1
1913
+
1914
+ if args.distill_max_grad_norm is not None:
1915
+ if use_scaler:
1916
+ scaler.unscale_(optimizer)
1917
+ torch.nn.utils.clip_grad_norm_(
1918
+ trainable_params,
1919
+ float(args.distill_max_grad_norm),
1920
+ )
1921
+
1922
+ if use_scaler:
1923
+ scaler.step(optimizer)
1924
+ scaler.update()
1925
+ else:
1926
+ optimizer.step()
1927
+
1928
+ opt_step += 1
1929
+ if log_steps and (opt_step == 1 or opt_step % log_steps == 0):
1930
+ denom = max(counted, 1)
1931
+ print(
1932
+ f"[comm] step={opt_step}/{target_opt_steps} "
1933
+ f"loss={total_loss_sum/denom:.6f} "
1934
+ f"anchor={anchor_sum/denom:.6f} "
1935
+ f"int={interaction_sum/denom:.6f} "
1936
+ f"mu={mu_sum/denom:.6f}"
1937
+ )
1938
+
1939
+ if comm_train_mode == "lora":
1940
+ merge_lora_adapters(student_model)
1941
+
1942
+ stats: Dict[str, object] = {
1943
+ "enabled": True,
1944
+ "train_mode": comm_train_mode,
1945
+ "opt_steps": int(target_opt_steps),
1946
+ "grad_accum_steps": int(grad_accum),
1947
+ "lr": float(lr),
1948
+ "temp": float(temp),
1949
+ "steps_ratio": float(steps_ratio),
1950
+ "lr_scale": float(lr_scale),
1951
+ "interaction_mode": interaction_mode,
1952
+ "interaction_eps": float(interaction_eps),
1953
+ "mu": float(base_mu),
1954
+ "mu_auto": bool(mu_auto),
1955
+ "mu_auto_rho": float(mu_auto_rho),
1956
+ "mu_auto_eps": float(mu_auto_eps),
1957
+ "sample_eta": float(sample_eta),
1958
+ "sample_dwce_scale": float(sample_dwce_scale),
1959
+ "topk": int(top_k),
1960
+ "candidate_pairs": [int(i) for i in candidate_pairs],
1961
+ "trainable_params": int(sum(int(param.numel()) for param in trainable_params)),
1962
+ }
1963
+ total_samples = int(sum(pair_counts))
1964
+ probs_list = [float(x) for x in probs_by_pair]
1965
+ freqs = (
1966
+ [float(c) / float(total_samples) for c in pair_counts]
1967
+ if total_samples > 0
1968
+ else [0.0 for _ in pair_counts]
1969
+ )
1970
+ top_show = min(10, num_pairs)
1971
+ top_indices = sorted(range(num_pairs), key=lambda i: pair_counts[i], reverse=True)[:top_show]
1972
+ top_pairs = [
1973
+ {
1974
+ "pair": int(i),
1975
+ "count": int(pair_counts[i]),
1976
+ "freq": float(freqs[i]),
1977
+ "prob": float(probs_list[i]) if i < len(probs_list) else None,
1978
+ }
1979
+ for i in top_indices
1980
+ if pair_counts[i] > 0
1981
+ ]
1982
+ stats["pair_selection"] = {
1983
+ "num_pairs": int(num_pairs),
1984
+ "excluded_pairs": sorted(exclude_set),
1985
+ "candidate_pairs": [int(i) for i in candidate_pairs],
1986
+ "total_samples": total_samples,
1987
+ "unique_pairs": int(sum(1 for c in pair_counts if c > 0)),
1988
+ "counts": [int(c) for c in pair_counts],
1989
+ "freqs": freqs,
1990
+ "probs": probs_list,
1991
+ "top_pairs": top_pairs,
1992
+ }
1993
+
1994
+ if total_samples > 0 and top_pairs:
1995
+ top_str = ", ".join(
1996
+ f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} "
1997
+ f"(obs={entry['freq']:.3f}, exp={entry['prob']:.3f})"
1998
+ for entry in top_pairs
1999
+ if entry.get("prob") is not None
2000
+ )
2001
+ if not top_str:
2002
+ top_str = ", ".join(
2003
+ f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} "
2004
+ f"(obs={entry['freq']:.3f})"
2005
+ for entry in top_pairs
2006
+ )
2007
+ print(
2008
+ f"[comm] Pair sampling stats: total={total_samples} "
2009
+ f"unique={stats['pair_selection']['unique_pairs']}/{num_pairs} "
2010
+ f"top={top_str}"
2011
+ )
2012
+
2013
+ if counted > 0:
2014
+ stats["avg_loss"] = float(total_loss_sum / float(counted))
2015
+ stats["avg_anchor"] = float(anchor_sum / float(counted))
2016
+ stats["avg_interaction"] = float(interaction_sum / float(counted))
2017
+ stats["avg_mu"] = float(mu_sum / float(counted))
2018
+ return stats
src/fuse_layers_model.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Model and layer helpers for fuse_layers."""
3
+
4
+ import os
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+
9
+ try:
10
+ from tqdm import tqdm
11
+ except Exception: # pragma: no cover - optional dependency
12
+ tqdm = None
13
+
14
+
15
+ def _tqdm_enabled() -> bool:
16
+ value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
17
+ return value.strip().lower() not in {"1", "true", "yes", "on"}
18
+
19
+
20
+ def get_dtype(dtype: str):
21
+ if dtype == "auto":
22
+ return None
23
+ if dtype == "float16":
24
+ return torch.float16
25
+ if dtype == "bfloat16":
26
+ return torch.bfloat16
27
+ return torch.float32
28
+
29
+
30
+ def resolve_attr(root: object, path: str) -> Optional[object]:
31
+ cur = root
32
+ for part in path.split("."):
33
+ if not hasattr(cur, part):
34
+ return None
35
+ cur = getattr(cur, part)
36
+ return cur
37
+
38
+
39
+ def resolve_attr_with_parent(root: object, path: str) -> Tuple[object, str, object]:
40
+ parts = path.split(".")
41
+ cur = root
42
+ for part in parts[:-1]:
43
+ if not hasattr(cur, part):
44
+ raise ValueError(f"'{path}' not found on model")
45
+ cur = getattr(cur, part)
46
+ name = parts[-1]
47
+ if not hasattr(cur, name):
48
+ raise ValueError(f"'{path}' not found on model")
49
+ return cur, name, getattr(cur, name)
50
+
51
+
52
+ def find_layer_container(model, layer_path: Optional[str]) -> Tuple[object, str, object]:
53
+ if layer_path:
54
+ parent, name, container = resolve_attr_with_parent(model, layer_path)
55
+ return parent, name, container
56
+
57
+ candidate_paths = [
58
+ "model.layers", # LLaMA, Mistral, Qwen2, Gemma
59
+ "model.decoder.layers", # OPT
60
+ "transformer.h", # GPT-2, GPT-J, Bloom, Falcon
61
+ "transformer.blocks", # MPT
62
+ "gpt_neox.layers", # GPT-NeoX
63
+ "layers", # fallback
64
+ ]
65
+ for path in candidate_paths:
66
+ candidate = resolve_attr(model, path)
67
+ if candidate is None:
68
+ continue
69
+ try:
70
+ list(candidate)
71
+ except TypeError:
72
+ continue
73
+ parent, name, container = resolve_attr_with_parent(model, path)
74
+ return parent, name, container
75
+
76
+ raise ValueError(
77
+ "Could not locate transformer layers. Pass --layer_path explicitly."
78
+ )
79
+
80
+
81
+ def find_attention_module(layer: torch.nn.Module) -> torch.nn.Module:
82
+ if hasattr(layer, "self_attn"):
83
+ return getattr(layer, "self_attn")
84
+ if hasattr(layer, "attn"):
85
+ return getattr(layer, "attn")
86
+ if hasattr(layer, "attention"):
87
+ return getattr(layer, "attention")
88
+ for _, module in layer.named_modules():
89
+ if all(
90
+ hasattr(module, attr) for attr in ("q_proj", "k_proj", "v_proj", "o_proj")
91
+ ):
92
+ return module
93
+ raise ValueError("Could not find attention module with q_proj/k_proj/v_proj/o_proj")
94
+
95
+
96
+ def find_mlp_module(layer: torch.nn.Module) -> torch.nn.Module:
97
+ if hasattr(layer, "mlp"):
98
+ return getattr(layer, "mlp")
99
+ for attr in ("feed_forward", "feedforward", "ffn", "ff"):
100
+ if hasattr(layer, attr):
101
+ return getattr(layer, attr)
102
+ for _, module in layer.named_modules():
103
+ if all(hasattr(module, attr) for attr in ("gate_proj", "up_proj", "down_proj")):
104
+ return module
105
+ if all(hasattr(module, attr) for attr in ("fc1", "fc2")):
106
+ return module
107
+ if all(
108
+ hasattr(module, attr)
109
+ for attr in ("dense_h_to_4h", "dense_4h_to_h")
110
+ ):
111
+ return module
112
+ if all(hasattr(module, attr) for attr in ("w1", "w2")):
113
+ return module
114
+ raise ValueError("Could not find MLP/FFN module on layer")
115
+
116
+
117
+ def get_head_info(
118
+ attn: torch.nn.Module, hidden_size: int, config
119
+ ) -> Tuple[int, int, int]:
120
+ num_heads = getattr(attn, "num_heads", None)
121
+ if num_heads is None:
122
+ num_heads = getattr(attn, "num_attention_heads", None)
123
+ if num_heads is None and config is not None:
124
+ num_heads = getattr(
125
+ config,
126
+ "num_attention_heads",
127
+ getattr(config, "num_heads", getattr(config, "n_head", None)),
128
+ )
129
+
130
+ num_key_value_heads = getattr(attn, "num_key_value_heads", None)
131
+ if num_key_value_heads is None:
132
+ num_key_value_heads = getattr(attn, "num_kv_heads", None)
133
+ if num_key_value_heads is None and config is not None:
134
+ num_key_value_heads = getattr(
135
+ config,
136
+ "num_key_value_heads",
137
+ getattr(config, "num_kv_heads", getattr(config, "n_head_kv", None)),
138
+ )
139
+
140
+ head_dim = getattr(attn, "head_dim", None)
141
+ if head_dim is None and config is not None:
142
+ head_dim = getattr(config, "head_dim", None)
143
+
144
+ if num_heads is None:
145
+ if hasattr(attn, "q_proj"):
146
+ q_out = attn.q_proj.weight.shape[0]
147
+ if head_dim is not None:
148
+ num_heads = q_out // head_dim
149
+ elif num_key_value_heads is not None and hasattr(attn, "k_proj"):
150
+ k_out = attn.k_proj.weight.shape[0]
151
+ head_dim = k_out // max(int(num_key_value_heads), 1)
152
+ num_heads = q_out // head_dim
153
+ if num_heads is None:
154
+ raise ValueError(
155
+ "Attention module missing num_heads/num_attention_heads; "
156
+ "pass --layer_path or add config overrides."
157
+ )
158
+
159
+ if num_key_value_heads is None:
160
+ num_key_value_heads = num_heads
161
+
162
+ if head_dim is None:
163
+ head_dim = hidden_size // int(num_heads)
164
+
165
+ if num_key_value_heads is None and hasattr(attn, "k_proj"):
166
+ k_out = attn.k_proj.weight.shape[0]
167
+ num_key_value_heads = k_out // int(head_dim)
168
+
169
+ return int(num_heads), int(num_key_value_heads), int(head_dim)
170
+
171
+
172
+ def cosine_cost_matrix(
173
+ a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8
174
+ ) -> torch.Tensor:
175
+ a_norm = a / (a.norm(dim=1, keepdim=True) + eps)
176
+ b_norm = b / (b.norm(dim=1, keepdim=True) + eps)
177
+ sim = a_norm @ b_norm.t()
178
+ return 1.0 - sim
179
+
180
+
181
+ def hungarian(cost: torch.Tensor) -> List[int]:
182
+ # Kuhn-Munkres for square cost matrix (minimization).
183
+ n = cost.size(0)
184
+ u = [0.0] * (n + 1)
185
+ v = [0.0] * (n + 1)
186
+ p = [0] * (n + 1)
187
+ way = [0] * (n + 1)
188
+
189
+ for i in range(1, n + 1):
190
+ p[0] = i
191
+ j0 = 0
192
+ minv = [float("inf")] * (n + 1)
193
+ used = [False] * (n + 1)
194
+ while True:
195
+ used[j0] = True
196
+ i0 = p[j0]
197
+ delta = float("inf")
198
+ j1 = 0
199
+ for j in range(1, n + 1):
200
+ if used[j]:
201
+ continue
202
+ cur = cost[i0 - 1, j - 1].item() - u[i0] - v[j]
203
+ if cur < minv[j]:
204
+ minv[j] = cur
205
+ way[j] = j0
206
+ if minv[j] < delta:
207
+ delta = minv[j]
208
+ j1 = j
209
+ for j in range(0, n + 1):
210
+ if used[j]:
211
+ u[p[j]] += delta
212
+ v[j] -= delta
213
+ else:
214
+ minv[j] -= delta
215
+ j0 = j1
216
+ if p[j0] == 0:
217
+ break
218
+ while True:
219
+ j1 = way[j0]
220
+ p[j0] = p[j1]
221
+ j0 = j1
222
+ if j0 == 0:
223
+ break
224
+
225
+ assignment = [-1] * n
226
+ for j in range(1, n + 1):
227
+ if p[j] > 0:
228
+ assignment[p[j] - 1] = j - 1
229
+ return assignment
230
+
231
+
232
+ def compute_head_means(
233
+ model,
234
+ attn_i: torch.nn.Module,
235
+ attn_j: torch.nn.Module,
236
+ dataloader,
237
+ device: str,
238
+ hidden_size: int,
239
+ ) -> Tuple[torch.Tensor, torch.Tensor, int, int, int]:
240
+ num_heads_i, num_kv_i, head_dim_i = get_head_info(attn_i, hidden_size, model.config)
241
+ num_heads_j, num_kv_j, head_dim_j = get_head_info(attn_j, hidden_size, model.config)
242
+ if num_heads_i != num_heads_j or head_dim_i != head_dim_j:
243
+ raise ValueError("Head counts or head_dim differ between layers; cannot align")
244
+
245
+ sums_i = torch.zeros(num_heads_i, head_dim_i, device="cpu")
246
+ sums_j = torch.zeros(num_heads_j, head_dim_j, device="cpu")
247
+ count_i = [0]
248
+ count_j = [0]
249
+
250
+ def make_hook(
251
+ sums: torch.Tensor, count_ref: List[int], num_heads: int, head_dim: int
252
+ ):
253
+ def hook(_module, inputs, _output):
254
+ hidden = inputs[0].detach()
255
+ if hidden.dim() != 3:
256
+ return
257
+ batch, seq, width = hidden.shape
258
+ if width != num_heads * head_dim:
259
+ return
260
+ reshaped = hidden.view(batch, seq, num_heads, head_dim)
261
+ sums.add_(reshaped.sum(dim=(0, 1)).float().cpu())
262
+ count_ref[0] += batch * seq
263
+
264
+ return hook
265
+
266
+ hook_i = attn_i.o_proj.register_forward_hook(
267
+ make_hook(sums_i, count_i, num_heads_i, head_dim_i)
268
+ )
269
+ hook_j = attn_j.o_proj.register_forward_hook(
270
+ make_hook(sums_j, count_j, num_heads_j, head_dim_j)
271
+ )
272
+
273
+ model.eval()
274
+ iterator = dataloader
275
+ if tqdm is not None and _tqdm_enabled():
276
+ iterator = tqdm(dataloader, desc="Head stats", unit="batch")
277
+ with torch.no_grad():
278
+ for batch in iterator:
279
+ input_ids = batch[0].to(device)
280
+ _ = model(input_ids=input_ids)
281
+
282
+ hook_i.remove()
283
+ hook_j.remove()
284
+
285
+ if count_i[0] == 0 or count_j[0] == 0:
286
+ raise RuntimeError("Failed to capture head outputs; check attention modules.")
287
+
288
+ mean_i = sums_i / count_i[0]
289
+ mean_j = sums_j / count_j[0]
290
+ return mean_i, mean_j, num_heads_i, num_kv_i, head_dim_i
291
+
292
+
293
+ def build_head_permutation(
294
+ mean_i: torch.Tensor,
295
+ mean_j: torch.Tensor,
296
+ num_heads: int,
297
+ num_kv_heads: int,
298
+ eps: float,
299
+ ) -> List[int]:
300
+ group_size = num_heads // num_kv_heads
301
+ if group_size * num_kv_heads != num_heads:
302
+ raise ValueError("num_heads must be divisible by num_key_value_heads")
303
+
304
+ perm = list(range(num_heads))
305
+ for g in range(num_kv_heads):
306
+ start = g * group_size
307
+ end = start + group_size
308
+ cost = cosine_cost_matrix(mean_i[start:end], mean_j[start:end], eps=eps)
309
+ assignment = hungarian(cost)
310
+ for local_idx, match in enumerate(assignment):
311
+ perm[start + local_idx] = start + match
312
+ return perm
313
+
314
+
315
+ def permute_attention_heads(
316
+ attn: torch.nn.Module,
317
+ perm: List[int],
318
+ num_heads: int,
319
+ num_kv_heads: int,
320
+ head_dim: int,
321
+ ) -> None:
322
+ hidden_size = num_heads * head_dim
323
+
324
+ def permute_out_proj_weight(weight: torch.Tensor) -> torch.Tensor:
325
+ out_features, in_features = weight.shape
326
+ if in_features != hidden_size:
327
+ raise ValueError(
328
+ "o_proj in_features ({} ) != num_heads*head_dim ({})".format(
329
+ in_features, hidden_size
330
+ )
331
+ )
332
+ reshaped = weight.view(out_features, num_heads, head_dim)
333
+ reshaped = reshaped[:, perm, :]
334
+ return reshaped.reshape(out_features, in_features)
335
+
336
+ def permute_proj_weight(weight: torch.Tensor) -> torch.Tensor:
337
+ out_features, in_features = weight.shape
338
+ if out_features != hidden_size:
339
+ raise ValueError(
340
+ "proj out_features ({}) != num_heads*head_dim ({})".format(
341
+ out_features, hidden_size
342
+ )
343
+ )
344
+ reshaped = weight.view(num_heads, head_dim, in_features)
345
+ reshaped = reshaped[perm, :, :]
346
+ return reshaped.reshape(out_features, in_features)
347
+
348
+ def permute_proj_bias(bias: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
349
+ if bias is None:
350
+ return None
351
+ reshaped = bias.view(num_heads, head_dim)
352
+ reshaped = reshaped[perm, :]
353
+ return reshaped.reshape(num_heads * head_dim)
354
+
355
+ with torch.no_grad():
356
+ attn.q_proj.weight.copy_(permute_proj_weight(attn.q_proj.weight))
357
+ if attn.q_proj.bias is not None:
358
+ attn.q_proj.bias.copy_(permute_proj_bias(attn.q_proj.bias))
359
+
360
+ if num_kv_heads == num_heads:
361
+ attn.k_proj.weight.copy_(permute_proj_weight(attn.k_proj.weight))
362
+ if attn.k_proj.bias is not None:
363
+ attn.k_proj.bias.copy_(permute_proj_bias(attn.k_proj.bias))
364
+ attn.v_proj.weight.copy_(permute_proj_weight(attn.v_proj.weight))
365
+ if attn.v_proj.bias is not None:
366
+ attn.v_proj.bias.copy_(permute_proj_bias(attn.v_proj.bias))
367
+
368
+ attn.o_proj.weight.copy_(permute_out_proj_weight(attn.o_proj.weight))
369
+
370
+
371
+ def compute_fisher(
372
+ model,
373
+ layer_a: torch.nn.Module,
374
+ layer_b: torch.nn.Module,
375
+ dataloader,
376
+ fisher_mode: str,
377
+ device: str,
378
+ ) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]:
379
+ for param in model.parameters():
380
+ param.requires_grad_(False)
381
+ for layer in (layer_a, layer_b):
382
+ for param in layer.parameters():
383
+ param.requires_grad_(True)
384
+
385
+ fisher_sums: List[Dict[str, object]] = []
386
+ param_numels: List[Dict[str, int]] = []
387
+ for layer in (layer_a, layer_b):
388
+ layer_sums: Dict[str, object] = {}
389
+ layer_numels: Dict[str, int] = {}
390
+ for name, param in layer.named_parameters():
391
+ if not param.requires_grad:
392
+ continue
393
+ if fisher_mode == "param":
394
+ layer_sums[name] = torch.zeros_like(
395
+ param, dtype=torch.float32, device="cpu"
396
+ )
397
+ else:
398
+ layer_sums[name] = 0.0
399
+ layer_numels[name] = param.numel()
400
+ fisher_sums.append(layer_sums)
401
+ param_numels.append(layer_numels)
402
+
403
+ num_batches = 0
404
+ model.eval()
405
+ iterator = dataloader
406
+ if tqdm is not None and _tqdm_enabled():
407
+ iterator = tqdm(dataloader, desc="Fisher", unit="batch")
408
+ for batch in iterator:
409
+ input_ids = batch[0].to(device)
410
+ outputs = model(input_ids=input_ids, labels=input_ids)
411
+ loss = outputs.loss
412
+ loss.backward()
413
+ for layer_idx, layer in enumerate((layer_a, layer_b)):
414
+ layer_sums = fisher_sums[layer_idx]
415
+ for name, param in layer.named_parameters():
416
+ if not param.requires_grad or param.grad is None:
417
+ continue
418
+ grad_sq = param.grad.detach().float().pow(2)
419
+ if fisher_mode == "param":
420
+ layer_sums[name] += grad_sq.cpu()
421
+ else:
422
+ layer_sums[name] += float(grad_sq.sum().item())
423
+ model.zero_grad(set_to_none=True)
424
+ num_batches += 1
425
+
426
+ if num_batches == 0:
427
+ raise RuntimeError("No batches processed; check dataset or text inputs.")
428
+
429
+ return fisher_sums, num_batches, param_numels
430
+
431
+
432
+ def merge_layers(
433
+ layer_a: torch.nn.Module,
434
+ layer_b: torch.nn.Module,
435
+ fisher_a: Dict[str, object],
436
+ fisher_b: Dict[str, object],
437
+ num_batches: int,
438
+ numels_a: Dict[str, int],
439
+ numels_b: Dict[str, int],
440
+ fisher_mode: str,
441
+ eps: float,
442
+ ) -> int:
443
+ merged = 0
444
+ params_b = {name: param for name, param in layer_b.named_parameters()}
445
+ with torch.no_grad():
446
+ for name, param_a in layer_a.named_parameters():
447
+ param_b = params_b.get(name)
448
+ if param_b is None or param_b.shape != param_a.shape:
449
+ continue
450
+ if fisher_mode == "param":
451
+ fa = fisher_a[name] / num_batches
452
+ fb = fisher_b[name] / num_batches
453
+ # Fisher tensors are accumulated on CPU to save VRAM; move to the
454
+ # parameter device for the actual merge.
455
+ if isinstance(fa, torch.Tensor) and fa.device != param_a.device:
456
+ fa = fa.to(param_a.device)
457
+ if isinstance(fb, torch.Tensor) and fb.device != param_a.device:
458
+ fb = fb.to(param_a.device)
459
+ denom = fa + fb
460
+ denom_mean = float(denom.mean().item())
461
+ if denom_mean <= eps:
462
+ merged_param = 0.5 * (param_a.float() + param_b.float())
463
+ else:
464
+ merged_param = (fa * param_a.float() + fb * param_b.float()) / (
465
+ denom + eps
466
+ )
467
+ else:
468
+ fa = fisher_a[name] / (num_batches * numels_a[name])
469
+ fb = fisher_b[name] / (num_batches * numels_b[name])
470
+ denom = fa + fb
471
+ if denom <= eps:
472
+ merged_param = 0.5 * (param_a.float() + param_b.float())
473
+ else:
474
+ merged_param = (
475
+ fa * param_a.float() + fb * param_b.float()
476
+ ) / (denom + eps)
477
+ param_a.copy_(merged_param.to(dtype=param_a.dtype))
478
+ merged += 1
479
+ return merged
480
+
481
+
482
+ def merge_layers_with_gates(
483
+ layer_a: torch.nn.Module,
484
+ layer_b: torch.nn.Module,
485
+ gates: Dict[str, torch.Tensor],
486
+ ) -> int:
487
+ """Merge layer_b into layer_a using precomputed gates.
488
+
489
+ Each gate is a lambda in [0, 1] that mixes parameters as:
490
+ W = lambda * W_a + (1 - lambda) * W_b
491
+
492
+ Gate tensors may be scalars (per-tensor gating) or full tensors matching the
493
+ parameter shape (per-parameter gating).
494
+ """
495
+ merged = 0
496
+ params_b = {name: param for name, param in layer_b.named_parameters()}
497
+ with torch.no_grad():
498
+ for name, param_a in layer_a.named_parameters():
499
+ gate = gates.get(name)
500
+ if gate is None:
501
+ continue
502
+ param_b = params_b.get(name)
503
+ if param_b is None or param_b.shape != param_a.shape:
504
+ continue
505
+ lam = gate
506
+ if not isinstance(lam, torch.Tensor):
507
+ lam = torch.tensor(lam)
508
+ if lam.device != param_a.device:
509
+ lam = lam.to(param_a.device)
510
+ merged_param = lam * param_a.float() + (1.0 - lam) * param_b.float()
511
+ param_a.copy_(merged_param.to(dtype=param_a.dtype))
512
+ merged += 1
513
+ return merged
514
+
515
+
516
+ def drop_layer(container: object, index: int) -> object:
517
+ if isinstance(container, torch.nn.ModuleList):
518
+ return torch.nn.ModuleList(
519
+ [layer for idx, layer in enumerate(container) if idx != index]
520
+ )
521
+ if isinstance(container, list):
522
+ del container[index]
523
+ return container
524
+ raise TypeError("Layer container must be ModuleList or list")
525
+
526
+
527
+ def decrement_config(config) -> None:
528
+ for attr in ("num_hidden_layers", "n_layer", "num_layers"):
529
+ if hasattr(config, attr):
530
+ value = getattr(config, attr)
531
+ if isinstance(value, int) and value > 0:
532
+ setattr(config, attr, value - 1)
533
+ normalize_config(config)
534
+
535
+
536
+ def normalize_config(config) -> None:
537
+ num_hidden_layers = getattr(config, "num_hidden_layers", None)
538
+ layer_types = getattr(config, "layer_types", None)
539
+ if (
540
+ isinstance(num_hidden_layers, int)
541
+ and num_hidden_layers >= 0
542
+ and isinstance(layer_types, (list, tuple))
543
+ and len(layer_types) != num_hidden_layers
544
+ ):
545
+ config.layer_types = list(layer_types[:num_hidden_layers])
546
+
547
+
548
+ def find_colon_modules(module: torch.nn.Module) -> List[str]:
549
+ found: List[str] = []
550
+ for name, child in module._modules.items():
551
+ if ":" in name:
552
+ found.append(name)
553
+ if isinstance(child, torch.nn.Module):
554
+ for sub in find_colon_modules(child):
555
+ found.append(f"{name}.{sub}")
556
+ return found
557
+
558
+
559
+ def get_norm_pair(
560
+ layer: torch.nn.Module,
561
+ ) -> Tuple[
562
+ Optional[torch.nn.Module],
563
+ Optional[torch.nn.Module],
564
+ Tuple[Optional[str], Optional[str]],
565
+ ]:
566
+ candidates = [
567
+ ("input_layernorm", "post_attention_layernorm"),
568
+ ("ln_1", "ln_2"),
569
+ ("norm1", "norm2"),
570
+ ("norm_1", "norm_2"),
571
+ ("layer_norm_1", "layer_norm_2"),
572
+ ("self_attn_layer_norm", "final_layer_norm"),
573
+ ]
574
+ for n1, n2 in candidates:
575
+ if hasattr(layer, n1) and hasattr(layer, n2):
576
+ return getattr(layer, n1), getattr(layer, n2), (n1, n2)
577
+ return None, None, (None, None)
578
+
579
+
580
+ def clone_state_dict(module: torch.nn.Module) -> Dict[str, torch.Tensor]:
581
+ return {k: v.detach().clone() for k, v in module.state_dict().items()}
582
+
583
+
584
+ def apply_norm_policy(
585
+ layer: torch.nn.Module,
586
+ norm_policy: str,
587
+ norm1_state: Optional[Dict[str, torch.Tensor]],
588
+ norm2_state: Optional[Dict[str, torch.Tensor]],
589
+ norm_names: Tuple[Optional[str], Optional[str]],
590
+ ) -> None:
591
+ norm1, norm2, _ = get_norm_pair(layer)
592
+ if norm_policy in {"copy_n1", "hybrid"} and norm1_state is not None and norm1 is not None:
593
+ norm1.load_state_dict(norm1_state)
594
+ if norm_policy == "copy_n1_n2" and norm2_state is not None and norm2 is not None:
595
+ norm2.load_state_dict(norm2_state)
src/fuse_layers_select.py ADDED
@@ -0,0 +1,1152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Automatic adjacent-pair selection via configurable scoring metrics."""
3
+
4
+ import copy
5
+ import math
6
+ from contextlib import contextmanager
7
+ from typing import Dict, List, Optional, Set, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from fuse_layers_model import (
13
+ build_head_permutation,
14
+ compute_fisher,
15
+ compute_head_means,
16
+ find_attention_module,
17
+ find_layer_container,
18
+ merge_layers,
19
+ permute_attention_heads,
20
+ )
21
+
22
+ _DWCE_GRAD_CACHE_MAX_BYTES = 1 << 30
23
+
24
+
25
+ class _DwceGradCacheOverflow(RuntimeError):
26
+ """Raised when shared-backward DWCE caching exceeds the configured budget."""
27
+
28
+
29
+ def _get_hidden_size(model) -> int:
30
+ hidden_size = getattr(model.config, "hidden_size", None)
31
+ if hidden_size is None:
32
+ hidden_size = getattr(model.config, "n_embd", None)
33
+ if hidden_size is None:
34
+ raise SystemExit("Model config missing hidden_size/n_embd")
35
+ return int(hidden_size)
36
+
37
+
38
+ def _detach_arg(arg):
39
+ if torch.is_tensor(arg):
40
+ return arg.detach()
41
+ if isinstance(arg, (list, tuple)):
42
+ return type(arg)(_detach_arg(x) for x in arg)
43
+ if isinstance(arg, dict):
44
+ return {k: _detach_arg(v) for k, v in arg.items()}
45
+ return arg
46
+
47
+
48
+ def _register_forward_hook(layer, hook):
49
+ try:
50
+ def wrapper(module, inputs, kwargs, output):
51
+ return hook(module, inputs, output, kwargs)
52
+
53
+ handle = layer.register_forward_hook(wrapper, with_kwargs=True)
54
+ return handle, True
55
+ except TypeError:
56
+ def wrapper(module, inputs, output):
57
+ return hook(module, inputs, output, None)
58
+ handle = layer.register_forward_hook(wrapper)
59
+ return handle, False
60
+
61
+
62
+ @contextmanager
63
+ def _temporary_layers(parent: object, name: str, new_layers: object):
64
+ original = getattr(parent, name)
65
+ setattr(parent, name, new_layers)
66
+ try:
67
+ yield
68
+ finally:
69
+ setattr(parent, name, original)
70
+
71
+
72
+ def _extract_hidden(output):
73
+ if torch.is_tensor(output):
74
+ return output
75
+ if isinstance(output, (tuple, list)):
76
+ if output and all(torch.is_tensor(item) for item in output):
77
+ return output[0]
78
+ for item in output:
79
+ hidden = _extract_hidden(item)
80
+ if hidden is not None:
81
+ return hidden
82
+ return None
83
+ if isinstance(output, dict):
84
+ for key in ("hidden_states", "last_hidden_state", "hidden_state"):
85
+ if key in output:
86
+ value = output[key]
87
+ if isinstance(value, (tuple, list)) and value and all(
88
+ torch.is_tensor(item) for item in value
89
+ ):
90
+ return value[-1]
91
+ hidden = _extract_hidden(value)
92
+ if hidden is not None:
93
+ return hidden
94
+ for value in output.values():
95
+ hidden = _extract_hidden(value)
96
+ if hidden is not None:
97
+ return hidden
98
+ return None
99
+ for attr in ("hidden_states", "last_hidden_state"):
100
+ if hasattr(output, attr):
101
+ value = getattr(output, attr)
102
+ if isinstance(value, (tuple, list)) and value and all(
103
+ torch.is_tensor(item) for item in value
104
+ ):
105
+ return value[-1]
106
+ hidden = _extract_hidden(value)
107
+ if hidden is not None:
108
+ return hidden
109
+ return None
110
+
111
+
112
+ def _build_fused_layer_for_pair(
113
+ model,
114
+ layer_a: torch.nn.Module,
115
+ layer_b: torch.nn.Module,
116
+ dataloader,
117
+ device: str,
118
+ fisher_mode: str,
119
+ eps: float,
120
+ hidden_size: int,
121
+ enable_head_permute: bool = True,
122
+ ) -> Tuple[torch.nn.Module, Dict[str, float]]:
123
+ attn_a = find_attention_module(layer_a)
124
+ attn_b = find_attention_module(layer_b)
125
+ perm = None
126
+ inv_perm = None
127
+ num_heads = None
128
+ num_kv_heads = None
129
+ head_dim = None
130
+ if enable_head_permute:
131
+ mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
132
+ model,
133
+ attn_a,
134
+ attn_b,
135
+ dataloader,
136
+ device,
137
+ hidden_size,
138
+ )
139
+
140
+ perm = build_head_permutation(
141
+ mean_a,
142
+ mean_b,
143
+ num_heads=num_heads,
144
+ num_kv_heads=num_kv_heads,
145
+ eps=eps,
146
+ )
147
+
148
+ layer_a_copy = copy.deepcopy(layer_a)
149
+ layer_b_copy = copy.deepcopy(layer_b)
150
+ attn_b_copy = find_attention_module(layer_b_copy)
151
+ if perm is not None:
152
+ permute_attention_heads(
153
+ attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim
154
+ )
155
+
156
+ inv_perm = [0] * len(perm)
157
+ for idx, mapped in enumerate(perm):
158
+ inv_perm[mapped] = idx
159
+
160
+ permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim)
161
+ try:
162
+ fisher_sums, num_batches, param_numels = compute_fisher(
163
+ model,
164
+ layer_a,
165
+ layer_b,
166
+ dataloader,
167
+ fisher_mode=fisher_mode,
168
+ device=device,
169
+ )
170
+ finally:
171
+ if inv_perm is not None:
172
+ permute_attention_heads(
173
+ attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim
174
+ )
175
+
176
+ merge_layers(
177
+ layer_a_copy,
178
+ layer_b_copy,
179
+ fisher_sums[0],
180
+ fisher_sums[1],
181
+ num_batches,
182
+ param_numels[0],
183
+ param_numels[1],
184
+ fisher_mode=fisher_mode,
185
+ eps=eps,
186
+ )
187
+
188
+ # Scalar mixing coefficients per parameter tensor; used by pressure redistribution
189
+ # to simulate future fusions without running another Fisher pass.
190
+ fuse_priors: Dict[str, float] = {}
191
+ params_b = {name: param for name, param in layer_b.named_parameters()}
192
+ clamp_eps = 1e-4
193
+ for name, param_a in layer_a.named_parameters():
194
+ param_b = params_b.get(name)
195
+ if param_b is None or param_b.shape != param_a.shape:
196
+ continue
197
+ if fisher_mode == "param":
198
+ fa = fisher_sums[0][name] / max(num_batches, 1)
199
+ fb = fisher_sums[1][name] / max(num_batches, 1)
200
+ if isinstance(fa, torch.Tensor):
201
+ fa_val = float(fa.mean().item())
202
+ else:
203
+ fa_val = float(fa)
204
+ if isinstance(fb, torch.Tensor):
205
+ fb_val = float(fb.mean().item())
206
+ else:
207
+ fb_val = float(fb)
208
+ else:
209
+ fa_val = float(
210
+ fisher_sums[0][name]
211
+ / (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1))
212
+ )
213
+ fb_val = float(
214
+ fisher_sums[1][name]
215
+ / (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1))
216
+ )
217
+ denom = fa_val + fb_val
218
+ if denom <= eps:
219
+ lam = 0.5
220
+ else:
221
+ lam = fa_val / (denom + eps)
222
+ lam = min(max(lam, clamp_eps), 1.0 - clamp_eps)
223
+ fuse_priors[name] = lam
224
+
225
+ layer_a_copy.eval()
226
+ return layer_a_copy, fuse_priors
227
+
228
+
229
+ def _init_fisher_accumulators(
230
+ layer_a: torch.nn.Module,
231
+ layer_b: torch.nn.Module,
232
+ fisher_mode: str,
233
+ device: str,
234
+ ) -> Tuple[List[Dict[str, object]], List[Dict[str, int]]]:
235
+ fisher_sums: List[Dict[str, object]] = []
236
+ param_numels: List[Dict[str, int]] = []
237
+ for layer in (layer_a, layer_b):
238
+ layer_sums: Dict[str, object] = {}
239
+ layer_numels: Dict[str, int] = {}
240
+ for name, param in layer.named_parameters():
241
+ if not param.requires_grad:
242
+ continue
243
+ if fisher_mode == "param":
244
+ layer_sums[name] = torch.zeros_like(
245
+ param, dtype=torch.float32, device="cpu"
246
+ )
247
+ else:
248
+ layer_sums[name] = torch.zeros((), dtype=torch.float32, device=device)
249
+ layer_numels[name] = param.numel()
250
+ fisher_sums.append(layer_sums)
251
+ param_numels.append(layer_numels)
252
+ return fisher_sums, param_numels
253
+
254
+
255
+ def _accumulate_fisher_from_grads(
256
+ layer: torch.nn.Module,
257
+ layer_sums: Dict[str, object],
258
+ fisher_mode: str,
259
+ ) -> None:
260
+ for name, param in layer.named_parameters():
261
+ if not param.requires_grad or param.grad is None:
262
+ continue
263
+ grad_sq = param.grad.detach().float().pow(2)
264
+ if fisher_mode == "param":
265
+ layer_sums[name] += grad_sq.cpu()
266
+ else:
267
+ layer_sums[name] += grad_sq.sum()
268
+
269
+
270
+ def _finalize_fisher_sums(
271
+ fisher_sums: List[Dict[str, object]],
272
+ fisher_mode: str,
273
+ ) -> List[Dict[str, object]]:
274
+ if fisher_mode == "param":
275
+ return fisher_sums
276
+
277
+ finalized: List[Dict[str, object]] = []
278
+ for layer_sums in fisher_sums:
279
+ finalized_layer: Dict[str, object] = {}
280
+ for name, value in layer_sums.items():
281
+ if isinstance(value, torch.Tensor):
282
+ finalized_layer[name] = float(value.detach().cpu().item())
283
+ else:
284
+ finalized_layer[name] = float(value)
285
+ finalized.append(finalized_layer)
286
+ return finalized
287
+
288
+
289
+ def _compute_fuse_priors(
290
+ layer_a: torch.nn.Module,
291
+ layer_b: torch.nn.Module,
292
+ fisher_sums: List[Dict[str, object]],
293
+ num_batches: int,
294
+ param_numels: List[Dict[str, int]],
295
+ fisher_mode: str,
296
+ eps: float,
297
+ ) -> Dict[str, float]:
298
+ fuse_priors: Dict[str, float] = {}
299
+ params_b = {name: param for name, param in layer_b.named_parameters()}
300
+ clamp_eps = 1e-4
301
+ for name, param_a in layer_a.named_parameters():
302
+ param_b = params_b.get(name)
303
+ if param_b is None or param_b.shape != param_a.shape:
304
+ continue
305
+ if fisher_mode == "param":
306
+ fa = fisher_sums[0][name] / max(num_batches, 1)
307
+ fb = fisher_sums[1][name] / max(num_batches, 1)
308
+ fa_val = float(fa.mean().item()) if isinstance(fa, torch.Tensor) else float(fa)
309
+ fb_val = float(fb.mean().item()) if isinstance(fb, torch.Tensor) else float(fb)
310
+ else:
311
+ fa_val = float(
312
+ fisher_sums[0][name]
313
+ / (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1))
314
+ )
315
+ fb_val = float(
316
+ fisher_sums[1][name]
317
+ / (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1))
318
+ )
319
+ denom = fa_val + fb_val
320
+ lam = 0.5 if denom <= eps else fa_val / (denom + eps)
321
+ fuse_priors[name] = min(max(lam, clamp_eps), 1.0 - clamp_eps)
322
+ return fuse_priors
323
+
324
+
325
+ def _score_dwce_with_shared_backward(
326
+ model,
327
+ layer_a: torch.nn.Module,
328
+ layer_b: torch.nn.Module,
329
+ dataloader,
330
+ device: str,
331
+ fisher_mode: str,
332
+ max_batches: int,
333
+ eps: float,
334
+ norm: str,
335
+ hidden_size: int,
336
+ enable_head_permute: bool = True,
337
+ ) -> Tuple[float, Dict[str, object]]:
338
+ attn_a = find_attention_module(layer_a)
339
+ attn_b = find_attention_module(layer_b)
340
+ perm = None
341
+ inv_perm = None
342
+ num_heads = None
343
+ num_kv_heads = None
344
+ head_dim = None
345
+ if enable_head_permute:
346
+ mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
347
+ model,
348
+ attn_a,
349
+ attn_b,
350
+ dataloader,
351
+ device,
352
+ hidden_size,
353
+ )
354
+ perm = build_head_permutation(
355
+ mean_a,
356
+ mean_b,
357
+ num_heads=num_heads,
358
+ num_kv_heads=num_kv_heads,
359
+ eps=eps,
360
+ )
361
+
362
+ layer_a_copy = copy.deepcopy(layer_a)
363
+ layer_b_copy = copy.deepcopy(layer_b)
364
+ attn_b_copy = find_attention_module(layer_b_copy)
365
+ if perm is not None:
366
+ permute_attention_heads(
367
+ attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim
368
+ )
369
+
370
+ inv_perm = [0] * len(perm)
371
+ for idx, mapped in enumerate(perm):
372
+ inv_perm[mapped] = idx
373
+
374
+ cache: Dict[str, Optional[torch.Tensor]] = {"teacher": None}
375
+ grad_sq_cache: List[torch.Tensor] = []
376
+ cached_bytes = 0
377
+
378
+ def hook_b(_module, _inputs, output, _kwargs=None):
379
+ teacher_hidden = _extract_hidden(output)
380
+ if teacher_hidden is None:
381
+ raise RuntimeError("Failed to extract teacher hidden state output.")
382
+ cache["teacher"] = teacher_hidden
383
+ if teacher_hidden.requires_grad:
384
+ teacher_hidden.retain_grad()
385
+ return output
386
+
387
+ handle_b, _ = _register_forward_hook(layer_b, hook_b)
388
+ for param in model.parameters():
389
+ param.requires_grad_(False)
390
+ for layer in (layer_a, layer_b):
391
+ for param in layer.parameters():
392
+ param.requires_grad_(True)
393
+ fisher_sums, param_numels = _init_fisher_accumulators(
394
+ layer_a, layer_b, fisher_mode, device
395
+ )
396
+ num_batches = 0
397
+
398
+ if perm is not None:
399
+ permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim)
400
+ try:
401
+ model.eval()
402
+ for batch_idx, batch in enumerate(dataloader):
403
+ if max_batches and batch_idx >= max_batches:
404
+ break
405
+ cache["teacher"] = None
406
+ input_ids = batch[0].to(device)
407
+ attention_mask = batch[1].to(device) if len(batch) > 1 else None
408
+
409
+ model.zero_grad(set_to_none=True)
410
+ outputs = model(
411
+ input_ids=input_ids,
412
+ attention_mask=attention_mask,
413
+ labels=input_ids,
414
+ )
415
+ outputs.loss.backward()
416
+
417
+ teacher = cache["teacher"]
418
+ grad = None if teacher is None else teacher.grad
419
+ if teacher is None or grad is None:
420
+ raise RuntimeError(
421
+ "Auto selection hooks failed to capture outputs/gradients. "
422
+ "Try updating PyTorch or run with --layer <index>."
423
+ )
424
+ grad_sq = grad.detach().pow(2).to(device=device, dtype=torch.float16)
425
+ cached_bytes += grad_sq.numel() * grad_sq.element_size()
426
+ if cached_bytes > _DWCE_GRAD_CACHE_MAX_BYTES:
427
+ raise _DwceGradCacheOverflow(
428
+ "DWCE grad cache exceeded device-memory budget during shared-backward scoring."
429
+ )
430
+ grad_sq_cache.append(grad_sq)
431
+ _accumulate_fisher_from_grads(layer_a, fisher_sums[0], fisher_mode)
432
+ _accumulate_fisher_from_grads(layer_b, fisher_sums[1], fisher_mode)
433
+ model.zero_grad(set_to_none=True)
434
+ num_batches += 1
435
+ finally:
436
+ handle_b.remove()
437
+ if inv_perm is not None:
438
+ permute_attention_heads(
439
+ attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim
440
+ )
441
+ for param in model.parameters():
442
+ param.requires_grad_(True)
443
+
444
+ if num_batches == 0:
445
+ raise RuntimeError("No batches processed; check dataset or text inputs.")
446
+
447
+ fisher_sums = _finalize_fisher_sums(fisher_sums, fisher_mode)
448
+ merge_layers(
449
+ layer_a_copy,
450
+ layer_b_copy,
451
+ fisher_sums[0],
452
+ fisher_sums[1],
453
+ num_batches,
454
+ param_numels[0],
455
+ param_numels[1],
456
+ fisher_mode=fisher_mode,
457
+ eps=eps,
458
+ )
459
+ fuse_priors = _compute_fuse_priors(
460
+ layer_a,
461
+ layer_b,
462
+ fisher_sums,
463
+ num_batches,
464
+ param_numels,
465
+ fisher_mode,
466
+ eps,
467
+ )
468
+
469
+ fused_layer = layer_a_copy
470
+ fused_layer.eval()
471
+ phase2_cache = {"teacher": None, "fused": None}
472
+
473
+ def hook_a(_module, inputs, output, kwargs=None):
474
+ with torch.no_grad():
475
+ detached_inputs = tuple(_detach_arg(arg) for arg in inputs)
476
+ if kwargs:
477
+ detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()}
478
+ fused_out = fused_layer(*detached_inputs, **detached_kwargs)
479
+ else:
480
+ fused_out = fused_layer(*detached_inputs)
481
+ fused_hidden = _extract_hidden(fused_out)
482
+ if fused_hidden is None:
483
+ raise RuntimeError("Failed to extract fused hidden state output.")
484
+ phase2_cache["fused"] = fused_hidden
485
+ return output
486
+
487
+ def hook_b_eval(_module, _inputs, output, _kwargs=None):
488
+ teacher_hidden = _extract_hidden(output)
489
+ if teacher_hidden is None:
490
+ raise RuntimeError("Failed to extract teacher hidden state output.")
491
+ phase2_cache["teacher"] = teacher_hidden
492
+ return output
493
+
494
+ handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
495
+ handle_b_eval, has_kwargs_b = _register_forward_hook(layer_b, hook_b_eval)
496
+ supports_kwargs = has_kwargs_a and has_kwargs_b
497
+
498
+ score_num = 0.0
499
+ score_den = 0.0
500
+ token_count = 0.0
501
+ try:
502
+ model.eval()
503
+ for batch_idx, batch in enumerate(dataloader):
504
+ if batch_idx >= num_batches:
505
+ break
506
+ phase2_cache["teacher"] = None
507
+ phase2_cache["fused"] = None
508
+ input_ids = batch[0].to(device)
509
+ attention_mask = batch[1].to(device) if len(batch) > 1 else None
510
+
511
+ with torch.no_grad():
512
+ model(
513
+ input_ids=input_ids,
514
+ attention_mask=attention_mask,
515
+ use_cache=False,
516
+ )
517
+
518
+ teacher = phase2_cache["teacher"]
519
+ fused = phase2_cache["fused"]
520
+ if teacher is None or fused is None:
521
+ raise RuntimeError(
522
+ "Auto selection hooks failed to capture outputs during DWCE replay."
523
+ )
524
+ grad_sq = grad_sq_cache[batch_idx].to(dtype=torch.float32)
525
+ if attention_mask is not None:
526
+ mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1)
527
+ batch_tokens = float(mask.sum().item())
528
+ grad_sq = grad_sq * mask
529
+ else:
530
+ mask = None
531
+ batch_tokens = float(input_ids.numel())
532
+ token_count += batch_tokens
533
+
534
+ delta = fused - teacher
535
+ if mask is not None:
536
+ delta = delta * mask
537
+ score_num += (delta.float().pow(2) * grad_sq).sum().item()
538
+ score_den += (teacher.float().pow(2) * grad_sq).sum().item()
539
+ finally:
540
+ handle_a.remove()
541
+ handle_b_eval.remove()
542
+
543
+ score = (
544
+ score_num / (score_den + eps)
545
+ if norm == "relative"
546
+ else score_num / max(token_count, 1.0)
547
+ )
548
+ meta = {
549
+ "num_batches": num_batches,
550
+ "token_count": token_count,
551
+ "norm": norm,
552
+ "supports_kwargs": supports_kwargs,
553
+ "fuse_priors": fuse_priors,
554
+ "metric": "dwce",
555
+ "dwce_mode": "shared",
556
+ }
557
+ return score, meta
558
+
559
+
560
+ def _compute_dwce_for_pair(
561
+ model,
562
+ layer_a: torch.nn.Module,
563
+ layer_b: torch.nn.Module,
564
+ fused_layer: torch.nn.Module,
565
+ dataloader,
566
+ device: str,
567
+ max_batches: int,
568
+ eps: float,
569
+ norm: str,
570
+ ) -> Tuple[float, Dict[str, object]]:
571
+ cache = {"teacher": None, "fused": None}
572
+ supports_kwargs = True
573
+
574
+ def hook_a(_module, inputs, output, kwargs=None):
575
+ with torch.no_grad():
576
+ detached_inputs = tuple(_detach_arg(arg) for arg in inputs)
577
+ if kwargs is not None and len(kwargs) > 0:
578
+ detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()}
579
+ fused_out = fused_layer(*detached_inputs, **detached_kwargs)
580
+ else:
581
+ fused_out = fused_layer(*detached_inputs)
582
+ fused_hidden = _extract_hidden(fused_out)
583
+ if fused_hidden is None:
584
+ raise RuntimeError("Failed to extract fused hidden state output.")
585
+ cache["fused"] = fused_hidden
586
+ return output
587
+
588
+ def hook_b(_module, _inputs, output, _kwargs=None):
589
+ teacher_hidden = _extract_hidden(output)
590
+ if teacher_hidden is None:
591
+ raise RuntimeError("Failed to extract teacher hidden state output.")
592
+ cache["teacher"] = teacher_hidden
593
+ if teacher_hidden.requires_grad:
594
+ teacher_hidden.retain_grad()
595
+ return output
596
+
597
+ handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
598
+ handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b)
599
+ supports_kwargs = has_kwargs_a and has_kwargs_b
600
+
601
+ score_num = 0.0
602
+ score_den = 0.0
603
+ token_count = 0.0
604
+ num_batches = 0
605
+
606
+ model.eval()
607
+ for batch_idx, batch in enumerate(dataloader):
608
+ if max_batches and batch_idx >= max_batches:
609
+ break
610
+ cache["teacher"] = None
611
+ cache["fused"] = None
612
+
613
+ input_ids = batch[0].to(device)
614
+ attention_mask = batch[1].to(device) if len(batch) > 1 else None
615
+
616
+ model.zero_grad(set_to_none=True)
617
+ outputs = model(
618
+ input_ids=input_ids,
619
+ attention_mask=attention_mask,
620
+ labels=input_ids,
621
+ )
622
+ loss = outputs.loss
623
+ loss.backward()
624
+
625
+ teacher = cache["teacher"]
626
+ fused = cache["fused"]
627
+ grad = None if teacher is None else teacher.grad
628
+ if teacher is None or fused is None or grad is None:
629
+ raise RuntimeError(
630
+ "Auto selection hooks failed to capture outputs/gradients. "
631
+ "Try updating PyTorch or run with --layer <index>."
632
+ )
633
+ if not teacher.requires_grad:
634
+ raise RuntimeError(
635
+ "Teacher hidden state does not require grad. "
636
+ "Ensure model parameters require grad for DWCE."
637
+ )
638
+
639
+ with torch.no_grad():
640
+ if attention_mask is not None:
641
+ mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1)
642
+ batch_tokens = float(mask.sum().item())
643
+ else:
644
+ mask = None
645
+ batch_tokens = float(input_ids.numel())
646
+ token_count += batch_tokens
647
+
648
+ delta = fused - teacher
649
+ grad_sq = grad.pow(2)
650
+ if mask is not None:
651
+ delta = delta * mask
652
+ grad_sq = grad_sq * mask
653
+
654
+ score_num += (delta.pow(2) * grad_sq).sum().item()
655
+ score_den += (teacher.pow(2) * grad_sq).sum().item()
656
+ num_batches += 1
657
+
658
+ handle_a.remove()
659
+ handle_b.remove()
660
+
661
+ if norm == "relative":
662
+ score = score_num / (score_den + eps)
663
+ else:
664
+ denom = token_count if token_count > 0 else 1.0
665
+ score = score_num / denom
666
+
667
+ meta = {
668
+ "num_batches": num_batches,
669
+ "token_count": token_count,
670
+ "norm": norm,
671
+ "supports_kwargs": supports_kwargs,
672
+ }
673
+ return score, meta
674
+
675
+
676
+ def _compute_cosine_for_pair(
677
+ model,
678
+ layer_a: torch.nn.Module,
679
+ layer_b: torch.nn.Module,
680
+ dataloader,
681
+ device: str,
682
+ max_batches: int,
683
+ eps: float,
684
+ ) -> Tuple[float, Dict[str, object]]:
685
+ cache = {"a": None, "b": None}
686
+ supports_kwargs = True
687
+
688
+ def hook_a(_module, _inputs, output, _kwargs=None):
689
+ hidden = _extract_hidden(output)
690
+ if hidden is None:
691
+ raise RuntimeError("Failed to extract layer_a hidden state output.")
692
+ cache["a"] = hidden
693
+ return output
694
+
695
+ def hook_b(_module, _inputs, output, _kwargs=None):
696
+ hidden = _extract_hidden(output)
697
+ if hidden is None:
698
+ raise RuntimeError("Failed to extract layer_b hidden state output.")
699
+ cache["b"] = hidden
700
+ return output
701
+
702
+ handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
703
+ handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b)
704
+ supports_kwargs = has_kwargs_a and has_kwargs_b
705
+
706
+ score_sum = 0.0
707
+ token_count = 0.0
708
+ num_batches = 0
709
+
710
+ model.eval()
711
+ for batch_idx, batch in enumerate(dataloader):
712
+ if max_batches and batch_idx >= max_batches:
713
+ break
714
+ cache["a"] = None
715
+ cache["b"] = None
716
+
717
+ input_ids = batch[0].to(device)
718
+ attention_mask = batch[1].to(device) if len(batch) > 1 else None
719
+
720
+ with torch.no_grad():
721
+ model(
722
+ input_ids=input_ids,
723
+ attention_mask=attention_mask,
724
+ use_cache=False,
725
+ )
726
+
727
+ hidden_a = cache["a"]
728
+ hidden_b = cache["b"]
729
+ if hidden_a is None or hidden_b is None:
730
+ raise RuntimeError(
731
+ "Auto selection hooks failed to capture outputs for cosine scoring."
732
+ )
733
+
734
+ with torch.no_grad():
735
+ a = hidden_a.float()
736
+ b = hidden_b.float()
737
+ cos = F.cosine_similarity(a, b, dim=-1, eps=eps)
738
+ distance = 1.0 - cos
739
+
740
+ if attention_mask is not None:
741
+ mask = attention_mask.to(dtype=torch.float32)
742
+ batch_tokens = float(mask.sum().item())
743
+ distance = distance * mask
744
+ else:
745
+ batch_tokens = float(distance.numel())
746
+
747
+ token_count += batch_tokens
748
+ score_sum += float(distance.sum().item())
749
+ num_batches += 1
750
+
751
+ handle_a.remove()
752
+ handle_b.remove()
753
+
754
+ denom = token_count if token_count > 0 else 1.0
755
+ score = score_sum / denom
756
+ meta = {
757
+ "num_batches": num_batches,
758
+ "token_count": token_count,
759
+ "metric": "cosine",
760
+ "supports_kwargs": supports_kwargs,
761
+ }
762
+ return score, meta
763
+
764
+
765
+ def _compute_global_rel_change_for_pair(
766
+ model,
767
+ layers: List[torch.nn.Module],
768
+ pair_idx: int,
769
+ dataloader,
770
+ args,
771
+ max_batches: int,
772
+ eps: float,
773
+ ) -> Tuple[float, Dict[str, object]]:
774
+ hidden_size = _get_hidden_size(model)
775
+ head_permute_select = not bool(getattr(args, "no_head_permute_select", False))
776
+ layer_a = layers[pair_idx]
777
+ layer_b = layers[pair_idx + 1]
778
+ fused_layer, fuse_priors = _build_fused_layer_for_pair(
779
+ model,
780
+ layer_a,
781
+ layer_b,
782
+ dataloader,
783
+ device=args.device,
784
+ fisher_mode=args.fisher_mode,
785
+ eps=eps,
786
+ hidden_size=hidden_size,
787
+ enable_head_permute=head_permute_select,
788
+ )
789
+ fused_layer.to(args.device)
790
+ fused_layer.eval()
791
+
792
+ parent, name, container = find_layer_container(model, getattr(args, "layer_path", None))
793
+ if len(list(container)) != len(layers):
794
+ raise RuntimeError("Layer container changed during auto-selection; aborting rerank.")
795
+
796
+ virtual_layers = list(layers)
797
+ virtual_layers[pair_idx] = fused_layer
798
+ del virtual_layers[pair_idx + 1]
799
+ if isinstance(container, torch.nn.ModuleList):
800
+ virtual_container = torch.nn.ModuleList(virtual_layers)
801
+ elif isinstance(container, list):
802
+ virtual_container = virtual_layers
803
+ else:
804
+ raise TypeError("Layer container must be ModuleList or list")
805
+
806
+ teacher_cache = {"pair": None, "final": None}
807
+ supports_kwargs = True
808
+
809
+ def hook_pair(_module, _inputs, output, _kwargs=None):
810
+ hidden = _extract_hidden(output)
811
+ if hidden is None:
812
+ raise RuntimeError("Failed to extract pair output for global relation rerank.")
813
+ teacher_cache["pair"] = hidden
814
+ return output
815
+
816
+ handle_pair, has_kwargs_pair = _register_forward_hook(layer_b, hook_pair)
817
+ supports_kwargs = supports_kwargs and has_kwargs_pair
818
+
819
+ score_sum = 0.0
820
+ token_count = 0.0
821
+ num_batches = 0
822
+
823
+ model.eval()
824
+ for batch_idx, batch in enumerate(dataloader):
825
+ if max_batches and batch_idx >= max_batches:
826
+ break
827
+
828
+ teacher_cache["pair"] = None
829
+
830
+ input_ids = batch[0].to(args.device)
831
+ attention_mask = batch[1].to(args.device) if len(batch) > 1 else None
832
+
833
+ with torch.no_grad():
834
+ teacher_outputs = model(
835
+ input_ids=input_ids,
836
+ attention_mask=attention_mask,
837
+ output_hidden_states=True,
838
+ use_cache=False,
839
+ )
840
+ teacher_hidden_states = getattr(teacher_outputs, "hidden_states", None)
841
+ if not teacher_hidden_states:
842
+ raise RuntimeError("Teacher forward did not return hidden_states.")
843
+ teacher_final = teacher_hidden_states[-1]
844
+ teacher_pair = teacher_cache["pair"]
845
+
846
+ if teacher_pair is None or teacher_final is None:
847
+ raise RuntimeError(
848
+ "Failed to capture teacher pair/final hidden states for global rerank."
849
+ )
850
+
851
+ with torch.no_grad(), _temporary_layers(parent, name, virtual_container):
852
+ fused_outputs = model(
853
+ input_ids=input_ids,
854
+ attention_mask=attention_mask,
855
+ output_hidden_states=True,
856
+ use_cache=False,
857
+ )
858
+ fused_hidden_states = getattr(fused_outputs, "hidden_states", None)
859
+ if not fused_hidden_states:
860
+ raise RuntimeError("Fused forward did not return hidden_states.")
861
+ fused_final = fused_hidden_states[-1]
862
+
863
+ if fused_final is None:
864
+ raise RuntimeError("Failed to capture fused final hidden state for global rerank.")
865
+
866
+ with torch.no_grad():
867
+ teacher_pair_f = teacher_pair.float()
868
+ teacher_final_f = teacher_final.float()
869
+ fused_final_f = fused_final.float()
870
+
871
+ teacher_rel = F.cosine_similarity(
872
+ teacher_pair_f, teacher_final_f, dim=-1, eps=eps
873
+ )
874
+ fused_rel = F.cosine_similarity(
875
+ teacher_pair_f, fused_final_f, dim=-1, eps=eps
876
+ )
877
+ rel_change = (teacher_rel - fused_rel).abs()
878
+
879
+ if attention_mask is not None:
880
+ mask = attention_mask.to(dtype=torch.float32)
881
+ batch_tokens = float(mask.sum().item())
882
+ rel_change = rel_change * mask
883
+ else:
884
+ batch_tokens = float(rel_change.numel())
885
+
886
+ token_count += batch_tokens
887
+ score_sum += float(rel_change.sum().item())
888
+ num_batches += 1
889
+
890
+ handle_pair.remove()
891
+ del fused_layer
892
+ if torch.cuda.is_available():
893
+ torch.cuda.empty_cache()
894
+
895
+ denom = token_count if token_count > 0 else 1.0
896
+ score = score_sum / denom
897
+ meta = {
898
+ "num_batches": num_batches,
899
+ "token_count": token_count,
900
+ "metric": "global_rel_change",
901
+ "supports_kwargs": supports_kwargs,
902
+ "fuse_priors": fuse_priors,
903
+ }
904
+ return score, meta
905
+
906
+
907
+ def select_layer_auto(
908
+ model,
909
+ layers: List[torch.nn.Module],
910
+ dataloader,
911
+ args,
912
+ previous_scores: Optional[List[float]] = None,
913
+ start_index: int = 0,
914
+ exclude_pairs: Optional[Set[int]] = None,
915
+ ) -> Tuple[int, List[float], Dict[str, object]]:
916
+ num_layers = len(layers)
917
+ if num_layers < 2:
918
+ raise SystemExit("Model must have at least 2 layers for auto selection.")
919
+
920
+ hidden_size = _get_hidden_size(model)
921
+ num_pairs = num_layers - 1
922
+ scores: List[float] = [float("inf")] * num_pairs
923
+ meta_per_pair: List[Optional[Dict[str, object]]] = [None] * num_pairs
924
+ supports_kwargs_all = True
925
+ head_permute_select = not bool(getattr(args, "no_head_permute_select", False))
926
+ exclude_set: Set[int] = {
927
+ int(idx)
928
+ for idx in (exclude_pairs or set())
929
+ if isinstance(idx, int) and 0 <= int(idx) < num_pairs
930
+ }
931
+
932
+ max_batches = args.auto_max_batches
933
+ start_index = max(0, min(start_index, num_pairs))
934
+ auto_metric = str(getattr(args, "auto_metric", "dwce")).strip().lower()
935
+ if auto_metric == "hybrid":
936
+ auto_metric = "hybrid_cosine"
937
+ if auto_metric not in {
938
+ "dwce",
939
+ "cosine",
940
+ "hybrid_cosine",
941
+ "hybrid_global_rel",
942
+ }:
943
+ raise SystemExit(
944
+ "--auto_metric must be one of: dwce, cosine, hybrid, "
945
+ "hybrid_cosine, hybrid_global_rel"
946
+ )
947
+ auto_cosine_topk = int(getattr(args, "auto_cosine_topk", 3))
948
+ if auto_cosine_topk <= 0:
949
+ raise SystemExit("--auto_cosine_topk must be >= 1")
950
+ print(
951
+ f"[auto] metric={auto_metric}; using "
952
+ f"{('all' if max_batches == 0 else max_batches)} batches "
953
+ "from calibration samples."
954
+ )
955
+
956
+ reuse_upto = 0
957
+ allow_reuse = auto_metric == "dwce"
958
+ if previous_scores:
959
+ reuse_upto = min(start_index, len(previous_scores), num_pairs) if allow_reuse else 0
960
+ for idx in range(reuse_upto):
961
+ if idx in exclude_set:
962
+ scores[idx] = float("inf")
963
+ meta_per_pair[idx] = {"excluded": True}
964
+ print(f"[auto] skipped excluded pair {idx}-{idx+1}.")
965
+ continue
966
+ scores[idx] = previous_scores[idx]
967
+ meta_per_pair[idx] = (
968
+ {
969
+ "num_batches": 0,
970
+ "token_count": 0.0,
971
+ "norm": args.auto_norm,
972
+ "metric": auto_metric,
973
+ "supports_kwargs": True,
974
+ "reused": True,
975
+ }
976
+ )
977
+ print(f"[auto] reused pair {idx}-{idx+1}: {scores[idx]:.6e}")
978
+
979
+ compute_start = start_index if reuse_upto == start_index else reuse_upto
980
+ pairs_to_score: List[int] = []
981
+ for idx in range(compute_start, num_pairs):
982
+ if idx in exclude_set:
983
+ scores[idx] = float("inf")
984
+ meta_per_pair[idx] = {"excluded": True}
985
+ print(f"[auto] skipped excluded pair {idx}-{idx+1}.")
986
+ continue
987
+ pairs_to_score.append(idx)
988
+
989
+ def _score_dwce_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
990
+ print(f"[auto] building fused pair {idx}-{idx+1} for DWCE...")
991
+ layer_a = layers[idx]
992
+ layer_b = layers[idx + 1]
993
+ dwce_mode = str(getattr(args, "auto_dwce_mode", "separate")).strip().lower()
994
+ if dwce_mode == "shared":
995
+ try:
996
+ return _score_dwce_with_shared_backward(
997
+ model,
998
+ layer_a,
999
+ layer_b,
1000
+ dataloader,
1001
+ device=args.device,
1002
+ fisher_mode=args.fisher_mode,
1003
+ max_batches=max_batches,
1004
+ eps=args.eps,
1005
+ norm=args.auto_norm,
1006
+ hidden_size=hidden_size,
1007
+ enable_head_permute=head_permute_select,
1008
+ )
1009
+ except _DwceGradCacheOverflow:
1010
+ print(
1011
+ "[auto] shared-backward DWCE cache exceeded budget; "
1012
+ "falling back to separate mode."
1013
+ )
1014
+ fused, fuse_priors = _build_fused_layer_for_pair(
1015
+ model,
1016
+ layer_a,
1017
+ layer_b,
1018
+ dataloader,
1019
+ device=args.device,
1020
+ fisher_mode=args.fisher_mode,
1021
+ eps=args.eps,
1022
+ hidden_size=hidden_size,
1023
+ enable_head_permute=head_permute_select,
1024
+ )
1025
+ fused.to(args.device)
1026
+ fused.eval()
1027
+ for param in model.parameters():
1028
+ param.requires_grad_(True)
1029
+ score, meta = _compute_dwce_for_pair(
1030
+ model,
1031
+ layer_a,
1032
+ layer_b,
1033
+ fused,
1034
+ dataloader,
1035
+ device=args.device,
1036
+ max_batches=max_batches,
1037
+ eps=args.eps,
1038
+ norm=args.auto_norm,
1039
+ )
1040
+ meta["fuse_priors"] = fuse_priors
1041
+ meta["metric"] = "dwce"
1042
+ del fused
1043
+ if torch.cuda.is_available():
1044
+ torch.cuda.empty_cache()
1045
+ return score, meta
1046
+
1047
+ def _score_cosine_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
1048
+ print(f"[auto] scoring cosine for pair {idx}-{idx+1}...")
1049
+ layer_a = layers[idx]
1050
+ layer_b = layers[idx + 1]
1051
+ return _compute_cosine_for_pair(
1052
+ model,
1053
+ layer_a,
1054
+ layer_b,
1055
+ dataloader,
1056
+ device=args.device,
1057
+ max_batches=max_batches,
1058
+ eps=args.eps,
1059
+ )
1060
+
1061
+ def _score_global_rel_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
1062
+ print(f"[auto] scoring global relation change for pair {idx}-{idx+1}...")
1063
+ return _compute_global_rel_change_for_pair(
1064
+ model,
1065
+ layers,
1066
+ idx,
1067
+ dataloader,
1068
+ args=args,
1069
+ max_batches=max_batches,
1070
+ eps=args.eps,
1071
+ )
1072
+
1073
+ if auto_metric in {"dwce", "cosine"}:
1074
+ for idx in pairs_to_score:
1075
+ if auto_metric == "dwce":
1076
+ score, meta = _score_dwce_for_pair(idx)
1077
+ else:
1078
+ score, meta = _score_cosine_for_pair(idx)
1079
+ supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True)
1080
+ scores[idx] = score
1081
+ meta_per_pair[idx] = meta
1082
+ print(f"[auto] {auto_metric} pair {idx}-{idx+1}: {score:.6e}")
1083
+ else:
1084
+ dwce_prefilter: Dict[int, float] = {}
1085
+ for idx in pairs_to_score:
1086
+ score, meta = _score_dwce_for_pair(idx)
1087
+ dwce_prefilter[idx] = score
1088
+ supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True)
1089
+ meta_per_pair[idx] = {
1090
+ "prefilter_dwce": score,
1091
+ "dwce_meta": meta,
1092
+ "metric": "hybrid",
1093
+ }
1094
+ print(f"[auto] hybrid prefilter DWCE pair {idx}-{idx+1}: {score:.6e}")
1095
+ ranked = sorted(pairs_to_score, key=lambda i: float(dwce_prefilter[i]))
1096
+ shortlist = ranked[: min(auto_cosine_topk, len(ranked))]
1097
+ print(f"[auto] hybrid shortlist (dwce top-{len(shortlist)}): {shortlist}")
1098
+ for idx in shortlist:
1099
+ if auto_metric == "hybrid_global_rel":
1100
+ score, rerank_meta = _score_global_rel_for_pair(idx)
1101
+ score_metric = "global_rel_change"
1102
+ else:
1103
+ score, rerank_meta = _score_cosine_for_pair(idx)
1104
+ score_metric = "cosine"
1105
+ supports_kwargs_all = supports_kwargs_all and rerank_meta.get(
1106
+ "supports_kwargs", True
1107
+ )
1108
+ scores[idx] = score
1109
+ pair_meta = meta_per_pair[idx] or {}
1110
+ pair_meta["rerank_meta"] = rerank_meta
1111
+ pair_meta["score_metric"] = score_metric
1112
+ meta_per_pair[idx] = pair_meta
1113
+ print(f"[auto] hybrid {score_metric} pair {idx}-{idx+1}: {score:.6e}")
1114
+
1115
+ if not supports_kwargs_all:
1116
+ print(
1117
+ "[auto] Warning: forward hooks did not capture kwargs; "
1118
+ "fused-layer calls may be approximate."
1119
+ )
1120
+
1121
+ print(f"[auto] score summary (metric={auto_metric}, norm={args.auto_norm}):")
1122
+ for idx, score in enumerate(scores):
1123
+ if idx in exclude_set:
1124
+ print(f"[auto] pair {idx}-{idx+1}: excluded")
1125
+ elif math.isfinite(float(score)):
1126
+ print(f"[auto] pair {idx}-{idx+1}: {score:.6e}")
1127
+ else:
1128
+ print(f"[auto] pair {idx}-{idx+1}: {score}")
1129
+
1130
+ candidates = [i for i in range(num_pairs) if i not in exclude_set]
1131
+ if not candidates:
1132
+ raise SystemExit("All pairs are excluded; cannot auto-select a fusion layer.")
1133
+ best_idx = min(candidates, key=lambda i: scores[i])
1134
+ best_score = float(scores[best_idx])
1135
+ if not math.isfinite(best_score):
1136
+ raise SystemExit(
1137
+ "Auto selection failed: all candidate pairs have non-finite scores "
1138
+ "(check --exclude_pairs and data)."
1139
+ )
1140
+ print(f"[auto] Selected layer {best_idx} (score={best_score:.6e})")
1141
+
1142
+ meta = {
1143
+ "per_pair": meta_per_pair,
1144
+ "supports_kwargs": supports_kwargs_all,
1145
+ "max_batches": max_batches,
1146
+ "norm": args.auto_norm,
1147
+ "metric": auto_metric,
1148
+ "cosine_topk": auto_cosine_topk,
1149
+ "start_index": start_index,
1150
+ "excluded_pairs": sorted(exclude_set),
1151
+ }
1152
+ return best_idx, scores, meta
src/loratune.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Centralized Alpaca LoRA finetuning for post-pruned models."""
3
+
4
+ import argparse
5
+ import itertools
6
+ import json
7
+ import os
8
+ from types import SimpleNamespace
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from contextlib import nullcontext
13
+
14
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
15
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
16
+ import ppl_eval
17
+
18
+ from fuse_layers_data import FixedSeqDataset, load_instruction_records
19
+ from fuse_layers_distill import LoRALinear, apply_lora_adapters, merge_lora_adapters
20
+
21
+ try:
22
+ from tqdm import tqdm
23
+ except Exception: # pragma: no cover
24
+ tqdm = None
25
+
26
+
27
+ def parse_args() -> argparse.Namespace:
28
+ parser = argparse.ArgumentParser(description="Run centralized Alpaca LoRA finetuning.")
29
+ parser.add_argument("--base_model", required=True, help="Path or HF model id to finetune")
30
+ parser.add_argument("--output_dir", required=True, help="Directory to save merged model")
31
+ parser.add_argument("--device", default="cuda", help="Training device")
32
+ parser.add_argument(
33
+ "--dtype",
34
+ default="bfloat16",
35
+ choices=["float32", "float16", "bfloat16"],
36
+ help="Model load/training dtype",
37
+ )
38
+ parser.add_argument("--trust_remote_code", action="store_true", help="Enable trust_remote_code")
39
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
40
+
41
+ parser.add_argument(
42
+ "--instruction_dataset",
43
+ default="yahma/alpaca-cleaned",
44
+ help="HF dataset name for Alpaca-style instruction data",
45
+ )
46
+ parser.add_argument("--instruction_config", default=None, help="Optional dataset config")
47
+ parser.add_argument("--instruction_split", default="train", help="Dataset split")
48
+ parser.add_argument("--instruction_field_instruction", default="instruction")
49
+ parser.add_argument("--instruction_field_input", default="input")
50
+ parser.add_argument("--instruction_field_output", default="output")
51
+ parser.add_argument("--max_samples", type=int, default=0, help="Limit instruction samples (0 = all)")
52
+ parser.add_argument("--seq_len", type=int, default=1024, help="Training sequence length")
53
+ parser.add_argument("--batch_size", type=int, default=64, help="Global batch size")
54
+ parser.add_argument("--micro_batch_size", type=int, default=4, help="Per-step micro-batch size")
55
+ parser.add_argument("--epochs", type=float, default=1.0, help="Training epochs")
56
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
57
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay")
58
+ parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Gradient clipping norm")
59
+ parser.add_argument("--log_steps", type=int, default=100, help="Log every N optimizer steps")
60
+ parser.add_argument(
61
+ "--save_steps",
62
+ type=int,
63
+ default=200,
64
+ help="Save LoRA adapter checkpoints every N optimizer steps (0 = disable)",
65
+ )
66
+ parser.add_argument(
67
+ "--no_wikitext2_ppl_on_log",
68
+ dest="wikitext2_ppl_on_log",
69
+ action="store_false",
70
+ help="Disable Wikitext-2 perplexity evaluation at loss log steps",
71
+ )
72
+ parser.set_defaults(wikitext2_ppl_on_log=True)
73
+ parser.add_argument("--wikitext2_ppl_seq_len", type=int, default=128)
74
+ parser.add_argument("--wikitext2_ppl_batch_size", type=int, default=8)
75
+ parser.add_argument("--wikitext2_ppl_max_batches", type=int, default=None)
76
+
77
+ parser.add_argument("--lora_rank", type=int, default=8, help="LoRA rank")
78
+ parser.add_argument("--lora_alpha", type=float, default=16.0, help="LoRA alpha")
79
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="LoRA dropout")
80
+ parser.add_argument(
81
+ "--lora_target_modules",
82
+ nargs="*",
83
+ default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
84
+ help="Linear module suffixes to LoRA-wrap",
85
+ )
86
+
87
+ return parser.parse_args()
88
+
89
+
90
+ def get_dtype(name: str) -> torch.dtype:
91
+ return {
92
+ "float32": torch.float32,
93
+ "float16": torch.float16,
94
+ "bfloat16": torch.bfloat16,
95
+ }[name]
96
+
97
+
98
+ def seed_all(seed: int) -> None:
99
+ torch.manual_seed(seed)
100
+ if torch.cuda.is_available():
101
+ torch.cuda.manual_seed_all(seed)
102
+
103
+
104
+ def normalize_config(config):
105
+ layer_types = getattr(config, "layer_types", None)
106
+ num_hidden_layers = getattr(config, "num_hidden_layers", None)
107
+ if layer_types is not None and num_hidden_layers is not None and len(layer_types) != num_hidden_layers:
108
+ config.layer_types = list(layer_types[:num_hidden_layers])
109
+ if getattr(config, "_attn_implementation", None) is None:
110
+ config._attn_implementation = "eager"
111
+ return config
112
+
113
+
114
+ def load_normalized_config(base_model: str, trust_remote_code: bool):
115
+ config_dict, unused_kwargs = PretrainedConfig.get_config_dict(base_model, trust_remote_code=trust_remote_code)
116
+ layer_types = config_dict.get("layer_types")
117
+ num_hidden_layers = config_dict.get("num_hidden_layers")
118
+ if layer_types is not None and num_hidden_layers is not None and len(layer_types) != num_hidden_layers:
119
+ config_dict["layer_types"] = list(layer_types[:num_hidden_layers])
120
+ if config_dict.get("_attn_implementation") is None:
121
+ config_dict["_attn_implementation"] = "eager"
122
+ model_type = config_dict["model_type"]
123
+ config_class = CONFIG_MAPPING[model_type]
124
+ config = config_class.from_dict(config_dict, **unused_kwargs)
125
+ return normalize_config(config)
126
+
127
+
128
+ def validate_local_model_dir(base_path: Path) -> None:
129
+ if not base_path.exists() or not base_path.is_dir():
130
+ return
131
+
132
+ has_config = (base_path / "config.json").is_file()
133
+ has_weights = any(
134
+ (base_path / name).is_file()
135
+ for name in (
136
+ "model.safetensors",
137
+ "model.safetensors.index.json",
138
+ "pytorch_model.bin",
139
+ "pytorch_model.bin.index.json",
140
+ )
141
+ )
142
+ if has_config and has_weights:
143
+ return
144
+
145
+ raise SystemExit(
146
+ "Local --base_model points to an incomplete HF model directory: "
147
+ f"{base_path}. Expected at least config.json and model weights. "
148
+ "Set --base_model/BASE_MODEL to a saved HF model directory."
149
+ )
150
+
151
+
152
+ def load_base_artifacts(args: argparse.Namespace):
153
+ base_path = Path(args.base_model)
154
+ if base_path.is_file() and base_path.suffix == ".bin":
155
+ checkpoint = torch.load(str(base_path), map_location="cpu", weights_only=False)
156
+ if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
157
+ raise SystemExit("Expected a .bin checkpoint dict with `model` and `tokenizer` entries.")
158
+ model = checkpoint["model"]
159
+ tokenizer = checkpoint["tokenizer"]
160
+ if tokenizer.pad_token is None:
161
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
162
+ return model, tokenizer
163
+
164
+ validate_local_model_dir(base_path)
165
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=args.trust_remote_code)
166
+ if tokenizer.pad_token is None:
167
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
168
+ config = load_normalized_config(args.base_model, trust_remote_code=args.trust_remote_code)
169
+ model = AutoModelForCausalLM.from_pretrained(
170
+ args.base_model,
171
+ config=config,
172
+ torch_dtype=get_dtype(args.dtype),
173
+ trust_remote_code=args.trust_remote_code,
174
+ )
175
+ return model, tokenizer
176
+
177
+
178
+ def build_training_loader(args: argparse.Namespace, tokenizer) -> torch.utils.data.DataLoader:
179
+ num_samples = args.max_samples if args.max_samples > 0 else 0
180
+ records = load_instruction_records(args, num_samples)
181
+ if not records:
182
+ raise SystemExit("No instruction records were loaded.")
183
+ dataset = FixedSeqDataset(records, tokenizer, args.seq_len)
184
+ return torch.utils.data.DataLoader(dataset, batch_size=args.micro_batch_size, shuffle=True)
185
+
186
+
187
+ def save_lora_adapters(
188
+ model: torch.nn.Module, args: argparse.Namespace, subdir: str = "lora_adapter"
189
+ ) -> str:
190
+ adapter_dir = os.path.join(args.output_dir, subdir)
191
+ os.makedirs(adapter_dir, exist_ok=True)
192
+
193
+ adapter_state = {}
194
+ adapter_modules = {}
195
+ for module_name, module in model.named_modules():
196
+ if not isinstance(module, LoRALinear):
197
+ continue
198
+ adapter_modules[module_name] = {
199
+ "rank": module.rank,
200
+ "alpha": module.alpha,
201
+ "scaling": module.scaling,
202
+ "dropout": getattr(module.dropout, "p", 0.0),
203
+ "base_layer_class": type(module.base).__name__,
204
+ "in_features": module.base.in_features,
205
+ "out_features": module.base.out_features,
206
+ }
207
+ adapter_state[f"{module_name}.lora_A.weight"] = module.lora_A.weight.detach().cpu()
208
+ adapter_state[f"{module_name}.lora_B.weight"] = module.lora_B.weight.detach().cpu()
209
+
210
+ torch.save(adapter_state, os.path.join(adapter_dir, "adapter_model.bin"))
211
+ with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as handle:
212
+ json.dump(
213
+ {
214
+ "base_model": args.base_model,
215
+ "lora_rank": args.lora_rank,
216
+ "lora_alpha": args.lora_alpha,
217
+ "lora_dropout": args.lora_dropout,
218
+ "lora_target_modules": list(args.lora_target_modules),
219
+ "batch_size": args.batch_size,
220
+ "micro_batch_size": args.micro_batch_size,
221
+ "grad_accum_steps": args.grad_accum_steps,
222
+ "modules": adapter_modules,
223
+ },
224
+ handle,
225
+ indent=2,
226
+ )
227
+ return adapter_dir
228
+
229
+
230
+ def prepare_wikitext2_eval(args: argparse.Namespace, model, tokenizer):
231
+ if not args.wikitext2_ppl_on_log:
232
+ return None
233
+ return ppl_eval.prepare_ppl_dataloaders(
234
+ tokenizer=tokenizer,
235
+ datasets=["wikitext"],
236
+ configs=["wikitext-2-raw-v1"],
237
+ split="test",
238
+ text_field=None,
239
+ num_samples=0,
240
+ seq_len=args.wikitext2_ppl_seq_len,
241
+ batch_size=args.wikitext2_ppl_batch_size,
242
+ seed=args.seed,
243
+ shuffle=False,
244
+ model_family="auto",
245
+ add_bos="auto",
246
+ cache_dir=None,
247
+ num_workers=0,
248
+ model=model,
249
+ )
250
+
251
+
252
+ def train(model: torch.nn.Module, dataloader, args: argparse.Namespace, wikitext2_eval_dataloaders=None) -> dict:
253
+ lora_args = SimpleNamespace(
254
+ lora_rank=args.lora_rank,
255
+ lora_alpha=args.lora_alpha,
256
+ lora_dropout=args.lora_dropout,
257
+ lora_target_modules=args.lora_target_modules,
258
+ lora_respect_exclude_pairs=False,
259
+ layer_path=None,
260
+ exclude_pairs=None,
261
+ )
262
+ lora_modules = apply_lora_adapters(model, lora_args)
263
+ lora_params = [param for module in lora_modules for param in module.lora_parameters()]
264
+
265
+ optimizer = torch.optim.AdamW(
266
+ lora_params,
267
+ lr=args.learning_rate,
268
+ weight_decay=args.weight_decay,
269
+ )
270
+ model.train()
271
+
272
+ device = torch.device(args.device)
273
+ device_type = device.type
274
+ amp_dtype = None
275
+ if args.dtype == "float16":
276
+ amp_dtype = torch.float16
277
+ elif args.dtype == "bfloat16":
278
+ amp_dtype = torch.bfloat16
279
+ use_amp = amp_dtype is not None and device_type == "cuda"
280
+ use_scaler = use_amp and amp_dtype == torch.float16
281
+ scaler = torch.cuda.amp.GradScaler() if use_scaler else None
282
+
283
+ full_epochs = int(args.epochs)
284
+ fractional = args.epochs - full_epochs
285
+ epoch_plan = [None] * full_epochs
286
+ if fractional > 1e-8:
287
+ frac_batches = max(1, int(round(fractional * len(dataloader))))
288
+ epoch_plan.append(frac_batches)
289
+
290
+ optimizer.zero_grad(set_to_none=True)
291
+ optimizer_step = 0
292
+ seen_batches = 0
293
+ last_loss = None
294
+ ppl_history = []
295
+
296
+ for epoch_idx, max_batches in enumerate(epoch_plan, start=1):
297
+ iterator = dataloader if max_batches is None else itertools.islice(dataloader, max_batches)
298
+ if tqdm is not None:
299
+ iterator = tqdm(iterator, desc=f"LoRA epoch {epoch_idx}", unit="batch", total=max_batches)
300
+ for batch in iterator:
301
+ input_ids = batch[0].to(args.device)
302
+ attention_mask = batch[1].to(args.device)
303
+
304
+ autocast_ctx = (
305
+ torch.autocast(device_type=device_type, dtype=amp_dtype)
306
+ if use_amp
307
+ else nullcontext()
308
+ )
309
+ with autocast_ctx:
310
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
311
+ logits = outputs.logits[:, :-1, :].contiguous()
312
+ labels = input_ids[:, 1:].contiguous()
313
+ mask = attention_mask[:, 1:].contiguous()
314
+ ce_flat = torch.nn.functional.cross_entropy(
315
+ logits.view(-1, logits.size(-1)),
316
+ labels.view(-1),
317
+ reduction="none",
318
+ )
319
+ denom = mask.sum()
320
+ if denom.item() == 0:
321
+ continue
322
+ loss = (ce_flat * mask.reshape(-1).to(ce_flat.dtype)).sum() / denom
323
+
324
+ last_loss = float(loss.detach().item())
325
+ scaled_loss = loss / max(args.grad_accum_steps, 1)
326
+ if use_scaler:
327
+ scaler.scale(scaled_loss).backward()
328
+ else:
329
+ scaled_loss.backward()
330
+
331
+ seen_batches += 1
332
+ if seen_batches % max(args.grad_accum_steps, 1) != 0:
333
+ continue
334
+
335
+ if args.max_grad_norm is not None:
336
+ if use_scaler:
337
+ scaler.unscale_(optimizer)
338
+ torch.nn.utils.clip_grad_norm_(lora_params, args.max_grad_norm)
339
+ if use_scaler:
340
+ scaler.step(optimizer)
341
+ scaler.update()
342
+ else:
343
+ optimizer.step()
344
+ optimizer.zero_grad(set_to_none=True)
345
+ optimizer_step += 1
346
+
347
+ if args.log_steps and optimizer_step % args.log_steps == 0:
348
+ print(f"[loratune] step={optimizer_step} loss={last_loss:.6f}")
349
+ if wikitext2_eval_dataloaders is not None:
350
+ prev_mode = model.training
351
+ model.eval()
352
+ ppl_results = ppl_eval.evaluate_ppl_dataloaders(
353
+ model,
354
+ wikitext2_eval_dataloaders,
355
+ args.device,
356
+ max_batches=args.wikitext2_ppl_max_batches,
357
+ )
358
+ ppl_history.append({"step": optimizer_step, "ppl": ppl_results})
359
+ print(f"[loratune] ppl step={optimizer_step} {ppl_results}")
360
+ if prev_mode:
361
+ model.train()
362
+
363
+ if args.save_steps and optimizer_step % args.save_steps == 0:
364
+ checkpoint_dir = save_lora_adapters(
365
+ model,
366
+ args,
367
+ subdir=os.path.join("checkpoints", f"step_{optimizer_step}"),
368
+ )
369
+ print(f"[loratune] saved adapter checkpoint to {checkpoint_dir}")
370
+
371
+ adapter_dir = save_lora_adapters(model, args)
372
+ merge_lora_adapters(model)
373
+ return {
374
+ "adapter_dir": adapter_dir,
375
+ "optimizer_steps": optimizer_step,
376
+ "seen_batches": seen_batches,
377
+ "last_loss": last_loss,
378
+ "wikitext2_ppl_history": ppl_history,
379
+ }
380
+
381
+
382
+ def main() -> None:
383
+ args = parse_args()
384
+ if args.batch_size < 1:
385
+ raise SystemExit("--batch_size must be >= 1")
386
+ if args.micro_batch_size < 1:
387
+ raise SystemExit("--micro_batch_size must be >= 1")
388
+ args.grad_accum_steps = args.batch_size // args.micro_batch_size
389
+ if args.grad_accum_steps < 1:
390
+ raise SystemExit("--batch_size must be >= --micro_batch_size")
391
+
392
+ seed_all(args.seed)
393
+ os.makedirs(args.output_dir, exist_ok=True)
394
+
395
+ model, tokenizer = load_base_artifacts(args)
396
+ if args.dtype != "float32":
397
+ model = model.to(get_dtype(args.dtype))
398
+ model.to(args.device)
399
+
400
+ dataloader = build_training_loader(args, tokenizer)
401
+ wikitext2_eval_dataloaders = prepare_wikitext2_eval(args, model, tokenizer)
402
+ metrics = train(model, dataloader, args, wikitext2_eval_dataloaders=wikitext2_eval_dataloaders)
403
+
404
+ model.save_pretrained(args.output_dir)
405
+ tokenizer.save_pretrained(args.output_dir)
406
+
407
+ with open(os.path.join(args.output_dir, "loratune_metrics.json"), "w", encoding="utf-8") as handle:
408
+ json.dump(
409
+ {
410
+ "base_model": args.base_model,
411
+ "instruction_dataset": args.instruction_dataset,
412
+ "seq_len": args.seq_len,
413
+ "batch_size": args.batch_size,
414
+ "micro_batch_size": args.micro_batch_size,
415
+ "grad_accum_steps": args.grad_accum_steps,
416
+ "epochs": args.epochs,
417
+ "learning_rate": args.learning_rate,
418
+ "save_steps": args.save_steps,
419
+ "lora_rank": args.lora_rank,
420
+ "lora_alpha": args.lora_alpha,
421
+ "lora_dropout": args.lora_dropout,
422
+ **metrics,
423
+ },
424
+ handle,
425
+ indent=2,
426
+ )
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()
src/loratune_config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Configuration helpers for centralized LoRA finetuning."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import asdict, dataclass, field
7
+ from types import SimpleNamespace
8
+ from typing import Any, Dict, List, Optional
9
+
10
+
11
+ @dataclass
12
+ class LoRATuneConfig:
13
+ """Structured config matching the current loratune.py CLI surface."""
14
+
15
+ base_model: str = ""
16
+ output_dir: str = ""
17
+ device: str = "cuda"
18
+ dtype: str = "bfloat16"
19
+ trust_remote_code: bool = False
20
+ seed: int = 42
21
+
22
+ instruction_dataset: str = "tatsu-lab/alpaca"
23
+ instruction_config: Optional[str] = None
24
+ instruction_split: str = "train"
25
+ instruction_field_instruction: str = "instruction"
26
+ instruction_field_input: str = "input"
27
+ instruction_field_output: str = "output"
28
+ max_samples: int = 0
29
+ seq_len: int = 1024
30
+ batch_size: int = 64
31
+ micro_batch_size: int = 4
32
+ epochs: float = 1.0
33
+ learning_rate: float = 1e-4
34
+ weight_decay: float = 0.0
35
+ max_grad_norm: float = 1.0
36
+ log_steps: int = 100
37
+
38
+ wikitext2_ppl_on_log: bool = True
39
+ wikitext2_ppl_seq_len: int = 128
40
+ wikitext2_ppl_batch_size: int = 8
41
+ wikitext2_ppl_max_batches: Optional[int] = None
42
+
43
+ lora_rank: int = 8
44
+ lora_alpha: float = 16.0
45
+ lora_dropout: float = 0.0
46
+ lora_target_modules: List[str] = field(
47
+ default_factory=lambda: [
48
+ "q_proj",
49
+ "k_proj",
50
+ "v_proj",
51
+ "o_proj",
52
+ "gate_proj",
53
+ "down_proj",
54
+ "up_proj",
55
+ ]
56
+ )
57
+
58
+ @property
59
+ def grad_accum_steps(self) -> int:
60
+ if self.batch_size < 1:
61
+ raise ValueError("batch_size must be >= 1")
62
+ if self.micro_batch_size < 1:
63
+ raise ValueError("micro_batch_size must be >= 1")
64
+ if self.batch_size < self.micro_batch_size:
65
+ raise ValueError("batch_size must be >= micro_batch_size")
66
+ return self.batch_size // self.micro_batch_size
67
+
68
+ def validate(self) -> "LoRATuneConfig":
69
+ _ = self.grad_accum_steps
70
+ if not self.base_model:
71
+ raise ValueError("base_model must be set")
72
+ if not self.output_dir:
73
+ raise ValueError("output_dir must be set")
74
+ return self
75
+
76
+ def to_dict(self) -> Dict[str, Any]:
77
+ data = asdict(self)
78
+ data["grad_accum_steps"] = self.grad_accum_steps
79
+ return data
80
+
81
+ def to_namespace(self) -> SimpleNamespace:
82
+ return SimpleNamespace(**self.to_dict())
83
+
84
+ @classmethod
85
+ def from_dict(cls, values: Dict[str, Any]) -> "LoRATuneConfig":
86
+ return cls(**values)
src/ppl_eval.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Perplexity evaluation for causal LMs on HF datasets or provided text."""
3
+
4
+ import argparse
5
+ import json
6
+ import math
7
+ import os
8
+ from typing import Dict, Iterable, List, Optional
9
+
10
+ import torch
11
+
12
+ try:
13
+ from datasets import load_dataset
14
+ except Exception: # pragma: no cover - optional dependency
15
+ load_dataset = None
16
+
17
+ try:
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+ except Exception as exc: # pragma: no cover - fail early with clear error
20
+ raise SystemExit("transformers is required: pip install transformers") from exc
21
+
22
+ try:
23
+ from tqdm import tqdm
24
+ except Exception: # pragma: no cover - optional dependency
25
+ tqdm = None
26
+
27
+
28
+ def _tqdm_enabled() -> bool:
29
+ value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
30
+ return value.strip().lower() not in {"1", "true", "yes", "on"}
31
+
32
+
33
+ def parse_args() -> argparse.Namespace:
34
+ parser = argparse.ArgumentParser(
35
+ description="Compute perplexity for a causal LM on one or more datasets."
36
+ )
37
+ parser.add_argument("--model", required=True, help="HF model id or local path")
38
+ parser.add_argument(
39
+ "--dataset",
40
+ action="append",
41
+ default=[],
42
+ help="HF dataset name (repeatable).",
43
+ )
44
+ parser.add_argument(
45
+ "--dataset_config",
46
+ action="append",
47
+ default=[],
48
+ help="Optional dataset config (repeatable or single shared config).",
49
+ )
50
+ parser.add_argument(
51
+ "--dataset_split",
52
+ default="test",
53
+ help="Dataset split to use (default: test)",
54
+ )
55
+ parser.add_argument(
56
+ "--dataset_text_field",
57
+ default=None,
58
+ help="Text field in dataset (default: auto-detect, applies to all datasets)",
59
+ )
60
+ parser.add_argument(
61
+ "--text",
62
+ action="append",
63
+ default=[],
64
+ help="Inline text samples (can pass multiple)",
65
+ )
66
+ parser.add_argument(
67
+ "--text_file",
68
+ default=None,
69
+ help="Path to a text file for evaluation data",
70
+ )
71
+ parser.add_argument(
72
+ "--num_samples",
73
+ type=int,
74
+ default=0,
75
+ help="Number of token sequences to use per dataset (0 = all)",
76
+ )
77
+ parser.add_argument(
78
+ "--seq_len", type=int, default=2048, help="Sequence length"
79
+ )
80
+ parser.add_argument(
81
+ "--batch_size", type=int, default=2, help="Batch size"
82
+ )
83
+ parser.add_argument(
84
+ "--max_batches",
85
+ type=int,
86
+ default=None,
87
+ help="Optional max number of batches to evaluate per dataset",
88
+ )
89
+ parser.add_argument(
90
+ "--model_family",
91
+ type=str,
92
+ choices=["auto", "llama", "qwen"],
93
+ default="auto",
94
+ help="Model family for BOS handling",
95
+ )
96
+ parser.add_argument(
97
+ "--add_bos",
98
+ type=str,
99
+ choices=["auto", "always", "never"],
100
+ default="auto",
101
+ help="Whether to prepend BOS to each sample",
102
+ )
103
+ parser.add_argument(
104
+ "--device",
105
+ default="cuda" if torch.cuda.is_available() else "cpu",
106
+ help="Device for model + compute",
107
+ )
108
+ parser.add_argument(
109
+ "--dtype",
110
+ default="auto",
111
+ choices=["auto", "float32", "float16", "bfloat16"],
112
+ help="Model dtype",
113
+ )
114
+ parser.add_argument(
115
+ "--seed", type=int, default=0, help="Random seed for shuffling"
116
+ )
117
+ parser.add_argument(
118
+ "--shuffle",
119
+ action="store_true",
120
+ help="Shuffle dataset before sampling",
121
+ )
122
+ parser.add_argument(
123
+ "--num_workers",
124
+ type=int,
125
+ default=0,
126
+ help="DataLoader workers",
127
+ )
128
+ parser.add_argument(
129
+ "--cache_dir",
130
+ default=None,
131
+ help="Optional datasets cache directory",
132
+ )
133
+ parser.add_argument(
134
+ "--trust_remote_code",
135
+ action="store_true",
136
+ help="Allow custom model code from hub",
137
+ )
138
+ parser.add_argument(
139
+ "--output",
140
+ default=None,
141
+ help="Optional JSON output path",
142
+ )
143
+ return parser.parse_args()
144
+
145
+
146
+ def _normalize_config(config: Optional[str]) -> Optional[str]:
147
+ if config is None:
148
+ return None
149
+ if config.strip().lower() in {"none", "null", "-"}:
150
+ return None
151
+ return config
152
+
153
+
154
+ def _expand_dataset_configs(
155
+ datasets: List[str], configs: List[str]
156
+ ) -> List[Optional[str]]:
157
+ if not configs:
158
+ return [None] * len(datasets)
159
+ if len(configs) == 1 and len(datasets) > 1:
160
+ return [_normalize_config(configs[0])] * len(datasets)
161
+ if len(configs) != len(datasets):
162
+ raise SystemExit(
163
+ "Provide zero, one, or matching-count --dataset_config values."
164
+ )
165
+ return [_normalize_config(cfg) for cfg in configs]
166
+
167
+
168
+ def guess_text_field(dataset) -> str:
169
+ if hasattr(dataset, "column_names") and dataset.column_names:
170
+ if "text" in dataset.column_names:
171
+ return "text"
172
+ return dataset.column_names[0]
173
+ if hasattr(dataset, "features"):
174
+ names = list(dataset.features.keys())
175
+ if "text" in names:
176
+ return "text"
177
+ if names:
178
+ return names[0]
179
+ return "text"
180
+
181
+
182
+ def _infer_model_family(model) -> str:
183
+ model_type = str(getattr(getattr(model, "config", None), "model_type", "")).lower()
184
+ architectures = getattr(getattr(model, "config", None), "architectures", [])
185
+ arch_lower = " ".join(str(name).lower() for name in architectures)
186
+ if "qwen" in model_type or "qwen" in arch_lower:
187
+ return "qwen"
188
+ if "llama" in model_type or "llama" in arch_lower:
189
+ return "llama"
190
+ return "unknown"
191
+
192
+
193
+ def _resolve_add_bos(setting: str, model_family: str, tokenizer) -> bool:
194
+ if setting == "always":
195
+ return True
196
+ if setting == "never":
197
+ return False
198
+ if model_family == "llama":
199
+ return True
200
+ if model_family == "qwen":
201
+ return False
202
+ if hasattr(tokenizer, "add_bos_token"):
203
+ return bool(getattr(tokenizer, "add_bos_token"))
204
+ init_kwargs = getattr(tokenizer, "init_kwargs", None)
205
+ if isinstance(init_kwargs, dict) and "add_bos_token" in init_kwargs:
206
+ return bool(init_kwargs["add_bos_token"])
207
+ return False
208
+
209
+
210
+ def build_token_chunks(
211
+ texts: Iterable[str],
212
+ tokenizer,
213
+ seq_len: int,
214
+ num_samples: int,
215
+ add_bos: bool = False,
216
+ ) -> List[torch.Tensor]:
217
+ chunks: List[torch.Tensor] = []
218
+ buffer: List[int] = []
219
+ for text in texts:
220
+ ids = tokenizer.encode(text, add_special_tokens=False)
221
+ if add_bos and tokenizer.bos_token_id is not None:
222
+ ids = [tokenizer.bos_token_id] + ids
223
+ if not ids:
224
+ continue
225
+ buffer.extend(ids)
226
+ while len(buffer) >= seq_len and len(chunks) < num_samples:
227
+ chunk = buffer[:seq_len]
228
+ buffer = buffer[seq_len:]
229
+ chunks.append(torch.tensor(chunk, dtype=torch.long))
230
+ if len(chunks) >= num_samples:
231
+ break
232
+ return chunks
233
+
234
+
235
+ def get_dtype(dtype: str):
236
+ if dtype == "auto":
237
+ return None
238
+ if dtype == "float16":
239
+ return torch.float16
240
+ if dtype == "bfloat16":
241
+ return torch.bfloat16
242
+ return torch.float32
243
+
244
+
245
+ def compute_ppl(model, dataloader, device: str, max_batches: Optional[int]) -> float:
246
+ model.eval()
247
+ nll_sum = 0.0
248
+ token_count = 0
249
+ iterator = dataloader
250
+ if tqdm is not None and _tqdm_enabled():
251
+ iterator = tqdm(dataloader, desc="PPL", unit="batch")
252
+ with torch.no_grad():
253
+ for step, batch in enumerate(iterator):
254
+ if isinstance(batch, dict):
255
+ input_ids = batch["input_ids"].to(device)
256
+ else:
257
+ input_ids = batch[0].to(device)
258
+ outputs = model(input_ids=input_ids)
259
+ logits = outputs.logits
260
+ shift_logits = logits[:, :-1, :].contiguous()
261
+ shift_labels = input_ids[:, 1:].contiguous()
262
+ loss = torch.nn.functional.cross_entropy(
263
+ shift_logits.view(-1, shift_logits.size(-1)),
264
+ shift_labels.view(-1),
265
+ reduction="sum",
266
+ )
267
+ nll_sum += float(loss.item())
268
+ token_count += shift_labels.numel()
269
+ if max_batches is not None and step + 1 >= max_batches:
270
+ break
271
+
272
+ if token_count == 0:
273
+ raise RuntimeError("No tokens processed; check evaluation inputs.")
274
+
275
+ return math.exp(nll_sum / token_count)
276
+
277
+
278
+ def _load_lm_dataset(
279
+ tokenizer,
280
+ dataset_name: str,
281
+ config: Optional[str],
282
+ split: str,
283
+ text_field: Optional[str],
284
+ seq_len: int,
285
+ add_bos: bool,
286
+ cache_dir: Optional[str],
287
+ ):
288
+ dataset = load_dataset(
289
+ dataset_name,
290
+ config,
291
+ split=split,
292
+ trust_remote_code=True,
293
+ cache_dir=cache_dir,
294
+ )
295
+
296
+ field = text_field or guess_text_field(dataset)
297
+
298
+ def is_valid_text(example) -> bool:
299
+ value = example.get(field)
300
+ return isinstance(value, str) and value.strip() != ""
301
+
302
+ dataset = dataset.filter(is_valid_text, desc=f"filter-{dataset_name}")
303
+
304
+ def tokenize_fn(examples):
305
+ tokenized = tokenizer(
306
+ examples[field],
307
+ add_special_tokens=False,
308
+ return_attention_mask=False,
309
+ )
310
+ if add_bos and tokenizer.bos_token_id is not None:
311
+ tokenized["input_ids"] = [
312
+ [tokenizer.bos_token_id] + ids for ids in tokenized["input_ids"]
313
+ ]
314
+ return tokenized
315
+
316
+ tokenized = dataset.map(
317
+ tokenize_fn,
318
+ batched=True,
319
+ remove_columns=dataset.column_names,
320
+ desc=f"tokenize-{dataset_name}",
321
+ )
322
+
323
+ def group_texts(examples):
324
+ concatenated = []
325
+ for ids in examples["input_ids"]:
326
+ concatenated.extend(ids)
327
+ total_length = (len(concatenated) // seq_len) * seq_len
328
+ if total_length == 0:
329
+ return {"input_ids": []}
330
+ return {
331
+ "input_ids": [
332
+ concatenated[i : i + seq_len] for i in range(0, total_length, seq_len)
333
+ ]
334
+ }
335
+
336
+ lm_dataset = tokenized.map(
337
+ group_texts,
338
+ batched=True,
339
+ batch_size=1000,
340
+ remove_columns=tokenized.column_names,
341
+ desc=f"group-{dataset_name}",
342
+ )
343
+ lm_dataset.set_format(type="torch", columns=["input_ids"])
344
+ return lm_dataset
345
+
346
+
347
+ def prepare_ppl_dataloaders(
348
+ tokenizer,
349
+ datasets: List[str],
350
+ configs: List[Optional[str]],
351
+ split: str,
352
+ text_field: Optional[str],
353
+ num_samples: int,
354
+ seq_len: int,
355
+ batch_size: int,
356
+ seed: int,
357
+ shuffle: bool,
358
+ model_family: str = "auto",
359
+ add_bos: str = "auto",
360
+ cache_dir: Optional[str] = None,
361
+ num_workers: int = 0,
362
+ model=None,
363
+ ) -> Dict[str, torch.utils.data.DataLoader]:
364
+ if load_dataset is None:
365
+ raise SystemExit("datasets is required for dataset evaluation")
366
+
367
+ resolved_family = model_family
368
+ if resolved_family == "auto":
369
+ if model is None:
370
+ raise SystemExit("model is required when model_family is 'auto'")
371
+ resolved_family = _infer_model_family(model)
372
+ use_bos = _resolve_add_bos(add_bos, resolved_family, tokenizer)
373
+ if use_bos and tokenizer.bos_token_id is None:
374
+ use_bos = False
375
+
376
+ dataloaders: Dict[str, torch.utils.data.DataLoader] = {}
377
+ for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
378
+ lm_dataset = _load_lm_dataset(
379
+ tokenizer=tokenizer,
380
+ dataset_name=dataset_name,
381
+ config=config,
382
+ split=split,
383
+ text_field=text_field,
384
+ seq_len=seq_len,
385
+ add_bos=use_bos,
386
+ cache_dir=cache_dir,
387
+ )
388
+ if shuffle:
389
+ try:
390
+ lm_dataset = lm_dataset.shuffle(seed=seed + idx)
391
+ except Exception:
392
+ pass
393
+ if num_samples and hasattr(lm_dataset, "__len__"):
394
+ limit = min(num_samples, len(lm_dataset))
395
+ lm_dataset = lm_dataset.select(range(limit))
396
+
397
+ data_loader = torch.utils.data.DataLoader(
398
+ lm_dataset,
399
+ batch_size=batch_size,
400
+ shuffle=False,
401
+ num_workers=num_workers,
402
+ )
403
+ label = dataset_name if config is None else f"{dataset_name}:{config}"
404
+ dataloaders[label] = data_loader
405
+
406
+ return dataloaders
407
+
408
+
409
+ def evaluate_ppl_dataloaders(
410
+ model,
411
+ dataloaders: Dict[str, torch.utils.data.DataLoader],
412
+ device: str,
413
+ max_batches: Optional[int] = None,
414
+ ) -> Dict[str, float]:
415
+ results: Dict[str, float] = {}
416
+ for label, data_loader in dataloaders.items():
417
+ ppl = compute_ppl(model, data_loader, device, max_batches=max_batches)
418
+ results[label] = ppl
419
+ return results
420
+
421
+
422
+ def evaluate_ppl_datasets(
423
+ model,
424
+ tokenizer,
425
+ datasets: List[str],
426
+ configs: List[Optional[str]],
427
+ split: str,
428
+ text_field: Optional[str],
429
+ num_samples: int,
430
+ seq_len: int,
431
+ batch_size: int,
432
+ device: str,
433
+ seed: int,
434
+ shuffle: bool,
435
+ model_family: str = "auto",
436
+ add_bos: str = "auto",
437
+ max_batches: Optional[int] = None,
438
+ cache_dir: Optional[str] = None,
439
+ num_workers: int = 0,
440
+ ) -> Dict[str, float]:
441
+ if load_dataset is None:
442
+ raise SystemExit("datasets is required for dataset evaluation")
443
+
444
+ resolved_family = model_family
445
+ if resolved_family == "auto":
446
+ resolved_family = _infer_model_family(model)
447
+ use_bos = _resolve_add_bos(add_bos, resolved_family, tokenizer)
448
+ if use_bos and tokenizer.bos_token_id is None:
449
+ use_bos = False
450
+
451
+ results: Dict[str, float] = {}
452
+ for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
453
+ lm_dataset = _load_lm_dataset(
454
+ tokenizer=tokenizer,
455
+ dataset_name=dataset_name,
456
+ config=config,
457
+ split=split,
458
+ text_field=text_field,
459
+ seq_len=seq_len,
460
+ add_bos=use_bos,
461
+ cache_dir=cache_dir,
462
+ )
463
+ if shuffle:
464
+ try:
465
+ lm_dataset = lm_dataset.shuffle(seed=seed + idx)
466
+ except Exception:
467
+ pass
468
+ if num_samples and hasattr(lm_dataset, "__len__"):
469
+ limit = min(num_samples, len(lm_dataset))
470
+ lm_dataset = lm_dataset.select(range(limit))
471
+
472
+ data_loader = torch.utils.data.DataLoader(
473
+ lm_dataset,
474
+ batch_size=batch_size,
475
+ shuffle=False,
476
+ num_workers=num_workers,
477
+ )
478
+ label = dataset_name if config is None else f"{dataset_name}:{config}"
479
+ ppl = compute_ppl(model, data_loader, device, max_batches=max_batches)
480
+ results[label] = ppl
481
+ return results
482
+
483
+
484
+ def main() -> None:
485
+ args = parse_args()
486
+ torch.manual_seed(args.seed)
487
+
488
+ dtype = get_dtype(args.dtype)
489
+ model = AutoModelForCausalLM.from_pretrained(
490
+ args.model,
491
+ torch_dtype=dtype,
492
+ trust_remote_code=args.trust_remote_code,
493
+ )
494
+ tokenizer = AutoTokenizer.from_pretrained(
495
+ args.model, trust_remote_code=args.trust_remote_code
496
+ )
497
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
498
+ tokenizer.pad_token = tokenizer.eos_token
499
+
500
+ model.to(args.device)
501
+
502
+ results: Dict[str, float] = {}
503
+ resolved_family = args.model_family
504
+ if resolved_family == "auto":
505
+ resolved_family = _infer_model_family(model)
506
+ use_bos = _resolve_add_bos(args.add_bos, resolved_family, tokenizer)
507
+ if use_bos and tokenizer.bos_token_id is None:
508
+ use_bos = False
509
+
510
+ if args.dataset:
511
+ datasets = list(args.dataset)
512
+ configs = _expand_dataset_configs(datasets, list(args.dataset_config))
513
+ results.update(
514
+ evaluate_ppl_datasets(
515
+ model,
516
+ tokenizer,
517
+ datasets=datasets,
518
+ configs=configs,
519
+ split=args.dataset_split,
520
+ text_field=args.dataset_text_field,
521
+ num_samples=args.num_samples,
522
+ seq_len=args.seq_len,
523
+ batch_size=args.batch_size,
524
+ device=args.device,
525
+ seed=args.seed,
526
+ shuffle=args.shuffle,
527
+ model_family=resolved_family,
528
+ add_bos="always" if use_bos else "never",
529
+ max_batches=args.max_batches,
530
+ cache_dir=args.cache_dir,
531
+ num_workers=args.num_workers,
532
+ )
533
+ )
534
+
535
+ if args.text_file or args.text:
536
+ custom_texts: List[str] = []
537
+ if args.text_file:
538
+ with open(args.text_file, "r", encoding="utf-8") as handle:
539
+ custom_texts.extend([line.strip() for line in handle if line.strip()])
540
+ if args.text:
541
+ custom_texts.extend([t for t in args.text if t])
542
+ if custom_texts:
543
+ chunks = build_token_chunks(
544
+ custom_texts,
545
+ tokenizer,
546
+ args.seq_len,
547
+ args.num_samples if args.num_samples > 0 else 1_000_000,
548
+ add_bos=use_bos,
549
+ )
550
+ if not chunks:
551
+ raise SystemExit(
552
+ "Not enough custom text to build token sequences. "
553
+ "Provide more --text/--text_file content or reduce --seq_len."
554
+ )
555
+ dataset = torch.utils.data.TensorDataset(torch.stack(chunks))
556
+ dataloader = torch.utils.data.DataLoader(
557
+ dataset, batch_size=args.batch_size, shuffle=False
558
+ )
559
+ results["custom"] = compute_ppl(
560
+ model, dataloader, args.device, max_batches=args.max_batches
561
+ )
562
+
563
+ if not results:
564
+ raise SystemExit("Provide --dataset and/or --text/--text_file for evaluation")
565
+
566
+ print("Perplexity results:")
567
+ for name, ppl in results.items():
568
+ print(f"{name}: {ppl:.4f}")
569
+
570
+ if args.output:
571
+ with open(args.output, "w", encoding="utf-8") as handle:
572
+ json.dump({"model": args.model, "results": results}, handle, indent=2)
573
+
574
+
575
+ if __name__ == "__main__":
576
+ main()
src/ppl_eval_progressive.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Evaluate perplexity for a progressive-pruned model assembled from cycles."""
3
+
4
+ import argparse
5
+
6
+ import torch
7
+
8
+ try:
9
+ import ppl_eval
10
+ except Exception as exc: # pragma: no cover - optional dependency
11
+ raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
12
+
13
+ try:
14
+ from transformers import AutoTokenizer
15
+ except Exception as exc: # pragma: no cover - fail early with clear error
16
+ raise SystemExit("transformers is required: pip install transformers") from exc
17
+
18
+ from progressive_loader import load_progressive_model
19
+
20
+
21
+ def parse_args() -> argparse.Namespace:
22
+ parser = argparse.ArgumentParser(
23
+ description="Evaluate PPL for a model reconstructed from progressive cycles."
24
+ )
25
+ parser.add_argument("--base_model", required=True, help="Base HF model id or path")
26
+ parser.add_argument(
27
+ "--progressive_dir",
28
+ required=True,
29
+ help="Output directory from progressive pruning",
30
+ )
31
+ parser.add_argument(
32
+ "--cycle",
33
+ type=int,
34
+ default=None,
35
+ help="Cycle to load (default: final)",
36
+ )
37
+ parser.add_argument(
38
+ "--dataset",
39
+ action="append",
40
+ default=[],
41
+ help="Evaluation dataset name (repeatable). Defaults to wikitext.",
42
+ )
43
+ parser.add_argument(
44
+ "--dataset_config",
45
+ action="append",
46
+ default=[],
47
+ help="Evaluation dataset config (repeatable or single shared config).",
48
+ )
49
+ parser.add_argument(
50
+ "--dataset_split",
51
+ default="test",
52
+ help="Evaluation dataset split (default: test)",
53
+ )
54
+ parser.add_argument(
55
+ "--dataset_text_field",
56
+ default=None,
57
+ help="Evaluation text field override (default: auto-detect)",
58
+ )
59
+ parser.add_argument(
60
+ "--num_samples",
61
+ type=int,
62
+ default=0,
63
+ help="Number of token sequences per dataset (0 = all)",
64
+ )
65
+ parser.add_argument(
66
+ "--seq_len",
67
+ type=int,
68
+ default=2048,
69
+ help="Sequence length for eval",
70
+ )
71
+ parser.add_argument(
72
+ "--batch_size",
73
+ type=int,
74
+ default=4,
75
+ help="Batch size for eval",
76
+ )
77
+ parser.add_argument(
78
+ "--device",
79
+ default="cuda" if torch.cuda.is_available() else "cpu",
80
+ help="Device for eval",
81
+ )
82
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
83
+ parser.add_argument(
84
+ "--model_family",
85
+ type=str,
86
+ choices=["auto", "llama", "qwen"],
87
+ default="auto",
88
+ help="Model family for BOS handling",
89
+ )
90
+ parser.add_argument(
91
+ "--add_bos",
92
+ type=str,
93
+ choices=["auto", "always", "never"],
94
+ default="auto",
95
+ help="Whether to prepend BOS to each sample",
96
+ )
97
+ parser.add_argument(
98
+ "--max_batches",
99
+ type=int,
100
+ default=None,
101
+ help="Optional max number of eval batches per dataset",
102
+ )
103
+ parser.add_argument(
104
+ "--cache_dir",
105
+ default=None,
106
+ help="Optional datasets cache dir for eval",
107
+ )
108
+ parser.add_argument(
109
+ "--num_workers",
110
+ type=int,
111
+ default=0,
112
+ help="Eval DataLoader workers",
113
+ )
114
+ parser.add_argument(
115
+ "--dtype",
116
+ default="auto",
117
+ choices=["auto", "float32", "float16", "bfloat16"],
118
+ help="Model dtype",
119
+ )
120
+ parser.add_argument(
121
+ "--trust_remote_code",
122
+ action="store_true",
123
+ help="Allow custom model code from hub",
124
+ )
125
+ parser.add_argument(
126
+ "--layer_path",
127
+ default=None,
128
+ help="Override layer attribute path if needed",
129
+ )
130
+ return parser.parse_args()
131
+
132
+
133
+ def main() -> None:
134
+ args = parse_args()
135
+ torch.manual_seed(args.seed)
136
+
137
+ datasets = args.dataset or ["wikitext"]
138
+ configs = args.dataset_config or ["wikitext-2-raw-v1"]
139
+ configs = ppl_eval._expand_dataset_configs(datasets, configs)
140
+
141
+ model = load_progressive_model(
142
+ args.base_model,
143
+ args.progressive_dir,
144
+ cycle=args.cycle,
145
+ device=args.device,
146
+ dtype=args.dtype,
147
+ trust_remote_code=args.trust_remote_code,
148
+ layer_path=args.layer_path,
149
+ )
150
+ tokenizer = AutoTokenizer.from_pretrained(
151
+ args.base_model, trust_remote_code=args.trust_remote_code
152
+ )
153
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
154
+ tokenizer.pad_token = tokenizer.eos_token
155
+
156
+ results = ppl_eval.evaluate_ppl_datasets(
157
+ model,
158
+ tokenizer,
159
+ datasets=datasets,
160
+ configs=configs,
161
+ split=args.dataset_split,
162
+ text_field=args.dataset_text_field,
163
+ num_samples=args.num_samples,
164
+ seq_len=args.seq_len,
165
+ batch_size=args.batch_size,
166
+ device=args.device,
167
+ seed=args.seed,
168
+ shuffle=False,
169
+ model_family=args.model_family,
170
+ add_bos=args.add_bos,
171
+ max_batches=args.max_batches,
172
+ cache_dir=args.cache_dir,
173
+ num_workers=args.num_workers,
174
+ )
175
+
176
+ print("Perplexity results:")
177
+ for name, ppl in results.items():
178
+ print(f"{name}: {ppl:.4f}")
179
+
180
+
181
+ if __name__ == "__main__":
182
+ main()
src/print_progressive_ppl_csv.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Print progressive PPL stats as CSV from progressive_metadata.json.
3
+
4
+ Expected (current) metadata shape:
5
+ - data["eval"]["pre_ppl"]
6
+ - data["cycles"][i]["redistrib_post_ppl"] (optional; legacy key)
7
+ - data["cycles"][i]["comm_post_ppl"] (optional; current key)
8
+ - data["cycles"][i]["distill_post_ppl"]
9
+ - data["cycles"][i]["lora_post_ppl"] (typically only set on the last cycle)
10
+ - data["cycles"][i]["post_ppl"]
11
+ """
12
+
13
+ import argparse
14
+ import csv
15
+ import json
16
+ import os
17
+ import shlex
18
+ import sys
19
+ from typing import Any, List, Optional
20
+
21
+
22
+ def _cell(value: Any) -> str:
23
+ if value is None:
24
+ return ""
25
+ if isinstance(value, dict):
26
+ if not value:
27
+ return ""
28
+ return ";".join(str(value[key]) for key in sorted(value))
29
+ if isinstance(value, (list, tuple)):
30
+ return ";".join(str(item) for item in value)
31
+ return str(value)
32
+
33
+
34
+ def _read_run_command_tokens(metadata_path: str) -> Optional[List[str]]:
35
+ meta_dir = os.path.dirname(os.path.abspath(metadata_path))
36
+ run_args_path = os.path.join(meta_dir, "run_args.txt")
37
+ if not os.path.exists(run_args_path):
38
+ return None
39
+
40
+ try:
41
+ with open(run_args_path, "r", encoding="utf-8") as handle:
42
+ lines = handle.read().splitlines()
43
+ except OSError:
44
+ return None
45
+
46
+ cmd_line = None
47
+ for idx, line in enumerate(lines):
48
+ if line.strip() == "command:":
49
+ if idx + 1 < len(lines):
50
+ cmd_line = lines[idx + 1].strip()
51
+ break
52
+
53
+ if not cmd_line:
54
+ return None
55
+
56
+ try:
57
+ return shlex.split(cmd_line)
58
+ except ValueError:
59
+ return None
60
+
61
+
62
+ def _parse_exclude_pairs_from_tokens(tokens: List[str]) -> Optional[List[int]]:
63
+ start = None
64
+ for idx, tok in enumerate(tokens):
65
+ if tok in ("--exclude_pairs", "--exclude_layers"):
66
+ start = idx + 1
67
+ break
68
+ if start is None:
69
+ return None
70
+
71
+ raw: List[int] = []
72
+ for tok in tokens[start:]:
73
+ if tok.startswith("--"):
74
+ break
75
+ # Legacy bug: run_args.txt used to print "python" before every token.
76
+ if tok == "python":
77
+ continue
78
+ try:
79
+ raw.append(int(tok))
80
+ except ValueError:
81
+ continue
82
+ return raw
83
+
84
+
85
+ def _normalize_excluded_pairs(raw: List[int], num_pairs: int) -> List[int]:
86
+ exclude: List[int] = []
87
+ for idx in raw:
88
+ if idx < 0:
89
+ idx = num_pairs + idx
90
+ if 0 <= idx < num_pairs:
91
+ exclude.append(idx)
92
+ return sorted(set(exclude))
93
+
94
+
95
+ def _read_excluded_pairs_from_cycle_meta(meta_dir: str, cycle_idx: int) -> Optional[List[int]]:
96
+ path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
97
+ try:
98
+ with open(path, "r", encoding="utf-8") as handle:
99
+ cycle_meta = json.load(handle)
100
+ except (FileNotFoundError, json.JSONDecodeError, OSError):
101
+ return None
102
+
103
+ dwce_meta = cycle_meta.get("dwce_meta") or {}
104
+ excluded = dwce_meta.get("excluded_pairs")
105
+ if isinstance(excluded, list) and all(isinstance(x, int) for x in excluded):
106
+ return excluded
107
+ return None
108
+
109
+
110
+ def _num_pairs_for_cycle(data: dict, meta_dir: str, cycle_idx: int) -> Optional[int]:
111
+ num_progressive = data.get("num_progressive")
112
+ final_num_layers = data.get("final_num_layers")
113
+ if isinstance(num_progressive, int) and isinstance(final_num_layers, int):
114
+ initial_layers = final_num_layers + num_progressive
115
+ return max(initial_layers - cycle_idx, 0)
116
+
117
+ cycle_meta_path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
118
+ try:
119
+ with open(cycle_meta_path, "r", encoding="utf-8") as handle:
120
+ cycle_meta = json.load(handle)
121
+ except (FileNotFoundError, json.JSONDecodeError, OSError):
122
+ return None
123
+
124
+ num_layers_before = cycle_meta.get("num_layers_before")
125
+ if isinstance(num_layers_before, int):
126
+ return max(num_layers_before - 1, 0)
127
+ return None
128
+
129
+
130
+ def main() -> None:
131
+ parser = argparse.ArgumentParser(
132
+ description="Print progressive PPL values as CSV from progressive_metadata.json"
133
+ )
134
+ parser.add_argument("path", help="Path to progressive_metadata.json")
135
+ args = parser.parse_args()
136
+
137
+ try:
138
+ with open(args.path, "r", encoding="utf-8") as handle:
139
+ data = json.load(handle)
140
+ except FileNotFoundError as exc:
141
+ raise SystemExit(f"File not found: {args.path}") from exc
142
+ except json.JSONDecodeError as exc:
143
+ raise SystemExit(f"Invalid JSON: {args.path}") from exc
144
+
145
+ meta_dir = os.path.dirname(os.path.abspath(args.path))
146
+ run_tokens = _read_run_command_tokens(args.path)
147
+ raw_exclude = (
148
+ _parse_exclude_pairs_from_tokens(run_tokens) if run_tokens is not None else None
149
+ )
150
+
151
+ writer = csv.writer(sys.stdout)
152
+ writer.writerow(
153
+ [
154
+ "cycle",
155
+ "layer_merged",
156
+ "layer_pair",
157
+ "excluded_pairs",
158
+ "redistrib_post_ppl",
159
+ "distill_post_ppl",
160
+ "lora_post_ppl",
161
+ "post_ppl",
162
+ ]
163
+ )
164
+
165
+ pre_ppl = data.get("eval", {}).get("pre_ppl")
166
+ if pre_ppl is not None:
167
+ writer.writerow(["pre", "", "", "", "", "", "", _cell(pre_ppl)])
168
+
169
+ cycles = data.get("cycles") or data.get("cycle_summaries") or []
170
+ for cycle in cycles:
171
+ cycle_idx = cycle.get("cycle", "")
172
+ layer_merged = cycle.get("layer_merged")
173
+ layer_pair = ""
174
+ if isinstance(layer_merged, int):
175
+ layer_pair = f"{layer_merged}-{layer_merged + 1}"
176
+
177
+ excluded_pairs = _read_excluded_pairs_from_cycle_meta(
178
+ meta_dir, cycle_idx if isinstance(cycle_idx, int) else -1
179
+ )
180
+ if excluded_pairs is None and raw_exclude is not None and isinstance(cycle_idx, int):
181
+ num_pairs = _num_pairs_for_cycle(data, meta_dir, cycle_idx)
182
+ if isinstance(num_pairs, int):
183
+ excluded_pairs = _normalize_excluded_pairs(raw_exclude, num_pairs)
184
+ redistrib_post_ppl = cycle.get("redistrib_post_ppl")
185
+ if redistrib_post_ppl is None:
186
+ redistrib_post_ppl = cycle.get("comm_post_ppl")
187
+
188
+ writer.writerow(
189
+ [
190
+ cycle_idx,
191
+ layer_merged if layer_merged is not None else "",
192
+ layer_pair,
193
+ _cell(excluded_pairs),
194
+ _cell(redistrib_post_ppl),
195
+ _cell(cycle.get("distill_post_ppl")),
196
+ _cell(cycle.get("lora_post_ppl")),
197
+ _cell(cycle.get("post_ppl")),
198
+ ]
199
+ )
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
src/progressive_loader.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Utilities to reconstruct models from progressive pruning cycles."""
3
+
4
+ import json
5
+ import os
6
+ from typing import Optional
7
+
8
+ import torch
9
+
10
+ try:
11
+ from transformers import AutoModelForCausalLM, PretrainedConfig
12
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
13
+ except Exception as exc: # pragma: no cover - fail early with clear error
14
+ raise SystemExit("transformers is required: pip install transformers") from exc
15
+
16
+ from fuse_layers_model import (
17
+ decrement_config,
18
+ drop_layer,
19
+ find_layer_container,
20
+ get_dtype,
21
+ normalize_config,
22
+ )
23
+
24
+
25
+ def load_progressive_metadata(output_dir: str) -> dict:
26
+ path = os.path.join(output_dir, "progressive_metadata.json")
27
+ if not os.path.exists(path):
28
+ raise FileNotFoundError(f"Missing progressive metadata at {path}")
29
+ with open(path, "r", encoding="utf-8") as handle:
30
+ return json.load(handle)
31
+
32
+
33
+ def load_normalized_config(model_path: str, trust_remote_code: bool):
34
+ config_dict, unused_kwargs = PretrainedConfig.get_config_dict(
35
+ model_path,
36
+ trust_remote_code=trust_remote_code,
37
+ )
38
+ num_hidden_layers = config_dict.get("num_hidden_layers")
39
+ layer_types = config_dict.get("layer_types")
40
+ if (
41
+ isinstance(num_hidden_layers, int)
42
+ and num_hidden_layers >= 0
43
+ and isinstance(layer_types, list)
44
+ and len(layer_types) != num_hidden_layers
45
+ ):
46
+ config_dict["layer_types"] = list(layer_types[:num_hidden_layers])
47
+ model_type = config_dict["model_type"]
48
+ config_class = CONFIG_MAPPING[model_type]
49
+ config = config_class.from_dict(config_dict, **unused_kwargs)
50
+ normalize_config(config)
51
+ return config
52
+
53
+
54
+ def load_causal_lm(
55
+ model_path_or_id: str,
56
+ *,
57
+ torch_dtype,
58
+ trust_remote_code: bool,
59
+ **kwargs,
60
+ ) -> torch.nn.Module:
61
+ config = None
62
+ config_path = os.path.join(model_path_or_id, "config.json")
63
+ if os.path.isdir(model_path_or_id) and os.path.isfile(config_path):
64
+ config = load_normalized_config(model_path_or_id, trust_remote_code)
65
+ return AutoModelForCausalLM.from_pretrained(
66
+ model_path_or_id,
67
+ config=config,
68
+ torch_dtype=torch_dtype,
69
+ trust_remote_code=trust_remote_code,
70
+ **kwargs,
71
+ )
72
+
73
+
74
+ def load_progressive_model(
75
+ base_model_id: str,
76
+ output_dir: str,
77
+ cycle: Optional[int] = None,
78
+ device: Optional[str] = None,
79
+ dtype: str = "auto",
80
+ trust_remote_code: bool = False,
81
+ layer_path: Optional[str] = None,
82
+ ) -> torch.nn.Module:
83
+ meta = load_progressive_metadata(output_dir)
84
+ num_cycles = int(meta.get("num_progressive", 0))
85
+ if cycle is None:
86
+ cycle = num_cycles
87
+ if cycle < 0 or cycle > num_cycles:
88
+ raise ValueError(f"Cycle {cycle} is outside [0, {num_cycles}]")
89
+
90
+ if cycle > 0:
91
+ full_model_dir = os.path.join(output_dir, f"cycle_{cycle}", "full_model")
92
+ if os.path.isdir(full_model_dir):
93
+ model = load_causal_lm(
94
+ full_model_dir,
95
+ torch_dtype=get_dtype(dtype),
96
+ trust_remote_code=trust_remote_code,
97
+ )
98
+ if device:
99
+ model.to(device)
100
+ return model
101
+
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ base_model_id,
104
+ torch_dtype=get_dtype(dtype),
105
+ trust_remote_code=trust_remote_code,
106
+ )
107
+ active_layer_path = layer_path or meta.get("layer_path")
108
+ parent, name, container = find_layer_container(model, active_layer_path)
109
+
110
+ for idx in range(1, cycle + 1):
111
+ cycle_dir = os.path.join(output_dir, f"cycle_{idx}")
112
+ cycle_meta_path = os.path.join(cycle_dir, "cycle_metadata.json")
113
+ if not os.path.exists(cycle_meta_path):
114
+ raise FileNotFoundError(f"Missing cycle metadata at {cycle_meta_path}")
115
+ with open(cycle_meta_path, "r", encoding="utf-8") as handle:
116
+ cycle_meta = json.load(handle)
117
+
118
+ layer_idx = int(cycle_meta["layer_merged"])
119
+ fused_state = cycle_meta.get("fused_layer_state", "fused_layer.pt")
120
+ fused_state_path = os.path.join(cycle_dir, fused_state)
121
+ if not os.path.exists(fused_state_path):
122
+ raise FileNotFoundError(f"Missing fused layer at {fused_state_path}")
123
+
124
+ layers = list(container)
125
+ if layer_idx < 0 or layer_idx >= len(layers):
126
+ raise ValueError(
127
+ f"Cycle {idx} layer index {layer_idx} out of range for {len(layers)} layers"
128
+ )
129
+
130
+ state = torch.load(fused_state_path, map_location="cpu")
131
+ layers[layer_idx].load_state_dict(state)
132
+
133
+ new_container = drop_layer(container, layer_idx + 1)
134
+ setattr(parent, name, new_container)
135
+ decrement_config(model.config)
136
+
137
+ container = new_container
138
+
139
+ if device:
140
+ model.to(device)
141
+
142
+ return model