anzheCheng commited on
Commit
8e9a70d
·
verified ·
1 Parent(s): 1f1ec1c

Upload eigen_moe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eigen_moe.py +483 -0
eigen_moe.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, dataclass
2
+ from pathlib import Path
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import List, Optional, Dict, Tuple
8
+ import timm
9
+
10
+ try:
11
+ from huggingface_hub import HfApi, PyTorchModelHubMixin, hf_hub_download
12
+ from huggingface_hub.utils import EntryNotFoundError
13
+ except ImportError: # pragma: no cover - only used when huggingface_hub is unavailable
14
+ HfApi = None # type: ignore[assignment]
15
+ PyTorchModelHubMixin = object # type: ignore[assignment,misc]
16
+ hf_hub_download = None # type: ignore[assignment]
17
+ EntryNotFoundError = FileNotFoundError # type: ignore[assignment]
18
+
19
+ class ENNBasis(nn.Module):
20
+ def __init__(self, d_in: int, d_out: int, r: int, ortho_lambda: float = 1e-3):
21
+ super().__init__()
22
+ assert r <= min(d_in, d_out)
23
+ self.d_in, self.d_out, self.r = d_in, d_out, r
24
+ self.ortho_lambda = ortho_lambda
25
+
26
+ Q = torch.empty(d_out, r)
27
+ P = torch.empty(d_in, r)
28
+ nn.init.orthogonal_(Q)
29
+ nn.init.orthogonal_(P)
30
+ self.Q = nn.Parameter(Q)
31
+ self.P = nn.Parameter(P)
32
+ self.log_lambda = nn.Parameter(torch.zeros(r))
33
+
34
+ @torch.no_grad()
35
+ def _qr_retract_(self):
36
+ qQ, _ = torch.linalg.qr(self.Q, mode='reduced')
37
+ qP, _ = torch.linalg.qr(self.P, mode='reduced')
38
+ self.Q.copy_(qQ); self.P.copy_(qP)
39
+
40
+ def ortho_penalty(self) -> torch.Tensor:
41
+ It = torch.eye(self.r, device=self.Q.device, dtype=self.Q.dtype)
42
+ t1 = (self.Q.T @ self.Q - It).pow(2).sum()
43
+ t2 = (self.P.T @ self.P - It).pow(2).sum()
44
+ return self.ortho_lambda * (t1 + t2)
45
+
46
+ def reconstruct_weight(self) -> torch.Tensor:
47
+ lam = torch.diag_embed(self.log_lambda.exp())
48
+ return self.Q @ lam @ self.P.T
49
+
50
+ def project_out(self, h: torch.Tensor) -> torch.Tensor:
51
+ return torch.einsum('dr,btd->btr', self.Q, h)
52
+
53
+ class AdapterExpert(nn.Module):
54
+ def __init__(self, d_model, bottleneck=192):
55
+ super().__init__()
56
+ self.down = nn.Linear(d_model, bottleneck, bias=False)
57
+ self.up = nn.Linear(bottleneck, d_model, bias=False)
58
+ self.act = nn.GELU()
59
+ def forward(self, x): return self.up(self.act(self.down(x)))
60
+
61
+ class EigenRouter(nn.Module):
62
+ def __init__(self, d_model: int, r: int, n_experts: int, tau: float = 1.0, topk: int = 0,
63
+ ortho_lambda: float = 1e-3):
64
+ super().__init__()
65
+ self.n_experts, self.topk, self.tau = n_experts, topk, tau
66
+ self.basis = ENNBasis(d_in=d_model, d_out=d_model, r=r, ortho_lambda=ortho_lambda)
67
+ self.gamma = nn.Parameter(torch.ones(r))
68
+ self.masks = nn.Parameter(torch.randn(n_experts, r))
69
+ self.bias = nn.Parameter(torch.zeros(n_experts))
70
+
71
+ def forward(self, h: torch.Tensor):
72
+ if self.training: self.basis._qr_retract_()
73
+ z = self.basis.project_out(h)
74
+ e = z.pow(2)
75
+ e = e / (e.sum(dim=-1, keepdim=True) + 1e-6)
76
+ m = torch.softmax(self.masks, dim=0)
77
+ logits = torch.einsum('btr,r,er->bte', e, self.gamma, m) + self.bias
78
+ probs = F.softmax(logits / self.tau, dim=-1)
79
+ ortho = self.basis.ortho_penalty()
80
+ if self.topk and self.topk < self.n_experts:
81
+ vals, idx = torch.topk(probs, k=self.topk, dim=-1)
82
+ return probs, vals, idx, ortho
83
+ return probs, None, None, ortho
84
+
85
+ class MoEAdapterBranch(nn.Module):
86
+ def __init__(self, d_model: int, n_experts: int = 8, r: int = 128, bottleneck: int = 192,
87
+ tau: float = 1.0, router_mode: str = "soft", alpha: float = 1.0,
88
+ apply_to_patches_only: bool = True, ortho_lambda: float = 1e-3):
89
+ super().__init__()
90
+ topk = 0 if router_mode == "soft" else (1 if router_mode == "top1" else 2)
91
+ self.router = EigenRouter(d_model, r, n_experts, tau, topk, ortho_lambda)
92
+ self.experts = nn.ModuleList([AdapterExpert(d_model, bottleneck) for _ in range(n_experts)])
93
+ self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
94
+ self.apply_to_patches_only = apply_to_patches_only
95
+
96
+ def forward(self, x: torch.Tensor):
97
+ if self.apply_to_patches_only and x.dim() == 3 and x.size(1) >= 2:
98
+ cls_tok, patches = x[:, :1, :], x[:, 1:, :]
99
+ y, stats = self._forward_tokens(patches)
100
+ return torch.cat([cls_tok, y], dim=1), stats
101
+ else:
102
+ return self._forward_tokens(x)
103
+
104
+ def _forward_tokens(self, h: torch.Tensor):
105
+ probs, vals, idx, ortho = self.router(h)
106
+ stats = {"ortho_reg": ortho, "router_entropy": (-(probs * (probs.clamp_min(1e-9)).log())).sum(-1).mean()}
107
+ if idx is None:
108
+ out = 0.0
109
+ for e_id, expert in enumerate(self.experts):
110
+ out = out + probs[..., e_id].unsqueeze(-1) * expert(h)
111
+ return h + self.alpha * out, stats
112
+ B, T, D = h.shape; K = idx.shape[-1]
113
+ out = torch.zeros_like(h)
114
+ with torch.no_grad():
115
+ flat_idx = idx.reshape(-1, K)
116
+ counts = torch.bincount(flat_idx.reshape(-1), minlength=len(self.experts))
117
+ stats["assign_hist"] = counts.float() / counts.sum().clamp_min(1)
118
+ for k in range(K):
119
+ ek = idx[..., k]
120
+ wk = vals[..., k].unsqueeze(-1)
121
+ for e_id, expert in enumerate(self.experts):
122
+ mask = (ek == e_id).unsqueeze(-1)
123
+ if mask.any(): out = out + mask * wk * expert(h)
124
+ return h + self.alpha * out, stats
125
+
126
+
127
+ @dataclass
128
+ class MoEConfig:
129
+ experts: int = 8
130
+ r: int = 128
131
+ bottleneck: int = 192
132
+ tau: float = 1.0
133
+ router_mode: str = "soft"
134
+ alpha: float = 1.0
135
+ blocks: str = "last6"
136
+ apply_to_patches_only: bool = True
137
+ ortho_lambda: float = 1e-3
138
+ freeze_backbone: bool = True
139
+ unfreeze_layernorm: bool = False
140
+
141
+ def _parse_block_indices(n_blocks: int, spec: str) -> List[int]:
142
+ if spec == "all": return list(range(n_blocks))
143
+ if spec == "last6": return list(range(max(0, n_blocks - 6), n_blocks))
144
+ if spec == "last4": return list(range(max(0, n_blocks - 4), n_blocks))
145
+ return [i for i in map(int, spec.split(",")) if 0 <= i < n_blocks]
146
+
147
+ class EigenMoE(nn.Module):
148
+ def __init__(self, vit: nn.Module, cfg: MoEConfig):
149
+ super().__init__()
150
+ self.vit, self.cfg = vit, cfg
151
+
152
+ if cfg.freeze_backbone:
153
+ for p in self.vit.parameters():
154
+ p.requires_grad = False
155
+ if cfg.unfreeze_layernorm:
156
+ for m in self.vit.modules():
157
+ if isinstance(m, nn.LayerNorm):
158
+ for p in m.parameters():
159
+ p.requires_grad = True
160
+
161
+ d_model = getattr(self.vit, "embed_dim", None)
162
+ if d_model is None:
163
+ d_model = self.vit.blocks[0].norm1.normalized_shape[0]
164
+ n_blocks = len(self.vit.blocks)
165
+ self.block_ids = _parse_block_indices(n_blocks, cfg.blocks)
166
+
167
+ self.branches = nn.ModuleDict()
168
+ for i in self.block_ids:
169
+ self.branches[str(i)] = MoEAdapterBranch(
170
+ d_model=d_model,
171
+ n_experts=cfg.experts,
172
+ r=cfg.r,
173
+ bottleneck=cfg.bottleneck,
174
+ tau=cfg.tau,
175
+ router_mode=cfg.router_mode,
176
+ alpha=cfg.alpha,
177
+ apply_to_patches_only=cfg.apply_to_patches_only,
178
+ ortho_lambda=cfg.ortho_lambda,
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
182
+ vit = self.vit
183
+ B = x.shape[0]
184
+ x = vit.patch_embed(x)
185
+
186
+ cls = vit.cls_token.expand(B, -1, -1)
187
+ if getattr(vit, "dist_token", None) is not None:
188
+ dist = vit.dist_token.expand(B, -1, -1)
189
+ x = torch.cat([cls, dist, x], dim=1)
190
+ else:
191
+ x = torch.cat([cls, x], dim=1)
192
+
193
+ if getattr(vit, "pos_embed", None) is not None:
194
+ x = x + vit.pos_embed
195
+ x = vit.pos_drop(x)
196
+
197
+ aux_losses = []
198
+ for i, blk in enumerate(vit.blocks):
199
+ x = blk(x)
200
+ key = str(i)
201
+ if key in self.branches:
202
+ x, stats = self.branches[key](x)
203
+ aux_losses.append(stats["ortho_reg"])
204
+
205
+ x = vit.norm(x)
206
+ if hasattr(vit, "forward_head"):
207
+ logits = vit.forward_head(x, pre_logits=False)
208
+ else:
209
+ logits = vit.head(x[:, 0])
210
+ aux = torch.stack(aux_losses).sum() if aux_losses else logits.new_zeros(())
211
+ return logits, aux
212
+
213
+ def trainable_parameters(self):
214
+ for p in self.parameters():
215
+ if p.requires_grad: yield p
216
+
217
+ def build(
218
+ vit: str = "vit_base_patch16_224",
219
+ num_classes: int = 1000,
220
+ pretrained: bool = True,
221
+ cfg: Optional[MoEConfig] = None,
222
+ ) -> EigenMoE:
223
+ vit = timm.create_model(vit, pretrained=pretrained, num_classes=num_classes)
224
+ if cfg is None:
225
+ cfg = MoEConfig()
226
+ return EigenMoE(vit, cfg)
227
+
228
+
229
+ DEFAULT_HUB_CHECKPOINTS = {
230
+ "vit_base_patch16_224": "eigen_moe_vit_base_patch16_224_imagenet1k.pth",
231
+ "vit_large_patch16_224.augreg_in21k_ft_in1k": "eigen_moe_vit_large_patch16_224.augreg_in21k_ft_in1k_imagenet1k.pth",
232
+ "vit_huge_patch14_224_in21k": "eigen_moe_vit_huge_patch14_224_in21k_imagenet1k.pth",
233
+ }
234
+
235
+
236
+ def default_hub_checkpoint_filename(vit_model_name: str) -> Optional[str]:
237
+ return DEFAULT_HUB_CHECKPOINTS.get(vit_model_name)
238
+
239
+
240
+ def _clean_state_dict(raw_checkpoint: Dict) -> Dict[str, torch.Tensor]:
241
+ if not isinstance(raw_checkpoint, dict):
242
+ raise TypeError(f"Expected checkpoint to be a dict, got {type(raw_checkpoint)}")
243
+
244
+ for key in ("state_dict", "model_state_dict", "model"):
245
+ if key in raw_checkpoint and isinstance(raw_checkpoint[key], dict):
246
+ raw_checkpoint = raw_checkpoint[key]
247
+ break
248
+
249
+ cleaned = {}
250
+ for key, value in raw_checkpoint.items():
251
+ if not isinstance(key, str) or not torch.is_tensor(value):
252
+ continue
253
+ if key.startswith("module."):
254
+ key = key[len("module."):]
255
+ cleaned[key] = value
256
+ if not cleaned:
257
+ raise ValueError("No tensor weights were found in checkpoint.")
258
+ return cleaned
259
+
260
+
261
+ class HFEigenMoE(nn.Module, PyTorchModelHubMixin):
262
+ """Hugging Face Hub wrapper for EigenMoE checkpoints."""
263
+
264
+ def __init__(
265
+ self,
266
+ vit_model_name: str = "vit_base_patch16_224",
267
+ num_classes: int = 1000,
268
+ backbone_pretrained: bool = False,
269
+ moe_config: Optional[Dict] = None,
270
+ ):
271
+ super().__init__()
272
+ cfg = MoEConfig(**(moe_config or {}))
273
+ self.vit_model_name = vit_model_name
274
+ self.num_classes = num_classes
275
+ self.backbone_pretrained = backbone_pretrained
276
+ self.moe_config = asdict(cfg)
277
+ self.model = build(
278
+ vit=vit_model_name,
279
+ num_classes=num_classes,
280
+ pretrained=backbone_pretrained,
281
+ cfg=cfg,
282
+ )
283
+
284
+ def forward(self, pixel_values: torch.Tensor, return_aux: bool = False):
285
+ logits, aux = self.model(pixel_values)
286
+ if return_aux:
287
+ return logits, aux
288
+ return logits
289
+
290
+ def load_checkpoint(
291
+ self,
292
+ checkpoint_path: str,
293
+ map_location: str = "cpu",
294
+ strict: bool = True,
295
+ ):
296
+ checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
297
+ state_dict = _clean_state_dict(checkpoint)
298
+ return self._load_state_dict_flexible(state_dict, strict=strict)
299
+
300
+ def _load_state_dict_flexible(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
301
+ try:
302
+ return self.load_state_dict(state_dict, strict=strict)
303
+ except RuntimeError as wrapper_err:
304
+ try:
305
+ return self.model.load_state_dict(state_dict, strict=strict)
306
+ except RuntimeError as inner_err:
307
+ raise RuntimeError(
308
+ "Failed to load checkpoint into both wrapper and inner EigenMoE model.\n"
309
+ f"Wrapper error: {wrapper_err}\n"
310
+ f"Inner model error: {inner_err}"
311
+ ) from inner_err
312
+
313
+ @classmethod
314
+ def _from_pretrained(
315
+ cls,
316
+ *,
317
+ model_id: str,
318
+ revision: Optional[str],
319
+ cache_dir: Optional[str],
320
+ force_download: bool,
321
+ proxies: Optional[Dict],
322
+ resume_download: Optional[bool],
323
+ local_files_only: bool,
324
+ token: Optional[str],
325
+ map_location: str = "cpu",
326
+ strict: bool = False,
327
+ **model_kwargs,
328
+ ):
329
+ checkpoint_filename = model_kwargs.pop("checkpoint_filename", None)
330
+ model = cls(**model_kwargs)
331
+
332
+ checkpoint_path = cls._resolve_checkpoint_path(
333
+ model_id=model_id,
334
+ revision=revision,
335
+ cache_dir=cache_dir,
336
+ force_download=force_download,
337
+ proxies=proxies,
338
+ resume_download=resume_download,
339
+ local_files_only=local_files_only,
340
+ token=token,
341
+ checkpoint_filename=checkpoint_filename,
342
+ vit_model_name=model.vit_model_name,
343
+ )
344
+
345
+ if checkpoint_path.endswith(".safetensors"):
346
+ from safetensors.torch import load_file
347
+
348
+ state_dict = load_file(checkpoint_path, device=map_location)
349
+ else:
350
+ raw = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
351
+ state_dict = _clean_state_dict(raw)
352
+
353
+ model._load_state_dict_flexible(state_dict, strict=strict)
354
+ return model
355
+
356
+ @classmethod
357
+ def _resolve_checkpoint_path(
358
+ cls,
359
+ *,
360
+ model_id: str,
361
+ revision: Optional[str],
362
+ cache_dir: Optional[str],
363
+ force_download: bool,
364
+ proxies: Optional[Dict],
365
+ resume_download: Optional[bool],
366
+ local_files_only: bool,
367
+ token: Optional[str],
368
+ checkpoint_filename: Optional[str],
369
+ vit_model_name: str,
370
+ ) -> str:
371
+ if os.path.isdir(model_id):
372
+ return cls._resolve_local_checkpoint(model_id, checkpoint_filename, vit_model_name)
373
+ return cls._resolve_remote_checkpoint(
374
+ model_id=model_id,
375
+ revision=revision,
376
+ cache_dir=cache_dir,
377
+ force_download=force_download,
378
+ proxies=proxies,
379
+ resume_download=resume_download,
380
+ local_files_only=local_files_only,
381
+ token=token,
382
+ checkpoint_filename=checkpoint_filename,
383
+ vit_model_name=vit_model_name,
384
+ )
385
+
386
+ @staticmethod
387
+ def _resolve_local_checkpoint(
388
+ model_dir: str,
389
+ checkpoint_filename: Optional[str],
390
+ vit_model_name: str,
391
+ ) -> str:
392
+ base = Path(model_dir)
393
+ candidates = []
394
+ if checkpoint_filename:
395
+ candidates.append(checkpoint_filename)
396
+ default_name = default_hub_checkpoint_filename(vit_model_name)
397
+ if default_name:
398
+ candidates.append(default_name)
399
+ candidates.extend(["model.safetensors", "pytorch_model.bin"])
400
+
401
+ for filename in candidates:
402
+ path = base / filename
403
+ if path.exists():
404
+ return str(path)
405
+
406
+ pth_files = sorted(base.glob("*.pth"))
407
+ if pth_files:
408
+ return str(pth_files[0])
409
+
410
+ raise FileNotFoundError(
411
+ f"Could not find a checkpoint in local directory: {model_dir}. "
412
+ f"Tried {candidates} and '*.pth'."
413
+ )
414
+
415
+ @staticmethod
416
+ def _resolve_remote_checkpoint(
417
+ *,
418
+ model_id: str,
419
+ revision: Optional[str],
420
+ cache_dir: Optional[str],
421
+ force_download: bool,
422
+ proxies: Optional[Dict],
423
+ resume_download: Optional[bool],
424
+ local_files_only: bool,
425
+ token: Optional[str],
426
+ checkpoint_filename: Optional[str],
427
+ vit_model_name: str,
428
+ ) -> str:
429
+ if hf_hub_download is None:
430
+ raise ImportError("huggingface_hub is required to download checkpoints from the Hub.")
431
+
432
+ candidates = []
433
+ if checkpoint_filename:
434
+ candidates.append(checkpoint_filename)
435
+ default_name = default_hub_checkpoint_filename(vit_model_name)
436
+ if default_name:
437
+ candidates.append(default_name)
438
+ candidates.extend(["model.safetensors", "pytorch_model.bin"])
439
+
440
+ seen = set()
441
+ unique_candidates = []
442
+ for name in candidates:
443
+ if name not in seen:
444
+ seen.add(name)
445
+ unique_candidates.append(name)
446
+
447
+ for filename in unique_candidates:
448
+ try:
449
+ return hf_hub_download(
450
+ repo_id=model_id,
451
+ filename=filename,
452
+ revision=revision,
453
+ cache_dir=cache_dir,
454
+ force_download=force_download,
455
+ proxies=proxies,
456
+ resume_download=resume_download,
457
+ token=token,
458
+ local_files_only=local_files_only,
459
+ )
460
+ except EntryNotFoundError:
461
+ continue
462
+
463
+ if HfApi is not None:
464
+ api = HfApi(token=token)
465
+ repo_files = api.list_repo_files(repo_id=model_id, revision=revision)
466
+ weight_files = [name for name in repo_files if name.endswith((".pth", ".pt", ".bin", ".safetensors"))]
467
+ if weight_files:
468
+ return hf_hub_download(
469
+ repo_id=model_id,
470
+ filename=weight_files[0],
471
+ revision=revision,
472
+ cache_dir=cache_dir,
473
+ force_download=force_download,
474
+ proxies=proxies,
475
+ resume_download=resume_download,
476
+ token=token,
477
+ local_files_only=local_files_only,
478
+ )
479
+
480
+ raise FileNotFoundError(
481
+ f"No compatible checkpoint found in Hub repo '{model_id}'. "
482
+ f"Tried {unique_candidates} and a fallback scan for *.pth/*.pt/*.bin/*.safetensors."
483
+ )