File size: 18,169 Bytes
8e9a70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588e364
 
 
 
 
 
8e9a70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588e364
 
 
 
 
 
8e9a70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
from dataclasses import asdict, dataclass
from pathlib import Path
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Dict, Tuple
import timm

try:
    from huggingface_hub import HfApi, PyTorchModelHubMixin, hf_hub_download
    from huggingface_hub.utils import EntryNotFoundError
except ImportError:  # pragma: no cover - only used when huggingface_hub is unavailable
    HfApi = None  # type: ignore[assignment]
    PyTorchModelHubMixin = object  # type: ignore[assignment,misc]
    hf_hub_download = None  # type: ignore[assignment]
    EntryNotFoundError = FileNotFoundError  # type: ignore[assignment]

class ENNBasis(nn.Module):
    def __init__(self, d_in: int, d_out: int, r: int, ortho_lambda: float = 1e-3):
        super().__init__()
        assert r <= min(d_in, d_out)
        self.d_in, self.d_out, self.r = d_in, d_out, r
        self.ortho_lambda = ortho_lambda

        Q = torch.empty(d_out, r)
        P = torch.empty(d_in,  r)
        nn.init.orthogonal_(Q)
        nn.init.orthogonal_(P)
        self.Q = nn.Parameter(Q)               
        self.P = nn.Parameter(P)                
        self.log_lambda = nn.Parameter(torch.zeros(r)) 

    @torch.no_grad()
    def _qr_retract_(self):
        qQ, _ = torch.linalg.qr(self.Q, mode='reduced')
        qP, _ = torch.linalg.qr(self.P, mode='reduced')
        self.Q.copy_(qQ); self.P.copy_(qP)

    def ortho_penalty(self) -> torch.Tensor:
        It = torch.eye(self.r, device=self.Q.device, dtype=self.Q.dtype)
        t1 = (self.Q.T @ self.Q - It).pow(2).sum()
        t2 = (self.P.T @ self.P - It).pow(2).sum()
        return self.ortho_lambda * (t1 + t2)

    def reconstruct_weight(self) -> torch.Tensor:
        lam = torch.diag_embed(self.log_lambda.exp())  
        return self.Q @ lam @ self.P.T         

    def project_out(self, h: torch.Tensor) -> torch.Tensor:
        return torch.einsum('dr,btd->btr', self.Q, h)  

class AdapterExpert(nn.Module):
    def __init__(self, d_model, bottleneck=192):
        super().__init__()
        self.down = nn.Linear(d_model, bottleneck, bias=False)
        self.up   = nn.Linear(bottleneck, d_model, bias=False)
        self.act  = nn.GELU()
    def forward(self, x): return self.up(self.act(self.down(x)))

class EigenRouter(nn.Module):
    def __init__(self, d_model: int, r: int, n_experts: int, tau: float = 1.0, topk: int = 0,
                 ortho_lambda: float = 1e-3):
        super().__init__()
        self.n_experts, self.topk, self.tau = n_experts, topk, tau
        self.basis = ENNBasis(d_in=d_model, d_out=d_model, r=r, ortho_lambda=ortho_lambda)
        self.gamma  = nn.Parameter(torch.ones(r))
        self.masks  = nn.Parameter(torch.randn(n_experts, r))  
        self.bias   = nn.Parameter(torch.zeros(n_experts))

    def forward(self, h: torch.Tensor):
        if self.training: self.basis._qr_retract_()
        z = self.basis.project_out(h)                       
        e = z.pow(2)
        e = e / (e.sum(dim=-1, keepdim=True) + 1e-6)     
        m = torch.softmax(self.masks, dim=0)               
        logits = torch.einsum('btr,r,er->bte', e, self.gamma, m) + self.bias
        probs  = F.softmax(logits / self.tau, dim=-1)    
        ortho  = self.basis.ortho_penalty()                
        if self.topk and self.topk < self.n_experts:
            vals, idx = torch.topk(probs, k=self.topk, dim=-1)  
            return probs, vals, idx, ortho
        return probs, None, None, ortho

