| | import math |
| |
|
| | import torch |
| |
|
| | from src.transformer_video import WanDiscreteVideoTransformer |
| |
|
| |
|
| | def _available_device(): |
| | return "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| |
|
| | def test_wan_discrete_video_transformer_forward_and_shapes(): |
| | """ |
| | Basic smoke test: |
| | - build a tiny WanDiscreteVideoTransformer |
| | - run a forward pass with random pseudo-video tokens + random text |
| | - check output shapes, parameter count and (if CUDA present) memory usage |
| | """ |
| |
|
| | device = _available_device() |
| |
|
| | |
| | codebook_size = 128 |
| | vocab_size = codebook_size + 1 |
| | num_frames = 2 |
| | height = 16 |
| | width = 16 |
| |
|
| | model = WanDiscreteVideoTransformer( |
| | codebook_size=codebook_size, |
| | vocab_size=vocab_size, |
| | num_frames=num_frames, |
| | height=height, |
| | width=width, |
| | |
| | in_dim=32, |
| | dim=64, |
| | ffn_dim=128, |
| | freq_dim=32, |
| | text_dim=64, |
| | out_dim=32, |
| | num_heads=4, |
| | num_layers=2, |
| | ).to(device) |
| | model.eval() |
| |
|
| | batch_size = 2 |
| |
|
| | |
| | tokens = torch.randint( |
| | low=0, |
| | high=codebook_size, |
| | size=(batch_size, num_frames, height, width), |
| | dtype=torch.long, |
| | device=device, |
| | ) |
| |
|
| | |
| | text_seq_len = 8 |
| | encoder_hidden_states = torch.randn( |
| | batch_size, text_seq_len, model.backbone.text_dim, device=device |
| | ) |
| |
|
| | |
| | timesteps = torch.randint( |
| | low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device |
| | ) |
| |
|
| | |
| | if device == "cuda": |
| | torch.cuda.reset_peak_memory_stats() |
| | mem_before = torch.cuda.memory_allocated() |
| | else: |
| | mem_before = 0 |
| |
|
| | with torch.no_grad(): |
| | logits = model( |
| | tokens=tokens, |
| | timesteps=timesteps, |
| | encoder_hidden_states=encoder_hidden_states, |
| | y=None, |
| | ) |
| |
|
| | if device == "cuda": |
| | mem_after = torch.cuda.memory_allocated() |
| | peak_mem = torch.cuda.max_memory_allocated() |
| | else: |
| | mem_after = mem_before |
| | peak_mem = mem_before |
| |
|
| | |
| | assert logits.shape[0] == batch_size |
| | assert logits.shape[1] == codebook_size |
| | assert logits.shape[2] == num_frames |
| |
|
| | |
| | h_out = height // model.backbone.patch_size[1] |
| | w_out = width // model.backbone.patch_size[2] |
| | assert logits.shape[3] == h_out |
| | assert logits.shape[4] == w_out |
| |
|
| | |
| | num_params = sum(p.numel() for p in model.parameters()) |
| | assert num_params > 0 |
| | assert math.isfinite(float(num_params)) |
| |
|
| | |
| | if device == "cuda": |
| | assert peak_mem >= mem_after >= mem_before |
| |
|
| |
|
| |
|