File size: 7,990 Bytes
9bfb518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#!/usr/bin/env python3
"""
MicroForge End-to-End Test Suite
Validates all modules work correctly on CPU.
"""

import torch
import time
import sys
import os

# Add parent to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))


def test_vae():
    """Test all VAE configurations."""
    from microforge.vae import MicroForgeVAE

    print("=" * 60)
    print("TEST: MicroForge VAE")
    print("=" * 60)

    for config in ['tiny', 'small', 'base']:
        vae = MicroForgeVAE(config=config)
        params = sum(p.numel() for p in vae.parameters())

        # Test forward pass
        x = torch.randn(1, 3, 256, 256)
        x_recon, mu, logvar = vae(x)

        assert x_recon.shape == x.shape, f"Recon shape mismatch: {x_recon.shape} vs {x.shape}"
        assert not torch.isnan(mu).any(), "NaN in mu"
        assert not torch.isnan(logvar).any(), "NaN in logvar"

        # Test encode/decode
        z = vae.get_latent(x)
        x_dec = vae.decode(z)
        assert x_dec.shape == x.shape

        # Test KL loss
        kl = MicroForgeVAE.kl_loss(mu, logvar)
        assert not torch.isnan(kl), "NaN in KL loss"

        print(f"  [{config:>5}] PASS | params={params:,} | latent={mu.shape} | KL={kl.item():.2f}")

    print()


def test_backbone():
    """Test all backbone configurations."""
    from microforge.backbone import MicroForgeBackbone

    print("=" * 60)
    print("TEST: MicroForge Backbone")
    print("=" * 60)

    for config in ['tiny', 'small', 'base']:
        lc = 16 if config == 'tiny' else 32
        backbone = MicroForgeBackbone(latent_channels=lc, config=config)
        params = sum(p.numel() for p in backbone.parameters())

        z = torch.randn(1, lc, 8, 8)
        t = torch.rand(1)
        text_emb = torch.randn(1, 10, 768)
        text_pooled = torch.randn(1, 768)

        start = time.time()
        v = backbone(z, t, text_emb, text_pooled)
        elapsed = (time.time() - start) * 1000

        assert v.shape == z.shape, f"Output shape mismatch: {v.shape} vs {z.shape}"
        assert not torch.isnan(v).any(), "NaN in velocity prediction"

        print(f"  [{config:>5}] PASS | params={params:,} | latency={elapsed:.0f}ms")

    print()


def test_planner():
    """Test Recurrent Latent Planner."""
    from microforge.planner import RecurrentLatentPlanner

    print("=" * 60)
    print("TEST: Recurrent Latent Planner")
    print("=" * 60)

    planner = RecurrentLatentPlanner(
        num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32
    )
    params = sum(p.numel() for p in planner.parameters())

    # Test initialization
    text_pooled = torch.randn(2, 768)
    plan = planner.initialize_plan(text_pooled, batch_size=2)
    assert plan.shape == (2, 32, 384), f"Plan shape: {plan.shape}"

    # Test forward
    img_tokens = torch.randn(2, 64, 32)  # 8x8 latent flattened
    t_emb = torch.randn(2, 384)
    plan_out, output = planner(img_tokens, plan, t_emb)

    assert plan_out.shape == (2, 32, 384)
    assert output.shape == (2, 32, 768)  # Projected to text_dim
    assert not torch.isnan(plan_out).any()
    assert not torch.isnan(output).any()

    # Test self-conditioning
    plan_next = planner.initialize_plan(text_pooled, 2, prev_plan=plan_out)
    assert plan_next.shape == plan.shape

    print(f"  PASS | params={params:,} | plan_state={planner.get_plan_size_bytes()} bytes")
    print()


