File size: 18,169 Bytes
8e9a70d 588e364 8e9a70d 588e364 8e9a70d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 | from dataclasses import asdict, dataclass
from pathlib import Path
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Dict, Tuple
import timm
try:
from huggingface_hub import HfApi, PyTorchModelHubMixin, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
except ImportError: # pragma: no cover - only used when huggingface_hub is unavailable
HfApi = None # type: ignore[assignment]
PyTorchModelHubMixin = object # type: ignore[assignment,misc]
hf_hub_download = None # type: ignore[assignment]
EntryNotFoundError = FileNotFoundError # type: ignore[assignment]
class ENNBasis(nn.Module):
def __init__(self, d_in: int, d_out: int, r: int, ortho_lambda: float = 1e-3):
super().__init__()
assert r <= min(d_in, d_out)
self.d_in, self.d_out, self.r = d_in, d_out, r
self.ortho_lambda = ortho_lambda
Q = torch.empty(d_out, r)
P = torch.empty(d_in, r)
nn.init.orthogonal_(Q)
nn.init.orthogonal_(P)
self.Q = nn.Parameter(Q)
self.P = nn.Parameter(P)
self.log_lambda = nn.Parameter(torch.zeros(r))
@torch.no_grad()
def _qr_retract_(self):
qQ, _ = torch.linalg.qr(self.Q, mode='reduced')
qP, _ = torch.linalg.qr(self.P, mode='reduced')
self.Q.copy_(qQ); self.P.copy_(qP)
def ortho_penalty(self) -> torch.Tensor:
It = torch.eye(self.r, device=self.Q.device, dtype=self.Q.dtype)
t1 = (self.Q.T @ self.Q - It).pow(2).sum()
t2 = (self.P.T @ self.P - It).pow(2).sum()
return self.ortho_lambda * (t1 + t2)
def reconstruct_weight(self) -> torch.Tensor:
lam = torch.diag_embed(self.log_lambda.exp())
return self.Q @ lam @ self.P.T
def project_out(self, h: torch.Tensor) -> torch.Tensor:
return torch.einsum('dr,btd->btr', self.Q, h)
class AdapterExpert(nn.Module):
def __init__(self, d_model, bottleneck=192):
super().__init__()
self.down = nn.Linear(d_model, bottleneck, bias=False)
self.up = nn.Linear(bottleneck, d_model, bias=False)
self.act = nn.GELU()
def forward(self, x): return self.up(self.act(self.down(x)))
class EigenRouter(nn.Module):
def __init__(self, d_model: int, r: int, n_experts: int, tau: float = 1.0, topk: int = 0,
ortho_lambda: float = 1e-3):
super().__init__()
self.n_experts, self.topk, self.tau = n_experts, topk, tau
self.basis = ENNBasis(d_in=d_model, d_out=d_model, r=r, ortho_lambda=ortho_lambda)
self.gamma = nn.Parameter(torch.ones(r))
self.masks = nn.Parameter(torch.randn(n_experts, r))
self.bias = nn.Parameter(torch.zeros(n_experts))
def forward(self, h: torch.Tensor):
if self.training: self.basis._qr_retract_()
z = self.basis.project_out(h)
e = z.pow(2)
e = e / (e.sum(dim=-1, keepdim=True) + 1e-6)
m = torch.softmax(self.masks, dim=0)
logits = torch.einsum('btr,r,er->bte', e, self.gamma, m) + self.bias
probs = F.softmax(logits / self.tau, dim=-1)
ortho = self.basis.ortho_penalty()
if self.topk and self.topk < self.n_experts:
vals, idx = torch.topk(probs, k=self.topk, dim=-1)
return probs, vals, idx, ortho
return probs, None, None, ortho
class MoEAdapterBranch(nn.Module):
def __init__(self, d_model: int, n_experts: int = 8, r: int = 128, bottleneck: int = 192,
tau: float = 1.0, router_mode: str = "soft", alpha: float = 1.0,
apply_to_patches_only: bool = True, ortho_lambda: float = 1e-3):
super().__init__()
topk = 0 if router_mode == "soft" else (1 if router_mode == "top1" else 2)
self.router = EigenRouter(d_model, r, n_experts, tau, topk, ortho_lambda)
self.experts = nn.ModuleList([AdapterExpert(d_model, bottleneck) for _ in range(n_experts)])
self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
self.apply_to_patches_only = apply_to_patches_only
def forward(self, x: torch.Tensor):
if self.apply_to_patches_only and x.dim() == 3 and x.size(1) >= 2:
cls_tok, patches = x[:, :1, :], x[:, 1:, :]
y, stats = self._forward_tokens(patches)
return torch.cat([cls_tok, y], dim=1), stats
else:
return self._forward_tokens(x)
def _forward_tokens(self, h: torch.Tensor):
probs, vals, idx, ortho = self.router(h)
stats = {"ortho_reg": ortho, "router_entropy": (-(probs * (probs.clamp_min(1e-9)).log())).sum(-1).mean()}
if idx is None:
out = 0.0
for e_id, expert in enumerate(self.experts):
out = out + probs[..., e_id].unsqueeze(-1) * expert(h)
return h + self.alpha * out, stats
B, T, D = h.shape; K = idx.shape[-1]
out = torch.zeros_like(h)
with torch.no_grad():
flat_idx = idx.reshape(-1, K)
counts = torch.bincount(flat_idx.reshape(-1), minlength=len(self.experts))
stats["assign_hist"] = counts.float() / counts.sum().clamp_min(1)
for k in range(K):
ek = idx[..., k]
wk = vals[..., k].unsqueeze(-1)
for e_id, expert in enumerate(self.experts):
mask = (ek == e_id).unsqueeze(-1)
if mask.any(): out = out + mask * wk * expert(h)
return h + self.alpha * out, stats
@dataclass
class MoEConfig:
experts: int = 8
r: int = 128
bottleneck: int = 192
tau: float = 1.0
router_mode: str = "soft"
alpha: float = 1.0
blocks: str = "last6"
apply_to_patches_only: bool = True
ortho_lambda: float = 1e-3
freeze_backbone: bool = True
unfreeze_layernorm: bool = False
def _parse_block_indices(n_blocks: int, spec: str) -> List[int]:
if spec == "all": return list(range(n_blocks))
if spec == "last6": return list(range(max(0, n_blocks - 6), n_blocks))
if spec == "last4": return list(range(max(0, n_blocks - 4), n_blocks))
return [i for i in map(int, spec.split(",")) if 0 <= i < n_blocks]
class EigenMoE(nn.Module):
def __init__(self, vit: nn.Module, cfg: MoEConfig):
super().__init__()
self.vit, self.cfg = vit, cfg
if cfg.freeze_backbone:
for p in self.vit.parameters():
p.requires_grad = False
if cfg.unfreeze_layernorm:
for m in self.vit.modules():
if isinstance(m, nn.LayerNorm):
for p in m.parameters():
p.requires_grad = True
d_model = getattr(self.vit, "embed_dim", None)
if d_model is None:
d_model = self.vit.blocks[0].norm1.normalized_shape[0]
n_blocks = len(self.vit.blocks)
self.block_ids = _parse_block_indices(n_blocks, cfg.blocks)
self.branches = nn.ModuleDict()
for i in self.block_ids:
self.branches[str(i)] = MoEAdapterBranch(
d_model=d_model,
n_experts=cfg.experts,
r=cfg.r,
bottleneck=cfg.bottleneck,
tau=cfg.tau,
router_mode=cfg.router_mode,
alpha=cfg.alpha,
apply_to_patches_only=cfg.apply_to_patches_only,
ortho_lambda=cfg.ortho_lambda,
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
vit = self.vit
B = x.shape[0]
x = vit.patch_embed(x)
cls = vit.cls_token.expand(B, -1, -1)
if getattr(vit, "dist_token", None) is not None:
dist = vit.dist_token.expand(B, -1, -1)
x = torch.cat([cls, dist, x], dim=1)
else:
x = torch.cat([cls, x], dim=1)
if getattr(vit, "pos_embed", None) is not None:
x = x + vit.pos_embed
x = vit.pos_drop(x)
aux_losses = []
for i, blk in enumerate(vit.blocks):
x = blk(x)
key = str(i)
if key in self.branches:
x, stats = self.branches[key](x)
aux_losses.append(stats["ortho_reg"])
x = vit.norm(x)
if hasattr(vit, "forward_head"):
logits = vit.forward_head(x, pre_logits=False)
else:
logits = vit.head(x[:, 0])
aux = torch.stack(aux_losses).sum() if aux_losses else logits.new_zeros(())
return logits, aux
def trainable_parameters(self):
for p in self.parameters():
if p.requires_grad: yield p
def build(
vit: str = "vit_base_patch16_224",
num_classes: int = 1000,
pretrained: bool = True,
cfg: Optional[MoEConfig] = None,
) -> EigenMoE:
vit = timm.create_model(vit, pretrained=pretrained, num_classes=num_classes)
if cfg is None:
cfg = MoEConfig()
return EigenMoE(vit, cfg)
DEFAULT_HUB_CHECKPOINTS = {
"vit_base_patch16_224": "eigen_moe_vit_base_patch16_224_imagenet1k.pth",
"vit_large_patch16_224.augreg_in21k_ft_in1k": "eigen_moe_vit_large_patch16_224.augreg_in21k_ft_in1k_imagenet1k.pth",
"vit_huge_patch14_224_in21k": "eigen_moe_vit_huge_patch14_224_in21k_imagenet1k.pth",
}
def default_hub_checkpoint_filename(vit_model_name: str) -> Optional[str]:
return DEFAULT_HUB_CHECKPOINTS.get(vit_model_name)
def _clean_state_dict(raw_checkpoint: Dict) -> Dict[str, torch.Tensor]:
if not isinstance(raw_checkpoint, dict):
raise TypeError(f"Expected checkpoint to be a dict, got {type(raw_checkpoint)}")
for key in ("state_dict", "model_state_dict", "model"):
if key in raw_checkpoint and isinstance(raw_checkpoint[key], dict):
raw_checkpoint = raw_checkpoint[key]
break
cleaned = {}
for key, value in raw_checkpoint.items():
if not isinstance(key, str) or not torch.is_tensor(value):
continue
if key.startswith("module."):
key = key[len("module."):]
cleaned[key] = value
if not cleaned:
raise ValueError("No tensor weights were found in checkpoint.")
return cleaned
class HFEigenMoE(nn.Module, PyTorchModelHubMixin):
"""Hugging Face Hub wrapper for EigenMoE checkpoints."""
def __init__(
self,
vit_model_name: str = "vit_base_patch16_224",
num_classes: int = 1000,
backbone_pretrained: bool = False,
moe_config: Optional[Dict] = None,
):
super().__init__()
cfg = MoEConfig(**(moe_config or {}))
self.vit_model_name = vit_model_name
self.num_classes = num_classes
self.backbone_pretrained = backbone_pretrained
self.moe_config = asdict(cfg)
self.model = build(
vit=vit_model_name,
num_classes=num_classes,
pretrained=backbone_pretrained,
cfg=cfg,
)
def forward(self, pixel_values: torch.Tensor, return_aux: bool = False):
logits, aux = self.model(pixel_values)
if return_aux:
return logits, aux
return logits
def load_checkpoint(
self,
checkpoint_path: str,
map_location: str = "cpu",
strict: bool = True,
):
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
state_dict = _clean_state_dict(checkpoint)
return self._load_state_dict_flexible(state_dict, strict=strict)
def _load_state_dict_flexible(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
try:
return self.load_state_dict(state_dict, strict=strict)
except RuntimeError as wrapper_err:
try:
return self.model.load_state_dict(state_dict, strict=strict)
except RuntimeError as inner_err:
raise RuntimeError(
"Failed to load checkpoint into both wrapper and inner EigenMoE model.\n"
f"Wrapper error: {wrapper_err}\n"
f"Inner model error: {inner_err}"
) from inner_err
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[str],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[str],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
checkpoint_filename = model_kwargs.pop("checkpoint_filename", None)
model = cls(**model_kwargs)
checkpoint_path = cls._resolve_checkpoint_path(
model_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
checkpoint_filename=checkpoint_filename,
vit_model_name=model.vit_model_name,
)
if checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(checkpoint_path, device=map_location)
else:
raw = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
state_dict = _clean_state_dict(raw)
model._load_state_dict_flexible(state_dict, strict=strict)
return model
@classmethod
def _resolve_checkpoint_path(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[str],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[str],
checkpoint_filename: Optional[str],
vit_model_name: str,
) -> str:
if os.path.isdir(model_id):
return cls._resolve_local_checkpoint(model_id, checkpoint_filename, vit_model_name)
return cls._resolve_remote_checkpoint(
model_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
checkpoint_filename=checkpoint_filename,
vit_model_name=vit_model_name,
)
@staticmethod
def _resolve_local_checkpoint(
model_dir: str,
checkpoint_filename: Optional[str],
vit_model_name: str,
) -> str:
base = Path(model_dir)
if checkpoint_filename:
candidates = [checkpoint_filename]
else:
candidates = ["model.safetensors", "pytorch_model.bin"]
default_name = default_hub_checkpoint_filename(vit_model_name)
if default_name:
candidates.append(default_name)
for filename in candidates:
path = base / filename
if path.exists():
return str(path)
pth_files = sorted(base.glob("*.pth"))
if pth_files:
return str(pth_files[0])
raise FileNotFoundError(
f"Could not find a checkpoint in local directory: {model_dir}. "
f"Tried {candidates} and '*.pth'."
)
@staticmethod
def _resolve_remote_checkpoint(
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[str],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[str],
checkpoint_filename: Optional[str],
vit_model_name: str,
) -> str:
if hf_hub_download is None:
raise ImportError("huggingface_hub is required to download checkpoints from the Hub.")
if checkpoint_filename:
candidates = [checkpoint_filename]
else:
candidates = ["model.safetensors", "pytorch_model.bin"]
default_name = default_hub_checkpoint_filename(vit_model_name)
if default_name:
candidates.append(default_name)
seen = set()
unique_candidates = []
for name in candidates:
if name not in seen:
seen.add(name)
unique_candidates.append(name)
for filename in unique_candidates:
try:
return hf_hub_download(
repo_id=model_id,
filename=filename,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except EntryNotFoundError:
continue
if HfApi is not None:
api = HfApi(token=token)
repo_files = api.list_repo_files(repo_id=model_id, revision=revision)
weight_files = [name for name in repo_files if name.endswith((".pth", ".pt", ".bin", ".safetensors"))]
if weight_files:
return hf_hub_download(
repo_id=model_id,
filename=weight_files[0],
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
raise FileNotFoundError(
f"No compatible checkpoint found in Hub repo '{model_id}'. "
f"Tried {unique_candidates} and a fallback scan for *.pth/*.pt/*.bin/*.safetensors."
)
|