SD/E2E-SD VAE to DINOv2 Bridge
GET Dino v2 features from SD VAE latents
| Item | Value |
|---|---|
| Input | SD/E2E-SD VAE latent |
| Input shape | [B, 4, 32, 32] |
| Output | DINOv2 patch-token features |
| Output shape | [B, 64, 768] |
| Patch grid | 8 ร 8 |
| CLS token | Not included |
| DINO target family | DINOv2-base-style, 768-dim |
| Bridge body | Adapter + Transformer bridge |
| Current training dataset | kingsidharth/zangei-dit-stage-1-250k |
| Training rows | ~220k |
| Main checkpoint | checkpoints/best.pt |
| Latest checkpoint | checkpoints/latest.pt |
Architecture
Our design cleanly separates the modality-specific layers from the spatial processing body:
Latent Adapter: A lightweight convolutional stem (VAE-specific). Maps 4-channel VAE latents up to the bridge's working width. Bridge Backbone: A standard transformer body (width 768, depth 8) that remains VAE-agnostic. Token Head: A linear projection that maps transformer outputs to the expected DINO patch targets (e.g., 64 tokens of 768 dim).
Note: This decoupled design means for future models like FLUX, you can swap out just the Latent Adapter (to handle 16-channel latents) while freezing/reusing the learned bridge backbone.
Loss Function
The training utilizes a composite, geometry-aware loss function (bridge_loss) designed to prioritize structural and directional alignment over raw magnitude matching:
- Cosine Loss (Weight: 1.0): 1.0 - cosine_similarity(pred, target). The primary driver, focusing heavily on matching the semantic direction of the DINOv2 features.
- MSE Norm (Weight: 0.25): Standard MSE applied after L2-normalizing the predictions and targets.
- MSE Raw (Weight: 0.05): Standard MSE applied to the raw values. Keeps the scale grounded without letting magnitude differences dominate the gradients.
Training History
| Epoch | Val loss | Val cosine โ | Val NMSE โ | Retrieval@1 โ | Retrieval@5 โ | Retrieval@10 โ |
|---|---|---|---|---|---|---|
| 1 | 0.401512 | 0.624136 | 0.610722 | 0.751818 | 0.905000 | 0.940455 |
| 2 | 0.335588 | 0.686514 | 0.526457 | 0.910000 | 0.979545 | 0.987727 |
| 3 | 0.303365 | 0.716936 | 0.483648 | 0.960000 | 0.990000 | 0.995000 |
| 4 | 0.282880 | 0.736246 | 0.455694 | 0.975455 | 0.995000 | 0.996818 |
| 5 | 0.268303 | 0.749953 | 0.434974 | 0.987273 | 0.996364 | 0.998182 |
| 6 | 0.258487 | 0.759174 | 0.420825 | 0.988182 | 0.996364 | 0.998636 |
Current Quality Read
The bridge is learning correctly.
Strong signal" Retrieval@1 reached ~98.8% on the held-out validation subset.
This means the predicted features preserve enough image identity / semantic structure to retrieve the matching true DINO target among validation candidates.
However, raw patch cosine is still: ~0.759
So the bridge is not yet a perfect DINO replacement. It is already useful for ranking / retrieval-like proxy supervision, but should be improved before being treated as a high-fidelity DINO teacher.
Suggested target before production use as a serious DINO proxy:
- val cosine: 0.85+
- val NMSE: <0.25
- retrieval@1: remain >0.95 on larger external eval
How to Use
Prepare Data: Pre-pack your SD latents ([N, 4, 32, 32]) and DINOv2 features ([N, 64, 768]) into memory-mappable .npy files. Ensure you are targeting the 8x8 patch grid, excluding the DINO CLS token. Configure: Update paths and training knobs in the @dataclass class CFG (Cell 3). This serves as the single source of truth for the run. Run All: The notebook will handle package installation, wandb logging, dataset splitting, and mixed-precision (AMP) training automatically.
Basic
import torch
ckpt = torch.load("checkpoints/best.pt", map_location="cpu")
state_dict = ckpt["model"] if "model" in ckpt else ckpt
model = DinoBridgeV3(
in_ch=4,
target_tokens=64,
target_dim=768,
adapter_mid_channels=256,
adapter_out_channels=512,
adapter_depth=2,
width=768,
depth=8,
heads=12,
mlp_ratio=4.0,
dropout=0.02,
)
model.load_state_dict(state_dict, strict=True)
model.eval().cuda()
Advanced
import torch
import torch.nn.functional as F
best_ckpt_path = OUT_DIR / "best.pt"
print(f"Loading checkpoint from: {best_ckpt_path}")
inference_model = DinoBridgeV3(
in_ch=in_ch,
target_tokens=TARGET_TOKENS,
target_dim=TARGET_DIM,
adapter_mid_channels=cfg.adapter_mid_channels,
adapter_out_channels=cfg.adapter_out_channels,
adapter_depth=cfg.adapter_depth,
width=cfg.model_width,
depth=cfg.model_depth,
heads=cfg.model_heads,
mlp_ratio=cfg.mlp_ratio,
dropout=cfg.dropout,
).to(device)
if best_ckpt_path.exists():
ckpt = torch.load(best_ckpt_path, map_location=device)
inference_model.load_state_dict(ckpt["model"])
print("Weights loaded successfully.")
else:
print("Checkpoint not found. Make sure you have completed at least one training epoch.")
inference_model.eval()
# In practice, this would be the output from your VAE encoder: latent = vae.encode(image)
sample_latent, sample_target = val_ds[0]
sample_latent = sample_latent.unsqueeze(0).to(device).float() # Add batch dim: [1, C, H, W]
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device=="cuda" and cfg.amp)):
pred_dino_features = inference_model(sample_latent)
print("\n--- Inference Results ---")
print("Input latent shape:", sample_latent.shape)
print("Predicted DINO features shape:", pred_dino_features.shape)
print("Ground truth DINO features shape:", sample_target.shape)
# 5. Quick comparison to ground truth (Cosine Similarity)
pred_norm = F.normalize(pred_dino_features[0].float(), dim=-1)
target_norm = F.normalize(sample_target.to(device).float(), dim=-1)
sim = F.cosine_similarity(pred_norm, target_norm, dim=-1).mean().item()
print(f"Average Cosine Similarity for this sample: {sim:.4f}")
Short HF model-card table
| Section | Value |
|---|---|
| Repo | kingsidharth/sd_vae_2_dino_v2_bridge |
| Task | VAE latent โ DINOv2 feature prediction |
| Input | [B, 4, 32, 32] SD/E2E-SD latent |
| Output | [B, 64, 768] DINOv2 patch tokens |
| Best checkpoint | checkpoints/best.pt |
| Safe checkpoint after interrupt | checkpoints/epoch_006.pt |
| Latest checkpoint caveat | latest.pt may be incomplete if interrupted during final save |
| Best logged val cosine | 0.759174 |
| Best logged val NMSE | 0.420825 |
| Best logged Retrieval@1 | 0.988182 |
| Best logged Retrieval@5 | 0.996364 |
| Best logged Retrieval@10 | 0.998636 |
The training log shows validation improving consistently from epoch 1 to epoch 6: cosine rose from 0.624136 to 0.759174, NMSE fell from 0.610722 to 0.420825, and Retrieval@1 rose from 0.751818 to 0.988182. The run was interrupted during the final latest.pt save after epoch_006.pt had already been saved, so best.pt / epoch_006.pt are the safer checkpoints. :contentReference[oaicite:0]{index=0}
Model tree for kingsidharth/sd_vae_2_dino_v2_bridge
Base model
facebook/dinov2-small