def test_training():
    """Test training loop."""
    from microforge.vae import MicroForgeVAE
    from microforge.backbone import MicroForgeBackbone
    from microforge.planner import RecurrentLatentPlanner
    from microforge.training import MicroForgeTrainer, FlowMatchingScheduler

    print("=" * 60)
    print("TEST: Training Pipeline")
    print("=" * 60)

    vae = MicroForgeVAE(config='tiny').eval()
    backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
    planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)

    trainer = MicroForgeTrainer(vae, backbone, planner, lr=1e-4, use_ema=True)

    # Test flow matching scheduler
    scheduler = FlowMatchingScheduler()
    t = scheduler.sample_timesteps(4, torch.device('cpu'))
    assert t.min() >= 0 and t.max() <= 1, f"Timesteps out of range: {t}"

    z_0 = torch.randn(4, 16, 4, 4)
    noise = torch.randn_like(z_0)
    z_t, v_target = scheduler.add_noise(z_0, noise, t)
    assert z_t.shape == z_0.shape
    assert v_target.shape == z_0.shape

    # Test training steps
    images = torch.randn(2, 3, 128, 128)
    text_emb = torch.randn(2, 10, 768)
    text_pooled = torch.randn(2, 768)

    losses = []
    for i in range(5):
        step_losses = trainer.train_step(images, text_emb, text_pooled)
        losses.append(step_losses['flow'])
        assert not any(torch.isnan(torch.tensor(v)) for v in step_losses.values()), \
            f"NaN in losses: {step_losses}"

    print(f"  5 training steps: loss {losses[0]:.2f} -> {losses[-1]:.2f}")
    print(f"  PASS")
    print()


def test_pipeline():
    """Test end-to-end inference pipeline."""
    from microforge.vae import MicroForgeVAE
    from microforge.backbone import MicroForgeBackbone
    from microforge.planner import RecurrentLatentPlanner
    from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder

    print("=" * 60)
    print("TEST: End-to-End Pipeline")
    print("=" * 60)

    vae = MicroForgeVAE(config='tiny')
    backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
    planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
    text_enc = SimpleTextEncoder(embed_dim=768, num_layers=2)

    pipeline = MicroForgePipeline(vae, backbone, text_enc, planner, device='cpu')

    # Test text2img
    tokens = torch.randint(0, 8192, (1, 10))
    start = time.time()
    images = pipeline.text2img(tokens, height=128, width=128, num_steps=2, cfg_scale=1.0, seed=42)
    t2i_time = time.time() - start

    assert images.shape == (1, 3, 128, 128), f"Wrong output shape: {images.shape}"
    assert images.min() >= -1 and images.max() <= 1, f"Range error: [{images.min()}, {images.max()}]"

    print(f"  text2img: {images.shape} in {t2i_time:.2f}s | PASS")

    # Test parameter count
    params = pipeline.count_parameters()
    print(f"  Total params: {params['total']:,}")

    # Test memory estimate
    mem = pipeline.get_memory_estimate(512, 512)
    print(f"  Est. memory @512px: {mem['estimated_inference_mb']:.0f} MB")

    print(f"  PASS")
    print()


def test_editing_pathway():
    """Test that editing pathway works (spatial concat)."""
    from microforge.backbone import MicroForgeBackbone

    print("=" * 60)
    print("TEST: Editing Pathway (Spatial Concat)")
    print("=" * 60)

    backbone = MicroForgeBackbone(latent_channels=16, config='tiny')

    # Standard generation: 8x8 latent
    z_gen = torch.randn(1, 16, 8, 8)
    t = torch.rand(1)
    text_emb = torch.randn(1, 5, 768)
    text_pooled = torch.randn(1, 768)

    v_gen = backbone(z_gen, t, text_emb, text_pooled)
    assert v_gen.shape == z_gen.shape, f"Gen output shape: {v_gen.shape}"

    # Editing: 8x16 latent (width-concat target + source)
    z_edit = torch.randn(1, 16, 8, 16)  # Doubled width
    v_edit = backbone(z_edit, t, text_emb, text_pooled)
    assert v_edit.shape == z_edit.shape, f"Edit output shape: {v_edit.shape}"

    # Extract target velocity (left half)
    v_target = v_edit[..., :8]
    assert v_target.shape == z_gen.shape

    print(f"  Generation: {z_gen.shape} -> {v_gen.shape} | PASS")
    print(f"  Editing:    {z_edit.shape} -> {v_edit.shape} | PASS")
    print()


def main():
    print()
    print("🔨 MicroForge Architecture Test Suite")
    print("=" * 60)
    print()

    test_vae()
    test_backbone()
    test_planner()
    test_training()
    test_pipeline()
    test_editing_pathway()

    print("=" * 60)
    print("✅ ALL TESTS PASSED")
    print("=" * 60)


if __name__ == "__main__":
    main()