vesakkivignesh commited on
Commit
d7d8e66
·
0 Parent(s):

Initial clean Space commit

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. bdh.py +171 -0
  3. requirements.txt +3 -0
  4. train.py +126 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ input.txt
bdh.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Pathway Technology, Inc.
2
+
3
+ import dataclasses
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ @dataclasses.dataclass
12
+ class BDHConfig:
13
+ n_layer: int = 6
14
+ n_embd: int = 256
15
+ dropout: float = 0.1
16
+ n_head: int = 4
17
+ mlp_internal_dim_multiplier: int = 128
18
+ vocab_size: int = 256
19
+
20
+
21
+ def get_freqs(n, theta, dtype):
22
+ def quantize(t, q=2):
23
+ return (t / q).floor() * q
24
+
25
+ return (
26
+ 1.0
27
+ / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n))
28
+ / (2 * math.pi)
29
+ )
30
+
31
+
32
+ class Attention(torch.nn.Module):
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ self.config = config
36
+ nh = config.n_head
37
+ D = config.n_embd
38
+ N = config.mlp_internal_dim_multiplier * D // nh
39
+ self.freqs = torch.nn.Buffer(
40
+ get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N)
41
+ )
42
+
43
+ @staticmethod
44
+ def phases_cos_sin(phases):
45
+ phases = (phases % 1) * (2 * math.pi)
46
+ phases_cos = torch.cos(phases)
47
+ phases_sin = torch.sin(phases)
48
+ return phases_cos, phases_sin
49
+
50
+ @staticmethod
51
+ def rope(phases, v):
52
+ v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size())
53
+ phases_cos, phases_sin = Attention.phases_cos_sin(phases)
54
+ return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype)
55
+
56
+ def forward(self, Q, K, V):
57
+ assert self.freqs.dtype == torch.float32
58
+ assert K is Q
59
+ _, _, T, _ = Q.size()
60
+
61
+ r_phases = (
62
+ torch.arange(
63
+ 0,
64
+ T,
65
+ device=self.freqs.device,
66
+ dtype=self.freqs.dtype,
67
+ ).view(1, 1, -1, 1)
68
+ ) * self.freqs
69
+ QR = self.rope(r_phases, Q)
70
+ KR = QR
71
+
72
+ # Current attention
73
+ scores = (QR @ KR.mT).tril(diagonal=-1)
74
+ return scores @ V
75
+
76
+
77
+ class BDH(nn.Module):
78
+ def __init__(self, config: BDHConfig):
79
+ super().__init__()
80
+ assert config.vocab_size is not None
81
+ self.config = config
82
+ nh = config.n_head
83
+ D = config.n_embd
84
+ N = config.mlp_internal_dim_multiplier * D // nh
85
+ self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02))
86
+ self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
87
+
88
+ self.attn = Attention(config)
89
+
90
+ self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False)
91
+ self.embed = nn.Embedding(config.vocab_size, D)
92
+ self.drop = nn.Dropout(config.dropout)
93
+ self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
94
+
95
+ self.lm_head = nn.Parameter(
96
+ torch.zeros((D, config.vocab_size)).normal_(std=0.02)
97
+ )
98
+
99
+ self.apply(self._init_weights)
100
+
101
+ def _init_weights(self, module):
102
+ if isinstance(module, nn.Linear):
103
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
104
+ if module.bias is not None:
105
+ nn.init.zeros_(module.bias)
106
+ elif isinstance(module, nn.Embedding):
107
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
108
+
109
+ def forward(self, idx, targets=None):
110
+ C = self.config
111
+
112
+ B, T = idx.size()
113
+ D = C.n_embd
114
+ nh = C.n_head
115
+ N = D * C.mlp_internal_dim_multiplier // nh
116
+
117
+ x = self.embed(idx).unsqueeze(1)
118
+
119
+ # actually helps with training
120
+ x = self.ln(x) # B, 1, T, D
121
+
122
+ for level in range(C.n_layer):
123
+ x_latent = x @ self.encoder
124
+
125
+ x_sparse = F.relu(x_latent) # B, nh, T, N
126
+
127
+ yKV = self.attn(
128
+ Q=x_sparse,
129
+ K=x_sparse,
130
+ V=x,
131
+ )
132
+ yKV = self.ln(yKV)
133
+
134
+ y_latent = yKV @ self.encoder_v
135
+ y_sparse = F.relu(y_latent)
136
+ xy_sparse = x_sparse * y_sparse # B, nh, T, N
137
+
138
+ xy_sparse = self.drop(xy_sparse)
139
+
140
+ yMLP = (
141
+ xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder
142
+ ) # B, 1, T, D
143
+ y = self.ln(yMLP)
144
+ x = self.ln(x + y)
145
+
146
+ logits = x.view(B, T, D) @ self.lm_head
147
+ loss = None
148
+ if targets is not None:
149
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
150
+
151
+ return logits, loss
152
+
153
+ @torch.no_grad()
154
+ def generate(
155
+ self,
156
+ idx: torch.Tensor,
157
+ max_new_tokens: int,
158
+ temperature: float = 1.0,
159
+ top_k: int | None = None,
160
+ ) -> torch.Tensor:
161
+ for _ in range(max_new_tokens):
162
+ idx_cond = idx
163
+ logits, _ = self(idx_cond)
164
+ logits = logits[:, -1, :] / temperature
165
+ if top_k is not None:
166
+ values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
167
+ logits[logits < values[:, [-1]]] = float("-inf")
168
+ probs = F.softmax(logits, dim=-1)
169
+ idx_next = torch.multinomial(probs, num_samples=1)
170
+ idx = torch.cat((idx, idx_next), dim=1)
171
+ return idx
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ numpy
3
+ requests
train.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Pathway Technology, Inc.
2
+
3
+ import os
4
+ from contextlib import nullcontext
5
+
6
+ import bdh
7
+ import numpy as np
8
+ import requests
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ # On a Mac you can also try
15
+ # device=torch.device('mps')
16
+
17
+ dtype = (
18
+ "bfloat16"
19
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
20
+ else "float16"
21
+ ) # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
22
+ ptdtype = {
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ "float16": torch.float16,
26
+ }[dtype]
27
+ ctx = (
28
+ torch.amp.autocast(device_type=device.type, dtype=ptdtype)
29
+ if "cuda" in device.type
30
+ else nullcontext()
31
+ )
32
+ scaler = torch.amp.GradScaler(device=device.type, enabled=(dtype == "float16"))
33
+ torch.manual_seed(1337)
34
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
35
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
36
+ print(f"Using device: {device} with dtype {dtype}")
37
+
38
+
39
+ # Configuration
40
+ BDH_CONFIG = bdh.BDHConfig()
41
+ BLOCK_SIZE = 512
42
+ BATCH_SIZE = 32
43
+ MAX_ITERS = 3000
44
+ LEARNING_RATE = 1e-3
45
+ WEIGHT_DECAY = 0.1
46
+ LOG_FREQ = 100
47
+
48
+ input_file_path = os.path.join(os.path.dirname(__file__), "input.txt")
49
+
50
+
51
+ # Fetch the tiny Shakespeare dataset
52
+ def fetch_data():
53
+ if not os.path.exists(input_file_path):
54
+ data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
55
+ with open(input_file_path, "w") as f:
56
+ f.write(requests.get(data_url).text)
57
+
58
+
59
+ def get_batch(split):
60
+ # treat the file as bytes
61
+ data = np.memmap(input_file_path, dtype=np.uint8, mode="r")
62
+ if split == "train":
63
+ data = data[: int(0.9 * len(data))]
64
+ else:
65
+ data = data[int(0.9 * len(data)) :]
66
+ ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
67
+ x = torch.stack(
68
+ [torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix]
69
+ )
70
+ y = torch.stack(
71
+ [
72
+ torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64))
73
+ for i in ix
74
+ ]
75
+ )
76
+ if torch.cuda.is_available():
77
+ # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
78
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(
79
+ device, non_blocking=True
80
+ )
81
+ else:
82
+ x, y = x.to(device), y.to(device)
83
+ return x, y
84
+
85
+
86
+ def eval(model):
87
+ model.eval()
88
+
89
+
90
+ if __name__ == "__main__":
91
+ fetch_data()
92
+
93
+ model = bdh.BDH(BDH_CONFIG).to(device)
94
+ model = torch.compile(model)
95
+ optimizer = torch.optim.AdamW(
96
+ model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
97
+ )
98
+
99
+ x, y = get_batch("train")
100
+
101
+ loss_acc = 0
102
+ loss_steps = 0
103
+ for step in range(MAX_ITERS):
104
+ with ctx:
105
+ logits, loss = model(x, y)
106
+ x, y = get_batch("train")
107
+ loss_acc += loss
108
+ loss_steps += 1
109
+ scaler.scale(loss).backward()
110
+ scaler.step(optimizer)
111
+ scaler.update()
112
+ optimizer.zero_grad()
113
+ if step % LOG_FREQ == 0:
114
+ print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item() / loss_steps:.3}")
115
+ loss_acc = 0
116
+ loss_steps = 0
117
+ print("Training done, now generating a sample ")
118
+ model.eval()
119
+ prompt = torch.tensor(
120
+ bytearray("To be or ", "utf-8"), dtype=torch.long, device=device
121
+ ).unsqueeze(0)
122
+ ret = model.generate(prompt, max_new_tokens=100, top_k=3)
123
+ ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode(
124
+ errors="backslashreplace"
125
+ )
126
+ print(ret_decoded)