File size: 4,783 Bytes
97ade79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
TemporalMesh Transformer — Inference Script
Full greedy / top-p / top-k text generation with exit gate analysis.
"""

import torch
import torch.nn.functional as F
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel


def load_model(checkpoint_path: str = None, config: TMTConfig = None) -> TMTModel:
    if config is None:
        config = TMTConfig(
            vocab_size=50258, d_model=512, n_heads=8, n_layers=12,
            graph_k=8, exit_threshold=0.85, memory_anchors=16, max_seq_len=256,
        )
    model = TMTModel(config)
    if checkpoint_path:
        ckpt = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(ckpt["model_state"])
    model.eval()
    return model


@torch.no_grad()
def generate(
    model: TMTModel,
    input_ids: torch.Tensor,
    max_new_tokens: int = 64,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    do_sample: bool = True,
) -> dict:
    """
    Generate tokens autoregressively. Returns generated ids + exit analysis.
    """
    device = next(model.parameters()).device
    input_ids = input_ids.to(device)
    generated = input_ids.clone()
    all_exit_stats = []

    for _ in range(max_new_tokens):
        output = model(generated)
        logits = output.logits[:, -1, :] / temperature  # (B, V)

        if top_k > 0:
            values, _ = torch.topk(logits, top_k)
            logits[logits < values[:, -1:]] = -float("Inf")

        if top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            remove = cumulative - F.softmax(sorted_logits, dim=-1) > top_p
            remove[:, 1:] = remove[:, :-1].clone()
            remove[:, 0] = False
            sorted_logits[remove] = -float("Inf")
            logits.scatter_(1, sorted_idx, sorted_logits)

        probs = F.softmax(logits, dim=-1)
        next_token = (
            torch.multinomial(probs, num_samples=1) if do_sample
            else logits.argmax(dim=-1, keepdim=True)
        )
        generated = torch.cat([generated, next_token], dim=1)

        # capture exit stats for this step
        step_exit = {
            "exit_rates": [m.float().mean().item() for m in output.exit_masks],
            "avg_confidence": [c.mean().item() for c in output.confidences],
        }
        all_exit_stats.append(step_exit)

        # stop at max_seq_len
        if generated.shape[1] >= model.config.max_seq_len:
            break

    avg_compute = sum(
        sum(s["exit_rates"]) / len(s["exit_rates"])
        for s in all_exit_stats
    ) / len(all_exit_stats)

    return {
        "generated_ids":    generated,
        "new_tokens":       generated[:, input_ids.shape[1]:],
        "exit_stats":       all_exit_stats,
        "avg_compute_used": round(avg_compute, 3),
    }


def analyse_sequence(model: TMTModel, input_ids: torch.Tensor) -> None:
    """
    Run a single forward pass and print detailed exit gate analysis.
    """
    device = next(model.parameters()).device
    with torch.no_grad():
        output = model(input_ids.to(device))

    S = input_ids.shape[1]
    print(f"\n{'='*55}")
    print(f"  TMT Sequence Analysis  (seq_len={S})")
    print(f"{'='*55}")
    print(f"  Logits shape:  {output.logits.shape}")
    print(f"  Graph edges:   {output.graph_edges[0].shape[1]} active edges")
    print(f"  Memory state:  {output.memory_state.shape}\n")
    print(f"  {'Layer':<8} {'Tokens frozen':>14} {'Exit rate':>12} {'Avg conf':>10}")
    print(f"  {'-'*46}")

    total_frozen = 0
    for i, (mask, conf) in enumerate(zip(output.exit_masks, output.confidences)):
        n_frozen = mask.sum().item()
        total_frozen += n_frozen
        rate = n_frozen / S
        avg_c = conf.mean().item()
        print(f"  {i+1:<8} {n_frozen:>14} {rate:>11.1%} {avg_c:>10.3f}")

    print(f"  {'-'*46}")
    print(f"  Total compute fraction: {total_frozen/(S*len(output.exit_masks)):.1%} of max")
    print(f"  Active graph edges:     {output.graph_edges[0].shape[1]}")
    print(f"{'='*55}\n")


if __name__ == "__main__":
    print("Loading TMT-Small for quick demo...")
    cfg = TMTConfig(
        vocab_size=50258, d_model=256, n_heads=4, n_layers=6,
        graph_k=4, exit_threshold=0.80, memory_anchors=8, max_seq_len=128,
    )
    model = load_model(config=cfg)
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

    ids = torch.randint(100, 50000, (1, 32))
    analyse_sequence(model, ids)

    result = generate(model, ids, max_new_tokens=16, do_sample=False)
    print(f"Generated {result['new_tokens'].shape[1]} new tokens.")
    print(f"Avg compute used per step: {result['avg_compute_used']:.1%}")