class MoEAdapterBranch(nn.Module):
    def __init__(self, d_model: int, n_experts: int = 8, r: int = 128, bottleneck: int = 192,
                 tau: float = 1.0, router_mode: str = "soft", alpha: float = 1.0,
                 apply_to_patches_only: bool = True, ortho_lambda: float = 1e-3):
        super().__init__()
        topk = 0 if router_mode == "soft" else (1 if router_mode == "top1" else 2)
        self.router = EigenRouter(d_model, r, n_experts, tau, topk, ortho_lambda)
        self.experts = nn.ModuleList([AdapterExpert(d_model, bottleneck) for _ in range(n_experts)])
        self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
        self.apply_to_patches_only = apply_to_patches_only

    def forward(self, x: torch.Tensor):
        if self.apply_to_patches_only and x.dim() == 3 and x.size(1) >= 2:
            cls_tok, patches = x[:, :1, :], x[:, 1:, :]
            y, stats = self._forward_tokens(patches)
            return torch.cat([cls_tok, y], dim=1), stats
        else:
            return self._forward_tokens(x)

    def _forward_tokens(self, h: torch.Tensor):
        probs, vals, idx, ortho = self.router(h)
        stats = {"ortho_reg": ortho, "router_entropy": (-(probs * (probs.clamp_min(1e-9)).log())).sum(-1).mean()}
        if idx is None:
            out = 0.0
            for e_id, expert in enumerate(self.experts):
                out = out + probs[..., e_id].unsqueeze(-1) * expert(h)
            return h + self.alpha * out, stats
        B, T, D = h.shape; K = idx.shape[-1]
        out = torch.zeros_like(h)
        with torch.no_grad():
            flat_idx = idx.reshape(-1, K)
            counts = torch.bincount(flat_idx.reshape(-1), minlength=len(self.experts))
            stats["assign_hist"] = counts.float() / counts.sum().clamp_min(1)
        for k in range(K):
            ek = idx[..., k]               
            wk = vals[..., k].unsqueeze(-1) 
            for e_id, expert in enumerate(self.experts):
                mask = (ek == e_id).unsqueeze(-1)
                if mask.any(): out = out + mask * wk * expert(h)
        return h + self.alpha * out, stats


@dataclass
class MoEConfig:
    experts: int = 8
    r: int = 128
    bottleneck: int = 192
    tau: float = 1.0
    router_mode: str = "soft"  
    alpha: float = 1.0
    blocks: str = "last6"     
    apply_to_patches_only: bool = True
    ortho_lambda: float = 1e-3
    freeze_backbone: bool = True
    unfreeze_layernorm: bool = False

def _parse_block_indices(n_blocks: int, spec: str) -> List[int]:
    if spec == "all":   return list(range(n_blocks))
    if spec == "last6": return list(range(max(0, n_blocks - 6), n_blocks))
    if spec == "last4": return list(range(max(0, n_blocks - 4), n_blocks))
    return [i for i in map(int, spec.split(",")) if 0 <= i < n_blocks]

class EigenMoE(nn.Module):
    def __init__(self, vit: nn.Module, cfg: MoEConfig):
        super().__init__()
        self.vit, self.cfg = vit, cfg

        if cfg.freeze_backbone:
            for p in self.vit.parameters():
                p.requires_grad = False
        if cfg.unfreeze_layernorm:
            for m in self.vit.modules():
                if isinstance(m, nn.LayerNorm):
                    for p in m.parameters():
                        p.requires_grad = True

        d_model = getattr(self.vit, "embed_dim", None)
        if d_model is None:
            d_model = self.vit.blocks[0].norm1.normalized_shape[0]
        n_blocks = len(self.vit.blocks)
        self.block_ids = _parse_block_indices(n_blocks, cfg.blocks)

        self.branches = nn.ModuleDict()
        for i in self.block_ids:
            self.branches[str(i)] = MoEAdapterBranch(
                d_model=d_model,
                n_experts=cfg.experts,
                r=cfg.r,
                bottleneck=cfg.bottleneck,
                tau=cfg.tau,
                router_mode=cfg.router_mode,
                alpha=cfg.alpha,
                apply_to_patches_only=cfg.apply_to_patches_only,
                ortho_lambda=cfg.ortho_lambda,
            )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        vit = self.vit
        B = x.shape[0]
        x = vit.patch_embed(x)

        cls = vit.cls_token.expand(B, -1, -1)
        if getattr(vit, "dist_token", None) is not None:
            dist = vit.dist_token.expand(B, -1, -1)
            x = torch.cat([cls, dist, x], dim=1)
        else:
            x = torch.cat([cls, x], dim=1)

        if getattr(vit, "pos_embed", None) is not None:
            x = x + vit.pos_embed
        x = vit.pos_drop(x)

        aux_losses = []
        for i, blk in enumerate(vit.blocks):
            x = blk(x)
            key = str(i)
            if key in self.branches:
                x, stats = self.branches[key](x)
                aux_losses.append(stats["ortho_reg"])

        x = vit.norm(x)
        if hasattr(vit, "forward_head"):
            logits = vit.forward_head(x, pre_logits=False)
        else:
            logits = vit.head(x[:, 0])
        aux = torch.stack(aux_losses).sum() if aux_losses else logits.new_zeros(())
        return logits, aux

    def trainable_parameters(self):
        for p in self.parameters():
            if p.requires_grad: yield p

