vigneshwar234 commited on
Commit
e131f80
·
verified ·
1 Parent(s): 3ee6731

Add source: tests/test_forward.py

Browse files
Files changed (1) hide show
  1. tests/test_forward.py +95 -0
tests/test_forward.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_forward.py — full end-to-end forward pass smoke tests.
3
+
4
+ Run: pytest tests/test_forward.py -v
5
+ """
6
+ import torch
7
+ import pytest
8
+
9
+ from tmt.model.config import TMTConfig
10
+ from tmt.model.model import TMTModel, TMTOutput
11
+ from tmt.training.loss import compute_loss
12
+
13
+
14
+ B, S = 2, 32
15
+ CFG = TMTConfig(
16
+ vocab_size=1000,
17
+ d_model=64,
18
+ n_heads=4,
19
+ n_layers=3,
20
+ max_seq_len=64,
21
+ graph_k=4,
22
+ ffn_stream_dim=32,
23
+ memory_anchors=4,
24
+ )
25
+
26
+
27
+ @pytest.fixture
28
+ def model():
29
+ return TMTModel(CFG)
30
+
31
+
32
+ @pytest.fixture
33
+ def input_ids():
34
+ return torch.randint(0, CFG.vocab_size, (B, S))
35
+
36
+
37
+ def test_full_forward_shape(model, input_ids):
38
+ out = model(input_ids)
39
+ assert isinstance(out, TMTOutput)
40
+ assert out.logits.shape == (B, S, CFG.vocab_size)
41
+
42
+
43
+ def test_output_has_all_fields(model, input_ids):
44
+ out = model(input_ids)
45
+ assert len(out.exit_masks) == CFG.n_layers
46
+ assert len(out.confidences) == CFG.n_layers
47
+ edge_index, edge_weight = out.graph_edges
48
+ assert edge_index.shape[0] == 2
49
+ assert out.memory_state.shape == (CFG.memory_anchors, CFG.d_model)
50
+ assert out.decay_scalars.shape == (B, S, CFG.d_model)
51
+
52
+
53
+ def test_loss_computable(model, input_ids):
54
+ # Use first S-1 tokens as input, predict last S-1 as targets
55
+ x = input_ids[:, :-1]
56
+ targets = input_ids[:, 1:]
57
+ out = model(x)
58
+ total, ce, gate = compute_loss(out.logits, targets, out.confidences)
59
+ assert total.item() > 0
60
+ assert not torch.isnan(total)
61
+ assert not torch.isinf(total)
62
+
63
+
64
+ def test_backward_pass(model, input_ids):
65
+ x = input_ids[:, :-1]
66
+ targets = input_ids[:, 1:]
67
+ out = model(x)
68
+ total, _, _ = compute_loss(out.logits, targets, out.confidences)
69
+ total.backward()
70
+ # Check at least some gradients flowed
71
+ for name, param in model.named_parameters():
72
+ if param.requires_grad and param.grad is not None:
73
+ assert not torch.isnan(param.grad).any(), f"NaN grad in {name}"
74
+
75
+
76
+ def test_exit_mask_monotone(model, input_ids):
77
+ """Once a token exits, it must stay exited in subsequent layers."""
78
+ out = model(input_ids)
79
+ for i in range(1, len(out.exit_masks)):
80
+ prev = out.exit_masks[i - 1]
81
+ curr = out.exit_masks[i]
82
+ # If exited in layer i-1, must be exited in layer i
83
+ assert (curr[prev]).all(), "Exit mask is not monotonically set"
84
+
85
+
86
+ def test_no_nan_in_logits(model, input_ids):
87
+ out = model(input_ids)
88
+ assert not torch.isnan(out.logits).any()
89
+ assert not torch.isinf(out.logits).any()
90
+
91
+
92
+ def test_model_repr(model):
93
+ r = repr(model)
94
+ assert "TMTModel" in r
95
+ assert "params" in r