vigneshwar234 commited on
Commit
b78b4bc
·
verified ·
1 Parent(s): 0a26cdd

Add source: tmt/experiments/02_mesh_only.ipynb

Browse files
Files changed (1) hide show
  1. tmt/experiments/02_mesh_only.ipynb +52 -0
tmt/experiments/02_mesh_only.ipynb ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": ["# Experiment 02 — Ablation: Mesh Attention Only\n", "TMT with only Innovation 1 (mesh attention) to isolate its contribution."]
7
+ },
8
+ {
9
+ "cell_type": "code",
10
+ "execution_count": null,
11
+ "metadata": {},
12
+ "outputs": [],
13
+ "source": [
14
+ "import torch\n",
15
+ "from tmt.model.config import TMTConfig\n",
16
+ "from tmt.model.model import TMTModel\n",
17
+ "from tmt.training.trainer import TMTTrainer, TrainConfig\n",
18
+ "from tmt.data.dataset import load_text_dataset\n",
19
+ "\n",
20
+ "# Disable innovations 2 and 3: set exit_threshold high, decay_rate=0\n",
21
+ "cfg = TMTConfig(\n",
22
+ " vocab_size=50258, d_model=512, n_heads=8, n_layers=6,\n",
23
+ " graph_k=8, decay_rate=0.0, exit_threshold=1.1, # never exit early\n",
24
+ " dual_stream=False, memory_anchors=0,\n",
25
+ " ffn_stream_dim=256,\n",
26
+ ")\n",
27
+ "print(cfg)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "loaders = load_text_dataset('wikitext-2', seq_len=256, batch_size=16)\n",
37
+ "train_cfg = TrainConfig(total_steps=2000, warmup_steps=200, use_wandb=False)\n",
38
+ "\n",
39
+ "trainer = TMTTrainer(cfg, train_cfg, loaders['train'], loaders.get('validation'))\n",
40
+ "trainer.train()\n",
41
+ "mesh_only_ppl = trainer.evaluate()\n",
42
+ "print(f'Mesh-only perplexity: {mesh_only_ppl:.2f}')"
43
+ ]
44
+ }
45
+ ],
46
+ "metadata": {
47
+ "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
48
+ "language_info": {"name": "python", "version": "3.10.0"}
49
+ },
50
+ "nbformat": 4,
51
+ "nbformat_minor": 4
52
+ }