""" TemporalMesh Transformer — Inference Script Full greedy / top-p / top-k text generation with exit gate analysis. """ import torch import torch.nn.functional as F from tmt.model.config import TMTConfig from tmt.model.model import TMTModel def load_model(checkpoint_path: str = None, config: TMTConfig = None) -> TMTModel: if config is None: config = TMTConfig( vocab_size=50258, d_model=512, n_heads=8, n_layers=12, graph_k=8, exit_threshold=0.85, memory_anchors=16, max_seq_len=256, ) model = TMTModel(config) if checkpoint_path: ckpt = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(ckpt["model_state"]) model.eval() return model @torch.no_grad() def generate( model: TMTModel, input_ids: torch.Tensor, max_new_tokens: int = 64, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.95, do_sample: bool = True, ) -> dict: """ Generate tokens autoregressively. Returns generated ids + exit analysis. """ device = next(model.parameters()).device input_ids = input_ids.to(device) generated = input_ids.clone() all_exit_stats = [] for _ in range(max_new_tokens): output = model(generated) logits = output.logits[:, -1, :] / temperature # (B, V) if top_k > 0: values, _ = torch.topk(logits, top_k) logits[logits < values[:, -1:]] = -float("Inf") if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) remove = cumulative - F.softmax(sorted_logits, dim=-1) > top_p remove[:, 1:] = remove[:, :-1].clone() remove[:, 0] = False sorted_logits[remove] = -float("Inf") logits.scatter_(1, sorted_idx, sorted_logits) probs = F.softmax(logits, dim=-1) next_token = ( torch.multinomial(probs, num_samples=1) if do_sample else logits.argmax(dim=-1, keepdim=True) ) generated = torch.cat([generated, next_token], dim=1) # capture exit stats for this step step_exit = { "exit_rates": [m.float().mean().item() for m in output.exit_masks], "avg_confidence": [c.mean().item() for c in output.confidences], } all_exit_stats.append(step_exit) # stop at max_seq_len if generated.shape[1] >= model.config.max_seq_len: break avg_compute = sum( sum(s["exit_rates"]) / len(s["exit_rates"]) for s in all_exit_stats ) / len(all_exit_stats) return { "generated_ids": generated, "new_tokens": generated[:, input_ids.shape[1]:], "exit_stats": all_exit_stats, "avg_compute_used": round(avg_compute, 3), } def analyse_sequence(model: TMTModel, input_ids: torch.Tensor) -> None: """ Run a single forward pass and print detailed exit gate analysis. """ device = next(model.parameters()).device with torch.no_grad(): output = model(input_ids.to(device)) S = input_ids.shape[1] print(f"\n{'='*55}") print(f" TMT Sequence Analysis (seq_len={S})") print(f"{'='*55}") print(f" Logits shape: {output.logits.shape}") print(f" Graph edges: {output.graph_edges[0].shape[1]} active edges") print(f" Memory state: {output.memory_state.shape}\n") print(f" {'Layer':<8} {'Tokens frozen':>14} {'Exit rate':>12} {'Avg conf':>10}") print(f" {'-'*46}") total_frozen = 0 for i, (mask, conf) in enumerate(zip(output.exit_masks, output.confidences)): n_frozen = mask.sum().item() total_frozen += n_frozen rate = n_frozen / S avg_c = conf.mean().item() print(f" {i+1:<8} {n_frozen:>14} {rate:>11.1%} {avg_c:>10.3f}") print(f" {'-'*46}") print(f" Total compute fraction: {total_frozen/(S*len(output.exit_masks)):.1%} of max") print(f" Active graph edges: {output.graph_edges[0].shape[1]}") print(f"{'='*55}\n") if __name__ == "__main__": print("Loading TMT-Small for quick demo...") cfg = TMTConfig( vocab_size=50258, d_model=256, n_heads=4, n_layers=6, graph_k=4, exit_threshold=0.80, memory_anchors=8, max_seq_len=128, ) model = load_model(config=cfg) print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") ids = torch.randint(100, 50000, (1, 32)) analyse_sequence(model, ids) result = generate(model, ids, max_new_tokens=16, do_sample=False) print(f"Generated {result['new_tokens'].shape[1]} new tokens.") print(f"Avg compute used per step: {result['avg_compute_used']:.1%}")