File size: 7,990 Bytes
9bfb518 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | #!/usr/bin/env python3
"""
MicroForge End-to-End Test Suite
Validates all modules work correctly on CPU.
"""
import torch
import time
import sys
import os
# Add parent to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_vae():
"""Test all VAE configurations."""
from microforge.vae import MicroForgeVAE
print("=" * 60)
print("TEST: MicroForge VAE")
print("=" * 60)
for config in ['tiny', 'small', 'base']:
vae = MicroForgeVAE(config=config)
params = sum(p.numel() for p in vae.parameters())
# Test forward pass
x = torch.randn(1, 3, 256, 256)
x_recon, mu, logvar = vae(x)
assert x_recon.shape == x.shape, f"Recon shape mismatch: {x_recon.shape} vs {x.shape}"
assert not torch.isnan(mu).any(), "NaN in mu"
assert not torch.isnan(logvar).any(), "NaN in logvar"
# Test encode/decode
z = vae.get_latent(x)
x_dec = vae.decode(z)
assert x_dec.shape == x.shape
# Test KL loss
kl = MicroForgeVAE.kl_loss(mu, logvar)
assert not torch.isnan(kl), "NaN in KL loss"
print(f" [{config:>5}] PASS | params={params:,} | latent={mu.shape} | KL={kl.item():.2f}")
print()
def test_backbone():
"""Test all backbone configurations."""
from microforge.backbone import MicroForgeBackbone
print("=" * 60)
print("TEST: MicroForge Backbone")
print("=" * 60)
for config in ['tiny', 'small', 'base']:
lc = 16 if config == 'tiny' else 32
backbone = MicroForgeBackbone(latent_channels=lc, config=config)
params = sum(p.numel() for p in backbone.parameters())
z = torch.randn(1, lc, 8, 8)
t = torch.rand(1)
text_emb = torch.randn(1, 10, 768)
text_pooled = torch.randn(1, 768)
start = time.time()
v = backbone(z, t, text_emb, text_pooled)
elapsed = (time.time() - start) * 1000
assert v.shape == z.shape, f"Output shape mismatch: {v.shape} vs {z.shape}"
assert not torch.isnan(v).any(), "NaN in velocity prediction"
print(f" [{config:>5}] PASS | params={params:,} | latency={elapsed:.0f}ms")
print()
def test_planner():
"""Test Recurrent Latent Planner."""
from microforge.planner import RecurrentLatentPlanner
print("=" * 60)
print("TEST: Recurrent Latent Planner")
print("=" * 60)
planner = RecurrentLatentPlanner(
num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32
)
params = sum(p.numel() for p in planner.parameters())
# Test initialization
text_pooled = torch.randn(2, 768)
plan = planner.initialize_plan(text_pooled, batch_size=2)
assert plan.shape == (2, 32, 384), f"Plan shape: {plan.shape}"
# Test forward
img_tokens = torch.randn(2, 64, 32) # 8x8 latent flattened
t_emb = torch.randn(2, 384)
plan_out, output = planner(img_tokens, plan, t_emb)
assert plan_out.shape == (2, 32, 384)
assert output.shape == (2, 32, 768) # Projected to text_dim
assert not torch.isnan(plan_out).any()
assert not torch.isnan(output).any()
# Test self-conditioning
plan_next = planner.initialize_plan(text_pooled, 2, prev_plan=plan_out)
assert plan_next.shape == plan.shape
print(f" PASS | params={params:,} | plan_state={planner.get_plan_size_bytes()} bytes")
print()
def test_training():
"""Test training loop."""
from microforge.vae import MicroForgeVAE
from microforge.backbone import MicroForgeBackbone
from microforge.planner import RecurrentLatentPlanner
from microforge.training import MicroForgeTrainer, FlowMatchingScheduler
print("=" * 60)
print("TEST: Training Pipeline")
print("=" * 60)
vae = MicroForgeVAE(config='tiny').eval()
backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
trainer = MicroForgeTrainer(vae, backbone, planner, lr=1e-4, use_ema=True)
# Test flow matching scheduler
scheduler = FlowMatchingScheduler()
t = scheduler.sample_timesteps(4, torch.device('cpu'))
assert t.min() >= 0 and t.max() <= 1, f"Timesteps out of range: {t}"
z_0 = torch.randn(4, 16, 4, 4)
noise = torch.randn_like(z_0)
z_t, v_target = scheduler.add_noise(z_0, noise, t)
assert z_t.shape == z_0.shape
assert v_target.shape == z_0.shape
# Test training steps
images = torch.randn(2, 3, 128, 128)
text_emb = torch.randn(2, 10, 768)
text_pooled = torch.randn(2, 768)
losses = []
for i in range(5):
step_losses = trainer.train_step(images, text_emb, text_pooled)
losses.append(step_losses['flow'])
assert not any(torch.isnan(torch.tensor(v)) for v in step_losses.values()), \
f"NaN in losses: {step_losses}"
print(f" 5 training steps: loss {losses[0]:.2f} -> {losses[-1]:.2f}")
print(f" PASS")
print()
def test_pipeline():
"""Test end-to-end inference pipeline."""
from microforge.vae import MicroForgeVAE
from microforge.backbone import MicroForgeBackbone
from microforge.planner import RecurrentLatentPlanner
from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder
print("=" * 60)
print("TEST: End-to-End Pipeline")
print("=" * 60)
vae = MicroForgeVAE(config='tiny')
backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
text_enc = SimpleTextEncoder(embed_dim=768, num_layers=2)
pipeline = MicroForgePipeline(vae, backbone, text_enc, planner, device='cpu')
# Test text2img
tokens = torch.randint(0, 8192, (1, 10))
start = time.time()
images = pipeline.text2img(tokens, height=128, width=128, num_steps=2, cfg_scale=1.0, seed=42)
t2i_time = time.time() - start
assert images.shape == (1, 3, 128, 128), f"Wrong output shape: {images.shape}"
assert images.min() >= -1 and images.max() <= 1, f"Range error: [{images.min()}, {images.max()}]"
print(f" text2img: {images.shape} in {t2i_time:.2f}s | PASS")
# Test parameter count
params = pipeline.count_parameters()
print(f" Total params: {params['total']:,}")
# Test memory estimate
mem = pipeline.get_memory_estimate(512, 512)
print(f" Est. memory @512px: {mem['estimated_inference_mb']:.0f} MB")
print(f" PASS")
print()
def test_editing_pathway():
"""Test that editing pathway works (spatial concat)."""
from microforge.backbone import MicroForgeBackbone
print("=" * 60)
print("TEST: Editing Pathway (Spatial Concat)")
print("=" * 60)
backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
# Standard generation: 8x8 latent
z_gen = torch.randn(1, 16, 8, 8)
t = torch.rand(1)
text_emb = torch.randn(1, 5, 768)
text_pooled = torch.randn(1, 768)
v_gen = backbone(z_gen, t, text_emb, text_pooled)
assert v_gen.shape == z_gen.shape, f"Gen output shape: {v_gen.shape}"
# Editing: 8x16 latent (width-concat target + source)
z_edit = torch.randn(1, 16, 8, 16) # Doubled width
v_edit = backbone(z_edit, t, text_emb, text_pooled)
assert v_edit.shape == z_edit.shape, f"Edit output shape: {v_edit.shape}"
# Extract target velocity (left half)
v_target = v_edit[..., :8]
assert v_target.shape == z_gen.shape
print(f" Generation: {z_gen.shape} -> {v_gen.shape} | PASS")
print(f" Editing: {z_edit.shape} -> {v_edit.shape} | PASS")
print()
def main():
print()
print("🔨 MicroForge Architecture Test Suite")
print("=" * 60)
print()
test_vae()
test_backbone()
test_planner()
test_training()
test_pipeline()
test_editing_pathway()
print("=" * 60)
print("✅ ALL TESTS PASSED")
print("=" * 60)
if __name__ == "__main__":
main()
|