zyzzyva commited on
Commit
3fe7988
·
1 Parent(s): 73bff3f

yeah we vibecoding

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -35
  2. README.md +12 -12
  3. app.py +233 -233
  4. pgptlformer.py +453 -453
  5. requirements.txt +4 -4
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Attn Shift Demo
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Attn Shift Demo
3
+ emoji: 📉
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.35.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,234 +1,234 @@
1
- import gradio as gr
2
- import torch
3
- import time
4
- import os
5
- from huggingface_hub import hf_hub_download
6
- import tiktoken
7
- import pgptlformer # Your model definition file
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- from contextlib import nullcontext
11
-
12
- # --- Configuration ---
13
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
14
- DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
15
- PTDTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[DTYPE]
16
- CTX = nullcontext() if DEVICE == 'cpu' else torch.amp.autocast(device_type=DEVICE, dtype=PTDTYPE)
17
- TORCH_COMPILE = False # Gradio instances can be slow, so compilation might timeout. Set to False for stability.
18
-
19
- # --- Model Loading ---
20
-
21
- @torch.no_grad()
22
- def load_model(repo_id, filename, config_override=None):
23
- """Loads a model from the Hugging Face Hub."""
24
- print(f"Loading model: {repo_id}/{filename}...")
25
- try:
26
- ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
27
- checkpoint = torch.load(ckpt_path, map_location=DEVICE)
28
-
29
- tformer_cfg = checkpoint['model_args']
30
- if config_override:
31
- tformer_cfg.update(config_override)
32
-
33
- model = pgptlformer.PGPT_Lformer(tformer_cfg)
34
- state_dict = checkpoint['model']
35
-
36
- # Clean up state dict if needed
37
- unwanted_prefix = '_orig_mod.'
38
- for k, v in list(state_dict.items()):
39
- if k.startswith(unwanted_prefix):
40
- state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
41
-
42
- model.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility
43
- model.eval()
44
- model.to(DEVICE)
45
-
46
- if TORCH_COMPILE:
47
- model = torch.compile(model)
48
-
49
- print(f"Model {filename} loaded successfully.")
50
- return model, tformer_cfg
51
- except Exception as e:
52
- print(f"Error loading model {filename}: {e}")
53
- raise
54
-
55
- # Load both models once at the start
56
- try:
57
- # This is the baseline model from your portfolio
58
- BASELINE_MODEL, BASELINE_CFG = load_model(
59
- repo_id="SQCU/pgptlformer-tinystories",
60
- filename="state_step040500.pt"
61
- )
62
-
63
- # This is the shift-attn model. Note the config_override.
64
- SHIFT_ATTN_MODEL, SHIFT_ATTN_CFG = load_model(
65
- repo_id="SQCU/pgptlformer-tinystories",
66
- filename="re-pqt-rmsXrmsx2x2-ATTNII-791967c5-5c59-4a5f-a2c5-07772bcf65ab/state_step040500.pt",
67
- config_override={"attention_deux": True} # Crucial: This enables the shift-attn mechanism in your code
68
- )
69
- except Exception as e:
70
- # If loading fails, show an error in the Gradio app instead of crashing
71
- BASELINE_MODEL, SHIFT_ATTN_MODEL = None, None
72
- ERROR_MESSAGE = f"Failed to load models. Please check logs. Error: {e}"
73
-
74
-
75
- # --- Inference and Metrics ---
76
-
77
- ENC = tiktoken.get_encoding("gpt2")
78
- ENCODE = lambda s: ENC.encode(s, allowed_special={"<|endoftext|>"})
79
- DECODE = lambda l: ENC.decode(l)
80
-
81
- @torch.no_grad()
82
- def generate_and_measure(model, prompt_ids, max_new_tokens=50):
83
- """Runs inference and calculates metrics."""
84
- # Reset stats for this run
85
- if DEVICE == 'cuda':
86
- torch.cuda.reset_peak_memory_stats(DEVICE)
87
- torch.cuda.synchronize()
88
-
89
- start_time = time.time()
90
-
91
- # --- Generation Loop ---
92
- model_logits = []
93
- generated_ids = prompt_ids
94
- for _ in range(max_new_tokens):
95
- idx_cond = generated_ids if generated_ids.size(1) <= 1024 else generated_ids[:, -1024:]
96
- logits, _, _ = model(idx_cond, return_logits=True)
97
-
98
- final_logits = logits[:, -1, :]
99
- model_logits.append(final_logits) # Store logits for perplexity/sharpening calc
100
-
101
- probs = torch.nn.functional.softmax(final_logits, dim=-1)
102
- idx_next = torch.multinomial(probs, num_samples=1)
103
- generated_ids = torch.cat((generated_ids, idx_next), dim=1)
104
-
105
- if DEVICE == 'cuda':
106
- torch.cuda.synchronize()
107
- end_time = time.time()
108
-
109
- # --- Metrics Calculation ---
110
- # 1. Inference Speed (Tokens/Second)
111
- tokens_per_sec = max_new_tokens / (end_time - start_time)
112
-
113
- # 2. VRAM Usage (MB)
114
- vram_usage = torch.cuda.max_memory_allocated(DEVICE) / (1024**2) if DEVICE == 'cuda' else 0
115
-
116
- # 3. Pseudo-Perplexity
117
- all_logits = torch.cat(model_logits, dim=0)
118
- target_ids = generated_ids[0, -max_new_tokens:]
119
- cross_entropy = torch.nn.functional.cross_entropy(all_logits, target_ids)
120
- pseudo_perplexity = torch.exp(cross_entropy).item()
121
-
122
- # 4. Logit Sharpening (Average of max probability)
123
- avg_max_prob = torch.nn.functional.softmax(all_logits, dim=-1).max(dim=-1).values.mean().item()
124
-
125
- # --- Decode and Return ---
126
- output_text = DECODE(generated_ids[0].tolist())
127
-
128
- metrics = {
129
- 'Tokens/Sec': tokens_per_sec,
130
- 'VRAM (MB)': vram_usage,
131
- 'Perplexity': pseudo_perplexity,
132
- 'Logit Sharpening': avg_max_prob,
133
- }
134
-
135
- return output_text, metrics
136
-
137
- # --- Visualization ---
138
-
139
- def plot_radar_chart(baseline_metrics, shift_attn_metrics):
140
- """Creates a radar chart comparing the two models."""
141
- labels = list(baseline_metrics.keys())
142
- baseline_stats = list(baseline_metrics.values())
143
- shift_attn_stats = list(shift_attn_metrics.values())
144
-
145
- # Normalize stats for plotting. Higher is better for all metrics on the chart.
146
- # We will take the inverse of Perplexity and VRAM for a "higher is better" visualization.
147
- baseline_plot_stats = [
148
- baseline_stats[0], # Tokens/Sec (Higher is better)
149
- 1 / (baseline_stats[1] + 1e-6), # VRAM (Inverse)
150
- 1 / (baseline_stats[2] + 1e-6), # Perplexity (Inverse)
151
- baseline_stats[3] # Sharpening (Higher is better)
152
- ]
153
- shift_attn_plot_stats = [
154
- shift_attn_stats[0],
155
- 1 / (shift_attn_stats[1] + 1e-6),
156
- 1 / (shift_attn_stats[2] + 1e-6),
157
- shift_attn_stats[3]
158
- ]
159
-
160
- angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
161
-
162
- # Make the plot circular
163
- baseline_plot_stats += baseline_plot_stats[:1]
164
- shift_attn_plot_stats += shift_attn_plot_stats[:1]
165
- angles += angles[:1]
166
-
167
- fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
168
-
169
- # Helper function to find nice plot limits
170
- def get_max_val(*args):
171
- return max(max(lst) for lst in args if lst) * 1.2
172
-
173
- ax.set_ylim(0, get_max_val(baseline_plot_stats, shift_attn_plot_stats))
174
-
175
- # Plot labels
176
- ax.set_xticks(angles[:-1])
177
- ax.set_xticklabels(["Tokens/Sec\n(Higher is Better)", "1 / VRAM\n(Higher is Better)", "1 / Perplexity\n(Higher is Better)", "Logit Sharpening\n(Higher is Better)"])
178
-
179
- # Plot data
180
- ax.plot(angles, baseline_plot_stats, 'o-', linewidth=2, label="Baseline")
181
- ax.fill(angles, baseline_plot_stats, alpha=0.25)
182
- ax.plot(angles, shift_attn_plot_stats, 'o-', linewidth=2, label="Shift-Attn")
183
- ax.fill(angles, shift_attn_plot_stats, alpha=0.25)
184
-
185
- ax.set_title("Model Performance Comparison", size=20, color='gray', y=1.1)
186
- ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
187
-
188
- plt.tight_layout()
189
- return fig
190
-
191
- # --- Gradio Interface ---
192
-
193
- def run_comparison(prompt, max_new_tokens):
194
- if not BASELINE_MODEL or not SHIFT_ATTN_MODEL:
195
- raise gr.Error(ERROR_MESSAGE)
196
-
197
- input_ids = ENCODE(prompt)
198
- x = (torch.tensor(input_ids, dtype=torch.long, device=DEVICE)[None, ...])
199
-
200
- # Run both models
201
- baseline_text, baseline_metrics = generate_and_measure(BASELINE_MODEL, x, max_new_tokens)
202
- shift_attn_text, shift_attn_metrics = generate_and_measure(SHIFT_ATTN_MODEL, x, max_new_tokens)
203
-
204
- # Create plot
205
- chart = plot_radar_chart(baseline_metrics, shift_attn_metrics)
206
-
207
- return baseline_text, shift_attn_text, chart
208
-
209
- with gr.Blocks(theme=gr.themes.Base()) as demo:
210
- gr.Markdown("# `shift-attn`: A Live Demonstration")
211
- gr.Markdown(
212
- "This demo compares a baseline `pgptlformer` model against an identical model enhanced with the `shift-attn` mechanism (`attention_deux`). "
213
- "The radar chart visualizes key performance and efficiency metrics, where a larger area indicates a better overall model."
214
- )
215
- with gr.Row():
216
- with gr.Column(scale=1):
217
- prompt_input = gr.Textbox(label="Enter your prompt:", value="The quick brown fox")
218
- token_slider = gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max New Tokens")
219
- submit_btn = gr.Button("Compare Models", variant="primary")
220
- with gr.Column(scale=2):
221
- plot_output = gr.Plot(label="Performance Radar Chart")
222
-
223
- with gr.Row():
224
- baseline_output = gr.Textbox(label="Baseline Model Output", lines=8)
225
- shift_attn_output = gr.Textbox(label="Shift-Attn Model Output", lines=8)
226
-
227
- submit_btn.click(
228
- fn=run_comparison,
229
- inputs=[prompt_input, token_slider],
230
- outputs=[baseline_output, shift_attn_output, plot_output]
231
- )
232
-
233
- if __name__ == "__main__":
234
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ import os
5
+ from huggingface_hub import hf_hub_download
6
+ import tiktoken
7
+ import pgptlformer # Your model definition file
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from contextlib import nullcontext
11
+
12
+ # --- Configuration ---
13
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
15
+ PTDTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[DTYPE]
16
+ CTX = nullcontext() if DEVICE == 'cpu' else torch.amp.autocast(device_type=DEVICE, dtype=PTDTYPE)
17
+ TORCH_COMPILE = False # Gradio instances can be slow, so compilation might timeout. Set to False for stability.
18
+
19
+ # --- Model Loading ---
20
+
21
+ @torch.no_grad()
22
+ def load_model(repo_id, filename, config_override=None):
23
+ """Loads a model from the Hugging Face Hub."""
24
+ print(f"Loading model: {repo_id}/{filename}...")
25
+ try:
26
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
27
+ checkpoint = torch.load(ckpt_path, map_location=DEVICE)
28
+
29
+ tformer_cfg = checkpoint['model_args']
30
+ if config_override:
31
+ tformer_cfg.update(config_override)
32
+
33
+ model = pgptlformer.PGPT_Lformer(tformer_cfg)
34
+ state_dict = checkpoint['model']
35
+
36
+ # Clean up state dict if needed
37
+ unwanted_prefix = '_orig_mod.'
38
+ for k, v in list(state_dict.items()):
39
+ if k.startswith(unwanted_prefix):
40
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
41
+
42
+ model.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility
43
+ model.eval()
44
+ model.to(DEVICE)
45
+
46
+ if TORCH_COMPILE:
47
+ model = torch.compile(model)
48
+
49
+ print(f"Model {filename} loaded successfully.")
50
+ return model, tformer_cfg
51
+ except Exception as e:
52
+ print(f"Error loading model {filename}: {e}")
53
+ raise
54
+
55
+ # Load both models once at the start
56
+ try:
57
+ # This is the baseline model from your portfolio
58
+ BASELINE_MODEL, BASELINE_CFG = load_model(
59
+ repo_id="SQCU/pgptlformer-tinystories",
60
+ filename="state_step040500.pt"
61
+ )
62
+
63
+ # This is the shift-attn model. Note the config_override.
64
+ SHIFT_ATTN_MODEL, SHIFT_ATTN_CFG = load_model(
65
+ repo_id="SQCU/pgptlformer-tinystories",
66
+ filename="re-pqt-rmsXrmsx2x2-ATTNII-791967c5-5c59-4a5f-a2c5-07772bcf65ab/state_step040500.pt",
67
+ config_override={"attention_deux": True} # Crucial: This enables the shift-attn mechanism in your code
68
+ )
69
+ except Exception as e:
70
+ # If loading fails, show an error in the Gradio app instead of crashing
71
+ BASELINE_MODEL, SHIFT_ATTN_MODEL = None, None
72
+ ERROR_MESSAGE = f"Failed to load models. Please check logs. Error: {e}"
73
+
74
+
75
+ # --- Inference and Metrics ---
76
+
77
+ ENC = tiktoken.get_encoding("gpt2")
78
+ ENCODE = lambda s: ENC.encode(s, allowed_special={"<|endoftext|>"})
79
+ DECODE = lambda l: ENC.decode(l)
80
+
81
+ @torch.no_grad()
82
+ def generate_and_measure(model, prompt_ids, max_new_tokens=50):
83
+ """Runs inference and calculates metrics."""
84
+ # Reset stats for this run
85
+ if DEVICE == 'cuda':
86
+ torch.cuda.reset_peak_memory_stats(DEVICE)
87
+ torch.cuda.synchronize()
88
+
89
+ start_time = time.time()
90
+
91
+ # --- Generation Loop ---
92
+ model_logits = []
93
+ generated_ids = prompt_ids
94
+ for _ in range(max_new_tokens):
95
+ idx_cond = generated_ids if generated_ids.size(1) <= 1024 else generated_ids[:, -1024:]
96
+ logits, _, _ = model(idx_cond, return_logits=True)
97
+
98
+ final_logits = logits[:, -1, :]
99
+ model_logits.append(final_logits) # Store logits for perplexity/sharpening calc
100
+
101
+ probs = torch.nn.functional.softmax(final_logits, dim=-1)
102
+ idx_next = torch.multinomial(probs, num_samples=1)
103
+ generated_ids = torch.cat((generated_ids, idx_next), dim=1)
104
+
105
+ if DEVICE == 'cuda':
106
+ torch.cuda.synchronize()
107
+ end_time = time.time()
108
+
109
+ # --- Metrics Calculation ---
110
+ # 1. Inference Speed (Tokens/Second)
111
+ tokens_per_sec = max_new_tokens / (end_time - start_time)
112
+
113
+ # 2. VRAM Usage (MB)
114
+ vram_usage = torch.cuda.max_memory_allocated(DEVICE) / (1024**2) if DEVICE == 'cuda' else 0
115
+
116
+ # 3. Pseudo-Perplexity
117
+ all_logits = torch.cat(model_logits, dim=0)
118
+ target_ids = generated_ids[0, -max_new_tokens:]
119
+ cross_entropy = torch.nn.functional.cross_entropy(all_logits, target_ids)
120
+ pseudo_perplexity = torch.exp(cross_entropy).item()
121
+
122
+ # 4. Logit Sharpening (Average of max probability)
123
+ avg_max_prob = torch.nn.functional.softmax(all_logits, dim=-1).max(dim=-1).values.mean().item()
124
+
125
+ # --- Decode and Return ---
126
+ output_text = DECODE(generated_ids[0].tolist())
127
+
128
+ metrics = {
129
+ 'Tokens/Sec': tokens_per_sec,
130
+ 'VRAM (MB)': vram_usage,
131
+ 'Perplexity': pseudo_perplexity,
132
+ 'Logit Sharpening': avg_max_prob,
133
+ }
134
+
135
+ return output_text, metrics
136
+
137
+ # --- Visualization ---
138
+
139
+ def plot_radar_chart(baseline_metrics, shift_attn_metrics):
140
+ """Creates a radar chart comparing the two models."""
141
+ labels = list(baseline_metrics.keys())
142
+ baseline_stats = list(baseline_metrics.values())
143
+ shift_attn_stats = list(shift_attn_metrics.values())
144
+
145
+ # Normalize stats for plotting. Higher is better for all metrics on the chart.
146
+ # We will take the inverse of Perplexity and VRAM for a "higher is better" visualization.
147
+ baseline_plot_stats = [
148
+ baseline_stats[0], # Tokens/Sec (Higher is better)
149
+ 1 / (baseline_stats[1] + 1e-6), # VRAM (Inverse)
150
+ 1 / (baseline_stats[2] + 1e-6), # Perplexity (Inverse)
151
+ baseline_stats[3] # Sharpening (Higher is better)
152
+ ]
153
+ shift_attn_plot_stats = [
154
+ shift_attn_stats[0],
155
+ 1 / (shift_attn_stats[1] + 1e-6),
156
+ 1 / (shift_attn_stats[2] + 1e-6),
157
+ shift_attn_stats[3]
158
+ ]
159
+
160
+ angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
161
+
162
+ # Make the plot circular
163
+ baseline_plot_stats += baseline_plot_stats[:1]
164
+ shift_attn_plot_stats += shift_attn_plot_stats[:1]
165
+ angles += angles[:1]
166
+
167
+ fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
168
+
169
+ # Helper function to find nice plot limits
170
+ def get_max_val(*args):
171
+ return max(max(lst) for lst in args if lst) * 1.2
172
+
173
+ ax.set_ylim(0, get_max_val(baseline_plot_stats, shift_attn_plot_stats))
174
+
175
+ # Plot labels
176
+ ax.set_xticks(angles[:-1])
177
+ ax.set_xticklabels(["Tokens/Sec\n(Higher is Better)", "1 / VRAM\n(Higher is Better)", "1 / Perplexity\n(Higher is Better)", "Logit Sharpening\n(Higher is Better)"])
178
+
179
+ # Plot data
180
+ ax.plot(angles, baseline_plot_stats, 'o-', linewidth=2, label="Baseline")
181
+ ax.fill(angles, baseline_plot_stats, alpha=0.25)
182
+ ax.plot(angles, shift_attn_plot_stats, 'o-', linewidth=2, label="Shift-Attn")
183
+ ax.fill(angles, shift_attn_plot_stats, alpha=0.25)
184
+
185
+ ax.set_title("Model Performance Comparison", size=20, color='gray', y=1.1)
186
+ ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
187
+
188
+ plt.tight_layout()
189
+ return fig
190
+
191
+ # --- Gradio Interface ---
192
+
193
+ def run_comparison(prompt, max_new_tokens):
194
+ if not BASELINE_MODEL or not SHIFT_ATTN_MODEL:
195
+ raise gr.Error(ERROR_MESSAGE)
196
+
197
+ input_ids = ENCODE(prompt)
198
+ x = (torch.tensor(input_ids, dtype=torch.long, device=DEVICE)[None, ...])
199
+
200
+ # Run both models
201
+ baseline_text, baseline_metrics = generate_and_measure(BASELINE_MODEL, x, max_new_tokens)
202
+ shift_attn_text, shift_attn_metrics = generate_and_measure(SHIFT_ATTN_MODEL, x, max_new_tokens)
203
+
204
+ # Create plot
205
+ chart = plot_radar_chart(baseline_metrics, shift_attn_metrics)
206
+
207
+ return baseline_text, shift_attn_text, chart
208
+
209
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
210
+ gr.Markdown("# `shift-attn`: A Live Demonstration")
211
+ gr.Markdown(
212
+ "This demo compares a baseline `pgptlformer` model against an identical model enhanced with the `shift-attn` mechanism (`attention_deux`). "
213
+ "The radar chart visualizes key performance and efficiency metrics, where a larger area indicates a better overall model."
214
+ )
215
+ with gr.Row():
216
+ with gr.Column(scale=1):
217
+ prompt_input = gr.Textbox(label="Enter your prompt:", value="The quick brown fox")
218
+ token_slider = gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max New Tokens")
219
+ submit_btn = gr.Button("Compare Models", variant="primary")
220
+ with gr.Column(scale=2):
221
+ plot_output = gr.Plot(label="Performance Radar Chart")
222
+
223
+ with gr.Row():
224
+ baseline_output = gr.Textbox(label="Baseline Model Output", lines=8)
225
+ shift_attn_output = gr.Textbox(label="Shift-Attn Model Output", lines=8)
226
+
227
+ submit_btn.click(
228
+ fn=run_comparison,
229
+ inputs=[prompt_input, token_slider],
230
+ outputs=[baseline_output, shift_attn_output, plot_output]
231
+ )
232
+
233
+ if __name__ == "__main__":
234
  demo.launch()