def build(
    vit: str = "vit_base_patch16_224",
    num_classes: int = 1000,
    pretrained: bool = True,
    cfg: Optional[MoEConfig] = None,
) -> EigenMoE:
    vit = timm.create_model(vit, pretrained=pretrained, num_classes=num_classes)
    if cfg is None:
        cfg = MoEConfig()
    return EigenMoE(vit, cfg)


DEFAULT_HUB_CHECKPOINTS = {
    "vit_base_patch16_224": "eigen_moe_vit_base_patch16_224_imagenet1k.pth",
    "vit_large_patch16_224.augreg_in21k_ft_in1k": "eigen_moe_vit_large_patch16_224.augreg_in21k_ft_in1k_imagenet1k.pth",
    "vit_huge_patch14_224_in21k": "eigen_moe_vit_huge_patch14_224_in21k_imagenet1k.pth",
}


def default_hub_checkpoint_filename(vit_model_name: str) -> Optional[str]:
    return DEFAULT_HUB_CHECKPOINTS.get(vit_model_name)


def _clean_state_dict(raw_checkpoint: Dict) -> Dict[str, torch.Tensor]:
    if not isinstance(raw_checkpoint, dict):
        raise TypeError(f"Expected checkpoint to be a dict, got {type(raw_checkpoint)}")

    for key in ("state_dict", "model_state_dict", "model"):
        if key in raw_checkpoint and isinstance(raw_checkpoint[key], dict):
            raw_checkpoint = raw_checkpoint[key]
            break

    cleaned = {}
    for key, value in raw_checkpoint.items():
        if not isinstance(key, str) or not torch.is_tensor(value):
            continue
        if key.startswith("module."):
            key = key[len("module."):]
        cleaned[key] = value
    if not cleaned:
        raise ValueError("No tensor weights were found in checkpoint.")
    return cleaned


