TemporalMesh-Transformer / inference.py
vigneshwar234's picture
Add inference.py
97ade79 verified
"""
TemporalMesh Transformer — Inference Script
Full greedy / top-p / top-k text generation with exit gate analysis.
"""
import torch
import torch.nn.functional as F
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel
def load_model(checkpoint_path: str = None, config: TMTConfig = None) -> TMTModel:
if config is None:
config = TMTConfig(
vocab_size=50258, d_model=512, n_heads=8, n_layers=12,
graph_k=8, exit_threshold=0.85, memory_anchors=16, max_seq_len=256,
)
model = TMTModel(config)
if checkpoint_path:
ckpt = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(ckpt["model_state"])
model.eval()
return model
@torch.no_grad()
def generate(
model: TMTModel,
input_ids: torch.Tensor,
max_new_tokens: int = 64,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
do_sample: bool = True,
) -> dict:
"""
Generate tokens autoregressively. Returns generated ids + exit analysis.
"""
device = next(model.parameters()).device
input_ids = input_ids.to(device)
generated = input_ids.clone()
all_exit_stats = []
for _ in range(max_new_tokens):
output = model(generated)
logits = output.logits[:, -1, :] / temperature # (B, V)
if top_k > 0:
values, _ = torch.topk(logits, top_k)
logits[logits < values[:, -1:]] = -float("Inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cumulative - F.softmax(sorted_logits, dim=-1) > top_p
remove[:, 1:] = remove[:, :-1].clone()
remove[:, 0] = False
sorted_logits[remove] = -float("Inf")
logits.scatter_(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_token = (
torch.multinomial(probs, num_samples=1) if do_sample
else logits.argmax(dim=-1, keepdim=True)
)
generated = torch.cat([generated, next_token], dim=1)
# capture exit stats for this step
step_exit = {
"exit_rates": [m.float().mean().item() for m in output.exit_masks],
"avg_confidence": [c.mean().item() for c in output.confidences],
}
all_exit_stats.append(step_exit)
# stop at max_seq_len
if generated.shape[1] >= model.config.max_seq_len:
break
avg_compute = sum(
sum(s["exit_rates"]) / len(s["exit_rates"])
for s in all_exit_stats
) / len(all_exit_stats)
return {
"generated_ids": generated,
"new_tokens": generated[:, input_ids.shape[1]:],
"exit_stats": all_exit_stats,
"avg_compute_used": round(avg_compute, 3),
}
def analyse_sequence(model: TMTModel, input_ids: torch.Tensor) -> None:
"""
Run a single forward pass and print detailed exit gate analysis.
"""
device = next(model.parameters()).device
with torch.no_grad():
output = model(input_ids.to(device))
S = input_ids.shape[1]
print(f"\n{'='*55}")
print(f" TMT Sequence Analysis (seq_len={S})")
print(f"{'='*55}")
print(f" Logits shape: {output.logits.shape}")
print(f" Graph edges: {output.graph_edges[0].shape[1]} active edges")
print(f" Memory state: {output.memory_state.shape}\n")
print(f" {'Layer':<8} {'Tokens frozen':>14} {'Exit rate':>12} {'Avg conf':>10}")
print(f" {'-'*46}")
total_frozen = 0
for i, (mask, conf) in enumerate(zip(output.exit_masks, output.confidences)):
n_frozen = mask.sum().item()
total_frozen += n_frozen
rate = n_frozen / S
avg_c = conf.mean().item()
print(f" {i+1:<8} {n_frozen:>14} {rate:>11.1%} {avg_c:>10.3f}")
print(f" {'-'*46}")
print(f" Total compute fraction: {total_frozen/(S*len(output.exit_masks)):.1%} of max")
print(f" Active graph edges: {output.graph_edges[0].shape[1]}")
print(f"{'='*55}\n")
if __name__ == "__main__":
print("Loading TMT-Small for quick demo...")
cfg = TMTConfig(
vocab_size=50258, d_model=256, n_heads=4, n_layers=6,
graph_k=4, exit_threshold=0.80, memory_anchors=8, max_seq_len=128,
)
model = load_model(config=cfg)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
ids = torch.randint(100, 50000, (1, 32))
analyse_sequence(model, ids)
result = generate(model, ids, max_new_tokens=16, do_sample=False)
print(f"Generated {result['new_tokens'].shape[1]} new tokens.")
print(f"Avg compute used per step: {result['avg_compute_used']:.1%}")