pgptlformer.py CHANGED
@@ -1,454 +1,454 @@
1
- # Import necessary, revised, libraries
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
-
6
- #dubious
7
- from torch.utils.data import DataLoader, TensorDataset
8
-
9
- #hehe
10
- import math
11
-
12
- ### note: modded_nanogpt.py is an more full container for a transformer block structure
13
- ### it should specify an encoder ("embedder") and decoder for autoregress on tinystories.
14
- ###
15
-
16
-
17
-
18
- ### 2302.05442's qk-layernorm is layernorm without centering and biases omitted.
19
- ### this is not equivalent to applying rmsnorm to the lexical scope of layernorm,
20
- ### as rmsnorm (1910.07467) doesn't use the mean statistic to yield variance.
21
- ### profiling and benchmarking a p%-pvarnorm would be great further work!
22
- ###
23
- ### to reach the complete spec of 2302.05442,
24
- ### noncentered nonbiased norms must be applied to projected q&k
25
- ###
26
- ### candidate default: cfig =
27
- ### {"dim":768,"dim_head":128,"headcount":6,"ffmult":4,
28
- ### "lambda":False,"layerwisenorm":"layernorm","qknorm":"identitynorm"}
29
- ### candidate tinystories:
30
- ### {"dim":256,"dim_head":32,"headcount":8,"ffmult":4,
31
- ### "lambda":True,"layerwisenorm":"rmsnorm","qknorm":"identitynorm"}
32
- ###
33
- ### 2401.14489 suggests GEneral Matrix Multiplication dim alignment.
34
- ### basically vocabsize%64:=0.
35
- ### and (emb_dim/heads)%2^k:=0 , for some integer k.
36
- ### and (batchsize*sequence_len)%2^k:=0 , for some integer k.
37
- ### this places the smallest possible seqlen at 64@bf16 and 128@fp8
38
- ###
39
- ### ...
40
- ### the swiglu returns to bite us. the presence of that doubled swiggy matrix does something!
41
- ### specifically. uh.
42
- ### actually because we cranked up the swiggy_dim by 2x, it follows all of our scaling rules
43
- ### lmao, lol, lol, lmao, etcetera.
44
- class vit22_tformer(nn.Module):
45
- def __init__(self, config):
46
- super().__init__()
47
- #query_dim = config["query_dim"] #don't even think about cross_attention
48
- self.dim = config["dim"]
49
- self.dim_head = config["dim_head"]
50
- self.heads = config["headcount"]
51
- self.weighted_skipnet = config["lambda"]
52
- self.denseproj_mul = config["ff_mult"]
53
- #self.naive_causal = config["is_causal_llm"]
54
- #...
55
- #self.qknormalized_shape = [config["dim_head"],config["training_seqlen"],config["headcount"],config["dim_head"],]
56
- self.qknormalized_shape = [config["headcount"],config["dim_head"]]
57
- self.layerwisenorm = getnorm(config["layerwisenorm"],shape=self.dim)
58
- self.projnorm = getnorm(config["qknorm"],shape=self.qknormalized_shape)
59
-
60
- attn_inner_dim = self.dim_head * self.heads
61
- self.denseproj_inner_dim = self.dim * self.denseproj_mul
62
-
63
- if "rotary_embedding_base" in config.keys():
64
- self.rotbase = config["rotary_embedding_base"]
65
- else:
66
- self.rotbase = 1000 # hehe
67
-
68
- self.attention_II = None
69
- if "attention_deux" in config.keys():
70
- self.attention_II = True
71
-
72
-
73
- self.rotary = rotarizer(self.dim_head, base=self.rotbase)
74
- self.learnedlambda = nn.Parameter(torch.tensor(1.0)) #my beloved
75
- self.fused_swiglu_dim = self.denseproj_inner_dim*2 #this is necessary so the swiglu's two projections can be applied as a single operation.
76
- self.scale = self.dim_head**-0.5 #this is the 's' in 's'dpa! #exposed for cosine attention reasons!
77
- self.l2normscale = None
78
- if config["qknorm"] == "l2norm": #bootleg cosine attention by overloading the scale term in sdpa
79
- self.l2normscale = nn.Parameter(torch.log(torch.tensor(config["training_seqlen"]**2)-torch.tensor(config["training_seqlen"])))
80
-
81
- #...
82
- self.queryproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
83
- self.keyproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
84
- self.valueproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
85
- self.attnoutproj = nn.Linear(in_features=attn_inner_dim, out_features=self.dim, bias=True)
86
-
87
- if self.attention_II:
88
- self.queryBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
89
- self.keyBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
90
-
91
- #dense ('mlp', 'feedforward', 'fully connected', ...) unit
92
- self.fused_denseproj_in = nn.Linear(in_features=self.dim, out_features=self.fused_swiglu_dim, bias=True) #this is the vit22b part
93
- self.dense_swiggy = swiglu() #this is kind of superfluous but this is pedagogical programming!
94
- self.denseproj_out = nn.Linear(in_features=self.denseproj_inner_dim, out_features=self.dim, bias=True)
95
-
96
- #[x]
97
- def self_attn(self, x, bat_len, seq_len):
98
- #norm -> {qkvproj -> qknorm{?}
99
- #reshape_h_d -> attn -> reshape_d_h} -> attnoutproj
100
- #project
101
- query = self.queryproj(x)
102
- key = self.keyproj(x)
103
- value = self.valueproj(x)
104
-
105
- if self.attention_II:
106
- biasquery = self.queryBproj(x)
107
- biaskey = self.keyBproj(x)
108
-
109
-
110
- #reshape to bundled up matmul formme
111
- #query = reshape_heads_dim(self.heads, query)
112
- #key = reshape_heads_dim(self.heads, key)
113
- #value = reshape_heads_dim(self.heads, value)
114
- #alternate reshape for compatibility with modded-nanogpt roformer
115
- query = query.view(bat_len, seq_len, self.heads, self.dim_head)
116
- key = key.view(bat_len, seq_len, self.heads, self.dim_head)
117
- value = value.view(bat_len, seq_len, self.heads, self.dim_head)
118
-
119
- if self.attention_II:
120
- biasquery = biasquery.view(bat_len, seq_len, self.heads, self.dim_head)
121
- biaskey = biaskey.view(bat_len, seq_len, self.heads, self.dim_head)
122
-
123
- #pos_emb suggested before qknorm re: kellerjordan re: @Grad62304977
124
- #but we get an error for the x.ndim assertion if we run this after reshaping. whoopsie!
125
- cos, sin = self.rotary(query) #our rotary unit does the shape detection from states
126
-
127
- #qk*norm
128
- query = self.projnorm(query)
129
- key = self.projnorm(key)
130
-
131
- if self.attention_II:
132
- biasquery = self.projnorm(biasquery)
133
- biaskey = self.projnorm(biaskey)
134
-
135
- #rotary embed after qknorm as suggested etc.
136
- query = apply_rotarizer_emb(query, cos, sin)
137
- key = apply_rotarizer_emb(key, cos, sin)
138
-
139
- if self.attention_II:
140
- biasquery = apply_rotarizer_emb(biasquery, cos, sin)
141
- biaskey = apply_rotarizer_emb(biaskey, cos, sin)
142
-
143
- #laser-attn goes here
144
- #...
145
-
146
- #if we were here to explain attention instead of projections and norms,
147
- #we would have written this in jax or a language that compiles well!
148
- #instead, to benefit from flash attention 2, we want to use torch SDPA!
149
- if self.l2normscale is not None:
150
- y = self.l2normscale*nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=1, is_causal=True)
151
- else:
152
- y = nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=self.scale, is_causal=True)
153
-
154
- if self.attention_II:
155
- #REV1
156
- dud = torch.ones_like(value, dtype=query.dtype, device=query.device)
157
- y = y + scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
158
- biasquery.transpose(1,2) , biaskey.transpose(1,2) , dud.transpose(1,2),
159
- scale=self.scale, is_causal=True
160
- )
161
- """
162
- #REV2
163
- #attn_bias now sums the shift matrix within the attn_bias operation to our 'value' target.
164
- y = scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
165
- biasquery.transpose(1,2), biaskey.transpose(1,2), y,
166
- scale=self.scale, is_causal=True
167
- )
168
- """
169
-
170
- #reshape scalars from folded position to unfolded position so the ribosome can read the messenger headrna
171
- #y = self.reshape_dim_heads(self.heads, y)
172
- #alternate reshape scalars
173
- y = y.transpose(1,2).contiguous().view_as(x) #thanks a bunch modded-nanogpt
174
-
175
- #laser-attn unscale goes here
176
- #...
177
-
178
- return self.attnoutproj(y)
179
-
180
- #[x]
181
- def feedfor(self,x):
182
- x = self.fused_denseproj_in(x)
183
- x = self.dense_swiggy(x)
184
- x = self.denseproj_out(x)
185
- return x
186
-
187
- #parallel forward from kingoflolz/mesh-transformer-jax/! check it out!!
188
- # "discovered by Wang et al + EleutherAI from GPT-J fame"
189
- def forward(self, h_states):
190
- # in trad dialect: b->batch, n,i,j,k,l,m,f,a,o -> sequentiality dims, h->heads, d->embedding dim
191
- bat_len, seq_len, emb_dim = h_states.size()
192
- # ^ detritus from modded-nanogpt transpose implementation. profile later ig.
193
-
194
- # highly traditional pre layernorm
195
- inner_states = self.layerwisenorm(h_states)
196
-
197
- #crunchy parts
198
- attn_out = self.self_attn(inner_states, bat_len, seq_len)
199
- dense_out = self.feedfor(inner_states)
200
- if self.weighted_skipnet==True:
201
- skip_out = h_states*self.learnedlambda
202
- else:
203
- skip_out = h_states
204
- #output w/ unabstracted resnet
205
- return skip_out + dense_out + attn_out
206
-
207
- def getnorm(type, shape=None):
208
- if type == "layernorm":
209
- return nn.LayerNorm(shape, elementwise_affine=True, bias=True)
210
- elif type == "layernorm-nobias":
211
- return nn.LayerNorm(shape, elementwise_affine=True, bias=False) #???
212
- elif type == "rmsnorm":
213
- return nn.RMSNorm(shape, elementwise_affine=False)
214
- elif type == "dynamic_shape_rmsnorm":
215
- return dynamic_shape_rmsnorm()
216
- elif type == "dynamic_shape_layernorm":
217
- return dynamic_shape_layernorm()
218
- elif type == "l2norm":
219
- return l2norm() #un function
220
- elif type == "identitynorm":
221
- return identitynorm(shape)
222
- else:
223
- raise Exception("Not implemented")
224
-
225
- class l2norm(nn.Module): #haha
226
- def forward(self, inputter, **kwargs):
227
- inputter = nn.functional.normalize(inputter, p=2, dim=-1)
228
- return inputter
229
-
230
- def identitynorm(row):
231
- return nn.Identity(row)
232
-
233
- #from `questions/76067020/`` lol
234
- class dynamic_shape_rmsnorm(nn.Module):
235
- def forward(self, inputter, **kwargs):
236
- inputter = inputter.transpose(1,2) #rotate!
237
- #i am so sorry haha
238
- #normalized_shape seems to require adjacencies, i tried a few other things first.
239
- #wait the notation in the paper suggests... [3:].
240
- inner_shape = inputter.size()[3:]
241
-
242
- nn.functional.rms_norm(inputter, normalized_shape=inner_shape, **kwargs)
243
- inputter = inputter.transpose(1,2) #reverse rotate!
244
- return inputter
245
-
246
- class dynamic_shape_layernorm(nn.Module):
247
- def forward(self, inputter, **kwargs):
248
- inputter = inputter.transpose(1,2) #rotate!
249
- #i am so sorry haha
250
- #normalized_shape seems to require adjacencies, i tried a few other things first.
251
- #wait the notation in the paper suggests... [3:].
252
- inner_shape = inputter.size()[3:]
253
-
254
- nn.functional.layer_norm(inputter, normalized_shape=inner_shape, **kwargs)
255
- inputter = inputter.transpose(1,2) #reverse rotate!
256
- return inputter
257
-
258
- #we too are hitting that mfing noam shazeer https://arxiv.org/pdf/2002.05202
259
- #if there was a self-gated ELU id want to use it instead though
260
- class swiglu(nn.Module):
261
- def forward(self, x):
262
- x, gate = x.chunk(2, dim=-1)
263
- return nn.functional.silu(gate) * x
264
-
265
- #rippin this one from modded-nanogpt
266
- class rotarizer(nn.Module):
267
- def __init__(self, dim, base=1000): #shhh don't tell anyone about the rotemb base
268
- super().__init__()
269
- self.inv_freq = (base ** (torch.arange(0,dim,2).float() / dim))**-1
270
- self.seq_len_cached = None
271
- self.cos_cached = None
272
- self.sin_cached = None
273
-
274
- def forward(self, x):
275
- seq_len = x.size()[1] #perform the surgical LENGTH,YOINKEMS,{}, b {n} h d
276
- #using torch tensor.size()[idx] notation bc i think it is more explicit than shape[]
277
- if seq_len != self.seq_len_cached:
278
- self.seq_len_cached = seq_len
279
- t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
280
- reg_freqs = torch.outer(t, self.inv_freq).to(x.device)
281
- self.cos_cached = reg_freqs.cos().bfloat16()
282
- self.sin_cached = reg_freqs.sin().bfloat16()
283
- return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
284
- #yeah slay em with the list comprehensions, cited author 😒
285
-
286
- def apply_rotarizer_emb(x, cos, sin):
287
- #assert x.ndim == 4 # b n h d
288
- d = x.size()[3]//2 # perform the superb DIVIDE,2,LENGTH,YOINKEMS,{}, b n h {d}
289
- x1 = x[..., :d] #some kind of slicing mystery code
290
- x2 = x[..., d:]
291
- y1 = x1 * cos + x2 * sin
292
- y2 = x1 * (-sin) + x2 * cos
293
- return torch.cat([y1,y2], 3).type_as(x)
294
-
295
- #alternate attention to retrieve a shift matrix instead of scale matrix.
296
- #this will either break the first time it runs or make perfect sense whomstdve doubted it all along
297
- #REVISION 1:
298
- #"""
299
- def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
300
- is_causal=False, scale=None, enable_gqa=False):
301
- #make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
302
- L, S = query.size(-2), key.size(-2)
303
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
304
- #inversion of normal masking since we're not softmaxing
305
- attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
306
-
307
- if is_causal: #sounds caus-tly to change haha heeehehee
308
- assert attn_mask is None
309
- temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
310
- attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
311
- attn_bias.to(query.dtype)
312
-
313
- if attn_mask is not None: #more boilerplate ty pytorch
314
- if attn_mask.dtype == torch.bool:
315
- attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
316
- else:
317
- attn_bias *= attn_mask
318
-
319
- if enable_gqa: #who can say what this does
320
- key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
321
- value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
322
-
323
- attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
324
- attn_magnitude *= attn_bias #* to combine instead of +
325
- #attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
326
- attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
327
- return attn_magnitude @ value
328
- #"""
329
- #REVISION 2: this doesn't benefit from abstract syntactic similarity to torch sdpa. so we gut it!
330
- #instead of creating a duds matrix of 1s to occupy the 'value' idx, we sum the shift-QK product directly
331
- #uncompiled this maybe uses fewer ops; profile and find out.
332
- """
333
- def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
334
- is_causal=False, scale=None, enable_gqa=False):
335
- #make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
336
- L, S = query.size(-2), key.size(-2)
337
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
338
- #inversion of normal masking since we're not softmaxing
339
- attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
340
-
341
- if is_causal: #sounds caus-tly to change haha heeehehee
342
- assert attn_mask is None
343
- temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
344
- attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
345
- attn_bias.to(query.dtype)
346
-
347
- if attn_mask is not None: #more boilerplate ty pytorch
348
- if attn_mask.dtype == torch.bool:
349
- attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
350
- else:
351
- attn_bias *= attn_mask
352
-
353
- if enable_gqa: #who can say what this does
354
- key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
355
- value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
356
-
357
- attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
358
- attn_magnitude *= attn_bias #* to combine instead of +
359
- #attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
360
- attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
361
- #... broadcasting... if A generalmatmul B, and A has shape (N x 1), B has shape (m x p),
362
- # 1 is prepended to A in torch broadcasting. then A matmul B. then prepend removed.
363
- # inplace prepend 1: a.unsqueeze_(0).
364
- #
365
- #attn_mag : b h n h_d ...
366
- #no it *wasn't*! it's b, h, n, n!
367
- #sdpa output (our v input) without transpose is
368
- # b h n h_d
369
- #so maybe we need to transpose sdpa_out by (-2, -1)
370
- #such that sdpa_out : b h h_d n, allowing
371
- #torch bmm of mat1:...{n X n} & mat2:...{h_d X n} --> bmmout: ...{h_d X n}
372
- #print(attn_magnitude.size())
373
- #print(value.size())
374
- #attn_magnitude.unsqueeze_(0)
375
- #...
376
- #wow okay this is tricky. we were using a ones row reduce.
377
- #basically last night we were assembling a (b h n_1 n_2) shape through biasq and biask matmul.
378
- #then we multiplied it by a ones of (b h n h_d),
379
- #which reduces b h n (n) to b h n (h_d)...
380
- #where h_d rows are copies of sum(n_2) at each n index in (b h n h_d).
381
- #meaning attn_II_rev1 was assigning a headwise bias along the entire sequence.
382
- #which itself would be chosen by the optimizer state transformation evolution of biasq and biask.
383
- return torch.add(attn_magnitude, value.transpose(-2,-1))
384
- """
385
-
386
-
387
- ### states take format batch, sequence, embedding
388
- ### therefore
389
- ### batch_size, sequence_length, embedding_dim = h_states.shape
390
- def reshape_heads_dim(heads, tensor):
391
- bat_len, seq_len, emb_dim = tensor.size()
392
- head_len = heads
393
- # i think equivalent to traditional
394
- # "b n (h d) -> b h n d"
395
- tensor = tensor.reshape(bat_len , seq_len, head_len, emb_dim // head_len)
396
- tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len*head_len, seq_len, emb_dim // head_len)
397
- return tensor
398
-
399
- def reshape_dim_heads(heads, tensor):
400
- bat_len, seq_len, emb_dim = tensor.size()
401
- head_len = heads
402
- # i think equivalent to traditional
403
- # "b h n d -> b n (h d)"
404
- tensor = tensor.reshape(bat_len // head_len, head_len, seq_len, emb_dim)
405
- tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len // head_len, seq_len, emb_dim*head_len)
406
- return tensor
407
-
408
-
409
- ###
410
- ### modelwise config:
411
- ### {"vocab_size":8000, "num_layers":4}
412
- ###
413
- class PGPT_Lformer(nn.Module):
414
- def __init__(self,config):
415
- super().__init__()
416
- self.config = config
417
-
418
- self.lambdaformer = nn.ModuleDict(dict(
419
- what_the_embedder_doin = nn.Embedding(config["vocab_size"], config["dim"]),
420
- blocks = nn.ModuleList([vit22_tformer(config) for _ in range(config["num_layers"])])
421
- ))
422
- self.tokenpicker_head = nn.Linear(in_features=config["dim"], out_features=config["vocab_size"], bias=False)
423
- self.tokenpicker_head.weight.data.zero_() #re: @Grad62304977
424
-
425
- def forward(self, index, targets=None, return_logits=True, return_zloss=False):
426
- x = self.lambdaformer.what_the_embedder_doin(index) # get token embeddings
427
- x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
428
- for decoder in self.lambdaformer.blocks:
429
- x = decoder(x)
430
- x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
431
-
432
- if targets is not None:
433
- #grab some losses woooo
434
- logits = self.tokenpicker_head(x)
435
- if return_zloss: #tracking https://arxiv.org/abs/2309.14322
436
- z = torch.sum(torch.exp(logits)) #reduce: e^logit[j]
437
- z_loss = torch.log(z)**2 #log and square Z. make sure to set a coefficient in trainer!
438
- logits = 30 * torch.tanh(logits / 30) # @Grad62304977
439
- logits = logits.float() # use tf32/fp32 for logits
440
- loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
441
- else:
442
- #kellerjordan optimi
443
- logits = self.tokenpicker_head(x[:, [-1], :]) # re: kj: note: using list [-1] to preserve the time dim
444
- logits = 30 * torch.tanh(logits / 30) # @Grad62304977
445
- logits = logits.float() # use tf32/fp32 for logits
446
- loss = None
447
-
448
- #an appeal to performance is made:
449
- if not return_logits:
450
- logits = None
451
- if not return_zloss:
452
- z_loss = None
453
-
454
  return logits, loss, z_loss
 
1
+ # Import necessary, revised, libraries
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+
6
+ #dubious
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+
9
+ #hehe
10
+ import math
11
+
12
+ ### note: modded_nanogpt.py is an more full container for a transformer block structure
13
+ ### it should specify an encoder ("embedder") and decoder for autoregress on tinystories.
14
+ ###
15
+
16
+
17
+
18
+ ### 2302.05442's qk-layernorm is layernorm without centering and biases omitted.
19
+ ### this is not equivalent to applying rmsnorm to the lexical scope of layernorm,
20
+ ### as rmsnorm (1910.07467) doesn't use the mean statistic to yield variance.
21
+ ### profiling and benchmarking a p%-pvarnorm would be great further work!
22
+ ###
23
+ ### to reach the complete spec of 2302.05442,
24
+ ### noncentered nonbiased norms must be applied to projected q&k
25
+ ###
26
+ ### candidate default: cfig =
27
+ ### {"dim":768,"dim_head":128,"headcount":6,"ffmult":4,
28
+ ### "lambda":False,"layerwisenorm":"layernorm","qknorm":"identitynorm"}
29
+ ### candidate tinystories:
30
+ ### {"dim":256,"dim_head":32,"headcount":8,"ffmult":4,
31
+ ### "lambda":True,"layerwisenorm":"rmsnorm","qknorm":"identitynorm"}
32
+ ###
33
+ ### 2401.14489 suggests GEneral Matrix Multiplication dim alignment.
34
+ ### basically vocabsize%64:=0.
35
+ ### and (emb_dim/heads)%2^k:=0 , for some integer k.
36
+ ### and (batchsize*sequence_len)%2^k:=0 , for some integer k.
37
+ ### this places the smallest possible seqlen at 64@bf16 and 128@fp8
38
+ ###
39
+ ### ...
40
+ ### the swiglu returns to bite us. the presence of that doubled swiggy matrix does something!
41
+ ### specifically. uh.
42
+ ### actually because we cranked up the swiggy_dim by 2x, it follows all of our scaling rules
43
+ ### lmao, lol, lol, lmao, etcetera.
44
+ class vit22_tformer(nn.Module):
45
+ def __init__(self, config):
46
+ super().__init__()
47
+ #query_dim = config["query_dim"] #don't even think about cross_attention
48
+ self.dim = config["dim"]
49
+ self.dim_head = config["dim_head"]
50
+ self.heads = config["headcount"]
51
+ self.weighted_skipnet = config["lambda"]
52
+ self.denseproj_mul = config["ff_mult"]
53
+ #self.naive_causal = config["is_causal_llm"]
54
+ #...
55
+ #self.qknormalized_shape = [config["dim_head"],config["training_seqlen"],config["headcount"],config["dim_head"],]
56
+ self.qknormalized_shape = [config["headcount"],config["dim_head"]]
57
+ self.layerwisenorm = getnorm(config["layerwisenorm"],shape=self.dim)
58
+ self.projnorm = getnorm(config["qknorm"],shape=self.qknormalized_shape)
59
+
60
+ attn_inner_dim = self.dim_head * self.heads
61
+ self.denseproj_inner_dim = self.dim * self.denseproj_mul
62
+
63
+ if "rotary_embedding_base" in config.keys():
64
+ self.rotbase = config["rotary_embedding_base"]
65
+ else:
66
+ self.rotbase = 1000 # hehe
67
+
68
+ self.attention_II = None
69
+ if "attention_deux" in config.keys():
70
+ self.attention_II = True
71
+
72
+
73
+ self.rotary = rotarizer(self.dim_head, base=self.rotbase)
74
+ self.learnedlambda = nn.Parameter(torch.tensor(1.0)) #my beloved
75
+ self.fused_swiglu_dim = self.denseproj_inner_dim*2 #this is necessary so the swiglu's two projections can be applied as a single operation.
76
+ self.scale = self.dim_head**-0.5 #this is the 's' in 's'dpa! #exposed for cosine attention reasons!
77
+ self.l2normscale = None
78
+ if config["qknorm"] == "l2norm": #bootleg cosine attention by overloading the scale term in sdpa
79
+ self.l2normscale = nn.Parameter(torch.log(torch.tensor(config["training_seqlen"]**2)-torch.tensor(config["training_seqlen"])))
80
+
81
+ #...
82
+ self.queryproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
83
+ self.keyproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
84
+ self.valueproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
85
+ self.attnoutproj = nn.Linear(in_features=attn_inner_dim, out_features=self.dim, bias=True)
86
+
87
+ if self.attention_II:
88
+ self.queryBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
89
+ self.keyBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
90
+
91
+ #dense ('mlp', 'feedforward', 'fully connected', ...) unit
92
+ self.fused_denseproj_in = nn.Linear(in_features=self.dim, out_features=self.fused_swiglu_dim, bias=True) #this is the vit22b part
93
+ self.dense_swiggy = swiglu() #this is kind of superfluous but this is pedagogical programming!
94
+ self.denseproj_out = nn.Linear(in_features=self.denseproj_inner_dim, out_features=self.dim, bias=True)
95
+
96
+ #[x]
97
+ def self_attn(self, x, bat_len, seq_len):
98
+ #norm -> {qkvproj -> qknorm{?}
99
+ #reshape_h_d -> attn -> reshape_d_h} -> attnoutproj
100
+ #project
101
+ query = self.queryproj(x)
102
+ key = self.keyproj(x)
103
+ value = self.valueproj(x)
104
+
105
+ if self.attention_II:
106
+ biasquery = self.queryBproj(x)
107
+ biaskey = self.keyBproj(x)
108
+
109
+
110
+ #reshape to bundled up matmul formme
111
+ #query = reshape_heads_dim(self.heads, query)
112
+ #key = reshape_heads_dim(self.heads, key)
113
+ #value = reshape_heads_dim(self.heads, value)
114
+ #alternate reshape for compatibility with modded-nanogpt roformer
115
+ query = query.view(bat_len, seq_len, self.heads, self.dim_head)
116
+ key = key.view(bat_len, seq_len, self.heads, self.dim_head)
117
+ value = value.view(bat_len, seq_len, self.heads, self.dim_head)
118
+
119
+ if self.attention_II:
120
+ biasquery = biasquery.view(bat_len, seq_len, self.heads, self.dim_head)
121
+ biaskey = biaskey.view(bat_len, seq_len, self.heads, self.dim_head)
122
+
123
+ #pos_emb suggested before qknorm re: kellerjordan re: @Grad62304977
124
+ #but we get an error for the x.ndim assertion if we run this after reshaping. whoopsie!
125
+ cos, sin = self.rotary(query) #our rotary unit does the shape detection from states
126
+
127
+ #qk*norm
128
+ query = self.projnorm(query)
129
+ key = self.projnorm(key)
130
+
131
+ if self.attention_II:
132
+ biasquery = self.projnorm(biasquery)
133
+ biaskey = self.projnorm(biaskey)
134
+
135
+ #rotary embed after qknorm as suggested etc.
136
+ query = apply_rotarizer_emb(query, cos, sin)
137
+ key = apply_rotarizer_emb(key, cos, sin)
138
+
139
+ if self.attention_II:
140
+ biasquery = apply_rotarizer_emb(biasquery, cos, sin)
141
+ biaskey = apply_rotarizer_emb(biaskey, cos, sin)
142
+
143
+ #laser-attn goes here
144
+ #...
145
+
146
+ #if we were here to explain attention instead of projections and norms,
147
+ #we would have written this in jax or a language that compiles well!
148
+ #instead, to benefit from flash attention 2, we want to use torch SDPA!
149
+ if self.l2normscale is not None:
150
+ y = self.l2normscale*nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=1, is_causal=True)
151
+ else:
152
+ y = nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=self.scale, is_causal=True)
153
+
154
+ if self.attention_II:
155
+ #REV1
156
+ dud = torch.ones_like(value, dtype=query.dtype, device=query.device)
157
+ y = y + scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
158
+ biasquery.transpose(1,2) , biaskey.transpose(1,2) , dud.transpose(1,2),
159
+ scale=self.scale, is_causal=True
160
+ )
161
+ """
162
+ #REV2
163
+ #attn_bias now sums the shift matrix within the attn_bias operation to our 'value' target.
164
+ y = scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
165
+ biasquery.transpose(1,2), biaskey.transpose(1,2), y,
166
+ scale=self.scale, is_causal=True
167
+ )
168
+ """
169
+
170
+ #reshape scalars from folded position to unfolded position so the ribosome can read the messenger headrna
171
+ #y = self.reshape_dim_heads(self.heads, y)
172
+ #alternate reshape scalars
173
+ y = y.transpose(1,2).contiguous().view_as(x) #thanks a bunch modded-nanogpt
174
+
175
+ #laser-attn unscale goes here
176
+ #...
177
+
178
+ return self.attnoutproj(y)
179
+
180
+ #[x]
181
+ def feedfor(self,x):
182
+ x = self.fused_denseproj_in(x)
183
+ x = self.dense_swiggy(x)
184
+ x = self.denseproj_out(x)
185
+ return x
186
+
187
+ #parallel forward from kingoflolz/mesh-transformer-jax/! check it out!!
188
+ # "discovered by Wang et al + EleutherAI from GPT-J fame"
189
+ def forward(self, h_states):
190
+ # in trad dialect: b->batch, n,i,j,k,l,m,f,a,o -> sequentiality dims, h->heads, d->embedding dim
191
+ bat_len, seq_len, emb_dim = h_states.size()
192
+ # ^ detritus from modded-nanogpt transpose implementation. profile later ig.
193
+
194
+ # highly traditional pre layernorm
195
+ inner_states = self.layerwisenorm(h_states)
196
+
197
+ #crunchy parts
198
+ attn_out = self.self_attn(inner_states, bat_len, seq_len)
199
+ dense_out = self.feedfor(inner_states)
200
+ if self.weighted_skipnet==True:
201
+ skip_out = h_states*self.learnedlambda
202
+ else:
203
+ skip_out = h_states
204
+ #output w/ unabstracted resnet
205
+ return skip_out + dense_out + attn_out
206
+
207
+ def getnorm(type, shape=None):
208
+ if type == "layernorm":
209
+ return nn.LayerNorm(shape, elementwise_affine=True, bias=True)
210
+ elif type == "layernorm-nobias":
211
+ return nn.LayerNorm(shape, elementwise_affine=True, bias=False) #???
212
+ elif type == "rmsnorm":
213
+ return nn.RMSNorm(shape, elementwise_affine=False)
214
+ elif type == "dynamic_shape_rmsnorm":
215
+ return dynamic_shape_rmsnorm()
216
+ elif type == "dynamic_shape_layernorm":
217
+ return dynamic_shape_layernorm()
218
+ elif type == "l2norm":
219
+ return l2norm() #un function
220
+ elif type == "identitynorm":
221
+ return identitynorm(shape)
222
+ else:
223
+ raise Exception("Not implemented")
224
+
225
+ class l2norm(nn.Module): #haha
226
+ def forward(self, inputter, **kwargs):
227
+ inputter = nn.functional.normalize(inputter, p=2, dim=-1)
228
+ return inputter
229
+
230
+ def identitynorm(row):
231
+ return nn.Identity(row)
232
+
233
+ #from `questions/76067020/`` lol
234
+ class dynamic_shape_rmsnorm(nn.Module):
235
+ def forward(self, inputter, **kwargs):
236
+ inputter = inputter.transpose(1,2) #rotate!
237
+ #i am so sorry haha
238
+ #normalized_shape seems to require adjacencies, i tried a few other things first.
239
+ #wait the notation in the paper suggests... [3:].
240
+ inner_shape = inputter.size()[3:]
241
+
242
+ nn.functional.rms_norm(inputter, normalized_shape=inner_shape, **kwargs)
243
+ inputter = inputter.transpose(1,2) #reverse rotate!
244
+ return inputter
245
+
246
+ class dynamic_shape_layernorm(nn.Module):
247
+ def forward(self, inputter, **kwargs):
248
+ inputter = inputter.transpose(1,2) #rotate!
249
+ #i am so sorry haha
250
+ #normalized_shape seems to require adjacencies, i tried a few other things first.
251
+ #wait the notation in the paper suggests... [3:].
252
+ inner_shape = inputter.size()[3:]
253
+
254
+ nn.functional.layer_norm(inputter, normalized_shape=inner_shape, **kwargs)
255
+ inputter = inputter.transpose(1,2) #reverse rotate!
256
+ return inputter
257
+
258
+ #we too are hitting that mfing noam shazeer https://arxiv.org/pdf/2002.05202
259
+ #if there was a self-gated ELU id want to use it instead though
260
+ class swiglu(nn.Module):
261
+ def forward(self, x):
262
+ x, gate = x.chunk(2, dim=-1)
263
+ return nn.functional.silu(gate) * x
264
+
265
+ #rippin this one from modded-nanogpt
266
+ class rotarizer(nn.Module):
267
+ def __init__(self, dim, base=1000): #shhh don't tell anyone about the rotemb base
268
+ super().__init__()
269
+ self.inv_freq = (base ** (torch.arange(0,dim,2).float() / dim))**-1
270
+ self.seq_len_cached = None
271
+ self.cos_cached = None
272
+ self.sin_cached = None
273
+
274
+ def forward(self, x):
275
+ seq_len = x.size()[1] #perform the surgical LENGTH,YOINKEMS,{}, b {n} h d
276
+ #using torch tensor.size()[idx] notation bc i think it is more explicit than shape[]
277
+ if seq_len != self.seq_len_cached:
278
+ self.seq_len_cached = seq_len
279
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
280
+ reg_freqs = torch.outer(t, self.inv_freq).to(x.device)
281
+ self.cos_cached = reg_freqs.cos().bfloat16()
282
+ self.sin_cached = reg_freqs.sin().bfloat16()
283
+ return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
284
+ #yeah slay em with the list comprehensions, cited author 😒
285
+
286
+ def apply_rotarizer_emb(x, cos, sin):
287
+ #assert x.ndim == 4 # b n h d
288
+ d = x.size()[3]//2 # perform the superb DIVIDE,2,LENGTH,YOINKEMS,{}, b n h {d}
289
+ x1 = x[..., :d] #some kind of slicing mystery code
290
+ x2 = x[..., d:]
291
+ y1 = x1 * cos + x2 * sin
292
+ y2 = x1 * (-sin) + x2 * cos
293
+ return torch.cat([y1,y2], 3).type_as(x)
294
+
295
+ #alternate attention to retrieve a shift matrix instead of scale matrix.
296
+ #this will either break the first time it runs or make perfect sense whomstdve doubted it all along
297
+ #REVISION 1:
298
+ #"""
299
+ def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
300
+ is_causal=False, scale=None, enable_gqa=False):
301
+ #make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
302
+ L, S = query.size(-2), key.size(-2)
303
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
304
+ #inversion of normal masking since we're not softmaxing
305
+ attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
306
+
307
+ if is_causal: #sounds caus-tly to change haha heeehehee
308
+ assert attn_mask is None
309
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
310
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
311
+ attn_bias.to(query.dtype)
312
+
313
+ if attn_mask is not None: #more boilerplate ty pytorch
314
+ if attn_mask.dtype == torch.bool:
315
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
316
+ else:
317
+ attn_bias *= attn_mask
318
+
319
+ if enable_gqa: #who can say what this does
320
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
321
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
322
+
323
+ attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
324
+ attn_magnitude *= attn_bias #* to combine instead of +
325
+ #attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
326
+ attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
327
+ return attn_magnitude @ value
328
+ #"""
329
+ #REVISION 2: this doesn't benefit from abstract syntactic similarity to torch sdpa. so we gut it!
330
+ #instead of creating a duds matrix of 1s to occupy the 'value' idx, we sum the shift-QK product directly
331
+ #uncompiled this maybe uses fewer ops; profile and find out.
332
+ """
333
+ def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
334
+ is_causal=False, scale=None, enable_gqa=False):
335
+ #make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
336
+ L, S = query.size(-2), key.size(-2)
337
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
338
+ #inversion of normal masking since we're not softmaxing
339
+ attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
340
+
341
+ if is_causal: #sounds caus-tly to change haha heeehehee
342
+ assert attn_mask is None
343
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
344
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
345
+ attn_bias.to(query.dtype)
346
+
347
+ if attn_mask is not None: #more boilerplate ty pytorch
348
+ if attn_mask.dtype == torch.bool:
349
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
350
+ else:
351
+ attn_bias *= attn_mask
352
+
353
+ if enable_gqa: #who can say what this does
354
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
355
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
356
+
357
+ attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
358
+ attn_magnitude *= attn_bias #* to combine instead of +
359
+ #attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
360
+ attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
361
+ #... broadcasting... if A generalmatmul B, and A has shape (N x 1), B has shape (m x p),
362
+ # 1 is prepended to A in torch broadcasting. then A matmul B. then prepend removed.
363
+ # inplace prepend 1: a.unsqueeze_(0).
364
+ #
365
+ #attn_mag : b h n h_d ...
366
+ #no it *wasn't*! it's b, h, n, n!
367
+ #sdpa output (our v input) without transpose is
368
+ # b h n h_d
369
+ #so maybe we need to transpose sdpa_out by (-2, -1)
370
+ #such that sdpa_out : b h h_d n, allowing
371
+ #torch bmm of mat1:...{n X n} & mat2:...{h_d X n} --> bmmout: ...{h_d X n}
372
+ #print(attn_magnitude.size())
373
+ #print(value.size())
374
+ #attn_magnitude.unsqueeze_(0)
375
+ #...
376
+ #wow okay this is tricky. we were using a ones row reduce.
377
+ #basically last night we were assembling a (b h n_1 n_2) shape through biasq and biask matmul.
378
+ #then we multiplied it by a ones of (b h n h_d),
379
+ #which reduces b h n (n) to b h n (h_d)...
380
+ #where h_d rows are copies of sum(n_2) at each n index in (b h n h_d).
381
+ #meaning attn_II_rev1 was assigning a headwise bias along the entire sequence.
382
+ #which itself would be chosen by the optimizer state transformation evolution of biasq and biask.
383
+ return torch.add(attn_magnitude, value.transpose(-2,-1))
384
+ """
385
+
386
+
387
+ ### states take format batch, sequence, embedding
388
+ ### therefore
389
+ ### batch_size, sequence_length, embedding_dim = h_states.shape
390
+ def reshape_heads_dim(heads, tensor):
391
+ bat_len, seq_len, emb_dim = tensor.size()
392
+ head_len = heads
393
+ # i think equivalent to traditional
394
+ # "b n (h d) -> b h n d"
395
+ tensor = tensor.reshape(bat_len , seq_len, head_len, emb_dim // head_len)
396
+ tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len*head_len, seq_len, emb_dim // head_len)
397
+ return tensor
398
+
399
+ def reshape_dim_heads(heads, tensor):
400
+ bat_len, seq_len, emb_dim = tensor.size()
401
+ head_len = heads
402
+ # i think equivalent to traditional
403
+ # "b h n d -> b n (h d)"
404
+ tensor = tensor.reshape(bat_len // head_len, head_len, seq_len, emb_dim)
405
+ tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len // head_len, seq_len, emb_dim*head_len)
406
+ return tensor
407
+
408
+
409
+ ###
410
+ ### modelwise config:
411
+ ### {"vocab_size":8000, "num_layers":4}
412
+ ###
413
+ class PGPT_Lformer(nn.Module):
414
+ def __init__(self,config):
415
+ super().__init__()
416
+ self.config = config
417
+
418
+ self.lambdaformer = nn.ModuleDict(dict(
419
+ what_the_embedder_doin = nn.Embedding(config["vocab_size"], config["dim"]),
420
+ blocks = nn.ModuleList([vit22_tformer(config) for _ in range(config["num_layers"])])
421
+ ))
422
+ self.tokenpicker_head = nn.Linear(in_features=config["dim"], out_features=config["vocab_size"], bias=False)
423
+ self.tokenpicker_head.weight.data.zero_() #re: @Grad62304977
424
+
425
+ def forward(self, index, targets=None, return_logits=True, return_zloss=False):
426
+ x = self.lambdaformer.what_the_embedder_doin(index) # get token embeddings
427
+ x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
428
+ for decoder in self.lambdaformer.blocks:
429
+ x = decoder(x)
430
+ x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
431
+
432
+ if targets is not None:
433
+ #grab some losses woooo
434
+ logits = self.tokenpicker_head(x)
435
+ if return_zloss: #tracking https://arxiv.org/abs/2309.14322
436
+ z = torch.sum(torch.exp(logits)) #reduce: e^logit[j]
437
+ z_loss = torch.log(z)**2 #log and square Z. make sure to set a coefficient in trainer!
438
+ logits = 30 * torch.tanh(logits / 30) # @Grad62304977
439
+ logits = logits.float() # use tf32/fp32 for logits
440
+ loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
441
+ else:
442
+ #kellerjordan optimi
443
+ logits = self.tokenpicker_head(x[:, [-1], :]) # re: kj: note: using list [-1] to preserve the time dim
444
+ logits = 30 * torch.tanh(logits / 30) # @Grad62304977
445
+ logits = logits.float() # use tf32/fp32 for logits
446
+ loss = None
447
+
448
+ #an appeal to performance is made:
449
+ if not return_logits:
450
+ logits = None
451
+ if not return_zloss:
452
+ z_loss = None
453
+
454
  return logits, loss, z_loss
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- torch
2
- tiktoken
3
- huggingface_hub
4
- gradio
5
  matplotlib
 
1
+ torch
2
+ tiktoken
3
+ huggingface_hub
4
+ gradio
5
  matplotlib