upload src
Browse files- src/common_lm_data.py +435 -0
- src/convert_llmpruner_checkpoint.py +42 -0
- src/eval_ppl.py +241 -0
- src/fbmc_metric.py +519 -0
- src/fuse_layers.py +2416 -0
- src/fuse_layers_data.py +280 -0
- src/fuse_layers_distill.py +2018 -0
- src/fuse_layers_model.py +595 -0
- src/fuse_layers_select.py +1152 -0
- src/loratune.py +430 -0
- src/loratune_config.py +86 -0
- src/ppl_eval.py +576 -0
- src/ppl_eval_progressive.py +182 -0
- src/print_progressive_ppl_csv.py +203 -0
- src/progressive_loader.py +142 -0
src/common_lm_data.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Shared LM dataset helpers for fair cross-method comparisons."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
from datasets import Dataset as HFDataset
|
| 14 |
+
except Exception: # pragma: no cover - optional dependency
|
| 15 |
+
load_dataset = None
|
| 16 |
+
HFDataset = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _normalize_config(config: Optional[str]) -> Optional[str]:
|
| 20 |
+
if config is None:
|
| 21 |
+
return None
|
| 22 |
+
if config.strip().lower() in {"none", "null", "-"}:
|
| 23 |
+
return None
|
| 24 |
+
return config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def guess_text_field(dataset) -> str:
|
| 28 |
+
if hasattr(dataset, "column_names") and dataset.column_names:
|
| 29 |
+
if "text" in dataset.column_names:
|
| 30 |
+
return "text"
|
| 31 |
+
return dataset.column_names[0]
|
| 32 |
+
if hasattr(dataset, "features"):
|
| 33 |
+
names = list(dataset.features.keys())
|
| 34 |
+
if "text" in names:
|
| 35 |
+
return "text"
|
| 36 |
+
if names:
|
| 37 |
+
return names[0]
|
| 38 |
+
return "text"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def normalize_dataset_name(name: str) -> str:
|
| 42 |
+
normalized = name.strip().lower()
|
| 43 |
+
aliases = {
|
| 44 |
+
"bookcorpus": "bookcorpus",
|
| 45 |
+
"boockcorpus": "bookcorpus",
|
| 46 |
+
"slimpajama": "slimpajama",
|
| 47 |
+
"dkyoon/slimpajama-6b": "slimpajama",
|
| 48 |
+
}
|
| 49 |
+
if normalized not in aliases:
|
| 50 |
+
raise ValueError(f"Unsupported dataset: {name}")
|
| 51 |
+
return aliases[normalized]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def resolve_dataset_spec(
|
| 55 |
+
name: str,
|
| 56 |
+
config: Optional[str] = None,
|
| 57 |
+
split: str = "train",
|
| 58 |
+
) -> Tuple[str, Optional[str], str]:
|
| 59 |
+
normalized = normalize_dataset_name(name)
|
| 60 |
+
if normalized == "bookcorpus":
|
| 61 |
+
return "bookcorpus", _normalize_config(config), split
|
| 62 |
+
if normalized == "slimpajama":
|
| 63 |
+
return "DKYoon/SlimPajama-6B", _normalize_config(config), split
|
| 64 |
+
raise ValueError(f"Unsupported dataset: {name}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _sample_dataset_rows(dataset, target: int, seed: int) -> List[Dict[str, object]]:
|
| 68 |
+
if target <= 0:
|
| 69 |
+
return []
|
| 70 |
+
try:
|
| 71 |
+
dataset = dataset.shuffle(seed=seed)
|
| 72 |
+
except Exception:
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
if hasattr(dataset, "__len__"):
|
| 76 |
+
limit = min(target, len(dataset))
|
| 77 |
+
dataset = dataset.select(range(limit))
|
| 78 |
+
return [row for row in dataset]
|
| 79 |
+
|
| 80 |
+
rows = []
|
| 81 |
+
for row in dataset:
|
| 82 |
+
rows.append(row)
|
| 83 |
+
if len(rows) >= target:
|
| 84 |
+
break
|
| 85 |
+
return rows
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _iter_dataset_rows(dataset, seed: int) -> Iterator[Dict[str, object]]:
|
| 89 |
+
try:
|
| 90 |
+
dataset = dataset.shuffle(seed=seed)
|
| 91 |
+
except Exception:
|
| 92 |
+
pass
|
| 93 |
+
for row in dataset:
|
| 94 |
+
yield row
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_named_texts(
|
| 98 |
+
dataset_name: str,
|
| 99 |
+
*,
|
| 100 |
+
config: Optional[str] = None,
|
| 101 |
+
split: str = "train",
|
| 102 |
+
text_field: Optional[str] = None,
|
| 103 |
+
num_samples: int = 0,
|
| 104 |
+
seed: int = 0,
|
| 105 |
+
) -> List[str]:
|
| 106 |
+
if load_dataset is None:
|
| 107 |
+
raise SystemExit("datasets is required for shared LM dataloaders")
|
| 108 |
+
|
| 109 |
+
hf_name, hf_config, hf_split = resolve_dataset_spec(dataset_name, config, split)
|
| 110 |
+
dataset = load_dataset(
|
| 111 |
+
hf_name,
|
| 112 |
+
hf_config,
|
| 113 |
+
split=hf_split,
|
| 114 |
+
trust_remote_code=True,
|
| 115 |
+
)
|
| 116 |
+
rows = dataset if num_samples <= 0 else _sample_dataset_rows(dataset, num_samples, seed)
|
| 117 |
+
field = text_field or guess_text_field(dataset)
|
| 118 |
+
|
| 119 |
+
texts: List[str] = []
|
| 120 |
+
for row in rows:
|
| 121 |
+
value = row.get(field, None) if isinstance(row, dict) else None
|
| 122 |
+
if isinstance(value, str) and value.strip():
|
| 123 |
+
texts.append(value)
|
| 124 |
+
return texts
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_token_chunks_from_rows(
|
| 128 |
+
rows: Iterable[Dict[str, object]],
|
| 129 |
+
*,
|
| 130 |
+
text_field: str,
|
| 131 |
+
tokenizer,
|
| 132 |
+
seq_len: int,
|
| 133 |
+
num_sequences: int = 0,
|
| 134 |
+
add_bos: bool = False,
|
| 135 |
+
max_rows: int = 0,
|
| 136 |
+
) -> List[torch.Tensor]:
|
| 137 |
+
chunks: List[torch.Tensor] = []
|
| 138 |
+
buffer: List[int] = []
|
| 139 |
+
limit = None if num_sequences <= 0 else num_sequences
|
| 140 |
+
rows_seen = 0
|
| 141 |
+
|
| 142 |
+
for row in rows:
|
| 143 |
+
if max_rows > 0 and rows_seen >= max_rows:
|
| 144 |
+
break
|
| 145 |
+
rows_seen += 1
|
| 146 |
+
|
| 147 |
+
value = row.get(text_field, None) if isinstance(row, dict) else None
|
| 148 |
+
if not isinstance(value, str) or not value.strip():
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
ids = tokenizer.encode(value, add_special_tokens=False)
|
| 152 |
+
if add_bos and tokenizer.bos_token_id is not None:
|
| 153 |
+
ids = [tokenizer.bos_token_id] + ids
|
| 154 |
+
if not ids:
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
buffer.extend(ids)
|
| 158 |
+
while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
|
| 159 |
+
chunk = buffer[:seq_len]
|
| 160 |
+
buffer = buffer[seq_len:]
|
| 161 |
+
chunks.append(torch.tensor(chunk, dtype=torch.long))
|
| 162 |
+
if limit is not None and len(chunks) >= limit:
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
return chunks
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def collect_texts_from_rows(
|
| 169 |
+
rows: Iterable[Dict[str, object]],
|
| 170 |
+
*,
|
| 171 |
+
text_field: str,
|
| 172 |
+
tokenizer,
|
| 173 |
+
target_tokens: int = 0,
|
| 174 |
+
add_bos: bool = False,
|
| 175 |
+
max_rows: int = 0,
|
| 176 |
+
) -> List[str]:
|
| 177 |
+
texts: List[str] = []
|
| 178 |
+
token_count = 0
|
| 179 |
+
rows_seen = 0
|
| 180 |
+
|
| 181 |
+
for row in rows:
|
| 182 |
+
if max_rows > 0 and rows_seen >= max_rows:
|
| 183 |
+
break
|
| 184 |
+
rows_seen += 1
|
| 185 |
+
|
| 186 |
+
value = row.get(text_field, None) if isinstance(row, dict) else None
|
| 187 |
+
if not isinstance(value, str) or not value.strip():
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
texts.append(value)
|
| 191 |
+
if target_tokens > 0:
|
| 192 |
+
ids = tokenizer.encode(value, add_special_tokens=False)
|
| 193 |
+
if add_bos and tokenizer.bos_token_id is not None:
|
| 194 |
+
ids = [tokenizer.bos_token_id] + ids
|
| 195 |
+
token_count += len(ids)
|
| 196 |
+
if token_count >= target_tokens:
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
return texts
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def build_token_chunks(
|
| 203 |
+
texts: Iterable[str],
|
| 204 |
+
tokenizer,
|
| 205 |
+
seq_len: int,
|
| 206 |
+
num_sequences: int = 0,
|
| 207 |
+
add_bos: bool = False,
|
| 208 |
+
) -> List[torch.Tensor]:
|
| 209 |
+
chunks: List[torch.Tensor] = []
|
| 210 |
+
buffer: List[int] = []
|
| 211 |
+
limit = None if num_sequences <= 0 else num_sequences
|
| 212 |
+
|
| 213 |
+
for text in texts:
|
| 214 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 215 |
+
if add_bos and tokenizer.bos_token_id is not None:
|
| 216 |
+
ids = [tokenizer.bos_token_id] + ids
|
| 217 |
+
if not ids:
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
buffer.extend(ids)
|
| 221 |
+
while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
|
| 222 |
+
chunk = buffer[:seq_len]
|
| 223 |
+
buffer = buffer[seq_len:]
|
| 224 |
+
chunks.append(torch.tensor(chunk, dtype=torch.long))
|
| 225 |
+
if limit is not None and len(chunks) >= limit:
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
return chunks
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class TokenChunkDataset(torch.utils.data.Dataset):
|
| 232 |
+
def __init__(self, chunks: List[torch.Tensor]) -> None:
|
| 233 |
+
self.chunks = chunks
|
| 234 |
+
|
| 235 |
+
def __len__(self) -> int:
|
| 236 |
+
return len(self.chunks)
|
| 237 |
+
|
| 238 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 239 |
+
input_ids = self.chunks[idx]
|
| 240 |
+
attention_mask = torch.ones_like(input_ids)
|
| 241 |
+
return {
|
| 242 |
+
"input_ids": input_ids,
|
| 243 |
+
"attention_mask": attention_mask,
|
| 244 |
+
"labels": input_ids.clone(),
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class TokenOnlyDataset(torch.utils.data.Dataset):
|
| 249 |
+
def __init__(self, chunks: List[torch.Tensor]) -> None:
|
| 250 |
+
self.chunks = chunks
|
| 251 |
+
|
| 252 |
+
def __len__(self) -> int:
|
| 253 |
+
return len(self.chunks)
|
| 254 |
+
|
| 255 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
| 256 |
+
return self.chunks[idx]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class TokenInputMaskDataset(torch.utils.data.Dataset):
|
| 260 |
+
def __init__(self, chunks: List[torch.Tensor]) -> None:
|
| 261 |
+
self.chunks = chunks
|
| 262 |
+
|
| 263 |
+
def __len__(self) -> int:
|
| 264 |
+
return len(self.chunks)
|
| 265 |
+
|
| 266 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 267 |
+
input_ids = self.chunks[idx]
|
| 268 |
+
return {
|
| 269 |
+
"input_ids": input_ids,
|
| 270 |
+
"attention_mask": torch.ones_like(input_ids),
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@dataclass
|
| 275 |
+
class SharedLMDataSpec:
|
| 276 |
+
dataset: str
|
| 277 |
+
config: Optional[str] = None
|
| 278 |
+
split: str = "train"
|
| 279 |
+
text_field: Optional[str] = None
|
| 280 |
+
num_samples: int = 0
|
| 281 |
+
seq_len: int = 2048
|
| 282 |
+
num_sequences: int = 0
|
| 283 |
+
target_tokens: int = 0
|
| 284 |
+
batch_size: int = 1
|
| 285 |
+
shuffle: bool = False
|
| 286 |
+
num_workers: int = 0
|
| 287 |
+
seed: int = 0
|
| 288 |
+
add_bos: bool = False
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def build_chunks(spec: SharedLMDataSpec, tokenizer) -> List[torch.Tensor]:
|
| 292 |
+
if load_dataset is None:
|
| 293 |
+
raise SystemExit("datasets is required for shared LM dataloaders")
|
| 294 |
+
|
| 295 |
+
hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split)
|
| 296 |
+
dataset = load_dataset(
|
| 297 |
+
hf_name,
|
| 298 |
+
hf_config,
|
| 299 |
+
split=hf_split,
|
| 300 |
+
trust_remote_code=True,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
target_sequences = spec.num_sequences
|
| 304 |
+
if spec.target_tokens > 0:
|
| 305 |
+
token_sequences = (spec.target_tokens + spec.seq_len - 1) // spec.seq_len
|
| 306 |
+
target_sequences = max(target_sequences, token_sequences)
|
| 307 |
+
row_limit = spec.num_samples if target_sequences <= 0 else 0
|
| 308 |
+
|
| 309 |
+
rows = _iter_dataset_rows(dataset, spec.seed)
|
| 310 |
+
text_field = spec.text_field or guess_text_field(dataset)
|
| 311 |
+
chunks = build_token_chunks_from_rows(
|
| 312 |
+
rows,
|
| 313 |
+
text_field=text_field,
|
| 314 |
+
tokenizer=tokenizer,
|
| 315 |
+
seq_len=spec.seq_len,
|
| 316 |
+
num_sequences=target_sequences,
|
| 317 |
+
add_bos=spec.add_bos,
|
| 318 |
+
max_rows=row_limit,
|
| 319 |
+
)
|
| 320 |
+
return chunks
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def build_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader:
|
| 324 |
+
chunks = build_chunks(spec, tokenizer)
|
| 325 |
+
dataset = TokenChunkDataset(chunks)
|
| 326 |
+
return torch.utils.data.DataLoader(
|
| 327 |
+
dataset,
|
| 328 |
+
batch_size=spec.batch_size,
|
| 329 |
+
shuffle=spec.shuffle,
|
| 330 |
+
num_workers=spec.num_workers,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def build_text_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader:
|
| 335 |
+
if load_dataset is None:
|
| 336 |
+
raise SystemExit("datasets is required for shared LM dataloaders")
|
| 337 |
+
|
| 338 |
+
hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split)
|
| 339 |
+
dataset = load_dataset(
|
| 340 |
+
hf_name,
|
| 341 |
+
hf_config,
|
| 342 |
+
split=hf_split,
|
| 343 |
+
trust_remote_code=True,
|
| 344 |
+
)
|
| 345 |
+
rows = _iter_dataset_rows(dataset, spec.seed)
|
| 346 |
+
text_field = spec.text_field or guess_text_field(dataset)
|
| 347 |
+
row_limit = spec.num_samples
|
| 348 |
+
texts = collect_texts_from_rows(
|
| 349 |
+
rows,
|
| 350 |
+
text_field=text_field,
|
| 351 |
+
tokenizer=tokenizer,
|
| 352 |
+
target_tokens=spec.target_tokens,
|
| 353 |
+
add_bos=spec.add_bos,
|
| 354 |
+
max_rows=row_limit,
|
| 355 |
+
)
|
| 356 |
+
return torch.utils.data.DataLoader(
|
| 357 |
+
texts,
|
| 358 |
+
batch_size=spec.batch_size,
|
| 359 |
+
shuffle=spec.shuffle,
|
| 360 |
+
num_workers=spec.num_workers,
|
| 361 |
+
drop_last=True,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def build_uidl_post_train_dataloader(
|
| 366 |
+
spec: SharedLMDataSpec,
|
| 367 |
+
tokenizer,
|
| 368 |
+
) -> torch.utils.data.DataLoader:
|
| 369 |
+
dataset = TokenChunkDataset(build_chunks(spec, tokenizer))
|
| 370 |
+
return torch.utils.data.DataLoader(
|
| 371 |
+
dataset,
|
| 372 |
+
batch_size=spec.batch_size,
|
| 373 |
+
shuffle=spec.shuffle,
|
| 374 |
+
num_workers=spec.num_workers,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def build_uidl_similarity_dataloader(
|
| 379 |
+
spec: SharedLMDataSpec,
|
| 380 |
+
tokenizer,
|
| 381 |
+
) -> torch.utils.data.DataLoader:
|
| 382 |
+
dataset = TokenInputMaskDataset(build_chunks(spec, tokenizer))
|
| 383 |
+
return torch.utils.data.DataLoader(
|
| 384 |
+
dataset,
|
| 385 |
+
batch_size=spec.batch_size,
|
| 386 |
+
shuffle=spec.shuffle,
|
| 387 |
+
num_workers=spec.num_workers,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def build_shortened_llm_dataloader(
|
| 392 |
+
spec: SharedLMDataSpec,
|
| 393 |
+
tokenizer,
|
| 394 |
+
) -> torch.utils.data.DataLoader:
|
| 395 |
+
dataset = TokenOnlyDataset(build_chunks(spec, tokenizer))
|
| 396 |
+
return torch.utils.data.DataLoader(
|
| 397 |
+
dataset,
|
| 398 |
+
batch_size=spec.batch_size,
|
| 399 |
+
shuffle=spec.shuffle,
|
| 400 |
+
num_workers=spec.num_workers,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def build_shortened_llm_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor:
|
| 405 |
+
chunks = build_chunks(spec, tokenizer)
|
| 406 |
+
if not chunks:
|
| 407 |
+
return torch.empty((0, spec.seq_len), dtype=torch.long)
|
| 408 |
+
return torch.stack(chunks, dim=0)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def build_llmpruner_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor:
|
| 412 |
+
chunks = build_chunks(spec, tokenizer)
|
| 413 |
+
if not chunks:
|
| 414 |
+
return torch.empty((0, spec.seq_len), dtype=torch.long)
|
| 415 |
+
return torch.stack(chunks, dim=0)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def build_replaceme_dataloader(
|
| 419 |
+
spec: SharedLMDataSpec,
|
| 420 |
+
tokenizer,
|
| 421 |
+
) -> torch.utils.data.DataLoader:
|
| 422 |
+
return build_text_dataloader(spec, tokenizer)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def build_hf_causal_dataset(spec: SharedLMDataSpec, tokenizer):
|
| 426 |
+
if HFDataset is None:
|
| 427 |
+
raise SystemExit("datasets is required for shared LM dataloaders")
|
| 428 |
+
|
| 429 |
+
chunks = build_chunks(spec, tokenizer)
|
| 430 |
+
payload = {
|
| 431 |
+
"input_ids": [chunk.tolist() for chunk in chunks],
|
| 432 |
+
"attention_mask": [torch.ones_like(chunk).tolist() for chunk in chunks],
|
| 433 |
+
"labels": [chunk.tolist() for chunk in chunks],
|
| 434 |
+
}
|
| 435 |
+
return HFDataset.from_dict(payload)
|
src/convert_llmpruner_checkpoint.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def ensure_llmpruner_on_path() -> None:
|
| 10 |
+
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
+
llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner")
|
| 12 |
+
if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path:
|
| 13 |
+
sys.path.insert(0, llmpruner_root)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_llmpruner_checkpoint(path: str):
|
| 17 |
+
ensure_llmpruner_on_path()
|
| 18 |
+
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
| 19 |
+
if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
|
| 20 |
+
raise SystemExit(
|
| 21 |
+
"Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries."
|
| 22 |
+
)
|
| 23 |
+
return checkpoint["model"], checkpoint["tokenizer"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> None:
|
| 27 |
+
parser = argparse.ArgumentParser(
|
| 28 |
+
description="Convert an LLM-Pruner .bin checkpoint to a Hugging Face save_pretrained directory."
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument("--input", required=True, help="Path to LLM-Pruner pytorch_model.bin")
|
| 31 |
+
parser.add_argument("--output_dir", required=True, help="Directory to write HF model artifacts")
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
|
| 34 |
+
model, tokenizer = load_llmpruner_checkpoint(args.input)
|
| 35 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 36 |
+
model.save_pretrained(args.output_dir)
|
| 37 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 38 |
+
print(args.output_dir)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
src/eval_ppl.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Iterable
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class IndexDataset(Dataset):
|
| 18 |
+
def __init__(self, tensors: torch.Tensor):
|
| 19 |
+
self.tensors = tensors
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, index: int) -> torch.Tensor:
|
| 22 |
+
return self.tensors[index]
|
| 23 |
+
|
| 24 |
+
def __len__(self) -> int:
|
| 25 |
+
return len(self.tensors)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_dataset(name: str):
|
| 29 |
+
if name == "wikitext2":
|
| 30 |
+
train_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 31 |
+
test_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
| 32 |
+
return train_data, test_data, "text"
|
| 33 |
+
if name == "ptb":
|
| 34 |
+
train_data = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
| 35 |
+
test_data = load_dataset("ptb_text_only", "penn_treebank", split="validation")
|
| 36 |
+
return train_data, test_data, "sentence"
|
| 37 |
+
raise ValueError(f"Unsupported dataset: {name}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def process_data(samples, tokenizer, seq_len: int, field_name: str, add_bos_to_every: bool) -> IndexDataset:
|
| 41 |
+
test_ids = tokenizer(
|
| 42 |
+
"\n\n".join(samples[field_name]),
|
| 43 |
+
return_tensors="pt",
|
| 44 |
+
add_special_tokens=False,
|
| 45 |
+
).input_ids[0]
|
| 46 |
+
|
| 47 |
+
if not add_bos_to_every and tokenizer.bos_token_id is not None:
|
| 48 |
+
test_ids = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), test_ids), dim=0)
|
| 49 |
+
|
| 50 |
+
batches = []
|
| 51 |
+
num_samples = test_ids.numel() // seq_len
|
| 52 |
+
for index in range(num_samples):
|
| 53 |
+
batch = test_ids[(index * seq_len) : ((index + 1) * seq_len)]
|
| 54 |
+
if add_bos_to_every and tokenizer.bos_token_id is not None:
|
| 55 |
+
batch = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), batch), dim=0)
|
| 56 |
+
batches.append(batch)
|
| 57 |
+
|
| 58 |
+
return IndexDataset(tensors=torch.stack(batches))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_loader(name: str, tokenizer, seq_len: int, batch_size: int, add_bos_to_every: bool):
|
| 62 |
+
_, test_data, field_name = get_dataset(name)
|
| 63 |
+
dataset = process_data(test_data, tokenizer, seq_len, field_name, add_bos_to_every)
|
| 64 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@torch.no_grad()
|
| 68 |
+
def evaluate_ppl(model, test_loader, device: str) -> float:
|
| 69 |
+
nlls = []
|
| 70 |
+
for batch in tqdm(test_loader, desc="Running PPL", dynamic_ncols=True):
|
| 71 |
+
batch = batch.to(device)
|
| 72 |
+
outputs = model(batch)
|
| 73 |
+
shift_logits = outputs.logits[:, :-1, :].contiguous()
|
| 74 |
+
shift_labels = batch[:, 1:].contiguous()
|
| 75 |
+
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
| 76 |
+
loss = loss_fct(
|
| 77 |
+
shift_logits.reshape(-1, shift_logits.size(-1)),
|
| 78 |
+
shift_labels.view(-1),
|
| 79 |
+
)
|
| 80 |
+
nlls.append(loss.cpu())
|
| 81 |
+
|
| 82 |
+
return float(np.exp(torch.cat(nlls, dim=-1).mean().item()))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def resolve_dtype(args) -> torch.dtype:
|
| 86 |
+
if args.use_bfloat:
|
| 87 |
+
return torch.bfloat16
|
| 88 |
+
|
| 89 |
+
dtype_name = args.dtype if args.dtype is not None else args.torch_dtype
|
| 90 |
+
if dtype_name is None:
|
| 91 |
+
dtype_name = "float16"
|
| 92 |
+
|
| 93 |
+
dtype_map = {
|
| 94 |
+
"float16": torch.float16,
|
| 95 |
+
"fp16": torch.float16,
|
| 96 |
+
"bfloat16": torch.bfloat16,
|
| 97 |
+
"bf16": torch.bfloat16,
|
| 98 |
+
"float32": torch.float32,
|
| 99 |
+
"fp32": torch.float32,
|
| 100 |
+
}
|
| 101 |
+
if dtype_name not in dtype_map:
|
| 102 |
+
raise ValueError(f"Unsupported dtype: {dtype_name}")
|
| 103 |
+
return dtype_map[dtype_name]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def normalize_datasets(datasets: Iterable[str]) -> list[str]:
|
| 107 |
+
normalized = []
|
| 108 |
+
for dataset in datasets:
|
| 109 |
+
normalized.append("wikitext2" if dataset == "wikitext" else dataset)
|
| 110 |
+
return normalized
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_arg_parser() -> argparse.ArgumentParser:
|
| 114 |
+
parser = argparse.ArgumentParser(description="Shared perplexity evaluation for abprune.")
|
| 115 |
+
parser.add_argument("--base_model", "--model-path", dest="model_path", required=True)
|
| 116 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
| 117 |
+
parser.add_argument("--dataset", nargs="+", default=["wikitext2", "ptb"])
|
| 118 |
+
parser.add_argument("--max_seq_len", "--seq-len", dest="seq_len", type=int, default=1024)
|
| 119 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 120 |
+
parser.add_argument("--device", default="cuda")
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--dtype",
|
| 123 |
+
default=None,
|
| 124 |
+
choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"],
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--torch_dtype",
|
| 128 |
+
default=None,
|
| 129 |
+
choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"],
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument("--use_bfloat", action="store_true")
|
| 132 |
+
parser.add_argument("--add_bos_to_every", action="store_true")
|
| 133 |
+
parser.add_argument("--fix_decapoda_config", action="store_true")
|
| 134 |
+
parser.add_argument("--local_files_only", action="store_true")
|
| 135 |
+
return parser
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def maybe_fix_decapoda_config(tokenizer, enabled: bool) -> None:
|
| 139 |
+
if not enabled:
|
| 140 |
+
return
|
| 141 |
+
if tokenizer.bos_token_id is None and tokenizer.eos_token_id is not None:
|
| 142 |
+
tokenizer.bos_token = tokenizer.eos_token
|
| 143 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 144 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def ensure_llmpruner_on_path() -> None:
|
| 148 |
+
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 149 |
+
llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner")
|
| 150 |
+
if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path:
|
| 151 |
+
sys.path.insert(0, llmpruner_root)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def load_model_and_tokenizer(model_path: str, *, torch_dtype: torch.dtype, local_files_only: bool):
|
| 155 |
+
if os.path.isfile(model_path) and model_path.endswith(".bin"):
|
| 156 |
+
ensure_llmpruner_on_path()
|
| 157 |
+
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 158 |
+
if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries."
|
| 161 |
+
)
|
| 162 |
+
model = checkpoint["model"]
|
| 163 |
+
tokenizer = checkpoint["tokenizer"]
|
| 164 |
+
if torch_dtype is not None:
|
| 165 |
+
model = model.to(dtype=torch_dtype)
|
| 166 |
+
return model, tokenizer
|
| 167 |
+
|
| 168 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 169 |
+
model_path,
|
| 170 |
+
local_files_only=local_files_only,
|
| 171 |
+
)
|
| 172 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 173 |
+
model_path,
|
| 174 |
+
torch_dtype=torch_dtype,
|
| 175 |
+
local_files_only=local_files_only,
|
| 176 |
+
)
|
| 177 |
+
return model, tokenizer
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def main() -> None:
|
| 181 |
+
parser = build_arg_parser()
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
|
| 184 |
+
datasets = normalize_datasets(args.dataset)
|
| 185 |
+
torch_dtype = resolve_dtype(args)
|
| 186 |
+
|
| 187 |
+
model, tokenizer = load_model_and_tokenizer(
|
| 188 |
+
args.model_path,
|
| 189 |
+
torch_dtype=torch_dtype,
|
| 190 |
+
local_files_only=args.local_files_only,
|
| 191 |
+
)
|
| 192 |
+
maybe_fix_decapoda_config(tokenizer, args.fix_decapoda_config)
|
| 193 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 194 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 195 |
+
|
| 196 |
+
model.eval()
|
| 197 |
+
model.to(args.device)
|
| 198 |
+
|
| 199 |
+
metrics = {}
|
| 200 |
+
for dataset in datasets:
|
| 201 |
+
test_loader = get_loader(
|
| 202 |
+
dataset,
|
| 203 |
+
tokenizer,
|
| 204 |
+
seq_len=args.seq_len,
|
| 205 |
+
batch_size=args.batch_size,
|
| 206 |
+
add_bos_to_every=args.add_bos_to_every,
|
| 207 |
+
)
|
| 208 |
+
metrics[dataset] = evaluate_ppl(model, test_loader, args.device)
|
| 209 |
+
print(f"PPL-{dataset}: {metrics[dataset]} | add_bos_to_every: {args.add_bos_to_every} | seq_len: {args.seq_len}")
|
| 210 |
+
|
| 211 |
+
mem = None
|
| 212 |
+
if torch.cuda.is_available() and args.device.startswith("cuda"):
|
| 213 |
+
mem = torch.cuda.memory_allocated(args.device) / 1024 / 1024
|
| 214 |
+
|
| 215 |
+
result = {
|
| 216 |
+
"model_path": os.path.abspath(args.model_path),
|
| 217 |
+
"datasets": datasets,
|
| 218 |
+
"seq_len": args.seq_len,
|
| 219 |
+
"batch_size": args.batch_size,
|
| 220 |
+
"device": args.device,
|
| 221 |
+
"dtype": str(torch_dtype).replace("torch.", ""),
|
| 222 |
+
"add_bos_to_every": args.add_bos_to_every,
|
| 223 |
+
"metrics": metrics,
|
| 224 |
+
"params": int(sum(parameter.numel() for parameter in model.parameters())),
|
| 225 |
+
"mem_mib": mem,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
if args.output_dir is not None:
|
| 229 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 230 |
+
filename = "ppl_bos.csv" if args.add_bos_to_every else "ppl.csv"
|
| 231 |
+
csv_path = os.path.join(args.output_dir, filename)
|
| 232 |
+
with open(csv_path, "w", newline="", encoding="utf-8") as handle:
|
| 233 |
+
writer = csv.writer(handle)
|
| 234 |
+
writer.writerow([*(f"ppl_{dataset}" for dataset in datasets), "params", "mem"])
|
| 235 |
+
writer.writerow([*(metrics[dataset] for dataset in datasets), result["params"], mem])
|
| 236 |
+
|
| 237 |
+
print(json.dumps(result, ensure_ascii=True))
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
main()
|
src/fbmc_metric.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Estimate Fisher-Barycentric Merge Cost (FBMC) for adjacent layers."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import csv
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from typing import Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
except Exception: # pragma: no cover - optional dependency
|
| 15 |
+
load_dataset = None
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 19 |
+
except Exception as exc: # pragma: no cover - fail early with clear error
|
| 20 |
+
raise SystemExit("transformers is required: pip install transformers") from exc
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_args() -> argparse.Namespace:
|
| 24 |
+
parser = argparse.ArgumentParser(
|
| 25 |
+
description="Compute FBMC for adjacent layers of a Hugging Face causal LM."
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--model", required=True, help="HF model id or local path")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--dataset",
|
| 30 |
+
action="append",
|
| 31 |
+
default=[],
|
| 32 |
+
help=(
|
| 33 |
+
"HF dataset name (repeatable). Optional if using --text or --text_file."
|
| 34 |
+
),
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--dataset_config",
|
| 38 |
+
action="append",
|
| 39 |
+
default=[],
|
| 40 |
+
help="Optional dataset config (repeatable or single shared config).",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--dataset_split",
|
| 44 |
+
default="train",
|
| 45 |
+
help="Dataset split to use (default: train)",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--dataset_text_field",
|
| 49 |
+
default=None,
|
| 50 |
+
help="Text field in dataset (default: auto-detect, applies to all datasets)",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--text",
|
| 54 |
+
action="append",
|
| 55 |
+
default=[],
|
| 56 |
+
help="Inline text samples (can pass multiple)",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--text_file",
|
| 60 |
+
default=None,
|
| 61 |
+
help="Path to a text file for calibration data",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--num_samples",
|
| 65 |
+
type=int,
|
| 66 |
+
default=128,
|
| 67 |
+
help="Number of token sequences to use",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--seq_len", type=int, default=256, help="Sequence length"
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--batch_size", type=int, default=2, help="Batch size"
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--device",
|
| 77 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 78 |
+
help="Device for model + compute",
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--dtype",
|
| 82 |
+
default="auto",
|
| 83 |
+
choices=["auto", "float32", "float16", "bfloat16"],
|
| 84 |
+
help="Model dtype",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--layer_path",
|
| 88 |
+
default=None,
|
| 89 |
+
help="Override layer attribute path (e.g., model.layers)",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--fisher_mode",
|
| 93 |
+
default="tensor",
|
| 94 |
+
choices=["tensor", "param"],
|
| 95 |
+
help="Fisher approximation granularity",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon")
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--output",
|
| 100 |
+
default=None,
|
| 101 |
+
help="Optional JSON output path",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--output_csv",
|
| 105 |
+
default=None,
|
| 106 |
+
help="Optional CSV output path",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--trust_remote_code",
|
| 111 |
+
action="store_true",
|
| 112 |
+
help="Allow custom model code from hub",
|
| 113 |
+
)
|
| 114 |
+
return parser.parse_args()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def resolve_attr(root: object, path: str) -> Optional[object]:
|
| 118 |
+
cur = root
|
| 119 |
+
for part in path.split("."):
|
| 120 |
+
if not hasattr(cur, part):
|
| 121 |
+
return None
|
| 122 |
+
cur = getattr(cur, part)
|
| 123 |
+
return cur
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def find_layers(model, layer_path: Optional[str]) -> List[torch.nn.Module]:
|
| 127 |
+
if layer_path:
|
| 128 |
+
layers = resolve_attr(model, layer_path)
|
| 129 |
+
if layers is None:
|
| 130 |
+
raise ValueError(f"layer_path '{layer_path}' not found on model")
|
| 131 |
+
return list(layers)
|
| 132 |
+
|
| 133 |
+
# Common decoder-only layer containers. Add more if needed.
|
| 134 |
+
candidate_paths = [
|
| 135 |
+
"model.layers", # LLaMA, Mistral, Qwen2, Gemma
|
| 136 |
+
"model.decoder.layers", # OPT
|
| 137 |
+
"transformer.h", # GPT-2, GPT-J, Bloom, Falcon
|
| 138 |
+
"transformer.blocks", # MPT
|
| 139 |
+
"gpt_neox.layers", # GPT-NeoX
|
| 140 |
+
"layers", # fallback
|
| 141 |
+
]
|
| 142 |
+
for path in candidate_paths:
|
| 143 |
+
layers = resolve_attr(model, path)
|
| 144 |
+
if layers is not None:
|
| 145 |
+
try:
|
| 146 |
+
return list(layers)
|
| 147 |
+
except TypeError:
|
| 148 |
+
continue
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"Could not locate transformer layers. Pass --layer_path explicitly."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def guess_text_field(dataset) -> str:
|
| 155 |
+
if hasattr(dataset, "column_names") and dataset.column_names:
|
| 156 |
+
if "text" in dataset.column_names:
|
| 157 |
+
return "text"
|
| 158 |
+
return dataset.column_names[0]
|
| 159 |
+
if hasattr(dataset, "features"):
|
| 160 |
+
names = list(dataset.features.keys())
|
| 161 |
+
if "text" in names:
|
| 162 |
+
return "text"
|
| 163 |
+
if names:
|
| 164 |
+
return names[0]
|
| 165 |
+
return "text"
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _normalize_config(config: Optional[str]) -> Optional[str]:
|
| 169 |
+
if config is None:
|
| 170 |
+
return None
|
| 171 |
+
if config.strip().lower() in {"none", "null", "-"}:
|
| 172 |
+
return None
|
| 173 |
+
return config
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _expand_dataset_configs(
|
| 177 |
+
datasets: List[str], configs: List[str]
|
| 178 |
+
) -> List[Optional[str]]:
|
| 179 |
+
if not configs:
|
| 180 |
+
return [None] * len(datasets)
|
| 181 |
+
if len(configs) == 1 and len(datasets) > 1:
|
| 182 |
+
return [_normalize_config(configs[0])] * len(datasets)
|
| 183 |
+
if len(configs) != len(datasets):
|
| 184 |
+
raise SystemExit(
|
| 185 |
+
"Provide zero, one, or matching-count --dataset_config values."
|
| 186 |
+
)
|
| 187 |
+
return [_normalize_config(cfg) for cfg in configs]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _sample_dataset_rows(
|
| 191 |
+
dataset, target: int, seed: int
|
| 192 |
+
) -> List[Dict[str, object]]:
|
| 193 |
+
if target <= 0:
|
| 194 |
+
return []
|
| 195 |
+
try:
|
| 196 |
+
dataset = dataset.shuffle(seed=seed)
|
| 197 |
+
except Exception:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
if hasattr(dataset, "__len__"):
|
| 201 |
+
limit = min(target, len(dataset))
|
| 202 |
+
dataset = dataset.select(range(limit))
|
| 203 |
+
return [row for row in dataset]
|
| 204 |
+
|
| 205 |
+
# IterableDataset fallback.
|
| 206 |
+
rows = []
|
| 207 |
+
for row in dataset:
|
| 208 |
+
rows.append(row)
|
| 209 |
+
if len(rows) >= target:
|
| 210 |
+
break
|
| 211 |
+
return rows
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def load_texts(args: argparse.Namespace) -> List[str]:
|
| 215 |
+
texts: List[str] = []
|
| 216 |
+
if args.text_file:
|
| 217 |
+
with open(args.text_file, "r", encoding="utf-8") as handle:
|
| 218 |
+
texts.extend([line.strip() for line in handle if line.strip()])
|
| 219 |
+
if args.text:
|
| 220 |
+
texts.extend([t for t in args.text if t])
|
| 221 |
+
|
| 222 |
+
if args.dataset:
|
| 223 |
+
if load_dataset is None:
|
| 224 |
+
raise SystemExit("datasets is required for --dataset")
|
| 225 |
+
|
| 226 |
+
datasets = list(args.dataset)
|
| 227 |
+
configs = _expand_dataset_configs(datasets, list(args.dataset_config))
|
| 228 |
+
num_datasets = len(datasets)
|
| 229 |
+
base = args.num_samples // num_datasets
|
| 230 |
+
remainder = args.num_samples % num_datasets
|
| 231 |
+
|
| 232 |
+
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
|
| 233 |
+
target = base + (1 if idx < remainder else 0)
|
| 234 |
+
dataset = load_dataset(
|
| 235 |
+
dataset_name,
|
| 236 |
+
config,
|
| 237 |
+
split=args.dataset_split,
|
| 238 |
+
trust_remote_code=True,
|
| 239 |
+
)
|
| 240 |
+
rows = _sample_dataset_rows(dataset, target, args.seed + idx)
|
| 241 |
+
text_field = args.dataset_text_field or guess_text_field(dataset)
|
| 242 |
+
for row in rows:
|
| 243 |
+
value = row.get(text_field, None) if isinstance(row, dict) else None
|
| 244 |
+
if isinstance(value, str) and value.strip():
|
| 245 |
+
texts.append(value)
|
| 246 |
+
|
| 247 |
+
return texts
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def build_token_chunks(
|
| 251 |
+
texts: List[str], tokenizer, seq_len: int, num_samples: int
|
| 252 |
+
) -> List[torch.Tensor]:
|
| 253 |
+
chunks: List[torch.Tensor] = []
|
| 254 |
+
buffer: List[int] = []
|
| 255 |
+
for text in texts:
|
| 256 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 257 |
+
if not ids:
|
| 258 |
+
continue
|
| 259 |
+
buffer.extend(ids)
|
| 260 |
+
while len(buffer) >= seq_len and len(chunks) < num_samples:
|
| 261 |
+
chunk = buffer[:seq_len]
|
| 262 |
+
buffer = buffer[seq_len:]
|
| 263 |
+
chunks.append(torch.tensor(chunk, dtype=torch.long))
|
| 264 |
+
if len(chunks) >= num_samples:
|
| 265 |
+
break
|
| 266 |
+
return chunks
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def get_dtype(dtype: str):
|
| 270 |
+
if dtype == "auto":
|
| 271 |
+
return None
|
| 272 |
+
if dtype == "float16":
|
| 273 |
+
return torch.float16
|
| 274 |
+
if dtype == "bfloat16":
|
| 275 |
+
return torch.bfloat16
|
| 276 |
+
return torch.float32
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def compute_fisher(
|
| 280 |
+
model,
|
| 281 |
+
layers: List[torch.nn.Module],
|
| 282 |
+
dataloader,
|
| 283 |
+
fisher_mode: str,
|
| 284 |
+
device: str,
|
| 285 |
+
) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]:
|
| 286 |
+
# Only compute grads for layer params.
|
| 287 |
+
for param in model.parameters():
|
| 288 |
+
param.requires_grad_(False)
|
| 289 |
+
for layer in layers:
|
| 290 |
+
for param in layer.parameters():
|
| 291 |
+
param.requires_grad_(True)
|
| 292 |
+
|
| 293 |
+
fisher_sums: List[Dict[str, object]] = []
|
| 294 |
+
param_numels: List[Dict[str, int]] = []
|
| 295 |
+
for layer in layers:
|
| 296 |
+
layer_sums: Dict[str, object] = {}
|
| 297 |
+
layer_numels: Dict[str, int] = {}
|
| 298 |
+
for name, param in layer.named_parameters():
|
| 299 |
+
if not param.requires_grad:
|
| 300 |
+
continue
|
| 301 |
+
if fisher_mode == "param":
|
| 302 |
+
layer_sums[name] = torch.zeros_like(
|
| 303 |
+
param, dtype=torch.float32, device="cpu"
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
layer_sums[name] = 0.0
|
| 307 |
+
layer_numels[name] = param.numel()
|
| 308 |
+
fisher_sums.append(layer_sums)
|
| 309 |
+
param_numels.append(layer_numels)
|
| 310 |
+
|
| 311 |
+
num_batches = 0
|
| 312 |
+
model.eval()
|
| 313 |
+
for batch in dataloader:
|
| 314 |
+
input_ids = batch[0].to(device)
|
| 315 |
+
outputs = model(input_ids=input_ids, labels=input_ids)
|
| 316 |
+
loss = outputs.loss
|
| 317 |
+
loss.backward()
|
| 318 |
+
for layer_idx, layer in enumerate(layers):
|
| 319 |
+
layer_sums = fisher_sums[layer_idx]
|
| 320 |
+
for name, param in layer.named_parameters():
|
| 321 |
+
if not param.requires_grad:
|
| 322 |
+
continue
|
| 323 |
+
if param.grad is None:
|
| 324 |
+
continue
|
| 325 |
+
grad_sq = param.grad.detach().float().pow(2)
|
| 326 |
+
if fisher_mode == "param":
|
| 327 |
+
layer_sums[name] += grad_sq.cpu()
|
| 328 |
+
else:
|
| 329 |
+
layer_sums[name] += float(grad_sq.sum().item())
|
| 330 |
+
model.zero_grad(set_to_none=True)
|
| 331 |
+
num_batches += 1
|
| 332 |
+
|
| 333 |
+
if num_batches == 0:
|
| 334 |
+
raise RuntimeError("No batches processed; check dataset or text inputs.")
|
| 335 |
+
|
| 336 |
+
return fisher_sums, num_batches, param_numels
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def compute_fbmc_costs(
|
| 340 |
+
layers: List[torch.nn.Module],
|
| 341 |
+
fisher_sums: List[Dict[str, object]],
|
| 342 |
+
num_batches: int,
|
| 343 |
+
param_numels: List[Dict[str, int]],
|
| 344 |
+
fisher_mode: str,
|
| 345 |
+
eps: float,
|
| 346 |
+
) -> List[Dict[str, object]]:
|
| 347 |
+
layer_params: List[Dict[str, torch.nn.Parameter]] = []
|
| 348 |
+
for layer in layers:
|
| 349 |
+
layer_params.append({name: param for name, param in layer.named_parameters()})
|
| 350 |
+
|
| 351 |
+
results: List[Dict[str, object]] = []
|
| 352 |
+
for idx in range(len(layers) - 1):
|
| 353 |
+
cost = 0.0
|
| 354 |
+
matched = 0
|
| 355 |
+
skipped = 0
|
| 356 |
+
params_i = layer_params[idx]
|
| 357 |
+
params_j = layer_params[idx + 1]
|
| 358 |
+
for name, param_i in params_i.items():
|
| 359 |
+
param_j = params_j.get(name)
|
| 360 |
+
if param_j is None or param_j.shape != param_i.shape:
|
| 361 |
+
skipped += 1
|
| 362 |
+
continue
|
| 363 |
+
matched += 1
|
| 364 |
+
if fisher_mode == "param":
|
| 365 |
+
fisher_i = fisher_sums[idx][name] / num_batches
|
| 366 |
+
fisher_j = fisher_sums[idx + 1][name] / num_batches
|
| 367 |
+
diff = (param_i.detach().float().cpu() - param_j.detach().float().cpu())
|
| 368 |
+
denom = fisher_i + fisher_j + eps
|
| 369 |
+
term = (fisher_i * fisher_j / denom) * diff * diff
|
| 370 |
+
cost += 0.5 * float(term.sum().item())
|
| 371 |
+
else:
|
| 372 |
+
fisher_i = fisher_sums[idx][name] / (
|
| 373 |
+
num_batches * param_numels[idx][name]
|
| 374 |
+
)
|
| 375 |
+
fisher_j = fisher_sums[idx + 1][name] / (
|
| 376 |
+
num_batches * param_numels[idx + 1][name]
|
| 377 |
+
)
|
| 378 |
+
denom = fisher_i + fisher_j + eps
|
| 379 |
+
if denom == 0:
|
| 380 |
+
continue
|
| 381 |
+
diff_sq = (
|
| 382 |
+
param_i.detach().float() - param_j.detach().float()
|
| 383 |
+
).pow(2)
|
| 384 |
+
cost += 0.5 * (fisher_i * fisher_j / denom) * float(
|
| 385 |
+
diff_sq.sum().item()
|
| 386 |
+
)
|
| 387 |
+
results.append(
|
| 388 |
+
{
|
| 389 |
+
"layer_i": idx,
|
| 390 |
+
"layer_j": idx + 1,
|
| 391 |
+
"fbmc": cost,
|
| 392 |
+
"matched_params": matched,
|
| 393 |
+
"skipped_params": skipped,
|
| 394 |
+
}
|
| 395 |
+
)
|
| 396 |
+
return results
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def main() -> None:
|
| 400 |
+
args = parse_args()
|
| 401 |
+
torch.manual_seed(args.seed)
|
| 402 |
+
|
| 403 |
+
dtype = get_dtype(args.dtype)
|
| 404 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 405 |
+
args.model,
|
| 406 |
+
torch_dtype=dtype,
|
| 407 |
+
trust_remote_code=args.trust_remote_code,
|
| 408 |
+
)
|
| 409 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 410 |
+
args.model, trust_remote_code=args.trust_remote_code
|
| 411 |
+
)
|
| 412 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 413 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 414 |
+
|
| 415 |
+
layers = find_layers(model, args.layer_path)
|
| 416 |
+
if len(layers) < 2:
|
| 417 |
+
raise SystemExit("Model has fewer than 2 layers; cannot compute FBMC.")
|
| 418 |
+
|
| 419 |
+
texts = load_texts(args)
|
| 420 |
+
if not texts:
|
| 421 |
+
raise SystemExit(
|
| 422 |
+
"No calibration text found. Provide --dataset, --text, or --text_file."
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples)
|
| 426 |
+
if not chunks:
|
| 427 |
+
raise SystemExit("Not enough text to build token sequences.")
|
| 428 |
+
|
| 429 |
+
dataset = torch.utils.data.TensorDataset(torch.stack(chunks))
|
| 430 |
+
dataloader = torch.utils.data.DataLoader(
|
| 431 |
+
dataset, batch_size=args.batch_size, shuffle=False
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
model.to(args.device)
|
| 435 |
+
|
| 436 |
+
fisher_sums, num_batches, param_numels = compute_fisher(
|
| 437 |
+
model,
|
| 438 |
+
layers,
|
| 439 |
+
dataloader,
|
| 440 |
+
fisher_mode=args.fisher_mode,
|
| 441 |
+
device=args.device,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
costs = compute_fbmc_costs(
|
| 445 |
+
layers,
|
| 446 |
+
fisher_sums,
|
| 447 |
+
num_batches,
|
| 448 |
+
param_numels,
|
| 449 |
+
fisher_mode=args.fisher_mode,
|
| 450 |
+
eps=args.eps,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
costs_sorted = sorted(costs, key=lambda x: x["fbmc"])
|
| 454 |
+
best = costs_sorted[0]
|
| 455 |
+
|
| 456 |
+
print("FBMC results (layer order):")
|
| 457 |
+
for item in costs:
|
| 458 |
+
print(
|
| 459 |
+
f"layers {item['layer_i']} & {item['layer_j']} -> "
|
| 460 |
+
f"fbmc={item['fbmc']:.6e} "
|
| 461 |
+
f"(matched={item['matched_params']}, skipped={item['skipped_params']})"
|
| 462 |
+
)
|
| 463 |
+
print("\nFBMC results (lowest cost first):")
|
| 464 |
+
for item in costs_sorted:
|
| 465 |
+
print(
|
| 466 |
+
f"layers {item['layer_i']} & {item['layer_j']} -> "
|
| 467 |
+
f"fbmc={item['fbmc']:.6e} "
|
| 468 |
+
f"(matched={item['matched_params']}, skipped={item['skipped_params']})"
|
| 469 |
+
)
|
| 470 |
+
print(
|
| 471 |
+
f"\nBest pair: layers {best['layer_i']} & {best['layer_j']} "
|
| 472 |
+
f"(fbmc={best['fbmc']:.6e})"
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
if args.output:
|
| 476 |
+
payload = {
|
| 477 |
+
"model": args.model,
|
| 478 |
+
"num_layers": len(layers),
|
| 479 |
+
"fisher_mode": args.fisher_mode,
|
| 480 |
+
"num_batches": num_batches,
|
| 481 |
+
"num_sequences": len(chunks),
|
| 482 |
+
"seq_len": args.seq_len,
|
| 483 |
+
"best_pair": best,
|
| 484 |
+
"pairs": costs_sorted,
|
| 485 |
+
}
|
| 486 |
+
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
| 487 |
+
with open(args.output, "w", encoding="utf-8") as handle:
|
| 488 |
+
json.dump(payload, handle, indent=2)
|
| 489 |
+
print(f"\nWrote results to {args.output}")
|
| 490 |
+
|
| 491 |
+
if args.output_csv:
|
| 492 |
+
os.makedirs(os.path.dirname(args.output_csv) or ".", exist_ok=True)
|
| 493 |
+
with open(args.output_csv, "w", encoding="utf-8", newline="") as handle:
|
| 494 |
+
writer = csv.DictWriter(
|
| 495 |
+
handle,
|
| 496 |
+
fieldnames=[
|
| 497 |
+
"layer_i",
|
| 498 |
+
"layer_j",
|
| 499 |
+
"fbmc",
|
| 500 |
+
"matched_params",
|
| 501 |
+
"skipped_params",
|
| 502 |
+
],
|
| 503 |
+
)
|
| 504 |
+
writer.writeheader()
|
| 505 |
+
for item in costs_sorted:
|
| 506 |
+
writer.writerow(
|
| 507 |
+
{
|
| 508 |
+
"layer_i": item["layer_i"],
|
| 509 |
+
"layer_j": item["layer_j"],
|
| 510 |
+
"fbmc": item["fbmc"],
|
| 511 |
+
"matched_params": item["matched_params"],
|
| 512 |
+
"skipped_params": item["skipped_params"],
|
| 513 |
+
}
|
| 514 |
+
)
|
| 515 |
+
print(f"Wrote CSV results to {args.output_csv}")
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
if __name__ == "__main__":
|
| 519 |
+
main()
|
src/fuse_layers.py
ADDED
|
@@ -0,0 +1,2416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Fuse adjacent layers via attention head alignment + Fisher-barycentric merge."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import copy
|
| 6 |
+
import gc
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import numpy as np
|
| 17 |
+
except Exception: # pragma: no cover - optional dependency
|
| 18 |
+
np = None
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
except Exception as exc: # pragma: no cover - fail early with clear error
|
| 23 |
+
raise SystemExit("transformers is required: pip install transformers") from exc
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import ppl_eval
|
| 27 |
+
except Exception as exc: # pragma: no cover - optional dependency
|
| 28 |
+
raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
|
| 29 |
+
|
| 30 |
+
from fuse_layers_data import (
|
| 31 |
+
FixedSeqDataset,
|
| 32 |
+
build_token_chunks,
|
| 33 |
+
expand_dataset_configs,
|
| 34 |
+
load_instruction_records,
|
| 35 |
+
load_texts,
|
| 36 |
+
load_texts_from_datasets,
|
| 37 |
+
)
|
| 38 |
+
from common_lm_data import SharedLMDataSpec, build_chunks, build_dataloader
|
| 39 |
+
from fuse_layers_distill import (
|
| 40 |
+
commutator_precondition,
|
| 41 |
+
compute_fisher_gate_priors,
|
| 42 |
+
distill_reparam_merge,
|
| 43 |
+
lora_ce_finetune,
|
| 44 |
+
)
|
| 45 |
+
from fuse_layers_model import (
|
| 46 |
+
apply_norm_policy,
|
| 47 |
+
build_head_permutation,
|
| 48 |
+
clone_state_dict,
|
| 49 |
+
compute_fisher,
|
| 50 |
+
compute_head_means,
|
| 51 |
+
decrement_config,
|
| 52 |
+
drop_layer,
|
| 53 |
+
find_attention_module,
|
| 54 |
+
find_colon_modules,
|
| 55 |
+
find_layer_container,
|
| 56 |
+
get_dtype,
|
| 57 |
+
get_norm_pair,
|
| 58 |
+
merge_layers,
|
| 59 |
+
permute_attention_heads,
|
| 60 |
+
)
|
| 61 |
+
from fuse_layers_select import select_layer_auto
|
| 62 |
+
from progressive_loader import load_causal_lm, load_progressive_model
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def parse_args() -> argparse.Namespace:
|
| 66 |
+
parser = argparse.ArgumentParser(
|
| 67 |
+
description="Fuse layer i and i+1 using head alignment + Fisher barycenter."
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument("--model", required=True, help="HF model id or local path")
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--model_cache_dir",
|
| 72 |
+
default=None,
|
| 73 |
+
help="Optional cache dir for model/tokenizer downloads",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--layer",
|
| 77 |
+
type=str,
|
| 78 |
+
default="auto",
|
| 79 |
+
help="Layer index i (int) or 'auto' to select via auto metric",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--selection_method",
|
| 83 |
+
choices=["dwce", "sequential"],
|
| 84 |
+
default="dwce",
|
| 85 |
+
help=(
|
| 86 |
+
"Pair selection policy for progressive pruning. "
|
| 87 |
+
"'dwce' uses downstream-weighted composition error; "
|
| 88 |
+
"'sequential' always takes the next available pair."
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--exclude_pairs",
|
| 93 |
+
"--exclude_layers",
|
| 94 |
+
nargs="*",
|
| 95 |
+
default=None,
|
| 96 |
+
dest="exclude_pairs",
|
| 97 |
+
help=(
|
| 98 |
+
"Exclude pair indices from consideration for any fusion. Indices refer to "
|
| 99 |
+
"pair start positions in [0..N-2]. Negative indices count from the end "
|
| 100 |
+
"(-1 = last pair, -2 = second last). Accepts space- or comma-separated ints. "
|
| 101 |
+
"Alias: --exclude_layers (deprecated)."
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--output_dir", required=True, help="Directory to write fused model"
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--dataset",
|
| 109 |
+
action="append",
|
| 110 |
+
default=[],
|
| 111 |
+
help=(
|
| 112 |
+
"HF dataset name (repeatable). Optional if using --text or --text_file."
|
| 113 |
+
),
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--dataset_config",
|
| 117 |
+
action="append",
|
| 118 |
+
default=[],
|
| 119 |
+
help="Optional dataset config (repeatable or single shared config).",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--dataset_split",
|
| 123 |
+
default="train",
|
| 124 |
+
help="Dataset split to use (default: train)",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--dataset_text_field",
|
| 128 |
+
default=None,
|
| 129 |
+
help="Text field in dataset (default: auto-detect, applies to all datasets)",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--text",
|
| 133 |
+
action="append",
|
| 134 |
+
default=[],
|
| 135 |
+
help="Inline text samples (can pass multiple)",
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--text_file",
|
| 139 |
+
default=None,
|
| 140 |
+
help="Path to a text file for calibration data",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--num_samples",
|
| 144 |
+
type=int,
|
| 145 |
+
default=128,
|
| 146 |
+
help="Number of token sequences to use",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--target_tokens",
|
| 150 |
+
type=int,
|
| 151 |
+
default=0,
|
| 152 |
+
help="Target token budget for common_lm_data-backed calibration/distillation (0 = disabled)",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument("--seq_len", type=int, default=256, help="Sequence length")
|
| 155 |
+
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--device",
|
| 158 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 159 |
+
help="Device for model + compute",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--dtype",
|
| 163 |
+
default="auto",
|
| 164 |
+
choices=["auto", "float32", "float16", "bfloat16"],
|
| 165 |
+
help="Model dtype",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--layer_path",
|
| 169 |
+
default=None,
|
| 170 |
+
help="Override layer attribute path (e.g., model.layers)",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--fisher_mode",
|
| 174 |
+
default="tensor",
|
| 175 |
+
choices=["tensor", "param"],
|
| 176 |
+
help="Fisher approximation granularity",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--no_head_permute",
|
| 180 |
+
action="store_true",
|
| 181 |
+
help=(
|
| 182 |
+
"Deprecated alias for --no_head_permute_merge. "
|
| 183 |
+
"Disables merge-stage head permutation only."
|
| 184 |
+
),
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--no_head_permute_merge",
|
| 188 |
+
action="store_true",
|
| 189 |
+
help="Disable attention head permutation alignment before merge",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--no_head_permute_select",
|
| 193 |
+
action="store_true",
|
| 194 |
+
help="Disable attention head permutation alignment during auto selection",
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon")
|
| 197 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--trust_remote_code",
|
| 200 |
+
action="store_true",
|
| 201 |
+
help="Allow custom model code from hub",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--save_metadata",
|
| 205 |
+
action="store_true",
|
| 206 |
+
help="Backward-compatible no-op; metadata is always written.",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--skip_eval",
|
| 210 |
+
action="store_true",
|
| 211 |
+
help="Skip pre/post perplexity evaluation",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--eval_dataset",
|
| 215 |
+
action="append",
|
| 216 |
+
default=[],
|
| 217 |
+
help="Evaluation dataset name (repeatable). Defaults to wikitext.",
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--eval_dataset_config",
|
| 221 |
+
action="append",
|
| 222 |
+
default=[],
|
| 223 |
+
help="Evaluation dataset config (repeatable or single shared config).",
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--eval_split",
|
| 227 |
+
default="test",
|
| 228 |
+
help="Evaluation dataset split (default: test)",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--eval_text_field",
|
| 232 |
+
default=None,
|
| 233 |
+
help="Evaluation text field override (default: auto-detect)",
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--eval_model_family",
|
| 237 |
+
type=str,
|
| 238 |
+
choices=["auto", "llama", "qwen"],
|
| 239 |
+
default="auto",
|
| 240 |
+
help="Model family for BOS handling during eval",
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--eval_add_bos",
|
| 244 |
+
type=str,
|
| 245 |
+
choices=["auto", "always", "never"],
|
| 246 |
+
default="auto",
|
| 247 |
+
help="Whether to prepend BOS to each eval sample",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--eval_num_samples",
|
| 251 |
+
type=int,
|
| 252 |
+
default=0,
|
| 253 |
+
help="Number of token sequences per eval dataset (0 = all)",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--eval_seq_len",
|
| 257 |
+
type=int,
|
| 258 |
+
default=2048,
|
| 259 |
+
help="Sequence length for eval",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--eval_batch_size",
|
| 263 |
+
type=int,
|
| 264 |
+
default=None,
|
| 265 |
+
help="Batch size for eval (defaults to --batch_size)",
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--eval_max_batches",
|
| 269 |
+
type=int,
|
| 270 |
+
default=None,
|
| 271 |
+
help="Optional max number of eval batches per dataset",
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--eval_cache_dir",
|
| 275 |
+
default=None,
|
| 276 |
+
help="Optional datasets cache dir for eval",
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--eval_num_workers",
|
| 280 |
+
type=int,
|
| 281 |
+
default=0,
|
| 282 |
+
help="Eval DataLoader workers",
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--eval_device",
|
| 286 |
+
default=None,
|
| 287 |
+
help="Device for eval (defaults to --device)",
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--skip_distill",
|
| 291 |
+
action="store_true",
|
| 292 |
+
help="Skip reparameterized distillation after head alignment/Fisher setup",
|
| 293 |
+
)
|
| 294 |
+
parser.add_argument(
|
| 295 |
+
"--distill_calib_samples",
|
| 296 |
+
type=int,
|
| 297 |
+
default=256,
|
| 298 |
+
help="Number of distillation sequences from calibration datasets",
|
| 299 |
+
)
|
| 300 |
+
parser.add_argument(
|
| 301 |
+
"--distill_inst_samples",
|
| 302 |
+
type=int,
|
| 303 |
+
default=0,
|
| 304 |
+
help="Number of distillation sequences from instruction dataset (0 = all)",
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument(
|
| 307 |
+
"--distill_seq_len",
|
| 308 |
+
type=int,
|
| 309 |
+
default=512,
|
| 310 |
+
help="Sequence length for distillation",
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--distill_batch_size",
|
| 314 |
+
type=int,
|
| 315 |
+
default=2,
|
| 316 |
+
help="Batch size for distillation",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--distill_epochs",
|
| 320 |
+
type=float,
|
| 321 |
+
default=1.0,
|
| 322 |
+
help="Number of distillation epochs (float allowed, e.g. 0.5)",
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--distill_lr",
|
| 326 |
+
type=float,
|
| 327 |
+
default=1e-4,
|
| 328 |
+
help="Learning rate for distillation",
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--distill_method",
|
| 332 |
+
choices=["reparam"],
|
| 333 |
+
default="reparam",
|
| 334 |
+
help="Distillation strategy (reparam only).",
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--distill_kl_weight",
|
| 338 |
+
type=float,
|
| 339 |
+
default=1e-2,
|
| 340 |
+
help="Weight for KL loss on logits",
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--distill_kl_temp",
|
| 344 |
+
type=float,
|
| 345 |
+
default=4.0,
|
| 346 |
+
help="Temperature for KL distillation on logits",
|
| 347 |
+
)
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--distill_hidden_mse_weight",
|
| 350 |
+
type=float,
|
| 351 |
+
default=1.0,
|
| 352 |
+
help="Weight for hidden-state MSE in reparam distillation (0 disables it)",
|
| 353 |
+
)
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--distill_attn_mse_weight",
|
| 356 |
+
type=float,
|
| 357 |
+
default=0.0,
|
| 358 |
+
help="Weight for auxiliary attention-output MSE in reparam distillation",
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--distill_mlp_mse_weight",
|
| 362 |
+
type=float,
|
| 363 |
+
default=0.0,
|
| 364 |
+
help="Weight for auxiliary MLP-output MSE in reparam distillation",
|
| 365 |
+
)
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--reparam_eta",
|
| 368 |
+
type=float,
|
| 369 |
+
default=1e-2,
|
| 370 |
+
help="Eta: ||lambda - lambda_gate||^2 regularizer weight for --distill_method reparam",
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument(
|
| 373 |
+
"--reparam_gamma",
|
| 374 |
+
type=float,
|
| 375 |
+
default=1e-4,
|
| 376 |
+
help="Gamma: ||U - U0||^2 regularizer weight for --distill_method reparam",
|
| 377 |
+
)
|
| 378 |
+
parser.add_argument(
|
| 379 |
+
"--reparam_attn_reg_scale",
|
| 380 |
+
type=float,
|
| 381 |
+
default=1.0,
|
| 382 |
+
help="Relative scale applied to attention-parameter reparam regularizers",
|
| 383 |
+
)
|
| 384 |
+
parser.add_argument(
|
| 385 |
+
"--reparam_mlp_reg_scale",
|
| 386 |
+
type=float,
|
| 387 |
+
default=1.0,
|
| 388 |
+
help="Relative scale applied to MLP-parameter reparam regularizers",
|
| 389 |
+
)
|
| 390 |
+
parser.add_argument(
|
| 391 |
+
"--reparam_param_subset",
|
| 392 |
+
type=str,
|
| 393 |
+
choices=["all", "mlp", "attn"],
|
| 394 |
+
default="all",
|
| 395 |
+
help="Restrict reparam merge/recovery capacity to only this parameter family",
|
| 396 |
+
)
|
| 397 |
+
parser.add_argument(
|
| 398 |
+
"--norm_policy",
|
| 399 |
+
type=str,
|
| 400 |
+
choices=["hybrid", "merge_all", "copy_n1", "copy_n1_n2"],
|
| 401 |
+
default="hybrid",
|
| 402 |
+
help="Norm merge policy (default: hybrid)",
|
| 403 |
+
)
|
| 404 |
+
parser.add_argument(
|
| 405 |
+
"--distill_weight_decay",
|
| 406 |
+
type=float,
|
| 407 |
+
default=0.0,
|
| 408 |
+
help="Weight decay for distillation",
|
| 409 |
+
)
|
| 410 |
+
parser.add_argument(
|
| 411 |
+
"--distill_max_grad_norm",
|
| 412 |
+
type=float,
|
| 413 |
+
default=1.0,
|
| 414 |
+
help="Max grad norm for distillation",
|
| 415 |
+
)
|
| 416 |
+
parser.add_argument(
|
| 417 |
+
"--distill_grad_accum_steps",
|
| 418 |
+
type=int,
|
| 419 |
+
default=1,
|
| 420 |
+
help="Gradient accumulation steps for distillation",
|
| 421 |
+
)
|
| 422 |
+
parser.add_argument(
|
| 423 |
+
"--distill_log_steps",
|
| 424 |
+
type=int,
|
| 425 |
+
default=100,
|
| 426 |
+
help="Log distillation loss every N steps",
|
| 427 |
+
)
|
| 428 |
+
parser.add_argument(
|
| 429 |
+
"--distill_eval_every",
|
| 430 |
+
type=int,
|
| 431 |
+
default=0,
|
| 432 |
+
help="Evaluate PPL every N distill steps (0 = disable)",
|
| 433 |
+
)
|
| 434 |
+
parser.add_argument(
|
| 435 |
+
"--distill_eval_max_batches",
|
| 436 |
+
type=int,
|
| 437 |
+
default=None,
|
| 438 |
+
help="Max eval batches per dataset during distill (default: all)",
|
| 439 |
+
)
|
| 440 |
+
parser.add_argument(
|
| 441 |
+
"--distill_teacher_device",
|
| 442 |
+
default=None,
|
| 443 |
+
help="Device for teacher model during distillation (defaults to --device)",
|
| 444 |
+
)
|
| 445 |
+
parser.add_argument(
|
| 446 |
+
"--comm_enabled",
|
| 447 |
+
action="store_true",
|
| 448 |
+
help=(
|
| 449 |
+
"Enable commutator-style preconditioning before each progressive "
|
| 450 |
+
"cycle's fusion."
|
| 451 |
+
),
|
| 452 |
+
)
|
| 453 |
+
parser.add_argument(
|
| 454 |
+
"--comm_include_cycle1",
|
| 455 |
+
action="store_true",
|
| 456 |
+
help="Run commutator preconditioning for cycle 1 as well (default: skip cycle 1).",
|
| 457 |
+
)
|
| 458 |
+
parser.add_argument(
|
| 459 |
+
"--comm_topk",
|
| 460 |
+
type=int,
|
| 461 |
+
default=1,
|
| 462 |
+
help="Top-K lowest-score pairs used as the commutator candidate set",
|
| 463 |
+
)
|
| 464 |
+
parser.add_argument(
|
| 465 |
+
"--comm_sample_eta",
|
| 466 |
+
type=float,
|
| 467 |
+
default=0.5,
|
| 468 |
+
help="Mixture weight between uniform and score-biased candidate sampling",
|
| 469 |
+
)
|
| 470 |
+
parser.add_argument(
|
| 471 |
+
"--comm_sample_dwce_scale",
|
| 472 |
+
type=float,
|
| 473 |
+
default=1.0,
|
| 474 |
+
help="Scale c in softmax(-c * score(i)) for commutator pair sampling",
|
| 475 |
+
)
|
| 476 |
+
parser.add_argument(
|
| 477 |
+
"--comm_temp",
|
| 478 |
+
type=float,
|
| 479 |
+
default=2.0,
|
| 480 |
+
help="Temperature for teacher-anchor KL in commutator preconditioning",
|
| 481 |
+
)
|
| 482 |
+
parser.add_argument(
|
| 483 |
+
"--comm_steps_ratio",
|
| 484 |
+
type=float,
|
| 485 |
+
default=0.1,
|
| 486 |
+
help="Run this fraction of distillation optimizer steps for commutator phase",
|
| 487 |
+
)
|
| 488 |
+
parser.add_argument(
|
| 489 |
+
"--comm_lr_scale",
|
| 490 |
+
type=float,
|
| 491 |
+
default=0.1,
|
| 492 |
+
help="Commutator LR = --distill_lr * this scale",
|
| 493 |
+
)
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--comm_train_mode",
|
| 496 |
+
choices=["lora", "full"],
|
| 497 |
+
default="lora",
|
| 498 |
+
help=(
|
| 499 |
+
"Commutator trainable parameter mode: "
|
| 500 |
+
"'lora' updates LoRA adapters on sampled receiver layers; "
|
| 501 |
+
"'full' updates full receiver-layer weights."
|
| 502 |
+
),
|
| 503 |
+
)
|
| 504 |
+
parser.add_argument(
|
| 505 |
+
"--comm_interaction_mode",
|
| 506 |
+
choices=["mse", "relative"],
|
| 507 |
+
default="relative",
|
| 508 |
+
help="Interaction loss form: plain MSE or relative MSE",
|
| 509 |
+
)
|
| 510 |
+
parser.add_argument(
|
| 511 |
+
"--comm_interaction_eps",
|
| 512 |
+
type=float,
|
| 513 |
+
default=1e-8,
|
| 514 |
+
help="Epsilon for relative commutator interaction normalization",
|
| 515 |
+
)
|
| 516 |
+
parser.add_argument(
|
| 517 |
+
"--comm_mu",
|
| 518 |
+
type=float,
|
| 519 |
+
default=None,
|
| 520 |
+
help=(
|
| 521 |
+
"Weight for interaction loss. Defaults to 0.1 for --comm_interaction_mode=mse "
|
| 522 |
+
"and 0.5 for --comm_interaction_mode=relative."
|
| 523 |
+
),
|
| 524 |
+
)
|
| 525 |
+
parser.add_argument(
|
| 526 |
+
"--comm_mu_auto",
|
| 527 |
+
action="store_true",
|
| 528 |
+
help="Enable automatic mu scaling via gradient-norm balancing",
|
| 529 |
+
)
|
| 530 |
+
parser.add_argument(
|
| 531 |
+
"--comm_mu_auto_rho",
|
| 532 |
+
type=float,
|
| 533 |
+
default=0.1,
|
| 534 |
+
help="Target anchor-to-interaction gradient ratio constant for auto-mu",
|
| 535 |
+
)
|
| 536 |
+
parser.add_argument(
|
| 537 |
+
"--comm_mu_auto_eps",
|
| 538 |
+
type=float,
|
| 539 |
+
default=1e-8,
|
| 540 |
+
help="Numerical epsilon in auto-mu denominator",
|
| 541 |
+
)
|
| 542 |
+
parser.add_argument(
|
| 543 |
+
"--comm_log_steps",
|
| 544 |
+
type=int,
|
| 545 |
+
default=50,
|
| 546 |
+
help="Log commutator preconditioning loss every N optimizer steps",
|
| 547 |
+
)
|
| 548 |
+
parser.add_argument(
|
| 549 |
+
"--comm_skip_post_reselect",
|
| 550 |
+
action="store_true",
|
| 551 |
+
help=(
|
| 552 |
+
"Keep the pre-comm selected fusion pair and skip recomputing "
|
| 553 |
+
"selection after commutator preconditioning."
|
| 554 |
+
),
|
| 555 |
+
)
|
| 556 |
+
parser.add_argument(
|
| 557 |
+
"--redistrib_teacher_source",
|
| 558 |
+
type=str,
|
| 559 |
+
choices=["base_model", "previous_cycle"],
|
| 560 |
+
default="base_model",
|
| 561 |
+
help=(
|
| 562 |
+
"Teacher source for commutator preconditioning teacher loading. "
|
| 563 |
+
"'base_model' uses --model for all cycles; "
|
| 564 |
+
"'previous_cycle' uses cycle-1 checkpoint (cycle 1 falls back to base_model)."
|
| 565 |
+
),
|
| 566 |
+
)
|
| 567 |
+
parser.add_argument(
|
| 568 |
+
"--lora_epochs",
|
| 569 |
+
type=float,
|
| 570 |
+
default=1.0,
|
| 571 |
+
help="LoRA CE finetuning epochs after distill (0 = disable)",
|
| 572 |
+
)
|
| 573 |
+
parser.add_argument(
|
| 574 |
+
"--lora_rank",
|
| 575 |
+
type=int,
|
| 576 |
+
default=8,
|
| 577 |
+
help="LoRA rank (r)",
|
| 578 |
+
)
|
| 579 |
+
parser.add_argument(
|
| 580 |
+
"--lora_alpha",
|
| 581 |
+
type=float,
|
| 582 |
+
default=16.0,
|
| 583 |
+
help="LoRA alpha",
|
| 584 |
+
)
|
| 585 |
+
parser.add_argument(
|
| 586 |
+
"--lora_dropout",
|
| 587 |
+
type=float,
|
| 588 |
+
default=0.0,
|
| 589 |
+
help="LoRA dropout",
|
| 590 |
+
)
|
| 591 |
+
parser.add_argument(
|
| 592 |
+
"--lora_kl_enabled",
|
| 593 |
+
action="store_true",
|
| 594 |
+
help="Add KL regularization between pre/post LoRA logits",
|
| 595 |
+
)
|
| 596 |
+
parser.add_argument(
|
| 597 |
+
"--lora_kl_weight",
|
| 598 |
+
type=float,
|
| 599 |
+
default=1e-1,
|
| 600 |
+
help="KL weight for LoRA regularization",
|
| 601 |
+
)
|
| 602 |
+
parser.add_argument(
|
| 603 |
+
"--lora_kl_temp",
|
| 604 |
+
type=float,
|
| 605 |
+
default=4.0,
|
| 606 |
+
help="Temperature for LoRA KL regularization",
|
| 607 |
+
)
|
| 608 |
+
parser.add_argument(
|
| 609 |
+
"--lora_target_modules",
|
| 610 |
+
nargs="*",
|
| 611 |
+
default=[
|
| 612 |
+
"q_proj",
|
| 613 |
+
"k_proj",
|
| 614 |
+
"v_proj",
|
| 615 |
+
"o_proj",
|
| 616 |
+
"gate_proj",
|
| 617 |
+
"down_proj",
|
| 618 |
+
"up_proj",
|
| 619 |
+
],
|
| 620 |
+
help="Module name suffixes to LoRA-wrap",
|
| 621 |
+
)
|
| 622 |
+
parser.add_argument(
|
| 623 |
+
"--lora_respect_exclude_pairs",
|
| 624 |
+
action="store_true",
|
| 625 |
+
help=(
|
| 626 |
+
"When attaching LoRA adapters, skip linear modules under layers touched by "
|
| 627 |
+
"--exclude_pairs (i and i+1 for each excluded pair)."
|
| 628 |
+
),
|
| 629 |
+
)
|
| 630 |
+
parser.add_argument(
|
| 631 |
+
"--lora_lr",
|
| 632 |
+
type=float,
|
| 633 |
+
default=1e-4,
|
| 634 |
+
help="Learning rate for LoRA finetuning",
|
| 635 |
+
)
|
| 636 |
+
parser.add_argument(
|
| 637 |
+
"--lora_weight_decay",
|
| 638 |
+
type=float,
|
| 639 |
+
default=0.0,
|
| 640 |
+
help="Weight decay for LoRA finetuning",
|
| 641 |
+
)
|
| 642 |
+
parser.add_argument(
|
| 643 |
+
"--lora_max_grad_norm",
|
| 644 |
+
type=float,
|
| 645 |
+
default=1.0,
|
| 646 |
+
help="Max grad norm for LoRA finetuning",
|
| 647 |
+
)
|
| 648 |
+
parser.add_argument(
|
| 649 |
+
"--lora_grad_accum_steps",
|
| 650 |
+
type=int,
|
| 651 |
+
default=1,
|
| 652 |
+
help="Gradient accumulation steps for LoRA finetuning",
|
| 653 |
+
)
|
| 654 |
+
parser.add_argument(
|
| 655 |
+
"--lora_log_steps",
|
| 656 |
+
type=int,
|
| 657 |
+
default=100,
|
| 658 |
+
help="Log LoRA loss every N steps",
|
| 659 |
+
)
|
| 660 |
+
parser.add_argument(
|
| 661 |
+
"--lora_eval_every",
|
| 662 |
+
type=int,
|
| 663 |
+
default=0,
|
| 664 |
+
help="Evaluate PPL every N LoRA steps (0 = disable)",
|
| 665 |
+
)
|
| 666 |
+
parser.add_argument(
|
| 667 |
+
"--lora_eval_max_batches",
|
| 668 |
+
type=int,
|
| 669 |
+
default=None,
|
| 670 |
+
help="Max eval batches per dataset during LoRA (default: all)",
|
| 671 |
+
)
|
| 672 |
+
parser.add_argument(
|
| 673 |
+
"--instruction_dataset",
|
| 674 |
+
default=None,
|
| 675 |
+
help="HF dataset name for alpaca-style instruction data",
|
| 676 |
+
)
|
| 677 |
+
parser.add_argument(
|
| 678 |
+
"--instruction_config",
|
| 679 |
+
default=None,
|
| 680 |
+
help="Optional instruction dataset config",
|
| 681 |
+
)
|
| 682 |
+
parser.add_argument(
|
| 683 |
+
"--instruction_split",
|
| 684 |
+
default="train",
|
| 685 |
+
help="Instruction dataset split",
|
| 686 |
+
)
|
| 687 |
+
parser.add_argument(
|
| 688 |
+
"--instruction_field_instruction",
|
| 689 |
+
default="instruction",
|
| 690 |
+
help="Instruction field name",
|
| 691 |
+
)
|
| 692 |
+
parser.add_argument(
|
| 693 |
+
"--instruction_field_input",
|
| 694 |
+
default="input",
|
| 695 |
+
help="Optional input field name",
|
| 696 |
+
)
|
| 697 |
+
parser.add_argument(
|
| 698 |
+
"--instruction_field_output",
|
| 699 |
+
default="output",
|
| 700 |
+
help="Response/output field name",
|
| 701 |
+
)
|
| 702 |
+
parser.add_argument(
|
| 703 |
+
"--auto_max_batches",
|
| 704 |
+
type=int,
|
| 705 |
+
default=0,
|
| 706 |
+
help="Max calibration batches for auto selection scoring (0 = all)",
|
| 707 |
+
)
|
| 708 |
+
parser.add_argument(
|
| 709 |
+
"--auto_metric",
|
| 710 |
+
type=str,
|
| 711 |
+
choices=[
|
| 712 |
+
"dwce",
|
| 713 |
+
"cosine",
|
| 714 |
+
"hybrid",
|
| 715 |
+
"hybrid_cosine",
|
| 716 |
+
"hybrid_global_rel",
|
| 717 |
+
],
|
| 718 |
+
default="dwce",
|
| 719 |
+
help=(
|
| 720 |
+
"Auto pair scoring metric. 'dwce' uses downstream-weighted composition error; "
|
| 721 |
+
"'cosine' uses average token-level cosine distance between adjacent layer outputs; "
|
| 722 |
+
"'hybrid'/'hybrid_cosine' use DWCE to shortlist then adjacent cosine for final scoring; "
|
| 723 |
+
"'hybrid_global_rel' uses DWCE to shortlist then reranks by the change in "
|
| 724 |
+
"pair-to-final-layer cosine relation after surrogate fusion."
|
| 725 |
+
),
|
| 726 |
+
)
|
| 727 |
+
parser.add_argument(
|
| 728 |
+
"--auto_cosine_topk",
|
| 729 |
+
type=int,
|
| 730 |
+
default=3,
|
| 731 |
+
help="Top-K DWCE candidates to rescore with cosine in --auto_metric=hybrid",
|
| 732 |
+
)
|
| 733 |
+
parser.add_argument(
|
| 734 |
+
"--auto_norm",
|
| 735 |
+
type=str,
|
| 736 |
+
choices=["relative", "none"],
|
| 737 |
+
default="relative",
|
| 738 |
+
help="Normalization mode for DWCE scoring (ignored for cosine)",
|
| 739 |
+
)
|
| 740 |
+
parser.add_argument(
|
| 741 |
+
"--auto_dwce_mode",
|
| 742 |
+
type=str,
|
| 743 |
+
choices=["separate", "shared"],
|
| 744 |
+
default="separate",
|
| 745 |
+
help=(
|
| 746 |
+
"DWCE implementation for auto scoring. "
|
| 747 |
+
"'separate' runs distinct Fisher and DWCE backward passes; "
|
| 748 |
+
"'shared' reuses one backward pass and replays DWCE with cached gradients."
|
| 749 |
+
),
|
| 750 |
+
)
|
| 751 |
+
parser.add_argument(
|
| 752 |
+
"--num_progressive",
|
| 753 |
+
type=int,
|
| 754 |
+
default=0,
|
| 755 |
+
help="Number of progressive fusions (>0 required)",
|
| 756 |
+
)
|
| 757 |
+
parser.add_argument(
|
| 758 |
+
"--resume_from_cycle",
|
| 759 |
+
type=int,
|
| 760 |
+
default=0,
|
| 761 |
+
help=(
|
| 762 |
+
"Resume from this completed cycle index. When > 0, --model should point "
|
| 763 |
+
"to the saved full model directory for that cycle."
|
| 764 |
+
),
|
| 765 |
+
)
|
| 766 |
+
parser.add_argument(
|
| 767 |
+
"--save_full_model_cycles",
|
| 768 |
+
nargs="*",
|
| 769 |
+
default=[],
|
| 770 |
+
help=(
|
| 771 |
+
"Cycle indices whose full models should be saved. Requesting cycle c "
|
| 772 |
+
"also saves cycle c-1 automatically (c=1 saves only cycle 1)."
|
| 773 |
+
),
|
| 774 |
+
)
|
| 775 |
+
return parser.parse_args()
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def parse_exclude_pairs(exclude_raw: Optional[List[str]], num_pairs: int) -> List[int]:
|
| 779 |
+
"""Parse --exclude_pairs into normalized pair indices for the current model.
|
| 780 |
+
|
| 781 |
+
Indices refer to the start of an adjacent pair (i, i+1) and must be in [0..N-2].
|
| 782 |
+
Negative indices count from the end (-1 = last pair).
|
| 783 |
+
"""
|
| 784 |
+
if not exclude_raw:
|
| 785 |
+
return []
|
| 786 |
+
exclude: List[int] = []
|
| 787 |
+
for item in exclude_raw:
|
| 788 |
+
if item is None:
|
| 789 |
+
continue
|
| 790 |
+
for part in str(item).split(","):
|
| 791 |
+
part = part.strip()
|
| 792 |
+
if not part:
|
| 793 |
+
continue
|
| 794 |
+
try:
|
| 795 |
+
idx = int(part)
|
| 796 |
+
except ValueError as exc:
|
| 797 |
+
raise SystemExit("--exclude_pairs must contain integers.") from exc
|
| 798 |
+
if idx < 0:
|
| 799 |
+
idx = num_pairs + idx
|
| 800 |
+
if 0 <= idx < num_pairs:
|
| 801 |
+
exclude.append(idx)
|
| 802 |
+
return sorted(set(exclude))
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def parse_cycle_list(raw_values: Optional[List[str]]) -> List[int]:
|
| 806 |
+
if not raw_values:
|
| 807 |
+
return []
|
| 808 |
+
cycles: List[int] = []
|
| 809 |
+
for item in raw_values:
|
| 810 |
+
if item is None:
|
| 811 |
+
continue
|
| 812 |
+
for part in str(item).split(","):
|
| 813 |
+
part = part.strip()
|
| 814 |
+
if not part:
|
| 815 |
+
continue
|
| 816 |
+
try:
|
| 817 |
+
cycles.append(int(part))
|
| 818 |
+
except ValueError as exc:
|
| 819 |
+
raise SystemExit(
|
| 820 |
+
"--save_full_model_cycles must contain integers."
|
| 821 |
+
) from exc
|
| 822 |
+
return cycles
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def resolve_full_model_save_cycles(
|
| 826 |
+
requested_cycles: List[int], num_progressive: int
|
| 827 |
+
) -> Set[int]:
|
| 828 |
+
resolved: Set[int] = set()
|
| 829 |
+
for cycle in requested_cycles:
|
| 830 |
+
if cycle <= 0 or cycle > num_progressive:
|
| 831 |
+
raise SystemExit(
|
| 832 |
+
"--save_full_model_cycles entries must be within [1, --num_progressive]."
|
| 833 |
+
)
|
| 834 |
+
resolved.add(cycle)
|
| 835 |
+
if cycle > 1:
|
| 836 |
+
resolved.add(cycle - 1)
|
| 837 |
+
return resolved
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def load_resume_metadata(model_path: str) -> Optional[Dict[str, object]]:
|
| 841 |
+
resume_meta_path = os.path.join(model_path, "resume_info.json")
|
| 842 |
+
if not os.path.exists(resume_meta_path):
|
| 843 |
+
return None
|
| 844 |
+
with open(resume_meta_path, "r", encoding="utf-8") as handle:
|
| 845 |
+
loaded = json.load(handle)
|
| 846 |
+
return loaded if isinstance(loaded, dict) else None
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
def build_generator(seed: int) -> torch.Generator:
|
| 850 |
+
generator = torch.Generator(device="cpu")
|
| 851 |
+
generator.manual_seed(int(seed))
|
| 852 |
+
return generator
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def capture_rng_state() -> Dict[str, object]:
|
| 856 |
+
state: Dict[str, object] = {
|
| 857 |
+
"python_random_state": random.getstate(),
|
| 858 |
+
"torch_cpu_rng_state": torch.get_rng_state(),
|
| 859 |
+
}
|
| 860 |
+
if np is not None:
|
| 861 |
+
state["numpy_random_state"] = np.random.get_state()
|
| 862 |
+
if torch.cuda.is_available():
|
| 863 |
+
state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
|
| 864 |
+
return state
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
def restore_rng_state(state: Dict[str, object]) -> None:
|
| 868 |
+
python_state = state.get("python_random_state")
|
| 869 |
+
if python_state is not None:
|
| 870 |
+
random.setstate(python_state)
|
| 871 |
+
|
| 872 |
+
numpy_state = state.get("numpy_random_state")
|
| 873 |
+
if numpy_state is not None and np is not None:
|
| 874 |
+
np.random.set_state(numpy_state)
|
| 875 |
+
|
| 876 |
+
torch_cpu_state = state.get("torch_cpu_rng_state")
|
| 877 |
+
if torch_cpu_state is not None:
|
| 878 |
+
torch.set_rng_state(torch_cpu_state)
|
| 879 |
+
|
| 880 |
+
torch_cuda_state = state.get("torch_cuda_rng_state_all")
|
| 881 |
+
if torch_cuda_state is not None and torch.cuda.is_available():
|
| 882 |
+
torch.cuda.set_rng_state_all(torch_cuda_state)
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
def save_rng_state(path: str) -> None:
|
| 886 |
+
torch.save(capture_rng_state(), path)
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def load_rng_state(path: str) -> Optional[Dict[str, object]]:
|
| 890 |
+
if not os.path.exists(path):
|
| 891 |
+
return None
|
| 892 |
+
loaded = torch.load(path, map_location="cpu", weights_only=False)
|
| 893 |
+
return loaded if isinstance(loaded, dict) else None
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def configure_reproducibility(seed: int) -> None:
|
| 897 |
+
random.seed(seed)
|
| 898 |
+
if np is not None:
|
| 899 |
+
np.random.seed(seed)
|
| 900 |
+
torch.manual_seed(seed)
|
| 901 |
+
if torch.cuda.is_available():
|
| 902 |
+
torch.cuda.manual_seed_all(seed)
|
| 903 |
+
if hasattr(torch.backends, "cudnn"):
|
| 904 |
+
torch.backends.cudnn.deterministic = True
|
| 905 |
+
torch.backends.cudnn.benchmark = False
|
| 906 |
+
if hasattr(torch, "use_deterministic_algorithms"):
|
| 907 |
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
def save_loader_generator_state(
|
| 911 |
+
base_dir: str,
|
| 912 |
+
*,
|
| 913 |
+
distill_generator: Optional[torch.Generator] = None,
|
| 914 |
+
lora_generator: Optional[torch.Generator] = None,
|
| 915 |
+
) -> None:
|
| 916 |
+
state: Dict[str, object] = {}
|
| 917 |
+
if distill_generator is not None:
|
| 918 |
+
state["distill_generator_state"] = distill_generator.get_state()
|
| 919 |
+
if lora_generator is not None:
|
| 920 |
+
state["lora_generator_state"] = lora_generator.get_state()
|
| 921 |
+
if state:
|
| 922 |
+
torch.save(state, os.path.join(base_dir, "loader_generators.pt"))
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def load_loader_generator_state(base_dir: str) -> Optional[Dict[str, object]]:
|
| 926 |
+
path = os.path.join(base_dir, "loader_generators.pt")
|
| 927 |
+
if not os.path.exists(path):
|
| 928 |
+
return None
|
| 929 |
+
loaded = torch.load(path, map_location="cpu")
|
| 930 |
+
return loaded if isinstance(loaded, dict) else None
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def resolve_layer_idx(
|
| 934 |
+
args: argparse.Namespace,
|
| 935 |
+
model,
|
| 936 |
+
layers: List[torch.nn.Module],
|
| 937 |
+
dataloader,
|
| 938 |
+
previous_scores,
|
| 939 |
+
start_index: int,
|
| 940 |
+
exclude_pairs: Set[int],
|
| 941 |
+
):
|
| 942 |
+
layer_arg = str(getattr(args, "layer", "auto")).strip().lower()
|
| 943 |
+
selection_method = str(getattr(args, "selection_method", "dwce")).strip().lower()
|
| 944 |
+
|
| 945 |
+
if layer_arg != "auto":
|
| 946 |
+
try:
|
| 947 |
+
layer_idx = int(layer_arg)
|
| 948 |
+
except ValueError as exc:
|
| 949 |
+
raise SystemExit("--layer must be 'auto' or an integer index") from exc
|
| 950 |
+
num_pairs = max(len(layers) - 1, 0)
|
| 951 |
+
if layer_idx < 0:
|
| 952 |
+
layer_idx += num_pairs
|
| 953 |
+
if layer_idx in exclude_pairs:
|
| 954 |
+
raise SystemExit(f"--layer resolved to excluded pair index {layer_idx}")
|
| 955 |
+
return layer_idx, previous_scores, {"method": "manual", "exclude_pairs": sorted(exclude_pairs)}
|
| 956 |
+
|
| 957 |
+
if selection_method == "sequential":
|
| 958 |
+
num_pairs = len(layers) - 1
|
| 959 |
+
for layer_idx in range(max(start_index, 0), num_pairs):
|
| 960 |
+
if layer_idx not in exclude_pairs:
|
| 961 |
+
return layer_idx, previous_scores, {
|
| 962 |
+
"method": "sequential",
|
| 963 |
+
"start_index": max(start_index, 0),
|
| 964 |
+
"exclude_pairs": sorted(exclude_pairs),
|
| 965 |
+
}
|
| 966 |
+
raise SystemExit("No eligible layer pairs remain after exclusions")
|
| 967 |
+
|
| 968 |
+
layer_idx, dwce_scores, dwce_meta = select_layer_auto(
|
| 969 |
+
model,
|
| 970 |
+
layers,
|
| 971 |
+
dataloader,
|
| 972 |
+
args,
|
| 973 |
+
previous_scores=previous_scores,
|
| 974 |
+
start_index=start_index,
|
| 975 |
+
exclude_pairs=exclude_pairs,
|
| 976 |
+
)
|
| 977 |
+
return layer_idx, dwce_scores, dwce_meta
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
@dataclass
|
| 981 |
+
class PreparedData:
|
| 982 |
+
calib_loader: torch.utils.data.DataLoader
|
| 983 |
+
calib_num_sequences: int
|
| 984 |
+
distill_loader: Optional[torch.utils.data.DataLoader]
|
| 985 |
+
distill_generator: Optional[torch.Generator]
|
| 986 |
+
distill_meta: Dict[str, object]
|
| 987 |
+
lora_loader: Optional[torch.utils.data.DataLoader]
|
| 988 |
+
lora_generator: Optional[torch.Generator]
|
| 989 |
+
lora_meta: Dict[str, object]
|
| 990 |
+
eval_datasets: List[str]
|
| 991 |
+
eval_configs: List[Optional[str]]
|
| 992 |
+
eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]]
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
def resolve_eval_datasets(args: argparse.Namespace) -> Tuple[List[str], List[Optional[str]]]:
|
| 996 |
+
eval_datasets = args.eval_dataset or ["wikitext"]
|
| 997 |
+
eval_configs = args.eval_dataset_config or ["wikitext-2-raw-v1"]
|
| 998 |
+
eval_configs = ppl_eval._expand_dataset_configs(eval_datasets, eval_configs)
|
| 999 |
+
return eval_datasets, eval_configs
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
def run_ppl_eval(
|
| 1003 |
+
model_id_or_path: str,
|
| 1004 |
+
eval_datasets: List[str],
|
| 1005 |
+
eval_configs: List[Optional[str]],
|
| 1006 |
+
args: argparse.Namespace,
|
| 1007 |
+
prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
|
| 1008 |
+
) -> Dict[str, float]:
|
| 1009 |
+
eval_device = args.eval_device or args.device
|
| 1010 |
+
dtype = get_dtype(args.dtype)
|
| 1011 |
+
|
| 1012 |
+
eval_model = load_causal_lm(
|
| 1013 |
+
model_id_or_path,
|
| 1014 |
+
torch_dtype=dtype,
|
| 1015 |
+
trust_remote_code=args.trust_remote_code,
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
eval_model.to(eval_device)
|
| 1019 |
+
if prepared_eval_dataloaders is not None:
|
| 1020 |
+
results = ppl_eval.evaluate_ppl_dataloaders(
|
| 1021 |
+
eval_model,
|
| 1022 |
+
prepared_eval_dataloaders,
|
| 1023 |
+
eval_device,
|
| 1024 |
+
max_batches=args.eval_max_batches,
|
| 1025 |
+
)
|
| 1026 |
+
else:
|
| 1027 |
+
eval_batch_size = args.eval_batch_size or args.batch_size
|
| 1028 |
+
eval_tokenizer = AutoTokenizer.from_pretrained(
|
| 1029 |
+
model_id_or_path, trust_remote_code=args.trust_remote_code
|
| 1030 |
+
)
|
| 1031 |
+
if eval_tokenizer.pad_token is None and eval_tokenizer.eos_token is not None:
|
| 1032 |
+
eval_tokenizer.pad_token = eval_tokenizer.eos_token
|
| 1033 |
+
|
| 1034 |
+
results = ppl_eval.evaluate_ppl_datasets(
|
| 1035 |
+
eval_model,
|
| 1036 |
+
eval_tokenizer,
|
| 1037 |
+
datasets=eval_datasets,
|
| 1038 |
+
configs=eval_configs,
|
| 1039 |
+
split=args.eval_split,
|
| 1040 |
+
text_field=args.eval_text_field,
|
| 1041 |
+
num_samples=args.eval_num_samples,
|
| 1042 |
+
seq_len=args.eval_seq_len,
|
| 1043 |
+
batch_size=eval_batch_size,
|
| 1044 |
+
device=eval_device,
|
| 1045 |
+
seed=args.seed,
|
| 1046 |
+
shuffle=False,
|
| 1047 |
+
model_family=args.eval_model_family,
|
| 1048 |
+
add_bos=args.eval_add_bos,
|
| 1049 |
+
max_batches=args.eval_max_batches,
|
| 1050 |
+
cache_dir=args.eval_cache_dir,
|
| 1051 |
+
num_workers=args.eval_num_workers,
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
del eval_model
|
| 1055 |
+
if torch.cuda.is_available():
|
| 1056 |
+
torch.cuda.empty_cache()
|
| 1057 |
+
|
| 1058 |
+
return results
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
def build_calibration_dataloader(
|
| 1062 |
+
args: argparse.Namespace, tokenizer
|
| 1063 |
+
) -> Tuple[List[str], List[torch.Tensor], torch.utils.data.DataLoader]:
|
| 1064 |
+
if args.dataset:
|
| 1065 |
+
datasets = list(args.dataset)
|
| 1066 |
+
configs = expand_dataset_configs(datasets, list(args.dataset_config))
|
| 1067 |
+
chunks: List[torch.Tensor] = []
|
| 1068 |
+
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
|
| 1069 |
+
spec = SharedLMDataSpec(
|
| 1070 |
+
dataset=dataset_name,
|
| 1071 |
+
config=config,
|
| 1072 |
+
split=args.dataset_split,
|
| 1073 |
+
text_field=args.dataset_text_field,
|
| 1074 |
+
seq_len=args.seq_len,
|
| 1075 |
+
num_sequences=args.num_samples,
|
| 1076 |
+
seed=args.seed + idx,
|
| 1077 |
+
)
|
| 1078 |
+
chunks.extend(build_chunks(spec, tokenizer))
|
| 1079 |
+
if not chunks:
|
| 1080 |
+
raise SystemExit("Not enough text to build token sequences.")
|
| 1081 |
+
input_ids = torch.stack(chunks)
|
| 1082 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1083 |
+
dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
|
| 1084 |
+
dataloader = torch.utils.data.DataLoader(
|
| 1085 |
+
dataset, batch_size=args.batch_size, shuffle=False
|
| 1086 |
+
)
|
| 1087 |
+
return [], chunks, dataloader
|
| 1088 |
+
|
| 1089 |
+
texts = load_texts(args)
|
| 1090 |
+
if not texts:
|
| 1091 |
+
raise SystemExit(
|
| 1092 |
+
"No calibration text found. Provide --dataset, --text, or --text_file."
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples)
|
| 1096 |
+
if not chunks:
|
| 1097 |
+
raise SystemExit("Not enough text to build token sequences.")
|
| 1098 |
+
|
| 1099 |
+
input_ids = torch.stack(chunks)
|
| 1100 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1101 |
+
dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
|
| 1102 |
+
dataloader = torch.utils.data.DataLoader(
|
| 1103 |
+
dataset, batch_size=args.batch_size, shuffle=False
|
| 1104 |
+
)
|
| 1105 |
+
return texts, chunks, dataloader
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
def prepare_distillation_data(
|
| 1109 |
+
args: argparse.Namespace, tokenizer, include_instruction: bool = True
|
| 1110 |
+
) -> Tuple[Optional[torch.utils.data.DataLoader], Optional[torch.Generator], Dict[str, object]]:
|
| 1111 |
+
if (
|
| 1112 |
+
include_instruction
|
| 1113 |
+
and args.distill_inst_samples != 0
|
| 1114 |
+
and not args.instruction_dataset
|
| 1115 |
+
):
|
| 1116 |
+
print(
|
| 1117 |
+
"Warning: --distill_inst_samples > 0 but no --instruction_dataset "
|
| 1118 |
+
"provided; instruction distillation will be skipped."
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
calib_texts: List[str] = []
|
| 1122 |
+
calib_dataset = None
|
| 1123 |
+
if args.target_tokens > 0 and args.dataset:
|
| 1124 |
+
datasets = list(args.dataset)
|
| 1125 |
+
configs = expand_dataset_configs(datasets, list(args.dataset_config))
|
| 1126 |
+
per_dataset = args.target_tokens // len(datasets)
|
| 1127 |
+
remainder = args.target_tokens % len(datasets)
|
| 1128 |
+
calib_chunks: List[torch.Tensor] = []
|
| 1129 |
+
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
|
| 1130 |
+
dataset_tokens = per_dataset + (remainder if idx == 0 else 0)
|
| 1131 |
+
spec = SharedLMDataSpec(
|
| 1132 |
+
dataset=dataset_name,
|
| 1133 |
+
config=config,
|
| 1134 |
+
split=args.dataset_split,
|
| 1135 |
+
text_field=args.dataset_text_field,
|
| 1136 |
+
seq_len=args.distill_seq_len,
|
| 1137 |
+
target_tokens=dataset_tokens,
|
| 1138 |
+
seed=args.seed + 17 + idx,
|
| 1139 |
+
)
|
| 1140 |
+
calib_chunks.extend(build_chunks(spec, tokenizer))
|
| 1141 |
+
if calib_chunks:
|
| 1142 |
+
input_ids = torch.stack(calib_chunks)
|
| 1143 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1144 |
+
calib_dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
|
| 1145 |
+
else:
|
| 1146 |
+
calib_texts = load_texts_from_datasets(
|
| 1147 |
+
datasets=list(args.dataset),
|
| 1148 |
+
configs=expand_dataset_configs(list(args.dataset), list(args.dataset_config)),
|
| 1149 |
+
split=args.dataset_split,
|
| 1150 |
+
text_field=args.dataset_text_field,
|
| 1151 |
+
num_samples=args.distill_calib_samples,
|
| 1152 |
+
seed=args.seed + 17,
|
| 1153 |
+
)
|
| 1154 |
+
inst_records = []
|
| 1155 |
+
if include_instruction:
|
| 1156 |
+
inst_records = load_instruction_records(args, args.distill_inst_samples)
|
| 1157 |
+
|
| 1158 |
+
distill_datasets = []
|
| 1159 |
+
if calib_dataset is not None:
|
| 1160 |
+
distill_datasets.append(calib_dataset)
|
| 1161 |
+
elif calib_texts:
|
| 1162 |
+
calib_records = [{"text": text} for text in calib_texts]
|
| 1163 |
+
distill_datasets.append(
|
| 1164 |
+
FixedSeqDataset(calib_records, tokenizer, args.distill_seq_len)
|
| 1165 |
+
)
|
| 1166 |
+
if inst_records:
|
| 1167 |
+
distill_datasets.append(
|
| 1168 |
+
FixedSeqDataset(inst_records, tokenizer, args.distill_seq_len)
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
distill_meta: Dict[str, object] = {
|
| 1172 |
+
"calib_texts": len(calib_texts),
|
| 1173 |
+
"calib_sequences": len(calib_dataset) if calib_dataset is not None else len(calib_texts),
|
| 1174 |
+
"inst_sequences": len(inst_records),
|
| 1175 |
+
"total_sequences": 0,
|
| 1176 |
+
}
|
| 1177 |
+
|
| 1178 |
+
if not distill_datasets:
|
| 1179 |
+
return None, None, distill_meta
|
| 1180 |
+
if len(distill_datasets) == 1:
|
| 1181 |
+
distill_dataset = distill_datasets[0]
|
| 1182 |
+
else:
|
| 1183 |
+
distill_dataset = torch.utils.data.ConcatDataset(distill_datasets)
|
| 1184 |
+
|
| 1185 |
+
distill_meta["total_sequences"] = len(distill_dataset)
|
| 1186 |
+
distill_generator = build_generator(
|
| 1187 |
+
args.seed + 1000 + (1000000 if include_instruction else 0)
|
| 1188 |
+
)
|
| 1189 |
+
distill_loader = torch.utils.data.DataLoader(
|
| 1190 |
+
distill_dataset,
|
| 1191 |
+
batch_size=args.distill_batch_size,
|
| 1192 |
+
shuffle=True,
|
| 1193 |
+
generator=distill_generator,
|
| 1194 |
+
)
|
| 1195 |
+
return distill_loader, distill_generator, distill_meta
|
| 1196 |
+
|
| 1197 |
+
|
| 1198 |
+
def prepare_eval_dataloaders(
|
| 1199 |
+
args: argparse.Namespace,
|
| 1200 |
+
tokenizer,
|
| 1201 |
+
model: torch.nn.Module,
|
| 1202 |
+
eval_datasets: List[str],
|
| 1203 |
+
eval_configs: List[Optional[str]],
|
| 1204 |
+
) -> Optional[Dict[str, torch.utils.data.DataLoader]]:
|
| 1205 |
+
needs_eval = (not args.skip_eval) or (
|
| 1206 |
+
(not args.skip_distill and args.distill_eval_every)
|
| 1207 |
+
or (args.lora_epochs > 0 and args.lora_eval_every)
|
| 1208 |
+
)
|
| 1209 |
+
if not needs_eval:
|
| 1210 |
+
return None
|
| 1211 |
+
|
| 1212 |
+
eval_batch_size = args.eval_batch_size or args.batch_size
|
| 1213 |
+
resolved_family = args.eval_model_family
|
| 1214 |
+
if resolved_family == "auto":
|
| 1215 |
+
resolved_family = ppl_eval._infer_model_family(model)
|
| 1216 |
+
|
| 1217 |
+
return ppl_eval.prepare_ppl_dataloaders(
|
| 1218 |
+
tokenizer=tokenizer,
|
| 1219 |
+
datasets=eval_datasets,
|
| 1220 |
+
configs=eval_configs,
|
| 1221 |
+
split=args.eval_split,
|
| 1222 |
+
text_field=args.eval_text_field,
|
| 1223 |
+
num_samples=args.eval_num_samples,
|
| 1224 |
+
seq_len=args.eval_seq_len,
|
| 1225 |
+
batch_size=eval_batch_size,
|
| 1226 |
+
seed=args.seed,
|
| 1227 |
+
shuffle=False,
|
| 1228 |
+
model_family=resolved_family,
|
| 1229 |
+
add_bos=args.eval_add_bos,
|
| 1230 |
+
cache_dir=args.eval_cache_dir,
|
| 1231 |
+
num_workers=args.eval_num_workers,
|
| 1232 |
+
model=model,
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
def prepare_all_data(
|
| 1237 |
+
args: argparse.Namespace,
|
| 1238 |
+
tokenizer,
|
| 1239 |
+
model: torch.nn.Module,
|
| 1240 |
+
eval_datasets: List[str],
|
| 1241 |
+
eval_configs: List[Optional[str]],
|
| 1242 |
+
loader_generator_state: Optional[Dict[str, object]] = None,
|
| 1243 |
+
) -> PreparedData:
|
| 1244 |
+
texts, chunks, calib_loader = build_calibration_dataloader(args, tokenizer)
|
| 1245 |
+
calib_num_sequences = len(chunks)
|
| 1246 |
+
del texts
|
| 1247 |
+
del chunks
|
| 1248 |
+
|
| 1249 |
+
distill_loader = None
|
| 1250 |
+
distill_generator = None
|
| 1251 |
+
distill_meta: Dict[str, object] = {
|
| 1252 |
+
"calib_texts": 0,
|
| 1253 |
+
"calib_sequences": 0,
|
| 1254 |
+
"inst_sequences": 0,
|
| 1255 |
+
"total_sequences": 0,
|
| 1256 |
+
}
|
| 1257 |
+
lora_loader = None
|
| 1258 |
+
lora_generator = None
|
| 1259 |
+
lora_meta: Dict[str, object] = {
|
| 1260 |
+
"calib_texts": 0,
|
| 1261 |
+
"calib_sequences": 0,
|
| 1262 |
+
"inst_sequences": 0,
|
| 1263 |
+
"total_sequences": 0,
|
| 1264 |
+
}
|
| 1265 |
+
if (not args.skip_distill) or bool(getattr(args, "comm_enabled", False)):
|
| 1266 |
+
distill_loader, distill_generator, distill_meta = prepare_distillation_data(
|
| 1267 |
+
args, tokenizer, include_instruction=False
|
| 1268 |
+
)
|
| 1269 |
+
if (
|
| 1270 |
+
distill_generator is not None
|
| 1271 |
+
and loader_generator_state is not None
|
| 1272 |
+
and loader_generator_state.get("distill_generator_state") is not None
|
| 1273 |
+
):
|
| 1274 |
+
distill_generator.set_state(loader_generator_state["distill_generator_state"])
|
| 1275 |
+
if args.lora_epochs > 0:
|
| 1276 |
+
lora_loader, lora_generator, lora_meta = prepare_distillation_data(
|
| 1277 |
+
args, tokenizer, include_instruction=True
|
| 1278 |
+
)
|
| 1279 |
+
if (
|
| 1280 |
+
lora_generator is not None
|
| 1281 |
+
and loader_generator_state is not None
|
| 1282 |
+
and loader_generator_state.get("lora_generator_state") is not None
|
| 1283 |
+
):
|
| 1284 |
+
lora_generator.set_state(loader_generator_state["lora_generator_state"])
|
| 1285 |
+
|
| 1286 |
+
eval_dataloaders = prepare_eval_dataloaders(
|
| 1287 |
+
args, tokenizer, model, eval_datasets, eval_configs
|
| 1288 |
+
)
|
| 1289 |
+
|
| 1290 |
+
return PreparedData(
|
| 1291 |
+
calib_loader=calib_loader,
|
| 1292 |
+
calib_num_sequences=calib_num_sequences,
|
| 1293 |
+
distill_loader=distill_loader,
|
| 1294 |
+
distill_generator=distill_generator,
|
| 1295 |
+
distill_meta=distill_meta,
|
| 1296 |
+
lora_loader=lora_loader,
|
| 1297 |
+
lora_generator=lora_generator,
|
| 1298 |
+
lora_meta=lora_meta,
|
| 1299 |
+
eval_datasets=eval_datasets,
|
| 1300 |
+
eval_configs=eval_configs,
|
| 1301 |
+
eval_dataloaders=eval_dataloaders,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
|
| 1305 |
+
def evaluate_ppl_model(
|
| 1306 |
+
model: torch.nn.Module,
|
| 1307 |
+
tokenizer,
|
| 1308 |
+
eval_datasets: List[str],
|
| 1309 |
+
eval_configs: List[Optional[str]],
|
| 1310 |
+
args: argparse.Namespace,
|
| 1311 |
+
max_batches: Optional[int] = None,
|
| 1312 |
+
prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
|
| 1313 |
+
) -> Dict[str, float]:
|
| 1314 |
+
eval_device = args.eval_device or args.device
|
| 1315 |
+
prev_mode = model.training
|
| 1316 |
+
try:
|
| 1317 |
+
prev_device = next(model.parameters()).device
|
| 1318 |
+
except StopIteration:
|
| 1319 |
+
prev_device = torch.device(eval_device)
|
| 1320 |
+
|
| 1321 |
+
model.eval()
|
| 1322 |
+
if str(prev_device) != eval_device:
|
| 1323 |
+
model.to(eval_device)
|
| 1324 |
+
|
| 1325 |
+
if prepared_eval_dataloaders is not None:
|
| 1326 |
+
results = ppl_eval.evaluate_ppl_dataloaders(
|
| 1327 |
+
model,
|
| 1328 |
+
prepared_eval_dataloaders,
|
| 1329 |
+
eval_device,
|
| 1330 |
+
max_batches=max_batches if max_batches is not None else args.eval_max_batches,
|
| 1331 |
+
)
|
| 1332 |
+
else:
|
| 1333 |
+
eval_batch_size = args.eval_batch_size or args.batch_size
|
| 1334 |
+
results = ppl_eval.evaluate_ppl_datasets(
|
| 1335 |
+
model,
|
| 1336 |
+
tokenizer,
|
| 1337 |
+
datasets=eval_datasets,
|
| 1338 |
+
configs=eval_configs,
|
| 1339 |
+
split=args.eval_split,
|
| 1340 |
+
text_field=args.eval_text_field,
|
| 1341 |
+
num_samples=args.eval_num_samples,
|
| 1342 |
+
seq_len=args.eval_seq_len,
|
| 1343 |
+
batch_size=eval_batch_size,
|
| 1344 |
+
device=eval_device,
|
| 1345 |
+
seed=args.seed,
|
| 1346 |
+
shuffle=False,
|
| 1347 |
+
model_family=args.eval_model_family,
|
| 1348 |
+
add_bos=args.eval_add_bos,
|
| 1349 |
+
max_batches=max_batches if max_batches is not None else args.eval_max_batches,
|
| 1350 |
+
cache_dir=args.eval_cache_dir,
|
| 1351 |
+
num_workers=args.eval_num_workers,
|
| 1352 |
+
)
|
| 1353 |
+
|
| 1354 |
+
if prev_mode:
|
| 1355 |
+
model.train()
|
| 1356 |
+
if str(prev_device) != eval_device:
|
| 1357 |
+
model.to(prev_device)
|
| 1358 |
+
if torch.cuda.is_available():
|
| 1359 |
+
torch.cuda.empty_cache()
|
| 1360 |
+
|
| 1361 |
+
return results
|
| 1362 |
+
|
| 1363 |
+
|
| 1364 |
+
def has_post_fusion_data(
|
| 1365 |
+
distill_loader: Optional[torch.utils.data.DataLoader],
|
| 1366 |
+
distill_meta: Optional[Dict[str, object]],
|
| 1367 |
+
) -> bool:
|
| 1368 |
+
if distill_loader is None or distill_meta is None:
|
| 1369 |
+
return False
|
| 1370 |
+
return distill_meta.get("total_sequences", 0) > 0
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
def summarize_gate_lambdas(gates: Dict[str, torch.Tensor]) -> Dict[str, object]:
|
| 1374 |
+
if not gates:
|
| 1375 |
+
return {"num_tensors": 0, "num_elements": 0}
|
| 1376 |
+
|
| 1377 |
+
total_sum = 0.0
|
| 1378 |
+
total_elems = 0
|
| 1379 |
+
per_tensor_mean: Dict[str, Optional[float]] = {}
|
| 1380 |
+
for name, gate in gates.items():
|
| 1381 |
+
g = gate.detach().float()
|
| 1382 |
+
if g.numel() == 0:
|
| 1383 |
+
per_tensor_mean[name] = None
|
| 1384 |
+
continue
|
| 1385 |
+
per_tensor_mean[name] = float(g.mean().item())
|
| 1386 |
+
total_sum += float(g.sum().item())
|
| 1387 |
+
total_elems += int(g.numel())
|
| 1388 |
+
|
| 1389 |
+
global_mean = None if total_elems == 0 else total_sum / float(total_elems)
|
| 1390 |
+
return {
|
| 1391 |
+
"num_tensors": len(gates),
|
| 1392 |
+
"num_elements": total_elems,
|
| 1393 |
+
"global_mean": global_mean,
|
| 1394 |
+
"per_tensor_mean": per_tensor_mean,
|
| 1395 |
+
}
|
| 1396 |
+
|
| 1397 |
+
|
| 1398 |
+
def compute_path_bytes(path: str) -> int:
|
| 1399 |
+
if os.path.isfile(path):
|
| 1400 |
+
return os.path.getsize(path)
|
| 1401 |
+
|
| 1402 |
+
total = 0
|
| 1403 |
+
for root, _, files in os.walk(path):
|
| 1404 |
+
for name in files:
|
| 1405 |
+
file_path = os.path.join(root, name)
|
| 1406 |
+
if os.path.islink(file_path):
|
| 1407 |
+
continue
|
| 1408 |
+
try:
|
| 1409 |
+
total += os.path.getsize(file_path)
|
| 1410 |
+
except OSError:
|
| 1411 |
+
continue
|
| 1412 |
+
return total
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
def save_stage_checkpoint(
|
| 1416 |
+
model: torch.nn.Module,
|
| 1417 |
+
tokenizer,
|
| 1418 |
+
stage_dir: str,
|
| 1419 |
+
stage_name: str,
|
| 1420 |
+
ppl_results: Optional[Dict[str, float]],
|
| 1421 |
+
) -> Dict[str, object]:
|
| 1422 |
+
os.makedirs(stage_dir, exist_ok=True)
|
| 1423 |
+
colon_modules = find_colon_modules(model)
|
| 1424 |
+
if colon_modules:
|
| 1425 |
+
raise RuntimeError(
|
| 1426 |
+
"Unexpected module names with ':' detected before save: "
|
| 1427 |
+
f"{', '.join(colon_modules)}."
|
| 1428 |
+
)
|
| 1429 |
+
|
| 1430 |
+
model.save_pretrained(stage_dir)
|
| 1431 |
+
tokenizer.save_pretrained(stage_dir)
|
| 1432 |
+
|
| 1433 |
+
stage_meta = {
|
| 1434 |
+
"stage": stage_name,
|
| 1435 |
+
"path": stage_dir,
|
| 1436 |
+
"weight_bytes": compute_path_bytes(stage_dir),
|
| 1437 |
+
"post_ppl": ppl_results,
|
| 1438 |
+
}
|
| 1439 |
+
with open(
|
| 1440 |
+
os.path.join(stage_dir, "stage_metrics.json"),
|
| 1441 |
+
"w",
|
| 1442 |
+
encoding="utf-8",
|
| 1443 |
+
) as handle:
|
| 1444 |
+
json.dump(stage_meta, handle, indent=2)
|
| 1445 |
+
return stage_meta
|
| 1446 |
+
|
| 1447 |
+
|
| 1448 |
+
def save_cycle_full_model(
|
| 1449 |
+
model: torch.nn.Module,
|
| 1450 |
+
tokenizer,
|
| 1451 |
+
cycle_dir: str,
|
| 1452 |
+
cycle_idx: int,
|
| 1453 |
+
args: argparse.Namespace,
|
| 1454 |
+
ppl_results: Optional[Dict[str, float]],
|
| 1455 |
+
) -> Dict[str, object]:
|
| 1456 |
+
full_model_dir = os.path.join(cycle_dir, "full_model")
|
| 1457 |
+
stage_meta = save_stage_checkpoint(
|
| 1458 |
+
model=model,
|
| 1459 |
+
tokenizer=tokenizer,
|
| 1460 |
+
stage_dir=full_model_dir,
|
| 1461 |
+
stage_name=f"cycle_{cycle_idx}_full_model",
|
| 1462 |
+
ppl_results=ppl_results,
|
| 1463 |
+
)
|
| 1464 |
+
resume_meta = {
|
| 1465 |
+
"base_model": getattr(args, "base_model_id", args.model),
|
| 1466 |
+
"cycle": cycle_idx,
|
| 1467 |
+
"output_dir": args.output_dir,
|
| 1468 |
+
"layer_path": args.layer_path,
|
| 1469 |
+
"rng_state": "rng_state.pt",
|
| 1470 |
+
"loader_generators": "loader_generators.pt",
|
| 1471 |
+
}
|
| 1472 |
+
with open(
|
| 1473 |
+
os.path.join(full_model_dir, "resume_info.json"),
|
| 1474 |
+
"w",
|
| 1475 |
+
encoding="utf-8",
|
| 1476 |
+
) as handle:
|
| 1477 |
+
json.dump(resume_meta, handle, indent=2)
|
| 1478 |
+
stage_meta["resume_info"] = "resume_info.json"
|
| 1479 |
+
return stage_meta
|
| 1480 |
+
|
| 1481 |
+
|
| 1482 |
+
def run_lora_phase(
|
| 1483 |
+
model: torch.nn.Module,
|
| 1484 |
+
tokenizer,
|
| 1485 |
+
eval_datasets: List[str],
|
| 1486 |
+
eval_configs: List[Optional[str]],
|
| 1487 |
+
args: argparse.Namespace,
|
| 1488 |
+
lora_loader: Optional[torch.utils.data.DataLoader] = None,
|
| 1489 |
+
lora_meta: Optional[Dict[str, object]] = None,
|
| 1490 |
+
eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
|
| 1491 |
+
cycle_idx: Optional[int] = None,
|
| 1492 |
+
num_cycles: Optional[int] = None,
|
| 1493 |
+
) -> List[Dict[str, object]]:
|
| 1494 |
+
lora_eval_history: List[Dict[str, object]] = []
|
| 1495 |
+
if args.lora_epochs <= 0:
|
| 1496 |
+
return lora_eval_history
|
| 1497 |
+
if not has_post_fusion_data(lora_loader, lora_meta):
|
| 1498 |
+
print("No post-fusion sequences built; skipping LoRA finetuning.")
|
| 1499 |
+
return lora_eval_history
|
| 1500 |
+
|
| 1501 |
+
lora_ce_finetune(
|
| 1502 |
+
model=model,
|
| 1503 |
+
dataloader=lora_loader,
|
| 1504 |
+
eval_tokenizer=tokenizer,
|
| 1505 |
+
eval_datasets=eval_datasets,
|
| 1506 |
+
eval_configs=eval_configs,
|
| 1507 |
+
eval_history=lora_eval_history,
|
| 1508 |
+
args=args,
|
| 1509 |
+
eval_dataloaders=eval_dataloaders,
|
| 1510 |
+
progressive_cycle=cycle_idx,
|
| 1511 |
+
progressive_total=num_cycles,
|
| 1512 |
+
)
|
| 1513 |
+
return lora_eval_history
|
| 1514 |
+
|
| 1515 |
+
|
| 1516 |
+
def run_progressive(
|
| 1517 |
+
args: argparse.Namespace,
|
| 1518 |
+
model: torch.nn.Module,
|
| 1519 |
+
tokenizer,
|
| 1520 |
+
prepared: PreparedData,
|
| 1521 |
+
) -> None:
|
| 1522 |
+
eval_datasets = prepared.eval_datasets
|
| 1523 |
+
eval_configs = prepared.eval_configs
|
| 1524 |
+
|
| 1525 |
+
dataloader = prepared.calib_loader
|
| 1526 |
+
num_sequences = prepared.calib_num_sequences
|
| 1527 |
+
|
| 1528 |
+
model.to(args.device)
|
| 1529 |
+
|
| 1530 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 1531 |
+
progressive_meta_path = os.path.join(args.output_dir, "progressive_metadata.json")
|
| 1532 |
+
existing_meta: Dict[str, object] = {}
|
| 1533 |
+
if args.resume_from_cycle > 0 and os.path.exists(progressive_meta_path):
|
| 1534 |
+
with open(progressive_meta_path, "r", encoding="utf-8") as handle:
|
| 1535 |
+
loaded_meta = json.load(handle)
|
| 1536 |
+
if isinstance(loaded_meta, dict):
|
| 1537 |
+
existing_meta = loaded_meta
|
| 1538 |
+
|
| 1539 |
+
bootstrap_meta = {
|
| 1540 |
+
"base_model": getattr(args, "base_model_id", args.model),
|
| 1541 |
+
"num_progressive": args.num_progressive,
|
| 1542 |
+
"layer_path": args.layer_path,
|
| 1543 |
+
"resume_from_cycle": args.resume_from_cycle,
|
| 1544 |
+
"save_full_model_cycles": sorted(args.full_model_save_cycles),
|
| 1545 |
+
"cycles": (
|
| 1546 |
+
existing_meta.get("cycles", [])
|
| 1547 |
+
if isinstance(existing_meta.get("cycles"), list)
|
| 1548 |
+
else []
|
| 1549 |
+
),
|
| 1550 |
+
}
|
| 1551 |
+
with open(
|
| 1552 |
+
progressive_meta_path,
|
| 1553 |
+
"w",
|
| 1554 |
+
encoding="utf-8",
|
| 1555 |
+
) as handle:
|
| 1556 |
+
json.dump(bootstrap_meta, handle, indent=2)
|
| 1557 |
+
|
| 1558 |
+
pre_eval = None
|
| 1559 |
+
if not args.skip_eval:
|
| 1560 |
+
pre_eval = evaluate_ppl_model(
|
| 1561 |
+
model,
|
| 1562 |
+
tokenizer,
|
| 1563 |
+
eval_datasets,
|
| 1564 |
+
eval_configs,
|
| 1565 |
+
args,
|
| 1566 |
+
prepared_eval_dataloaders=prepared.eval_dataloaders,
|
| 1567 |
+
)
|
| 1568 |
+
print("Pre-pruning perplexity:")
|
| 1569 |
+
for dataset_name, ppl in pre_eval.items():
|
| 1570 |
+
print(f"{dataset_name}: {ppl:.4f}")
|
| 1571 |
+
|
| 1572 |
+
parent, name, container = find_layer_container(model, args.layer_path)
|
| 1573 |
+
layers = list(container)
|
| 1574 |
+
if args.num_progressive > (len(layers) - 1 + args.resume_from_cycle):
|
| 1575 |
+
raise SystemExit(
|
| 1576 |
+
f"--num_progressive ({args.num_progressive}) exceeds available pairs "
|
| 1577 |
+
f"after resume offset ({len(layers) - 1 + args.resume_from_cycle})"
|
| 1578 |
+
)
|
| 1579 |
+
|
| 1580 |
+
dwce_scores = None
|
| 1581 |
+
dwce_meta = None
|
| 1582 |
+
last_fused_idx = 0
|
| 1583 |
+
cycle_summaries: List[Dict[str, object]] = []
|
| 1584 |
+
existing_cycles = existing_meta.get("cycles", [])
|
| 1585 |
+
if isinstance(existing_cycles, list):
|
| 1586 |
+
for entry in existing_cycles:
|
| 1587 |
+
if not isinstance(entry, dict):
|
| 1588 |
+
continue
|
| 1589 |
+
cycle_value = entry.get("cycle")
|
| 1590 |
+
if isinstance(cycle_value, int) and cycle_value <= args.resume_from_cycle:
|
| 1591 |
+
cycle_summaries.append(entry)
|
| 1592 |
+
|
| 1593 |
+
comm_enabled = bool(getattr(args, "comm_enabled", False))
|
| 1594 |
+
comm_teacher_model = None
|
| 1595 |
+
comm_teacher_cycle: Optional[int] = None
|
| 1596 |
+
teacher_device = args.distill_teacher_device or args.device
|
| 1597 |
+
previous_cycle_teacher_model = None
|
| 1598 |
+
previous_cycle_teacher_cycle: Optional[int] = None
|
| 1599 |
+
|
| 1600 |
+
def _release_comm_teacher() -> None:
|
| 1601 |
+
nonlocal comm_teacher_model, comm_teacher_cycle
|
| 1602 |
+
if comm_teacher_model is not None:
|
| 1603 |
+
del comm_teacher_model
|
| 1604 |
+
comm_teacher_model = None
|
| 1605 |
+
comm_teacher_cycle = None
|
| 1606 |
+
if torch.cuda.is_available():
|
| 1607 |
+
torch.cuda.empty_cache()
|
| 1608 |
+
|
| 1609 |
+
def _release_previous_cycle_teacher() -> None:
|
| 1610 |
+
nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle
|
| 1611 |
+
if previous_cycle_teacher_model is not None:
|
| 1612 |
+
del previous_cycle_teacher_model
|
| 1613 |
+
previous_cycle_teacher_model = None
|
| 1614 |
+
previous_cycle_teacher_cycle = None
|
| 1615 |
+
if torch.cuda.is_available():
|
| 1616 |
+
torch.cuda.empty_cache()
|
| 1617 |
+
|
| 1618 |
+
def _snapshot_previous_cycle_teacher(cycle_idx: int) -> None:
|
| 1619 |
+
nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle
|
| 1620 |
+
_release_previous_cycle_teacher()
|
| 1621 |
+
previous_cycle_teacher_model = copy.deepcopy(model)
|
| 1622 |
+
previous_cycle_teacher_model.to(teacher_device)
|
| 1623 |
+
previous_cycle_teacher_model.eval()
|
| 1624 |
+
previous_cycle_teacher_cycle = cycle_idx
|
| 1625 |
+
|
| 1626 |
+
def _get_previous_cycle_teacher(
|
| 1627 |
+
cycle_idx: int,
|
| 1628 |
+
) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]:
|
| 1629 |
+
prev_cycle = cycle_idx - 1
|
| 1630 |
+
if prev_cycle <= 0:
|
| 1631 |
+
return None, "base_model", 0
|
| 1632 |
+
if (
|
| 1633 |
+
previous_cycle_teacher_model is not None
|
| 1634 |
+
and previous_cycle_teacher_cycle == prev_cycle
|
| 1635 |
+
):
|
| 1636 |
+
return previous_cycle_teacher_model, "previous_cycle_memory", prev_cycle
|
| 1637 |
+
teacher_model = load_progressive_model(
|
| 1638 |
+
getattr(args, "base_model_id", args.model),
|
| 1639 |
+
args.output_dir,
|
| 1640 |
+
cycle=prev_cycle,
|
| 1641 |
+
device=teacher_device,
|
| 1642 |
+
dtype=args.dtype,
|
| 1643 |
+
trust_remote_code=args.trust_remote_code,
|
| 1644 |
+
layer_path=args.layer_path,
|
| 1645 |
+
)
|
| 1646 |
+
teacher_model.eval()
|
| 1647 |
+
return teacher_model, "previous_cycle_disk", prev_cycle
|
| 1648 |
+
|
| 1649 |
+
def _get_comm_teacher(cycle_idx: int) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]:
|
| 1650 |
+
nonlocal comm_teacher_model, comm_teacher_cycle
|
| 1651 |
+
if not comm_enabled:
|
| 1652 |
+
return None, "disabled", None
|
| 1653 |
+
|
| 1654 |
+
source = str(getattr(args, "redistrib_teacher_source", "base_model"))
|
| 1655 |
+
if source == "base_model":
|
| 1656 |
+
if comm_teacher_model is None:
|
| 1657 |
+
print(
|
| 1658 |
+
"[comm] Loading fixed base teacher for anchor loss "
|
| 1659 |
+
f"(device={teacher_device})."
|
| 1660 |
+
)
|
| 1661 |
+
comm_teacher_model = AutoModelForCausalLM.from_pretrained(
|
| 1662 |
+
getattr(args, "base_model_id", args.model),
|
| 1663 |
+
torch_dtype=get_dtype(args.dtype),
|
| 1664 |
+
trust_remote_code=args.trust_remote_code,
|
| 1665 |
+
)
|
| 1666 |
+
comm_teacher_model.to(teacher_device)
|
| 1667 |
+
comm_teacher_model.eval()
|
| 1668 |
+
comm_teacher_cycle = 0
|
| 1669 |
+
return comm_teacher_model, "base_model", 0
|
| 1670 |
+
|
| 1671 |
+
prev_cycle = cycle_idx - 1
|
| 1672 |
+
if prev_cycle <= 0:
|
| 1673 |
+
if comm_teacher_model is None or comm_teacher_cycle != 0:
|
| 1674 |
+
_release_comm_teacher()
|
| 1675 |
+
print(
|
| 1676 |
+
"[comm] --redistrib_teacher_source=previous_cycle but cycle 1 "
|
| 1677 |
+
"has no prior checkpoint; using base teacher."
|
| 1678 |
+
)
|
| 1679 |
+
comm_teacher_model = AutoModelForCausalLM.from_pretrained(
|
| 1680 |
+
getattr(args, "base_model_id", args.model),
|
| 1681 |
+
torch_dtype=get_dtype(args.dtype),
|
| 1682 |
+
trust_remote_code=args.trust_remote_code,
|
| 1683 |
+
)
|
| 1684 |
+
comm_teacher_model.to(teacher_device)
|
| 1685 |
+
comm_teacher_model.eval()
|
| 1686 |
+
comm_teacher_cycle = 0
|
| 1687 |
+
return comm_teacher_model, "base_model", 0
|
| 1688 |
+
|
| 1689 |
+
if (
|
| 1690 |
+
previous_cycle_teacher_model is not None
|
| 1691 |
+
and previous_cycle_teacher_cycle == prev_cycle
|
| 1692 |
+
):
|
| 1693 |
+
if comm_teacher_model is not previous_cycle_teacher_model:
|
| 1694 |
+
_release_comm_teacher()
|
| 1695 |
+
comm_teacher_model = previous_cycle_teacher_model
|
| 1696 |
+
comm_teacher_cycle = prev_cycle
|
| 1697 |
+
return comm_teacher_model, "previous_cycle_memory", prev_cycle
|
| 1698 |
+
|
| 1699 |
+
if comm_teacher_model is None or comm_teacher_cycle != prev_cycle:
|
| 1700 |
+
_release_comm_teacher()
|
| 1701 |
+
print(
|
| 1702 |
+
"[comm] Loading teacher from previous cycle "
|
| 1703 |
+
f"{prev_cycle} (device={teacher_device})."
|
| 1704 |
+
)
|
| 1705 |
+
comm_teacher_model = load_progressive_model(
|
| 1706 |
+
getattr(args, "base_model_id", args.model),
|
| 1707 |
+
args.output_dir,
|
| 1708 |
+
cycle=prev_cycle,
|
| 1709 |
+
device=teacher_device,
|
| 1710 |
+
dtype=args.dtype,
|
| 1711 |
+
trust_remote_code=args.trust_remote_code,
|
| 1712 |
+
layer_path=args.layer_path,
|
| 1713 |
+
)
|
| 1714 |
+
comm_teacher_model.eval()
|
| 1715 |
+
comm_teacher_cycle = prev_cycle
|
| 1716 |
+
return comm_teacher_model, "previous_cycle_disk", prev_cycle
|
| 1717 |
+
|
| 1718 |
+
if args.resume_from_cycle > 0:
|
| 1719 |
+
_snapshot_previous_cycle_teacher(args.resume_from_cycle)
|
| 1720 |
+
|
| 1721 |
+
start_cycle = args.resume_from_cycle + 1
|
| 1722 |
+
for cycle_idx in range(start_cycle, args.num_progressive + 1):
|
| 1723 |
+
print(f"[progressive] Cycle {cycle_idx}/{args.num_progressive}")
|
| 1724 |
+
run_comm = comm_enabled and (
|
| 1725 |
+
cycle_idx > 1 or bool(getattr(args, "comm_include_cycle1", False))
|
| 1726 |
+
)
|
| 1727 |
+
comm_stats: Dict[str, object] = {"enabled": False}
|
| 1728 |
+
comm_post_eval = None
|
| 1729 |
+
if run_comm:
|
| 1730 |
+
# Preconditioning updates model weights, so DWCE reuse is unreliable.
|
| 1731 |
+
start_index = 0
|
| 1732 |
+
reuse_scores = None
|
| 1733 |
+
else:
|
| 1734 |
+
start_index = last_fused_idx if cycle_idx > 1 else 0
|
| 1735 |
+
reuse_scores = dwce_scores
|
| 1736 |
+
exclude_pairs = set(parse_exclude_pairs(args.exclude_pairs, max(len(layers) - 1, 0)))
|
| 1737 |
+
layer_idx, dwce_scores, dwce_meta = resolve_layer_idx(
|
| 1738 |
+
args,
|
| 1739 |
+
model,
|
| 1740 |
+
layers,
|
| 1741 |
+
dataloader,
|
| 1742 |
+
reuse_scores,
|
| 1743 |
+
start_index,
|
| 1744 |
+
exclude_pairs,
|
| 1745 |
+
)
|
| 1746 |
+
|
| 1747 |
+
if run_comm:
|
| 1748 |
+
dwce_scores_pre_comm = dwce_scores
|
| 1749 |
+
if prepared.calib_loader is None:
|
| 1750 |
+
print(
|
| 1751 |
+
"[comm] Enabled but no calibration sequences were built; skipping."
|
| 1752 |
+
)
|
| 1753 |
+
else:
|
| 1754 |
+
(
|
| 1755 |
+
comm_teacher_model_loaded,
|
| 1756 |
+
comm_teacher_source,
|
| 1757 |
+
comm_teacher_cycle_idx,
|
| 1758 |
+
) = _get_comm_teacher(cycle_idx)
|
| 1759 |
+
if comm_teacher_model_loaded is None:
|
| 1760 |
+
raise RuntimeError("comm_enabled but teacher model was not loaded.")
|
| 1761 |
+
comm_stats = commutator_precondition(
|
| 1762 |
+
student_model=model,
|
| 1763 |
+
student_layers=layers,
|
| 1764 |
+
teacher_model=comm_teacher_model_loaded,
|
| 1765 |
+
dataloader=prepared.calib_loader,
|
| 1766 |
+
dwce_scores=dwce_scores_pre_comm,
|
| 1767 |
+
exclude_pairs=exclude_pairs,
|
| 1768 |
+
args=args,
|
| 1769 |
+
progressive_cycle=cycle_idx,
|
| 1770 |
+
progressive_total=args.num_progressive,
|
| 1771 |
+
)
|
| 1772 |
+
if comm_stats.get("enabled"):
|
| 1773 |
+
comm_stats["teacher_source"] = comm_teacher_source
|
| 1774 |
+
comm_stats["teacher_cycle"] = comm_teacher_cycle_idx
|
| 1775 |
+
comm_stats["dwce_scores_pre"] = dwce_scores_pre_comm
|
| 1776 |
+
print(
|
| 1777 |
+
"[comm] Done:"
|
| 1778 |
+
f" opt_steps={comm_stats.get('opt_steps')}"
|
| 1779 |
+
f" lr={comm_stats.get('lr')}"
|
| 1780 |
+
)
|
| 1781 |
+
if not args.skip_eval:
|
| 1782 |
+
comm_post_eval = evaluate_ppl_model(
|
| 1783 |
+
model,
|
| 1784 |
+
tokenizer,
|
| 1785 |
+
eval_datasets,
|
| 1786 |
+
eval_configs,
|
| 1787 |
+
args,
|
| 1788 |
+
prepared_eval_dataloaders=prepared.eval_dataloaders,
|
| 1789 |
+
)
|
| 1790 |
+
comm_stats["post_ppl"] = comm_post_eval
|
| 1791 |
+
print(f"[progressive] Cycle {cycle_idx} post-comm perplexity:")
|
| 1792 |
+
for dataset_name, ppl in comm_post_eval.items():
|
| 1793 |
+
print(f"{dataset_name}: {ppl:.4f}")
|
| 1794 |
+
|
| 1795 |
+
if bool(getattr(args, "comm_skip_post_reselect", False)):
|
| 1796 |
+
comm_stats["post_selection_recomputed"] = False
|
| 1797 |
+
comm_stats["selected_layer_post"] = int(layer_idx)
|
| 1798 |
+
print(
|
| 1799 |
+
"[comm] Keeping pre-comm DWCE pair selection for fusion."
|
| 1800 |
+
)
|
| 1801 |
+
else:
|
| 1802 |
+
print(
|
| 1803 |
+
"[comm] Recomputing DWCE after preconditioning for fusion selection."
|
| 1804 |
+
)
|
| 1805 |
+
layer_idx, dwce_scores, dwce_meta = resolve_layer_idx(
|
| 1806 |
+
args,
|
| 1807 |
+
model,
|
| 1808 |
+
layers,
|
| 1809 |
+
dataloader,
|
| 1810 |
+
None,
|
| 1811 |
+
0,
|
| 1812 |
+
exclude_pairs,
|
| 1813 |
+
)
|
| 1814 |
+
comm_stats["post_selection_recomputed"] = True
|
| 1815 |
+
comm_stats["selected_layer_post"] = int(layer_idx)
|
| 1816 |
+
|
| 1817 |
+
if layer_idx < 0 or layer_idx >= len(layers) - 1:
|
| 1818 |
+
raise SystemExit("--layer must be in [0, num_layers-2]")
|
| 1819 |
+
|
| 1820 |
+
num_layers_before = len(layers)
|
| 1821 |
+
layer_a = layers[layer_idx]
|
| 1822 |
+
layer_b = layers[layer_idx + 1]
|
| 1823 |
+
|
| 1824 |
+
norm1_state = None
|
| 1825 |
+
norm2_state = None
|
| 1826 |
+
norm1, norm2, norm_names = get_norm_pair(layer_a)
|
| 1827 |
+
if norm1 is not None:
|
| 1828 |
+
norm1_state = clone_state_dict(norm1)
|
| 1829 |
+
if norm2 is not None:
|
| 1830 |
+
norm2_state = clone_state_dict(norm2)
|
| 1831 |
+
|
| 1832 |
+
attn_a = find_attention_module(layer_a)
|
| 1833 |
+
attn_b = find_attention_module(layer_b)
|
| 1834 |
+
hidden_size = getattr(model.config, "hidden_size", None)
|
| 1835 |
+
if hidden_size is None:
|
| 1836 |
+
hidden_size = getattr(model.config, "n_embd", None)
|
| 1837 |
+
if hidden_size is None:
|
| 1838 |
+
raise SystemExit("Model config missing hidden_size/n_embd")
|
| 1839 |
+
|
| 1840 |
+
no_head_permute_merge = bool(
|
| 1841 |
+
getattr(args, "no_head_permute_merge", False)
|
| 1842 |
+
or getattr(args, "no_head_permute", False)
|
| 1843 |
+
)
|
| 1844 |
+
if no_head_permute_merge:
|
| 1845 |
+
print("[fuse] Head permutation disabled; merging with original head order.")
|
| 1846 |
+
else:
|
| 1847 |
+
mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
|
| 1848 |
+
model,
|
| 1849 |
+
attn_a,
|
| 1850 |
+
attn_b,
|
| 1851 |
+
dataloader,
|
| 1852 |
+
args.device,
|
| 1853 |
+
hidden_size,
|
| 1854 |
+
)
|
| 1855 |
+
|
| 1856 |
+
perm = build_head_permutation(
|
| 1857 |
+
mean_a,
|
| 1858 |
+
mean_b,
|
| 1859 |
+
num_heads=num_heads,
|
| 1860 |
+
num_kv_heads=num_kv_heads,
|
| 1861 |
+
eps=args.eps,
|
| 1862 |
+
)
|
| 1863 |
+
permute_attention_heads(
|
| 1864 |
+
attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim
|
| 1865 |
+
)
|
| 1866 |
+
|
| 1867 |
+
fisher_sums, num_batches, param_numels = compute_fisher(
|
| 1868 |
+
model,
|
| 1869 |
+
layer_a,
|
| 1870 |
+
layer_b,
|
| 1871 |
+
dataloader,
|
| 1872 |
+
fisher_mode=args.fisher_mode,
|
| 1873 |
+
device=args.device,
|
| 1874 |
+
)
|
| 1875 |
+
|
| 1876 |
+
distill_ready = has_post_fusion_data(
|
| 1877 |
+
prepared.distill_loader, prepared.distill_meta
|
| 1878 |
+
)
|
| 1879 |
+
teacher_cycle = cycle_idx - 1
|
| 1880 |
+
teacher_source = "previous_cycle" if teacher_cycle > 0 else "base_model"
|
| 1881 |
+
|
| 1882 |
+
merge_method = "fisher"
|
| 1883 |
+
distill_method = str(getattr(args, "distill_method", "reparam"))
|
| 1884 |
+
reparam_stats: Optional[Dict[str, object]] = None
|
| 1885 |
+
reparam_gate_summary: Optional[Dict[str, object]] = None
|
| 1886 |
+
|
| 1887 |
+
needs_teacher_for_reparam = (
|
| 1888 |
+
(not args.skip_distill)
|
| 1889 |
+
and distill_ready
|
| 1890 |
+
and float(args.distill_epochs) > 0.0
|
| 1891 |
+
)
|
| 1892 |
+
teacher_model = None
|
| 1893 |
+
teacher_parent = None
|
| 1894 |
+
teacher_layer_attr = None
|
| 1895 |
+
teacher_layers: Optional[List[torch.nn.Module]] = None
|
| 1896 |
+
teacher_from_cache = False
|
| 1897 |
+
if needs_teacher_for_reparam:
|
| 1898 |
+
teacher_model, teacher_source, teacher_cycle = _get_previous_cycle_teacher(
|
| 1899 |
+
cycle_idx
|
| 1900 |
+
)
|
| 1901 |
+
teacher_from_cache = (
|
| 1902 |
+
teacher_source == "previous_cycle_memory"
|
| 1903 |
+
and teacher_model is previous_cycle_teacher_model
|
| 1904 |
+
)
|
| 1905 |
+
if teacher_model is None:
|
| 1906 |
+
teacher_model = load_causal_lm(
|
| 1907 |
+
getattr(args, "base_model_id", args.model),
|
| 1908 |
+
torch_dtype=get_dtype(args.dtype),
|
| 1909 |
+
trust_remote_code=args.trust_remote_code,
|
| 1910 |
+
cache_dir=args.model_cache_dir,
|
| 1911 |
+
)
|
| 1912 |
+
teacher_model.to(teacher_device)
|
| 1913 |
+
teacher_model.eval()
|
| 1914 |
+
teacher_source = "base_model"
|
| 1915 |
+
teacher_cycle = 0
|
| 1916 |
+
teacher_parent, teacher_layer_attr, teacher_container = find_layer_container(
|
| 1917 |
+
teacher_model, args.layer_path
|
| 1918 |
+
)
|
| 1919 |
+
teacher_layers = list(teacher_container)
|
| 1920 |
+
|
| 1921 |
+
do_reparam = (
|
| 1922 |
+
(not args.skip_distill)
|
| 1923 |
+
and distill_ready
|
| 1924 |
+
and prepared.distill_loader is not None
|
| 1925 |
+
)
|
| 1926 |
+
if (not args.skip_distill) and not do_reparam:
|
| 1927 |
+
print("[reparam] No distillation sequences built; skipping reparam distill.")
|
| 1928 |
+
distill_post_eval = None
|
| 1929 |
+
if do_reparam:
|
| 1930 |
+
lambda_source = "fisher_prior"
|
| 1931 |
+
reparam_gate_targets: Dict[str, object] = compute_fisher_gate_priors(
|
| 1932 |
+
layer_a=layer_a,
|
| 1933 |
+
layer_b=layer_b,
|
| 1934 |
+
fisher_a=fisher_sums[0],
|
| 1935 |
+
fisher_b=fisher_sums[1],
|
| 1936 |
+
num_batches=num_batches,
|
| 1937 |
+
numels_a=param_numels[0],
|
| 1938 |
+
numels_b=param_numels[1],
|
| 1939 |
+
fisher_mode=args.fisher_mode,
|
| 1940 |
+
eps=float(args.eps),
|
| 1941 |
+
)
|
| 1942 |
+
if not reparam_gate_targets:
|
| 1943 |
+
raise SystemExit("[reparam] No mergeable parameters found; cannot continue.")
|
| 1944 |
+
if float(args.distill_epochs) > 0.0 and (
|
| 1945 |
+
teacher_model is None or teacher_layers is None
|
| 1946 |
+
):
|
| 1947 |
+
raise SystemExit("--distill_method reparam requires a teacher model.")
|
| 1948 |
+
print(
|
| 1949 |
+
f"[reparam] Cycle {cycle_idx}: training U + gates for pair "
|
| 1950 |
+
f"{layer_idx}-{layer_idx + 1} (epochs={args.distill_epochs}, "
|
| 1951 |
+
f"hidden_mse_w={args.distill_hidden_mse_weight}, "
|
| 1952 |
+
f"attn_mse_w={args.distill_attn_mse_weight}, "
|
| 1953 |
+
f"mlp_mse_w={args.distill_mlp_mse_weight}, "
|
| 1954 |
+
f"eta={args.reparam_eta}, gamma={args.reparam_gamma}, "
|
| 1955 |
+
f"attn_reg_scale={args.reparam_attn_reg_scale}, "
|
| 1956 |
+
f"mlp_reg_scale={args.reparam_mlp_reg_scale}, "
|
| 1957 |
+
f"param_subset={args.reparam_param_subset}, "
|
| 1958 |
+
f"lambda_init={lambda_source})."
|
| 1959 |
+
)
|
| 1960 |
+
merged, final_gates, reparam_stats = distill_reparam_merge(
|
| 1961 |
+
student_model=model,
|
| 1962 |
+
student_parent=parent,
|
| 1963 |
+
student_layer_attr=name,
|
| 1964 |
+
student_layers=layers,
|
| 1965 |
+
teacher_model=teacher_model,
|
| 1966 |
+
teacher_parent=teacher_parent,
|
| 1967 |
+
teacher_layer_attr=teacher_layer_attr,
|
| 1968 |
+
teacher_layers=teacher_layers,
|
| 1969 |
+
layer_idx=layer_idx,
|
| 1970 |
+
gate_lambdas=reparam_gate_targets,
|
| 1971 |
+
dataloader=prepared.distill_loader,
|
| 1972 |
+
args=args,
|
| 1973 |
+
progressive_cycle=cycle_idx,
|
| 1974 |
+
progressive_total=args.num_progressive,
|
| 1975 |
+
)
|
| 1976 |
+
reparam_gate_summary = summarize_gate_lambdas(final_gates)
|
| 1977 |
+
merge_method = "reparam"
|
| 1978 |
+
if reparam_stats is not None:
|
| 1979 |
+
reparam_stats["lambda_init"] = lambda_source
|
| 1980 |
+
else:
|
| 1981 |
+
merged = merge_layers(
|
| 1982 |
+
layer_a,
|
| 1983 |
+
layer_b,
|
| 1984 |
+
fisher_sums[0],
|
| 1985 |
+
fisher_sums[1],
|
| 1986 |
+
num_batches,
|
| 1987 |
+
param_numels[0],
|
| 1988 |
+
param_numels[1],
|
| 1989 |
+
fisher_mode=args.fisher_mode,
|
| 1990 |
+
eps=args.eps,
|
| 1991 |
+
)
|
| 1992 |
+
|
| 1993 |
+
apply_norm_policy(
|
| 1994 |
+
layer_a,
|
| 1995 |
+
args.norm_policy,
|
| 1996 |
+
norm1_state,
|
| 1997 |
+
norm2_state,
|
| 1998 |
+
norm_names,
|
| 1999 |
+
)
|
| 2000 |
+
if teacher_model is not None and not teacher_from_cache:
|
| 2001 |
+
del teacher_model
|
| 2002 |
+
teacher_model = None
|
| 2003 |
+
teacher_parent = None
|
| 2004 |
+
teacher_layer_attr = None
|
| 2005 |
+
teacher_layers = None
|
| 2006 |
+
if torch.cuda.is_available():
|
| 2007 |
+
torch.cuda.empty_cache()
|
| 2008 |
+
|
| 2009 |
+
new_container = drop_layer(container, layer_idx + 1)
|
| 2010 |
+
setattr(parent, name, new_container)
|
| 2011 |
+
decrement_config(model.config)
|
| 2012 |
+
layers = list(new_container)
|
| 2013 |
+
|
| 2014 |
+
lora_post_eval = None
|
| 2015 |
+
if (not args.skip_eval) and (not args.skip_distill) and do_reparam:
|
| 2016 |
+
distill_post_eval = evaluate_ppl_model(
|
| 2017 |
+
model,
|
| 2018 |
+
tokenizer,
|
| 2019 |
+
eval_datasets,
|
| 2020 |
+
eval_configs,
|
| 2021 |
+
args,
|
| 2022 |
+
prepared_eval_dataloaders=prepared.eval_dataloaders,
|
| 2023 |
+
)
|
| 2024 |
+
print(f"[progressive] Cycle {cycle_idx} post-distill perplexity:")
|
| 2025 |
+
for dataset_name, ppl in distill_post_eval.items():
|
| 2026 |
+
print(f"{dataset_name}: {ppl:.4f}")
|
| 2027 |
+
|
| 2028 |
+
post_eval = None
|
| 2029 |
+
if not args.skip_eval:
|
| 2030 |
+
if distill_post_eval is not None:
|
| 2031 |
+
post_eval = distill_post_eval
|
| 2032 |
+
else:
|
| 2033 |
+
post_eval = evaluate_ppl_model(
|
| 2034 |
+
model,
|
| 2035 |
+
tokenizer,
|
| 2036 |
+
eval_datasets,
|
| 2037 |
+
eval_configs,
|
| 2038 |
+
args,
|
| 2039 |
+
prepared_eval_dataloaders=prepared.eval_dataloaders,
|
| 2040 |
+
)
|
| 2041 |
+
print(f"[progressive] Cycle {cycle_idx} perplexity:")
|
| 2042 |
+
for dataset_name, ppl in post_eval.items():
|
| 2043 |
+
print(f"{dataset_name}: {ppl:.4f}")
|
| 2044 |
+
|
| 2045 |
+
cycle_dir = os.path.join(args.output_dir, f"cycle_{cycle_idx}")
|
| 2046 |
+
os.makedirs(cycle_dir, exist_ok=True)
|
| 2047 |
+
fused_layer_file = "fused_layer.pt"
|
| 2048 |
+
fused_layer_path = os.path.join(cycle_dir, fused_layer_file)
|
| 2049 |
+
torch.save(layers[layer_idx].state_dict(), fused_layer_path)
|
| 2050 |
+
|
| 2051 |
+
cycle_meta: Dict[str, object] = {
|
| 2052 |
+
"cycle": cycle_idx,
|
| 2053 |
+
"layer_merged": layer_idx,
|
| 2054 |
+
"num_layers_before": num_layers_before,
|
| 2055 |
+
"num_layers_after": num_layers_before - 1,
|
| 2056 |
+
"fused_layer_state": fused_layer_file,
|
| 2057 |
+
"dwce_score": dwce_scores[layer_idx] if dwce_scores else None,
|
| 2058 |
+
"dwce_scores": dwce_scores,
|
| 2059 |
+
"dwce_meta": dwce_meta,
|
| 2060 |
+
"fisher_num_batches": num_batches,
|
| 2061 |
+
"merge_method": merge_method,
|
| 2062 |
+
"merged_params": merged,
|
| 2063 |
+
"num_sequences": num_sequences,
|
| 2064 |
+
"teacher_source": teacher_source,
|
| 2065 |
+
"teacher_cycle": teacher_cycle,
|
| 2066 |
+
"eval": {
|
| 2067 |
+
"datasets": eval_datasets,
|
| 2068 |
+
"configs": eval_configs,
|
| 2069 |
+
"split": args.eval_split,
|
| 2070 |
+
"num_samples": args.eval_num_samples,
|
| 2071 |
+
"seq_len": args.eval_seq_len,
|
| 2072 |
+
"post_ppl": post_eval,
|
| 2073 |
+
},
|
| 2074 |
+
"comm": comm_stats,
|
| 2075 |
+
"distill": {
|
| 2076 |
+
"enabled": not args.skip_distill,
|
| 2077 |
+
"method": distill_method,
|
| 2078 |
+
"calib_samples": args.distill_calib_samples,
|
| 2079 |
+
"inst_samples": args.distill_inst_samples,
|
| 2080 |
+
"seq_len": args.distill_seq_len,
|
| 2081 |
+
"batch_size": args.distill_batch_size,
|
| 2082 |
+
"epochs": args.distill_epochs,
|
| 2083 |
+
"lr": args.distill_lr,
|
| 2084 |
+
"kl_weight": args.distill_kl_weight,
|
| 2085 |
+
"kl_temp": args.distill_kl_temp,
|
| 2086 |
+
"hidden_mse_weight": args.distill_hidden_mse_weight,
|
| 2087 |
+
"attn_mse_weight": args.distill_attn_mse_weight,
|
| 2088 |
+
"mlp_mse_weight": args.distill_mlp_mse_weight,
|
| 2089 |
+
"reparam_eta": args.reparam_eta,
|
| 2090 |
+
"reparam_gamma": args.reparam_gamma,
|
| 2091 |
+
"reparam_attn_reg_scale": args.reparam_attn_reg_scale,
|
| 2092 |
+
"reparam_mlp_reg_scale": args.reparam_mlp_reg_scale,
|
| 2093 |
+
"reparam_param_subset": args.reparam_param_subset,
|
| 2094 |
+
"reparam_stats": reparam_stats,
|
| 2095 |
+
"reparam_gate_summary": reparam_gate_summary,
|
| 2096 |
+
"post_ppl": distill_post_eval,
|
| 2097 |
+
"weight_decay": args.distill_weight_decay,
|
| 2098 |
+
"max_grad_norm": args.distill_max_grad_norm,
|
| 2099 |
+
"grad_accum_steps": args.distill_grad_accum_steps,
|
| 2100 |
+
"instruction_dataset": args.instruction_dataset,
|
| 2101 |
+
"instruction_config": args.instruction_config,
|
| 2102 |
+
"instruction_split": args.instruction_split,
|
| 2103 |
+
},
|
| 2104 |
+
"lora": {
|
| 2105 |
+
"enabled": args.lora_epochs > 0,
|
| 2106 |
+
"seq_len": args.distill_seq_len,
|
| 2107 |
+
"batch_size": args.distill_batch_size,
|
| 2108 |
+
"epochs": args.lora_epochs,
|
| 2109 |
+
"rank": args.lora_rank,
|
| 2110 |
+
"alpha": args.lora_alpha,
|
| 2111 |
+
"dropout": args.lora_dropout,
|
| 2112 |
+
"target_modules": args.lora_target_modules,
|
| 2113 |
+
"respect_exclude_pairs": args.lora_respect_exclude_pairs,
|
| 2114 |
+
"kl_enabled": args.lora_kl_enabled,
|
| 2115 |
+
"kl_weight": args.lora_kl_weight,
|
| 2116 |
+
"kl_temp": args.lora_kl_temp,
|
| 2117 |
+
"post_ppl": lora_post_eval,
|
| 2118 |
+
"lr": args.lora_lr,
|
| 2119 |
+
"weight_decay": args.lora_weight_decay,
|
| 2120 |
+
"max_grad_norm": args.lora_max_grad_norm,
|
| 2121 |
+
"grad_accum_steps": args.lora_grad_accum_steps,
|
| 2122 |
+
"log_steps": args.lora_log_steps,
|
| 2123 |
+
"eval_every": args.lora_eval_every,
|
| 2124 |
+
"eval_max_batches": args.lora_eval_max_batches,
|
| 2125 |
+
},
|
| 2126 |
+
"norm_policy": args.norm_policy,
|
| 2127 |
+
}
|
| 2128 |
+
|
| 2129 |
+
saved_full_model_dir = None
|
| 2130 |
+
if cycle_idx in args.full_model_save_cycles:
|
| 2131 |
+
cycle_meta["full_model_saved"] = True
|
| 2132 |
+
cycle_meta["full_model"] = save_cycle_full_model(
|
| 2133 |
+
model=model,
|
| 2134 |
+
tokenizer=tokenizer,
|
| 2135 |
+
cycle_dir=cycle_dir,
|
| 2136 |
+
cycle_idx=cycle_idx,
|
| 2137 |
+
args=args,
|
| 2138 |
+
ppl_results=post_eval,
|
| 2139 |
+
)
|
| 2140 |
+
saved_full_model_dir = os.path.join(cycle_dir, "full_model")
|
| 2141 |
+
else:
|
| 2142 |
+
cycle_meta["full_model_saved"] = False
|
| 2143 |
+
|
| 2144 |
+
with open(
|
| 2145 |
+
os.path.join(cycle_dir, "cycle_metadata.json"),
|
| 2146 |
+
"w",
|
| 2147 |
+
encoding="utf-8",
|
| 2148 |
+
) as handle:
|
| 2149 |
+
json.dump(cycle_meta, handle, indent=2)
|
| 2150 |
+
|
| 2151 |
+
cycle_summaries.append(
|
| 2152 |
+
{
|
| 2153 |
+
"cycle": cycle_idx,
|
| 2154 |
+
"layer_merged": layer_idx,
|
| 2155 |
+
"dwce_score": dwce_scores[layer_idx] if dwce_scores else None,
|
| 2156 |
+
"comm_post_ppl": comm_post_eval,
|
| 2157 |
+
"distill_post_ppl": distill_post_eval,
|
| 2158 |
+
"lora_post_ppl": lora_post_eval,
|
| 2159 |
+
"post_ppl": post_eval,
|
| 2160 |
+
"cycle_dir": f"cycle_{cycle_idx}",
|
| 2161 |
+
}
|
| 2162 |
+
)
|
| 2163 |
+
|
| 2164 |
+
last_fused_idx = layer_idx
|
| 2165 |
+
_snapshot_previous_cycle_teacher(cycle_idx)
|
| 2166 |
+
|
| 2167 |
+
parent, name, container = find_layer_container(model, args.layer_path)
|
| 2168 |
+
layers = list(container)
|
| 2169 |
+
if dwce_scores:
|
| 2170 |
+
dwce_scores = dwce_scores[: max(len(layers) - 1, 0)]
|
| 2171 |
+
|
| 2172 |
+
# Encourage allocator to release cached blocks between cycles.
|
| 2173 |
+
if torch.cuda.is_available():
|
| 2174 |
+
torch.cuda.empty_cache()
|
| 2175 |
+
torch.cuda.ipc_collect()
|
| 2176 |
+
gc.collect()
|
| 2177 |
+
|
| 2178 |
+
if saved_full_model_dir is not None:
|
| 2179 |
+
save_rng_state(os.path.join(saved_full_model_dir, "rng_state.pt"))
|
| 2180 |
+
save_loader_generator_state(
|
| 2181 |
+
saved_full_model_dir,
|
| 2182 |
+
distill_generator=prepared.distill_generator,
|
| 2183 |
+
lora_generator=prepared.lora_generator,
|
| 2184 |
+
)
|
| 2185 |
+
|
| 2186 |
+
_release_comm_teacher()
|
| 2187 |
+
_release_previous_cycle_teacher()
|
| 2188 |
+
|
| 2189 |
+
final_pre_lora_eval = cycle_summaries[-1]["post_ppl"] if cycle_summaries else None
|
| 2190 |
+
final_pre_lora_dir = f"{os.path.abspath(args.output_dir.rstrip(os.sep))}_final_pre_lora_hf"
|
| 2191 |
+
final_pre_lora_meta = save_stage_checkpoint(
|
| 2192 |
+
model=model,
|
| 2193 |
+
tokenizer=tokenizer,
|
| 2194 |
+
stage_dir=final_pre_lora_dir,
|
| 2195 |
+
stage_name="final_pre_lora",
|
| 2196 |
+
ppl_results=final_pre_lora_eval,
|
| 2197 |
+
)
|
| 2198 |
+
|
| 2199 |
+
# Optional final LoRA finetune after all pruning cycles.
|
| 2200 |
+
lora_eval_history: List[Dict[str, object]] = []
|
| 2201 |
+
lora_post_eval = None
|
| 2202 |
+
lora_ready = has_post_fusion_data(prepared.lora_loader, prepared.lora_meta)
|
| 2203 |
+
if args.lora_epochs > 0:
|
| 2204 |
+
if not lora_ready:
|
| 2205 |
+
print("No post-fusion sequences built; skipping LoRA finetuning.")
|
| 2206 |
+
else:
|
| 2207 |
+
print(
|
| 2208 |
+
f"[progressive] Running final LoRA finetuning (epochs={args.lora_epochs})."
|
| 2209 |
+
)
|
| 2210 |
+
lora_eval_history = run_lora_phase(
|
| 2211 |
+
model=model,
|
| 2212 |
+
tokenizer=tokenizer,
|
| 2213 |
+
eval_datasets=eval_datasets,
|
| 2214 |
+
eval_configs=eval_configs,
|
| 2215 |
+
args=args,
|
| 2216 |
+
lora_loader=prepared.lora_loader,
|
| 2217 |
+
lora_meta=prepared.lora_meta,
|
| 2218 |
+
eval_dataloaders=prepared.eval_dataloaders,
|
| 2219 |
+
cycle_idx=args.num_progressive,
|
| 2220 |
+
num_cycles=args.num_progressive,
|
| 2221 |
+
)
|
| 2222 |
+
|
| 2223 |
+
if not args.skip_eval:
|
| 2224 |
+
lora_post_eval = evaluate_ppl_model(
|
| 2225 |
+
model,
|
| 2226 |
+
tokenizer,
|
| 2227 |
+
eval_datasets,
|
| 2228 |
+
eval_configs,
|
| 2229 |
+
args,
|
| 2230 |
+
prepared_eval_dataloaders=prepared.eval_dataloaders,
|
| 2231 |
+
)
|
| 2232 |
+
print("[progressive] Post-LoRA perplexity:")
|
| 2233 |
+
for dataset_name, ppl in lora_post_eval.items():
|
| 2234 |
+
print(f"{dataset_name}: {ppl:.4f}")
|
| 2235 |
+
|
| 2236 |
+
# Update final cycle metadata and summary with the post-LoRA PPL.
|
| 2237 |
+
if cycle_summaries:
|
| 2238 |
+
cycle_summaries[-1]["lora_post_ppl"] = lora_post_eval
|
| 2239 |
+
if lora_post_eval is not None:
|
| 2240 |
+
cycle_summaries[-1]["post_ppl"] = lora_post_eval
|
| 2241 |
+
|
| 2242 |
+
final_cycle_dir = os.path.join(
|
| 2243 |
+
args.output_dir, f"cycle_{args.num_progressive}"
|
| 2244 |
+
)
|
| 2245 |
+
final_cycle_meta_path = os.path.join(final_cycle_dir, "cycle_metadata.json")
|
| 2246 |
+
if os.path.exists(final_cycle_meta_path):
|
| 2247 |
+
with open(final_cycle_meta_path, "r", encoding="utf-8") as handle:
|
| 2248 |
+
final_cycle_meta = json.load(handle)
|
| 2249 |
+
|
| 2250 |
+
lora_meta_entry = final_cycle_meta.get("lora")
|
| 2251 |
+
if not isinstance(lora_meta_entry, dict):
|
| 2252 |
+
lora_meta_entry = {}
|
| 2253 |
+
final_cycle_meta["lora"] = lora_meta_entry
|
| 2254 |
+
lora_meta_entry["ran"] = True
|
| 2255 |
+
lora_meta_entry["post_ppl"] = lora_post_eval
|
| 2256 |
+
if lora_post_eval is not None and isinstance(
|
| 2257 |
+
final_cycle_meta.get("eval"), dict
|
| 2258 |
+
):
|
| 2259 |
+
final_cycle_meta["eval"]["post_ppl"] = lora_post_eval
|
| 2260 |
+
|
| 2261 |
+
if lora_eval_history:
|
| 2262 |
+
lora_path = os.path.join(final_cycle_dir, "ppl_over_lora.json")
|
| 2263 |
+
with open(lora_path, "w", encoding="utf-8") as handle:
|
| 2264 |
+
json.dump(lora_eval_history, handle, indent=2)
|
| 2265 |
+
lora_meta_entry["ppl_over_lora"] = "ppl_over_lora.json"
|
| 2266 |
+
|
| 2267 |
+
with open(final_cycle_meta_path, "w", encoding="utf-8") as handle:
|
| 2268 |
+
json.dump(final_cycle_meta, handle, indent=2)
|
| 2269 |
+
|
| 2270 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 2271 |
+
final_post_lora_meta = save_stage_checkpoint(
|
| 2272 |
+
model=model,
|
| 2273 |
+
tokenizer=tokenizer,
|
| 2274 |
+
stage_dir=args.output_dir,
|
| 2275 |
+
stage_name="final_post_lora" if lora_post_eval is not None else "final_model",
|
| 2276 |
+
ppl_results=lora_post_eval,
|
| 2277 |
+
)
|
| 2278 |
+
|
| 2279 |
+
progressive_meta = {
|
| 2280 |
+
"base_model": getattr(args, "base_model_id", args.model),
|
| 2281 |
+
"num_progressive": args.num_progressive,
|
| 2282 |
+
"layer_path": args.layer_path,
|
| 2283 |
+
"resume_from_cycle": args.resume_from_cycle,
|
| 2284 |
+
"save_full_model_cycles": sorted(args.full_model_save_cycles),
|
| 2285 |
+
"num_sequences": num_sequences,
|
| 2286 |
+
"seq_len": args.seq_len,
|
| 2287 |
+
"lora": {
|
| 2288 |
+
"enabled": args.lora_epochs > 0,
|
| 2289 |
+
"ran": args.lora_epochs > 0 and lora_ready,
|
| 2290 |
+
"seq_len": args.distill_seq_len,
|
| 2291 |
+
"batch_size": args.distill_batch_size,
|
| 2292 |
+
"epochs": args.lora_epochs,
|
| 2293 |
+
"rank": args.lora_rank,
|
| 2294 |
+
"alpha": args.lora_alpha,
|
| 2295 |
+
"dropout": args.lora_dropout,
|
| 2296 |
+
"target_modules": args.lora_target_modules,
|
| 2297 |
+
"respect_exclude_pairs": args.lora_respect_exclude_pairs,
|
| 2298 |
+
"kl_enabled": args.lora_kl_enabled,
|
| 2299 |
+
"kl_weight": args.lora_kl_weight,
|
| 2300 |
+
"kl_temp": args.lora_kl_temp,
|
| 2301 |
+
"post_ppl": lora_post_eval,
|
| 2302 |
+
"ppl_over_lora": (
|
| 2303 |
+
f"cycle_{args.num_progressive}/ppl_over_lora.json"
|
| 2304 |
+
if lora_eval_history
|
| 2305 |
+
else None
|
| 2306 |
+
),
|
| 2307 |
+
"lr": args.lora_lr,
|
| 2308 |
+
"weight_decay": args.lora_weight_decay,
|
| 2309 |
+
"max_grad_norm": args.lora_max_grad_norm,
|
| 2310 |
+
"grad_accum_steps": args.lora_grad_accum_steps,
|
| 2311 |
+
"log_steps": args.lora_log_steps,
|
| 2312 |
+
"eval_every": args.lora_eval_every,
|
| 2313 |
+
"eval_max_batches": args.lora_eval_max_batches,
|
| 2314 |
+
},
|
| 2315 |
+
"artifacts": {
|
| 2316 |
+
"final_pre_lora": final_pre_lora_meta,
|
| 2317 |
+
"final_post_lora": final_post_lora_meta,
|
| 2318 |
+
},
|
| 2319 |
+
"eval": {
|
| 2320 |
+
"datasets": eval_datasets,
|
| 2321 |
+
"configs": eval_configs,
|
| 2322 |
+
"split": args.eval_split,
|
| 2323 |
+
"num_samples": args.eval_num_samples,
|
| 2324 |
+
"seq_len": args.eval_seq_len,
|
| 2325 |
+
"pre_ppl": pre_eval,
|
| 2326 |
+
"post_ppl": cycle_summaries[-1]["post_ppl"] if cycle_summaries else None,
|
| 2327 |
+
},
|
| 2328 |
+
"cycles": cycle_summaries,
|
| 2329 |
+
"final_num_layers": len(layers),
|
| 2330 |
+
}
|
| 2331 |
+
|
| 2332 |
+
with open(
|
| 2333 |
+
os.path.join(args.output_dir, "progressive_metadata.json"),
|
| 2334 |
+
"w",
|
| 2335 |
+
encoding="utf-8",
|
| 2336 |
+
) as handle:
|
| 2337 |
+
json.dump(progressive_meta, handle, indent=2)
|
| 2338 |
+
|
| 2339 |
+
print(
|
| 2340 |
+
f"[progressive] Completed {args.num_progressive} cycles. "
|
| 2341 |
+
f"Final model saved to {args.output_dir}."
|
| 2342 |
+
)
|
| 2343 |
+
|
| 2344 |
+
|
| 2345 |
+
def main() -> None:
|
| 2346 |
+
args = parse_args()
|
| 2347 |
+
if args.num_progressive <= 0:
|
| 2348 |
+
raise SystemExit(
|
| 2349 |
+
"Single-cycle mode has been removed. Pass --num_progressive > 0."
|
| 2350 |
+
)
|
| 2351 |
+
if args.resume_from_cycle < 0:
|
| 2352 |
+
raise SystemExit("--resume_from_cycle must be >= 0.")
|
| 2353 |
+
if args.resume_from_cycle >= args.num_progressive:
|
| 2354 |
+
raise SystemExit("--resume_from_cycle must be smaller than --num_progressive.")
|
| 2355 |
+
args.full_model_save_cycles = resolve_full_model_save_cycles(
|
| 2356 |
+
parse_cycle_list(args.save_full_model_cycles),
|
| 2357 |
+
args.num_progressive,
|
| 2358 |
+
)
|
| 2359 |
+
args.base_model_id = args.model
|
| 2360 |
+
if args.resume_from_cycle > 0:
|
| 2361 |
+
resume_meta = load_resume_metadata(args.model)
|
| 2362 |
+
if resume_meta is None:
|
| 2363 |
+
raise SystemExit(
|
| 2364 |
+
"--resume_from_cycle requires --model to point to a saved cycle full model "
|
| 2365 |
+
"directory containing resume_info.json."
|
| 2366 |
+
)
|
| 2367 |
+
resume_cycle = resume_meta.get("cycle")
|
| 2368 |
+
if resume_cycle is not None and int(resume_cycle) != args.resume_from_cycle:
|
| 2369 |
+
raise SystemExit(
|
| 2370 |
+
"resume_info.json cycle does not match --resume_from_cycle."
|
| 2371 |
+
)
|
| 2372 |
+
base_model = resume_meta.get("base_model")
|
| 2373 |
+
if isinstance(base_model, str) and base_model:
|
| 2374 |
+
args.base_model_id = base_model
|
| 2375 |
+
configure_reproducibility(args.seed)
|
| 2376 |
+
|
| 2377 |
+
eval_datasets, eval_configs = resolve_eval_datasets(args)
|
| 2378 |
+
dtype = get_dtype(args.dtype)
|
| 2379 |
+
model = load_causal_lm(
|
| 2380 |
+
args.model,
|
| 2381 |
+
torch_dtype=dtype,
|
| 2382 |
+
trust_remote_code=args.trust_remote_code,
|
| 2383 |
+
cache_dir=args.model_cache_dir,
|
| 2384 |
+
)
|
| 2385 |
+
loader_generator_state = None
|
| 2386 |
+
if args.resume_from_cycle > 0:
|
| 2387 |
+
rng_state_path = os.path.join(args.model, "rng_state.pt")
|
| 2388 |
+
rng_state = load_rng_state(rng_state_path)
|
| 2389 |
+
if rng_state is not None:
|
| 2390 |
+
restore_rng_state(rng_state)
|
| 2391 |
+
loader_generator_state = load_loader_generator_state(args.model)
|
| 2392 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 2393 |
+
args.model,
|
| 2394 |
+
trust_remote_code=args.trust_remote_code,
|
| 2395 |
+
cache_dir=args.model_cache_dir,
|
| 2396 |
+
)
|
| 2397 |
+
print(model)
|
| 2398 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 2399 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 2400 |
+
|
| 2401 |
+
# for llama?
|
| 2402 |
+
model.config.use_cache = False
|
| 2403 |
+
|
| 2404 |
+
prepared = prepare_all_data(
|
| 2405 |
+
args,
|
| 2406 |
+
tokenizer,
|
| 2407 |
+
model,
|
| 2408 |
+
eval_datasets,
|
| 2409 |
+
eval_configs,
|
| 2410 |
+
loader_generator_state=loader_generator_state,
|
| 2411 |
+
)
|
| 2412 |
+
run_progressive(args, model, tokenizer, prepared)
|
| 2413 |
+
|
| 2414 |
+
|
| 2415 |
+
if __name__ == "__main__":
|
| 2416 |
+
main()
|
src/fuse_layers_data.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Dataset and text helpers for fuse_layers."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
except Exception: # pragma: no cover - optional dependency
|
| 12 |
+
load_dataset = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def guess_text_field(dataset) -> str:
|
| 16 |
+
if hasattr(dataset, "column_names") and dataset.column_names:
|
| 17 |
+
if "text" in dataset.column_names:
|
| 18 |
+
return "text"
|
| 19 |
+
return dataset.column_names[0]
|
| 20 |
+
if hasattr(dataset, "features"):
|
| 21 |
+
names = list(dataset.features.keys())
|
| 22 |
+
if "text" in names:
|
| 23 |
+
return "text"
|
| 24 |
+
if names:
|
| 25 |
+
return names[0]
|
| 26 |
+
return "text"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _normalize_config(config: Optional[str]) -> Optional[str]:
|
| 30 |
+
if config is None:
|
| 31 |
+
return None
|
| 32 |
+
if config.strip().lower() in {"none", "null", "-"}:
|
| 33 |
+
return None
|
| 34 |
+
return config
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def expand_dataset_configs(
|
| 38 |
+
datasets: List[str], configs: List[str]
|
| 39 |
+
) -> List[Optional[str]]:
|
| 40 |
+
if not configs:
|
| 41 |
+
return [None] * len(datasets)
|
| 42 |
+
if len(configs) == 1 and len(datasets) > 1:
|
| 43 |
+
return [_normalize_config(configs[0])] * len(datasets)
|
| 44 |
+
if len(configs) != len(datasets):
|
| 45 |
+
raise SystemExit(
|
| 46 |
+
"Provide zero, one, or matching-count --dataset_config values."
|
| 47 |
+
)
|
| 48 |
+
return [_normalize_config(cfg) for cfg in configs]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _sample_dataset_rows(
|
| 52 |
+
dataset, target: int, seed: int
|
| 53 |
+
) -> List[Dict[str, object]]:
|
| 54 |
+
if target <= 0:
|
| 55 |
+
return []
|
| 56 |
+
try:
|
| 57 |
+
dataset = dataset.shuffle(seed=seed)
|
| 58 |
+
except Exception:
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
if hasattr(dataset, "__len__"):
|
| 62 |
+
limit = min(target, len(dataset))
|
| 63 |
+
dataset = dataset.select(range(limit))
|
| 64 |
+
return [row for row in dataset]
|
| 65 |
+
|
| 66 |
+
rows = []
|
| 67 |
+
for row in dataset:
|
| 68 |
+
rows.append(row)
|
| 69 |
+
if len(rows) >= target:
|
| 70 |
+
break
|
| 71 |
+
return rows
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_texts(args: argparse.Namespace) -> List[str]:
|
| 75 |
+
texts: List[str] = []
|
| 76 |
+
if args.text_file:
|
| 77 |
+
with open(args.text_file, "r", encoding="utf-8") as handle:
|
| 78 |
+
texts.extend([line.strip() for line in handle if line.strip()])
|
| 79 |
+
if args.text:
|
| 80 |
+
texts.extend([t for t in args.text if t])
|
| 81 |
+
|
| 82 |
+
if args.dataset:
|
| 83 |
+
if load_dataset is None:
|
| 84 |
+
raise SystemExit("datasets is required for --dataset")
|
| 85 |
+
|
| 86 |
+
datasets = list(args.dataset)
|
| 87 |
+
configs = expand_dataset_configs(datasets, list(args.dataset_config))
|
| 88 |
+
num_datasets = len(datasets)
|
| 89 |
+
base = args.num_samples // num_datasets
|
| 90 |
+
remainder = args.num_samples % num_datasets
|
| 91 |
+
|
| 92 |
+
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
|
| 93 |
+
target = base + (1 if idx < remainder else 0)
|
| 94 |
+
dataset = load_dataset(
|
| 95 |
+
dataset_name,
|
| 96 |
+
config,
|
| 97 |
+
split=args.dataset_split,
|
| 98 |
+
trust_remote_code=True,
|
| 99 |
+
)
|
| 100 |
+
rows = _sample_dataset_rows(dataset, target, args.seed + idx)
|
| 101 |
+
text_field = args.dataset_text_field or guess_text_field(dataset)
|
| 102 |
+
for row in rows:
|
| 103 |
+
value = row.get(text_field, None) if isinstance(row, dict) else None
|
| 104 |
+
if isinstance(value, str) and value.strip():
|
| 105 |
+
texts.append(value)
|
| 106 |
+
|
| 107 |
+
return texts
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_texts_from_datasets(
|
| 111 |
+
datasets: List[str],
|
| 112 |
+
configs: List[Optional[str]],
|
| 113 |
+
split: str,
|
| 114 |
+
text_field: Optional[str],
|
| 115 |
+
num_samples: int,
|
| 116 |
+
seed: int,
|
| 117 |
+
) -> List[str]:
|
| 118 |
+
if not datasets:
|
| 119 |
+
return []
|
| 120 |
+
if load_dataset is None:
|
| 121 |
+
raise SystemExit("datasets is required for --dataset")
|
| 122 |
+
|
| 123 |
+
texts: List[str] = []
|
| 124 |
+
num_datasets = len(datasets)
|
| 125 |
+
base = num_samples // num_datasets
|
| 126 |
+
remainder = num_samples % num_datasets
|
| 127 |
+
|
| 128 |
+
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
|
| 129 |
+
target = base + (1 if idx < remainder else 0)
|
| 130 |
+
dataset = load_dataset(
|
| 131 |
+
dataset_name,
|
| 132 |
+
config,
|
| 133 |
+
split=split,
|
| 134 |
+
trust_remote_code=True,
|
| 135 |
+
)
|
| 136 |
+
rows = _sample_dataset_rows(dataset, target, seed + idx)
|
| 137 |
+
field = text_field or guess_text_field(dataset)
|
| 138 |
+
for row in rows:
|
| 139 |
+
value = row.get(field, None) if isinstance(row, dict) else None
|
| 140 |
+
if isinstance(value, str) and value.strip():
|
| 141 |
+
texts.append(value)
|
| 142 |
+
return texts
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def format_alpaca_example(instruction: str, inp: str, output: str) -> str:
|
| 146 |
+
if inp:
|
| 147 |
+
return (
|
| 148 |
+
"### Instruction:\n"
|
| 149 |
+
f"{instruction}\n\n"
|
| 150 |
+
"### Input:\n"
|
| 151 |
+
f"{inp}\n\n"
|
| 152 |
+
"### Response:\n"
|
| 153 |
+
f"{output}"
|
| 154 |
+
)
|
| 155 |
+
return (
|
| 156 |
+
"### Instruction:\n"
|
| 157 |
+
f"{instruction}\n\n"
|
| 158 |
+
"### Response:\n"
|
| 159 |
+
f"{output}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def build_alpaca_messages(
|
| 164 |
+
instruction: str, inp: str, output: str
|
| 165 |
+
) -> List[Dict[str, str]]:
|
| 166 |
+
if inp:
|
| 167 |
+
user_content = f"{instruction}\n\nInput:\n{inp}"
|
| 168 |
+
else:
|
| 169 |
+
user_content = instruction
|
| 170 |
+
return [
|
| 171 |
+
{"role": "user", "content": user_content},
|
| 172 |
+
{"role": "assistant", "content": output},
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class FixedSeqDataset(torch.utils.data.Dataset):
|
| 177 |
+
def __init__(self, records: List[Dict[str, object]], tokenizer, seq_len: int) -> None:
|
| 178 |
+
self.records = records
|
| 179 |
+
self.tokenizer = tokenizer
|
| 180 |
+
self.seq_len = seq_len
|
| 181 |
+
self.pad_id = tokenizer.pad_token_id
|
| 182 |
+
if self.pad_id is None:
|
| 183 |
+
self.pad_id = tokenizer.eos_token_id or 0
|
| 184 |
+
|
| 185 |
+
def __len__(self) -> int:
|
| 186 |
+
return len(self.records)
|
| 187 |
+
|
| 188 |
+
def __getitem__(self, idx: int):
|
| 189 |
+
record = self.records[idx]
|
| 190 |
+
chat_template = getattr(self.tokenizer, "chat_template", None)
|
| 191 |
+
if (
|
| 192 |
+
"messages" in record
|
| 193 |
+
and hasattr(self.tokenizer, "apply_chat_template")
|
| 194 |
+
and chat_template
|
| 195 |
+
):
|
| 196 |
+
ids = self.tokenizer.apply_chat_template(
|
| 197 |
+
record["messages"],
|
| 198 |
+
tokenize=True,
|
| 199 |
+
add_generation_prompt=False,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
text = record.get("text", "")
|
| 203 |
+
ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 204 |
+
|
| 205 |
+
# Transformers may return a BatchEncoding here instead of a plain list.
|
| 206 |
+
if hasattr(ids, "input_ids"):
|
| 207 |
+
ids = ids.input_ids
|
| 208 |
+
if isinstance(ids, torch.Tensor):
|
| 209 |
+
ids = ids.tolist()
|
| 210 |
+
elif not isinstance(ids, list):
|
| 211 |
+
ids = list(ids)
|
| 212 |
+
|
| 213 |
+
if len(ids) > self.seq_len:
|
| 214 |
+
ids = ids[: self.seq_len]
|
| 215 |
+
attn = [1] * len(ids)
|
| 216 |
+
if len(ids) < self.seq_len:
|
| 217 |
+
pad_len = self.seq_len - len(ids)
|
| 218 |
+
ids = ids + [self.pad_id] * pad_len
|
| 219 |
+
attn = attn + [0] * pad_len
|
| 220 |
+
|
| 221 |
+
return (
|
| 222 |
+
torch.tensor(ids, dtype=torch.long),
|
| 223 |
+
torch.tensor(attn, dtype=torch.long),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def load_instruction_records(
|
| 228 |
+
args: argparse.Namespace, num_samples: int
|
| 229 |
+
) -> List[Dict[str, object]]:
|
| 230 |
+
if not args.instruction_dataset:
|
| 231 |
+
return []
|
| 232 |
+
if load_dataset is None:
|
| 233 |
+
raise SystemExit("datasets is required for instruction dataset")
|
| 234 |
+
|
| 235 |
+
dataset = load_dataset(
|
| 236 |
+
args.instruction_dataset,
|
| 237 |
+
_normalize_config(args.instruction_config),
|
| 238 |
+
split=args.instruction_split,
|
| 239 |
+
trust_remote_code=True,
|
| 240 |
+
)
|
| 241 |
+
if num_samples > 0:
|
| 242 |
+
rows = _sample_dataset_rows(dataset, num_samples, args.seed)
|
| 243 |
+
else:
|
| 244 |
+
rows = dataset
|
| 245 |
+
records: List[Dict[str, object]] = []
|
| 246 |
+
for row in rows:
|
| 247 |
+
if not isinstance(row, dict):
|
| 248 |
+
continue
|
| 249 |
+
instruction = str(row.get(args.instruction_field_instruction, "")).strip()
|
| 250 |
+
inp = str(row.get(args.instruction_field_input, "")).strip()
|
| 251 |
+
output = str(row.get(args.instruction_field_output, "")).strip()
|
| 252 |
+
if not instruction or not output:
|
| 253 |
+
continue
|
| 254 |
+
records.append(
|
| 255 |
+
{
|
| 256 |
+
"messages": build_alpaca_messages(instruction, inp, output),
|
| 257 |
+
"text": format_alpaca_example(instruction, inp, output),
|
| 258 |
+
}
|
| 259 |
+
)
|
| 260 |
+
return records
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def build_token_chunks(
|
| 264 |
+
texts: List[str], tokenizer, seq_len: int, num_samples: int
|
| 265 |
+
) -> List[torch.Tensor]:
|
| 266 |
+
chunks: List[torch.Tensor] = []
|
| 267 |
+
buffer: List[int] = []
|
| 268 |
+
limit = None if num_samples <= 0 else num_samples
|
| 269 |
+
for text in texts:
|
| 270 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 271 |
+
if not ids:
|
| 272 |
+
continue
|
| 273 |
+
buffer.extend(ids)
|
| 274 |
+
while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
|
| 275 |
+
chunk = buffer[:seq_len]
|
| 276 |
+
buffer = buffer[seq_len:]
|
| 277 |
+
chunks.append(torch.tensor(chunk, dtype=torch.long))
|
| 278 |
+
if limit is not None and len(chunks) >= limit:
|
| 279 |
+
break
|
| 280 |
+
return chunks
|
src/fuse_layers_distill.py
ADDED
|
@@ -0,0 +1,2018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Distillation helpers for fuse_layers."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import itertools
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from contextlib import contextmanager, nullcontext
|
| 9 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import ppl_eval
|
| 16 |
+
except Exception as exc: # pragma: no cover - optional dependency
|
| 17 |
+
raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
|
| 18 |
+
try:
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
except Exception: # pragma: no cover - optional dependency
|
| 21 |
+
tqdm = None
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from torch.func import functional_call as _functional_call
|
| 25 |
+
except Exception: # pragma: no cover - depends on torch version
|
| 26 |
+
try:
|
| 27 |
+
from torch.nn.utils.stateless import functional_call as _functional_call
|
| 28 |
+
except Exception: # pragma: no cover - depends on torch version
|
| 29 |
+
_functional_call = None
|
| 30 |
+
|
| 31 |
+
from fuse_layers_model import find_attention_module, find_mlp_module
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _tqdm_enabled() -> bool:
|
| 35 |
+
value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
|
| 36 |
+
return value.strip().lower() not in {"1", "true", "yes", "on"}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@contextmanager
|
| 40 |
+
def temporary_layers(parent: object, name: str, new_layers: torch.nn.Module):
|
| 41 |
+
original = getattr(parent, name)
|
| 42 |
+
setattr(parent, name, new_layers)
|
| 43 |
+
try:
|
| 44 |
+
yield
|
| 45 |
+
finally:
|
| 46 |
+
setattr(parent, name, original)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@contextmanager
|
| 50 |
+
def temporary_norm(parent: object):
|
| 51 |
+
if hasattr(parent, "norm"):
|
| 52 |
+
original = getattr(parent, "norm")
|
| 53 |
+
setattr(parent, "norm", torch.nn.Identity())
|
| 54 |
+
try:
|
| 55 |
+
yield
|
| 56 |
+
finally:
|
| 57 |
+
setattr(parent, "norm", original)
|
| 58 |
+
else:
|
| 59 |
+
yield
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def forward_truncated(
|
| 63 |
+
parent: torch.nn.Module,
|
| 64 |
+
layer_attr: str,
|
| 65 |
+
layers: List[torch.nn.Module],
|
| 66 |
+
upto: int,
|
| 67 |
+
input_ids: torch.Tensor,
|
| 68 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
truncated = torch.nn.ModuleList(layers[:upto])
|
| 71 |
+
with temporary_layers(parent, layer_attr, truncated), temporary_norm(parent):
|
| 72 |
+
outputs = parent(
|
| 73 |
+
input_ids=input_ids,
|
| 74 |
+
attention_mask=attention_mask,
|
| 75 |
+
use_cache=False,
|
| 76 |
+
)
|
| 77 |
+
if hasattr(outputs, "last_hidden_state"):
|
| 78 |
+
return outputs.last_hidden_state
|
| 79 |
+
return outputs[0]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _masked_hidden_mse(diff: torch.Tensor, attention_mask: torch.Tensor) -> Optional[torch.Tensor]:
|
| 83 |
+
diff_f = diff.float()
|
| 84 |
+
mask = attention_mask.to(device=diff.device, dtype=torch.float32)
|
| 85 |
+
denom = mask.sum() * diff_f.size(-1)
|
| 86 |
+
if denom.item() == 0:
|
| 87 |
+
return None
|
| 88 |
+
return (diff_f.pow(2) * mask.unsqueeze(-1)).sum() / denom
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _extract_hidden_like(output) -> Optional[torch.Tensor]:
|
| 92 |
+
if torch.is_tensor(output):
|
| 93 |
+
return output
|
| 94 |
+
if isinstance(output, (tuple, list)) and output:
|
| 95 |
+
first = output[0]
|
| 96 |
+
if torch.is_tensor(first):
|
| 97 |
+
return first
|
| 98 |
+
if hasattr(output, "last_hidden_state"):
|
| 99 |
+
hidden = getattr(output, "last_hidden_state")
|
| 100 |
+
if torch.is_tensor(hidden):
|
| 101 |
+
return hidden
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@contextmanager
|
| 106 |
+
def capture_module_output(module: torch.nn.Module):
|
| 107 |
+
cache: Dict[str, Optional[torch.Tensor]] = {"output": None}
|
| 108 |
+
|
| 109 |
+
def hook(_module, _inputs, output):
|
| 110 |
+
cache["output"] = _extract_hidden_like(output)
|
| 111 |
+
|
| 112 |
+
handle = module.register_forward_hook(hook)
|
| 113 |
+
try:
|
| 114 |
+
yield cache
|
| 115 |
+
finally:
|
| 116 |
+
handle.remove()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
_ATTN_NAME_FRAGMENTS = (
|
| 120 |
+
"self_attn.",
|
| 121 |
+
"attn.",
|
| 122 |
+
"attention.",
|
| 123 |
+
"q_proj",
|
| 124 |
+
"k_proj",
|
| 125 |
+
"v_proj",
|
| 126 |
+
"o_proj",
|
| 127 |
+
"q_norm",
|
| 128 |
+
"k_norm",
|
| 129 |
+
)
|
| 130 |
+
_MLP_NAME_FRAGMENTS = (
|
| 131 |
+
"mlp.",
|
| 132 |
+
"ffn.",
|
| 133 |
+
"feed_forward",
|
| 134 |
+
"feedforward",
|
| 135 |
+
"gate_proj",
|
| 136 |
+
"up_proj",
|
| 137 |
+
"down_proj",
|
| 138 |
+
"fc1",
|
| 139 |
+
"fc2",
|
| 140 |
+
"dense_h_to_4h",
|
| 141 |
+
"dense_4h_to_h",
|
| 142 |
+
"w1",
|
| 143 |
+
"w2",
|
| 144 |
+
"w3",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _classify_param_family(name: str) -> str:
|
| 149 |
+
lowered = name.lower()
|
| 150 |
+
if any(fragment in lowered for fragment in _MLP_NAME_FRAGMENTS):
|
| 151 |
+
return "mlp"
|
| 152 |
+
if any(fragment in lowered for fragment in _ATTN_NAME_FRAGMENTS):
|
| 153 |
+
return "attn"
|
| 154 |
+
return "other"
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _family_reg_scale(family: str, attn_scale: float, mlp_scale: float) -> float:
|
| 158 |
+
if family == "attn":
|
| 159 |
+
return attn_scale
|
| 160 |
+
if family == "mlp":
|
| 161 |
+
return mlp_scale
|
| 162 |
+
return 1.0
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _subset_allows_param(name: str, subset: str) -> bool:
|
| 166 |
+
if subset == "all":
|
| 167 |
+
return True
|
| 168 |
+
return _classify_param_family(name) == subset
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _gate_logit_from_prior(prior: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
# Stable logit: log(p) - log(1 - p).
|
| 173 |
+
return torch.log(prior) - torch.log1p(-prior)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _build_gate_priors(
|
| 177 |
+
layer_a: torch.nn.Module,
|
| 178 |
+
layer_b: torch.nn.Module,
|
| 179 |
+
fisher_a: Dict[str, object],
|
| 180 |
+
fisher_b: Dict[str, object],
|
| 181 |
+
num_batches: int,
|
| 182 |
+
numels_a: Dict[str, int],
|
| 183 |
+
numels_b: Dict[str, int],
|
| 184 |
+
fisher_mode: str,
|
| 185 |
+
eps: float,
|
| 186 |
+
clamp_eps: float,
|
| 187 |
+
) -> Dict[str, torch.Tensor]:
|
| 188 |
+
"""Return lambda priors for parameters that can be merged."""
|
| 189 |
+
priors: Dict[str, torch.Tensor] = {}
|
| 190 |
+
params_b = {name: param for name, param in layer_b.named_parameters()}
|
| 191 |
+
for name, param_a in layer_a.named_parameters():
|
| 192 |
+
param_b = params_b.get(name)
|
| 193 |
+
if param_b is None or param_b.shape != param_a.shape:
|
| 194 |
+
continue
|
| 195 |
+
if fisher_mode == "param":
|
| 196 |
+
fa = fisher_a[name] / max(num_batches, 1)
|
| 197 |
+
fb = fisher_b[name] / max(num_batches, 1)
|
| 198 |
+
denom = fa + fb
|
| 199 |
+
if not isinstance(denom, torch.Tensor):
|
| 200 |
+
denom = torch.tensor(float(denom))
|
| 201 |
+
# If Fisher is uninformative, default to symmetric init.
|
| 202 |
+
prior = torch.where(
|
| 203 |
+
denom > eps,
|
| 204 |
+
fa / (denom + eps),
|
| 205 |
+
torch.full_like(denom, 0.5),
|
| 206 |
+
)
|
| 207 |
+
prior = prior.clamp(clamp_eps, 1.0 - clamp_eps)
|
| 208 |
+
priors[name] = prior
|
| 209 |
+
else:
|
| 210 |
+
fa = fisher_a[name] / (max(num_batches, 1) * numels_a[name])
|
| 211 |
+
fb = fisher_b[name] / (max(num_batches, 1) * numels_b[name])
|
| 212 |
+
denom = fa + fb
|
| 213 |
+
if denom <= eps:
|
| 214 |
+
prior_val = 0.5
|
| 215 |
+
else:
|
| 216 |
+
prior_val = float(fa / (denom + eps))
|
| 217 |
+
prior_val = min(max(prior_val, clamp_eps), 1.0 - clamp_eps)
|
| 218 |
+
priors[name] = torch.tensor(prior_val, dtype=torch.float32)
|
| 219 |
+
return priors
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def compute_fisher_gate_priors(
|
| 223 |
+
layer_a: torch.nn.Module,
|
| 224 |
+
layer_b: torch.nn.Module,
|
| 225 |
+
fisher_a: Dict[str, object],
|
| 226 |
+
fisher_b: Dict[str, object],
|
| 227 |
+
num_batches: int,
|
| 228 |
+
numels_a: Dict[str, int],
|
| 229 |
+
numels_b: Dict[str, int],
|
| 230 |
+
fisher_mode: str,
|
| 231 |
+
eps: float,
|
| 232 |
+
clamp_eps: float = 1e-4,
|
| 233 |
+
) -> Dict[str, torch.Tensor]:
|
| 234 |
+
"""Compute Fisher prior gate lambdas (lambda_prior) for mergeable parameters."""
|
| 235 |
+
return _build_gate_priors(
|
| 236 |
+
layer_a=layer_a,
|
| 237 |
+
layer_b=layer_b,
|
| 238 |
+
fisher_a=fisher_a,
|
| 239 |
+
fisher_b=fisher_b,
|
| 240 |
+
num_batches=num_batches,
|
| 241 |
+
numels_a=numels_a,
|
| 242 |
+
numels_b=numels_b,
|
| 243 |
+
fisher_mode=fisher_mode,
|
| 244 |
+
eps=eps,
|
| 245 |
+
clamp_eps=clamp_eps,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ReparamMergedLayer(torch.nn.Module):
|
| 250 |
+
"""Virtual layer that merges parameters via W0/U reparameterization.
|
| 251 |
+
|
| 252 |
+
Parameters of layer_a/layer_b are treated as frozen (detached). We train:
|
| 253 |
+
- gate logits s (lambda = sigmoid(s))
|
| 254 |
+
- U (initialized as U0 = (W_a - W_b) / 2)
|
| 255 |
+
|
| 256 |
+
Forward uses:
|
| 257 |
+
W_merge = W0 + (2 * lambda - 1) * U
|
| 258 |
+
where W0 = (W_a + W_b) / 2
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
layer_a: torch.nn.Module,
|
| 264 |
+
layer_b: torch.nn.Module,
|
| 265 |
+
gate_targets: Dict[str, object],
|
| 266 |
+
param_subset: str = "all",
|
| 267 |
+
clamp_eps: float = 1e-4,
|
| 268 |
+
) -> None:
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.layer_a = layer_a
|
| 271 |
+
self.layer_b = layer_b
|
| 272 |
+
self.param_subset = param_subset
|
| 273 |
+
self._name_map: Dict[str, str] = {}
|
| 274 |
+
|
| 275 |
+
self.gates = torch.nn.ParameterDict()
|
| 276 |
+
self.u = torch.nn.ParameterDict()
|
| 277 |
+
|
| 278 |
+
params_b = {name: param for name, param in layer_b.named_parameters()}
|
| 279 |
+
try:
|
| 280 |
+
device = next(layer_a.parameters()).device
|
| 281 |
+
except StopIteration:
|
| 282 |
+
device = torch.device("cpu")
|
| 283 |
+
|
| 284 |
+
for name, param_a in layer_a.named_parameters():
|
| 285 |
+
param_b = params_b.get(name)
|
| 286 |
+
if param_b is None or param_b.shape != param_a.shape:
|
| 287 |
+
continue
|
| 288 |
+
if not _subset_allows_param(name, self.param_subset):
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
target = gate_targets.get(name)
|
| 292 |
+
if target is None:
|
| 293 |
+
target_t = torch.tensor(0.5, device=device, dtype=torch.float32)
|
| 294 |
+
elif isinstance(target, torch.Tensor):
|
| 295 |
+
target_t = target.detach().to(device=device, dtype=torch.float32)
|
| 296 |
+
else:
|
| 297 |
+
target_t = torch.tensor(float(target), device=device, dtype=torch.float32)
|
| 298 |
+
|
| 299 |
+
target_t = target_t.clamp(clamp_eps, 1.0 - clamp_eps)
|
| 300 |
+
s0 = _gate_logit_from_prior(target_t)
|
| 301 |
+
u0 = 0.5 * (param_a.detach().float() - param_b.detach().float())
|
| 302 |
+
|
| 303 |
+
safe = name.replace(".", "__")
|
| 304 |
+
if safe in self.gates:
|
| 305 |
+
safe = f"{safe}_{len(self.gates)}"
|
| 306 |
+
self._name_map[name] = safe
|
| 307 |
+
self.gates[safe] = torch.nn.Parameter(s0)
|
| 308 |
+
self.u[safe] = torch.nn.Parameter(u0)
|
| 309 |
+
|
| 310 |
+
def __getattr__(self, name: str):
|
| 311 |
+
# Delegate model-specific attributes (e.g. Qwen's `attention_type`) to
|
| 312 |
+
# the underlying layer so the parent forward doesn't break.
|
| 313 |
+
try:
|
| 314 |
+
return super().__getattr__(name)
|
| 315 |
+
except AttributeError as exc:
|
| 316 |
+
try:
|
| 317 |
+
layer_a = super().__getattr__("layer_a")
|
| 318 |
+
if hasattr(layer_a, name):
|
| 319 |
+
return getattr(layer_a, name)
|
| 320 |
+
except AttributeError:
|
| 321 |
+
pass
|
| 322 |
+
try:
|
| 323 |
+
layer_b = super().__getattr__("layer_b")
|
| 324 |
+
if hasattr(layer_b, name):
|
| 325 |
+
return getattr(layer_b, name)
|
| 326 |
+
except AttributeError:
|
| 327 |
+
pass
|
| 328 |
+
raise exc
|
| 329 |
+
|
| 330 |
+
def _safe_for(self, orig: str) -> Optional[str]:
|
| 331 |
+
return self._name_map.get(orig)
|
| 332 |
+
|
| 333 |
+
def gate_lambdas(self) -> Dict[str, torch.Tensor]:
|
| 334 |
+
out: Dict[str, torch.Tensor] = {}
|
| 335 |
+
for orig, safe in self._name_map.items():
|
| 336 |
+
out[orig] = torch.sigmoid(self.gates[safe]).detach()
|
| 337 |
+
return out
|
| 338 |
+
|
| 339 |
+
def _merged_params(self) -> Dict[str, torch.Tensor]:
|
| 340 |
+
params_a = {name: p for name, p in self.layer_a.named_parameters()}
|
| 341 |
+
params_b = {name: p for name, p in self.layer_b.named_parameters()}
|
| 342 |
+
merged_params: Dict[str, torch.Tensor] = {}
|
| 343 |
+
|
| 344 |
+
for name, param_a in params_a.items():
|
| 345 |
+
param_b = params_b.get(name)
|
| 346 |
+
safe = self._safe_for(name)
|
| 347 |
+
if safe is None or param_b is None or param_b.shape != param_a.shape:
|
| 348 |
+
merged_params[name] = param_a.detach()
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
lam = torch.sigmoid(self.gates[safe]).to(dtype=torch.float32)
|
| 352 |
+
u = self.u[safe].to(dtype=torch.float32)
|
| 353 |
+
w0 = 0.5 * (param_a.detach().float() + param_b.detach().float())
|
| 354 |
+
merged = w0 + (2.0 * lam - 1.0) * u
|
| 355 |
+
merged_params[name] = merged.to(dtype=param_a.dtype)
|
| 356 |
+
return merged_params
|
| 357 |
+
|
| 358 |
+
def forward(self, *args, **kwargs):
|
| 359 |
+
if _functional_call is None:
|
| 360 |
+
raise RuntimeError(
|
| 361 |
+
"Reparam distillation requires torch.func.functional_call"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
merged_params = self._merged_params()
|
| 365 |
+
return _functional_call(self.layer_a, merged_params, args, kwargs)
|
| 366 |
+
|
| 367 |
+
def materialize_into_layer_a(self) -> int:
|
| 368 |
+
merged = 0
|
| 369 |
+
params_a = {name: p for name, p in self.layer_a.named_parameters()}
|
| 370 |
+
params_b = {name: p for name, p in self.layer_b.named_parameters()}
|
| 371 |
+
with torch.no_grad():
|
| 372 |
+
for orig, safe in self._name_map.items():
|
| 373 |
+
param_a = params_a.get(orig)
|
| 374 |
+
param_b = params_b.get(orig)
|
| 375 |
+
if param_a is None or param_b is None or param_b.shape != param_a.shape:
|
| 376 |
+
continue
|
| 377 |
+
lam = torch.sigmoid(self.gates[safe]).to(device=param_a.device, dtype=torch.float32)
|
| 378 |
+
u = self.u[safe].to(device=param_a.device, dtype=torch.float32)
|
| 379 |
+
w0 = 0.5 * (param_a.detach().float() + param_b.detach().float())
|
| 380 |
+
merged_param = w0 + (2.0 * lam - 1.0) * u
|
| 381 |
+
param_a.copy_(merged_param.to(dtype=param_a.dtype))
|
| 382 |
+
merged += 1
|
| 383 |
+
return merged
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def distill_reparam_merge(
|
| 387 |
+
student_model: torch.nn.Module,
|
| 388 |
+
student_parent: object,
|
| 389 |
+
student_layer_attr: str,
|
| 390 |
+
student_layers: List[torch.nn.Module],
|
| 391 |
+
teacher_model: torch.nn.Module,
|
| 392 |
+
teacher_parent: object,
|
| 393 |
+
teacher_layer_attr: str,
|
| 394 |
+
teacher_layers: List[torch.nn.Module],
|
| 395 |
+
layer_idx: int,
|
| 396 |
+
gate_lambdas: Dict[str, object],
|
| 397 |
+
dataloader,
|
| 398 |
+
args: argparse.Namespace,
|
| 399 |
+
progressive_cycle: Optional[int] = None,
|
| 400 |
+
progressive_total: Optional[int] = None,
|
| 401 |
+
) -> Tuple[int, Dict[str, torch.Tensor], Dict[str, object]]:
|
| 402 |
+
"""Reparameterized distillation that materializes a fused layer into layer_a.
|
| 403 |
+
|
| 404 |
+
Trains U and gate logits s (lambda = sigmoid(s)) using:
|
| 405 |
+
- composition MSE + distill-KL
|
| 406 |
+
- eta * ||lambda - lambda_gate||^2 + gamma * ||U - U0||^2
|
| 407 |
+
"""
|
| 408 |
+
total_epochs = float(args.distill_epochs)
|
| 409 |
+
|
| 410 |
+
hidden_mse_weight = float(getattr(args, "distill_hidden_mse_weight", 1.0))
|
| 411 |
+
if hidden_mse_weight < 0.0:
|
| 412 |
+
raise SystemExit("--distill_hidden_mse_weight must be >= 0")
|
| 413 |
+
attn_mse_weight = float(getattr(args, "distill_attn_mse_weight", 0.0))
|
| 414 |
+
if attn_mse_weight < 0.0:
|
| 415 |
+
raise SystemExit("--distill_attn_mse_weight must be >= 0")
|
| 416 |
+
mlp_mse_weight = float(getattr(args, "distill_mlp_mse_weight", 0.0))
|
| 417 |
+
if mlp_mse_weight < 0.0:
|
| 418 |
+
raise SystemExit("--distill_mlp_mse_weight must be >= 0")
|
| 419 |
+
param_subset = str(getattr(args, "reparam_param_subset", "all"))
|
| 420 |
+
if param_subset not in {"all", "mlp", "attn"}:
|
| 421 |
+
raise SystemExit("--reparam_param_subset must be one of: all, mlp, attn")
|
| 422 |
+
|
| 423 |
+
kl_weight = float(args.distill_kl_weight)
|
| 424 |
+
kl_temp = float(args.distill_kl_temp)
|
| 425 |
+
if kl_weight < 0.0:
|
| 426 |
+
raise SystemExit("--distill_kl_weight must be >= 0")
|
| 427 |
+
if kl_temp <= 0.0:
|
| 428 |
+
raise SystemExit("--distill_kl_temp must be > 0")
|
| 429 |
+
|
| 430 |
+
eta = float(getattr(args, "reparam_eta", 0.0))
|
| 431 |
+
gamma = float(getattr(args, "reparam_gamma", 0.0))
|
| 432 |
+
if eta < 0.0:
|
| 433 |
+
raise SystemExit("--reparam_eta must be >= 0")
|
| 434 |
+
if gamma < 0.0:
|
| 435 |
+
raise SystemExit("--reparam_gamma must be >= 0")
|
| 436 |
+
attn_reg_scale = float(getattr(args, "reparam_attn_reg_scale", 1.0))
|
| 437 |
+
mlp_reg_scale = float(getattr(args, "reparam_mlp_reg_scale", 1.0))
|
| 438 |
+
if attn_reg_scale < 0.0:
|
| 439 |
+
raise SystemExit("--reparam_attn_reg_scale must be >= 0")
|
| 440 |
+
if mlp_reg_scale < 0.0:
|
| 441 |
+
raise SystemExit("--reparam_mlp_reg_scale must be >= 0")
|
| 442 |
+
if (
|
| 443 |
+
total_epochs > 0.0
|
| 444 |
+
and hidden_mse_weight == 0.0
|
| 445 |
+
and attn_mse_weight == 0.0
|
| 446 |
+
and mlp_mse_weight == 0.0
|
| 447 |
+
and kl_weight == 0.0
|
| 448 |
+
and eta == 0.0
|
| 449 |
+
and gamma == 0.0
|
| 450 |
+
):
|
| 451 |
+
raise SystemExit(
|
| 452 |
+
"Reparam distillation has no active loss terms. "
|
| 453 |
+
"Enable hidden/attention/MLP MSE, KL, or at least one reparam regularizer."
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if not gate_lambdas:
|
| 457 |
+
raise SystemExit("Reparam distillation requires non-empty gate lambdas.")
|
| 458 |
+
|
| 459 |
+
layer_a = student_layers[layer_idx]
|
| 460 |
+
layer_b = student_layers[layer_idx + 1]
|
| 461 |
+
|
| 462 |
+
reparam_layer = ReparamMergedLayer(
|
| 463 |
+
layer_a,
|
| 464 |
+
layer_b,
|
| 465 |
+
gate_lambdas,
|
| 466 |
+
param_subset=param_subset,
|
| 467 |
+
clamp_eps=1e-4,
|
| 468 |
+
)
|
| 469 |
+
if not reparam_layer._name_map:
|
| 470 |
+
raise RuntimeError(
|
| 471 |
+
"No mergeable parameters found for reparam distillation under "
|
| 472 |
+
f"--reparam_param_subset={param_subset!r}."
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
teacher_attn = None
|
| 476 |
+
student_attn = None
|
| 477 |
+
if attn_mse_weight > 0.0:
|
| 478 |
+
try:
|
| 479 |
+
teacher_attn = find_attention_module(teacher_layers[layer_idx + 1])
|
| 480 |
+
student_attn = find_attention_module(reparam_layer.layer_a)
|
| 481 |
+
except ValueError as exc:
|
| 482 |
+
raise SystemExit(
|
| 483 |
+
"Attention-output preservation was requested but an attention module "
|
| 484 |
+
f"could not be resolved: {exc}"
|
| 485 |
+
) from exc
|
| 486 |
+
|
| 487 |
+
teacher_mlp = None
|
| 488 |
+
student_mlp = None
|
| 489 |
+
if mlp_mse_weight > 0.0:
|
| 490 |
+
try:
|
| 491 |
+
teacher_mlp = find_mlp_module(teacher_layers[layer_idx + 1])
|
| 492 |
+
student_mlp = find_mlp_module(reparam_layer.layer_a)
|
| 493 |
+
except ValueError as exc:
|
| 494 |
+
raise SystemExit(
|
| 495 |
+
"MLP-output preservation was requested but an MLP module could not be "
|
| 496 |
+
f"resolved: {exc}"
|
| 497 |
+
) from exc
|
| 498 |
+
|
| 499 |
+
# Virtual layer list: replace layer_a with reparam layer and remove layer_b.
|
| 500 |
+
virtual_layers = list(student_layers)
|
| 501 |
+
virtual_layers[layer_idx] = reparam_layer
|
| 502 |
+
del virtual_layers[layer_idx + 1]
|
| 503 |
+
|
| 504 |
+
# Only (U, s) are trainable.
|
| 505 |
+
for param in student_model.parameters():
|
| 506 |
+
param.requires_grad_(False)
|
| 507 |
+
for param in reparam_layer.gates.parameters():
|
| 508 |
+
param.requires_grad_(True)
|
| 509 |
+
for param in reparam_layer.u.parameters():
|
| 510 |
+
param.requires_grad_(True)
|
| 511 |
+
|
| 512 |
+
do_train = total_epochs > 0.0
|
| 513 |
+
if do_train:
|
| 514 |
+
teacher_model.eval()
|
| 515 |
+
student_model.train()
|
| 516 |
+
|
| 517 |
+
# Rough memory heads-up (esp. when --fisher_mode param makes per-element gates).
|
| 518 |
+
total_gate_elems = sum(int(p.numel()) for p in reparam_layer.gates.parameters())
|
| 519 |
+
total_u_elems = sum(int(p.numel()) for p in reparam_layer.u.parameters())
|
| 520 |
+
gate_mib = total_gate_elems * 4.0 / (1024.0 * 1024.0)
|
| 521 |
+
u_mib = total_u_elems * 4.0 / (1024.0 * 1024.0)
|
| 522 |
+
family_counts: Dict[str, int] = {"attn": 0, "mlp": 0, "other": 0}
|
| 523 |
+
for orig in reparam_layer._name_map:
|
| 524 |
+
family_counts[_classify_param_family(orig)] += 1
|
| 525 |
+
print(
|
| 526 |
+
f"[reparam] subset={param_subset} gates={len(reparam_layer.gates)} "
|
| 527 |
+
f"(attn={family_counts['attn']}, mlp={family_counts['mlp']}, other={family_counts['other']}) "
|
| 528 |
+
f"elems={total_gate_elems} (~{gate_mib:.1f} MiB), "
|
| 529 |
+
f"U_elems={total_u_elems} (~{u_mib:.1f} MiB; +optimizer state)"
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
optimizer = None
|
| 533 |
+
if do_train:
|
| 534 |
+
optimizer = torch.optim.AdamW(
|
| 535 |
+
[*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()],
|
| 536 |
+
lr=float(args.distill_lr),
|
| 537 |
+
weight_decay=float(args.distill_weight_decay),
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
device_type = torch.device(args.device).type
|
| 541 |
+
amp_dtype = None
|
| 542 |
+
if args.dtype == "float16":
|
| 543 |
+
amp_dtype = torch.float16
|
| 544 |
+
elif args.dtype == "bfloat16":
|
| 545 |
+
amp_dtype = torch.bfloat16
|
| 546 |
+
use_amp = do_train and amp_dtype is not None and device_type == "cuda"
|
| 547 |
+
use_scaler = use_amp and amp_dtype == torch.float16
|
| 548 |
+
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
| 549 |
+
|
| 550 |
+
full_epochs = int(total_epochs) if do_train else 0
|
| 551 |
+
fractional = (total_epochs - full_epochs) if do_train else 0.0
|
| 552 |
+
if fractional < 1e-8:
|
| 553 |
+
fractional = 0.0
|
| 554 |
+
|
| 555 |
+
epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)]
|
| 556 |
+
if fractional > 0:
|
| 557 |
+
try:
|
| 558 |
+
batches_per_epoch = len(dataloader)
|
| 559 |
+
except TypeError as exc:
|
| 560 |
+
raise SystemExit(
|
| 561 |
+
"Fractional distill epochs require a dataloader with finite length."
|
| 562 |
+
) from exc
|
| 563 |
+
if batches_per_epoch > 0:
|
| 564 |
+
frac_batches = int(round(fractional * batches_per_epoch))
|
| 565 |
+
if frac_batches <= 0:
|
| 566 |
+
frac_batches = 1
|
| 567 |
+
epoch_plan.append((full_epochs, frac_batches))
|
| 568 |
+
|
| 569 |
+
grad_accum = int(getattr(args, "distill_grad_accum_steps", 1))
|
| 570 |
+
if grad_accum <= 0:
|
| 571 |
+
raise SystemExit("--distill_grad_accum_steps must be >= 1")
|
| 572 |
+
|
| 573 |
+
log_steps = int(getattr(args, "distill_log_steps", 100))
|
| 574 |
+
max_grad_norm = getattr(args, "distill_max_grad_norm", 1.0)
|
| 575 |
+
|
| 576 |
+
params_a = {name: p for name, p in layer_a.named_parameters()}
|
| 577 |
+
params_b = {name: p for name, p in layer_b.named_parameters()}
|
| 578 |
+
|
| 579 |
+
step = 0
|
| 580 |
+
for epoch_idx, max_batches in epoch_plan:
|
| 581 |
+
if max_batches is None:
|
| 582 |
+
epoch_iter = dataloader
|
| 583 |
+
else:
|
| 584 |
+
epoch_iter = itertools.islice(dataloader, max_batches)
|
| 585 |
+
iterator = epoch_iter
|
| 586 |
+
if tqdm is not None and _tqdm_enabled():
|
| 587 |
+
if progressive_cycle is not None:
|
| 588 |
+
if progressive_total is not None:
|
| 589 |
+
desc = (
|
| 590 |
+
f"Reparam (cycle {progressive_cycle}/{progressive_total}, "
|
| 591 |
+
f"epoch {epoch_idx+1})"
|
| 592 |
+
)
|
| 593 |
+
else:
|
| 594 |
+
desc = f"Reparam (cycle {progressive_cycle}, epoch {epoch_idx+1})"
|
| 595 |
+
else:
|
| 596 |
+
desc = f"Reparam (epoch {epoch_idx+1})"
|
| 597 |
+
iterator = tqdm(epoch_iter, desc=desc, unit="batch", total=max_batches)
|
| 598 |
+
|
| 599 |
+
for batch in iterator:
|
| 600 |
+
input_ids = batch[0].to(args.device)
|
| 601 |
+
attention_mask = batch[1].to(args.device)
|
| 602 |
+
teacher_ids = input_ids.to(args.distill_teacher_device or args.device)
|
| 603 |
+
teacher_mask = attention_mask.to(args.distill_teacher_device or args.device)
|
| 604 |
+
|
| 605 |
+
teacher_depth = layer_idx + 2
|
| 606 |
+
student_depth = layer_idx + 1
|
| 607 |
+
|
| 608 |
+
autocast_ctx = (
|
| 609 |
+
torch.autocast(device_type=device_type, dtype=amp_dtype)
|
| 610 |
+
if use_amp
|
| 611 |
+
else nullcontext()
|
| 612 |
+
)
|
| 613 |
+
with autocast_ctx:
|
| 614 |
+
teacher_attn_ctx = (
|
| 615 |
+
capture_module_output(teacher_attn)
|
| 616 |
+
if teacher_attn is not None
|
| 617 |
+
else nullcontext({"output": None})
|
| 618 |
+
)
|
| 619 |
+
teacher_mlp_ctx = (
|
| 620 |
+
capture_module_output(teacher_mlp)
|
| 621 |
+
if teacher_mlp is not None
|
| 622 |
+
else nullcontext({"output": None})
|
| 623 |
+
)
|
| 624 |
+
with torch.no_grad():
|
| 625 |
+
with teacher_attn_ctx as teacher_attn_cache, teacher_mlp_ctx as teacher_mlp_cache:
|
| 626 |
+
teacher_hidden = forward_truncated(
|
| 627 |
+
teacher_parent,
|
| 628 |
+
teacher_layer_attr,
|
| 629 |
+
teacher_layers,
|
| 630 |
+
teacher_depth,
|
| 631 |
+
teacher_ids,
|
| 632 |
+
attention_mask=teacher_mask,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
student_attn_ctx = (
|
| 636 |
+
capture_module_output(student_attn)
|
| 637 |
+
if student_attn is not None
|
| 638 |
+
else nullcontext({"output": None})
|
| 639 |
+
)
|
| 640 |
+
student_mlp_ctx = (
|
| 641 |
+
capture_module_output(student_mlp)
|
| 642 |
+
if student_mlp is not None
|
| 643 |
+
else nullcontext({"output": None})
|
| 644 |
+
)
|
| 645 |
+
with student_attn_ctx as student_attn_cache, student_mlp_ctx as student_mlp_cache:
|
| 646 |
+
student_hidden = forward_truncated(
|
| 647 |
+
student_parent,
|
| 648 |
+
student_layer_attr,
|
| 649 |
+
virtual_layers,
|
| 650 |
+
student_depth,
|
| 651 |
+
input_ids,
|
| 652 |
+
attention_mask=attention_mask,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
if teacher_hidden.device != student_hidden.device:
|
| 656 |
+
teacher_hidden = teacher_hidden.to(student_hidden.device)
|
| 657 |
+
|
| 658 |
+
mse_loss = None
|
| 659 |
+
if hidden_mse_weight > 0.0:
|
| 660 |
+
diff = student_hidden - teacher_hidden
|
| 661 |
+
mse_loss = _masked_hidden_mse(diff, attention_mask)
|
| 662 |
+
if mse_loss is None:
|
| 663 |
+
continue
|
| 664 |
+
|
| 665 |
+
attn_aux_loss = None
|
| 666 |
+
if attn_mse_weight > 0.0:
|
| 667 |
+
teacher_attn_hidden = teacher_attn_cache.get("output")
|
| 668 |
+
student_attn_hidden = student_attn_cache.get("output")
|
| 669 |
+
if teacher_attn_hidden is None or student_attn_hidden is None:
|
| 670 |
+
raise RuntimeError(
|
| 671 |
+
"Attention-output preservation is enabled, but the forward "
|
| 672 |
+
"hook did not capture attention outputs."
|
| 673 |
+
)
|
| 674 |
+
if teacher_attn_hidden.device != student_attn_hidden.device:
|
| 675 |
+
teacher_attn_hidden = teacher_attn_hidden.to(student_attn_hidden.device)
|
| 676 |
+
attn_aux_loss = _masked_hidden_mse(
|
| 677 |
+
student_attn_hidden - teacher_attn_hidden,
|
| 678 |
+
attention_mask,
|
| 679 |
+
)
|
| 680 |
+
if attn_aux_loss is None:
|
| 681 |
+
continue
|
| 682 |
+
|
| 683 |
+
mlp_aux_loss = None
|
| 684 |
+
if mlp_mse_weight > 0.0:
|
| 685 |
+
teacher_mlp_hidden = teacher_mlp_cache.get("output")
|
| 686 |
+
student_mlp_hidden = student_mlp_cache.get("output")
|
| 687 |
+
if teacher_mlp_hidden is None or student_mlp_hidden is None:
|
| 688 |
+
raise RuntimeError(
|
| 689 |
+
"MLP-output preservation is enabled, but the forward hook "
|
| 690 |
+
"did not capture MLP outputs."
|
| 691 |
+
)
|
| 692 |
+
if teacher_mlp_hidden.device != student_mlp_hidden.device:
|
| 693 |
+
teacher_mlp_hidden = teacher_mlp_hidden.to(student_mlp_hidden.device)
|
| 694 |
+
mlp_aux_loss = _masked_hidden_mse(
|
| 695 |
+
student_mlp_hidden - teacher_mlp_hidden,
|
| 696 |
+
attention_mask,
|
| 697 |
+
)
|
| 698 |
+
if mlp_aux_loss is None:
|
| 699 |
+
continue
|
| 700 |
+
|
| 701 |
+
kl_loss = None
|
| 702 |
+
if kl_weight > 0.0:
|
| 703 |
+
with torch.no_grad():
|
| 704 |
+
teacher_outputs = teacher_model(
|
| 705 |
+
input_ids=teacher_ids,
|
| 706 |
+
attention_mask=teacher_mask,
|
| 707 |
+
use_cache=False,
|
| 708 |
+
)
|
| 709 |
+
teacher_logits = teacher_outputs.logits
|
| 710 |
+
|
| 711 |
+
virtual_container = torch.nn.ModuleList(virtual_layers)
|
| 712 |
+
with temporary_layers(
|
| 713 |
+
student_parent, student_layer_attr, virtual_container
|
| 714 |
+
):
|
| 715 |
+
student_outputs = student_model(
|
| 716 |
+
input_ids=input_ids,
|
| 717 |
+
attention_mask=attention_mask,
|
| 718 |
+
use_cache=False,
|
| 719 |
+
)
|
| 720 |
+
student_logits = student_outputs.logits
|
| 721 |
+
if teacher_logits.device != student_logits.device:
|
| 722 |
+
teacher_logits = teacher_logits.to(student_logits.device)
|
| 723 |
+
|
| 724 |
+
shift_teacher_logits = teacher_logits[:, :-1, :].contiguous()
|
| 725 |
+
shift_student_logits = student_logits[:, :-1, :].contiguous()
|
| 726 |
+
shift_mask = attention_mask[:, 1:].contiguous()
|
| 727 |
+
log_p_t = F.log_softmax(shift_teacher_logits / kl_temp, dim=-1)
|
| 728 |
+
log_p_s = F.log_softmax(shift_student_logits / kl_temp, dim=-1)
|
| 729 |
+
p_t = log_p_t.exp()
|
| 730 |
+
kl_flat = (p_t * (log_p_t - log_p_s)).sum(dim=-1)
|
| 731 |
+
kl_denom = shift_mask.sum()
|
| 732 |
+
if kl_denom.item() == 0:
|
| 733 |
+
continue
|
| 734 |
+
kl_loss = (
|
| 735 |
+
kl_flat * shift_mask.to(kl_flat.dtype)
|
| 736 |
+
).sum() / kl_denom
|
| 737 |
+
|
| 738 |
+
lambda_reg = None
|
| 739 |
+
if eta > 0.0:
|
| 740 |
+
reg_sum: Optional[torch.Tensor] = None
|
| 741 |
+
reg_elems = 0
|
| 742 |
+
for orig, safe in reparam_layer._name_map.items():
|
| 743 |
+
lam = torch.sigmoid(reparam_layer.gates[safe]).float()
|
| 744 |
+
target = gate_lambdas.get(orig)
|
| 745 |
+
if target is None:
|
| 746 |
+
target_t = 0.5
|
| 747 |
+
elif isinstance(target, torch.Tensor):
|
| 748 |
+
target_t = target.to(device=lam.device, dtype=lam.dtype)
|
| 749 |
+
else:
|
| 750 |
+
target_t = float(target)
|
| 751 |
+
diff_lam = lam - target_t
|
| 752 |
+
family = _classify_param_family(orig)
|
| 753 |
+
scale = _family_reg_scale(
|
| 754 |
+
family,
|
| 755 |
+
attn_scale=attn_reg_scale,
|
| 756 |
+
mlp_scale=mlp_reg_scale,
|
| 757 |
+
)
|
| 758 |
+
if scale <= 0.0:
|
| 759 |
+
continue
|
| 760 |
+
part = diff_lam.pow(2).sum() * scale
|
| 761 |
+
reg_sum = part if reg_sum is None else reg_sum + part
|
| 762 |
+
reg_elems += int(float(diff_lam.numel()) * scale)
|
| 763 |
+
if reg_elems > 0 and reg_sum is not None:
|
| 764 |
+
lambda_reg = reg_sum / float(reg_elems)
|
| 765 |
+
|
| 766 |
+
u_reg = None
|
| 767 |
+
if gamma > 0.0:
|
| 768 |
+
reg_sum: Optional[torch.Tensor] = None
|
| 769 |
+
reg_elems = 0
|
| 770 |
+
for orig, safe in reparam_layer._name_map.items():
|
| 771 |
+
u = reparam_layer.u[safe].float()
|
| 772 |
+
param_a = params_a.get(orig)
|
| 773 |
+
param_b = params_b.get(orig)
|
| 774 |
+
if param_a is None or param_b is None or param_b.shape != param_a.shape:
|
| 775 |
+
continue
|
| 776 |
+
u0 = 0.5 * (param_a.detach().float() - param_b.detach().float())
|
| 777 |
+
diff_u = u - u0
|
| 778 |
+
family = _classify_param_family(orig)
|
| 779 |
+
scale = _family_reg_scale(
|
| 780 |
+
family,
|
| 781 |
+
attn_scale=attn_reg_scale,
|
| 782 |
+
mlp_scale=mlp_reg_scale,
|
| 783 |
+
)
|
| 784 |
+
if scale <= 0.0:
|
| 785 |
+
continue
|
| 786 |
+
part = diff_u.pow(2).sum() * scale
|
| 787 |
+
reg_sum = part if reg_sum is None else reg_sum + part
|
| 788 |
+
reg_elems += int(float(diff_u.numel()) * scale)
|
| 789 |
+
if reg_elems > 0 and reg_sum is not None:
|
| 790 |
+
u_reg = reg_sum / float(reg_elems)
|
| 791 |
+
|
| 792 |
+
total_loss = None
|
| 793 |
+
if mse_loss is not None:
|
| 794 |
+
total_loss = hidden_mse_weight * mse_loss
|
| 795 |
+
if attn_aux_loss is not None:
|
| 796 |
+
total_loss = attn_mse_weight * attn_aux_loss if total_loss is None else total_loss + (attn_mse_weight * attn_aux_loss)
|
| 797 |
+
if mlp_aux_loss is not None:
|
| 798 |
+
total_loss = mlp_mse_weight * mlp_aux_loss if total_loss is None else total_loss + (mlp_mse_weight * mlp_aux_loss)
|
| 799 |
+
if kl_loss is not None:
|
| 800 |
+
total_loss = kl_weight * (kl_temp ** 2) * kl_loss if total_loss is None else total_loss + (kl_weight * (kl_temp ** 2) * kl_loss)
|
| 801 |
+
if lambda_reg is not None:
|
| 802 |
+
total_loss = eta * lambda_reg if total_loss is None else total_loss + (eta * lambda_reg)
|
| 803 |
+
if u_reg is not None:
|
| 804 |
+
total_loss = gamma * u_reg if total_loss is None else total_loss + (gamma * u_reg)
|
| 805 |
+
if total_loss is None:
|
| 806 |
+
continue
|
| 807 |
+
|
| 808 |
+
if grad_accum > 1:
|
| 809 |
+
total_loss = total_loss / grad_accum
|
| 810 |
+
if use_scaler:
|
| 811 |
+
scaler.scale(total_loss).backward()
|
| 812 |
+
else:
|
| 813 |
+
total_loss.backward()
|
| 814 |
+
|
| 815 |
+
if (step + 1) % grad_accum == 0:
|
| 816 |
+
if max_grad_norm is not None:
|
| 817 |
+
if use_scaler:
|
| 818 |
+
scaler.unscale_(optimizer)
|
| 819 |
+
torch.nn.utils.clip_grad_norm_(
|
| 820 |
+
[*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()],
|
| 821 |
+
float(max_grad_norm),
|
| 822 |
+
)
|
| 823 |
+
if use_scaler:
|
| 824 |
+
scaler.step(optimizer)
|
| 825 |
+
scaler.update()
|
| 826 |
+
else:
|
| 827 |
+
optimizer.step()
|
| 828 |
+
optimizer.zero_grad(set_to_none=True)
|
| 829 |
+
|
| 830 |
+
if log_steps and (step == 0 or (step + 1) % log_steps == 0):
|
| 831 |
+
log_parts = [f"loss={total_loss.item():.6e}"]
|
| 832 |
+
if mse_loss is not None:
|
| 833 |
+
log_parts.append(f"mse={mse_loss.item():.6e}")
|
| 834 |
+
else:
|
| 835 |
+
log_parts.append("mse=disabled")
|
| 836 |
+
if attn_aux_loss is not None:
|
| 837 |
+
log_parts.append(f"attn_mse={attn_aux_loss.item():.6e}")
|
| 838 |
+
elif attn_mse_weight > 0.0:
|
| 839 |
+
log_parts.append("attn_mse=nan")
|
| 840 |
+
if mlp_aux_loss is not None:
|
| 841 |
+
log_parts.append(f"mlp_mse={mlp_aux_loss.item():.6e}")
|
| 842 |
+
elif mlp_mse_weight > 0.0:
|
| 843 |
+
log_parts.append("mlp_mse=nan")
|
| 844 |
+
if kl_loss is not None:
|
| 845 |
+
log_parts.append(f"kl={kl_loss.item():.6e}")
|
| 846 |
+
if lambda_reg is not None:
|
| 847 |
+
log_parts.append(f"lam_reg={lambda_reg.item():.6e}")
|
| 848 |
+
if u_reg is not None:
|
| 849 |
+
log_parts.append(f"u_reg={u_reg.item():.6e}")
|
| 850 |
+
print(
|
| 851 |
+
f"[reparam] epoch={epoch_idx+1} step={step+1} " + " ".join(log_parts)
|
| 852 |
+
)
|
| 853 |
+
step += 1
|
| 854 |
+
|
| 855 |
+
merged = reparam_layer.materialize_into_layer_a()
|
| 856 |
+
final_lambdas = reparam_layer.gate_lambdas()
|
| 857 |
+
stats: Dict[str, object] = {
|
| 858 |
+
"enabled": True,
|
| 859 |
+
"epochs": total_epochs,
|
| 860 |
+
"lr": float(args.distill_lr),
|
| 861 |
+
"hidden_mse_weight": hidden_mse_weight,
|
| 862 |
+
"attn_mse_weight": attn_mse_weight,
|
| 863 |
+
"mlp_mse_weight": mlp_mse_weight,
|
| 864 |
+
"eta": eta,
|
| 865 |
+
"gamma": gamma,
|
| 866 |
+
"attn_reg_scale": attn_reg_scale,
|
| 867 |
+
"mlp_reg_scale": mlp_reg_scale,
|
| 868 |
+
"param_subset": param_subset,
|
| 869 |
+
"num_gates": len(final_lambdas),
|
| 870 |
+
"num_attn_gates": family_counts["attn"],
|
| 871 |
+
"num_mlp_gates": family_counts["mlp"],
|
| 872 |
+
"num_other_gates": family_counts["other"],
|
| 873 |
+
}
|
| 874 |
+
return merged, final_lambdas, stats
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
class LoRALinear(torch.nn.Module):
|
| 878 |
+
def __init__(
|
| 879 |
+
self,
|
| 880 |
+
base: torch.nn.Linear,
|
| 881 |
+
rank: int,
|
| 882 |
+
alpha: float,
|
| 883 |
+
dropout: float,
|
| 884 |
+
) -> None:
|
| 885 |
+
super().__init__()
|
| 886 |
+
if rank <= 0:
|
| 887 |
+
raise ValueError("LoRA rank must be positive")
|
| 888 |
+
self.base = base
|
| 889 |
+
self.rank = int(rank)
|
| 890 |
+
self.alpha = float(alpha)
|
| 891 |
+
self.scaling = self.alpha / float(self.rank)
|
| 892 |
+
self.enabled = True
|
| 893 |
+
if dropout > 0:
|
| 894 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 895 |
+
else:
|
| 896 |
+
self.dropout = torch.nn.Identity()
|
| 897 |
+
|
| 898 |
+
self.lora_A = torch.nn.Linear(base.in_features, self.rank, bias=False)
|
| 899 |
+
self.lora_B = torch.nn.Linear(self.rank, base.out_features, bias=False)
|
| 900 |
+
torch.nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
| 901 |
+
torch.nn.init.zeros_(self.lora_B.weight)
|
| 902 |
+
|
| 903 |
+
self.lora_A.to(device=base.weight.device, dtype=base.weight.dtype)
|
| 904 |
+
self.lora_B.to(device=base.weight.device, dtype=base.weight.dtype)
|
| 905 |
+
self.merged = False
|
| 906 |
+
|
| 907 |
+
def lora_parameters(self) -> List[torch.nn.Parameter]:
|
| 908 |
+
return [*self.lora_A.parameters(), *self.lora_B.parameters()]
|
| 909 |
+
|
| 910 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 911 |
+
result = self.base(x)
|
| 912 |
+
if self.merged or not self.enabled:
|
| 913 |
+
return result
|
| 914 |
+
lora_out = self.lora_B(self.lora_A(self.dropout(x)))
|
| 915 |
+
return result + lora_out * self.scaling
|
| 916 |
+
|
| 917 |
+
def merge(self) -> None:
|
| 918 |
+
if self.merged:
|
| 919 |
+
return
|
| 920 |
+
delta = torch.matmul(self.lora_B.weight, self.lora_A.weight)
|
| 921 |
+
delta = delta.to(dtype=self.base.weight.dtype) * self.scaling
|
| 922 |
+
self.base.weight.data.add_(delta)
|
| 923 |
+
self.merged = True
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def _get_child_module(parent: torch.nn.Module, part: str) -> torch.nn.Module:
|
| 927 |
+
if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit():
|
| 928 |
+
return parent[int(part)]
|
| 929 |
+
if isinstance(parent, torch.nn.ModuleDict):
|
| 930 |
+
return parent[part]
|
| 931 |
+
return getattr(parent, part)
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def _set_child_module(parent: torch.nn.Module, part: str, module: torch.nn.Module) -> None:
|
| 935 |
+
if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit():
|
| 936 |
+
parent[int(part)] = module
|
| 937 |
+
return
|
| 938 |
+
if isinstance(parent, torch.nn.ModuleDict):
|
| 939 |
+
parent[part] = module
|
| 940 |
+
return
|
| 941 |
+
setattr(parent, part, module)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def _resolve_parent_module(
|
| 945 |
+
root: torch.nn.Module, module_name: str
|
| 946 |
+
) -> Optional[tuple]:
|
| 947 |
+
if not module_name:
|
| 948 |
+
return None
|
| 949 |
+
parts = module_name.split(".")
|
| 950 |
+
parent = root
|
| 951 |
+
for part in parts[:-1]:
|
| 952 |
+
parent = _get_child_module(parent, part)
|
| 953 |
+
return parent, parts[-1]
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
def _resolve_module_by_path(root: torch.nn.Module, module_path: str) -> Optional[torch.nn.Module]:
|
| 957 |
+
if not module_path:
|
| 958 |
+
return None
|
| 959 |
+
parts = [part for part in module_path.split(".") if part]
|
| 960 |
+
node = root
|
| 961 |
+
for part in parts:
|
| 962 |
+
try:
|
| 963 |
+
node = _get_child_module(node, part)
|
| 964 |
+
except Exception:
|
| 965 |
+
return None
|
| 966 |
+
return node
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def _resolve_layer_container_for_lora(
|
| 970 |
+
model: torch.nn.Module, layer_path: Optional[str]
|
| 971 |
+
) -> Tuple[Optional[str], Optional[object]]:
|
| 972 |
+
"""Resolve transformer layer container with optional auto-detection.
|
| 973 |
+
|
| 974 |
+
Mirrors the candidate path strategy used elsewhere, so LoRA filtering can work
|
| 975 |
+
even when --layer_path is not provided.
|
| 976 |
+
"""
|
| 977 |
+
if isinstance(layer_path, str) and layer_path and layer_path.lower() != "none":
|
| 978 |
+
container = _resolve_module_by_path(model, layer_path)
|
| 979 |
+
if container is not None:
|
| 980 |
+
try:
|
| 981 |
+
list(container)
|
| 982 |
+
return layer_path, container
|
| 983 |
+
except TypeError:
|
| 984 |
+
pass
|
| 985 |
+
|
| 986 |
+
candidate_paths = [
|
| 987 |
+
"model.layers", # LLaMA, Mistral, Qwen2, Gemma
|
| 988 |
+
"model.decoder.layers", # OPT
|
| 989 |
+
"transformer.h", # GPT-2, GPT-J, Bloom, Falcon
|
| 990 |
+
"transformer.blocks", # MPT
|
| 991 |
+
"gpt_neox.layers", # GPT-NeoX
|
| 992 |
+
"layers", # fallback
|
| 993 |
+
]
|
| 994 |
+
for path in candidate_paths:
|
| 995 |
+
container = _resolve_module_by_path(model, path)
|
| 996 |
+
if container is None:
|
| 997 |
+
continue
|
| 998 |
+
try:
|
| 999 |
+
list(container)
|
| 1000 |
+
except TypeError:
|
| 1001 |
+
continue
|
| 1002 |
+
return path, container
|
| 1003 |
+
|
| 1004 |
+
return None, None
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
def _parse_exclude_pairs_local(raw_values, num_pairs: int) -> Set[int]:
|
| 1008 |
+
if not raw_values or num_pairs <= 0:
|
| 1009 |
+
return set()
|
| 1010 |
+
exclude: Set[int] = set()
|
| 1011 |
+
for item in raw_values:
|
| 1012 |
+
if item is None:
|
| 1013 |
+
continue
|
| 1014 |
+
for part in str(item).split(","):
|
| 1015 |
+
part = part.strip()
|
| 1016 |
+
if not part:
|
| 1017 |
+
continue
|
| 1018 |
+
try:
|
| 1019 |
+
idx = int(part)
|
| 1020 |
+
except ValueError as exc:
|
| 1021 |
+
raise SystemExit("--exclude_pairs must contain integers.") from exc
|
| 1022 |
+
if idx < 0:
|
| 1023 |
+
idx = num_pairs + idx
|
| 1024 |
+
if 0 <= idx < num_pairs:
|
| 1025 |
+
exclude.add(idx)
|
| 1026 |
+
return exclude
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
def _extract_layer_index_from_module_name(
|
| 1030 |
+
module_name: str, layer_path: str
|
| 1031 |
+
) -> Optional[int]:
|
| 1032 |
+
if not layer_path:
|
| 1033 |
+
return None
|
| 1034 |
+
prefix = f"{layer_path}."
|
| 1035 |
+
if not module_name.startswith(prefix):
|
| 1036 |
+
return None
|
| 1037 |
+
rest = module_name[len(prefix) :]
|
| 1038 |
+
if not rest:
|
| 1039 |
+
return None
|
| 1040 |
+
idx_text = rest.split(".", 1)[0]
|
| 1041 |
+
if not idx_text.isdigit():
|
| 1042 |
+
return None
|
| 1043 |
+
return int(idx_text)
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
def _select_linear_modules_for_lora_targets(
|
| 1047 |
+
model: torch.nn.Module,
|
| 1048 |
+
args: argparse.Namespace,
|
| 1049 |
+
*,
|
| 1050 |
+
log_tag: str,
|
| 1051 |
+
) -> Tuple[List[Tuple[str, torch.nn.Linear]], Optional[Set[str]], Set[int], Optional[str]]:
|
| 1052 |
+
raw_targets = getattr(args, "lora_target_modules", None)
|
| 1053 |
+
target_modules: Optional[Set[str]] = None
|
| 1054 |
+
if raw_targets:
|
| 1055 |
+
target_modules = {str(item) for item in raw_targets if str(item)}
|
| 1056 |
+
|
| 1057 |
+
exclude_layer_indices: Set[int] = set()
|
| 1058 |
+
resolved_layer_path: Optional[str] = None
|
| 1059 |
+
if bool(getattr(args, "lora_respect_exclude_pairs", False)):
|
| 1060 |
+
requested_layer_path = getattr(args, "layer_path", None)
|
| 1061 |
+
resolved_layer_path, layer_container = _resolve_layer_container_for_lora(
|
| 1062 |
+
model, requested_layer_path
|
| 1063 |
+
)
|
| 1064 |
+
if isinstance(layer_container, (torch.nn.ModuleList, list, tuple)):
|
| 1065 |
+
num_pairs = max(len(layer_container) - 1, 0)
|
| 1066 |
+
exclude_pairs = _parse_exclude_pairs_local(
|
| 1067 |
+
getattr(args, "exclude_pairs", None), num_pairs
|
| 1068 |
+
)
|
| 1069 |
+
for pair_idx in exclude_pairs:
|
| 1070 |
+
exclude_layer_indices.add(pair_idx)
|
| 1071 |
+
exclude_layer_indices.add(pair_idx + 1)
|
| 1072 |
+
else:
|
| 1073 |
+
print(
|
| 1074 |
+
f"[{log_tag}] Warning: --lora_respect_exclude_pairs enabled, but "
|
| 1075 |
+
f"could not resolve layer path '{requested_layer_path}'."
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
linear_modules = [
|
| 1079 |
+
(name, module)
|
| 1080 |
+
for name, module in model.named_modules()
|
| 1081 |
+
if isinstance(module, torch.nn.Linear)
|
| 1082 |
+
and (target_modules is None or name.split(".")[-1] in target_modules)
|
| 1083 |
+
and (
|
| 1084 |
+
not exclude_layer_indices
|
| 1085 |
+
or _extract_layer_index_from_module_name(name, resolved_layer_path or "")
|
| 1086 |
+
not in exclude_layer_indices
|
| 1087 |
+
)
|
| 1088 |
+
]
|
| 1089 |
+
return linear_modules, target_modules, exclude_layer_indices, resolved_layer_path
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
def apply_lora_adapters(
|
| 1093 |
+
model: torch.nn.Module, args: argparse.Namespace
|
| 1094 |
+
) -> List[LoRALinear]:
|
| 1095 |
+
if args.lora_rank <= 0:
|
| 1096 |
+
raise SystemExit("--lora_rank must be > 0 when --lora_epochs > 0")
|
| 1097 |
+
linear_modules, target_modules, exclude_layer_indices, _ = (
|
| 1098 |
+
_select_linear_modules_for_lora_targets(model, args, log_tag="lora")
|
| 1099 |
+
)
|
| 1100 |
+
if not linear_modules:
|
| 1101 |
+
raise SystemExit(
|
| 1102 |
+
"No Linear modules found for LoRA adapters "
|
| 1103 |
+
"(check --lora_target_modules / --exclude_pairs / --lora_respect_exclude_pairs)."
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
lora_modules: List[LoRALinear] = []
|
| 1107 |
+
for name, module in linear_modules:
|
| 1108 |
+
resolved = _resolve_parent_module(model, name)
|
| 1109 |
+
if resolved is None:
|
| 1110 |
+
continue
|
| 1111 |
+
parent, attr = resolved
|
| 1112 |
+
wrapped = LoRALinear(
|
| 1113 |
+
base=module,
|
| 1114 |
+
rank=args.lora_rank,
|
| 1115 |
+
alpha=args.lora_alpha,
|
| 1116 |
+
dropout=args.lora_dropout,
|
| 1117 |
+
)
|
| 1118 |
+
_set_child_module(parent, attr, wrapped)
|
| 1119 |
+
lora_modules.append(wrapped)
|
| 1120 |
+
|
| 1121 |
+
for param in model.parameters():
|
| 1122 |
+
param.requires_grad_(False)
|
| 1123 |
+
for lora_module in lora_modules:
|
| 1124 |
+
for param in lora_module.lora_parameters():
|
| 1125 |
+
param.requires_grad_(True)
|
| 1126 |
+
|
| 1127 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 1128 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1129 |
+
percent = 100.0 * trainable_params / max(total_params, 1)
|
| 1130 |
+
target_note = ""
|
| 1131 |
+
if target_modules is not None:
|
| 1132 |
+
target_note = f" target={sorted(target_modules)}"
|
| 1133 |
+
exclude_note = ""
|
| 1134 |
+
if exclude_layer_indices:
|
| 1135 |
+
exclude_note = f" excluded_layers={sorted(exclude_layer_indices)}"
|
| 1136 |
+
print(
|
| 1137 |
+
"[lora] Applied adapters to "
|
| 1138 |
+
f"{len(lora_modules)} linear modules "
|
| 1139 |
+
f"({trainable_params}/{total_params} trainable, {percent:.4f}%)."
|
| 1140 |
+
f"{target_note}{exclude_note}"
|
| 1141 |
+
)
|
| 1142 |
+
return lora_modules
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
def merge_lora_adapters(model: torch.nn.Module) -> None:
|
| 1146 |
+
lora_entries = [
|
| 1147 |
+
(name, module)
|
| 1148 |
+
for name, module in model.named_modules()
|
| 1149 |
+
if isinstance(module, LoRALinear)
|
| 1150 |
+
]
|
| 1151 |
+
for name, module in lora_entries:
|
| 1152 |
+
module.merge()
|
| 1153 |
+
resolved = _resolve_parent_module(model, name)
|
| 1154 |
+
if resolved is None:
|
| 1155 |
+
continue
|
| 1156 |
+
parent, attr = resolved
|
| 1157 |
+
_set_child_module(parent, attr, module.base)
|
| 1158 |
+
|
| 1159 |
+
|
| 1160 |
+
def set_lora_enabled(lora_modules: List[LoRALinear], enabled: bool) -> None:
|
| 1161 |
+
for module in lora_modules:
|
| 1162 |
+
module.enabled = enabled
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
def lora_ce_finetune(
|
| 1166 |
+
model: torch.nn.Module,
|
| 1167 |
+
dataloader,
|
| 1168 |
+
eval_tokenizer,
|
| 1169 |
+
eval_datasets: List[str],
|
| 1170 |
+
eval_configs: List[Optional[str]],
|
| 1171 |
+
eval_history: List[Dict[str, object]],
|
| 1172 |
+
args: argparse.Namespace,
|
| 1173 |
+
eval_dataloaders: Optional[Dict[str, object]] = None,
|
| 1174 |
+
progressive_cycle: Optional[int] = None,
|
| 1175 |
+
progressive_total: Optional[int] = None,
|
| 1176 |
+
) -> None:
|
| 1177 |
+
total_epochs = float(args.lora_epochs)
|
| 1178 |
+
if total_epochs <= 0:
|
| 1179 |
+
return
|
| 1180 |
+
|
| 1181 |
+
use_kl = bool(getattr(args, "lora_kl_enabled", False))
|
| 1182 |
+
kl_weight = float(getattr(args, "lora_kl_weight", 0.0))
|
| 1183 |
+
kl_temp = float(getattr(args, "lora_kl_temp", 1.0))
|
| 1184 |
+
if use_kl:
|
| 1185 |
+
if kl_weight < 0.0:
|
| 1186 |
+
raise SystemExit("--lora_kl_weight must be >= 0")
|
| 1187 |
+
if kl_temp <= 0.0:
|
| 1188 |
+
raise SystemExit("--lora_kl_temp must be > 0")
|
| 1189 |
+
if kl_weight == 0.0:
|
| 1190 |
+
use_kl = False
|
| 1191 |
+
|
| 1192 |
+
lora_modules = apply_lora_adapters(model, args)
|
| 1193 |
+
if not lora_modules:
|
| 1194 |
+
return
|
| 1195 |
+
|
| 1196 |
+
model.train()
|
| 1197 |
+
|
| 1198 |
+
lora_params = []
|
| 1199 |
+
for module in lora_modules:
|
| 1200 |
+
lora_params.extend(module.lora_parameters())
|
| 1201 |
+
|
| 1202 |
+
optimizer = torch.optim.AdamW(
|
| 1203 |
+
lora_params,
|
| 1204 |
+
lr=args.lora_lr,
|
| 1205 |
+
weight_decay=args.lora_weight_decay,
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
device_type = torch.device(args.device).type
|
| 1209 |
+
amp_dtype = None
|
| 1210 |
+
if args.dtype == "float16":
|
| 1211 |
+
amp_dtype = torch.float16
|
| 1212 |
+
elif args.dtype == "bfloat16":
|
| 1213 |
+
amp_dtype = torch.bfloat16
|
| 1214 |
+
use_amp = amp_dtype is not None and device_type == "cuda"
|
| 1215 |
+
use_scaler = use_amp and amp_dtype == torch.float16
|
| 1216 |
+
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
| 1217 |
+
|
| 1218 |
+
full_epochs = int(total_epochs)
|
| 1219 |
+
fractional = total_epochs - full_epochs
|
| 1220 |
+
if fractional < 1e-8:
|
| 1221 |
+
fractional = 0.0
|
| 1222 |
+
|
| 1223 |
+
epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)]
|
| 1224 |
+
if fractional > 0:
|
| 1225 |
+
try:
|
| 1226 |
+
batches_per_epoch = len(dataloader)
|
| 1227 |
+
except TypeError as exc:
|
| 1228 |
+
raise SystemExit(
|
| 1229 |
+
"Fractional lora epochs require a dataloader with finite length."
|
| 1230 |
+
) from exc
|
| 1231 |
+
if batches_per_epoch > 0:
|
| 1232 |
+
frac_batches = int(round(fractional * batches_per_epoch))
|
| 1233 |
+
if frac_batches <= 0:
|
| 1234 |
+
frac_batches = 1
|
| 1235 |
+
epoch_plan.append((full_epochs, frac_batches))
|
| 1236 |
+
|
| 1237 |
+
step = 0
|
| 1238 |
+
for epoch_idx, max_batches in epoch_plan:
|
| 1239 |
+
if max_batches is None:
|
| 1240 |
+
epoch_iter = dataloader
|
| 1241 |
+
else:
|
| 1242 |
+
epoch_iter = itertools.islice(dataloader, max_batches)
|
| 1243 |
+
iterator = epoch_iter
|
| 1244 |
+
if tqdm is not None and _tqdm_enabled():
|
| 1245 |
+
if progressive_cycle is not None:
|
| 1246 |
+
if progressive_total is not None:
|
| 1247 |
+
desc = (
|
| 1248 |
+
f"LoRA (cycle {progressive_cycle}/{progressive_total}, "
|
| 1249 |
+
f"epoch {epoch_idx+1})"
|
| 1250 |
+
)
|
| 1251 |
+
else:
|
| 1252 |
+
desc = f"LoRA (cycle {progressive_cycle}, epoch {epoch_idx+1})"
|
| 1253 |
+
else:
|
| 1254 |
+
desc = f"LoRA (epoch {epoch_idx+1})"
|
| 1255 |
+
iterator = tqdm(
|
| 1256 |
+
epoch_iter,
|
| 1257 |
+
desc=desc,
|
| 1258 |
+
unit="batch",
|
| 1259 |
+
total=max_batches,
|
| 1260 |
+
)
|
| 1261 |
+
for batch in iterator:
|
| 1262 |
+
input_ids = batch[0].to(args.device)
|
| 1263 |
+
attention_mask = batch[1].to(args.device)
|
| 1264 |
+
autocast_ctx = (
|
| 1265 |
+
torch.autocast(device_type=device_type, dtype=amp_dtype)
|
| 1266 |
+
if use_amp
|
| 1267 |
+
else nullcontext()
|
| 1268 |
+
)
|
| 1269 |
+
with autocast_ctx:
|
| 1270 |
+
outputs = model(
|
| 1271 |
+
input_ids=input_ids,
|
| 1272 |
+
attention_mask=attention_mask,
|
| 1273 |
+
use_cache=False,
|
| 1274 |
+
)
|
| 1275 |
+
logits = outputs.logits
|
| 1276 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 1277 |
+
shift_labels = input_ids[:, 1:].contiguous()
|
| 1278 |
+
shift_mask = attention_mask[:, 1:].contiguous()
|
| 1279 |
+
ce_flat = F.cross_entropy(
|
| 1280 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 1281 |
+
shift_labels.view(-1),
|
| 1282 |
+
reduction="none",
|
| 1283 |
+
)
|
| 1284 |
+
ce_denom = shift_mask.sum()
|
| 1285 |
+
if ce_denom.item() == 0:
|
| 1286 |
+
continue
|
| 1287 |
+
ce_loss = (
|
| 1288 |
+
ce_flat * shift_mask.view(-1).to(ce_flat.dtype)
|
| 1289 |
+
).sum() / ce_denom
|
| 1290 |
+
kl_loss = None
|
| 1291 |
+
if use_kl:
|
| 1292 |
+
set_lora_enabled(lora_modules, False)
|
| 1293 |
+
with torch.no_grad():
|
| 1294 |
+
base_outputs = model(
|
| 1295 |
+
input_ids=input_ids,
|
| 1296 |
+
attention_mask=attention_mask,
|
| 1297 |
+
use_cache=False,
|
| 1298 |
+
)
|
| 1299 |
+
base_logits = base_outputs.logits
|
| 1300 |
+
set_lora_enabled(lora_modules, True)
|
| 1301 |
+
if base_logits.device != shift_logits.device:
|
| 1302 |
+
base_logits = base_logits.to(shift_logits.device)
|
| 1303 |
+
shift_base_logits = base_logits[:, :-1, :].contiguous()
|
| 1304 |
+
log_p_pre = F.log_softmax(shift_base_logits / kl_temp, dim=-1)
|
| 1305 |
+
log_p_post = F.log_softmax(shift_logits / kl_temp, dim=-1)
|
| 1306 |
+
p_pre = log_p_pre.exp()
|
| 1307 |
+
kl_flat = (p_pre * (log_p_pre - log_p_post)).sum(dim=-1)
|
| 1308 |
+
kl_loss = (
|
| 1309 |
+
kl_flat * shift_mask.to(kl_flat.dtype)
|
| 1310 |
+
).sum() / ce_denom
|
| 1311 |
+
|
| 1312 |
+
total_loss = ce_loss
|
| 1313 |
+
if kl_loss is not None:
|
| 1314 |
+
total_loss = total_loss + (kl_weight * (kl_temp ** 2) * kl_loss)
|
| 1315 |
+
|
| 1316 |
+
if args.lora_grad_accum_steps > 1:
|
| 1317 |
+
total_loss = total_loss / args.lora_grad_accum_steps
|
| 1318 |
+
if use_scaler:
|
| 1319 |
+
scaler.scale(total_loss).backward()
|
| 1320 |
+
else:
|
| 1321 |
+
total_loss.backward()
|
| 1322 |
+
|
| 1323 |
+
if (step + 1) % args.lora_grad_accum_steps == 0:
|
| 1324 |
+
if args.lora_max_grad_norm is not None:
|
| 1325 |
+
if use_scaler:
|
| 1326 |
+
scaler.unscale_(optimizer)
|
| 1327 |
+
torch.nn.utils.clip_grad_norm_(
|
| 1328 |
+
lora_params,
|
| 1329 |
+
args.lora_max_grad_norm,
|
| 1330 |
+
)
|
| 1331 |
+
if use_scaler:
|
| 1332 |
+
scaler.step(optimizer)
|
| 1333 |
+
scaler.update()
|
| 1334 |
+
else:
|
| 1335 |
+
optimizer.step()
|
| 1336 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1337 |
+
|
| 1338 |
+
if args.lora_eval_every and (step + 1) % args.lora_eval_every == 0:
|
| 1339 |
+
prev_mode = model.training
|
| 1340 |
+
model.eval()
|
| 1341 |
+
eval_device = args.eval_device or args.device
|
| 1342 |
+
if eval_dataloaders is not None:
|
| 1343 |
+
results = ppl_eval.evaluate_ppl_dataloaders(
|
| 1344 |
+
model,
|
| 1345 |
+
eval_dataloaders,
|
| 1346 |
+
eval_device,
|
| 1347 |
+
max_batches=args.lora_eval_max_batches,
|
| 1348 |
+
)
|
| 1349 |
+
else:
|
| 1350 |
+
results = ppl_eval.evaluate_ppl_datasets(
|
| 1351 |
+
model,
|
| 1352 |
+
eval_tokenizer,
|
| 1353 |
+
datasets=eval_datasets,
|
| 1354 |
+
configs=eval_configs,
|
| 1355 |
+
split=args.eval_split,
|
| 1356 |
+
text_field=args.eval_text_field,
|
| 1357 |
+
num_samples=args.eval_num_samples,
|
| 1358 |
+
seq_len=args.eval_seq_len,
|
| 1359 |
+
batch_size=args.eval_batch_size or args.batch_size,
|
| 1360 |
+
device=eval_device,
|
| 1361 |
+
seed=args.seed,
|
| 1362 |
+
shuffle=False,
|
| 1363 |
+
model_family=args.eval_model_family,
|
| 1364 |
+
add_bos=args.eval_add_bos,
|
| 1365 |
+
max_batches=args.lora_eval_max_batches,
|
| 1366 |
+
cache_dir=args.eval_cache_dir,
|
| 1367 |
+
num_workers=args.eval_num_workers,
|
| 1368 |
+
)
|
| 1369 |
+
eval_history.append({"step": step + 1, "ppl": results})
|
| 1370 |
+
print(f"[lora] eval step={step+1}: {results}")
|
| 1371 |
+
if prev_mode:
|
| 1372 |
+
model.train()
|
| 1373 |
+
|
| 1374 |
+
if args.lora_log_steps and (
|
| 1375 |
+
step == 0 or (step + 1) % args.lora_log_steps == 0
|
| 1376 |
+
):
|
| 1377 |
+
log_parts = [f"loss={total_loss.item():.6f}"]
|
| 1378 |
+
if kl_loss is not None:
|
| 1379 |
+
log_parts.append(f"kl={kl_loss.item():.6f}")
|
| 1380 |
+
print(
|
| 1381 |
+
f"[lora] epoch={epoch_idx+1} step={step+1} "
|
| 1382 |
+
+ " ".join(log_parts)
|
| 1383 |
+
)
|
| 1384 |
+
step += 1
|
| 1385 |
+
|
| 1386 |
+
merge_lora_adapters(model)
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
def _masked_kl(
|
| 1390 |
+
logits_p: torch.Tensor,
|
| 1391 |
+
logits_q: torch.Tensor,
|
| 1392 |
+
attention_mask: torch.Tensor,
|
| 1393 |
+
temp: float,
|
| 1394 |
+
detach_p: bool = True,
|
| 1395 |
+
) -> Optional[torch.Tensor]:
|
| 1396 |
+
shift_mask = attention_mask[:, 1:].contiguous()
|
| 1397 |
+
denom = shift_mask.sum()
|
| 1398 |
+
if denom.item() == 0:
|
| 1399 |
+
return None
|
| 1400 |
+
|
| 1401 |
+
p = logits_p[:, :-1, :].contiguous()
|
| 1402 |
+
q = logits_q[:, :-1, :].contiguous()
|
| 1403 |
+
if p.device != q.device:
|
| 1404 |
+
p = p.to(q.device)
|
| 1405 |
+
|
| 1406 |
+
# Keep dtype to avoid blowing up memory on large vocab models.
|
| 1407 |
+
log_p = F.log_softmax(p / temp, dim=-1)
|
| 1408 |
+
log_q = F.log_softmax(q / temp, dim=-1)
|
| 1409 |
+
if detach_p:
|
| 1410 |
+
log_p = log_p.detach()
|
| 1411 |
+
p_probs = log_p.exp()
|
| 1412 |
+
kl_flat = (p_probs * (log_p - log_q)).sum(dim=-1)
|
| 1413 |
+
return (kl_flat * shift_mask.to(kl_flat.dtype)).sum() / denom
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
def _extract_hidden_tensor(output: object) -> Optional[torch.Tensor]:
|
| 1417 |
+
if isinstance(output, torch.Tensor):
|
| 1418 |
+
return output
|
| 1419 |
+
if isinstance(output, (tuple, list)) and output:
|
| 1420 |
+
first = output[0]
|
| 1421 |
+
if isinstance(first, torch.Tensor):
|
| 1422 |
+
return first
|
| 1423 |
+
return None
|
| 1424 |
+
|
| 1425 |
+
|
| 1426 |
+
def _grad_l2_norm(grads: List[Optional[torch.Tensor]]) -> float:
|
| 1427 |
+
total = 0.0
|
| 1428 |
+
for grad in grads:
|
| 1429 |
+
if grad is None:
|
| 1430 |
+
continue
|
| 1431 |
+
total += float(grad.detach().float().pow(2).sum().item())
|
| 1432 |
+
if total <= 0.0:
|
| 1433 |
+
return 0.0
|
| 1434 |
+
return float(math.sqrt(total))
|
| 1435 |
+
|
| 1436 |
+
|
| 1437 |
+
def _register_forward_pre_hook_with_optional_kwargs(layer, hook):
|
| 1438 |
+
try:
|
| 1439 |
+
handle = layer.register_forward_pre_hook(hook, with_kwargs=True)
|
| 1440 |
+
return handle
|
| 1441 |
+
except TypeError:
|
| 1442 |
+
def wrapper(module, inputs):
|
| 1443 |
+
return hook(module, inputs, None)
|
| 1444 |
+
|
| 1445 |
+
return layer.register_forward_pre_hook(wrapper)
|
| 1446 |
+
|
| 1447 |
+
|
| 1448 |
+
def commutator_precondition(
|
| 1449 |
+
student_model: torch.nn.Module,
|
| 1450 |
+
student_layers: List[torch.nn.Module],
|
| 1451 |
+
teacher_model: torch.nn.Module,
|
| 1452 |
+
dataloader,
|
| 1453 |
+
dwce_scores: Optional[List[float]],
|
| 1454 |
+
args: argparse.Namespace,
|
| 1455 |
+
exclude_pairs: Optional[Set[int]] = None,
|
| 1456 |
+
progressive_cycle: Optional[int] = None,
|
| 1457 |
+
progressive_total: Optional[int] = None,
|
| 1458 |
+
) -> Dict[str, object]:
|
| 1459 |
+
"""Run commutator-style preconditioning before pair fusion.
|
| 1460 |
+
|
| 1461 |
+
Objective on each sampled pair i:
|
| 1462 |
+
L = T^2 * KL(p_teacher || p_student) + mu * L_interaction(i)
|
| 1463 |
+
|
| 1464 |
+
Interaction loss is computed locally on block (i+1):
|
| 1465 |
+
r1 = B_{i+1}(h_{i+1}) - h_{i+1}
|
| 1466 |
+
r0 = B_{i+1}(h_i) - h_i
|
| 1467 |
+
L_interaction = ||r1-r0||^2 (or relative form).
|
| 1468 |
+
"""
|
| 1469 |
+
if not bool(getattr(args, "comm_enabled", False)):
|
| 1470 |
+
return {"enabled": False}
|
| 1471 |
+
if not student_layers or len(student_layers) < 2:
|
| 1472 |
+
return {"enabled": False, "reason": "need_at_least_2_layers"}
|
| 1473 |
+
|
| 1474 |
+
temp = float(getattr(args, "comm_temp", 2.0))
|
| 1475 |
+
steps_ratio = float(getattr(args, "comm_steps_ratio", 0.1))
|
| 1476 |
+
lr_scale = float(getattr(args, "comm_lr_scale", 0.1))
|
| 1477 |
+
sample_eta = float(getattr(args, "comm_sample_eta", 0.5))
|
| 1478 |
+
sample_dwce_scale = float(getattr(args, "comm_sample_dwce_scale", 1.0))
|
| 1479 |
+
top_k = int(getattr(args, "comm_topk", 1))
|
| 1480 |
+
interaction_mode = str(getattr(args, "comm_interaction_mode", "relative")).strip().lower()
|
| 1481 |
+
interaction_eps = float(getattr(args, "comm_interaction_eps", 1e-8))
|
| 1482 |
+
mu_cfg = getattr(args, "comm_mu", None)
|
| 1483 |
+
mu_auto = bool(getattr(args, "comm_mu_auto", False))
|
| 1484 |
+
mu_auto_rho = float(getattr(args, "comm_mu_auto_rho", 0.1))
|
| 1485 |
+
mu_auto_eps = float(getattr(args, "comm_mu_auto_eps", 1e-8))
|
| 1486 |
+
comm_train_mode = str(getattr(args, "comm_train_mode", "lora")).strip().lower()
|
| 1487 |
+
log_steps = int(getattr(args, "comm_log_steps", 50))
|
| 1488 |
+
|
| 1489 |
+
if temp <= 0.0:
|
| 1490 |
+
raise SystemExit("--comm_temp must be > 0")
|
| 1491 |
+
if steps_ratio < 0.0:
|
| 1492 |
+
raise SystemExit("--comm_steps_ratio must be >= 0")
|
| 1493 |
+
if lr_scale <= 0.0:
|
| 1494 |
+
raise SystemExit("--comm_lr_scale must be > 0")
|
| 1495 |
+
if not (0.0 <= sample_eta <= 1.0):
|
| 1496 |
+
raise SystemExit("--comm_sample_eta must be in [0, 1]")
|
| 1497 |
+
if top_k <= 0:
|
| 1498 |
+
raise SystemExit("--comm_topk must be >= 1")
|
| 1499 |
+
if interaction_mode not in {"mse", "relative"}:
|
| 1500 |
+
raise SystemExit("--comm_interaction_mode must be one of: mse, relative")
|
| 1501 |
+
if comm_train_mode not in {"lora", "full"}:
|
| 1502 |
+
raise SystemExit("--comm_train_mode must be one of: lora, full")
|
| 1503 |
+
if interaction_eps <= 0.0:
|
| 1504 |
+
raise SystemExit("--comm_interaction_eps must be > 0")
|
| 1505 |
+
if mu_auto_rho < 0.0:
|
| 1506 |
+
raise SystemExit("--comm_mu_auto_rho must be >= 0")
|
| 1507 |
+
if mu_auto_eps <= 0.0:
|
| 1508 |
+
raise SystemExit("--comm_mu_auto_eps must be > 0")
|
| 1509 |
+
|
| 1510 |
+
if mu_cfg is None:
|
| 1511 |
+
base_mu = 0.5 if interaction_mode == "relative" else 0.1
|
| 1512 |
+
else:
|
| 1513 |
+
base_mu = float(mu_cfg)
|
| 1514 |
+
if base_mu < 0.0:
|
| 1515 |
+
raise SystemExit("--comm_mu must be >= 0")
|
| 1516 |
+
|
| 1517 |
+
distill_epochs = float(getattr(args, "distill_epochs", 1.0))
|
| 1518 |
+
if distill_epochs <= 0.0:
|
| 1519 |
+
distill_epochs = 1.0
|
| 1520 |
+
grad_accum = int(getattr(args, "distill_grad_accum_steps", 1))
|
| 1521 |
+
if grad_accum <= 0:
|
| 1522 |
+
grad_accum = 1
|
| 1523 |
+
|
| 1524 |
+
try:
|
| 1525 |
+
batches_per_epoch = len(dataloader)
|
| 1526 |
+
except TypeError as exc:
|
| 1527 |
+
raise SystemExit(
|
| 1528 |
+
"Commutator preconditioning requires a finite-length distillation dataloader."
|
| 1529 |
+
) from exc
|
| 1530 |
+
if batches_per_epoch <= 0:
|
| 1531 |
+
return {"enabled": False, "reason": "empty_dataloader"}
|
| 1532 |
+
|
| 1533 |
+
full_epochs = int(distill_epochs)
|
| 1534 |
+
fractional = distill_epochs - full_epochs
|
| 1535 |
+
if fractional < 1e-8:
|
| 1536 |
+
fractional = 0.0
|
| 1537 |
+
total_batches = full_epochs * batches_per_epoch
|
| 1538 |
+
if fractional > 0.0:
|
| 1539 |
+
frac_batches = int(round(fractional * batches_per_epoch))
|
| 1540 |
+
if frac_batches <= 0:
|
| 1541 |
+
frac_batches = 1
|
| 1542 |
+
total_batches += frac_batches
|
| 1543 |
+
|
| 1544 |
+
distill_opt_steps = int(math.ceil(total_batches / float(grad_accum)))
|
| 1545 |
+
target_opt_steps = int(round(steps_ratio * distill_opt_steps))
|
| 1546 |
+
if target_opt_steps <= 0:
|
| 1547 |
+
target_opt_steps = 1
|
| 1548 |
+
|
| 1549 |
+
num_pairs = max(len(student_layers) - 1, 0)
|
| 1550 |
+
exclude_set = {
|
| 1551 |
+
int(idx)
|
| 1552 |
+
for idx in (exclude_pairs or set())
|
| 1553 |
+
if isinstance(idx, int) and 0 <= int(idx) < num_pairs
|
| 1554 |
+
}
|
| 1555 |
+
allowed_pairs = [i for i in range(num_pairs) if i not in exclude_set]
|
| 1556 |
+
if not allowed_pairs:
|
| 1557 |
+
return {"enabled": False, "reason": "all_pairs_excluded"}
|
| 1558 |
+
|
| 1559 |
+
ranked_pairs = list(allowed_pairs)
|
| 1560 |
+
if dwce_scores is not None and len(dwce_scores) >= num_pairs:
|
| 1561 |
+
finite_pairs = []
|
| 1562 |
+
for idx in allowed_pairs:
|
| 1563 |
+
value = float(dwce_scores[idx])
|
| 1564 |
+
if math.isfinite(value):
|
| 1565 |
+
finite_pairs.append(idx)
|
| 1566 |
+
if finite_pairs:
|
| 1567 |
+
ranked_pairs = sorted(finite_pairs, key=lambda i: float(dwce_scores[i]))
|
| 1568 |
+
else:
|
| 1569 |
+
ranked_pairs = list(allowed_pairs)
|
| 1570 |
+
candidate_pairs = ranked_pairs[: min(top_k, len(ranked_pairs))]
|
| 1571 |
+
if not candidate_pairs:
|
| 1572 |
+
return {"enabled": False, "reason": "no_candidate_pairs"}
|
| 1573 |
+
|
| 1574 |
+
layer_trainable_params: List[List[torch.nn.Parameter]] = []
|
| 1575 |
+
trainable_params: List[torch.nn.Parameter] = []
|
| 1576 |
+
if comm_train_mode == "lora":
|
| 1577 |
+
# LoRA comm preconditioning: update LoRA adapters on receiver layer (i+1).
|
| 1578 |
+
lora_modules = apply_lora_adapters(student_model, args)
|
| 1579 |
+
if not lora_modules:
|
| 1580 |
+
return {"enabled": False, "reason": "no_lora_modules"}
|
| 1581 |
+
|
| 1582 |
+
trainable_seen: Set[int] = set()
|
| 1583 |
+
for module in lora_modules:
|
| 1584 |
+
for param in module.lora_parameters():
|
| 1585 |
+
pid = id(param)
|
| 1586 |
+
if pid in trainable_seen:
|
| 1587 |
+
continue
|
| 1588 |
+
trainable_seen.add(pid)
|
| 1589 |
+
trainable_params.append(param)
|
| 1590 |
+
|
| 1591 |
+
for layer in student_layers:
|
| 1592 |
+
seen: Set[int] = set()
|
| 1593 |
+
params: List[torch.nn.Parameter] = []
|
| 1594 |
+
for module in layer.modules():
|
| 1595 |
+
if not isinstance(module, LoRALinear):
|
| 1596 |
+
continue
|
| 1597 |
+
for param in module.lora_parameters():
|
| 1598 |
+
pid = id(param)
|
| 1599 |
+
if pid in seen:
|
| 1600 |
+
continue
|
| 1601 |
+
seen.add(pid)
|
| 1602 |
+
params.append(param)
|
| 1603 |
+
layer_trainable_params.append(params)
|
| 1604 |
+
else:
|
| 1605 |
+
# Full-weight comm preconditioning: update full receiver-layer weights.
|
| 1606 |
+
for layer in student_layers:
|
| 1607 |
+
seen: Set[int] = set()
|
| 1608 |
+
params: List[torch.nn.Parameter] = []
|
| 1609 |
+
for param in layer.parameters():
|
| 1610 |
+
if not isinstance(param, torch.nn.Parameter):
|
| 1611 |
+
continue
|
| 1612 |
+
pid = id(param)
|
| 1613 |
+
if pid in seen:
|
| 1614 |
+
continue
|
| 1615 |
+
seen.add(pid)
|
| 1616 |
+
params.append(param)
|
| 1617 |
+
layer_trainable_params.append(params)
|
| 1618 |
+
|
| 1619 |
+
candidate_pairs = [
|
| 1620 |
+
i
|
| 1621 |
+
for i in candidate_pairs
|
| 1622 |
+
if (i + 1) < len(layer_trainable_params) and layer_trainable_params[i + 1]
|
| 1623 |
+
]
|
| 1624 |
+
if not candidate_pairs:
|
| 1625 |
+
if comm_train_mode == "lora":
|
| 1626 |
+
merge_lora_adapters(student_model)
|
| 1627 |
+
return {"enabled": False, "reason": "no_trainable_receiver_layers"}
|
| 1628 |
+
|
| 1629 |
+
if comm_train_mode == "full":
|
| 1630 |
+
trainable_seen: Set[int] = set()
|
| 1631 |
+
for pair_idx in candidate_pairs:
|
| 1632 |
+
for param in layer_trainable_params[pair_idx + 1]:
|
| 1633 |
+
pid = id(param)
|
| 1634 |
+
if pid in trainable_seen:
|
| 1635 |
+
continue
|
| 1636 |
+
trainable_seen.add(pid)
|
| 1637 |
+
trainable_params.append(param)
|
| 1638 |
+
if not trainable_params:
|
| 1639 |
+
return {"enabled": False, "reason": "no_trainable_receiver_layers"}
|
| 1640 |
+
|
| 1641 |
+
# Freeze non-comm params to reduce grad memory.
|
| 1642 |
+
for param in student_model.parameters():
|
| 1643 |
+
param.requires_grad_(False)
|
| 1644 |
+
for param in trainable_params:
|
| 1645 |
+
param.requires_grad_(True)
|
| 1646 |
+
|
| 1647 |
+
if not trainable_params:
|
| 1648 |
+
if comm_train_mode == "lora":
|
| 1649 |
+
merge_lora_adapters(student_model)
|
| 1650 |
+
return {"enabled": False, "reason": "no_trainable_params"}
|
| 1651 |
+
|
| 1652 |
+
candidate_probs = torch.full(
|
| 1653 |
+
(len(candidate_pairs),),
|
| 1654 |
+
1.0 / float(len(candidate_pairs)),
|
| 1655 |
+
dtype=torch.float32,
|
| 1656 |
+
)
|
| 1657 |
+
if dwce_scores is not None and len(dwce_scores) >= num_pairs and sample_eta > 0.0:
|
| 1658 |
+
score_vec = torch.tensor(
|
| 1659 |
+
[float(dwce_scores[i]) for i in candidate_pairs], dtype=torch.float32
|
| 1660 |
+
)
|
| 1661 |
+
score_vec = torch.nan_to_num(score_vec, nan=1e9, posinf=1e9, neginf=-1e9)
|
| 1662 |
+
biased = torch.softmax(-float(sample_dwce_scale) * score_vec, dim=0)
|
| 1663 |
+
candidate_probs = (1.0 - sample_eta) * candidate_probs + sample_eta * biased
|
| 1664 |
+
candidate_probs = candidate_probs / candidate_probs.sum()
|
| 1665 |
+
|
| 1666 |
+
probs_by_pair = [0.0 for _ in range(num_pairs)]
|
| 1667 |
+
for pos, pair_idx in enumerate(candidate_pairs):
|
| 1668 |
+
probs_by_pair[pair_idx] = float(candidate_probs[pos].item())
|
| 1669 |
+
|
| 1670 |
+
lr = float(getattr(args, "distill_lr", 1e-4)) * lr_scale
|
| 1671 |
+
optimizer = torch.optim.AdamW(
|
| 1672 |
+
trainable_params,
|
| 1673 |
+
lr=lr,
|
| 1674 |
+
weight_decay=float(getattr(args, "distill_weight_decay", 0.0)),
|
| 1675 |
+
)
|
| 1676 |
+
|
| 1677 |
+
device_type = torch.device(args.device).type
|
| 1678 |
+
amp_dtype = None
|
| 1679 |
+
if args.dtype == "float16":
|
| 1680 |
+
amp_dtype = torch.float16
|
| 1681 |
+
elif args.dtype == "bfloat16":
|
| 1682 |
+
amp_dtype = torch.bfloat16
|
| 1683 |
+
use_amp = amp_dtype is not None and device_type == "cuda"
|
| 1684 |
+
use_scaler = use_amp and amp_dtype == torch.float16
|
| 1685 |
+
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
| 1686 |
+
|
| 1687 |
+
teacher_device = next(teacher_model.parameters()).device
|
| 1688 |
+
teacher_model.eval()
|
| 1689 |
+
student_model.train()
|
| 1690 |
+
|
| 1691 |
+
gen = torch.Generator(device="cpu")
|
| 1692 |
+
seed = int(getattr(args, "seed", 0))
|
| 1693 |
+
if progressive_cycle is not None:
|
| 1694 |
+
seed += int(progressive_cycle) * 100003
|
| 1695 |
+
gen.manual_seed(seed)
|
| 1696 |
+
|
| 1697 |
+
opt_step = 0
|
| 1698 |
+
total_loss_sum = 0.0
|
| 1699 |
+
anchor_sum = 0.0
|
| 1700 |
+
interaction_sum = 0.0
|
| 1701 |
+
mu_sum = 0.0
|
| 1702 |
+
counted = 0
|
| 1703 |
+
pair_counts = [0 for _ in range(num_pairs)]
|
| 1704 |
+
|
| 1705 |
+
desc = "Comm"
|
| 1706 |
+
if progressive_cycle is not None:
|
| 1707 |
+
if progressive_total is not None:
|
| 1708 |
+
desc = f"Comm (cycle {progressive_cycle}/{progressive_total})"
|
| 1709 |
+
else:
|
| 1710 |
+
desc = f"Comm (cycle {progressive_cycle})"
|
| 1711 |
+
iterator = range(target_opt_steps)
|
| 1712 |
+
if tqdm is not None and _tqdm_enabled():
|
| 1713 |
+
iterator = tqdm(iterator, desc=desc, unit="step")
|
| 1714 |
+
|
| 1715 |
+
data_iter = iter(dataloader)
|
| 1716 |
+
autocast_ctx = (
|
| 1717 |
+
torch.autocast(device_type=device_type, dtype=amp_dtype)
|
| 1718 |
+
if use_amp
|
| 1719 |
+
else nullcontext()
|
| 1720 |
+
)
|
| 1721 |
+
|
| 1722 |
+
for _ in iterator:
|
| 1723 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1724 |
+
accum_done = 0
|
| 1725 |
+
while accum_done < grad_accum:
|
| 1726 |
+
try:
|
| 1727 |
+
batch = next(data_iter)
|
| 1728 |
+
except StopIteration:
|
| 1729 |
+
data_iter = iter(dataloader)
|
| 1730 |
+
batch = next(data_iter)
|
| 1731 |
+
|
| 1732 |
+
input_ids = batch[0].to(args.device)
|
| 1733 |
+
attention_mask = batch[1].to(args.device)
|
| 1734 |
+
sampled_pos = int(torch.multinomial(candidate_probs, 1, generator=gen).item())
|
| 1735 |
+
pair_idx = int(candidate_pairs[sampled_pos])
|
| 1736 |
+
pair_counts[pair_idx] += 1
|
| 1737 |
+
|
| 1738 |
+
receiver_params = layer_trainable_params[pair_idx + 1]
|
| 1739 |
+
receiver_param_ids = {id(param) for param in receiver_params}
|
| 1740 |
+
|
| 1741 |
+
teacher_ids = input_ids.to(teacher_device)
|
| 1742 |
+
teacher_mask = attention_mask.to(teacher_device)
|
| 1743 |
+
with torch.no_grad(), autocast_ctx:
|
| 1744 |
+
teacher_outputs = teacher_model(
|
| 1745 |
+
input_ids=teacher_ids,
|
| 1746 |
+
attention_mask=teacher_mask,
|
| 1747 |
+
use_cache=False,
|
| 1748 |
+
)
|
| 1749 |
+
teacher_logits = teacher_outputs.logits
|
| 1750 |
+
|
| 1751 |
+
capture: Dict[str, object] = {
|
| 1752 |
+
"h_l": None,
|
| 1753 |
+
"h_lp1": None,
|
| 1754 |
+
"y1": None,
|
| 1755 |
+
"recv_args": None,
|
| 1756 |
+
"recv_kwargs": None,
|
| 1757 |
+
}
|
| 1758 |
+
|
| 1759 |
+
def _hook_l(_module, inputs, _output):
|
| 1760 |
+
if inputs and isinstance(inputs[0], torch.Tensor):
|
| 1761 |
+
capture["h_l"] = inputs[0]
|
| 1762 |
+
|
| 1763 |
+
def _hook_recv_pre(_module, inputs, kwargs):
|
| 1764 |
+
capture["recv_args"] = inputs
|
| 1765 |
+
capture["recv_kwargs"] = kwargs
|
| 1766 |
+
|
| 1767 |
+
def _hook_recv(_module, inputs, output):
|
| 1768 |
+
if inputs and isinstance(inputs[0], torch.Tensor):
|
| 1769 |
+
capture["h_lp1"] = inputs[0]
|
| 1770 |
+
capture["y1"] = _extract_hidden_tensor(output)
|
| 1771 |
+
|
| 1772 |
+
handles: List[object] = [
|
| 1773 |
+
student_layers[pair_idx].register_forward_hook(_hook_l),
|
| 1774 |
+
_register_forward_pre_hook_with_optional_kwargs(
|
| 1775 |
+
student_layers[pair_idx + 1], _hook_recv_pre
|
| 1776 |
+
),
|
| 1777 |
+
student_layers[pair_idx + 1].register_forward_hook(_hook_recv),
|
| 1778 |
+
]
|
| 1779 |
+
try:
|
| 1780 |
+
with autocast_ctx:
|
| 1781 |
+
student_outputs = student_model(
|
| 1782 |
+
input_ids=input_ids,
|
| 1783 |
+
attention_mask=attention_mask,
|
| 1784 |
+
use_cache=False,
|
| 1785 |
+
)
|
| 1786 |
+
student_logits = student_outputs.logits
|
| 1787 |
+
finally:
|
| 1788 |
+
for handle in handles:
|
| 1789 |
+
try:
|
| 1790 |
+
handle.remove()
|
| 1791 |
+
except Exception:
|
| 1792 |
+
pass
|
| 1793 |
+
|
| 1794 |
+
with autocast_ctx:
|
| 1795 |
+
anchor_kl = _masked_kl(
|
| 1796 |
+
teacher_logits,
|
| 1797 |
+
student_logits,
|
| 1798 |
+
attention_mask,
|
| 1799 |
+
temp=temp,
|
| 1800 |
+
detach_p=True,
|
| 1801 |
+
)
|
| 1802 |
+
if anchor_kl is None:
|
| 1803 |
+
continue
|
| 1804 |
+
anchor_loss = (temp ** 2) * anchor_kl
|
| 1805 |
+
|
| 1806 |
+
interaction_loss = None
|
| 1807 |
+
h_l = capture.get("h_l")
|
| 1808 |
+
h_lp1 = capture.get("h_lp1")
|
| 1809 |
+
y1 = capture.get("y1")
|
| 1810 |
+
recv_args = capture.get("recv_args")
|
| 1811 |
+
recv_kwargs = capture.get("recv_kwargs")
|
| 1812 |
+
if (
|
| 1813 |
+
isinstance(h_l, torch.Tensor)
|
| 1814 |
+
and isinstance(h_lp1, torch.Tensor)
|
| 1815 |
+
and isinstance(y1, torch.Tensor)
|
| 1816 |
+
and isinstance(recv_args, tuple)
|
| 1817 |
+
and len(recv_args) > 0
|
| 1818 |
+
and isinstance(recv_args[0], torch.Tensor)
|
| 1819 |
+
):
|
| 1820 |
+
call_args = list(recv_args)
|
| 1821 |
+
first_hidden = call_args[0]
|
| 1822 |
+
h_l_detached = h_l.detach().to(
|
| 1823 |
+
device=first_hidden.device,
|
| 1824 |
+
dtype=first_hidden.dtype,
|
| 1825 |
+
)
|
| 1826 |
+
call_args[0] = h_l_detached
|
| 1827 |
+
call_kwargs = dict(recv_kwargs) if isinstance(recv_kwargs, dict) else {}
|
| 1828 |
+
|
| 1829 |
+
y0_raw = student_layers[pair_idx + 1](*tuple(call_args), **call_kwargs)
|
| 1830 |
+
y0 = _extract_hidden_tensor(y0_raw)
|
| 1831 |
+
if isinstance(y0, torch.Tensor):
|
| 1832 |
+
if y0.device != y1.device:
|
| 1833 |
+
y0 = y0.to(y1.device)
|
| 1834 |
+
h_lp1_detached = h_lp1.detach().to(device=y1.device, dtype=y1.dtype)
|
| 1835 |
+
h_l_for_res = h_l.detach().to(device=y0.device, dtype=y0.dtype)
|
| 1836 |
+
r1 = y1 - h_lp1_detached
|
| 1837 |
+
r0 = y0 - h_l_for_res
|
| 1838 |
+
mask = attention_mask.to(dtype=r1.dtype)
|
| 1839 |
+
mask_sum = mask.sum()
|
| 1840 |
+
if mask_sum.item() > 0:
|
| 1841 |
+
if interaction_mode == "relative":
|
| 1842 |
+
num = (r1 - r0).float().pow(2).sum(dim=-1)
|
| 1843 |
+
den = r1.float().pow(2).sum(dim=-1) + float(interaction_eps)
|
| 1844 |
+
ratio = (num / den) * mask.to(num.dtype)
|
| 1845 |
+
interaction_loss = ratio.sum() / (mask_sum + 1e-8)
|
| 1846 |
+
else:
|
| 1847 |
+
denom = mask_sum * r1.size(-1)
|
| 1848 |
+
if denom.item() > 0:
|
| 1849 |
+
interaction_loss = (
|
| 1850 |
+
(r1 - r0).pow(2) * mask.unsqueeze(-1)
|
| 1851 |
+
).sum() / denom
|
| 1852 |
+
|
| 1853 |
+
mu_effective = float(base_mu)
|
| 1854 |
+
if (
|
| 1855 |
+
mu_auto
|
| 1856 |
+
and interaction_loss is not None
|
| 1857 |
+
and receiver_params
|
| 1858 |
+
and mu_auto_rho > 0.0
|
| 1859 |
+
):
|
| 1860 |
+
anchor_grads = torch.autograd.grad(
|
| 1861 |
+
anchor_loss,
|
| 1862 |
+
receiver_params,
|
| 1863 |
+
retain_graph=True,
|
| 1864 |
+
allow_unused=True,
|
| 1865 |
+
)
|
| 1866 |
+
interaction_grads = torch.autograd.grad(
|
| 1867 |
+
interaction_loss,
|
| 1868 |
+
receiver_params,
|
| 1869 |
+
retain_graph=True,
|
| 1870 |
+
allow_unused=True,
|
| 1871 |
+
)
|
| 1872 |
+
anchor_norm = _grad_l2_norm(list(anchor_grads))
|
| 1873 |
+
interaction_norm = _grad_l2_norm(list(interaction_grads))
|
| 1874 |
+
if interaction_norm > 0.0:
|
| 1875 |
+
mu_effective = float(
|
| 1876 |
+
mu_auto_rho
|
| 1877 |
+
* (anchor_norm / (interaction_norm + float(mu_auto_eps)))
|
| 1878 |
+
)
|
| 1879 |
+
else:
|
| 1880 |
+
mu_effective = float(base_mu)
|
| 1881 |
+
if not math.isfinite(mu_effective):
|
| 1882 |
+
mu_effective = float(base_mu)
|
| 1883 |
+
|
| 1884 |
+
total_loss = anchor_loss
|
| 1885 |
+
if interaction_loss is not None:
|
| 1886 |
+
total_loss = total_loss + (float(mu_effective) * interaction_loss)
|
| 1887 |
+
|
| 1888 |
+
if grad_accum > 1:
|
| 1889 |
+
total_loss = total_loss / float(grad_accum)
|
| 1890 |
+
|
| 1891 |
+
if use_scaler:
|
| 1892 |
+
scaler.scale(total_loss).backward()
|
| 1893 |
+
else:
|
| 1894 |
+
total_loss.backward()
|
| 1895 |
+
|
| 1896 |
+
# Only the sampled receiver layer updates on this micro-batch.
|
| 1897 |
+
for param in trainable_params:
|
| 1898 |
+
if id(param) in receiver_param_ids:
|
| 1899 |
+
continue
|
| 1900 |
+
if param.grad is not None:
|
| 1901 |
+
if comm_train_mode == "lora":
|
| 1902 |
+
param.grad.zero_()
|
| 1903 |
+
else:
|
| 1904 |
+
param.grad = None
|
| 1905 |
+
|
| 1906 |
+
total_loss_sum += float(total_loss.detach().float().item())
|
| 1907 |
+
anchor_sum += float(anchor_loss.detach().float().item())
|
| 1908 |
+
if interaction_loss is not None:
|
| 1909 |
+
interaction_sum += float(interaction_loss.detach().float().item())
|
| 1910 |
+
mu_sum += float(mu_effective)
|
| 1911 |
+
counted += 1
|
| 1912 |
+
accum_done += 1
|
| 1913 |
+
|
| 1914 |
+
if args.distill_max_grad_norm is not None:
|
| 1915 |
+
if use_scaler:
|
| 1916 |
+
scaler.unscale_(optimizer)
|
| 1917 |
+
torch.nn.utils.clip_grad_norm_(
|
| 1918 |
+
trainable_params,
|
| 1919 |
+
float(args.distill_max_grad_norm),
|
| 1920 |
+
)
|
| 1921 |
+
|
| 1922 |
+
if use_scaler:
|
| 1923 |
+
scaler.step(optimizer)
|
| 1924 |
+
scaler.update()
|
| 1925 |
+
else:
|
| 1926 |
+
optimizer.step()
|
| 1927 |
+
|
| 1928 |
+
opt_step += 1
|
| 1929 |
+
if log_steps and (opt_step == 1 or opt_step % log_steps == 0):
|
| 1930 |
+
denom = max(counted, 1)
|
| 1931 |
+
print(
|
| 1932 |
+
f"[comm] step={opt_step}/{target_opt_steps} "
|
| 1933 |
+
f"loss={total_loss_sum/denom:.6f} "
|
| 1934 |
+
f"anchor={anchor_sum/denom:.6f} "
|
| 1935 |
+
f"int={interaction_sum/denom:.6f} "
|
| 1936 |
+
f"mu={mu_sum/denom:.6f}"
|
| 1937 |
+
)
|
| 1938 |
+
|
| 1939 |
+
if comm_train_mode == "lora":
|
| 1940 |
+
merge_lora_adapters(student_model)
|
| 1941 |
+
|
| 1942 |
+
stats: Dict[str, object] = {
|
| 1943 |
+
"enabled": True,
|
| 1944 |
+
"train_mode": comm_train_mode,
|
| 1945 |
+
"opt_steps": int(target_opt_steps),
|
| 1946 |
+
"grad_accum_steps": int(grad_accum),
|
| 1947 |
+
"lr": float(lr),
|
| 1948 |
+
"temp": float(temp),
|
| 1949 |
+
"steps_ratio": float(steps_ratio),
|
| 1950 |
+
"lr_scale": float(lr_scale),
|
| 1951 |
+
"interaction_mode": interaction_mode,
|
| 1952 |
+
"interaction_eps": float(interaction_eps),
|
| 1953 |
+
"mu": float(base_mu),
|
| 1954 |
+
"mu_auto": bool(mu_auto),
|
| 1955 |
+
"mu_auto_rho": float(mu_auto_rho),
|
| 1956 |
+
"mu_auto_eps": float(mu_auto_eps),
|
| 1957 |
+
"sample_eta": float(sample_eta),
|
| 1958 |
+
"sample_dwce_scale": float(sample_dwce_scale),
|
| 1959 |
+
"topk": int(top_k),
|
| 1960 |
+
"candidate_pairs": [int(i) for i in candidate_pairs],
|
| 1961 |
+
"trainable_params": int(sum(int(param.numel()) for param in trainable_params)),
|
| 1962 |
+
}
|
| 1963 |
+
total_samples = int(sum(pair_counts))
|
| 1964 |
+
probs_list = [float(x) for x in probs_by_pair]
|
| 1965 |
+
freqs = (
|
| 1966 |
+
[float(c) / float(total_samples) for c in pair_counts]
|
| 1967 |
+
if total_samples > 0
|
| 1968 |
+
else [0.0 for _ in pair_counts]
|
| 1969 |
+
)
|
| 1970 |
+
top_show = min(10, num_pairs)
|
| 1971 |
+
top_indices = sorted(range(num_pairs), key=lambda i: pair_counts[i], reverse=True)[:top_show]
|
| 1972 |
+
top_pairs = [
|
| 1973 |
+
{
|
| 1974 |
+
"pair": int(i),
|
| 1975 |
+
"count": int(pair_counts[i]),
|
| 1976 |
+
"freq": float(freqs[i]),
|
| 1977 |
+
"prob": float(probs_list[i]) if i < len(probs_list) else None,
|
| 1978 |
+
}
|
| 1979 |
+
for i in top_indices
|
| 1980 |
+
if pair_counts[i] > 0
|
| 1981 |
+
]
|
| 1982 |
+
stats["pair_selection"] = {
|
| 1983 |
+
"num_pairs": int(num_pairs),
|
| 1984 |
+
"excluded_pairs": sorted(exclude_set),
|
| 1985 |
+
"candidate_pairs": [int(i) for i in candidate_pairs],
|
| 1986 |
+
"total_samples": total_samples,
|
| 1987 |
+
"unique_pairs": int(sum(1 for c in pair_counts if c > 0)),
|
| 1988 |
+
"counts": [int(c) for c in pair_counts],
|
| 1989 |
+
"freqs": freqs,
|
| 1990 |
+
"probs": probs_list,
|
| 1991 |
+
"top_pairs": top_pairs,
|
| 1992 |
+
}
|
| 1993 |
+
|
| 1994 |
+
if total_samples > 0 and top_pairs:
|
| 1995 |
+
top_str = ", ".join(
|
| 1996 |
+
f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} "
|
| 1997 |
+
f"(obs={entry['freq']:.3f}, exp={entry['prob']:.3f})"
|
| 1998 |
+
for entry in top_pairs
|
| 1999 |
+
if entry.get("prob") is not None
|
| 2000 |
+
)
|
| 2001 |
+
if not top_str:
|
| 2002 |
+
top_str = ", ".join(
|
| 2003 |
+
f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} "
|
| 2004 |
+
f"(obs={entry['freq']:.3f})"
|
| 2005 |
+
for entry in top_pairs
|
| 2006 |
+
)
|
| 2007 |
+
print(
|
| 2008 |
+
f"[comm] Pair sampling stats: total={total_samples} "
|
| 2009 |
+
f"unique={stats['pair_selection']['unique_pairs']}/{num_pairs} "
|
| 2010 |
+
f"top={top_str}"
|
| 2011 |
+
)
|
| 2012 |
+
|
| 2013 |
+
if counted > 0:
|
| 2014 |
+
stats["avg_loss"] = float(total_loss_sum / float(counted))
|
| 2015 |
+
stats["avg_anchor"] = float(anchor_sum / float(counted))
|
| 2016 |
+
stats["avg_interaction"] = float(interaction_sum / float(counted))
|
| 2017 |
+
stats["avg_mu"] = float(mu_sum / float(counted))
|
| 2018 |
+
return stats
|
src/fuse_layers_model.py
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Model and layer helpers for fuse_layers."""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
except Exception: # pragma: no cover - optional dependency
|
| 12 |
+
tqdm = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _tqdm_enabled() -> bool:
|
| 16 |
+
value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
|
| 17 |
+
return value.strip().lower() not in {"1", "true", "yes", "on"}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_dtype(dtype: str):
|
| 21 |
+
if dtype == "auto":
|
| 22 |
+
return None
|
| 23 |
+
if dtype == "float16":
|
| 24 |
+
return torch.float16
|
| 25 |
+
if dtype == "bfloat16":
|
| 26 |
+
return torch.bfloat16
|
| 27 |
+
return torch.float32
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def resolve_attr(root: object, path: str) -> Optional[object]:
|
| 31 |
+
cur = root
|
| 32 |
+
for part in path.split("."):
|
| 33 |
+
if not hasattr(cur, part):
|
| 34 |
+
return None
|
| 35 |
+
cur = getattr(cur, part)
|
| 36 |
+
return cur
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def resolve_attr_with_parent(root: object, path: str) -> Tuple[object, str, object]:
|
| 40 |
+
parts = path.split(".")
|
| 41 |
+
cur = root
|
| 42 |
+
for part in parts[:-1]:
|
| 43 |
+
if not hasattr(cur, part):
|
| 44 |
+
raise ValueError(f"'{path}' not found on model")
|
| 45 |
+
cur = getattr(cur, part)
|
| 46 |
+
name = parts[-1]
|
| 47 |
+
if not hasattr(cur, name):
|
| 48 |
+
raise ValueError(f"'{path}' not found on model")
|
| 49 |
+
return cur, name, getattr(cur, name)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def find_layer_container(model, layer_path: Optional[str]) -> Tuple[object, str, object]:
|
| 53 |
+
if layer_path:
|
| 54 |
+
parent, name, container = resolve_attr_with_parent(model, layer_path)
|
| 55 |
+
return parent, name, container
|
| 56 |
+
|
| 57 |
+
candidate_paths = [
|
| 58 |
+
"model.layers", # LLaMA, Mistral, Qwen2, Gemma
|
| 59 |
+
"model.decoder.layers", # OPT
|
| 60 |
+
"transformer.h", # GPT-2, GPT-J, Bloom, Falcon
|
| 61 |
+
"transformer.blocks", # MPT
|
| 62 |
+
"gpt_neox.layers", # GPT-NeoX
|
| 63 |
+
"layers", # fallback
|
| 64 |
+
]
|
| 65 |
+
for path in candidate_paths:
|
| 66 |
+
candidate = resolve_attr(model, path)
|
| 67 |
+
if candidate is None:
|
| 68 |
+
continue
|
| 69 |
+
try:
|
| 70 |
+
list(candidate)
|
| 71 |
+
except TypeError:
|
| 72 |
+
continue
|
| 73 |
+
parent, name, container = resolve_attr_with_parent(model, path)
|
| 74 |
+
return parent, name, container
|
| 75 |
+
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"Could not locate transformer layers. Pass --layer_path explicitly."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def find_attention_module(layer: torch.nn.Module) -> torch.nn.Module:
|
| 82 |
+
if hasattr(layer, "self_attn"):
|
| 83 |
+
return getattr(layer, "self_attn")
|
| 84 |
+
if hasattr(layer, "attn"):
|
| 85 |
+
return getattr(layer, "attn")
|
| 86 |
+
if hasattr(layer, "attention"):
|
| 87 |
+
return getattr(layer, "attention")
|
| 88 |
+
for _, module in layer.named_modules():
|
| 89 |
+
if all(
|
| 90 |
+
hasattr(module, attr) for attr in ("q_proj", "k_proj", "v_proj", "o_proj")
|
| 91 |
+
):
|
| 92 |
+
return module
|
| 93 |
+
raise ValueError("Could not find attention module with q_proj/k_proj/v_proj/o_proj")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def find_mlp_module(layer: torch.nn.Module) -> torch.nn.Module:
|
| 97 |
+
if hasattr(layer, "mlp"):
|
| 98 |
+
return getattr(layer, "mlp")
|
| 99 |
+
for attr in ("feed_forward", "feedforward", "ffn", "ff"):
|
| 100 |
+
if hasattr(layer, attr):
|
| 101 |
+
return getattr(layer, attr)
|
| 102 |
+
for _, module in layer.named_modules():
|
| 103 |
+
if all(hasattr(module, attr) for attr in ("gate_proj", "up_proj", "down_proj")):
|
| 104 |
+
return module
|
| 105 |
+
if all(hasattr(module, attr) for attr in ("fc1", "fc2")):
|
| 106 |
+
return module
|
| 107 |
+
if all(
|
| 108 |
+
hasattr(module, attr)
|
| 109 |
+
for attr in ("dense_h_to_4h", "dense_4h_to_h")
|
| 110 |
+
):
|
| 111 |
+
return module
|
| 112 |
+
if all(hasattr(module, attr) for attr in ("w1", "w2")):
|
| 113 |
+
return module
|
| 114 |
+
raise ValueError("Could not find MLP/FFN module on layer")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_head_info(
|
| 118 |
+
attn: torch.nn.Module, hidden_size: int, config
|
| 119 |
+
) -> Tuple[int, int, int]:
|
| 120 |
+
num_heads = getattr(attn, "num_heads", None)
|
| 121 |
+
if num_heads is None:
|
| 122 |
+
num_heads = getattr(attn, "num_attention_heads", None)
|
| 123 |
+
if num_heads is None and config is not None:
|
| 124 |
+
num_heads = getattr(
|
| 125 |
+
config,
|
| 126 |
+
"num_attention_heads",
|
| 127 |
+
getattr(config, "num_heads", getattr(config, "n_head", None)),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
num_key_value_heads = getattr(attn, "num_key_value_heads", None)
|
| 131 |
+
if num_key_value_heads is None:
|
| 132 |
+
num_key_value_heads = getattr(attn, "num_kv_heads", None)
|
| 133 |
+
if num_key_value_heads is None and config is not None:
|
| 134 |
+
num_key_value_heads = getattr(
|
| 135 |
+
config,
|
| 136 |
+
"num_key_value_heads",
|
| 137 |
+
getattr(config, "num_kv_heads", getattr(config, "n_head_kv", None)),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
head_dim = getattr(attn, "head_dim", None)
|
| 141 |
+
if head_dim is None and config is not None:
|
| 142 |
+
head_dim = getattr(config, "head_dim", None)
|
| 143 |
+
|
| 144 |
+
if num_heads is None:
|
| 145 |
+
if hasattr(attn, "q_proj"):
|
| 146 |
+
q_out = attn.q_proj.weight.shape[0]
|
| 147 |
+
if head_dim is not None:
|
| 148 |
+
num_heads = q_out // head_dim
|
| 149 |
+
elif num_key_value_heads is not None and hasattr(attn, "k_proj"):
|
| 150 |
+
k_out = attn.k_proj.weight.shape[0]
|
| 151 |
+
head_dim = k_out // max(int(num_key_value_heads), 1)
|
| 152 |
+
num_heads = q_out // head_dim
|
| 153 |
+
if num_heads is None:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
"Attention module missing num_heads/num_attention_heads; "
|
| 156 |
+
"pass --layer_path or add config overrides."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if num_key_value_heads is None:
|
| 160 |
+
num_key_value_heads = num_heads
|
| 161 |
+
|
| 162 |
+
if head_dim is None:
|
| 163 |
+
head_dim = hidden_size // int(num_heads)
|
| 164 |
+
|
| 165 |
+
if num_key_value_heads is None and hasattr(attn, "k_proj"):
|
| 166 |
+
k_out = attn.k_proj.weight.shape[0]
|
| 167 |
+
num_key_value_heads = k_out // int(head_dim)
|
| 168 |
+
|
| 169 |
+
return int(num_heads), int(num_key_value_heads), int(head_dim)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def cosine_cost_matrix(
|
| 173 |
+
a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
a_norm = a / (a.norm(dim=1, keepdim=True) + eps)
|
| 176 |
+
b_norm = b / (b.norm(dim=1, keepdim=True) + eps)
|
| 177 |
+
sim = a_norm @ b_norm.t()
|
| 178 |
+
return 1.0 - sim
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def hungarian(cost: torch.Tensor) -> List[int]:
|
| 182 |
+
# Kuhn-Munkres for square cost matrix (minimization).
|
| 183 |
+
n = cost.size(0)
|
| 184 |
+
u = [0.0] * (n + 1)
|
| 185 |
+
v = [0.0] * (n + 1)
|
| 186 |
+
p = [0] * (n + 1)
|
| 187 |
+
way = [0] * (n + 1)
|
| 188 |
+
|
| 189 |
+
for i in range(1, n + 1):
|
| 190 |
+
p[0] = i
|
| 191 |
+
j0 = 0
|
| 192 |
+
minv = [float("inf")] * (n + 1)
|
| 193 |
+
used = [False] * (n + 1)
|
| 194 |
+
while True:
|
| 195 |
+
used[j0] = True
|
| 196 |
+
i0 = p[j0]
|
| 197 |
+
delta = float("inf")
|
| 198 |
+
j1 = 0
|
| 199 |
+
for j in range(1, n + 1):
|
| 200 |
+
if used[j]:
|
| 201 |
+
continue
|
| 202 |
+
cur = cost[i0 - 1, j - 1].item() - u[i0] - v[j]
|
| 203 |
+
if cur < minv[j]:
|
| 204 |
+
minv[j] = cur
|
| 205 |
+
way[j] = j0
|
| 206 |
+
if minv[j] < delta:
|
| 207 |
+
delta = minv[j]
|
| 208 |
+
j1 = j
|
| 209 |
+
for j in range(0, n + 1):
|
| 210 |
+
if used[j]:
|
| 211 |
+
u[p[j]] += delta
|
| 212 |
+
v[j] -= delta
|
| 213 |
+
else:
|
| 214 |
+
minv[j] -= delta
|
| 215 |
+
j0 = j1
|
| 216 |
+
if p[j0] == 0:
|
| 217 |
+
break
|
| 218 |
+
while True:
|
| 219 |
+
j1 = way[j0]
|
| 220 |
+
p[j0] = p[j1]
|
| 221 |
+
j0 = j1
|
| 222 |
+
if j0 == 0:
|
| 223 |
+
break
|
| 224 |
+
|
| 225 |
+
assignment = [-1] * n
|
| 226 |
+
for j in range(1, n + 1):
|
| 227 |
+
if p[j] > 0:
|
| 228 |
+
assignment[p[j] - 1] = j - 1
|
| 229 |
+
return assignment
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def compute_head_means(
|
| 233 |
+
model,
|
| 234 |
+
attn_i: torch.nn.Module,
|
| 235 |
+
attn_j: torch.nn.Module,
|
| 236 |
+
dataloader,
|
| 237 |
+
device: str,
|
| 238 |
+
hidden_size: int,
|
| 239 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int]:
|
| 240 |
+
num_heads_i, num_kv_i, head_dim_i = get_head_info(attn_i, hidden_size, model.config)
|
| 241 |
+
num_heads_j, num_kv_j, head_dim_j = get_head_info(attn_j, hidden_size, model.config)
|
| 242 |
+
if num_heads_i != num_heads_j or head_dim_i != head_dim_j:
|
| 243 |
+
raise ValueError("Head counts or head_dim differ between layers; cannot align")
|
| 244 |
+
|
| 245 |
+
sums_i = torch.zeros(num_heads_i, head_dim_i, device="cpu")
|
| 246 |
+
sums_j = torch.zeros(num_heads_j, head_dim_j, device="cpu")
|
| 247 |
+
count_i = [0]
|
| 248 |
+
count_j = [0]
|
| 249 |
+
|
| 250 |
+
def make_hook(
|
| 251 |
+
sums: torch.Tensor, count_ref: List[int], num_heads: int, head_dim: int
|
| 252 |
+
):
|
| 253 |
+
def hook(_module, inputs, _output):
|
| 254 |
+
hidden = inputs[0].detach()
|
| 255 |
+
if hidden.dim() != 3:
|
| 256 |
+
return
|
| 257 |
+
batch, seq, width = hidden.shape
|
| 258 |
+
if width != num_heads * head_dim:
|
| 259 |
+
return
|
| 260 |
+
reshaped = hidden.view(batch, seq, num_heads, head_dim)
|
| 261 |
+
sums.add_(reshaped.sum(dim=(0, 1)).float().cpu())
|
| 262 |
+
count_ref[0] += batch * seq
|
| 263 |
+
|
| 264 |
+
return hook
|
| 265 |
+
|
| 266 |
+
hook_i = attn_i.o_proj.register_forward_hook(
|
| 267 |
+
make_hook(sums_i, count_i, num_heads_i, head_dim_i)
|
| 268 |
+
)
|
| 269 |
+
hook_j = attn_j.o_proj.register_forward_hook(
|
| 270 |
+
make_hook(sums_j, count_j, num_heads_j, head_dim_j)
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
model.eval()
|
| 274 |
+
iterator = dataloader
|
| 275 |
+
if tqdm is not None and _tqdm_enabled():
|
| 276 |
+
iterator = tqdm(dataloader, desc="Head stats", unit="batch")
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
for batch in iterator:
|
| 279 |
+
input_ids = batch[0].to(device)
|
| 280 |
+
_ = model(input_ids=input_ids)
|
| 281 |
+
|
| 282 |
+
hook_i.remove()
|
| 283 |
+
hook_j.remove()
|
| 284 |
+
|
| 285 |
+
if count_i[0] == 0 or count_j[0] == 0:
|
| 286 |
+
raise RuntimeError("Failed to capture head outputs; check attention modules.")
|
| 287 |
+
|
| 288 |
+
mean_i = sums_i / count_i[0]
|
| 289 |
+
mean_j = sums_j / count_j[0]
|
| 290 |
+
return mean_i, mean_j, num_heads_i, num_kv_i, head_dim_i
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def build_head_permutation(
|
| 294 |
+
mean_i: torch.Tensor,
|
| 295 |
+
mean_j: torch.Tensor,
|
| 296 |
+
num_heads: int,
|
| 297 |
+
num_kv_heads: int,
|
| 298 |
+
eps: float,
|
| 299 |
+
) -> List[int]:
|
| 300 |
+
group_size = num_heads // num_kv_heads
|
| 301 |
+
if group_size * num_kv_heads != num_heads:
|
| 302 |
+
raise ValueError("num_heads must be divisible by num_key_value_heads")
|
| 303 |
+
|
| 304 |
+
perm = list(range(num_heads))
|
| 305 |
+
for g in range(num_kv_heads):
|
| 306 |
+
start = g * group_size
|
| 307 |
+
end = start + group_size
|
| 308 |
+
cost = cosine_cost_matrix(mean_i[start:end], mean_j[start:end], eps=eps)
|
| 309 |
+
assignment = hungarian(cost)
|
| 310 |
+
for local_idx, match in enumerate(assignment):
|
| 311 |
+
perm[start + local_idx] = start + match
|
| 312 |
+
return perm
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def permute_attention_heads(
|
| 316 |
+
attn: torch.nn.Module,
|
| 317 |
+
perm: List[int],
|
| 318 |
+
num_heads: int,
|
| 319 |
+
num_kv_heads: int,
|
| 320 |
+
head_dim: int,
|
| 321 |
+
) -> None:
|
| 322 |
+
hidden_size = num_heads * head_dim
|
| 323 |
+
|
| 324 |
+
def permute_out_proj_weight(weight: torch.Tensor) -> torch.Tensor:
|
| 325 |
+
out_features, in_features = weight.shape
|
| 326 |
+
if in_features != hidden_size:
|
| 327 |
+
raise ValueError(
|
| 328 |
+
"o_proj in_features ({} ) != num_heads*head_dim ({})".format(
|
| 329 |
+
in_features, hidden_size
|
| 330 |
+
)
|
| 331 |
+
)
|
| 332 |
+
reshaped = weight.view(out_features, num_heads, head_dim)
|
| 333 |
+
reshaped = reshaped[:, perm, :]
|
| 334 |
+
return reshaped.reshape(out_features, in_features)
|
| 335 |
+
|
| 336 |
+
def permute_proj_weight(weight: torch.Tensor) -> torch.Tensor:
|
| 337 |
+
out_features, in_features = weight.shape
|
| 338 |
+
if out_features != hidden_size:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"proj out_features ({}) != num_heads*head_dim ({})".format(
|
| 341 |
+
out_features, hidden_size
|
| 342 |
+
)
|
| 343 |
+
)
|
| 344 |
+
reshaped = weight.view(num_heads, head_dim, in_features)
|
| 345 |
+
reshaped = reshaped[perm, :, :]
|
| 346 |
+
return reshaped.reshape(out_features, in_features)
|
| 347 |
+
|
| 348 |
+
def permute_proj_bias(bias: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 349 |
+
if bias is None:
|
| 350 |
+
return None
|
| 351 |
+
reshaped = bias.view(num_heads, head_dim)
|
| 352 |
+
reshaped = reshaped[perm, :]
|
| 353 |
+
return reshaped.reshape(num_heads * head_dim)
|
| 354 |
+
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
attn.q_proj.weight.copy_(permute_proj_weight(attn.q_proj.weight))
|
| 357 |
+
if attn.q_proj.bias is not None:
|
| 358 |
+
attn.q_proj.bias.copy_(permute_proj_bias(attn.q_proj.bias))
|
| 359 |
+
|
| 360 |
+
if num_kv_heads == num_heads:
|
| 361 |
+
attn.k_proj.weight.copy_(permute_proj_weight(attn.k_proj.weight))
|
| 362 |
+
if attn.k_proj.bias is not None:
|
| 363 |
+
attn.k_proj.bias.copy_(permute_proj_bias(attn.k_proj.bias))
|
| 364 |
+
attn.v_proj.weight.copy_(permute_proj_weight(attn.v_proj.weight))
|
| 365 |
+
if attn.v_proj.bias is not None:
|
| 366 |
+
attn.v_proj.bias.copy_(permute_proj_bias(attn.v_proj.bias))
|
| 367 |
+
|
| 368 |
+
attn.o_proj.weight.copy_(permute_out_proj_weight(attn.o_proj.weight))
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def compute_fisher(
|
| 372 |
+
model,
|
| 373 |
+
layer_a: torch.nn.Module,
|
| 374 |
+
layer_b: torch.nn.Module,
|
| 375 |
+
dataloader,
|
| 376 |
+
fisher_mode: str,
|
| 377 |
+
device: str,
|
| 378 |
+
) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]:
|
| 379 |
+
for param in model.parameters():
|
| 380 |
+
param.requires_grad_(False)
|
| 381 |
+
for layer in (layer_a, layer_b):
|
| 382 |
+
for param in layer.parameters():
|
| 383 |
+
param.requires_grad_(True)
|
| 384 |
+
|
| 385 |
+
fisher_sums: List[Dict[str, object]] = []
|
| 386 |
+
param_numels: List[Dict[str, int]] = []
|
| 387 |
+
for layer in (layer_a, layer_b):
|
| 388 |
+
layer_sums: Dict[str, object] = {}
|
| 389 |
+
layer_numels: Dict[str, int] = {}
|
| 390 |
+
for name, param in layer.named_parameters():
|
| 391 |
+
if not param.requires_grad:
|
| 392 |
+
continue
|
| 393 |
+
if fisher_mode == "param":
|
| 394 |
+
layer_sums[name] = torch.zeros_like(
|
| 395 |
+
param, dtype=torch.float32, device="cpu"
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
layer_sums[name] = 0.0
|
| 399 |
+
layer_numels[name] = param.numel()
|
| 400 |
+
fisher_sums.append(layer_sums)
|
| 401 |
+
param_numels.append(layer_numels)
|
| 402 |
+
|
| 403 |
+
num_batches = 0
|
| 404 |
+
model.eval()
|
| 405 |
+
iterator = dataloader
|
| 406 |
+
if tqdm is not None and _tqdm_enabled():
|
| 407 |
+
iterator = tqdm(dataloader, desc="Fisher", unit="batch")
|
| 408 |
+
for batch in iterator:
|
| 409 |
+
input_ids = batch[0].to(device)
|
| 410 |
+
outputs = model(input_ids=input_ids, labels=input_ids)
|
| 411 |
+
loss = outputs.loss
|
| 412 |
+
loss.backward()
|
| 413 |
+
for layer_idx, layer in enumerate((layer_a, layer_b)):
|
| 414 |
+
layer_sums = fisher_sums[layer_idx]
|
| 415 |
+
for name, param in layer.named_parameters():
|
| 416 |
+
if not param.requires_grad or param.grad is None:
|
| 417 |
+
continue
|
| 418 |
+
grad_sq = param.grad.detach().float().pow(2)
|
| 419 |
+
if fisher_mode == "param":
|
| 420 |
+
layer_sums[name] += grad_sq.cpu()
|
| 421 |
+
else:
|
| 422 |
+
layer_sums[name] += float(grad_sq.sum().item())
|
| 423 |
+
model.zero_grad(set_to_none=True)
|
| 424 |
+
num_batches += 1
|
| 425 |
+
|
| 426 |
+
if num_batches == 0:
|
| 427 |
+
raise RuntimeError("No batches processed; check dataset or text inputs.")
|
| 428 |
+
|
| 429 |
+
return fisher_sums, num_batches, param_numels
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def merge_layers(
|
| 433 |
+
layer_a: torch.nn.Module,
|
| 434 |
+
layer_b: torch.nn.Module,
|
| 435 |
+
fisher_a: Dict[str, object],
|
| 436 |
+
fisher_b: Dict[str, object],
|
| 437 |
+
num_batches: int,
|
| 438 |
+
numels_a: Dict[str, int],
|
| 439 |
+
numels_b: Dict[str, int],
|
| 440 |
+
fisher_mode: str,
|
| 441 |
+
eps: float,
|
| 442 |
+
) -> int:
|
| 443 |
+
merged = 0
|
| 444 |
+
params_b = {name: param for name, param in layer_b.named_parameters()}
|
| 445 |
+
with torch.no_grad():
|
| 446 |
+
for name, param_a in layer_a.named_parameters():
|
| 447 |
+
param_b = params_b.get(name)
|
| 448 |
+
if param_b is None or param_b.shape != param_a.shape:
|
| 449 |
+
continue
|
| 450 |
+
if fisher_mode == "param":
|
| 451 |
+
fa = fisher_a[name] / num_batches
|
| 452 |
+
fb = fisher_b[name] / num_batches
|
| 453 |
+
# Fisher tensors are accumulated on CPU to save VRAM; move to the
|
| 454 |
+
# parameter device for the actual merge.
|
| 455 |
+
if isinstance(fa, torch.Tensor) and fa.device != param_a.device:
|
| 456 |
+
fa = fa.to(param_a.device)
|
| 457 |
+
if isinstance(fb, torch.Tensor) and fb.device != param_a.device:
|
| 458 |
+
fb = fb.to(param_a.device)
|
| 459 |
+
denom = fa + fb
|
| 460 |
+
denom_mean = float(denom.mean().item())
|
| 461 |
+
if denom_mean <= eps:
|
| 462 |
+
merged_param = 0.5 * (param_a.float() + param_b.float())
|
| 463 |
+
else:
|
| 464 |
+
merged_param = (fa * param_a.float() + fb * param_b.float()) / (
|
| 465 |
+
denom + eps
|
| 466 |
+
)
|
| 467 |
+
else:
|
| 468 |
+
fa = fisher_a[name] / (num_batches * numels_a[name])
|
| 469 |
+
fb = fisher_b[name] / (num_batches * numels_b[name])
|
| 470 |
+
denom = fa + fb
|
| 471 |
+
if denom <= eps:
|
| 472 |
+
merged_param = 0.5 * (param_a.float() + param_b.float())
|
| 473 |
+
else:
|
| 474 |
+
merged_param = (
|
| 475 |
+
fa * param_a.float() + fb * param_b.float()
|
| 476 |
+
) / (denom + eps)
|
| 477 |
+
param_a.copy_(merged_param.to(dtype=param_a.dtype))
|
| 478 |
+
merged += 1
|
| 479 |
+
return merged
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def merge_layers_with_gates(
|
| 483 |
+
layer_a: torch.nn.Module,
|
| 484 |
+
layer_b: torch.nn.Module,
|
| 485 |
+
gates: Dict[str, torch.Tensor],
|
| 486 |
+
) -> int:
|
| 487 |
+
"""Merge layer_b into layer_a using precomputed gates.
|
| 488 |
+
|
| 489 |
+
Each gate is a lambda in [0, 1] that mixes parameters as:
|
| 490 |
+
W = lambda * W_a + (1 - lambda) * W_b
|
| 491 |
+
|
| 492 |
+
Gate tensors may be scalars (per-tensor gating) or full tensors matching the
|
| 493 |
+
parameter shape (per-parameter gating).
|
| 494 |
+
"""
|
| 495 |
+
merged = 0
|
| 496 |
+
params_b = {name: param for name, param in layer_b.named_parameters()}
|
| 497 |
+
with torch.no_grad():
|
| 498 |
+
for name, param_a in layer_a.named_parameters():
|
| 499 |
+
gate = gates.get(name)
|
| 500 |
+
if gate is None:
|
| 501 |
+
continue
|
| 502 |
+
param_b = params_b.get(name)
|
| 503 |
+
if param_b is None or param_b.shape != param_a.shape:
|
| 504 |
+
continue
|
| 505 |
+
lam = gate
|
| 506 |
+
if not isinstance(lam, torch.Tensor):
|
| 507 |
+
lam = torch.tensor(lam)
|
| 508 |
+
if lam.device != param_a.device:
|
| 509 |
+
lam = lam.to(param_a.device)
|
| 510 |
+
merged_param = lam * param_a.float() + (1.0 - lam) * param_b.float()
|
| 511 |
+
param_a.copy_(merged_param.to(dtype=param_a.dtype))
|
| 512 |
+
merged += 1
|
| 513 |
+
return merged
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def drop_layer(container: object, index: int) -> object:
|
| 517 |
+
if isinstance(container, torch.nn.ModuleList):
|
| 518 |
+
return torch.nn.ModuleList(
|
| 519 |
+
[layer for idx, layer in enumerate(container) if idx != index]
|
| 520 |
+
)
|
| 521 |
+
if isinstance(container, list):
|
| 522 |
+
del container[index]
|
| 523 |
+
return container
|
| 524 |
+
raise TypeError("Layer container must be ModuleList or list")
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def decrement_config(config) -> None:
|
| 528 |
+
for attr in ("num_hidden_layers", "n_layer", "num_layers"):
|
| 529 |
+
if hasattr(config, attr):
|
| 530 |
+
value = getattr(config, attr)
|
| 531 |
+
if isinstance(value, int) and value > 0:
|
| 532 |
+
setattr(config, attr, value - 1)
|
| 533 |
+
normalize_config(config)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def normalize_config(config) -> None:
|
| 537 |
+
num_hidden_layers = getattr(config, "num_hidden_layers", None)
|
| 538 |
+
layer_types = getattr(config, "layer_types", None)
|
| 539 |
+
if (
|
| 540 |
+
isinstance(num_hidden_layers, int)
|
| 541 |
+
and num_hidden_layers >= 0
|
| 542 |
+
and isinstance(layer_types, (list, tuple))
|
| 543 |
+
and len(layer_types) != num_hidden_layers
|
| 544 |
+
):
|
| 545 |
+
config.layer_types = list(layer_types[:num_hidden_layers])
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def find_colon_modules(module: torch.nn.Module) -> List[str]:
|
| 549 |
+
found: List[str] = []
|
| 550 |
+
for name, child in module._modules.items():
|
| 551 |
+
if ":" in name:
|
| 552 |
+
found.append(name)
|
| 553 |
+
if isinstance(child, torch.nn.Module):
|
| 554 |
+
for sub in find_colon_modules(child):
|
| 555 |
+
found.append(f"{name}.{sub}")
|
| 556 |
+
return found
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def get_norm_pair(
|
| 560 |
+
layer: torch.nn.Module,
|
| 561 |
+
) -> Tuple[
|
| 562 |
+
Optional[torch.nn.Module],
|
| 563 |
+
Optional[torch.nn.Module],
|
| 564 |
+
Tuple[Optional[str], Optional[str]],
|
| 565 |
+
]:
|
| 566 |
+
candidates = [
|
| 567 |
+
("input_layernorm", "post_attention_layernorm"),
|
| 568 |
+
("ln_1", "ln_2"),
|
| 569 |
+
("norm1", "norm2"),
|
| 570 |
+
("norm_1", "norm_2"),
|
| 571 |
+
("layer_norm_1", "layer_norm_2"),
|
| 572 |
+
("self_attn_layer_norm", "final_layer_norm"),
|
| 573 |
+
]
|
| 574 |
+
for n1, n2 in candidates:
|
| 575 |
+
if hasattr(layer, n1) and hasattr(layer, n2):
|
| 576 |
+
return getattr(layer, n1), getattr(layer, n2), (n1, n2)
|
| 577 |
+
return None, None, (None, None)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def clone_state_dict(module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
| 581 |
+
return {k: v.detach().clone() for k, v in module.state_dict().items()}
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def apply_norm_policy(
|
| 585 |
+
layer: torch.nn.Module,
|
| 586 |
+
norm_policy: str,
|
| 587 |
+
norm1_state: Optional[Dict[str, torch.Tensor]],
|
| 588 |
+
norm2_state: Optional[Dict[str, torch.Tensor]],
|
| 589 |
+
norm_names: Tuple[Optional[str], Optional[str]],
|
| 590 |
+
) -> None:
|
| 591 |
+
norm1, norm2, _ = get_norm_pair(layer)
|
| 592 |
+
if norm_policy in {"copy_n1", "hybrid"} and norm1_state is not None and norm1 is not None:
|
| 593 |
+
norm1.load_state_dict(norm1_state)
|
| 594 |
+
if norm_policy == "copy_n1_n2" and norm2_state is not None and norm2 is not None:
|
| 595 |
+
norm2.load_state_dict(norm2_state)
|
src/fuse_layers_select.py
ADDED
|
@@ -0,0 +1,1152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Automatic adjacent-pair selection via configurable scoring metrics."""
|
| 3 |
+
|
| 4 |
+
import copy
|
| 5 |
+
import math
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from fuse_layers_model import (
|
| 13 |
+
build_head_permutation,
|
| 14 |
+
compute_fisher,
|
| 15 |
+
compute_head_means,
|
| 16 |
+
find_attention_module,
|
| 17 |
+
find_layer_container,
|
| 18 |
+
merge_layers,
|
| 19 |
+
permute_attention_heads,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
_DWCE_GRAD_CACHE_MAX_BYTES = 1 << 30
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _DwceGradCacheOverflow(RuntimeError):
|
| 26 |
+
"""Raised when shared-backward DWCE caching exceeds the configured budget."""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _get_hidden_size(model) -> int:
|
| 30 |
+
hidden_size = getattr(model.config, "hidden_size", None)
|
| 31 |
+
if hidden_size is None:
|
| 32 |
+
hidden_size = getattr(model.config, "n_embd", None)
|
| 33 |
+
if hidden_size is None:
|
| 34 |
+
raise SystemExit("Model config missing hidden_size/n_embd")
|
| 35 |
+
return int(hidden_size)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _detach_arg(arg):
|
| 39 |
+
if torch.is_tensor(arg):
|
| 40 |
+
return arg.detach()
|
| 41 |
+
if isinstance(arg, (list, tuple)):
|
| 42 |
+
return type(arg)(_detach_arg(x) for x in arg)
|
| 43 |
+
if isinstance(arg, dict):
|
| 44 |
+
return {k: _detach_arg(v) for k, v in arg.items()}
|
| 45 |
+
return arg
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _register_forward_hook(layer, hook):
|
| 49 |
+
try:
|
| 50 |
+
def wrapper(module, inputs, kwargs, output):
|
| 51 |
+
return hook(module, inputs, output, kwargs)
|
| 52 |
+
|
| 53 |
+
handle = layer.register_forward_hook(wrapper, with_kwargs=True)
|
| 54 |
+
return handle, True
|
| 55 |
+
except TypeError:
|
| 56 |
+
def wrapper(module, inputs, output):
|
| 57 |
+
return hook(module, inputs, output, None)
|
| 58 |
+
handle = layer.register_forward_hook(wrapper)
|
| 59 |
+
return handle, False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@contextmanager
|
| 63 |
+
def _temporary_layers(parent: object, name: str, new_layers: object):
|
| 64 |
+
original = getattr(parent, name)
|
| 65 |
+
setattr(parent, name, new_layers)
|
| 66 |
+
try:
|
| 67 |
+
yield
|
| 68 |
+
finally:
|
| 69 |
+
setattr(parent, name, original)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _extract_hidden(output):
|
| 73 |
+
if torch.is_tensor(output):
|
| 74 |
+
return output
|
| 75 |
+
if isinstance(output, (tuple, list)):
|
| 76 |
+
if output and all(torch.is_tensor(item) for item in output):
|
| 77 |
+
return output[0]
|
| 78 |
+
for item in output:
|
| 79 |
+
hidden = _extract_hidden(item)
|
| 80 |
+
if hidden is not None:
|
| 81 |
+
return hidden
|
| 82 |
+
return None
|
| 83 |
+
if isinstance(output, dict):
|
| 84 |
+
for key in ("hidden_states", "last_hidden_state", "hidden_state"):
|
| 85 |
+
if key in output:
|
| 86 |
+
value = output[key]
|
| 87 |
+
if isinstance(value, (tuple, list)) and value and all(
|
| 88 |
+
torch.is_tensor(item) for item in value
|
| 89 |
+
):
|
| 90 |
+
return value[-1]
|
| 91 |
+
hidden = _extract_hidden(value)
|
| 92 |
+
if hidden is not None:
|
| 93 |
+
return hidden
|
| 94 |
+
for value in output.values():
|
| 95 |
+
hidden = _extract_hidden(value)
|
| 96 |
+
if hidden is not None:
|
| 97 |
+
return hidden
|
| 98 |
+
return None
|
| 99 |
+
for attr in ("hidden_states", "last_hidden_state"):
|
| 100 |
+
if hasattr(output, attr):
|
| 101 |
+
value = getattr(output, attr)
|
| 102 |
+
if isinstance(value, (tuple, list)) and value and all(
|
| 103 |
+
torch.is_tensor(item) for item in value
|
| 104 |
+
):
|
| 105 |
+
return value[-1]
|
| 106 |
+
hidden = _extract_hidden(value)
|
| 107 |
+
if hidden is not None:
|
| 108 |
+
return hidden
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _build_fused_layer_for_pair(
|
| 113 |
+
model,
|
| 114 |
+
layer_a: torch.nn.Module,
|
| 115 |
+
layer_b: torch.nn.Module,
|
| 116 |
+
dataloader,
|
| 117 |
+
device: str,
|
| 118 |
+
fisher_mode: str,
|
| 119 |
+
eps: float,
|
| 120 |
+
hidden_size: int,
|
| 121 |
+
enable_head_permute: bool = True,
|
| 122 |
+
) -> Tuple[torch.nn.Module, Dict[str, float]]:
|
| 123 |
+
attn_a = find_attention_module(layer_a)
|
| 124 |
+
attn_b = find_attention_module(layer_b)
|
| 125 |
+
perm = None
|
| 126 |
+
inv_perm = None
|
| 127 |
+
num_heads = None
|
| 128 |
+
num_kv_heads = None
|
| 129 |
+
head_dim = None
|
| 130 |
+
if enable_head_permute:
|
| 131 |
+
mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
|
| 132 |
+
model,
|
| 133 |
+
attn_a,
|
| 134 |
+
attn_b,
|
| 135 |
+
dataloader,
|
| 136 |
+
device,
|
| 137 |
+
hidden_size,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
perm = build_head_permutation(
|
| 141 |
+
mean_a,
|
| 142 |
+
mean_b,
|
| 143 |
+
num_heads=num_heads,
|
| 144 |
+
num_kv_heads=num_kv_heads,
|
| 145 |
+
eps=eps,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
layer_a_copy = copy.deepcopy(layer_a)
|
| 149 |
+
layer_b_copy = copy.deepcopy(layer_b)
|
| 150 |
+
attn_b_copy = find_attention_module(layer_b_copy)
|
| 151 |
+
if perm is not None:
|
| 152 |
+
permute_attention_heads(
|
| 153 |
+
attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
inv_perm = [0] * len(perm)
|
| 157 |
+
for idx, mapped in enumerate(perm):
|
| 158 |
+
inv_perm[mapped] = idx
|
| 159 |
+
|
| 160 |
+
permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim)
|
| 161 |
+
try:
|
| 162 |
+
fisher_sums, num_batches, param_numels = compute_fisher(
|
| 163 |
+
model,
|
| 164 |
+
layer_a,
|
| 165 |
+
layer_b,
|
| 166 |
+
dataloader,
|
| 167 |
+
fisher_mode=fisher_mode,
|
| 168 |
+
device=device,
|
| 169 |
+
)
|
| 170 |
+
finally:
|
| 171 |
+
if inv_perm is not None:
|
| 172 |
+
permute_attention_heads(
|
| 173 |
+
attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
merge_layers(
|
| 177 |
+
layer_a_copy,
|
| 178 |
+
layer_b_copy,
|
| 179 |
+
fisher_sums[0],
|
| 180 |
+
fisher_sums[1],
|
| 181 |
+
num_batches,
|
| 182 |
+
param_numels[0],
|
| 183 |
+
param_numels[1],
|
| 184 |
+
fisher_mode=fisher_mode,
|
| 185 |
+
eps=eps,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Scalar mixing coefficients per parameter tensor; used by pressure redistribution
|
| 189 |
+
# to simulate future fusions without running another Fisher pass.
|
| 190 |
+
fuse_priors: Dict[str, float] = {}
|
| 191 |
+
params_b = {name: param for name, param in layer_b.named_parameters()}
|
| 192 |
+
clamp_eps = 1e-4
|
| 193 |
+
for name, param_a in layer_a.named_parameters():
|
| 194 |
+
param_b = params_b.get(name)
|
| 195 |
+
if param_b is None or param_b.shape != param_a.shape:
|
| 196 |
+
continue
|
| 197 |
+
if fisher_mode == "param":
|
| 198 |
+
fa = fisher_sums[0][name] / max(num_batches, 1)
|
| 199 |
+
fb = fisher_sums[1][name] / max(num_batches, 1)
|
| 200 |
+
if isinstance(fa, torch.Tensor):
|
| 201 |
+
fa_val = float(fa.mean().item())
|
| 202 |
+
else:
|
| 203 |
+
fa_val = float(fa)
|
| 204 |
+
if isinstance(fb, torch.Tensor):
|
| 205 |
+
fb_val = float(fb.mean().item())
|
| 206 |
+
else:
|
| 207 |
+
fb_val = float(fb)
|
| 208 |
+
else:
|
| 209 |
+
fa_val = float(
|
| 210 |
+
fisher_sums[0][name]
|
| 211 |
+
/ (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1))
|
| 212 |
+
)
|
| 213 |
+
fb_val = float(
|
| 214 |
+
fisher_sums[1][name]
|
| 215 |
+
/ (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1))
|
| 216 |
+
)
|
| 217 |
+
denom = fa_val + fb_val
|
| 218 |
+
if denom <= eps:
|
| 219 |
+
lam = 0.5
|
| 220 |
+
else:
|
| 221 |
+
lam = fa_val / (denom + eps)
|
| 222 |
+
lam = min(max(lam, clamp_eps), 1.0 - clamp_eps)
|
| 223 |
+
fuse_priors[name] = lam
|
| 224 |
+
|
| 225 |
+
layer_a_copy.eval()
|
| 226 |
+
return layer_a_copy, fuse_priors
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _init_fisher_accumulators(
|
| 230 |
+
layer_a: torch.nn.Module,
|
| 231 |
+
layer_b: torch.nn.Module,
|
| 232 |
+
fisher_mode: str,
|
| 233 |
+
device: str,
|
| 234 |
+
) -> Tuple[List[Dict[str, object]], List[Dict[str, int]]]:
|
| 235 |
+
fisher_sums: List[Dict[str, object]] = []
|
| 236 |
+
param_numels: List[Dict[str, int]] = []
|
| 237 |
+
for layer in (layer_a, layer_b):
|
| 238 |
+
layer_sums: Dict[str, object] = {}
|
| 239 |
+
layer_numels: Dict[str, int] = {}
|
| 240 |
+
for name, param in layer.named_parameters():
|
| 241 |
+
if not param.requires_grad:
|
| 242 |
+
continue
|
| 243 |
+
if fisher_mode == "param":
|
| 244 |
+
layer_sums[name] = torch.zeros_like(
|
| 245 |
+
param, dtype=torch.float32, device="cpu"
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
layer_sums[name] = torch.zeros((), dtype=torch.float32, device=device)
|
| 249 |
+
layer_numels[name] = param.numel()
|
| 250 |
+
fisher_sums.append(layer_sums)
|
| 251 |
+
param_numels.append(layer_numels)
|
| 252 |
+
return fisher_sums, param_numels
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _accumulate_fisher_from_grads(
|
| 256 |
+
layer: torch.nn.Module,
|
| 257 |
+
layer_sums: Dict[str, object],
|
| 258 |
+
fisher_mode: str,
|
| 259 |
+
) -> None:
|
| 260 |
+
for name, param in layer.named_parameters():
|
| 261 |
+
if not param.requires_grad or param.grad is None:
|
| 262 |
+
continue
|
| 263 |
+
grad_sq = param.grad.detach().float().pow(2)
|
| 264 |
+
if fisher_mode == "param":
|
| 265 |
+
layer_sums[name] += grad_sq.cpu()
|
| 266 |
+
else:
|
| 267 |
+
layer_sums[name] += grad_sq.sum()
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _finalize_fisher_sums(
|
| 271 |
+
fisher_sums: List[Dict[str, object]],
|
| 272 |
+
fisher_mode: str,
|
| 273 |
+
) -> List[Dict[str, object]]:
|
| 274 |
+
if fisher_mode == "param":
|
| 275 |
+
return fisher_sums
|
| 276 |
+
|
| 277 |
+
finalized: List[Dict[str, object]] = []
|
| 278 |
+
for layer_sums in fisher_sums:
|
| 279 |
+
finalized_layer: Dict[str, object] = {}
|
| 280 |
+
for name, value in layer_sums.items():
|
| 281 |
+
if isinstance(value, torch.Tensor):
|
| 282 |
+
finalized_layer[name] = float(value.detach().cpu().item())
|
| 283 |
+
else:
|
| 284 |
+
finalized_layer[name] = float(value)
|
| 285 |
+
finalized.append(finalized_layer)
|
| 286 |
+
return finalized
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _compute_fuse_priors(
|
| 290 |
+
layer_a: torch.nn.Module,
|
| 291 |
+
layer_b: torch.nn.Module,
|
| 292 |
+
fisher_sums: List[Dict[str, object]],
|
| 293 |
+
num_batches: int,
|
| 294 |
+
param_numels: List[Dict[str, int]],
|
| 295 |
+
fisher_mode: str,
|
| 296 |
+
eps: float,
|
| 297 |
+
) -> Dict[str, float]:
|
| 298 |
+
fuse_priors: Dict[str, float] = {}
|
| 299 |
+
params_b = {name: param for name, param in layer_b.named_parameters()}
|
| 300 |
+
clamp_eps = 1e-4
|
| 301 |
+
for name, param_a in layer_a.named_parameters():
|
| 302 |
+
param_b = params_b.get(name)
|
| 303 |
+
if param_b is None or param_b.shape != param_a.shape:
|
| 304 |
+
continue
|
| 305 |
+
if fisher_mode == "param":
|
| 306 |
+
fa = fisher_sums[0][name] / max(num_batches, 1)
|
| 307 |
+
fb = fisher_sums[1][name] / max(num_batches, 1)
|
| 308 |
+
fa_val = float(fa.mean().item()) if isinstance(fa, torch.Tensor) else float(fa)
|
| 309 |
+
fb_val = float(fb.mean().item()) if isinstance(fb, torch.Tensor) else float(fb)
|
| 310 |
+
else:
|
| 311 |
+
fa_val = float(
|
| 312 |
+
fisher_sums[0][name]
|
| 313 |
+
/ (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1))
|
| 314 |
+
)
|
| 315 |
+
fb_val = float(
|
| 316 |
+
fisher_sums[1][name]
|
| 317 |
+
/ (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1))
|
| 318 |
+
)
|
| 319 |
+
denom = fa_val + fb_val
|
| 320 |
+
lam = 0.5 if denom <= eps else fa_val / (denom + eps)
|
| 321 |
+
fuse_priors[name] = min(max(lam, clamp_eps), 1.0 - clamp_eps)
|
| 322 |
+
return fuse_priors
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _score_dwce_with_shared_backward(
|
| 326 |
+
model,
|
| 327 |
+
layer_a: torch.nn.Module,
|
| 328 |
+
layer_b: torch.nn.Module,
|
| 329 |
+
dataloader,
|
| 330 |
+
device: str,
|
| 331 |
+
fisher_mode: str,
|
| 332 |
+
max_batches: int,
|
| 333 |
+
eps: float,
|
| 334 |
+
norm: str,
|
| 335 |
+
hidden_size: int,
|
| 336 |
+
enable_head_permute: bool = True,
|
| 337 |
+
) -> Tuple[float, Dict[str, object]]:
|
| 338 |
+
attn_a = find_attention_module(layer_a)
|
| 339 |
+
attn_b = find_attention_module(layer_b)
|
| 340 |
+
perm = None
|
| 341 |
+
inv_perm = None
|
| 342 |
+
num_heads = None
|
| 343 |
+
num_kv_heads = None
|
| 344 |
+
head_dim = None
|
| 345 |
+
if enable_head_permute:
|
| 346 |
+
mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
|
| 347 |
+
model,
|
| 348 |
+
attn_a,
|
| 349 |
+
attn_b,
|
| 350 |
+
dataloader,
|
| 351 |
+
device,
|
| 352 |
+
hidden_size,
|
| 353 |
+
)
|
| 354 |
+
perm = build_head_permutation(
|
| 355 |
+
mean_a,
|
| 356 |
+
mean_b,
|
| 357 |
+
num_heads=num_heads,
|
| 358 |
+
num_kv_heads=num_kv_heads,
|
| 359 |
+
eps=eps,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
layer_a_copy = copy.deepcopy(layer_a)
|
| 363 |
+
layer_b_copy = copy.deepcopy(layer_b)
|
| 364 |
+
attn_b_copy = find_attention_module(layer_b_copy)
|
| 365 |
+
if perm is not None:
|
| 366 |
+
permute_attention_heads(
|
| 367 |
+
attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
inv_perm = [0] * len(perm)
|
| 371 |
+
for idx, mapped in enumerate(perm):
|
| 372 |
+
inv_perm[mapped] = idx
|
| 373 |
+
|
| 374 |
+
cache: Dict[str, Optional[torch.Tensor]] = {"teacher": None}
|
| 375 |
+
grad_sq_cache: List[torch.Tensor] = []
|
| 376 |
+
cached_bytes = 0
|
| 377 |
+
|
| 378 |
+
def hook_b(_module, _inputs, output, _kwargs=None):
|
| 379 |
+
teacher_hidden = _extract_hidden(output)
|
| 380 |
+
if teacher_hidden is None:
|
| 381 |
+
raise RuntimeError("Failed to extract teacher hidden state output.")
|
| 382 |
+
cache["teacher"] = teacher_hidden
|
| 383 |
+
if teacher_hidden.requires_grad:
|
| 384 |
+
teacher_hidden.retain_grad()
|
| 385 |
+
return output
|
| 386 |
+
|
| 387 |
+
handle_b, _ = _register_forward_hook(layer_b, hook_b)
|
| 388 |
+
for param in model.parameters():
|
| 389 |
+
param.requires_grad_(False)
|
| 390 |
+
for layer in (layer_a, layer_b):
|
| 391 |
+
for param in layer.parameters():
|
| 392 |
+
param.requires_grad_(True)
|
| 393 |
+
fisher_sums, param_numels = _init_fisher_accumulators(
|
| 394 |
+
layer_a, layer_b, fisher_mode, device
|
| 395 |
+
)
|
| 396 |
+
num_batches = 0
|
| 397 |
+
|
| 398 |
+
if perm is not None:
|
| 399 |
+
permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim)
|
| 400 |
+
try:
|
| 401 |
+
model.eval()
|
| 402 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 403 |
+
if max_batches and batch_idx >= max_batches:
|
| 404 |
+
break
|
| 405 |
+
cache["teacher"] = None
|
| 406 |
+
input_ids = batch[0].to(device)
|
| 407 |
+
attention_mask = batch[1].to(device) if len(batch) > 1 else None
|
| 408 |
+
|
| 409 |
+
model.zero_grad(set_to_none=True)
|
| 410 |
+
outputs = model(
|
| 411 |
+
input_ids=input_ids,
|
| 412 |
+
attention_mask=attention_mask,
|
| 413 |
+
labels=input_ids,
|
| 414 |
+
)
|
| 415 |
+
outputs.loss.backward()
|
| 416 |
+
|
| 417 |
+
teacher = cache["teacher"]
|
| 418 |
+
grad = None if teacher is None else teacher.grad
|
| 419 |
+
if teacher is None or grad is None:
|
| 420 |
+
raise RuntimeError(
|
| 421 |
+
"Auto selection hooks failed to capture outputs/gradients. "
|
| 422 |
+
"Try updating PyTorch or run with --layer <index>."
|
| 423 |
+
)
|
| 424 |
+
grad_sq = grad.detach().pow(2).to(device=device, dtype=torch.float16)
|
| 425 |
+
cached_bytes += grad_sq.numel() * grad_sq.element_size()
|
| 426 |
+
if cached_bytes > _DWCE_GRAD_CACHE_MAX_BYTES:
|
| 427 |
+
raise _DwceGradCacheOverflow(
|
| 428 |
+
"DWCE grad cache exceeded device-memory budget during shared-backward scoring."
|
| 429 |
+
)
|
| 430 |
+
grad_sq_cache.append(grad_sq)
|
| 431 |
+
_accumulate_fisher_from_grads(layer_a, fisher_sums[0], fisher_mode)
|
| 432 |
+
_accumulate_fisher_from_grads(layer_b, fisher_sums[1], fisher_mode)
|
| 433 |
+
model.zero_grad(set_to_none=True)
|
| 434 |
+
num_batches += 1
|
| 435 |
+
finally:
|
| 436 |
+
handle_b.remove()
|
| 437 |
+
if inv_perm is not None:
|
| 438 |
+
permute_attention_heads(
|
| 439 |
+
attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim
|
| 440 |
+
)
|
| 441 |
+
for param in model.parameters():
|
| 442 |
+
param.requires_grad_(True)
|
| 443 |
+
|
| 444 |
+
if num_batches == 0:
|
| 445 |
+
raise RuntimeError("No batches processed; check dataset or text inputs.")
|
| 446 |
+
|
| 447 |
+
fisher_sums = _finalize_fisher_sums(fisher_sums, fisher_mode)
|
| 448 |
+
merge_layers(
|
| 449 |
+
layer_a_copy,
|
| 450 |
+
layer_b_copy,
|
| 451 |
+
fisher_sums[0],
|
| 452 |
+
fisher_sums[1],
|
| 453 |
+
num_batches,
|
| 454 |
+
param_numels[0],
|
| 455 |
+
param_numels[1],
|
| 456 |
+
fisher_mode=fisher_mode,
|
| 457 |
+
eps=eps,
|
| 458 |
+
)
|
| 459 |
+
fuse_priors = _compute_fuse_priors(
|
| 460 |
+
layer_a,
|
| 461 |
+
layer_b,
|
| 462 |
+
fisher_sums,
|
| 463 |
+
num_batches,
|
| 464 |
+
param_numels,
|
| 465 |
+
fisher_mode,
|
| 466 |
+
eps,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
fused_layer = layer_a_copy
|
| 470 |
+
fused_layer.eval()
|
| 471 |
+
phase2_cache = {"teacher": None, "fused": None}
|
| 472 |
+
|
| 473 |
+
def hook_a(_module, inputs, output, kwargs=None):
|
| 474 |
+
with torch.no_grad():
|
| 475 |
+
detached_inputs = tuple(_detach_arg(arg) for arg in inputs)
|
| 476 |
+
if kwargs:
|
| 477 |
+
detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()}
|
| 478 |
+
fused_out = fused_layer(*detached_inputs, **detached_kwargs)
|
| 479 |
+
else:
|
| 480 |
+
fused_out = fused_layer(*detached_inputs)
|
| 481 |
+
fused_hidden = _extract_hidden(fused_out)
|
| 482 |
+
if fused_hidden is None:
|
| 483 |
+
raise RuntimeError("Failed to extract fused hidden state output.")
|
| 484 |
+
phase2_cache["fused"] = fused_hidden
|
| 485 |
+
return output
|
| 486 |
+
|
| 487 |
+
def hook_b_eval(_module, _inputs, output, _kwargs=None):
|
| 488 |
+
teacher_hidden = _extract_hidden(output)
|
| 489 |
+
if teacher_hidden is None:
|
| 490 |
+
raise RuntimeError("Failed to extract teacher hidden state output.")
|
| 491 |
+
phase2_cache["teacher"] = teacher_hidden
|
| 492 |
+
return output
|
| 493 |
+
|
| 494 |
+
handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
|
| 495 |
+
handle_b_eval, has_kwargs_b = _register_forward_hook(layer_b, hook_b_eval)
|
| 496 |
+
supports_kwargs = has_kwargs_a and has_kwargs_b
|
| 497 |
+
|
| 498 |
+
score_num = 0.0
|
| 499 |
+
score_den = 0.0
|
| 500 |
+
token_count = 0.0
|
| 501 |
+
try:
|
| 502 |
+
model.eval()
|
| 503 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 504 |
+
if batch_idx >= num_batches:
|
| 505 |
+
break
|
| 506 |
+
phase2_cache["teacher"] = None
|
| 507 |
+
phase2_cache["fused"] = None
|
| 508 |
+
input_ids = batch[0].to(device)
|
| 509 |
+
attention_mask = batch[1].to(device) if len(batch) > 1 else None
|
| 510 |
+
|
| 511 |
+
with torch.no_grad():
|
| 512 |
+
model(
|
| 513 |
+
input_ids=input_ids,
|
| 514 |
+
attention_mask=attention_mask,
|
| 515 |
+
use_cache=False,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
teacher = phase2_cache["teacher"]
|
| 519 |
+
fused = phase2_cache["fused"]
|
| 520 |
+
if teacher is None or fused is None:
|
| 521 |
+
raise RuntimeError(
|
| 522 |
+
"Auto selection hooks failed to capture outputs during DWCE replay."
|
| 523 |
+
)
|
| 524 |
+
grad_sq = grad_sq_cache[batch_idx].to(dtype=torch.float32)
|
| 525 |
+
if attention_mask is not None:
|
| 526 |
+
mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1)
|
| 527 |
+
batch_tokens = float(mask.sum().item())
|
| 528 |
+
grad_sq = grad_sq * mask
|
| 529 |
+
else:
|
| 530 |
+
mask = None
|
| 531 |
+
batch_tokens = float(input_ids.numel())
|
| 532 |
+
token_count += batch_tokens
|
| 533 |
+
|
| 534 |
+
delta = fused - teacher
|
| 535 |
+
if mask is not None:
|
| 536 |
+
delta = delta * mask
|
| 537 |
+
score_num += (delta.float().pow(2) * grad_sq).sum().item()
|
| 538 |
+
score_den += (teacher.float().pow(2) * grad_sq).sum().item()
|
| 539 |
+
finally:
|
| 540 |
+
handle_a.remove()
|
| 541 |
+
handle_b_eval.remove()
|
| 542 |
+
|
| 543 |
+
score = (
|
| 544 |
+
score_num / (score_den + eps)
|
| 545 |
+
if norm == "relative"
|
| 546 |
+
else score_num / max(token_count, 1.0)
|
| 547 |
+
)
|
| 548 |
+
meta = {
|
| 549 |
+
"num_batches": num_batches,
|
| 550 |
+
"token_count": token_count,
|
| 551 |
+
"norm": norm,
|
| 552 |
+
"supports_kwargs": supports_kwargs,
|
| 553 |
+
"fuse_priors": fuse_priors,
|
| 554 |
+
"metric": "dwce",
|
| 555 |
+
"dwce_mode": "shared",
|
| 556 |
+
}
|
| 557 |
+
return score, meta
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def _compute_dwce_for_pair(
|
| 561 |
+
model,
|
| 562 |
+
layer_a: torch.nn.Module,
|
| 563 |
+
layer_b: torch.nn.Module,
|
| 564 |
+
fused_layer: torch.nn.Module,
|
| 565 |
+
dataloader,
|
| 566 |
+
device: str,
|
| 567 |
+
max_batches: int,
|
| 568 |
+
eps: float,
|
| 569 |
+
norm: str,
|
| 570 |
+
) -> Tuple[float, Dict[str, object]]:
|
| 571 |
+
cache = {"teacher": None, "fused": None}
|
| 572 |
+
supports_kwargs = True
|
| 573 |
+
|
| 574 |
+
def hook_a(_module, inputs, output, kwargs=None):
|
| 575 |
+
with torch.no_grad():
|
| 576 |
+
detached_inputs = tuple(_detach_arg(arg) for arg in inputs)
|
| 577 |
+
if kwargs is not None and len(kwargs) > 0:
|
| 578 |
+
detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()}
|
| 579 |
+
fused_out = fused_layer(*detached_inputs, **detached_kwargs)
|
| 580 |
+
else:
|
| 581 |
+
fused_out = fused_layer(*detached_inputs)
|
| 582 |
+
fused_hidden = _extract_hidden(fused_out)
|
| 583 |
+
if fused_hidden is None:
|
| 584 |
+
raise RuntimeError("Failed to extract fused hidden state output.")
|
| 585 |
+
cache["fused"] = fused_hidden
|
| 586 |
+
return output
|
| 587 |
+
|
| 588 |
+
def hook_b(_module, _inputs, output, _kwargs=None):
|
| 589 |
+
teacher_hidden = _extract_hidden(output)
|
| 590 |
+
if teacher_hidden is None:
|
| 591 |
+
raise RuntimeError("Failed to extract teacher hidden state output.")
|
| 592 |
+
cache["teacher"] = teacher_hidden
|
| 593 |
+
if teacher_hidden.requires_grad:
|
| 594 |
+
teacher_hidden.retain_grad()
|
| 595 |
+
return output
|
| 596 |
+
|
| 597 |
+
handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
|
| 598 |
+
handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b)
|
| 599 |
+
supports_kwargs = has_kwargs_a and has_kwargs_b
|
| 600 |
+
|
| 601 |
+
score_num = 0.0
|
| 602 |
+
score_den = 0.0
|
| 603 |
+
token_count = 0.0
|
| 604 |
+
num_batches = 0
|
| 605 |
+
|
| 606 |
+
model.eval()
|
| 607 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 608 |
+
if max_batches and batch_idx >= max_batches:
|
| 609 |
+
break
|
| 610 |
+
cache["teacher"] = None
|
| 611 |
+
cache["fused"] = None
|
| 612 |
+
|
| 613 |
+
input_ids = batch[0].to(device)
|
| 614 |
+
attention_mask = batch[1].to(device) if len(batch) > 1 else None
|
| 615 |
+
|
| 616 |
+
model.zero_grad(set_to_none=True)
|
| 617 |
+
outputs = model(
|
| 618 |
+
input_ids=input_ids,
|
| 619 |
+
attention_mask=attention_mask,
|
| 620 |
+
labels=input_ids,
|
| 621 |
+
)
|
| 622 |
+
loss = outputs.loss
|
| 623 |
+
loss.backward()
|
| 624 |
+
|
| 625 |
+
teacher = cache["teacher"]
|
| 626 |
+
fused = cache["fused"]
|
| 627 |
+
grad = None if teacher is None else teacher.grad
|
| 628 |
+
if teacher is None or fused is None or grad is None:
|
| 629 |
+
raise RuntimeError(
|
| 630 |
+
"Auto selection hooks failed to capture outputs/gradients. "
|
| 631 |
+
"Try updating PyTorch or run with --layer <index>."
|
| 632 |
+
)
|
| 633 |
+
if not teacher.requires_grad:
|
| 634 |
+
raise RuntimeError(
|
| 635 |
+
"Teacher hidden state does not require grad. "
|
| 636 |
+
"Ensure model parameters require grad for DWCE."
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
with torch.no_grad():
|
| 640 |
+
if attention_mask is not None:
|
| 641 |
+
mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1)
|
| 642 |
+
batch_tokens = float(mask.sum().item())
|
| 643 |
+
else:
|
| 644 |
+
mask = None
|
| 645 |
+
batch_tokens = float(input_ids.numel())
|
| 646 |
+
token_count += batch_tokens
|
| 647 |
+
|
| 648 |
+
delta = fused - teacher
|
| 649 |
+
grad_sq = grad.pow(2)
|
| 650 |
+
if mask is not None:
|
| 651 |
+
delta = delta * mask
|
| 652 |
+
grad_sq = grad_sq * mask
|
| 653 |
+
|
| 654 |
+
score_num += (delta.pow(2) * grad_sq).sum().item()
|
| 655 |
+
score_den += (teacher.pow(2) * grad_sq).sum().item()
|
| 656 |
+
num_batches += 1
|
| 657 |
+
|
| 658 |
+
handle_a.remove()
|
| 659 |
+
handle_b.remove()
|
| 660 |
+
|
| 661 |
+
if norm == "relative":
|
| 662 |
+
score = score_num / (score_den + eps)
|
| 663 |
+
else:
|
| 664 |
+
denom = token_count if token_count > 0 else 1.0
|
| 665 |
+
score = score_num / denom
|
| 666 |
+
|
| 667 |
+
meta = {
|
| 668 |
+
"num_batches": num_batches,
|
| 669 |
+
"token_count": token_count,
|
| 670 |
+
"norm": norm,
|
| 671 |
+
"supports_kwargs": supports_kwargs,
|
| 672 |
+
}
|
| 673 |
+
return score, meta
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def _compute_cosine_for_pair(
|
| 677 |
+
model,
|
| 678 |
+
layer_a: torch.nn.Module,
|
| 679 |
+
layer_b: torch.nn.Module,
|
| 680 |
+
dataloader,
|
| 681 |
+
device: str,
|
| 682 |
+
max_batches: int,
|
| 683 |
+
eps: float,
|
| 684 |
+
) -> Tuple[float, Dict[str, object]]:
|
| 685 |
+
cache = {"a": None, "b": None}
|
| 686 |
+
supports_kwargs = True
|
| 687 |
+
|
| 688 |
+
def hook_a(_module, _inputs, output, _kwargs=None):
|
| 689 |
+
hidden = _extract_hidden(output)
|
| 690 |
+
if hidden is None:
|
| 691 |
+
raise RuntimeError("Failed to extract layer_a hidden state output.")
|
| 692 |
+
cache["a"] = hidden
|
| 693 |
+
return output
|
| 694 |
+
|
| 695 |
+
def hook_b(_module, _inputs, output, _kwargs=None):
|
| 696 |
+
hidden = _extract_hidden(output)
|
| 697 |
+
if hidden is None:
|
| 698 |
+
raise RuntimeError("Failed to extract layer_b hidden state output.")
|
| 699 |
+
cache["b"] = hidden
|
| 700 |
+
return output
|
| 701 |
+
|
| 702 |
+
handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
|
| 703 |
+
handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b)
|
| 704 |
+
supports_kwargs = has_kwargs_a and has_kwargs_b
|
| 705 |
+
|
| 706 |
+
score_sum = 0.0
|
| 707 |
+
token_count = 0.0
|
| 708 |
+
num_batches = 0
|
| 709 |
+
|
| 710 |
+
model.eval()
|
| 711 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 712 |
+
if max_batches and batch_idx >= max_batches:
|
| 713 |
+
break
|
| 714 |
+
cache["a"] = None
|
| 715 |
+
cache["b"] = None
|
| 716 |
+
|
| 717 |
+
input_ids = batch[0].to(device)
|
| 718 |
+
attention_mask = batch[1].to(device) if len(batch) > 1 else None
|
| 719 |
+
|
| 720 |
+
with torch.no_grad():
|
| 721 |
+
model(
|
| 722 |
+
input_ids=input_ids,
|
| 723 |
+
attention_mask=attention_mask,
|
| 724 |
+
use_cache=False,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
hidden_a = cache["a"]
|
| 728 |
+
hidden_b = cache["b"]
|
| 729 |
+
if hidden_a is None or hidden_b is None:
|
| 730 |
+
raise RuntimeError(
|
| 731 |
+
"Auto selection hooks failed to capture outputs for cosine scoring."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
with torch.no_grad():
|
| 735 |
+
a = hidden_a.float()
|
| 736 |
+
b = hidden_b.float()
|
| 737 |
+
cos = F.cosine_similarity(a, b, dim=-1, eps=eps)
|
| 738 |
+
distance = 1.0 - cos
|
| 739 |
+
|
| 740 |
+
if attention_mask is not None:
|
| 741 |
+
mask = attention_mask.to(dtype=torch.float32)
|
| 742 |
+
batch_tokens = float(mask.sum().item())
|
| 743 |
+
distance = distance * mask
|
| 744 |
+
else:
|
| 745 |
+
batch_tokens = float(distance.numel())
|
| 746 |
+
|
| 747 |
+
token_count += batch_tokens
|
| 748 |
+
score_sum += float(distance.sum().item())
|
| 749 |
+
num_batches += 1
|
| 750 |
+
|
| 751 |
+
handle_a.remove()
|
| 752 |
+
handle_b.remove()
|
| 753 |
+
|
| 754 |
+
denom = token_count if token_count > 0 else 1.0
|
| 755 |
+
score = score_sum / denom
|
| 756 |
+
meta = {
|
| 757 |
+
"num_batches": num_batches,
|
| 758 |
+
"token_count": token_count,
|
| 759 |
+
"metric": "cosine",
|
| 760 |
+
"supports_kwargs": supports_kwargs,
|
| 761 |
+
}
|
| 762 |
+
return score, meta
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def _compute_global_rel_change_for_pair(
|
| 766 |
+
model,
|
| 767 |
+
layers: List[torch.nn.Module],
|
| 768 |
+
pair_idx: int,
|
| 769 |
+
dataloader,
|
| 770 |
+
args,
|
| 771 |
+
max_batches: int,
|
| 772 |
+
eps: float,
|
| 773 |
+
) -> Tuple[float, Dict[str, object]]:
|
| 774 |
+
hidden_size = _get_hidden_size(model)
|
| 775 |
+
head_permute_select = not bool(getattr(args, "no_head_permute_select", False))
|
| 776 |
+
layer_a = layers[pair_idx]
|
| 777 |
+
layer_b = layers[pair_idx + 1]
|
| 778 |
+
fused_layer, fuse_priors = _build_fused_layer_for_pair(
|
| 779 |
+
model,
|
| 780 |
+
layer_a,
|
| 781 |
+
layer_b,
|
| 782 |
+
dataloader,
|
| 783 |
+
device=args.device,
|
| 784 |
+
fisher_mode=args.fisher_mode,
|
| 785 |
+
eps=eps,
|
| 786 |
+
hidden_size=hidden_size,
|
| 787 |
+
enable_head_permute=head_permute_select,
|
| 788 |
+
)
|
| 789 |
+
fused_layer.to(args.device)
|
| 790 |
+
fused_layer.eval()
|
| 791 |
+
|
| 792 |
+
parent, name, container = find_layer_container(model, getattr(args, "layer_path", None))
|
| 793 |
+
if len(list(container)) != len(layers):
|
| 794 |
+
raise RuntimeError("Layer container changed during auto-selection; aborting rerank.")
|
| 795 |
+
|
| 796 |
+
virtual_layers = list(layers)
|
| 797 |
+
virtual_layers[pair_idx] = fused_layer
|
| 798 |
+
del virtual_layers[pair_idx + 1]
|
| 799 |
+
if isinstance(container, torch.nn.ModuleList):
|
| 800 |
+
virtual_container = torch.nn.ModuleList(virtual_layers)
|
| 801 |
+
elif isinstance(container, list):
|
| 802 |
+
virtual_container = virtual_layers
|
| 803 |
+
else:
|
| 804 |
+
raise TypeError("Layer container must be ModuleList or list")
|
| 805 |
+
|
| 806 |
+
teacher_cache = {"pair": None, "final": None}
|
| 807 |
+
supports_kwargs = True
|
| 808 |
+
|
| 809 |
+
def hook_pair(_module, _inputs, output, _kwargs=None):
|
| 810 |
+
hidden = _extract_hidden(output)
|
| 811 |
+
if hidden is None:
|
| 812 |
+
raise RuntimeError("Failed to extract pair output for global relation rerank.")
|
| 813 |
+
teacher_cache["pair"] = hidden
|
| 814 |
+
return output
|
| 815 |
+
|
| 816 |
+
handle_pair, has_kwargs_pair = _register_forward_hook(layer_b, hook_pair)
|
| 817 |
+
supports_kwargs = supports_kwargs and has_kwargs_pair
|
| 818 |
+
|
| 819 |
+
score_sum = 0.0
|
| 820 |
+
token_count = 0.0
|
| 821 |
+
num_batches = 0
|
| 822 |
+
|
| 823 |
+
model.eval()
|
| 824 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 825 |
+
if max_batches and batch_idx >= max_batches:
|
| 826 |
+
break
|
| 827 |
+
|
| 828 |
+
teacher_cache["pair"] = None
|
| 829 |
+
|
| 830 |
+
input_ids = batch[0].to(args.device)
|
| 831 |
+
attention_mask = batch[1].to(args.device) if len(batch) > 1 else None
|
| 832 |
+
|
| 833 |
+
with torch.no_grad():
|
| 834 |
+
teacher_outputs = model(
|
| 835 |
+
input_ids=input_ids,
|
| 836 |
+
attention_mask=attention_mask,
|
| 837 |
+
output_hidden_states=True,
|
| 838 |
+
use_cache=False,
|
| 839 |
+
)
|
| 840 |
+
teacher_hidden_states = getattr(teacher_outputs, "hidden_states", None)
|
| 841 |
+
if not teacher_hidden_states:
|
| 842 |
+
raise RuntimeError("Teacher forward did not return hidden_states.")
|
| 843 |
+
teacher_final = teacher_hidden_states[-1]
|
| 844 |
+
teacher_pair = teacher_cache["pair"]
|
| 845 |
+
|
| 846 |
+
if teacher_pair is None or teacher_final is None:
|
| 847 |
+
raise RuntimeError(
|
| 848 |
+
"Failed to capture teacher pair/final hidden states for global rerank."
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
with torch.no_grad(), _temporary_layers(parent, name, virtual_container):
|
| 852 |
+
fused_outputs = model(
|
| 853 |
+
input_ids=input_ids,
|
| 854 |
+
attention_mask=attention_mask,
|
| 855 |
+
output_hidden_states=True,
|
| 856 |
+
use_cache=False,
|
| 857 |
+
)
|
| 858 |
+
fused_hidden_states = getattr(fused_outputs, "hidden_states", None)
|
| 859 |
+
if not fused_hidden_states:
|
| 860 |
+
raise RuntimeError("Fused forward did not return hidden_states.")
|
| 861 |
+
fused_final = fused_hidden_states[-1]
|
| 862 |
+
|
| 863 |
+
if fused_final is None:
|
| 864 |
+
raise RuntimeError("Failed to capture fused final hidden state for global rerank.")
|
| 865 |
+
|
| 866 |
+
with torch.no_grad():
|
| 867 |
+
teacher_pair_f = teacher_pair.float()
|
| 868 |
+
teacher_final_f = teacher_final.float()
|
| 869 |
+
fused_final_f = fused_final.float()
|
| 870 |
+
|
| 871 |
+
teacher_rel = F.cosine_similarity(
|
| 872 |
+
teacher_pair_f, teacher_final_f, dim=-1, eps=eps
|
| 873 |
+
)
|
| 874 |
+
fused_rel = F.cosine_similarity(
|
| 875 |
+
teacher_pair_f, fused_final_f, dim=-1, eps=eps
|
| 876 |
+
)
|
| 877 |
+
rel_change = (teacher_rel - fused_rel).abs()
|
| 878 |
+
|
| 879 |
+
if attention_mask is not None:
|
| 880 |
+
mask = attention_mask.to(dtype=torch.float32)
|
| 881 |
+
batch_tokens = float(mask.sum().item())
|
| 882 |
+
rel_change = rel_change * mask
|
| 883 |
+
else:
|
| 884 |
+
batch_tokens = float(rel_change.numel())
|
| 885 |
+
|
| 886 |
+
token_count += batch_tokens
|
| 887 |
+
score_sum += float(rel_change.sum().item())
|
| 888 |
+
num_batches += 1
|
| 889 |
+
|
| 890 |
+
handle_pair.remove()
|
| 891 |
+
del fused_layer
|
| 892 |
+
if torch.cuda.is_available():
|
| 893 |
+
torch.cuda.empty_cache()
|
| 894 |
+
|
| 895 |
+
denom = token_count if token_count > 0 else 1.0
|
| 896 |
+
score = score_sum / denom
|
| 897 |
+
meta = {
|
| 898 |
+
"num_batches": num_batches,
|
| 899 |
+
"token_count": token_count,
|
| 900 |
+
"metric": "global_rel_change",
|
| 901 |
+
"supports_kwargs": supports_kwargs,
|
| 902 |
+
"fuse_priors": fuse_priors,
|
| 903 |
+
}
|
| 904 |
+
return score, meta
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def select_layer_auto(
|
| 908 |
+
model,
|
| 909 |
+
layers: List[torch.nn.Module],
|
| 910 |
+
dataloader,
|
| 911 |
+
args,
|
| 912 |
+
previous_scores: Optional[List[float]] = None,
|
| 913 |
+
start_index: int = 0,
|
| 914 |
+
exclude_pairs: Optional[Set[int]] = None,
|
| 915 |
+
) -> Tuple[int, List[float], Dict[str, object]]:
|
| 916 |
+
num_layers = len(layers)
|
| 917 |
+
if num_layers < 2:
|
| 918 |
+
raise SystemExit("Model must have at least 2 layers for auto selection.")
|
| 919 |
+
|
| 920 |
+
hidden_size = _get_hidden_size(model)
|
| 921 |
+
num_pairs = num_layers - 1
|
| 922 |
+
scores: List[float] = [float("inf")] * num_pairs
|
| 923 |
+
meta_per_pair: List[Optional[Dict[str, object]]] = [None] * num_pairs
|
| 924 |
+
supports_kwargs_all = True
|
| 925 |
+
head_permute_select = not bool(getattr(args, "no_head_permute_select", False))
|
| 926 |
+
exclude_set: Set[int] = {
|
| 927 |
+
int(idx)
|
| 928 |
+
for idx in (exclude_pairs or set())
|
| 929 |
+
if isinstance(idx, int) and 0 <= int(idx) < num_pairs
|
| 930 |
+
}
|
| 931 |
+
|
| 932 |
+
max_batches = args.auto_max_batches
|
| 933 |
+
start_index = max(0, min(start_index, num_pairs))
|
| 934 |
+
auto_metric = str(getattr(args, "auto_metric", "dwce")).strip().lower()
|
| 935 |
+
if auto_metric == "hybrid":
|
| 936 |
+
auto_metric = "hybrid_cosine"
|
| 937 |
+
if auto_metric not in {
|
| 938 |
+
"dwce",
|
| 939 |
+
"cosine",
|
| 940 |
+
"hybrid_cosine",
|
| 941 |
+
"hybrid_global_rel",
|
| 942 |
+
}:
|
| 943 |
+
raise SystemExit(
|
| 944 |
+
"--auto_metric must be one of: dwce, cosine, hybrid, "
|
| 945 |
+
"hybrid_cosine, hybrid_global_rel"
|
| 946 |
+
)
|
| 947 |
+
auto_cosine_topk = int(getattr(args, "auto_cosine_topk", 3))
|
| 948 |
+
if auto_cosine_topk <= 0:
|
| 949 |
+
raise SystemExit("--auto_cosine_topk must be >= 1")
|
| 950 |
+
print(
|
| 951 |
+
f"[auto] metric={auto_metric}; using "
|
| 952 |
+
f"{('all' if max_batches == 0 else max_batches)} batches "
|
| 953 |
+
"from calibration samples."
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
reuse_upto = 0
|
| 957 |
+
allow_reuse = auto_metric == "dwce"
|
| 958 |
+
if previous_scores:
|
| 959 |
+
reuse_upto = min(start_index, len(previous_scores), num_pairs) if allow_reuse else 0
|
| 960 |
+
for idx in range(reuse_upto):
|
| 961 |
+
if idx in exclude_set:
|
| 962 |
+
scores[idx] = float("inf")
|
| 963 |
+
meta_per_pair[idx] = {"excluded": True}
|
| 964 |
+
print(f"[auto] skipped excluded pair {idx}-{idx+1}.")
|
| 965 |
+
continue
|
| 966 |
+
scores[idx] = previous_scores[idx]
|
| 967 |
+
meta_per_pair[idx] = (
|
| 968 |
+
{
|
| 969 |
+
"num_batches": 0,
|
| 970 |
+
"token_count": 0.0,
|
| 971 |
+
"norm": args.auto_norm,
|
| 972 |
+
"metric": auto_metric,
|
| 973 |
+
"supports_kwargs": True,
|
| 974 |
+
"reused": True,
|
| 975 |
+
}
|
| 976 |
+
)
|
| 977 |
+
print(f"[auto] reused pair {idx}-{idx+1}: {scores[idx]:.6e}")
|
| 978 |
+
|
| 979 |
+
compute_start = start_index if reuse_upto == start_index else reuse_upto
|
| 980 |
+
pairs_to_score: List[int] = []
|
| 981 |
+
for idx in range(compute_start, num_pairs):
|
| 982 |
+
if idx in exclude_set:
|
| 983 |
+
scores[idx] = float("inf")
|
| 984 |
+
meta_per_pair[idx] = {"excluded": True}
|
| 985 |
+
print(f"[auto] skipped excluded pair {idx}-{idx+1}.")
|
| 986 |
+
continue
|
| 987 |
+
pairs_to_score.append(idx)
|
| 988 |
+
|
| 989 |
+
def _score_dwce_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
|
| 990 |
+
print(f"[auto] building fused pair {idx}-{idx+1} for DWCE...")
|
| 991 |
+
layer_a = layers[idx]
|
| 992 |
+
layer_b = layers[idx + 1]
|
| 993 |
+
dwce_mode = str(getattr(args, "auto_dwce_mode", "separate")).strip().lower()
|
| 994 |
+
if dwce_mode == "shared":
|
| 995 |
+
try:
|
| 996 |
+
return _score_dwce_with_shared_backward(
|
| 997 |
+
model,
|
| 998 |
+
layer_a,
|
| 999 |
+
layer_b,
|
| 1000 |
+
dataloader,
|
| 1001 |
+
device=args.device,
|
| 1002 |
+
fisher_mode=args.fisher_mode,
|
| 1003 |
+
max_batches=max_batches,
|
| 1004 |
+
eps=args.eps,
|
| 1005 |
+
norm=args.auto_norm,
|
| 1006 |
+
hidden_size=hidden_size,
|
| 1007 |
+
enable_head_permute=head_permute_select,
|
| 1008 |
+
)
|
| 1009 |
+
except _DwceGradCacheOverflow:
|
| 1010 |
+
print(
|
| 1011 |
+
"[auto] shared-backward DWCE cache exceeded budget; "
|
| 1012 |
+
"falling back to separate mode."
|
| 1013 |
+
)
|
| 1014 |
+
fused, fuse_priors = _build_fused_layer_for_pair(
|
| 1015 |
+
model,
|
| 1016 |
+
layer_a,
|
| 1017 |
+
layer_b,
|
| 1018 |
+
dataloader,
|
| 1019 |
+
device=args.device,
|
| 1020 |
+
fisher_mode=args.fisher_mode,
|
| 1021 |
+
eps=args.eps,
|
| 1022 |
+
hidden_size=hidden_size,
|
| 1023 |
+
enable_head_permute=head_permute_select,
|
| 1024 |
+
)
|
| 1025 |
+
fused.to(args.device)
|
| 1026 |
+
fused.eval()
|
| 1027 |
+
for param in model.parameters():
|
| 1028 |
+
param.requires_grad_(True)
|
| 1029 |
+
score, meta = _compute_dwce_for_pair(
|
| 1030 |
+
model,
|
| 1031 |
+
layer_a,
|
| 1032 |
+
layer_b,
|
| 1033 |
+
fused,
|
| 1034 |
+
dataloader,
|
| 1035 |
+
device=args.device,
|
| 1036 |
+
max_batches=max_batches,
|
| 1037 |
+
eps=args.eps,
|
| 1038 |
+
norm=args.auto_norm,
|
| 1039 |
+
)
|
| 1040 |
+
meta["fuse_priors"] = fuse_priors
|
| 1041 |
+
meta["metric"] = "dwce"
|
| 1042 |
+
del fused
|
| 1043 |
+
if torch.cuda.is_available():
|
| 1044 |
+
torch.cuda.empty_cache()
|
| 1045 |
+
return score, meta
|
| 1046 |
+
|
| 1047 |
+
def _score_cosine_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
|
| 1048 |
+
print(f"[auto] scoring cosine for pair {idx}-{idx+1}...")
|
| 1049 |
+
layer_a = layers[idx]
|
| 1050 |
+
layer_b = layers[idx + 1]
|
| 1051 |
+
return _compute_cosine_for_pair(
|
| 1052 |
+
model,
|
| 1053 |
+
layer_a,
|
| 1054 |
+
layer_b,
|
| 1055 |
+
dataloader,
|
| 1056 |
+
device=args.device,
|
| 1057 |
+
max_batches=max_batches,
|
| 1058 |
+
eps=args.eps,
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
def _score_global_rel_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
|
| 1062 |
+
print(f"[auto] scoring global relation change for pair {idx}-{idx+1}...")
|
| 1063 |
+
return _compute_global_rel_change_for_pair(
|
| 1064 |
+
model,
|
| 1065 |
+
layers,
|
| 1066 |
+
idx,
|
| 1067 |
+
dataloader,
|
| 1068 |
+
args=args,
|
| 1069 |
+
max_batches=max_batches,
|
| 1070 |
+
eps=args.eps,
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
if auto_metric in {"dwce", "cosine"}:
|
| 1074 |
+
for idx in pairs_to_score:
|
| 1075 |
+
if auto_metric == "dwce":
|
| 1076 |
+
score, meta = _score_dwce_for_pair(idx)
|
| 1077 |
+
else:
|
| 1078 |
+
score, meta = _score_cosine_for_pair(idx)
|
| 1079 |
+
supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True)
|
| 1080 |
+
scores[idx] = score
|
| 1081 |
+
meta_per_pair[idx] = meta
|
| 1082 |
+
print(f"[auto] {auto_metric} pair {idx}-{idx+1}: {score:.6e}")
|
| 1083 |
+
else:
|
| 1084 |
+
dwce_prefilter: Dict[int, float] = {}
|
| 1085 |
+
for idx in pairs_to_score:
|
| 1086 |
+
score, meta = _score_dwce_for_pair(idx)
|
| 1087 |
+
dwce_prefilter[idx] = score
|
| 1088 |
+
supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True)
|
| 1089 |
+
meta_per_pair[idx] = {
|
| 1090 |
+
"prefilter_dwce": score,
|
| 1091 |
+
"dwce_meta": meta,
|
| 1092 |
+
"metric": "hybrid",
|
| 1093 |
+
}
|
| 1094 |
+
print(f"[auto] hybrid prefilter DWCE pair {idx}-{idx+1}: {score:.6e}")
|
| 1095 |
+
ranked = sorted(pairs_to_score, key=lambda i: float(dwce_prefilter[i]))
|
| 1096 |
+
shortlist = ranked[: min(auto_cosine_topk, len(ranked))]
|
| 1097 |
+
print(f"[auto] hybrid shortlist (dwce top-{len(shortlist)}): {shortlist}")
|
| 1098 |
+
for idx in shortlist:
|
| 1099 |
+
if auto_metric == "hybrid_global_rel":
|
| 1100 |
+
score, rerank_meta = _score_global_rel_for_pair(idx)
|
| 1101 |
+
score_metric = "global_rel_change"
|
| 1102 |
+
else:
|
| 1103 |
+
score, rerank_meta = _score_cosine_for_pair(idx)
|
| 1104 |
+
score_metric = "cosine"
|
| 1105 |
+
supports_kwargs_all = supports_kwargs_all and rerank_meta.get(
|
| 1106 |
+
"supports_kwargs", True
|
| 1107 |
+
)
|
| 1108 |
+
scores[idx] = score
|
| 1109 |
+
pair_meta = meta_per_pair[idx] or {}
|
| 1110 |
+
pair_meta["rerank_meta"] = rerank_meta
|
| 1111 |
+
pair_meta["score_metric"] = score_metric
|
| 1112 |
+
meta_per_pair[idx] = pair_meta
|
| 1113 |
+
print(f"[auto] hybrid {score_metric} pair {idx}-{idx+1}: {score:.6e}")
|
| 1114 |
+
|
| 1115 |
+
if not supports_kwargs_all:
|
| 1116 |
+
print(
|
| 1117 |
+
"[auto] Warning: forward hooks did not capture kwargs; "
|
| 1118 |
+
"fused-layer calls may be approximate."
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
print(f"[auto] score summary (metric={auto_metric}, norm={args.auto_norm}):")
|
| 1122 |
+
for idx, score in enumerate(scores):
|
| 1123 |
+
if idx in exclude_set:
|
| 1124 |
+
print(f"[auto] pair {idx}-{idx+1}: excluded")
|
| 1125 |
+
elif math.isfinite(float(score)):
|
| 1126 |
+
print(f"[auto] pair {idx}-{idx+1}: {score:.6e}")
|
| 1127 |
+
else:
|
| 1128 |
+
print(f"[auto] pair {idx}-{idx+1}: {score}")
|
| 1129 |
+
|
| 1130 |
+
candidates = [i for i in range(num_pairs) if i not in exclude_set]
|
| 1131 |
+
if not candidates:
|
| 1132 |
+
raise SystemExit("All pairs are excluded; cannot auto-select a fusion layer.")
|
| 1133 |
+
best_idx = min(candidates, key=lambda i: scores[i])
|
| 1134 |
+
best_score = float(scores[best_idx])
|
| 1135 |
+
if not math.isfinite(best_score):
|
| 1136 |
+
raise SystemExit(
|
| 1137 |
+
"Auto selection failed: all candidate pairs have non-finite scores "
|
| 1138 |
+
"(check --exclude_pairs and data)."
|
| 1139 |
+
)
|
| 1140 |
+
print(f"[auto] Selected layer {best_idx} (score={best_score:.6e})")
|
| 1141 |
+
|
| 1142 |
+
meta = {
|
| 1143 |
+
"per_pair": meta_per_pair,
|
| 1144 |
+
"supports_kwargs": supports_kwargs_all,
|
| 1145 |
+
"max_batches": max_batches,
|
| 1146 |
+
"norm": args.auto_norm,
|
| 1147 |
+
"metric": auto_metric,
|
| 1148 |
+
"cosine_topk": auto_cosine_topk,
|
| 1149 |
+
"start_index": start_index,
|
| 1150 |
+
"excluded_pairs": sorted(exclude_set),
|
| 1151 |
+
}
|
| 1152 |
+
return best_idx, scores, meta
|
src/loratune.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Centralized Alpaca LoRA finetuning for post-pruned models."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import itertools
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from types import SimpleNamespace
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from contextlib import nullcontext
|
| 13 |
+
|
| 14 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
|
| 15 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
| 16 |
+
import ppl_eval
|
| 17 |
+
|
| 18 |
+
from fuse_layers_data import FixedSeqDataset, load_instruction_records
|
| 19 |
+
from fuse_layers_distill import LoRALinear, apply_lora_adapters, merge_lora_adapters
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
except Exception: # pragma: no cover
|
| 24 |
+
tqdm = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_args() -> argparse.Namespace:
|
| 28 |
+
parser = argparse.ArgumentParser(description="Run centralized Alpaca LoRA finetuning.")
|
| 29 |
+
parser.add_argument("--base_model", required=True, help="Path or HF model id to finetune")
|
| 30 |
+
parser.add_argument("--output_dir", required=True, help="Directory to save merged model")
|
| 31 |
+
parser.add_argument("--device", default="cuda", help="Training device")
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--dtype",
|
| 34 |
+
default="bfloat16",
|
| 35 |
+
choices=["float32", "float16", "bfloat16"],
|
| 36 |
+
help="Model load/training dtype",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument("--trust_remote_code", action="store_true", help="Enable trust_remote_code")
|
| 39 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 40 |
+
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--instruction_dataset",
|
| 43 |
+
default="yahma/alpaca-cleaned",
|
| 44 |
+
help="HF dataset name for Alpaca-style instruction data",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument("--instruction_config", default=None, help="Optional dataset config")
|
| 47 |
+
parser.add_argument("--instruction_split", default="train", help="Dataset split")
|
| 48 |
+
parser.add_argument("--instruction_field_instruction", default="instruction")
|
| 49 |
+
parser.add_argument("--instruction_field_input", default="input")
|
| 50 |
+
parser.add_argument("--instruction_field_output", default="output")
|
| 51 |
+
parser.add_argument("--max_samples", type=int, default=0, help="Limit instruction samples (0 = all)")
|
| 52 |
+
parser.add_argument("--seq_len", type=int, default=1024, help="Training sequence length")
|
| 53 |
+
parser.add_argument("--batch_size", type=int, default=64, help="Global batch size")
|
| 54 |
+
parser.add_argument("--micro_batch_size", type=int, default=4, help="Per-step micro-batch size")
|
| 55 |
+
parser.add_argument("--epochs", type=float, default=1.0, help="Training epochs")
|
| 56 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
|
| 57 |
+
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay")
|
| 58 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Gradient clipping norm")
|
| 59 |
+
parser.add_argument("--log_steps", type=int, default=100, help="Log every N optimizer steps")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--save_steps",
|
| 62 |
+
type=int,
|
| 63 |
+
default=200,
|
| 64 |
+
help="Save LoRA adapter checkpoints every N optimizer steps (0 = disable)",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--no_wikitext2_ppl_on_log",
|
| 68 |
+
dest="wikitext2_ppl_on_log",
|
| 69 |
+
action="store_false",
|
| 70 |
+
help="Disable Wikitext-2 perplexity evaluation at loss log steps",
|
| 71 |
+
)
|
| 72 |
+
parser.set_defaults(wikitext2_ppl_on_log=True)
|
| 73 |
+
parser.add_argument("--wikitext2_ppl_seq_len", type=int, default=128)
|
| 74 |
+
parser.add_argument("--wikitext2_ppl_batch_size", type=int, default=8)
|
| 75 |
+
parser.add_argument("--wikitext2_ppl_max_batches", type=int, default=None)
|
| 76 |
+
|
| 77 |
+
parser.add_argument("--lora_rank", type=int, default=8, help="LoRA rank")
|
| 78 |
+
parser.add_argument("--lora_alpha", type=float, default=16.0, help="LoRA alpha")
|
| 79 |
+
parser.add_argument("--lora_dropout", type=float, default=0.0, help="LoRA dropout")
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--lora_target_modules",
|
| 82 |
+
nargs="*",
|
| 83 |
+
default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
|
| 84 |
+
help="Linear module suffixes to LoRA-wrap",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return parser.parse_args()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_dtype(name: str) -> torch.dtype:
|
| 91 |
+
return {
|
| 92 |
+
"float32": torch.float32,
|
| 93 |
+
"float16": torch.float16,
|
| 94 |
+
"bfloat16": torch.bfloat16,
|
| 95 |
+
}[name]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def seed_all(seed: int) -> None:
|
| 99 |
+
torch.manual_seed(seed)
|
| 100 |
+
if torch.cuda.is_available():
|
| 101 |
+
torch.cuda.manual_seed_all(seed)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def normalize_config(config):
|
| 105 |
+
layer_types = getattr(config, "layer_types", None)
|
| 106 |
+
num_hidden_layers = getattr(config, "num_hidden_layers", None)
|
| 107 |
+
if layer_types is not None and num_hidden_layers is not None and len(layer_types) != num_hidden_layers:
|
| 108 |
+
config.layer_types = list(layer_types[:num_hidden_layers])
|
| 109 |
+
if getattr(config, "_attn_implementation", None) is None:
|
| 110 |
+
config._attn_implementation = "eager"
|
| 111 |
+
return config
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def load_normalized_config(base_model: str, trust_remote_code: bool):
|
| 115 |
+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(base_model, trust_remote_code=trust_remote_code)
|
| 116 |
+
layer_types = config_dict.get("layer_types")
|
| 117 |
+
num_hidden_layers = config_dict.get("num_hidden_layers")
|
| 118 |
+
if layer_types is not None and num_hidden_layers is not None and len(layer_types) != num_hidden_layers:
|
| 119 |
+
config_dict["layer_types"] = list(layer_types[:num_hidden_layers])
|
| 120 |
+
if config_dict.get("_attn_implementation") is None:
|
| 121 |
+
config_dict["_attn_implementation"] = "eager"
|
| 122 |
+
model_type = config_dict["model_type"]
|
| 123 |
+
config_class = CONFIG_MAPPING[model_type]
|
| 124 |
+
config = config_class.from_dict(config_dict, **unused_kwargs)
|
| 125 |
+
return normalize_config(config)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def validate_local_model_dir(base_path: Path) -> None:
|
| 129 |
+
if not base_path.exists() or not base_path.is_dir():
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
has_config = (base_path / "config.json").is_file()
|
| 133 |
+
has_weights = any(
|
| 134 |
+
(base_path / name).is_file()
|
| 135 |
+
for name in (
|
| 136 |
+
"model.safetensors",
|
| 137 |
+
"model.safetensors.index.json",
|
| 138 |
+
"pytorch_model.bin",
|
| 139 |
+
"pytorch_model.bin.index.json",
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
if has_config and has_weights:
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
raise SystemExit(
|
| 146 |
+
"Local --base_model points to an incomplete HF model directory: "
|
| 147 |
+
f"{base_path}. Expected at least config.json and model weights. "
|
| 148 |
+
"Set --base_model/BASE_MODEL to a saved HF model directory."
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_base_artifacts(args: argparse.Namespace):
|
| 153 |
+
base_path = Path(args.base_model)
|
| 154 |
+
if base_path.is_file() and base_path.suffix == ".bin":
|
| 155 |
+
checkpoint = torch.load(str(base_path), map_location="cpu", weights_only=False)
|
| 156 |
+
if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
|
| 157 |
+
raise SystemExit("Expected a .bin checkpoint dict with `model` and `tokenizer` entries.")
|
| 158 |
+
model = checkpoint["model"]
|
| 159 |
+
tokenizer = checkpoint["tokenizer"]
|
| 160 |
+
if tokenizer.pad_token is None:
|
| 161 |
+
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
|
| 162 |
+
return model, tokenizer
|
| 163 |
+
|
| 164 |
+
validate_local_model_dir(base_path)
|
| 165 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=args.trust_remote_code)
|
| 166 |
+
if tokenizer.pad_token is None:
|
| 167 |
+
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
|
| 168 |
+
config = load_normalized_config(args.base_model, trust_remote_code=args.trust_remote_code)
|
| 169 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 170 |
+
args.base_model,
|
| 171 |
+
config=config,
|
| 172 |
+
torch_dtype=get_dtype(args.dtype),
|
| 173 |
+
trust_remote_code=args.trust_remote_code,
|
| 174 |
+
)
|
| 175 |
+
return model, tokenizer
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def build_training_loader(args: argparse.Namespace, tokenizer) -> torch.utils.data.DataLoader:
|
| 179 |
+
num_samples = args.max_samples if args.max_samples > 0 else 0
|
| 180 |
+
records = load_instruction_records(args, num_samples)
|
| 181 |
+
if not records:
|
| 182 |
+
raise SystemExit("No instruction records were loaded.")
|
| 183 |
+
dataset = FixedSeqDataset(records, tokenizer, args.seq_len)
|
| 184 |
+
return torch.utils.data.DataLoader(dataset, batch_size=args.micro_batch_size, shuffle=True)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def save_lora_adapters(
|
| 188 |
+
model: torch.nn.Module, args: argparse.Namespace, subdir: str = "lora_adapter"
|
| 189 |
+
) -> str:
|
| 190 |
+
adapter_dir = os.path.join(args.output_dir, subdir)
|
| 191 |
+
os.makedirs(adapter_dir, exist_ok=True)
|
| 192 |
+
|
| 193 |
+
adapter_state = {}
|
| 194 |
+
adapter_modules = {}
|
| 195 |
+
for module_name, module in model.named_modules():
|
| 196 |
+
if not isinstance(module, LoRALinear):
|
| 197 |
+
continue
|
| 198 |
+
adapter_modules[module_name] = {
|
| 199 |
+
"rank": module.rank,
|
| 200 |
+
"alpha": module.alpha,
|
| 201 |
+
"scaling": module.scaling,
|
| 202 |
+
"dropout": getattr(module.dropout, "p", 0.0),
|
| 203 |
+
"base_layer_class": type(module.base).__name__,
|
| 204 |
+
"in_features": module.base.in_features,
|
| 205 |
+
"out_features": module.base.out_features,
|
| 206 |
+
}
|
| 207 |
+
adapter_state[f"{module_name}.lora_A.weight"] = module.lora_A.weight.detach().cpu()
|
| 208 |
+
adapter_state[f"{module_name}.lora_B.weight"] = module.lora_B.weight.detach().cpu()
|
| 209 |
+
|
| 210 |
+
torch.save(adapter_state, os.path.join(adapter_dir, "adapter_model.bin"))
|
| 211 |
+
with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as handle:
|
| 212 |
+
json.dump(
|
| 213 |
+
{
|
| 214 |
+
"base_model": args.base_model,
|
| 215 |
+
"lora_rank": args.lora_rank,
|
| 216 |
+
"lora_alpha": args.lora_alpha,
|
| 217 |
+
"lora_dropout": args.lora_dropout,
|
| 218 |
+
"lora_target_modules": list(args.lora_target_modules),
|
| 219 |
+
"batch_size": args.batch_size,
|
| 220 |
+
"micro_batch_size": args.micro_batch_size,
|
| 221 |
+
"grad_accum_steps": args.grad_accum_steps,
|
| 222 |
+
"modules": adapter_modules,
|
| 223 |
+
},
|
| 224 |
+
handle,
|
| 225 |
+
indent=2,
|
| 226 |
+
)
|
| 227 |
+
return adapter_dir
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def prepare_wikitext2_eval(args: argparse.Namespace, model, tokenizer):
|
| 231 |
+
if not args.wikitext2_ppl_on_log:
|
| 232 |
+
return None
|
| 233 |
+
return ppl_eval.prepare_ppl_dataloaders(
|
| 234 |
+
tokenizer=tokenizer,
|
| 235 |
+
datasets=["wikitext"],
|
| 236 |
+
configs=["wikitext-2-raw-v1"],
|
| 237 |
+
split="test",
|
| 238 |
+
text_field=None,
|
| 239 |
+
num_samples=0,
|
| 240 |
+
seq_len=args.wikitext2_ppl_seq_len,
|
| 241 |
+
batch_size=args.wikitext2_ppl_batch_size,
|
| 242 |
+
seed=args.seed,
|
| 243 |
+
shuffle=False,
|
| 244 |
+
model_family="auto",
|
| 245 |
+
add_bos="auto",
|
| 246 |
+
cache_dir=None,
|
| 247 |
+
num_workers=0,
|
| 248 |
+
model=model,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def train(model: torch.nn.Module, dataloader, args: argparse.Namespace, wikitext2_eval_dataloaders=None) -> dict:
|
| 253 |
+
lora_args = SimpleNamespace(
|
| 254 |
+
lora_rank=args.lora_rank,
|
| 255 |
+
lora_alpha=args.lora_alpha,
|
| 256 |
+
lora_dropout=args.lora_dropout,
|
| 257 |
+
lora_target_modules=args.lora_target_modules,
|
| 258 |
+
lora_respect_exclude_pairs=False,
|
| 259 |
+
layer_path=None,
|
| 260 |
+
exclude_pairs=None,
|
| 261 |
+
)
|
| 262 |
+
lora_modules = apply_lora_adapters(model, lora_args)
|
| 263 |
+
lora_params = [param for module in lora_modules for param in module.lora_parameters()]
|
| 264 |
+
|
| 265 |
+
optimizer = torch.optim.AdamW(
|
| 266 |
+
lora_params,
|
| 267 |
+
lr=args.learning_rate,
|
| 268 |
+
weight_decay=args.weight_decay,
|
| 269 |
+
)
|
| 270 |
+
model.train()
|
| 271 |
+
|
| 272 |
+
device = torch.device(args.device)
|
| 273 |
+
device_type = device.type
|
| 274 |
+
amp_dtype = None
|
| 275 |
+
if args.dtype == "float16":
|
| 276 |
+
amp_dtype = torch.float16
|
| 277 |
+
elif args.dtype == "bfloat16":
|
| 278 |
+
amp_dtype = torch.bfloat16
|
| 279 |
+
use_amp = amp_dtype is not None and device_type == "cuda"
|
| 280 |
+
use_scaler = use_amp and amp_dtype == torch.float16
|
| 281 |
+
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
| 282 |
+
|
| 283 |
+
full_epochs = int(args.epochs)
|
| 284 |
+
fractional = args.epochs - full_epochs
|
| 285 |
+
epoch_plan = [None] * full_epochs
|
| 286 |
+
if fractional > 1e-8:
|
| 287 |
+
frac_batches = max(1, int(round(fractional * len(dataloader))))
|
| 288 |
+
epoch_plan.append(frac_batches)
|
| 289 |
+
|
| 290 |
+
optimizer.zero_grad(set_to_none=True)
|
| 291 |
+
optimizer_step = 0
|
| 292 |
+
seen_batches = 0
|
| 293 |
+
last_loss = None
|
| 294 |
+
ppl_history = []
|
| 295 |
+
|
| 296 |
+
for epoch_idx, max_batches in enumerate(epoch_plan, start=1):
|
| 297 |
+
iterator = dataloader if max_batches is None else itertools.islice(dataloader, max_batches)
|
| 298 |
+
if tqdm is not None:
|
| 299 |
+
iterator = tqdm(iterator, desc=f"LoRA epoch {epoch_idx}", unit="batch", total=max_batches)
|
| 300 |
+
for batch in iterator:
|
| 301 |
+
input_ids = batch[0].to(args.device)
|
| 302 |
+
attention_mask = batch[1].to(args.device)
|
| 303 |
+
|
| 304 |
+
autocast_ctx = (
|
| 305 |
+
torch.autocast(device_type=device_type, dtype=amp_dtype)
|
| 306 |
+
if use_amp
|
| 307 |
+
else nullcontext()
|
| 308 |
+
)
|
| 309 |
+
with autocast_ctx:
|
| 310 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
|
| 311 |
+
logits = outputs.logits[:, :-1, :].contiguous()
|
| 312 |
+
labels = input_ids[:, 1:].contiguous()
|
| 313 |
+
mask = attention_mask[:, 1:].contiguous()
|
| 314 |
+
ce_flat = torch.nn.functional.cross_entropy(
|
| 315 |
+
logits.view(-1, logits.size(-1)),
|
| 316 |
+
labels.view(-1),
|
| 317 |
+
reduction="none",
|
| 318 |
+
)
|
| 319 |
+
denom = mask.sum()
|
| 320 |
+
if denom.item() == 0:
|
| 321 |
+
continue
|
| 322 |
+
loss = (ce_flat * mask.reshape(-1).to(ce_flat.dtype)).sum() / denom
|
| 323 |
+
|
| 324 |
+
last_loss = float(loss.detach().item())
|
| 325 |
+
scaled_loss = loss / max(args.grad_accum_steps, 1)
|
| 326 |
+
if use_scaler:
|
| 327 |
+
scaler.scale(scaled_loss).backward()
|
| 328 |
+
else:
|
| 329 |
+
scaled_loss.backward()
|
| 330 |
+
|
| 331 |
+
seen_batches += 1
|
| 332 |
+
if seen_batches % max(args.grad_accum_steps, 1) != 0:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
if args.max_grad_norm is not None:
|
| 336 |
+
if use_scaler:
|
| 337 |
+
scaler.unscale_(optimizer)
|
| 338 |
+
torch.nn.utils.clip_grad_norm_(lora_params, args.max_grad_norm)
|
| 339 |
+
if use_scaler:
|
| 340 |
+
scaler.step(optimizer)
|
| 341 |
+
scaler.update()
|
| 342 |
+
else:
|
| 343 |
+
optimizer.step()
|
| 344 |
+
optimizer.zero_grad(set_to_none=True)
|
| 345 |
+
optimizer_step += 1
|
| 346 |
+
|
| 347 |
+
if args.log_steps and optimizer_step % args.log_steps == 0:
|
| 348 |
+
print(f"[loratune] step={optimizer_step} loss={last_loss:.6f}")
|
| 349 |
+
if wikitext2_eval_dataloaders is not None:
|
| 350 |
+
prev_mode = model.training
|
| 351 |
+
model.eval()
|
| 352 |
+
ppl_results = ppl_eval.evaluate_ppl_dataloaders(
|
| 353 |
+
model,
|
| 354 |
+
wikitext2_eval_dataloaders,
|
| 355 |
+
args.device,
|
| 356 |
+
max_batches=args.wikitext2_ppl_max_batches,
|
| 357 |
+
)
|
| 358 |
+
ppl_history.append({"step": optimizer_step, "ppl": ppl_results})
|
| 359 |
+
print(f"[loratune] ppl step={optimizer_step} {ppl_results}")
|
| 360 |
+
if prev_mode:
|
| 361 |
+
model.train()
|
| 362 |
+
|
| 363 |
+
if args.save_steps and optimizer_step % args.save_steps == 0:
|
| 364 |
+
checkpoint_dir = save_lora_adapters(
|
| 365 |
+
model,
|
| 366 |
+
args,
|
| 367 |
+
subdir=os.path.join("checkpoints", f"step_{optimizer_step}"),
|
| 368 |
+
)
|
| 369 |
+
print(f"[loratune] saved adapter checkpoint to {checkpoint_dir}")
|
| 370 |
+
|
| 371 |
+
adapter_dir = save_lora_adapters(model, args)
|
| 372 |
+
merge_lora_adapters(model)
|
| 373 |
+
return {
|
| 374 |
+
"adapter_dir": adapter_dir,
|
| 375 |
+
"optimizer_steps": optimizer_step,
|
| 376 |
+
"seen_batches": seen_batches,
|
| 377 |
+
"last_loss": last_loss,
|
| 378 |
+
"wikitext2_ppl_history": ppl_history,
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def main() -> None:
|
| 383 |
+
args = parse_args()
|
| 384 |
+
if args.batch_size < 1:
|
| 385 |
+
raise SystemExit("--batch_size must be >= 1")
|
| 386 |
+
if args.micro_batch_size < 1:
|
| 387 |
+
raise SystemExit("--micro_batch_size must be >= 1")
|
| 388 |
+
args.grad_accum_steps = args.batch_size // args.micro_batch_size
|
| 389 |
+
if args.grad_accum_steps < 1:
|
| 390 |
+
raise SystemExit("--batch_size must be >= --micro_batch_size")
|
| 391 |
+
|
| 392 |
+
seed_all(args.seed)
|
| 393 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 394 |
+
|
| 395 |
+
model, tokenizer = load_base_artifacts(args)
|
| 396 |
+
if args.dtype != "float32":
|
| 397 |
+
model = model.to(get_dtype(args.dtype))
|
| 398 |
+
model.to(args.device)
|
| 399 |
+
|
| 400 |
+
dataloader = build_training_loader(args, tokenizer)
|
| 401 |
+
wikitext2_eval_dataloaders = prepare_wikitext2_eval(args, model, tokenizer)
|
| 402 |
+
metrics = train(model, dataloader, args, wikitext2_eval_dataloaders=wikitext2_eval_dataloaders)
|
| 403 |
+
|
| 404 |
+
model.save_pretrained(args.output_dir)
|
| 405 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 406 |
+
|
| 407 |
+
with open(os.path.join(args.output_dir, "loratune_metrics.json"), "w", encoding="utf-8") as handle:
|
| 408 |
+
json.dump(
|
| 409 |
+
{
|
| 410 |
+
"base_model": args.base_model,
|
| 411 |
+
"instruction_dataset": args.instruction_dataset,
|
| 412 |
+
"seq_len": args.seq_len,
|
| 413 |
+
"batch_size": args.batch_size,
|
| 414 |
+
"micro_batch_size": args.micro_batch_size,
|
| 415 |
+
"grad_accum_steps": args.grad_accum_steps,
|
| 416 |
+
"epochs": args.epochs,
|
| 417 |
+
"learning_rate": args.learning_rate,
|
| 418 |
+
"save_steps": args.save_steps,
|
| 419 |
+
"lora_rank": args.lora_rank,
|
| 420 |
+
"lora_alpha": args.lora_alpha,
|
| 421 |
+
"lora_dropout": args.lora_dropout,
|
| 422 |
+
**metrics,
|
| 423 |
+
},
|
| 424 |
+
handle,
|
| 425 |
+
indent=2,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
main()
|
src/loratune_config.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Configuration helpers for centralized LoRA finetuning."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from dataclasses import asdict, dataclass, field
|
| 7 |
+
from types import SimpleNamespace
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class LoRATuneConfig:
|
| 13 |
+
"""Structured config matching the current loratune.py CLI surface."""
|
| 14 |
+
|
| 15 |
+
base_model: str = ""
|
| 16 |
+
output_dir: str = ""
|
| 17 |
+
device: str = "cuda"
|
| 18 |
+
dtype: str = "bfloat16"
|
| 19 |
+
trust_remote_code: bool = False
|
| 20 |
+
seed: int = 42
|
| 21 |
+
|
| 22 |
+
instruction_dataset: str = "tatsu-lab/alpaca"
|
| 23 |
+
instruction_config: Optional[str] = None
|
| 24 |
+
instruction_split: str = "train"
|
| 25 |
+
instruction_field_instruction: str = "instruction"
|
| 26 |
+
instruction_field_input: str = "input"
|
| 27 |
+
instruction_field_output: str = "output"
|
| 28 |
+
max_samples: int = 0
|
| 29 |
+
seq_len: int = 1024
|
| 30 |
+
batch_size: int = 64
|
| 31 |
+
micro_batch_size: int = 4
|
| 32 |
+
epochs: float = 1.0
|
| 33 |
+
learning_rate: float = 1e-4
|
| 34 |
+
weight_decay: float = 0.0
|
| 35 |
+
max_grad_norm: float = 1.0
|
| 36 |
+
log_steps: int = 100
|
| 37 |
+
|
| 38 |
+
wikitext2_ppl_on_log: bool = True
|
| 39 |
+
wikitext2_ppl_seq_len: int = 128
|
| 40 |
+
wikitext2_ppl_batch_size: int = 8
|
| 41 |
+
wikitext2_ppl_max_batches: Optional[int] = None
|
| 42 |
+
|
| 43 |
+
lora_rank: int = 8
|
| 44 |
+
lora_alpha: float = 16.0
|
| 45 |
+
lora_dropout: float = 0.0
|
| 46 |
+
lora_target_modules: List[str] = field(
|
| 47 |
+
default_factory=lambda: [
|
| 48 |
+
"q_proj",
|
| 49 |
+
"k_proj",
|
| 50 |
+
"v_proj",
|
| 51 |
+
"o_proj",
|
| 52 |
+
"gate_proj",
|
| 53 |
+
"down_proj",
|
| 54 |
+
"up_proj",
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def grad_accum_steps(self) -> int:
|
| 60 |
+
if self.batch_size < 1:
|
| 61 |
+
raise ValueError("batch_size must be >= 1")
|
| 62 |
+
if self.micro_batch_size < 1:
|
| 63 |
+
raise ValueError("micro_batch_size must be >= 1")
|
| 64 |
+
if self.batch_size < self.micro_batch_size:
|
| 65 |
+
raise ValueError("batch_size must be >= micro_batch_size")
|
| 66 |
+
return self.batch_size // self.micro_batch_size
|
| 67 |
+
|
| 68 |
+
def validate(self) -> "LoRATuneConfig":
|
| 69 |
+
_ = self.grad_accum_steps
|
| 70 |
+
if not self.base_model:
|
| 71 |
+
raise ValueError("base_model must be set")
|
| 72 |
+
if not self.output_dir:
|
| 73 |
+
raise ValueError("output_dir must be set")
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 77 |
+
data = asdict(self)
|
| 78 |
+
data["grad_accum_steps"] = self.grad_accum_steps
|
| 79 |
+
return data
|
| 80 |
+
|
| 81 |
+
def to_namespace(self) -> SimpleNamespace:
|
| 82 |
+
return SimpleNamespace(**self.to_dict())
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_dict(cls, values: Dict[str, Any]) -> "LoRATuneConfig":
|
| 86 |
+
return cls(**values)
|
src/ppl_eval.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Perplexity evaluation for causal LMs on HF datasets or provided text."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from typing import Dict, Iterable, List, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
except Exception: # pragma: no cover - optional dependency
|
| 15 |
+
load_dataset = None
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 19 |
+
except Exception as exc: # pragma: no cover - fail early with clear error
|
| 20 |
+
raise SystemExit("transformers is required: pip install transformers") from exc
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
except Exception: # pragma: no cover - optional dependency
|
| 25 |
+
tqdm = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _tqdm_enabled() -> bool:
|
| 29 |
+
value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
|
| 30 |
+
return value.strip().lower() not in {"1", "true", "yes", "on"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def parse_args() -> argparse.Namespace:
|
| 34 |
+
parser = argparse.ArgumentParser(
|
| 35 |
+
description="Compute perplexity for a causal LM on one or more datasets."
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument("--model", required=True, help="HF model id or local path")
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--dataset",
|
| 40 |
+
action="append",
|
| 41 |
+
default=[],
|
| 42 |
+
help="HF dataset name (repeatable).",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--dataset_config",
|
| 46 |
+
action="append",
|
| 47 |
+
default=[],
|
| 48 |
+
help="Optional dataset config (repeatable or single shared config).",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--dataset_split",
|
| 52 |
+
default="test",
|
| 53 |
+
help="Dataset split to use (default: test)",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--dataset_text_field",
|
| 57 |
+
default=None,
|
| 58 |
+
help="Text field in dataset (default: auto-detect, applies to all datasets)",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--text",
|
| 62 |
+
action="append",
|
| 63 |
+
default=[],
|
| 64 |
+
help="Inline text samples (can pass multiple)",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--text_file",
|
| 68 |
+
default=None,
|
| 69 |
+
help="Path to a text file for evaluation data",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--num_samples",
|
| 73 |
+
type=int,
|
| 74 |
+
default=0,
|
| 75 |
+
help="Number of token sequences to use per dataset (0 = all)",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--seq_len", type=int, default=2048, help="Sequence length"
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--batch_size", type=int, default=2, help="Batch size"
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--max_batches",
|
| 85 |
+
type=int,
|
| 86 |
+
default=None,
|
| 87 |
+
help="Optional max number of batches to evaluate per dataset",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--model_family",
|
| 91 |
+
type=str,
|
| 92 |
+
choices=["auto", "llama", "qwen"],
|
| 93 |
+
default="auto",
|
| 94 |
+
help="Model family for BOS handling",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--add_bos",
|
| 98 |
+
type=str,
|
| 99 |
+
choices=["auto", "always", "never"],
|
| 100 |
+
default="auto",
|
| 101 |
+
help="Whether to prepend BOS to each sample",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--device",
|
| 105 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 106 |
+
help="Device for model + compute",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--dtype",
|
| 110 |
+
default="auto",
|
| 111 |
+
choices=["auto", "float32", "float16", "bfloat16"],
|
| 112 |
+
help="Model dtype",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--seed", type=int, default=0, help="Random seed for shuffling"
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--shuffle",
|
| 119 |
+
action="store_true",
|
| 120 |
+
help="Shuffle dataset before sampling",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--num_workers",
|
| 124 |
+
type=int,
|
| 125 |
+
default=0,
|
| 126 |
+
help="DataLoader workers",
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--cache_dir",
|
| 130 |
+
default=None,
|
| 131 |
+
help="Optional datasets cache directory",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--trust_remote_code",
|
| 135 |
+
action="store_true",
|
| 136 |
+
help="Allow custom model code from hub",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--output",
|
| 140 |
+
default=None,
|
| 141 |
+
help="Optional JSON output path",
|
| 142 |
+
)
|
| 143 |
+
return parser.parse_args()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _normalize_config(config: Optional[str]) -> Optional[str]:
|
| 147 |
+
if config is None:
|
| 148 |
+
return None
|
| 149 |
+
if config.strip().lower() in {"none", "null", "-"}:
|
| 150 |
+
return None
|
| 151 |
+
return config
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _expand_dataset_configs(
|
| 155 |
+
datasets: List[str], configs: List[str]
|
| 156 |
+
) -> List[Optional[str]]:
|
| 157 |
+
if not configs:
|
| 158 |
+
return [None] * len(datasets)
|
| 159 |
+
if len(configs) == 1 and len(datasets) > 1:
|
| 160 |
+
return [_normalize_config(configs[0])] * len(datasets)
|
| 161 |
+
if len(configs) != len(datasets):
|
| 162 |
+
raise SystemExit(
|
| 163 |
+
"Provide zero, one, or matching-count --dataset_config values."
|
| 164 |
+
)
|
| 165 |
+
return [_normalize_config(cfg) for cfg in configs]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def guess_text_field(dataset) -> str:
|
| 169 |
+
if hasattr(dataset, "column_names") and dataset.column_names:
|
| 170 |
+
if "text" in dataset.column_names:
|
| 171 |
+
return "text"
|
| 172 |
+
return dataset.column_names[0]
|
| 173 |
+
if hasattr(dataset, "features"):
|
| 174 |
+
names = list(dataset.features.keys())
|
| 175 |
+
if "text" in names:
|
| 176 |
+
return "text"
|
| 177 |
+
if names:
|
| 178 |
+
return names[0]
|
| 179 |
+
return "text"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _infer_model_family(model) -> str:
|
| 183 |
+
model_type = str(getattr(getattr(model, "config", None), "model_type", "")).lower()
|
| 184 |
+
architectures = getattr(getattr(model, "config", None), "architectures", [])
|
| 185 |
+
arch_lower = " ".join(str(name).lower() for name in architectures)
|
| 186 |
+
if "qwen" in model_type or "qwen" in arch_lower:
|
| 187 |
+
return "qwen"
|
| 188 |
+
if "llama" in model_type or "llama" in arch_lower:
|
| 189 |
+
return "llama"
|
| 190 |
+
return "unknown"
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _resolve_add_bos(setting: str, model_family: str, tokenizer) -> bool:
|
| 194 |
+
if setting == "always":
|
| 195 |
+
return True
|
| 196 |
+
if setting == "never":
|
| 197 |
+
return False
|
| 198 |
+
if model_family == "llama":
|
| 199 |
+
return True
|
| 200 |
+
if model_family == "qwen":
|
| 201 |
+
return False
|
| 202 |
+
if hasattr(tokenizer, "add_bos_token"):
|
| 203 |
+
return bool(getattr(tokenizer, "add_bos_token"))
|
| 204 |
+
init_kwargs = getattr(tokenizer, "init_kwargs", None)
|
| 205 |
+
if isinstance(init_kwargs, dict) and "add_bos_token" in init_kwargs:
|
| 206 |
+
return bool(init_kwargs["add_bos_token"])
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def build_token_chunks(
|
| 211 |
+
texts: Iterable[str],
|
| 212 |
+
tokenizer,
|
| 213 |
+
seq_len: int,
|
| 214 |
+
num_samples: int,
|
| 215 |
+
add_bos: bool = False,
|
| 216 |
+
) -> List[torch.Tensor]:
|
| 217 |
+
chunks: List[torch.Tensor] = []
|
| 218 |
+
buffer: List[int] = []
|
| 219 |
+
for text in texts:
|
| 220 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 221 |
+
if add_bos and tokenizer.bos_token_id is not None:
|
| 222 |
+
ids = [tokenizer.bos_token_id] + ids
|
| 223 |
+
if not ids:
|
| 224 |
+
continue
|
| 225 |
+
buffer.extend(ids)
|
| 226 |
+
while len(buffer) >= seq_len and len(chunks) < num_samples:
|
| 227 |
+
chunk = buffer[:seq_len]
|
| 228 |
+
buffer = buffer[seq_len:]
|
| 229 |
+
chunks.append(torch.tensor(chunk, dtype=torch.long))
|
| 230 |
+
if len(chunks) >= num_samples:
|
| 231 |
+
break
|
| 232 |
+
return chunks
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_dtype(dtype: str):
|
| 236 |
+
if dtype == "auto":
|
| 237 |
+
return None
|
| 238 |
+
if dtype == "float16":
|
| 239 |
+
return torch.float16
|
| 240 |
+
if dtype == "bfloat16":
|
| 241 |
+
return torch.bfloat16
|
| 242 |
+
return torch.float32
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def compute_ppl(model, dataloader, device: str, max_batches: Optional[int]) -> float:
|
| 246 |
+
model.eval()
|
| 247 |
+
nll_sum = 0.0
|
| 248 |
+
token_count = 0
|
| 249 |
+
iterator = dataloader
|
| 250 |
+
if tqdm is not None and _tqdm_enabled():
|
| 251 |
+
iterator = tqdm(dataloader, desc="PPL", unit="batch")
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
for step, batch in enumerate(iterator):
|
| 254 |
+
if isinstance(batch, dict):
|
| 255 |
+
input_ids = batch["input_ids"].to(device)
|
| 256 |
+
else:
|
| 257 |
+
input_ids = batch[0].to(device)
|
| 258 |
+
outputs = model(input_ids=input_ids)
|
| 259 |
+
logits = outputs.logits
|
| 260 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 261 |
+
shift_labels = input_ids[:, 1:].contiguous()
|
| 262 |
+
loss = torch.nn.functional.cross_entropy(
|
| 263 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 264 |
+
shift_labels.view(-1),
|
| 265 |
+
reduction="sum",
|
| 266 |
+
)
|
| 267 |
+
nll_sum += float(loss.item())
|
| 268 |
+
token_count += shift_labels.numel()
|
| 269 |
+
if max_batches is not None and step + 1 >= max_batches:
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
if token_count == 0:
|
| 273 |
+
raise RuntimeError("No tokens processed; check evaluation inputs.")
|
| 274 |
+
|
| 275 |
+
return math.exp(nll_sum / token_count)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _load_lm_dataset(
|
| 279 |
+
tokenizer,
|
| 280 |
+
dataset_name: str,
|
| 281 |
+
config: Optional[str],
|
| 282 |
+
split: str,
|
| 283 |
+
text_field: Optional[str],
|
| 284 |
+
seq_len: int,
|
| 285 |
+
add_bos: bool,
|
| 286 |
+
cache_dir: Optional[str],
|
| 287 |
+
):
|
| 288 |
+
dataset = load_dataset(
|
| 289 |
+
dataset_name,
|
| 290 |
+
config,
|
| 291 |
+
split=split,
|
| 292 |
+
trust_remote_code=True,
|
| 293 |
+
cache_dir=cache_dir,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
field = text_field or guess_text_field(dataset)
|
| 297 |
+
|
| 298 |
+
def is_valid_text(example) -> bool:
|
| 299 |
+
value = example.get(field)
|
| 300 |
+
return isinstance(value, str) and value.strip() != ""
|
| 301 |
+
|
| 302 |
+
dataset = dataset.filter(is_valid_text, desc=f"filter-{dataset_name}")
|
| 303 |
+
|
| 304 |
+
def tokenize_fn(examples):
|
| 305 |
+
tokenized = tokenizer(
|
| 306 |
+
examples[field],
|
| 307 |
+
add_special_tokens=False,
|
| 308 |
+
return_attention_mask=False,
|
| 309 |
+
)
|
| 310 |
+
if add_bos and tokenizer.bos_token_id is not None:
|
| 311 |
+
tokenized["input_ids"] = [
|
| 312 |
+
[tokenizer.bos_token_id] + ids for ids in tokenized["input_ids"]
|
| 313 |
+
]
|
| 314 |
+
return tokenized
|
| 315 |
+
|
| 316 |
+
tokenized = dataset.map(
|
| 317 |
+
tokenize_fn,
|
| 318 |
+
batched=True,
|
| 319 |
+
remove_columns=dataset.column_names,
|
| 320 |
+
desc=f"tokenize-{dataset_name}",
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def group_texts(examples):
|
| 324 |
+
concatenated = []
|
| 325 |
+
for ids in examples["input_ids"]:
|
| 326 |
+
concatenated.extend(ids)
|
| 327 |
+
total_length = (len(concatenated) // seq_len) * seq_len
|
| 328 |
+
if total_length == 0:
|
| 329 |
+
return {"input_ids": []}
|
| 330 |
+
return {
|
| 331 |
+
"input_ids": [
|
| 332 |
+
concatenated[i : i + seq_len] for i in range(0, total_length, seq_len)
|
| 333 |
+
]
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
lm_dataset = tokenized.map(
|
| 337 |
+
group_texts,
|
| 338 |
+
batched=True,
|
| 339 |
+
batch_size=1000,
|
| 340 |
+
remove_columns=tokenized.column_names,
|
| 341 |
+
desc=f"group-{dataset_name}",
|
| 342 |
+
)
|
| 343 |
+
lm_dataset.set_format(type="torch", columns=["input_ids"])
|
| 344 |
+
return lm_dataset
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def prepare_ppl_dataloaders(
|
| 348 |
+
tokenizer,
|
| 349 |
+
datasets: List[str],
|
| 350 |
+
configs: List[Optional[str]],
|
| 351 |
+
split: str,
|
| 352 |
+
text_field: Optional[str],
|
| 353 |
+
num_samples: int,
|
| 354 |
+
seq_len: int,
|
| 355 |
+
batch_size: int,
|
| 356 |
+
seed: int,
|
| 357 |
+
shuffle: bool,
|
| 358 |
+
model_family: str = "auto",
|
| 359 |
+
add_bos: str = "auto",
|
| 360 |
+
cache_dir: Optional[str] = None,
|
| 361 |
+
num_workers: int = 0,
|
| 362 |
+
model=None,
|
| 363 |
+
) -> Dict[str, torch.utils.data.DataLoader]:
|
| 364 |
+
if load_dataset is None:
|
| 365 |
+
raise SystemExit("datasets is required for dataset evaluation")
|
| 366 |
+
|
| 367 |
+
resolved_family = model_family
|
| 368 |
+
if resolved_family == "auto":
|
| 369 |
+
if model is None:
|
| 370 |
+
raise SystemExit("model is required when model_family is 'auto'")
|
| 371 |
+
resolved_family = _infer_model_family(model)
|
| 372 |
+
use_bos = _resolve_add_bos(add_bos, resolved_family, tokenizer)
|
| 373 |
+
if use_bos and tokenizer.bos_token_id is None:
|
| 374 |
+
use_bos = False
|
| 375 |
+
|
| 376 |
+
dataloaders: Dict[str, torch.utils.data.DataLoader] = {}
|
| 377 |
+
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
|
| 378 |
+
lm_dataset = _load_lm_dataset(
|
| 379 |
+
tokenizer=tokenizer,
|
| 380 |
+
dataset_name=dataset_name,
|
| 381 |
+
config=config,
|
| 382 |
+
split=split,
|
| 383 |
+
text_field=text_field,
|
| 384 |
+
seq_len=seq_len,
|
| 385 |
+
add_bos=use_bos,
|
| 386 |
+
cache_dir=cache_dir,
|
| 387 |
+
)
|
| 388 |
+
if shuffle:
|
| 389 |
+
try:
|
| 390 |
+
lm_dataset = lm_dataset.shuffle(seed=seed + idx)
|
| 391 |
+
except Exception:
|
| 392 |
+
pass
|
| 393 |
+
if num_samples and hasattr(lm_dataset, "__len__"):
|
| 394 |
+
limit = min(num_samples, len(lm_dataset))
|
| 395 |
+
lm_dataset = lm_dataset.select(range(limit))
|
| 396 |
+
|
| 397 |
+
data_loader = torch.utils.data.DataLoader(
|
| 398 |
+
lm_dataset,
|
| 399 |
+
batch_size=batch_size,
|
| 400 |
+
shuffle=False,
|
| 401 |
+
num_workers=num_workers,
|
| 402 |
+
)
|
| 403 |
+
label = dataset_name if config is None else f"{dataset_name}:{config}"
|
| 404 |
+
dataloaders[label] = data_loader
|
| 405 |
+
|
| 406 |
+
return dataloaders
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def evaluate_ppl_dataloaders(
|
| 410 |
+
model,
|
| 411 |
+
dataloaders: Dict[str, torch.utils.data.DataLoader],
|
| 412 |
+
device: str,
|
| 413 |
+
max_batches: Optional[int] = None,
|
| 414 |
+
) -> Dict[str, float]:
|
| 415 |
+
results: Dict[str, float] = {}
|
| 416 |
+
for label, data_loader in dataloaders.items():
|
| 417 |
+
ppl = compute_ppl(model, data_loader, device, max_batches=max_batches)
|
| 418 |
+
results[label] = ppl
|
| 419 |
+
return results
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def evaluate_ppl_datasets(
|
| 423 |
+
model,
|
| 424 |
+
tokenizer,
|
| 425 |
+
datasets: List[str],
|
| 426 |
+
configs: List[Optional[str]],
|
| 427 |
+
split: str,
|
| 428 |
+
text_field: Optional[str],
|
| 429 |
+
num_samples: int,
|
| 430 |
+
seq_len: int,
|
| 431 |
+
batch_size: int,
|
| 432 |
+
device: str,
|
| 433 |
+
seed: int,
|
| 434 |
+
shuffle: bool,
|
| 435 |
+
model_family: str = "auto",
|
| 436 |
+
add_bos: str = "auto",
|
| 437 |
+
max_batches: Optional[int] = None,
|
| 438 |
+
cache_dir: Optional[str] = None,
|
| 439 |
+
num_workers: int = 0,
|
| 440 |
+
) -> Dict[str, float]:
|
| 441 |
+
if load_dataset is None:
|
| 442 |
+
raise SystemExit("datasets is required for dataset evaluation")
|
| 443 |
+
|
| 444 |
+
resolved_family = model_family
|
| 445 |
+
if resolved_family == "auto":
|
| 446 |
+
resolved_family = _infer_model_family(model)
|
| 447 |
+
use_bos = _resolve_add_bos(add_bos, resolved_family, tokenizer)
|
| 448 |
+
if use_bos and tokenizer.bos_token_id is None:
|
| 449 |
+
use_bos = False
|
| 450 |
+
|
| 451 |
+
results: Dict[str, float] = {}
|
| 452 |
+
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
|
| 453 |
+
lm_dataset = _load_lm_dataset(
|
| 454 |
+
tokenizer=tokenizer,
|
| 455 |
+
dataset_name=dataset_name,
|
| 456 |
+
config=config,
|
| 457 |
+
split=split,
|
| 458 |
+
text_field=text_field,
|
| 459 |
+
seq_len=seq_len,
|
| 460 |
+
add_bos=use_bos,
|
| 461 |
+
cache_dir=cache_dir,
|
| 462 |
+
)
|
| 463 |
+
if shuffle:
|
| 464 |
+
try:
|
| 465 |
+
lm_dataset = lm_dataset.shuffle(seed=seed + idx)
|
| 466 |
+
except Exception:
|
| 467 |
+
pass
|
| 468 |
+
if num_samples and hasattr(lm_dataset, "__len__"):
|
| 469 |
+
limit = min(num_samples, len(lm_dataset))
|
| 470 |
+
lm_dataset = lm_dataset.select(range(limit))
|
| 471 |
+
|
| 472 |
+
data_loader = torch.utils.data.DataLoader(
|
| 473 |
+
lm_dataset,
|
| 474 |
+
batch_size=batch_size,
|
| 475 |
+
shuffle=False,
|
| 476 |
+
num_workers=num_workers,
|
| 477 |
+
)
|
| 478 |
+
label = dataset_name if config is None else f"{dataset_name}:{config}"
|
| 479 |
+
ppl = compute_ppl(model, data_loader, device, max_batches=max_batches)
|
| 480 |
+
results[label] = ppl
|
| 481 |
+
return results
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def main() -> None:
|
| 485 |
+
args = parse_args()
|
| 486 |
+
torch.manual_seed(args.seed)
|
| 487 |
+
|
| 488 |
+
dtype = get_dtype(args.dtype)
|
| 489 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 490 |
+
args.model,
|
| 491 |
+
torch_dtype=dtype,
|
| 492 |
+
trust_remote_code=args.trust_remote_code,
|
| 493 |
+
)
|
| 494 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 495 |
+
args.model, trust_remote_code=args.trust_remote_code
|
| 496 |
+
)
|
| 497 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 498 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 499 |
+
|
| 500 |
+
model.to(args.device)
|
| 501 |
+
|
| 502 |
+
results: Dict[str, float] = {}
|
| 503 |
+
resolved_family = args.model_family
|
| 504 |
+
if resolved_family == "auto":
|
| 505 |
+
resolved_family = _infer_model_family(model)
|
| 506 |
+
use_bos = _resolve_add_bos(args.add_bos, resolved_family, tokenizer)
|
| 507 |
+
if use_bos and tokenizer.bos_token_id is None:
|
| 508 |
+
use_bos = False
|
| 509 |
+
|
| 510 |
+
if args.dataset:
|
| 511 |
+
datasets = list(args.dataset)
|
| 512 |
+
configs = _expand_dataset_configs(datasets, list(args.dataset_config))
|
| 513 |
+
results.update(
|
| 514 |
+
evaluate_ppl_datasets(
|
| 515 |
+
model,
|
| 516 |
+
tokenizer,
|
| 517 |
+
datasets=datasets,
|
| 518 |
+
configs=configs,
|
| 519 |
+
split=args.dataset_split,
|
| 520 |
+
text_field=args.dataset_text_field,
|
| 521 |
+
num_samples=args.num_samples,
|
| 522 |
+
seq_len=args.seq_len,
|
| 523 |
+
batch_size=args.batch_size,
|
| 524 |
+
device=args.device,
|
| 525 |
+
seed=args.seed,
|
| 526 |
+
shuffle=args.shuffle,
|
| 527 |
+
model_family=resolved_family,
|
| 528 |
+
add_bos="always" if use_bos else "never",
|
| 529 |
+
max_batches=args.max_batches,
|
| 530 |
+
cache_dir=args.cache_dir,
|
| 531 |
+
num_workers=args.num_workers,
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
if args.text_file or args.text:
|
| 536 |
+
custom_texts: List[str] = []
|
| 537 |
+
if args.text_file:
|
| 538 |
+
with open(args.text_file, "r", encoding="utf-8") as handle:
|
| 539 |
+
custom_texts.extend([line.strip() for line in handle if line.strip()])
|
| 540 |
+
if args.text:
|
| 541 |
+
custom_texts.extend([t for t in args.text if t])
|
| 542 |
+
if custom_texts:
|
| 543 |
+
chunks = build_token_chunks(
|
| 544 |
+
custom_texts,
|
| 545 |
+
tokenizer,
|
| 546 |
+
args.seq_len,
|
| 547 |
+
args.num_samples if args.num_samples > 0 else 1_000_000,
|
| 548 |
+
add_bos=use_bos,
|
| 549 |
+
)
|
| 550 |
+
if not chunks:
|
| 551 |
+
raise SystemExit(
|
| 552 |
+
"Not enough custom text to build token sequences. "
|
| 553 |
+
"Provide more --text/--text_file content or reduce --seq_len."
|
| 554 |
+
)
|
| 555 |
+
dataset = torch.utils.data.TensorDataset(torch.stack(chunks))
|
| 556 |
+
dataloader = torch.utils.data.DataLoader(
|
| 557 |
+
dataset, batch_size=args.batch_size, shuffle=False
|
| 558 |
+
)
|
| 559 |
+
results["custom"] = compute_ppl(
|
| 560 |
+
model, dataloader, args.device, max_batches=args.max_batches
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
if not results:
|
| 564 |
+
raise SystemExit("Provide --dataset and/or --text/--text_file for evaluation")
|
| 565 |
+
|
| 566 |
+
print("Perplexity results:")
|
| 567 |
+
for name, ppl in results.items():
|
| 568 |
+
print(f"{name}: {ppl:.4f}")
|
| 569 |
+
|
| 570 |
+
if args.output:
|
| 571 |
+
with open(args.output, "w", encoding="utf-8") as handle:
|
| 572 |
+
json.dump({"model": args.model, "results": results}, handle, indent=2)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
if __name__ == "__main__":
|
| 576 |
+
main()
|
src/ppl_eval_progressive.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Evaluate perplexity for a progressive-pruned model assembled from cycles."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import ppl_eval
|
| 10 |
+
except Exception as exc: # pragma: no cover - optional dependency
|
| 11 |
+
raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from transformers import AutoTokenizer
|
| 15 |
+
except Exception as exc: # pragma: no cover - fail early with clear error
|
| 16 |
+
raise SystemExit("transformers is required: pip install transformers") from exc
|
| 17 |
+
|
| 18 |
+
from progressive_loader import load_progressive_model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args() -> argparse.Namespace:
|
| 22 |
+
parser = argparse.ArgumentParser(
|
| 23 |
+
description="Evaluate PPL for a model reconstructed from progressive cycles."
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument("--base_model", required=True, help="Base HF model id or path")
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--progressive_dir",
|
| 28 |
+
required=True,
|
| 29 |
+
help="Output directory from progressive pruning",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--cycle",
|
| 33 |
+
type=int,
|
| 34 |
+
default=None,
|
| 35 |
+
help="Cycle to load (default: final)",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--dataset",
|
| 39 |
+
action="append",
|
| 40 |
+
default=[],
|
| 41 |
+
help="Evaluation dataset name (repeatable). Defaults to wikitext.",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--dataset_config",
|
| 45 |
+
action="append",
|
| 46 |
+
default=[],
|
| 47 |
+
help="Evaluation dataset config (repeatable or single shared config).",
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--dataset_split",
|
| 51 |
+
default="test",
|
| 52 |
+
help="Evaluation dataset split (default: test)",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--dataset_text_field",
|
| 56 |
+
default=None,
|
| 57 |
+
help="Evaluation text field override (default: auto-detect)",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--num_samples",
|
| 61 |
+
type=int,
|
| 62 |
+
default=0,
|
| 63 |
+
help="Number of token sequences per dataset (0 = all)",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--seq_len",
|
| 67 |
+
type=int,
|
| 68 |
+
default=2048,
|
| 69 |
+
help="Sequence length for eval",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--batch_size",
|
| 73 |
+
type=int,
|
| 74 |
+
default=4,
|
| 75 |
+
help="Batch size for eval",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--device",
|
| 79 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 80 |
+
help="Device for eval",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--model_family",
|
| 85 |
+
type=str,
|
| 86 |
+
choices=["auto", "llama", "qwen"],
|
| 87 |
+
default="auto",
|
| 88 |
+
help="Model family for BOS handling",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--add_bos",
|
| 92 |
+
type=str,
|
| 93 |
+
choices=["auto", "always", "never"],
|
| 94 |
+
default="auto",
|
| 95 |
+
help="Whether to prepend BOS to each sample",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--max_batches",
|
| 99 |
+
type=int,
|
| 100 |
+
default=None,
|
| 101 |
+
help="Optional max number of eval batches per dataset",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--cache_dir",
|
| 105 |
+
default=None,
|
| 106 |
+
help="Optional datasets cache dir for eval",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--num_workers",
|
| 110 |
+
type=int,
|
| 111 |
+
default=0,
|
| 112 |
+
help="Eval DataLoader workers",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--dtype",
|
| 116 |
+
default="auto",
|
| 117 |
+
choices=["auto", "float32", "float16", "bfloat16"],
|
| 118 |
+
help="Model dtype",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--trust_remote_code",
|
| 122 |
+
action="store_true",
|
| 123 |
+
help="Allow custom model code from hub",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--layer_path",
|
| 127 |
+
default=None,
|
| 128 |
+
help="Override layer attribute path if needed",
|
| 129 |
+
)
|
| 130 |
+
return parser.parse_args()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def main() -> None:
|
| 134 |
+
args = parse_args()
|
| 135 |
+
torch.manual_seed(args.seed)
|
| 136 |
+
|
| 137 |
+
datasets = args.dataset or ["wikitext"]
|
| 138 |
+
configs = args.dataset_config or ["wikitext-2-raw-v1"]
|
| 139 |
+
configs = ppl_eval._expand_dataset_configs(datasets, configs)
|
| 140 |
+
|
| 141 |
+
model = load_progressive_model(
|
| 142 |
+
args.base_model,
|
| 143 |
+
args.progressive_dir,
|
| 144 |
+
cycle=args.cycle,
|
| 145 |
+
device=args.device,
|
| 146 |
+
dtype=args.dtype,
|
| 147 |
+
trust_remote_code=args.trust_remote_code,
|
| 148 |
+
layer_path=args.layer_path,
|
| 149 |
+
)
|
| 150 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 151 |
+
args.base_model, trust_remote_code=args.trust_remote_code
|
| 152 |
+
)
|
| 153 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 154 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 155 |
+
|
| 156 |
+
results = ppl_eval.evaluate_ppl_datasets(
|
| 157 |
+
model,
|
| 158 |
+
tokenizer,
|
| 159 |
+
datasets=datasets,
|
| 160 |
+
configs=configs,
|
| 161 |
+
split=args.dataset_split,
|
| 162 |
+
text_field=args.dataset_text_field,
|
| 163 |
+
num_samples=args.num_samples,
|
| 164 |
+
seq_len=args.seq_len,
|
| 165 |
+
batch_size=args.batch_size,
|
| 166 |
+
device=args.device,
|
| 167 |
+
seed=args.seed,
|
| 168 |
+
shuffle=False,
|
| 169 |
+
model_family=args.model_family,
|
| 170 |
+
add_bos=args.add_bos,
|
| 171 |
+
max_batches=args.max_batches,
|
| 172 |
+
cache_dir=args.cache_dir,
|
| 173 |
+
num_workers=args.num_workers,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
print("Perplexity results:")
|
| 177 |
+
for name, ppl in results.items():
|
| 178 |
+
print(f"{name}: {ppl:.4f}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
main()
|
src/print_progressive_ppl_csv.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Print progressive PPL stats as CSV from progressive_metadata.json.
|
| 3 |
+
|
| 4 |
+
Expected (current) metadata shape:
|
| 5 |
+
- data["eval"]["pre_ppl"]
|
| 6 |
+
- data["cycles"][i]["redistrib_post_ppl"] (optional; legacy key)
|
| 7 |
+
- data["cycles"][i]["comm_post_ppl"] (optional; current key)
|
| 8 |
+
- data["cycles"][i]["distill_post_ppl"]
|
| 9 |
+
- data["cycles"][i]["lora_post_ppl"] (typically only set on the last cycle)
|
| 10 |
+
- data["cycles"][i]["post_ppl"]
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import csv
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import shlex
|
| 18 |
+
import sys
|
| 19 |
+
from typing import Any, List, Optional
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _cell(value: Any) -> str:
|
| 23 |
+
if value is None:
|
| 24 |
+
return ""
|
| 25 |
+
if isinstance(value, dict):
|
| 26 |
+
if not value:
|
| 27 |
+
return ""
|
| 28 |
+
return ";".join(str(value[key]) for key in sorted(value))
|
| 29 |
+
if isinstance(value, (list, tuple)):
|
| 30 |
+
return ";".join(str(item) for item in value)
|
| 31 |
+
return str(value)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _read_run_command_tokens(metadata_path: str) -> Optional[List[str]]:
|
| 35 |
+
meta_dir = os.path.dirname(os.path.abspath(metadata_path))
|
| 36 |
+
run_args_path = os.path.join(meta_dir, "run_args.txt")
|
| 37 |
+
if not os.path.exists(run_args_path):
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
with open(run_args_path, "r", encoding="utf-8") as handle:
|
| 42 |
+
lines = handle.read().splitlines()
|
| 43 |
+
except OSError:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
cmd_line = None
|
| 47 |
+
for idx, line in enumerate(lines):
|
| 48 |
+
if line.strip() == "command:":
|
| 49 |
+
if idx + 1 < len(lines):
|
| 50 |
+
cmd_line = lines[idx + 1].strip()
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
if not cmd_line:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
return shlex.split(cmd_line)
|
| 58 |
+
except ValueError:
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _parse_exclude_pairs_from_tokens(tokens: List[str]) -> Optional[List[int]]:
|
| 63 |
+
start = None
|
| 64 |
+
for idx, tok in enumerate(tokens):
|
| 65 |
+
if tok in ("--exclude_pairs", "--exclude_layers"):
|
| 66 |
+
start = idx + 1
|
| 67 |
+
break
|
| 68 |
+
if start is None:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
raw: List[int] = []
|
| 72 |
+
for tok in tokens[start:]:
|
| 73 |
+
if tok.startswith("--"):
|
| 74 |
+
break
|
| 75 |
+
# Legacy bug: run_args.txt used to print "python" before every token.
|
| 76 |
+
if tok == "python":
|
| 77 |
+
continue
|
| 78 |
+
try:
|
| 79 |
+
raw.append(int(tok))
|
| 80 |
+
except ValueError:
|
| 81 |
+
continue
|
| 82 |
+
return raw
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _normalize_excluded_pairs(raw: List[int], num_pairs: int) -> List[int]:
|
| 86 |
+
exclude: List[int] = []
|
| 87 |
+
for idx in raw:
|
| 88 |
+
if idx < 0:
|
| 89 |
+
idx = num_pairs + idx
|
| 90 |
+
if 0 <= idx < num_pairs:
|
| 91 |
+
exclude.append(idx)
|
| 92 |
+
return sorted(set(exclude))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _read_excluded_pairs_from_cycle_meta(meta_dir: str, cycle_idx: int) -> Optional[List[int]]:
|
| 96 |
+
path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
|
| 97 |
+
try:
|
| 98 |
+
with open(path, "r", encoding="utf-8") as handle:
|
| 99 |
+
cycle_meta = json.load(handle)
|
| 100 |
+
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
dwce_meta = cycle_meta.get("dwce_meta") or {}
|
| 104 |
+
excluded = dwce_meta.get("excluded_pairs")
|
| 105 |
+
if isinstance(excluded, list) and all(isinstance(x, int) for x in excluded):
|
| 106 |
+
return excluded
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _num_pairs_for_cycle(data: dict, meta_dir: str, cycle_idx: int) -> Optional[int]:
|
| 111 |
+
num_progressive = data.get("num_progressive")
|
| 112 |
+
final_num_layers = data.get("final_num_layers")
|
| 113 |
+
if isinstance(num_progressive, int) and isinstance(final_num_layers, int):
|
| 114 |
+
initial_layers = final_num_layers + num_progressive
|
| 115 |
+
return max(initial_layers - cycle_idx, 0)
|
| 116 |
+
|
| 117 |
+
cycle_meta_path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
|
| 118 |
+
try:
|
| 119 |
+
with open(cycle_meta_path, "r", encoding="utf-8") as handle:
|
| 120 |
+
cycle_meta = json.load(handle)
|
| 121 |
+
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
num_layers_before = cycle_meta.get("num_layers_before")
|
| 125 |
+
if isinstance(num_layers_before, int):
|
| 126 |
+
return max(num_layers_before - 1, 0)
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main() -> None:
|
| 131 |
+
parser = argparse.ArgumentParser(
|
| 132 |
+
description="Print progressive PPL values as CSV from progressive_metadata.json"
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument("path", help="Path to progressive_metadata.json")
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
with open(args.path, "r", encoding="utf-8") as handle:
|
| 139 |
+
data = json.load(handle)
|
| 140 |
+
except FileNotFoundError as exc:
|
| 141 |
+
raise SystemExit(f"File not found: {args.path}") from exc
|
| 142 |
+
except json.JSONDecodeError as exc:
|
| 143 |
+
raise SystemExit(f"Invalid JSON: {args.path}") from exc
|
| 144 |
+
|
| 145 |
+
meta_dir = os.path.dirname(os.path.abspath(args.path))
|
| 146 |
+
run_tokens = _read_run_command_tokens(args.path)
|
| 147 |
+
raw_exclude = (
|
| 148 |
+
_parse_exclude_pairs_from_tokens(run_tokens) if run_tokens is not None else None
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
writer = csv.writer(sys.stdout)
|
| 152 |
+
writer.writerow(
|
| 153 |
+
[
|
| 154 |
+
"cycle",
|
| 155 |
+
"layer_merged",
|
| 156 |
+
"layer_pair",
|
| 157 |
+
"excluded_pairs",
|
| 158 |
+
"redistrib_post_ppl",
|
| 159 |
+
"distill_post_ppl",
|
| 160 |
+
"lora_post_ppl",
|
| 161 |
+
"post_ppl",
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
pre_ppl = data.get("eval", {}).get("pre_ppl")
|
| 166 |
+
if pre_ppl is not None:
|
| 167 |
+
writer.writerow(["pre", "", "", "", "", "", "", _cell(pre_ppl)])
|
| 168 |
+
|
| 169 |
+
cycles = data.get("cycles") or data.get("cycle_summaries") or []
|
| 170 |
+
for cycle in cycles:
|
| 171 |
+
cycle_idx = cycle.get("cycle", "")
|
| 172 |
+
layer_merged = cycle.get("layer_merged")
|
| 173 |
+
layer_pair = ""
|
| 174 |
+
if isinstance(layer_merged, int):
|
| 175 |
+
layer_pair = f"{layer_merged}-{layer_merged + 1}"
|
| 176 |
+
|
| 177 |
+
excluded_pairs = _read_excluded_pairs_from_cycle_meta(
|
| 178 |
+
meta_dir, cycle_idx if isinstance(cycle_idx, int) else -1
|
| 179 |
+
)
|
| 180 |
+
if excluded_pairs is None and raw_exclude is not None and isinstance(cycle_idx, int):
|
| 181 |
+
num_pairs = _num_pairs_for_cycle(data, meta_dir, cycle_idx)
|
| 182 |
+
if isinstance(num_pairs, int):
|
| 183 |
+
excluded_pairs = _normalize_excluded_pairs(raw_exclude, num_pairs)
|
| 184 |
+
redistrib_post_ppl = cycle.get("redistrib_post_ppl")
|
| 185 |
+
if redistrib_post_ppl is None:
|
| 186 |
+
redistrib_post_ppl = cycle.get("comm_post_ppl")
|
| 187 |
+
|
| 188 |
+
writer.writerow(
|
| 189 |
+
[
|
| 190 |
+
cycle_idx,
|
| 191 |
+
layer_merged if layer_merged is not None else "",
|
| 192 |
+
layer_pair,
|
| 193 |
+
_cell(excluded_pairs),
|
| 194 |
+
_cell(redistrib_post_ppl),
|
| 195 |
+
_cell(cycle.get("distill_post_ppl")),
|
| 196 |
+
_cell(cycle.get("lora_post_ppl")),
|
| 197 |
+
_cell(cycle.get("post_ppl")),
|
| 198 |
+
]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|
src/progressive_loader.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Utilities to reconstruct models from progressive pruning cycles."""
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from transformers import AutoModelForCausalLM, PretrainedConfig
|
| 12 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
| 13 |
+
except Exception as exc: # pragma: no cover - fail early with clear error
|
| 14 |
+
raise SystemExit("transformers is required: pip install transformers") from exc
|
| 15 |
+
|
| 16 |
+
from fuse_layers_model import (
|
| 17 |
+
decrement_config,
|
| 18 |
+
drop_layer,
|
| 19 |
+
find_layer_container,
|
| 20 |
+
get_dtype,
|
| 21 |
+
normalize_config,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_progressive_metadata(output_dir: str) -> dict:
|
| 26 |
+
path = os.path.join(output_dir, "progressive_metadata.json")
|
| 27 |
+
if not os.path.exists(path):
|
| 28 |
+
raise FileNotFoundError(f"Missing progressive metadata at {path}")
|
| 29 |
+
with open(path, "r", encoding="utf-8") as handle:
|
| 30 |
+
return json.load(handle)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_normalized_config(model_path: str, trust_remote_code: bool):
|
| 34 |
+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(
|
| 35 |
+
model_path,
|
| 36 |
+
trust_remote_code=trust_remote_code,
|
| 37 |
+
)
|
| 38 |
+
num_hidden_layers = config_dict.get("num_hidden_layers")
|
| 39 |
+
layer_types = config_dict.get("layer_types")
|
| 40 |
+
if (
|
| 41 |
+
isinstance(num_hidden_layers, int)
|
| 42 |
+
and num_hidden_layers >= 0
|
| 43 |
+
and isinstance(layer_types, list)
|
| 44 |
+
and len(layer_types) != num_hidden_layers
|
| 45 |
+
):
|
| 46 |
+
config_dict["layer_types"] = list(layer_types[:num_hidden_layers])
|
| 47 |
+
model_type = config_dict["model_type"]
|
| 48 |
+
config_class = CONFIG_MAPPING[model_type]
|
| 49 |
+
config = config_class.from_dict(config_dict, **unused_kwargs)
|
| 50 |
+
normalize_config(config)
|
| 51 |
+
return config
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_causal_lm(
|
| 55 |
+
model_path_or_id: str,
|
| 56 |
+
*,
|
| 57 |
+
torch_dtype,
|
| 58 |
+
trust_remote_code: bool,
|
| 59 |
+
**kwargs,
|
| 60 |
+
) -> torch.nn.Module:
|
| 61 |
+
config = None
|
| 62 |
+
config_path = os.path.join(model_path_or_id, "config.json")
|
| 63 |
+
if os.path.isdir(model_path_or_id) and os.path.isfile(config_path):
|
| 64 |
+
config = load_normalized_config(model_path_or_id, trust_remote_code)
|
| 65 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 66 |
+
model_path_or_id,
|
| 67 |
+
config=config,
|
| 68 |
+
torch_dtype=torch_dtype,
|
| 69 |
+
trust_remote_code=trust_remote_code,
|
| 70 |
+
**kwargs,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_progressive_model(
|
| 75 |
+
base_model_id: str,
|
| 76 |
+
output_dir: str,
|
| 77 |
+
cycle: Optional[int] = None,
|
| 78 |
+
device: Optional[str] = None,
|
| 79 |
+
dtype: str = "auto",
|
| 80 |
+
trust_remote_code: bool = False,
|
| 81 |
+
layer_path: Optional[str] = None,
|
| 82 |
+
) -> torch.nn.Module:
|
| 83 |
+
meta = load_progressive_metadata(output_dir)
|
| 84 |
+
num_cycles = int(meta.get("num_progressive", 0))
|
| 85 |
+
if cycle is None:
|
| 86 |
+
cycle = num_cycles
|
| 87 |
+
if cycle < 0 or cycle > num_cycles:
|
| 88 |
+
raise ValueError(f"Cycle {cycle} is outside [0, {num_cycles}]")
|
| 89 |
+
|
| 90 |
+
if cycle > 0:
|
| 91 |
+
full_model_dir = os.path.join(output_dir, f"cycle_{cycle}", "full_model")
|
| 92 |
+
if os.path.isdir(full_model_dir):
|
| 93 |
+
model = load_causal_lm(
|
| 94 |
+
full_model_dir,
|
| 95 |
+
torch_dtype=get_dtype(dtype),
|
| 96 |
+
trust_remote_code=trust_remote_code,
|
| 97 |
+
)
|
| 98 |
+
if device:
|
| 99 |
+
model.to(device)
|
| 100 |
+
return model
|
| 101 |
+
|
| 102 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 103 |
+
base_model_id,
|
| 104 |
+
torch_dtype=get_dtype(dtype),
|
| 105 |
+
trust_remote_code=trust_remote_code,
|
| 106 |
+
)
|
| 107 |
+
active_layer_path = layer_path or meta.get("layer_path")
|
| 108 |
+
parent, name, container = find_layer_container(model, active_layer_path)
|
| 109 |
+
|
| 110 |
+
for idx in range(1, cycle + 1):
|
| 111 |
+
cycle_dir = os.path.join(output_dir, f"cycle_{idx}")
|
| 112 |
+
cycle_meta_path = os.path.join(cycle_dir, "cycle_metadata.json")
|
| 113 |
+
if not os.path.exists(cycle_meta_path):
|
| 114 |
+
raise FileNotFoundError(f"Missing cycle metadata at {cycle_meta_path}")
|
| 115 |
+
with open(cycle_meta_path, "r", encoding="utf-8") as handle:
|
| 116 |
+
cycle_meta = json.load(handle)
|
| 117 |
+
|
| 118 |
+
layer_idx = int(cycle_meta["layer_merged"])
|
| 119 |
+
fused_state = cycle_meta.get("fused_layer_state", "fused_layer.pt")
|
| 120 |
+
fused_state_path = os.path.join(cycle_dir, fused_state)
|
| 121 |
+
if not os.path.exists(fused_state_path):
|
| 122 |
+
raise FileNotFoundError(f"Missing fused layer at {fused_state_path}")
|
| 123 |
+
|
| 124 |
+
layers = list(container)
|
| 125 |
+
if layer_idx < 0 or layer_idx >= len(layers):
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"Cycle {idx} layer index {layer_idx} out of range for {len(layers)} layers"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
state = torch.load(fused_state_path, map_location="cpu")
|
| 131 |
+
layers[layer_idx].load_state_dict(state)
|
| 132 |
+
|
| 133 |
+
new_container = drop_layer(container, layer_idx + 1)
|
| 134 |
+
setattr(parent, name, new_container)
|
| 135 |
+
decrement_config(model.config)
|
| 136 |
+
|
| 137 |
+
container = new_container
|
| 138 |
+
|
| 139 |
+
if device:
|
| 140 |
+
model.to(device)
|
| 141 |
+
|
| 142 |
+
return model
|