vigneshwar234 commited on
Commit
a25d929
·
verified ·
1 Parent(s): f817d5c

Add source: tmt/experiments/03_full_tmt.ipynb

Browse files
Files changed (1) hide show
  1. tmt/experiments/03_full_tmt.ipynb +58 -0
tmt/experiments/03_full_tmt.ipynb ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": ["# Experiment 03 — Full TMT Training Run\n", "All three innovations active: Mesh Attention + Temporal Decay + Adaptive Depth Routing."]
7
+ },
8
+ {
9
+ "cell_type": "code",
10
+ "execution_count": null,
11
+ "metadata": {},
12
+ "outputs": [],
13
+ "source": [
14
+ "import torch\n",
15
+ "from tmt.model.config import TMTConfig\n",
16
+ "from tmt.training.trainer import TMTTrainer, TrainConfig\n",
17
+ "from tmt.data.dataset import load_text_dataset\n",
18
+ "\n",
19
+ "cfg = TMTConfig(\n",
20
+ " vocab_size=50258, d_model=512, n_heads=8, n_layers=12,\n",
21
+ " graph_k=8, decay_rate=0.1, exit_threshold=0.85,\n",
22
+ " dual_stream=True, memory_anchors=16, ffn_stream_dim=256,\n",
23
+ ")\n",
24
+ "print(cfg)\n",
25
+ "print(f'Device: {\"cuda\" if torch.cuda.is_available() else \"cpu\"}')"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "loaders = load_text_dataset('wikitext-2', seq_len=256, batch_size=16)\n",
35
+ "\n",
36
+ "train_cfg = TrainConfig(\n",
37
+ " total_steps=10_000,\n",
38
+ " warmup_steps=500,\n",
39
+ " eval_every=500,\n",
40
+ " save_every=1000,\n",
41
+ " use_wandb=False, # set True and login to wandb first\n",
42
+ ")\n",
43
+ "\n",
44
+ "trainer = TMTTrainer(cfg, train_cfg, loaders['train'], loaders.get('validation'))\n",
45
+ "trainer.train()\n",
46
+ "\n",
47
+ "tmt_ppl = trainer.evaluate()\n",
48
+ "print(f'Full TMT perplexity: {tmt_ppl:.2f}')"
49
+ ]
50
+ }
51
+ ],
52
+ "metadata": {
53
+ "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
54
+ "language_info": {"name": "python", "version": "3.10.0"}
55
+ },
56
+ "nbformat": 4,
57
+ "nbformat_minor": 4
58
+ }