import math import torch from src.transformer_video import WanDiscreteVideoTransformer def _available_device(): return "cuda" if torch.cuda.is_available() else "cpu" def test_wan_discrete_video_transformer_forward_and_shapes(): """ Basic smoke test: - build a tiny WanDiscreteVideoTransformer - run a forward pass with random pseudo-video tokens + random text - check output shapes, parameter count and (if CUDA present) memory usage """ device = _available_device() # small config to keep the test lightweight codebook_size = 128 vocab_size = codebook_size + 1 # reserve one for mask if needed later num_frames = 2 height = 16 width = 16 model = WanDiscreteVideoTransformer( codebook_size=codebook_size, vocab_size=vocab_size, num_frames=num_frames, height=height, width=width, # shrink Wan backbone for the unit test in_dim=32, dim=64, ffn_dim=128, freq_dim=32, text_dim=64, out_dim=32, num_heads=4, num_layers=2, ).to(device) model.eval() batch_size = 2 # pseudo-video tokens from 2D VQ-VAE on frames: [B, F, H, W] tokens = torch.randint( low=0, high=codebook_size, size=(batch_size, num_frames, height, width), dtype=torch.long, device=device, ) # text: [B, L, C_text] text_seq_len = 8 encoder_hidden_states = torch.randn( batch_size, text_seq_len, model.backbone.text_dim, device=device ) # timesteps: [B] timesteps = torch.randint( low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device ) # track memory if CUDA is available if device == "cuda": torch.cuda.reset_peak_memory_stats() mem_before = torch.cuda.memory_allocated() else: mem_before = 0 with torch.no_grad(): logits = model( tokens=tokens, timesteps=timesteps, encoder_hidden_states=encoder_hidden_states, y=None, ) if device == "cuda": mem_after = torch.cuda.memory_allocated() peak_mem = torch.cuda.max_memory_allocated() else: mem_after = mem_before peak_mem = mem_before # logits: [B, codebook_size, F, H_out, W_out] assert logits.shape[0] == batch_size assert logits.shape[1] == codebook_size assert logits.shape[2] == num_frames # spatial size after Wan patch embedding with default patch_size (1, 2, 2) h_out = height // model.backbone.patch_size[1] w_out = width // model.backbone.patch_size[2] assert logits.shape[3] == h_out assert logits.shape[4] == w_out # parameter count sanity check (just ensure it's > 0 and finite) num_params = sum(p.numel() for p in model.parameters()) assert num_params > 0 assert math.isfinite(float(num_params)) # memory sanity check (on CUDA the forward pass should allocate > 0 bytes) if device == "cuda": assert peak_mem >= mem_after >= mem_before