class HFEigenMoE(nn.Module, PyTorchModelHubMixin):
    """Hugging Face Hub wrapper for EigenMoE checkpoints."""

    def __init__(
        self,
        vit_model_name: str = "vit_base_patch16_224",
        num_classes: int = 1000,
        backbone_pretrained: bool = False,
        moe_config: Optional[Dict] = None,
    ):
        super().__init__()
        cfg = MoEConfig(**(moe_config or {}))
        self.vit_model_name = vit_model_name
        self.num_classes = num_classes
        self.backbone_pretrained = backbone_pretrained
        self.moe_config = asdict(cfg)
        self.model = build(
            vit=vit_model_name,
            num_classes=num_classes,
            pretrained=backbone_pretrained,
            cfg=cfg,
        )

    def forward(self, pixel_values: torch.Tensor, return_aux: bool = False):
        logits, aux = self.model(pixel_values)
        if return_aux:
            return logits, aux
        return logits

    def load_checkpoint(
        self,
        checkpoint_path: str,
        map_location: str = "cpu",
        strict: bool = True,
    ):
        checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
        state_dict = _clean_state_dict(checkpoint)
        return self._load_state_dict_flexible(state_dict, strict=strict)

    def _load_state_dict_flexible(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
        try:
            return self.load_state_dict(state_dict, strict=strict)
        except RuntimeError as wrapper_err:
            try:
                return self.model.load_state_dict(state_dict, strict=strict)
            except RuntimeError as inner_err:
                raise RuntimeError(
                    "Failed to load checkpoint into both wrapper and inner EigenMoE model.\n"
                    f"Wrapper error: {wrapper_err}\n"
                    f"Inner model error: {inner_err}"
                ) from inner_err

    @classmethod
    def _from_pretrained(
        cls,
        *,
        model_id: str,
        revision: Optional[str],
        cache_dir: Optional[str],
        force_download: bool,
        proxies: Optional[Dict],
        resume_download: Optional[bool],
        local_files_only: bool,
        token: Optional[str],
        map_location: str = "cpu",
        strict: bool = False,
        **model_kwargs,
    ):
        checkpoint_filename = model_kwargs.pop("checkpoint_filename", None)
        model = cls(**model_kwargs)

        checkpoint_path = cls._resolve_checkpoint_path(
            model_id=model_id,
            revision=revision,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            local_files_only=local_files_only,
            token=token,
            checkpoint_filename=checkpoint_filename,
            vit_model_name=model.vit_model_name,
        )

        if checkpoint_path.endswith(".safetensors"):
            from safetensors.torch import load_file

            state_dict = load_file(checkpoint_path, device=map_location)
        else:
            raw = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
            state_dict = _clean_state_dict(raw)

        model._load_state_dict_flexible(state_dict, strict=strict)
        return model

    @classmethod
    def _resolve_checkpoint_path(
        cls,
        *,
        model_id: str,
        revision: Optional[str],
        cache_dir: Optional[str],
        force_download: bool,
        proxies: Optional[Dict],
        resume_download: Optional[bool],
        local_files_only: bool,
        token: Optional[str],
        checkpoint_filename: Optional[str],
        vit_model_name: str,
    ) -> str:
        if os.path.isdir(model_id):
            return cls._resolve_local_checkpoint(model_id, checkpoint_filename, vit_model_name)
        return cls._resolve_remote_checkpoint(
            model_id=model_id,
            revision=revision,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            local_files_only=local_files_only,
            token=token,
            checkpoint_filename=checkpoint_filename,
            vit_model_name=vit_model_name,
        )

    @staticmethod
    def _resolve_local_checkpoint(
        model_dir: str,
        checkpoint_filename: Optional[str],
        vit_model_name: str,
    ) -> str:
        base = Path(model_dir)
        if checkpoint_filename:
            candidates = [checkpoint_filename]
        else:
            candidates = ["model.safetensors", "pytorch_model.bin"]
            default_name = default_hub_checkpoint_filename(vit_model_name)
            if default_name:
                candidates.append(default_name)

        for filename in candidates:
            path = base / filename
            if path.exists():
                return str(path)

        pth_files = sorted(base.glob("*.pth"))
        if pth_files:
            return str(pth_files[0])

        raise FileNotFoundError(
            f"Could not find a checkpoint in local directory: {model_dir}. "
            f"Tried {candidates} and '*.pth'."
        )

    @staticmethod
    def _resolve_remote_checkpoint(
        *,
        model_id: str,
        revision: Optional[str],
        cache_dir: Optional[str],
        force_download: bool,
        proxies: Optional[Dict],
        resume_download: Optional[bool],
        local_files_only: bool,
        token: Optional[str],
        checkpoint_filename: Optional[str],
        vit_model_name: str,
    ) -> str:
        if hf_hub_download is None:
            raise ImportError("huggingface_hub is required to download checkpoints from the Hub.")

        if checkpoint_filename:
            candidates = [checkpoint_filename]
        else:
            candidates = ["model.safetensors", "pytorch_model.bin"]
            default_name = default_hub_checkpoint_filename(vit_model_name)
            if default_name:
                candidates.append(default_name)

        seen = set()
        unique_candidates = []
        for name in candidates:
            if name not in seen:
                seen.add(name)
                unique_candidates.append(name)

        for filename in unique_candidates:
            try:
                return hf_hub_download(
                    repo_id=model_id,
                    filename=filename,
                    revision=revision,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    token=token,
                    local_files_only=local_files_only,
                )
            except EntryNotFoundError:
                continue

        if HfApi is not None:
            api = HfApi(token=token)
            repo_files = api.list_repo_files(repo_id=model_id, revision=revision)
            weight_files = [name for name in repo_files if name.endswith((".pth", ".pt", ".bin", ".safetensors"))]
            if weight_files:
                return hf_hub_download(
                    repo_id=model_id,
                    filename=weight_files[0],
                    revision=revision,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    token=token,
                    local_files_only=local_files_only,
                )

        raise FileNotFoundError(
            f"No compatible checkpoint found in Hub repo '{model_id}'. "
            f"Tried {unique_candidates} and a fallback scan for *.pth/*.pt/*.bin/*.safetensors."
        )