OpenTransformer commited on
Commit
d3d6c03
·
verified ·
1 Parent(s): 46d7788

Add tenstorrent/README.md — Tenstorrent N300s training port

Browse files
Files changed (1) hide show
  1. tenstorrent/README.md +78 -0
tenstorrent/README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tenstorrent port notes
2
+
3
+ ## What changed
4
+
5
+ - Added `--backend auto|cuda|tt|cpu`
6
+ - Added Tenstorrent runtime setup through TT-XLA / PJRT
7
+ - Training path uses XLA-style optimizer stepping on TT (`xm.optimizer_step`)
8
+ - Checkpoints are always saved with CPU tensors so they can move between CUDA and TT
9
+ - TT inference avoids dynamic KV-cache assumptions and uses a static-shape path for robustness
10
+ - Added TT tuning flags:
11
+ - `--tt_dtype fp32|bf16`
12
+ - `--tt_bfp8`
13
+ - `--tt_weight_bfp8`
14
+ - `--tt_optimization_level`
15
+ - `--tt_trace`
16
+ - `--tt_spmd` (experimental)
17
+
18
+ ## Koyeb setup sketch
19
+
20
+ ```bash
21
+ python3 -m venv .xla-venv
22
+ source .xla-venv/bin/activate
23
+ pip install pjrt-plugin-tt --extra-index-url https://pypi.eng.aws.tenstorrent.com/
24
+ pip install torch datasets transformers sentencepiece safetensors
25
+ ```
26
+
27
+ ## Training example
28
+
29
+ ```bash
30
+ python n_tenstorrent_port.py train \
31
+ --backend tt \
32
+ --preset nano_3x \
33
+ --steps 10000 \
34
+ --batch_size 4 \
35
+ --block 576 \
36
+ --save_dir /workspace/ckpts_expansion_tt \
37
+ --tt_dtype bf16 \
38
+ --tt_optimization_level 1
39
+ ```
40
+
41
+ ## Warm-start from NVIDIA checkpoint and continue training on TT
42
+
43
+ ```bash
44
+ python n_tenstorrent_port.py train \
45
+ --backend tt \
46
+ --preset nano_3x \
47
+ --warmstart_from /workspace/ckpts_expansion/final.pt \
48
+ --steps 10000 \
49
+ --batch_size 4 \
50
+ --block 576 \
51
+ --save_dir /workspace/ckpts_tt_resume \
52
+ --tt_dtype bf16
53
+ ```
54
+
55
+ ## Inference from NVIDIA-trained checkpoint on TT
56
+
57
+ ```bash
58
+ python n_tenstorrent_port.py infer \
59
+ --backend tt \
60
+ --mode ar \
61
+ --ckpt /workspace/ckpts_expansion/final.pt \
62
+ --prompt "The capital of France is" \
63
+ --max_new 64 \
64
+ --tt_dtype bf16
65
+ ```
66
+
67
+ ## Experimental two-chip attempt on N300
68
+
69
+ ```bash
70
+ python n_tenstorrent_port.py train \
71
+ --backend tt \
72
+ --tt_spmd \
73
+ --batch_size 8 \
74
+ --block 576 \
75
+ --steps 10000
76
+ ```
77
+
78
+ Use the SPMD flag carefully. It is intentionally marked experimental in the script.