|
|
""" |
|
|
ViT-Beatrix V5 - Contrarian Tower Collective |
|
|
============================================ |
|
|
|
|
|
Architecture using geofractal router infrastructure with pos/neg tower pairs. |
|
|
|
|
|
Key insights from V4 200-epoch run: |
|
|
- λ converged to 0.217 ≈ 1/5 (structure ~15% of routing) |
|
|
- patch_weight → -0.575 (emergent contrastive readout) |
|
|
- Model naturally learned to subtract common-mode signal |
|
|
|
|
|
V5 Design: |
|
|
- Explicit pos/neg tower pairs (what V4 learned implicitly) |
|
|
- WideRouter for parallel tower execution |
|
|
- Contrastive fusion: pos_output - α * neg_output |
|
|
- Cantor routing within each tower |
|
|
|
|
|
Geofractal infrastructure: |
|
|
- BaseTower: stages as nn.ModuleList |
|
|
- WideRouter: discover_towers(), wide_forward() |
|
|
- TorchComponent: for attention blocks |
|
|
- FusionComponent pattern for contrastive fusion |
|
|
|
|
|
COLAB SETUP: |
|
|
------------ |
|
|
# Install geofractal first: |
|
|
try: |
|
|
!pip uninstall -qy geofractal geometricvocab |
|
|
except: |
|
|
pass |
|
|
!pip install -q git+https://github.com/AbstractEyes/geofractal.git |
|
|
|
|
|
Copyright 2025 AbstractPhil |
|
|
Licensed under the Apache License, Version 2.0 |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Dict, List, Tuple |
|
|
from dataclasses import dataclass |
|
|
from datetime import datetime |
|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm.auto import tqdm |
|
|
from huggingface_hub import HfApi, upload_folder |
|
|
|
|
|
|
|
|
from geofractal.router.base_tower import BaseTower |
|
|
from geofractal.router.wide_router import WideRouter |
|
|
from geofractal.router.components.torch_component import TorchComponent |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BeatrixV5Config: |
|
|
image_size: int = 32 |
|
|
patch_size: int = 4 |
|
|
in_channels: int = 3 |
|
|
embed_dim: int = 384 |
|
|
depth: int = 6 |
|
|
num_heads: int = 6 |
|
|
mlp_ratio: float = 4.0 |
|
|
|
|
|
|
|
|
num_tower_pairs: int = 2 |
|
|
|
|
|
|
|
|
cantor_levels: int = 5 |
|
|
cantor_tau: float = 0.25 |
|
|
routing_weight_init: float = 0.22 |
|
|
learnable_routing_weight: bool = True |
|
|
num_wormholes: int = 8 |
|
|
wormhole_temperature: float = 0.1 |
|
|
|
|
|
|
|
|
contrastive_alpha_init: float = 0.5 |
|
|
|
|
|
dropout: float = 0.1 |
|
|
drop_path: float = 0.1 |
|
|
num_classes: int = 100 |
|
|
|
|
|
@property |
|
|
def num_patches(self) -> int: |
|
|
return (self.image_size // self.patch_size) ** 2 |
|
|
|
|
|
@property |
|
|
def head_dim(self) -> int: |
|
|
return self.embed_dim // self.num_heads |
|
|
|
|
|
@property |
|
|
def num_towers(self) -> int: |
|
|
return self.num_tower_pairs * 2 |
|
|
|
|
|
def to_dict(self) -> dict: |
|
|
"""Serialize config to dict for checkpoint saving.""" |
|
|
return { |
|
|
'image_size': self.image_size, |
|
|
'patch_size': self.patch_size, |
|
|
'in_channels': self.in_channels, |
|
|
'embed_dim': self.embed_dim, |
|
|
'depth': self.depth, |
|
|
'num_heads': self.num_heads, |
|
|
'mlp_ratio': self.mlp_ratio, |
|
|
'num_tower_pairs': self.num_tower_pairs, |
|
|
'cantor_levels': self.cantor_levels, |
|
|
'cantor_tau': self.cantor_tau, |
|
|
'routing_weight_init': self.routing_weight_init, |
|
|
'learnable_routing_weight': self.learnable_routing_weight, |
|
|
'num_wormholes': self.num_wormholes, |
|
|
'wormhole_temperature': self.wormhole_temperature, |
|
|
'contrastive_alpha_init': self.contrastive_alpha_init, |
|
|
'dropout': self.dropout, |
|
|
'drop_path': self.drop_path, |
|
|
'num_classes': self.num_classes, |
|
|
'num_patches': self.num_patches, |
|
|
'num_towers': self.num_towers, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BeatrixStaircase(nn.Module): |
|
|
"""Cantor-based branch path encoding.""" |
|
|
|
|
|
def __init__(self, levels: int = 5, tau: float = 0.25, alpha: float = 0.5): |
|
|
super().__init__() |
|
|
self.levels = levels |
|
|
self.tau = tau |
|
|
|
|
|
centers = torch.tensor([0.5, 1.5, 2.5], dtype=torch.float32) |
|
|
self.register_buffer('centers', centers) |
|
|
self.register_buffer('_alpha', torch.tensor(alpha)) |
|
|
|
|
|
scales = 3.0 ** torch.arange(1, levels + 1, dtype=torch.float32) |
|
|
self.register_buffer('scales', scales) |
|
|
|
|
|
level_weights = 0.5 ** torch.arange(1, levels + 1, dtype=torch.float32) |
|
|
self.register_buffer('level_weights', level_weights) |
|
|
|
|
|
def forward(self, x): |
|
|
original_shape = x.shape |
|
|
x = x.clamp(1e-6, 1.0 - 1e-6) |
|
|
x_flat = x.reshape(-1) |
|
|
|
|
|
y = (x_flat.unsqueeze(-1) * self.scales) % 3 |
|
|
d2 = (y.unsqueeze(-1) - self.centers) ** 2 |
|
|
logits = -d2 / (self.tau + 1e-8) |
|
|
branch_path = logits.argmax(dim=-1) |
|
|
|
|
|
return branch_path.reshape(*original_shape, self.levels) |
|
|
|
|
|
|
|
|
class HierarchicalRoutingBias(nn.Module): |
|
|
"""Cantor-based routing bias for attention.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_positions: int, |
|
|
levels: int = 5, |
|
|
tau: float = 0.25, |
|
|
learnable_weight: bool = True, |
|
|
init_weight: float = 0.22, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_positions = num_positions |
|
|
self.levels = levels |
|
|
|
|
|
self.staircase = BeatrixStaircase(levels=levels, tau=tau) |
|
|
|
|
|
positions = torch.linspace(0, 1, num_positions) |
|
|
with torch.no_grad(): |
|
|
branch_paths = self.staircase(positions) |
|
|
self.register_buffer('branch_paths', branch_paths) |
|
|
|
|
|
alignment = self._compute_alignment_matrix(branch_paths) |
|
|
self.register_buffer('alignment_matrix', alignment) |
|
|
|
|
|
if learnable_weight: |
|
|
self.routing_weight = nn.Parameter(torch.tensor(init_weight)) |
|
|
else: |
|
|
self.register_buffer('routing_weight', torch.tensor(init_weight)) |
|
|
|
|
|
def _compute_alignment_matrix(self, paths): |
|
|
P, L = paths.shape |
|
|
level_weights = 0.5 ** torch.arange(1, L + 1, device=paths.device) |
|
|
matches = (paths.unsqueeze(0) == paths.unsqueeze(1)).float() |
|
|
alignment = (matches * level_weights).sum(dim=-1) |
|
|
alignment.fill_diagonal_(0) |
|
|
return alignment |
|
|
|
|
|
def forward(self, content_scores): |
|
|
return content_scores + self.routing_weight * self.alignment_matrix |
|
|
|
|
|
def get_structure_only_scores(self, batch_size: int, device: torch.device): |
|
|
return self.alignment_matrix.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DropPath(nn.Module): |
|
|
def __init__(self, drop_prob: float = 0.0): |
|
|
super().__init__() |
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
def forward(self, x): |
|
|
if self.drop_prob == 0.0 or not self.training: |
|
|
return x |
|
|
keep_prob = 1 - self.drop_prob |
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
|
random_tensor.floor_() |
|
|
return x.div(keep_prob) * random_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WormholeAttention(nn.Module): |
|
|
"""Attention with Cantor-based routing.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int, |
|
|
num_patches: int, |
|
|
num_wormholes: int = 8, |
|
|
temperature: float = 0.1, |
|
|
routing_bias: Optional[HierarchicalRoutingBias] = None, |
|
|
dropout: float = 0.0, |
|
|
layer_idx: int = 0, |
|
|
num_layers: int = 6, |
|
|
inverted: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.scale = self.head_dim ** -0.5 |
|
|
self.num_patches = num_patches |
|
|
self.num_wormholes = min(num_wormholes, num_patches - 1) |
|
|
self.temperature = temperature |
|
|
self.routing_bias = routing_bias |
|
|
self.layer_idx = layer_idx |
|
|
self.is_final_layer = (layer_idx == num_layers - 1) |
|
|
self.inverted = inverted |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(dropout) |
|
|
self.attn_drop = nn.Dropout(dropout) |
|
|
|
|
|
if not self.is_final_layer: |
|
|
self.route_q = nn.Linear(dim, dim) |
|
|
self.route_k = nn.Linear(dim, dim) |
|
|
|
|
|
def _compute_routes(self, x): |
|
|
B, S, D = x.shape |
|
|
P = self.num_patches |
|
|
K = self.num_wormholes |
|
|
|
|
|
x_patches = x[:, 1:, :] |
|
|
|
|
|
if self.is_final_layer: |
|
|
scores = self.routing_bias.get_structure_only_scores(B, x.device) |
|
|
else: |
|
|
q = F.normalize(self.route_q(x_patches), dim=-1) |
|
|
k = F.normalize(self.route_k(x_patches), dim=-1) |
|
|
content_scores = torch.bmm(q, k.transpose(1, 2)) |
|
|
|
|
|
if self.routing_bias is not None: |
|
|
scores = self.routing_bias(content_scores) |
|
|
else: |
|
|
scores = content_scores |
|
|
|
|
|
|
|
|
if self.inverted: |
|
|
scores = -scores |
|
|
|
|
|
mask = torch.eye(P, device=x.device, dtype=torch.bool) |
|
|
scores = scores.masked_fill(mask.unsqueeze(0), -1e9) |
|
|
|
|
|
scores_scaled = scores / self.temperature |
|
|
topk_scores, routes = torch.topk(scores_scaled, K, dim=-1) |
|
|
weights = F.softmax(topk_scores, dim=-1) |
|
|
|
|
|
return routes, weights |
|
|
|
|
|
def _gather_wormhole(self, x, routes): |
|
|
B, H, P, D = x.shape |
|
|
K = routes.shape[-1] |
|
|
|
|
|
x_flat = x.reshape(B * H, P, D) |
|
|
routes_exp = routes.unsqueeze(1).expand(-1, H, -1, -1).reshape(B * H, P * K) |
|
|
routes_exp = routes_exp.unsqueeze(-1).expand(-1, -1, D) |
|
|
|
|
|
gathered = torch.gather(x_flat, 1, routes_exp) |
|
|
return gathered.view(B, H, P, K, D) |
|
|
|
|
|
def forward(self, x): |
|
|
B, S, D = x.shape |
|
|
H = self.num_heads |
|
|
P = self.num_patches |
|
|
head_dim = self.head_dim |
|
|
|
|
|
routes, route_weights = self._compute_routes(x) |
|
|
|
|
|
qkv = self.qkv(x).reshape(B, S, 3, H, head_dim).permute(2, 0, 3, 1, 4) |
|
|
Q, K_full, V = qkv.unbind(0) |
|
|
|
|
|
|
|
|
Q_cls = Q[:, :, :1, :] |
|
|
attn_cls = F.softmax( |
|
|
torch.einsum('bhqd,bhkd->bhqk', Q_cls, K_full) * self.scale, |
|
|
dim=-1 |
|
|
) |
|
|
attn_cls = self.attn_drop(attn_cls) |
|
|
out_cls = torch.einsum('bhqk,bhkd->bhqd', attn_cls, V) |
|
|
|
|
|
|
|
|
Q_patches = Q[:, :, 1:, :] |
|
|
K_patches = K_full[:, :, 1:, :] |
|
|
V_patches = V[:, :, 1:, :] |
|
|
|
|
|
K_gathered = self._gather_wormhole(K_patches, routes) |
|
|
V_gathered = self._gather_wormhole(V_patches, routes) |
|
|
|
|
|
scores_patches = torch.einsum('bhpd,bhpkd->bhpk', Q_patches, K_gathered) * self.scale |
|
|
scores_patches = scores_patches + route_weights.unsqueeze(1).log().clamp(min=-10) |
|
|
|
|
|
attn_patches = F.softmax(scores_patches, dim=-1) |
|
|
attn_patches = self.attn_drop(attn_patches) |
|
|
|
|
|
out_patches = torch.einsum('bhpk,bhpkd->bhpd', attn_patches, V_gathered) |
|
|
|
|
|
out = torch.cat([out_cls, out_patches], dim=2) |
|
|
out = out.transpose(1, 2).reshape(B, S, D) |
|
|
|
|
|
return self.proj_drop(self.proj(out)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BeatrixBlock(TorchComponent): |
|
|
"""Transformer block as TorchComponent for proper stage registration.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
dim: int, |
|
|
num_heads: int, |
|
|
num_patches: int, |
|
|
num_wormholes: int = 8, |
|
|
mlp_ratio: float = 4.0, |
|
|
routing_bias: Optional[HierarchicalRoutingBias] = None, |
|
|
dropout: float = 0.0, |
|
|
drop_path: float = 0.0, |
|
|
layer_idx: int = 0, |
|
|
num_layers: int = 6, |
|
|
inverted: bool = False, |
|
|
): |
|
|
super().__init__(name) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
self.attn = WormholeAttention( |
|
|
dim=dim, num_heads=num_heads, num_patches=num_patches, |
|
|
num_wormholes=num_wormholes, routing_bias=routing_bias, |
|
|
dropout=dropout, layer_idx=layer_idx, num_layers=num_layers, |
|
|
inverted=inverted, |
|
|
) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
mlp_hidden = int(dim * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, mlp_hidden), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(mlp_hidden, dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.drop_path(self.attn(self.norm1(x))) |
|
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BeatrixTower(BaseTower): |
|
|
""" |
|
|
Single tower using geofractal BaseTower infrastructure. |
|
|
|
|
|
Uses: |
|
|
- self.append() to add stages |
|
|
- self.attach() for named components |
|
|
- self.stages for iteration |
|
|
- self['name'] for component access |
|
|
|
|
|
Can be positive (normal) or negative (contrarian/inverted routing). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
config: BeatrixV5Config, |
|
|
inverted: bool = False, |
|
|
): |
|
|
super().__init__(name, strict=False) |
|
|
self.inverted = inverted |
|
|
self._config = config |
|
|
|
|
|
|
|
|
self.attach('routing_bias', HierarchicalRoutingBias( |
|
|
num_positions=config.num_patches, |
|
|
levels=config.cantor_levels, |
|
|
tau=config.cantor_tau, |
|
|
learnable_weight=config.learnable_routing_weight, |
|
|
init_weight=config.routing_weight_init, |
|
|
)) |
|
|
|
|
|
|
|
|
dpr = torch.linspace(0, config.drop_path, config.depth).tolist() |
|
|
for i in range(config.depth): |
|
|
self.append(BeatrixBlock( |
|
|
name=f'{name}_block_{i}', |
|
|
dim=config.embed_dim, |
|
|
num_heads=config.num_heads, |
|
|
num_patches=config.num_patches, |
|
|
num_wormholes=config.num_wormholes, |
|
|
mlp_ratio=config.mlp_ratio, |
|
|
routing_bias=self['routing_bias'], |
|
|
dropout=config.dropout, |
|
|
drop_path=dpr[i], |
|
|
layer_idx=i, |
|
|
num_layers=config.depth, |
|
|
inverted=inverted, |
|
|
)) |
|
|
|
|
|
|
|
|
self.attach('norm', nn.LayerNorm(config.embed_dim)) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
"""Process input and return opinion (CLS token).""" |
|
|
for stage in self.stages: |
|
|
x = stage(x) |
|
|
x = self['norm'](x) |
|
|
return x[:, 0] |
|
|
|
|
|
def get_routing_weight(self) -> float: |
|
|
return self['routing_bias'].routing_weight.item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ContrastiveFusion(TorchComponent): |
|
|
""" |
|
|
Fuses pos/neg tower pairs via learned contrastive combination. |
|
|
|
|
|
For each pair: output = pos + α * neg |
|
|
Where α is learnable and typically becomes negative (subtracting common-mode). |
|
|
|
|
|
This makes explicit what V4 learned implicitly with patch_weight. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
num_pairs: int, |
|
|
dim: int, |
|
|
alpha_init: float = 0.5, |
|
|
): |
|
|
super().__init__(name) |
|
|
self.num_pairs = num_pairs |
|
|
|
|
|
|
|
|
self.alphas = nn.Parameter(torch.full((num_pairs,), alpha_init)) |
|
|
|
|
|
|
|
|
if num_pairs > 1: |
|
|
self.pair_fusion = nn.Linear(dim * num_pairs, dim) |
|
|
else: |
|
|
self.pair_fusion = None |
|
|
|
|
|
def forward(self, pos_opinions: List[Tensor], neg_opinions: List[Tensor]) -> Tensor: |
|
|
""" |
|
|
Args: |
|
|
pos_opinions: List of [B, D] tensors from positive towers |
|
|
neg_opinions: List of [B, D] tensors from negative towers |
|
|
Returns: |
|
|
Fused output [B, D] |
|
|
""" |
|
|
assert len(pos_opinions) == len(neg_opinions) == self.num_pairs |
|
|
|
|
|
|
|
|
fused_pairs = [] |
|
|
for i, (pos, neg) in enumerate(zip(pos_opinions, neg_opinions)): |
|
|
|
|
|
fused = pos + self.alphas[i] * neg |
|
|
fused_pairs.append(fused) |
|
|
|
|
|
if self.pair_fusion is not None: |
|
|
|
|
|
combined = torch.cat(fused_pairs, dim=-1) |
|
|
return self.pair_fusion(combined) |
|
|
else: |
|
|
return fused_pairs[0] |
|
|
|
|
|
def get_alphas(self) -> List[float]: |
|
|
return self.alphas.tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingParams(TorchComponent): |
|
|
"""Wrapper for learnable embedding parameters.""" |
|
|
def __init__(self, name: str, num_patches: int, embed_dim: int): |
|
|
super().__init__(name) |
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim)) |
|
|
nn.init.trunc_normal_(self.cls_token, std=0.02) |
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
B = x.shape[0] |
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
return x + self.pos_embed |
|
|
|
|
|
|
|
|
class BeatrixCollective(WideRouter): |
|
|
""" |
|
|
WideRouter collective managing pos/neg tower pairs. |
|
|
|
|
|
Follows geofractal WideRouter pattern: |
|
|
1. super().__init__(name, auto_discover=True) |
|
|
2. attach towers with self.attach(name, tower) |
|
|
3. call self.discover_towers() AFTER attaching |
|
|
4. wide_forward(x) returns Dict[tower_name, output] |
|
|
|
|
|
"Individual towers don't need to be accurate. |
|
|
They need to see differently. |
|
|
The routing fabric triangulates truth from divergent viewpoints." |
|
|
""" |
|
|
|
|
|
def __init__(self, config: BeatrixV5Config): |
|
|
|
|
|
super().__init__(name='beatrix_collective', auto_discover=True) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.attach('patch_embed', nn.Conv2d( |
|
|
config.in_channels, config.embed_dim, |
|
|
kernel_size=config.patch_size, stride=config.patch_size |
|
|
)) |
|
|
|
|
|
|
|
|
self.attach('embeddings', EmbeddingParams( |
|
|
'embeddings', config.num_patches, config.embed_dim |
|
|
)) |
|
|
self.attach('pos_drop', nn.Dropout(config.dropout)) |
|
|
|
|
|
|
|
|
for i in range(config.num_tower_pairs): |
|
|
pos_name = f'pos_{i}' |
|
|
neg_name = f'neg_{i}' |
|
|
self.attach(pos_name, BeatrixTower(pos_name, config, inverted=False)) |
|
|
self.attach(neg_name, BeatrixTower(neg_name, config, inverted=True)) |
|
|
|
|
|
|
|
|
self.discover_towers() |
|
|
|
|
|
|
|
|
self.attach('fusion', ContrastiveFusion( |
|
|
name='contrastive_fusion', |
|
|
num_pairs=config.num_tower_pairs, |
|
|
dim=config.embed_dim, |
|
|
alpha_init=config.contrastive_alpha_init, |
|
|
)) |
|
|
|
|
|
|
|
|
self.attach('head', nn.Linear(config.embed_dim, config.num_classes)) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.ones_(m.weight) |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def _prepare_input(self, images: Tensor) -> Tensor: |
|
|
"""Shared input preparation: patch embed + pos embed.""" |
|
|
|
|
|
x = self['patch_embed'](images) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
x = self['embeddings'](x) |
|
|
x = self['pos_drop'](x) |
|
|
|
|
|
return x |
|
|
|
|
|
def forward(self, images: Tensor) -> Tensor: |
|
|
|
|
|
x = self._prepare_input(images) |
|
|
|
|
|
|
|
|
opinions = self.wide_forward(x) |
|
|
|
|
|
|
|
|
pos_opinions = [] |
|
|
neg_opinions = [] |
|
|
for i in range(self.config.num_tower_pairs): |
|
|
pos_opinions.append(opinions[f'pos_{i}']) |
|
|
neg_opinions.append(opinions[f'neg_{i}']) |
|
|
|
|
|
|
|
|
fused = self['fusion'](pos_opinions, neg_opinions) |
|
|
|
|
|
|
|
|
return self['head'](fused) |
|
|
|
|
|
def get_diagnostics(self) -> Dict: |
|
|
"""Get diagnostic info about tower states.""" |
|
|
diag = { |
|
|
'fusion_alphas': self['fusion'].get_alphas(), |
|
|
'tower_lambdas': {}, |
|
|
} |
|
|
for name in self.tower_names: |
|
|
diag['tower_lambdas'][name] = self[name].get_routing_weight() |
|
|
return diag |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_beatrix_v5_small(num_classes=100, **kwargs) -> BeatrixCollective: |
|
|
"""Small model: 2 tower pairs, 384 dim, 6 depth.""" |
|
|
config = BeatrixV5Config( |
|
|
embed_dim=384, |
|
|
depth=6, |
|
|
num_heads=6, |
|
|
num_tower_pairs=2, |
|
|
num_wormholes=8, |
|
|
num_classes=num_classes, |
|
|
**kwargs |
|
|
) |
|
|
return BeatrixCollective(config) |
|
|
|
|
|
|
|
|
def create_beatrix_v5_base(num_classes=100, **kwargs) -> BeatrixCollective: |
|
|
"""Base model: 2 tower pairs, 512 dim, 8 depth.""" |
|
|
config = BeatrixV5Config( |
|
|
embed_dim=512, |
|
|
depth=8, |
|
|
num_heads=8, |
|
|
num_tower_pairs=2, |
|
|
num_wormholes=12, |
|
|
num_classes=num_classes, |
|
|
**kwargs |
|
|
) |
|
|
return BeatrixCollective(config) |
|
|
|
|
|
|
|
|
def create_beatrix_v5_wide(num_classes=100, **kwargs) -> BeatrixCollective: |
|
|
"""Wide model: 4 tower pairs, 384 dim, 4 depth.""" |
|
|
config = BeatrixV5Config( |
|
|
embed_dim=512, |
|
|
depth=2, |
|
|
num_heads=8, |
|
|
num_tower_pairs=8, |
|
|
num_wormholes=32, |
|
|
num_classes=num_classes, |
|
|
patch_size=4, |
|
|
**kwargs |
|
|
) |
|
|
return BeatrixCollective(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CosineWarmupScheduler: |
|
|
def __init__(self, optimizer, warmup_epochs, total_epochs, min_lr=1e-6, base_lr=1e-3): |
|
|
self.optimizer = optimizer |
|
|
self.warmup_epochs = warmup_epochs |
|
|
self.total_epochs = total_epochs |
|
|
self.min_lr = min_lr |
|
|
self.base_lr = base_lr |
|
|
|
|
|
def step(self, epoch): |
|
|
if epoch < self.warmup_epochs: |
|
|
lr = self.base_lr * (epoch + 1) / self.warmup_epochs |
|
|
else: |
|
|
progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs) |
|
|
lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + math.cos(math.pi * progress)) |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
return lr |
|
|
|
|
|
|
|
|
def train_epoch(model, loader, criterion, optimizer, device): |
|
|
model.train() |
|
|
total_loss, correct, total = 0, 0, 0 |
|
|
|
|
|
pbar = tqdm(loader, desc='Train', leave=False) |
|
|
for inputs, targets in pbar: |
|
|
inputs, targets = inputs.to(device), targets.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(inputs) |
|
|
loss = criterion(outputs, targets) |
|
|
loss.backward() |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
_, predicted = outputs.max(1) |
|
|
total += targets.size(0) |
|
|
correct += predicted.eq(targets).sum().item() |
|
|
|
|
|
|
|
|
pbar.set_postfix({ |
|
|
'loss': f'{loss.item():.3f}', |
|
|
'acc': f'{100.*correct/total:.1f}%' |
|
|
}) |
|
|
|
|
|
return total_loss / len(loader), 100. * correct / total |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(model, loader, criterion, device): |
|
|
model.eval() |
|
|
total_loss, correct, total = 0, 0, 0 |
|
|
|
|
|
pbar = tqdm(loader, desc='Eval', leave=False) |
|
|
for inputs, targets in pbar: |
|
|
inputs, targets = inputs.to(device), targets.to(device) |
|
|
outputs = model(inputs) |
|
|
loss = criterion(outputs, targets) |
|
|
|
|
|
total_loss += loss.item() |
|
|
_, predicted = outputs.max(1) |
|
|
total += targets.size(0) |
|
|
correct += predicted.eq(targets).sum().item() |
|
|
|
|
|
pbar.set_postfix({'acc': f'{100.*correct/total:.1f}%'}) |
|
|
|
|
|
return total_loss / len(loader), 100. * correct / total |
|
|
|
|
|
|
|
|
def print_diagnostics(epoch: int, model: BeatrixCollective): |
|
|
diag = model.get_diagnostics() |
|
|
|
|
|
print(f"\n ┌─ DIAGNOSTICS (Epoch {epoch}) ─────────────────────────────────────") |
|
|
print(f" │ Fusion alphas (expect negative): {diag['fusion_alphas']}") |
|
|
print(f" │ Tower routing weights (λ):") |
|
|
for name, lam in diag['tower_lambdas'].items(): |
|
|
tower_type = "POS" if name.startswith('pos') else "NEG" |
|
|
print(f" │ {name} ({tower_type}): {lam:.4f}") |
|
|
print(f" └───────────────────────────────────────────────────────────────") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
import torchvision |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_TYPE = 'wide' |
|
|
EPOCHS = 100 |
|
|
BASE_LR = 1e-3 |
|
|
WARMUP_EPOCHS = 10 |
|
|
BATCH_SIZE = 128 |
|
|
|
|
|
|
|
|
print("=" * 70) |
|
|
print("ViT-Beatrix V5 - CONTRARIAN TOWER COLLECTIVE") |
|
|
print("=" * 70) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"\nDevice: {device}") |
|
|
print(f"Model type: {MODEL_TYPE}") |
|
|
|
|
|
|
|
|
if MODEL_TYPE == 'small': |
|
|
model = create_beatrix_v5_small() |
|
|
elif MODEL_TYPE == 'base': |
|
|
model = create_beatrix_v5_base() |
|
|
elif MODEL_TYPE == 'wide': |
|
|
model = create_beatrix_v5_wide() |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {MODEL_TYPE}") |
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"Total parameters: {total_params:,}") |
|
|
print(f"Towers: {model.tower_names}") |
|
|
|
|
|
|
|
|
print("\nPreparing and compiling model...") |
|
|
torch.set_float32_matmul_precision('high') |
|
|
model_raw = model |
|
|
model = model.prepare_and_compile() |
|
|
print("✓ Model compiled") |
|
|
|
|
|
|
|
|
transform_train = transforms.Compose([ |
|
|
transforms.RandomCrop(32, padding=4), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), |
|
|
]) |
|
|
|
|
|
transform_test = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), |
|
|
]) |
|
|
|
|
|
print("\nLoading CIFAR-100...") |
|
|
trainset = torchvision.datasets.CIFAR100( |
|
|
root='./data', train=True, download=True, transform=transform_train |
|
|
) |
|
|
testset = torchvision.datasets.CIFAR100( |
|
|
root='./data', train=False, download=True, transform=transform_test |
|
|
) |
|
|
|
|
|
trainloader = torch.utils.data.DataLoader( |
|
|
trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True |
|
|
) |
|
|
testloader = torch.utils.data.DataLoader( |
|
|
testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) |
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), lr=BASE_LR, weight_decay=0.05, betas=(0.9, 0.999) |
|
|
) |
|
|
scheduler = CosineWarmupScheduler( |
|
|
optimizer, warmup_epochs=WARMUP_EPOCHS, total_epochs=EPOCHS, |
|
|
min_lr=1e-6, base_lr=BASE_LR |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_REPO = "AbstractPhil/vit-beatrix-contrarian" |
|
|
CHECKPOINT_INTERVAL = 10 |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
run_name = f"v5_{MODEL_TYPE}_{timestamp}" |
|
|
|
|
|
|
|
|
checkpoint_dir = f"checkpoints/{run_name}" |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
tb_dir = f"{checkpoint_dir}/tensorboard" |
|
|
writer = SummaryWriter(tb_dir) |
|
|
|
|
|
|
|
|
writer.add_text("config/model_type", MODEL_TYPE) |
|
|
writer.add_text("config/num_towers", str(len(model_raw.tower_names))) |
|
|
writer.add_text("config/total_params", f"{total_params:,}") |
|
|
|
|
|
|
|
|
hf_api = HfApi() |
|
|
try: |
|
|
hf_api.repo_info(repo_id=HF_REPO, repo_type="model") |
|
|
print(f"✓ HF repo exists: {HF_REPO}") |
|
|
except Exception: |
|
|
print(f"Creating HF repo: {HF_REPO}") |
|
|
try: |
|
|
hf_api.create_repo(repo_id=HF_REPO, repo_type="model", exist_ok=True) |
|
|
print(f"✓ Created HF repo: {HF_REPO}") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Could not create repo: {e}") |
|
|
|
|
|
def save_best_locally(epoch, model_raw, history, diag, test_acc): |
|
|
"""Save best checkpoint locally (no upload).""" |
|
|
ckpt_path = f"{checkpoint_dir}/{run_name}_best.pth" |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model_raw.state_dict(), |
|
|
'config': model_raw.config.to_dict(), |
|
|
'test_acc': test_acc, |
|
|
'history': history, |
|
|
'diagnostics': diag, |
|
|
'run_name': run_name, |
|
|
'timestamp': timestamp, |
|
|
}, ckpt_path) |
|
|
print(f" 💾 Saved best locally: {run_name}_best.pth") |
|
|
|
|
|
def save_interval_and_upload(epoch, model_raw, history, diag, test_acc): |
|
|
"""Save interval checkpoint and upload everything to HuggingFace.""" |
|
|
|
|
|
ckpt_name = f"{run_name}_e{epoch+1}.pth" |
|
|
ckpt_path = f"{checkpoint_dir}/{ckpt_name}" |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model_raw.state_dict(), |
|
|
'config': model_raw.config.to_dict(), |
|
|
'test_acc': test_acc, |
|
|
'history': history, |
|
|
'diagnostics': diag, |
|
|
'run_name': run_name, |
|
|
'timestamp': timestamp, |
|
|
}, ckpt_path) |
|
|
print(f" 💾 Saved interval: {ckpt_name}") |
|
|
|
|
|
|
|
|
readme_content = f"""# ViT-Beatrix V5 Contrarian Tower Collective |
|
|
|
|
|
## Run: {run_name} |
|
|
|
|
|
### Model Configuration |
|
|
- **Type**: {MODEL_TYPE} |
|
|
- **Total Parameters**: {total_params:,} |
|
|
- **Towers**: {len(model_raw.tower_names)} ({model_raw.config.num_tower_pairs} pos/neg pairs) |
|
|
- **Embed Dim**: {model_raw.config.embed_dim} |
|
|
- **Depth**: {model_raw.config.depth} layers per tower |
|
|
|
|
|
### Training Progress (Epoch {epoch+1}) |
|
|
- **Test Accuracy**: {test_acc:.2f}% |
|
|
- **Best Accuracy**: {best_acc:.2f}% |
|
|
|
|
|
### Files |
|
|
- `{run_name}_best.pth` - Best checkpoint |
|
|
- `{run_name}_e*.pth` - Interval checkpoints |
|
|
- `tensorboard/` - Training metrics |
|
|
|
|
|
### Usage |
|
|
```python |
|
|
import torch |
|
|
from vit_beatrix_v5_contrarian import BeatrixCollective, BeatrixV5Config |
|
|
|
|
|
ckpt = torch.load("{run_name}_best.pth") |
|
|
config = BeatrixV5Config(**ckpt['config']) |
|
|
model = BeatrixCollective(config) |
|
|
model.load_state_dict(ckpt['model_state_dict']) |
|
|
``` |
|
|
""" |
|
|
with open(f"{checkpoint_dir}/README.md", 'w') as f: |
|
|
f.write(readme_content) |
|
|
|
|
|
|
|
|
try: |
|
|
upload_folder( |
|
|
folder_path=checkpoint_dir, |
|
|
repo_id=HF_REPO, |
|
|
path_in_repo=run_name, |
|
|
repo_type="model", |
|
|
) |
|
|
print(f" ☁️ Uploaded to {HF_REPO}/{run_name}") |
|
|
except Exception as e: |
|
|
print(f" ⚠️ Upload failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history = { |
|
|
'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], |
|
|
'fusion_alphas': [], 'tower_lambdas': [], |
|
|
} |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(f"Starting Training ({EPOCHS} epochs)") |
|
|
print(f"Run: {run_name}") |
|
|
print(f"Checkpoints: {checkpoint_dir}") |
|
|
print(f"HuggingFace: {HF_REPO}") |
|
|
print("=" * 70) |
|
|
|
|
|
best_acc = 0 |
|
|
|
|
|
epoch_pbar = tqdm(range(EPOCHS), desc='Training') |
|
|
for epoch in epoch_pbar: |
|
|
lr = scheduler.step(epoch) |
|
|
|
|
|
train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device) |
|
|
test_loss, test_acc = evaluate(model, testloader, criterion, device) |
|
|
|
|
|
|
|
|
diag = model_raw.get_diagnostics() |
|
|
gap = train_acc - test_acc |
|
|
|
|
|
epoch_pbar.set_postfix({ |
|
|
'test': f'{test_acc:.1f}%', |
|
|
'gap': f'{gap:.1f}%', |
|
|
'α': f'{diag["fusion_alphas"][0]:.2f}' |
|
|
}) |
|
|
|
|
|
|
|
|
writer.add_scalar('Loss/train', train_loss, epoch) |
|
|
writer.add_scalar('Loss/test', test_loss, epoch) |
|
|
writer.add_scalar('Accuracy/train', train_acc, epoch) |
|
|
writer.add_scalar('Accuracy/test', test_acc, epoch) |
|
|
writer.add_scalar('Accuracy/gap', gap, epoch) |
|
|
writer.add_scalar('LR', lr, epoch) |
|
|
|
|
|
|
|
|
for i, alpha in enumerate(diag['fusion_alphas']): |
|
|
writer.add_scalar(f'Fusion/alpha_{i}', alpha, epoch) |
|
|
|
|
|
|
|
|
for name, lam in diag['tower_lambdas'].items(): |
|
|
writer.add_scalar(f'Lambda/{name}', lam, epoch) |
|
|
|
|
|
print(f"\nEpoch {epoch+1}/{EPOCHS} | LR: {lr:.6f}") |
|
|
print(f" Train: {train_acc:.2f}% (loss={train_loss:.4f})") |
|
|
print(f" Test: {test_acc:.2f}% (loss={test_loss:.4f}) | Gap: {gap:.2f}%") |
|
|
print(f" Fusion α: {diag['fusion_alphas'][:4]}{'...' if len(diag['fusion_alphas']) > 4 else ''}") |
|
|
|
|
|
|
|
|
if (epoch + 1) % 10 == 0 or test_acc > best_acc: |
|
|
print_diagnostics(epoch + 1, model_raw) |
|
|
|
|
|
|
|
|
history['train_loss'].append(train_loss) |
|
|
history['train_acc'].append(train_acc) |
|
|
history['test_loss'].append(test_loss) |
|
|
history['test_acc'].append(test_acc) |
|
|
history['fusion_alphas'].append(diag['fusion_alphas']) |
|
|
history['tower_lambdas'].append(diag['tower_lambdas']) |
|
|
|
|
|
|
|
|
if test_acc > best_acc: |
|
|
best_acc = test_acc |
|
|
print(f" ★ New best: {best_acc:.2f}%") |
|
|
save_best_locally(epoch, model_raw, history, diag, test_acc) |
|
|
|
|
|
|
|
|
if (epoch + 1) % CHECKPOINT_INTERVAL == 0: |
|
|
save_interval_and_upload(epoch, model_raw, history, diag, test_acc) |
|
|
|
|
|
|
|
|
save_interval_and_upload(EPOCHS-1, model_raw, history, diag, test_acc) |
|
|
|
|
|
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(f"Training Complete!") |
|
|
print(f"Best accuracy: {best_acc:.2f}%") |
|
|
print(f"Checkpoints: {checkpoint_dir}") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print_diagnostics(EPOCHS, model_raw) |
|
|
|
|
|
return model_raw, history |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |