Spaces:
Sleeping
Sleeping
zyzzyva commited on
Commit ·
3fe7988
1
Parent(s): 73bff3f
yeah we vibecoding
Browse files- .gitattributes +35 -35
- README.md +12 -12
- app.py +233 -233
- pgptlformer.py +453 -453
- requirements.txt +4 -4
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Attn Shift Demo
|
| 3 |
-
emoji: 📉
|
| 4 |
-
colorFrom: green
|
| 5 |
-
colorTo: gray
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.35.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Attn Shift Demo
|
| 3 |
+
emoji: 📉
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.35.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,234 +1,234 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import torch
|
| 3 |
-
import time
|
| 4 |
-
import os
|
| 5 |
-
from huggingface_hub import hf_hub_download
|
| 6 |
-
import tiktoken
|
| 7 |
-
import pgptlformer # Your model definition file
|
| 8 |
-
import matplotlib.pyplot as plt
|
| 9 |
-
import numpy as np
|
| 10 |
-
from contextlib import nullcontext
|
| 11 |
-
|
| 12 |
-
# --- Configuration ---
|
| 13 |
-
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 14 |
-
DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
|
| 15 |
-
PTDTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[DTYPE]
|
| 16 |
-
CTX = nullcontext() if DEVICE == 'cpu' else torch.amp.autocast(device_type=DEVICE, dtype=PTDTYPE)
|
| 17 |
-
TORCH_COMPILE = False # Gradio instances can be slow, so compilation might timeout. Set to False for stability.
|
| 18 |
-
|
| 19 |
-
# --- Model Loading ---
|
| 20 |
-
|
| 21 |
-
@torch.no_grad()
|
| 22 |
-
def load_model(repo_id, filename, config_override=None):
|
| 23 |
-
"""Loads a model from the Hugging Face Hub."""
|
| 24 |
-
print(f"Loading model: {repo_id}/{filename}...")
|
| 25 |
-
try:
|
| 26 |
-
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 27 |
-
checkpoint = torch.load(ckpt_path, map_location=DEVICE)
|
| 28 |
-
|
| 29 |
-
tformer_cfg = checkpoint['model_args']
|
| 30 |
-
if config_override:
|
| 31 |
-
tformer_cfg.update(config_override)
|
| 32 |
-
|
| 33 |
-
model = pgptlformer.PGPT_Lformer(tformer_cfg)
|
| 34 |
-
state_dict = checkpoint['model']
|
| 35 |
-
|
| 36 |
-
# Clean up state dict if needed
|
| 37 |
-
unwanted_prefix = '_orig_mod.'
|
| 38 |
-
for k, v in list(state_dict.items()):
|
| 39 |
-
if k.startswith(unwanted_prefix):
|
| 40 |
-
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
| 41 |
-
|
| 42 |
-
model.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility
|
| 43 |
-
model.eval()
|
| 44 |
-
model.to(DEVICE)
|
| 45 |
-
|
| 46 |
-
if TORCH_COMPILE:
|
| 47 |
-
model = torch.compile(model)
|
| 48 |
-
|
| 49 |
-
print(f"Model {filename} loaded successfully.")
|
| 50 |
-
return model, tformer_cfg
|
| 51 |
-
except Exception as e:
|
| 52 |
-
print(f"Error loading model {filename}: {e}")
|
| 53 |
-
raise
|
| 54 |
-
|
| 55 |
-
# Load both models once at the start
|
| 56 |
-
try:
|
| 57 |
-
# This is the baseline model from your portfolio
|
| 58 |
-
BASELINE_MODEL, BASELINE_CFG = load_model(
|
| 59 |
-
repo_id="SQCU/pgptlformer-tinystories",
|
| 60 |
-
filename="state_step040500.pt"
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
# This is the shift-attn model. Note the config_override.
|
| 64 |
-
SHIFT_ATTN_MODEL, SHIFT_ATTN_CFG = load_model(
|
| 65 |
-
repo_id="SQCU/pgptlformer-tinystories",
|
| 66 |
-
filename="re-pqt-rmsXrmsx2x2-ATTNII-791967c5-5c59-4a5f-a2c5-07772bcf65ab/state_step040500.pt",
|
| 67 |
-
config_override={"attention_deux": True} # Crucial: This enables the shift-attn mechanism in your code
|
| 68 |
-
)
|
| 69 |
-
except Exception as e:
|
| 70 |
-
# If loading fails, show an error in the Gradio app instead of crashing
|
| 71 |
-
BASELINE_MODEL, SHIFT_ATTN_MODEL = None, None
|
| 72 |
-
ERROR_MESSAGE = f"Failed to load models. Please check logs. Error: {e}"
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# --- Inference and Metrics ---
|
| 76 |
-
|
| 77 |
-
ENC = tiktoken.get_encoding("gpt2")
|
| 78 |
-
ENCODE = lambda s: ENC.encode(s, allowed_special={"<|endoftext|>"})
|
| 79 |
-
DECODE = lambda l: ENC.decode(l)
|
| 80 |
-
|
| 81 |
-
@torch.no_grad()
|
| 82 |
-
def generate_and_measure(model, prompt_ids, max_new_tokens=50):
|
| 83 |
-
"""Runs inference and calculates metrics."""
|
| 84 |
-
# Reset stats for this run
|
| 85 |
-
if DEVICE == 'cuda':
|
| 86 |
-
torch.cuda.reset_peak_memory_stats(DEVICE)
|
| 87 |
-
torch.cuda.synchronize()
|
| 88 |
-
|
| 89 |
-
start_time = time.time()
|
| 90 |
-
|
| 91 |
-
# --- Generation Loop ---
|
| 92 |
-
model_logits = []
|
| 93 |
-
generated_ids = prompt_ids
|
| 94 |
-
for _ in range(max_new_tokens):
|
| 95 |
-
idx_cond = generated_ids if generated_ids.size(1) <= 1024 else generated_ids[:, -1024:]
|
| 96 |
-
logits, _, _ = model(idx_cond, return_logits=True)
|
| 97 |
-
|
| 98 |
-
final_logits = logits[:, -1, :]
|
| 99 |
-
model_logits.append(final_logits) # Store logits for perplexity/sharpening calc
|
| 100 |
-
|
| 101 |
-
probs = torch.nn.functional.softmax(final_logits, dim=-1)
|
| 102 |
-
idx_next = torch.multinomial(probs, num_samples=1)
|
| 103 |
-
generated_ids = torch.cat((generated_ids, idx_next), dim=1)
|
| 104 |
-
|
| 105 |
-
if DEVICE == 'cuda':
|
| 106 |
-
torch.cuda.synchronize()
|
| 107 |
-
end_time = time.time()
|
| 108 |
-
|
| 109 |
-
# --- Metrics Calculation ---
|
| 110 |
-
# 1. Inference Speed (Tokens/Second)
|
| 111 |
-
tokens_per_sec = max_new_tokens / (end_time - start_time)
|
| 112 |
-
|
| 113 |
-
# 2. VRAM Usage (MB)
|
| 114 |
-
vram_usage = torch.cuda.max_memory_allocated(DEVICE) / (1024**2) if DEVICE == 'cuda' else 0
|
| 115 |
-
|
| 116 |
-
# 3. Pseudo-Perplexity
|
| 117 |
-
all_logits = torch.cat(model_logits, dim=0)
|
| 118 |
-
target_ids = generated_ids[0, -max_new_tokens:]
|
| 119 |
-
cross_entropy = torch.nn.functional.cross_entropy(all_logits, target_ids)
|
| 120 |
-
pseudo_perplexity = torch.exp(cross_entropy).item()
|
| 121 |
-
|
| 122 |
-
# 4. Logit Sharpening (Average of max probability)
|
| 123 |
-
avg_max_prob = torch.nn.functional.softmax(all_logits, dim=-1).max(dim=-1).values.mean().item()
|
| 124 |
-
|
| 125 |
-
# --- Decode and Return ---
|
| 126 |
-
output_text = DECODE(generated_ids[0].tolist())
|
| 127 |
-
|
| 128 |
-
metrics = {
|
| 129 |
-
'Tokens/Sec': tokens_per_sec,
|
| 130 |
-
'VRAM (MB)': vram_usage,
|
| 131 |
-
'Perplexity': pseudo_perplexity,
|
| 132 |
-
'Logit Sharpening': avg_max_prob,
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
return output_text, metrics
|
| 136 |
-
|
| 137 |
-
# --- Visualization ---
|
| 138 |
-
|
| 139 |
-
def plot_radar_chart(baseline_metrics, shift_attn_metrics):
|
| 140 |
-
"""Creates a radar chart comparing the two models."""
|
| 141 |
-
labels = list(baseline_metrics.keys())
|
| 142 |
-
baseline_stats = list(baseline_metrics.values())
|
| 143 |
-
shift_attn_stats = list(shift_attn_metrics.values())
|
| 144 |
-
|
| 145 |
-
# Normalize stats for plotting. Higher is better for all metrics on the chart.
|
| 146 |
-
# We will take the inverse of Perplexity and VRAM for a "higher is better" visualization.
|
| 147 |
-
baseline_plot_stats = [
|
| 148 |
-
baseline_stats[0], # Tokens/Sec (Higher is better)
|
| 149 |
-
1 / (baseline_stats[1] + 1e-6), # VRAM (Inverse)
|
| 150 |
-
1 / (baseline_stats[2] + 1e-6), # Perplexity (Inverse)
|
| 151 |
-
baseline_stats[3] # Sharpening (Higher is better)
|
| 152 |
-
]
|
| 153 |
-
shift_attn_plot_stats = [
|
| 154 |
-
shift_attn_stats[0],
|
| 155 |
-
1 / (shift_attn_stats[1] + 1e-6),
|
| 156 |
-
1 / (shift_attn_stats[2] + 1e-6),
|
| 157 |
-
shift_attn_stats[3]
|
| 158 |
-
]
|
| 159 |
-
|
| 160 |
-
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
|
| 161 |
-
|
| 162 |
-
# Make the plot circular
|
| 163 |
-
baseline_plot_stats += baseline_plot_stats[:1]
|
| 164 |
-
shift_attn_plot_stats += shift_attn_plot_stats[:1]
|
| 165 |
-
angles += angles[:1]
|
| 166 |
-
|
| 167 |
-
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
|
| 168 |
-
|
| 169 |
-
# Helper function to find nice plot limits
|
| 170 |
-
def get_max_val(*args):
|
| 171 |
-
return max(max(lst) for lst in args if lst) * 1.2
|
| 172 |
-
|
| 173 |
-
ax.set_ylim(0, get_max_val(baseline_plot_stats, shift_attn_plot_stats))
|
| 174 |
-
|
| 175 |
-
# Plot labels
|
| 176 |
-
ax.set_xticks(angles[:-1])
|
| 177 |
-
ax.set_xticklabels(["Tokens/Sec\n(Higher is Better)", "1 / VRAM\n(Higher is Better)", "1 / Perplexity\n(Higher is Better)", "Logit Sharpening\n(Higher is Better)"])
|
| 178 |
-
|
| 179 |
-
# Plot data
|
| 180 |
-
ax.plot(angles, baseline_plot_stats, 'o-', linewidth=2, label="Baseline")
|
| 181 |
-
ax.fill(angles, baseline_plot_stats, alpha=0.25)
|
| 182 |
-
ax.plot(angles, shift_attn_plot_stats, 'o-', linewidth=2, label="Shift-Attn")
|
| 183 |
-
ax.fill(angles, shift_attn_plot_stats, alpha=0.25)
|
| 184 |
-
|
| 185 |
-
ax.set_title("Model Performance Comparison", size=20, color='gray', y=1.1)
|
| 186 |
-
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
|
| 187 |
-
|
| 188 |
-
plt.tight_layout()
|
| 189 |
-
return fig
|
| 190 |
-
|
| 191 |
-
# --- Gradio Interface ---
|
| 192 |
-
|
| 193 |
-
def run_comparison(prompt, max_new_tokens):
|
| 194 |
-
if not BASELINE_MODEL or not SHIFT_ATTN_MODEL:
|
| 195 |
-
raise gr.Error(ERROR_MESSAGE)
|
| 196 |
-
|
| 197 |
-
input_ids = ENCODE(prompt)
|
| 198 |
-
x = (torch.tensor(input_ids, dtype=torch.long, device=DEVICE)[None, ...])
|
| 199 |
-
|
| 200 |
-
# Run both models
|
| 201 |
-
baseline_text, baseline_metrics = generate_and_measure(BASELINE_MODEL, x, max_new_tokens)
|
| 202 |
-
shift_attn_text, shift_attn_metrics = generate_and_measure(SHIFT_ATTN_MODEL, x, max_new_tokens)
|
| 203 |
-
|
| 204 |
-
# Create plot
|
| 205 |
-
chart = plot_radar_chart(baseline_metrics, shift_attn_metrics)
|
| 206 |
-
|
| 207 |
-
return baseline_text, shift_attn_text, chart
|
| 208 |
-
|
| 209 |
-
with gr.Blocks(theme=gr.themes.Base()) as demo:
|
| 210 |
-
gr.Markdown("# `shift-attn`: A Live Demonstration")
|
| 211 |
-
gr.Markdown(
|
| 212 |
-
"This demo compares a baseline `pgptlformer` model against an identical model enhanced with the `shift-attn` mechanism (`attention_deux`). "
|
| 213 |
-
"The radar chart visualizes key performance and efficiency metrics, where a larger area indicates a better overall model."
|
| 214 |
-
)
|
| 215 |
-
with gr.Row():
|
| 216 |
-
with gr.Column(scale=1):
|
| 217 |
-
prompt_input = gr.Textbox(label="Enter your prompt:", value="The quick brown fox")
|
| 218 |
-
token_slider = gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max New Tokens")
|
| 219 |
-
submit_btn = gr.Button("Compare Models", variant="primary")
|
| 220 |
-
with gr.Column(scale=2):
|
| 221 |
-
plot_output = gr.Plot(label="Performance Radar Chart")
|
| 222 |
-
|
| 223 |
-
with gr.Row():
|
| 224 |
-
baseline_output = gr.Textbox(label="Baseline Model Output", lines=8)
|
| 225 |
-
shift_attn_output = gr.Textbox(label="Shift-Attn Model Output", lines=8)
|
| 226 |
-
|
| 227 |
-
submit_btn.click(
|
| 228 |
-
fn=run_comparison,
|
| 229 |
-
inputs=[prompt_input, token_slider],
|
| 230 |
-
outputs=[baseline_output, shift_attn_output, plot_output]
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
if __name__ == "__main__":
|
| 234 |
demo.launch()
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
import tiktoken
|
| 7 |
+
import pgptlformer # Your model definition file
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from contextlib import nullcontext
|
| 11 |
+
|
| 12 |
+
# --- Configuration ---
|
| 13 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 14 |
+
DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
|
| 15 |
+
PTDTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[DTYPE]
|
| 16 |
+
CTX = nullcontext() if DEVICE == 'cpu' else torch.amp.autocast(device_type=DEVICE, dtype=PTDTYPE)
|
| 17 |
+
TORCH_COMPILE = False # Gradio instances can be slow, so compilation might timeout. Set to False for stability.
|
| 18 |
+
|
| 19 |
+
# --- Model Loading ---
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def load_model(repo_id, filename, config_override=None):
|
| 23 |
+
"""Loads a model from the Hugging Face Hub."""
|
| 24 |
+
print(f"Loading model: {repo_id}/{filename}...")
|
| 25 |
+
try:
|
| 26 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 27 |
+
checkpoint = torch.load(ckpt_path, map_location=DEVICE)
|
| 28 |
+
|
| 29 |
+
tformer_cfg = checkpoint['model_args']
|
| 30 |
+
if config_override:
|
| 31 |
+
tformer_cfg.update(config_override)
|
| 32 |
+
|
| 33 |
+
model = pgptlformer.PGPT_Lformer(tformer_cfg)
|
| 34 |
+
state_dict = checkpoint['model']
|
| 35 |
+
|
| 36 |
+
# Clean up state dict if needed
|
| 37 |
+
unwanted_prefix = '_orig_mod.'
|
| 38 |
+
for k, v in list(state_dict.items()):
|
| 39 |
+
if k.startswith(unwanted_prefix):
|
| 40 |
+
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
| 41 |
+
|
| 42 |
+
model.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility
|
| 43 |
+
model.eval()
|
| 44 |
+
model.to(DEVICE)
|
| 45 |
+
|
| 46 |
+
if TORCH_COMPILE:
|
| 47 |
+
model = torch.compile(model)
|
| 48 |
+
|
| 49 |
+
print(f"Model {filename} loaded successfully.")
|
| 50 |
+
return model, tformer_cfg
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Error loading model {filename}: {e}")
|
| 53 |
+
raise
|
| 54 |
+
|
| 55 |
+
# Load both models once at the start
|
| 56 |
+
try:
|
| 57 |
+
# This is the baseline model from your portfolio
|
| 58 |
+
BASELINE_MODEL, BASELINE_CFG = load_model(
|
| 59 |
+
repo_id="SQCU/pgptlformer-tinystories",
|
| 60 |
+
filename="state_step040500.pt"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# This is the shift-attn model. Note the config_override.
|
| 64 |
+
SHIFT_ATTN_MODEL, SHIFT_ATTN_CFG = load_model(
|
| 65 |
+
repo_id="SQCU/pgptlformer-tinystories",
|
| 66 |
+
filename="re-pqt-rmsXrmsx2x2-ATTNII-791967c5-5c59-4a5f-a2c5-07772bcf65ab/state_step040500.pt",
|
| 67 |
+
config_override={"attention_deux": True} # Crucial: This enables the shift-attn mechanism in your code
|
| 68 |
+
)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
# If loading fails, show an error in the Gradio app instead of crashing
|
| 71 |
+
BASELINE_MODEL, SHIFT_ATTN_MODEL = None, None
|
| 72 |
+
ERROR_MESSAGE = f"Failed to load models. Please check logs. Error: {e}"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# --- Inference and Metrics ---
|
| 76 |
+
|
| 77 |
+
ENC = tiktoken.get_encoding("gpt2")
|
| 78 |
+
ENCODE = lambda s: ENC.encode(s, allowed_special={"<|endoftext|>"})
|
| 79 |
+
DECODE = lambda l: ENC.decode(l)
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def generate_and_measure(model, prompt_ids, max_new_tokens=50):
|
| 83 |
+
"""Runs inference and calculates metrics."""
|
| 84 |
+
# Reset stats for this run
|
| 85 |
+
if DEVICE == 'cuda':
|
| 86 |
+
torch.cuda.reset_peak_memory_stats(DEVICE)
|
| 87 |
+
torch.cuda.synchronize()
|
| 88 |
+
|
| 89 |
+
start_time = time.time()
|
| 90 |
+
|
| 91 |
+
# --- Generation Loop ---
|
| 92 |
+
model_logits = []
|
| 93 |
+
generated_ids = prompt_ids
|
| 94 |
+
for _ in range(max_new_tokens):
|
| 95 |
+
idx_cond = generated_ids if generated_ids.size(1) <= 1024 else generated_ids[:, -1024:]
|
| 96 |
+
logits, _, _ = model(idx_cond, return_logits=True)
|
| 97 |
+
|
| 98 |
+
final_logits = logits[:, -1, :]
|
| 99 |
+
model_logits.append(final_logits) # Store logits for perplexity/sharpening calc
|
| 100 |
+
|
| 101 |
+
probs = torch.nn.functional.softmax(final_logits, dim=-1)
|
| 102 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 103 |
+
generated_ids = torch.cat((generated_ids, idx_next), dim=1)
|
| 104 |
+
|
| 105 |
+
if DEVICE == 'cuda':
|
| 106 |
+
torch.cuda.synchronize()
|
| 107 |
+
end_time = time.time()
|
| 108 |
+
|
| 109 |
+
# --- Metrics Calculation ---
|
| 110 |
+
# 1. Inference Speed (Tokens/Second)
|
| 111 |
+
tokens_per_sec = max_new_tokens / (end_time - start_time)
|
| 112 |
+
|
| 113 |
+
# 2. VRAM Usage (MB)
|
| 114 |
+
vram_usage = torch.cuda.max_memory_allocated(DEVICE) / (1024**2) if DEVICE == 'cuda' else 0
|
| 115 |
+
|
| 116 |
+
# 3. Pseudo-Perplexity
|
| 117 |
+
all_logits = torch.cat(model_logits, dim=0)
|
| 118 |
+
target_ids = generated_ids[0, -max_new_tokens:]
|
| 119 |
+
cross_entropy = torch.nn.functional.cross_entropy(all_logits, target_ids)
|
| 120 |
+
pseudo_perplexity = torch.exp(cross_entropy).item()
|
| 121 |
+
|
| 122 |
+
# 4. Logit Sharpening (Average of max probability)
|
| 123 |
+
avg_max_prob = torch.nn.functional.softmax(all_logits, dim=-1).max(dim=-1).values.mean().item()
|
| 124 |
+
|
| 125 |
+
# --- Decode and Return ---
|
| 126 |
+
output_text = DECODE(generated_ids[0].tolist())
|
| 127 |
+
|
| 128 |
+
metrics = {
|
| 129 |
+
'Tokens/Sec': tokens_per_sec,
|
| 130 |
+
'VRAM (MB)': vram_usage,
|
| 131 |
+
'Perplexity': pseudo_perplexity,
|
| 132 |
+
'Logit Sharpening': avg_max_prob,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
return output_text, metrics
|
| 136 |
+
|
| 137 |
+
# --- Visualization ---
|
| 138 |
+
|
| 139 |
+
def plot_radar_chart(baseline_metrics, shift_attn_metrics):
|
| 140 |
+
"""Creates a radar chart comparing the two models."""
|
| 141 |
+
labels = list(baseline_metrics.keys())
|
| 142 |
+
baseline_stats = list(baseline_metrics.values())
|
| 143 |
+
shift_attn_stats = list(shift_attn_metrics.values())
|
| 144 |
+
|
| 145 |
+
# Normalize stats for plotting. Higher is better for all metrics on the chart.
|
| 146 |
+
# We will take the inverse of Perplexity and VRAM for a "higher is better" visualization.
|
| 147 |
+
baseline_plot_stats = [
|
| 148 |
+
baseline_stats[0], # Tokens/Sec (Higher is better)
|
| 149 |
+
1 / (baseline_stats[1] + 1e-6), # VRAM (Inverse)
|
| 150 |
+
1 / (baseline_stats[2] + 1e-6), # Perplexity (Inverse)
|
| 151 |
+
baseline_stats[3] # Sharpening (Higher is better)
|
| 152 |
+
]
|
| 153 |
+
shift_attn_plot_stats = [
|
| 154 |
+
shift_attn_stats[0],
|
| 155 |
+
1 / (shift_attn_stats[1] + 1e-6),
|
| 156 |
+
1 / (shift_attn_stats[2] + 1e-6),
|
| 157 |
+
shift_attn_stats[3]
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
|
| 161 |
+
|
| 162 |
+
# Make the plot circular
|
| 163 |
+
baseline_plot_stats += baseline_plot_stats[:1]
|
| 164 |
+
shift_attn_plot_stats += shift_attn_plot_stats[:1]
|
| 165 |
+
angles += angles[:1]
|
| 166 |
+
|
| 167 |
+
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
|
| 168 |
+
|
| 169 |
+
# Helper function to find nice plot limits
|
| 170 |
+
def get_max_val(*args):
|
| 171 |
+
return max(max(lst) for lst in args if lst) * 1.2
|
| 172 |
+
|
| 173 |
+
ax.set_ylim(0, get_max_val(baseline_plot_stats, shift_attn_plot_stats))
|
| 174 |
+
|
| 175 |
+
# Plot labels
|
| 176 |
+
ax.set_xticks(angles[:-1])
|
| 177 |
+
ax.set_xticklabels(["Tokens/Sec\n(Higher is Better)", "1 / VRAM\n(Higher is Better)", "1 / Perplexity\n(Higher is Better)", "Logit Sharpening\n(Higher is Better)"])
|
| 178 |
+
|
| 179 |
+
# Plot data
|
| 180 |
+
ax.plot(angles, baseline_plot_stats, 'o-', linewidth=2, label="Baseline")
|
| 181 |
+
ax.fill(angles, baseline_plot_stats, alpha=0.25)
|
| 182 |
+
ax.plot(angles, shift_attn_plot_stats, 'o-', linewidth=2, label="Shift-Attn")
|
| 183 |
+
ax.fill(angles, shift_attn_plot_stats, alpha=0.25)
|
| 184 |
+
|
| 185 |
+
ax.set_title("Model Performance Comparison", size=20, color='gray', y=1.1)
|
| 186 |
+
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
|
| 187 |
+
|
| 188 |
+
plt.tight_layout()
|
| 189 |
+
return fig
|
| 190 |
+
|
| 191 |
+
# --- Gradio Interface ---
|
| 192 |
+
|
| 193 |
+
def run_comparison(prompt, max_new_tokens):
|
| 194 |
+
if not BASELINE_MODEL or not SHIFT_ATTN_MODEL:
|
| 195 |
+
raise gr.Error(ERROR_MESSAGE)
|
| 196 |
+
|
| 197 |
+
input_ids = ENCODE(prompt)
|
| 198 |
+
x = (torch.tensor(input_ids, dtype=torch.long, device=DEVICE)[None, ...])
|
| 199 |
+
|
| 200 |
+
# Run both models
|
| 201 |
+
baseline_text, baseline_metrics = generate_and_measure(BASELINE_MODEL, x, max_new_tokens)
|
| 202 |
+
shift_attn_text, shift_attn_metrics = generate_and_measure(SHIFT_ATTN_MODEL, x, max_new_tokens)
|
| 203 |
+
|
| 204 |
+
# Create plot
|
| 205 |
+
chart = plot_radar_chart(baseline_metrics, shift_attn_metrics)
|
| 206 |
+
|
| 207 |
+
return baseline_text, shift_attn_text, chart
|
| 208 |
+
|
| 209 |
+
with gr.Blocks(theme=gr.themes.Base()) as demo:
|
| 210 |
+
gr.Markdown("# `shift-attn`: A Live Demonstration")
|
| 211 |
+
gr.Markdown(
|
| 212 |
+
"This demo compares a baseline `pgptlformer` model against an identical model enhanced with the `shift-attn` mechanism (`attention_deux`). "
|
| 213 |
+
"The radar chart visualizes key performance and efficiency metrics, where a larger area indicates a better overall model."
|
| 214 |
+
)
|
| 215 |
+
with gr.Row():
|
| 216 |
+
with gr.Column(scale=1):
|
| 217 |
+
prompt_input = gr.Textbox(label="Enter your prompt:", value="The quick brown fox")
|
| 218 |
+
token_slider = gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max New Tokens")
|
| 219 |
+
submit_btn = gr.Button("Compare Models", variant="primary")
|
| 220 |
+
with gr.Column(scale=2):
|
| 221 |
+
plot_output = gr.Plot(label="Performance Radar Chart")
|
| 222 |
+
|
| 223 |
+
with gr.Row():
|
| 224 |
+
baseline_output = gr.Textbox(label="Baseline Model Output", lines=8)
|
| 225 |
+
shift_attn_output = gr.Textbox(label="Shift-Attn Model Output", lines=8)
|
| 226 |
+
|
| 227 |
+
submit_btn.click(
|
| 228 |
+
fn=run_comparison,
|
| 229 |
+
inputs=[prompt_input, token_slider],
|
| 230 |
+
outputs=[baseline_output, shift_attn_output, plot_output]
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
demo.launch()
|
pgptlformer.py
CHANGED
|
@@ -1,454 +1,454 @@
|
|
| 1 |
-
# Import necessary, revised, libraries
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.optim as optim
|
| 5 |
-
|
| 6 |
-
#dubious
|
| 7 |
-
from torch.utils.data import DataLoader, TensorDataset
|
| 8 |
-
|
| 9 |
-
#hehe
|
| 10 |
-
import math
|
| 11 |
-
|
| 12 |
-
### note: modded_nanogpt.py is an more full container for a transformer block structure
|
| 13 |
-
### it should specify an encoder ("embedder") and decoder for autoregress on tinystories.
|
| 14 |
-
###
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
### 2302.05442's qk-layernorm is layernorm without centering and biases omitted.
|
| 19 |
-
### this is not equivalent to applying rmsnorm to the lexical scope of layernorm,
|
| 20 |
-
### as rmsnorm (1910.07467) doesn't use the mean statistic to yield variance.
|
| 21 |
-
### profiling and benchmarking a p%-pvarnorm would be great further work!
|
| 22 |
-
###
|
| 23 |
-
### to reach the complete spec of 2302.05442,
|
| 24 |
-
### noncentered nonbiased norms must be applied to projected q&k
|
| 25 |
-
###
|
| 26 |
-
### candidate default: cfig =
|
| 27 |
-
### {"dim":768,"dim_head":128,"headcount":6,"ffmult":4,
|
| 28 |
-
### "lambda":False,"layerwisenorm":"layernorm","qknorm":"identitynorm"}
|
| 29 |
-
### candidate tinystories:
|
| 30 |
-
### {"dim":256,"dim_head":32,"headcount":8,"ffmult":4,
|
| 31 |
-
### "lambda":True,"layerwisenorm":"rmsnorm","qknorm":"identitynorm"}
|
| 32 |
-
###
|
| 33 |
-
### 2401.14489 suggests GEneral Matrix Multiplication dim alignment.
|
| 34 |
-
### basically vocabsize%64:=0.
|
| 35 |
-
### and (emb_dim/heads)%2^k:=0 , for some integer k.
|
| 36 |
-
### and (batchsize*sequence_len)%2^k:=0 , for some integer k.
|
| 37 |
-
### this places the smallest possible seqlen at 64@bf16 and 128@fp8
|
| 38 |
-
###
|
| 39 |
-
### ...
|
| 40 |
-
### the swiglu returns to bite us. the presence of that doubled swiggy matrix does something!
|
| 41 |
-
### specifically. uh.
|
| 42 |
-
### actually because we cranked up the swiggy_dim by 2x, it follows all of our scaling rules
|
| 43 |
-
### lmao, lol, lol, lmao, etcetera.
|
| 44 |
-
class vit22_tformer(nn.Module):
|
| 45 |
-
def __init__(self, config):
|
| 46 |
-
super().__init__()
|
| 47 |
-
#query_dim = config["query_dim"] #don't even think about cross_attention
|
| 48 |
-
self.dim = config["dim"]
|
| 49 |
-
self.dim_head = config["dim_head"]
|
| 50 |
-
self.heads = config["headcount"]
|
| 51 |
-
self.weighted_skipnet = config["lambda"]
|
| 52 |
-
self.denseproj_mul = config["ff_mult"]
|
| 53 |
-
#self.naive_causal = config["is_causal_llm"]
|
| 54 |
-
#...
|
| 55 |
-
#self.qknormalized_shape = [config["dim_head"],config["training_seqlen"],config["headcount"],config["dim_head"],]
|
| 56 |
-
self.qknormalized_shape = [config["headcount"],config["dim_head"]]
|
| 57 |
-
self.layerwisenorm = getnorm(config["layerwisenorm"],shape=self.dim)
|
| 58 |
-
self.projnorm = getnorm(config["qknorm"],shape=self.qknormalized_shape)
|
| 59 |
-
|
| 60 |
-
attn_inner_dim = self.dim_head * self.heads
|
| 61 |
-
self.denseproj_inner_dim = self.dim * self.denseproj_mul
|
| 62 |
-
|
| 63 |
-
if "rotary_embedding_base" in config.keys():
|
| 64 |
-
self.rotbase = config["rotary_embedding_base"]
|
| 65 |
-
else:
|
| 66 |
-
self.rotbase = 1000 # hehe
|
| 67 |
-
|
| 68 |
-
self.attention_II = None
|
| 69 |
-
if "attention_deux" in config.keys():
|
| 70 |
-
self.attention_II = True
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
self.rotary = rotarizer(self.dim_head, base=self.rotbase)
|
| 74 |
-
self.learnedlambda = nn.Parameter(torch.tensor(1.0)) #my beloved
|
| 75 |
-
self.fused_swiglu_dim = self.denseproj_inner_dim*2 #this is necessary so the swiglu's two projections can be applied as a single operation.
|
| 76 |
-
self.scale = self.dim_head**-0.5 #this is the 's' in 's'dpa! #exposed for cosine attention reasons!
|
| 77 |
-
self.l2normscale = None
|
| 78 |
-
if config["qknorm"] == "l2norm": #bootleg cosine attention by overloading the scale term in sdpa
|
| 79 |
-
self.l2normscale = nn.Parameter(torch.log(torch.tensor(config["training_seqlen"]**2)-torch.tensor(config["training_seqlen"])))
|
| 80 |
-
|
| 81 |
-
#...
|
| 82 |
-
self.queryproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 83 |
-
self.keyproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 84 |
-
self.valueproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 85 |
-
self.attnoutproj = nn.Linear(in_features=attn_inner_dim, out_features=self.dim, bias=True)
|
| 86 |
-
|
| 87 |
-
if self.attention_II:
|
| 88 |
-
self.queryBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 89 |
-
self.keyBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 90 |
-
|
| 91 |
-
#dense ('mlp', 'feedforward', 'fully connected', ...) unit
|
| 92 |
-
self.fused_denseproj_in = nn.Linear(in_features=self.dim, out_features=self.fused_swiglu_dim, bias=True) #this is the vit22b part
|
| 93 |
-
self.dense_swiggy = swiglu() #this is kind of superfluous but this is pedagogical programming!
|
| 94 |
-
self.denseproj_out = nn.Linear(in_features=self.denseproj_inner_dim, out_features=self.dim, bias=True)
|
| 95 |
-
|
| 96 |
-
#[x]
|
| 97 |
-
def self_attn(self, x, bat_len, seq_len):
|
| 98 |
-
#norm -> {qkvproj -> qknorm{?}
|
| 99 |
-
#reshape_h_d -> attn -> reshape_d_h} -> attnoutproj
|
| 100 |
-
#project
|
| 101 |
-
query = self.queryproj(x)
|
| 102 |
-
key = self.keyproj(x)
|
| 103 |
-
value = self.valueproj(x)
|
| 104 |
-
|
| 105 |
-
if self.attention_II:
|
| 106 |
-
biasquery = self.queryBproj(x)
|
| 107 |
-
biaskey = self.keyBproj(x)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
#reshape to bundled up matmul formme
|
| 111 |
-
#query = reshape_heads_dim(self.heads, query)
|
| 112 |
-
#key = reshape_heads_dim(self.heads, key)
|
| 113 |
-
#value = reshape_heads_dim(self.heads, value)
|
| 114 |
-
#alternate reshape for compatibility with modded-nanogpt roformer
|
| 115 |
-
query = query.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 116 |
-
key = key.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 117 |
-
value = value.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 118 |
-
|
| 119 |
-
if self.attention_II:
|
| 120 |
-
biasquery = biasquery.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 121 |
-
biaskey = biaskey.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 122 |
-
|
| 123 |
-
#pos_emb suggested before qknorm re: kellerjordan re: @Grad62304977
|
| 124 |
-
#but we get an error for the x.ndim assertion if we run this after reshaping. whoopsie!
|
| 125 |
-
cos, sin = self.rotary(query) #our rotary unit does the shape detection from states
|
| 126 |
-
|
| 127 |
-
#qk*norm
|
| 128 |
-
query = self.projnorm(query)
|
| 129 |
-
key = self.projnorm(key)
|
| 130 |
-
|
| 131 |
-
if self.attention_II:
|
| 132 |
-
biasquery = self.projnorm(biasquery)
|
| 133 |
-
biaskey = self.projnorm(biaskey)
|
| 134 |
-
|
| 135 |
-
#rotary embed after qknorm as suggested etc.
|
| 136 |
-
query = apply_rotarizer_emb(query, cos, sin)
|
| 137 |
-
key = apply_rotarizer_emb(key, cos, sin)
|
| 138 |
-
|
| 139 |
-
if self.attention_II:
|
| 140 |
-
biasquery = apply_rotarizer_emb(biasquery, cos, sin)
|
| 141 |
-
biaskey = apply_rotarizer_emb(biaskey, cos, sin)
|
| 142 |
-
|
| 143 |
-
#laser-attn goes here
|
| 144 |
-
#...
|
| 145 |
-
|
| 146 |
-
#if we were here to explain attention instead of projections and norms,
|
| 147 |
-
#we would have written this in jax or a language that compiles well!
|
| 148 |
-
#instead, to benefit from flash attention 2, we want to use torch SDPA!
|
| 149 |
-
if self.l2normscale is not None:
|
| 150 |
-
y = self.l2normscale*nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=1, is_causal=True)
|
| 151 |
-
else:
|
| 152 |
-
y = nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=self.scale, is_causal=True)
|
| 153 |
-
|
| 154 |
-
if self.attention_II:
|
| 155 |
-
#REV1
|
| 156 |
-
dud = torch.ones_like(value, dtype=query.dtype, device=query.device)
|
| 157 |
-
y = y + scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
|
| 158 |
-
biasquery.transpose(1,2) , biaskey.transpose(1,2) , dud.transpose(1,2),
|
| 159 |
-
scale=self.scale, is_causal=True
|
| 160 |
-
)
|
| 161 |
-
"""
|
| 162 |
-
#REV2
|
| 163 |
-
#attn_bias now sums the shift matrix within the attn_bias operation to our 'value' target.
|
| 164 |
-
y = scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
|
| 165 |
-
biasquery.transpose(1,2), biaskey.transpose(1,2), y,
|
| 166 |
-
scale=self.scale, is_causal=True
|
| 167 |
-
)
|
| 168 |
-
"""
|
| 169 |
-
|
| 170 |
-
#reshape scalars from folded position to unfolded position so the ribosome can read the messenger headrna
|
| 171 |
-
#y = self.reshape_dim_heads(self.heads, y)
|
| 172 |
-
#alternate reshape scalars
|
| 173 |
-
y = y.transpose(1,2).contiguous().view_as(x) #thanks a bunch modded-nanogpt
|
| 174 |
-
|
| 175 |
-
#laser-attn unscale goes here
|
| 176 |
-
#...
|
| 177 |
-
|
| 178 |
-
return self.attnoutproj(y)
|
| 179 |
-
|
| 180 |
-
#[x]
|
| 181 |
-
def feedfor(self,x):
|
| 182 |
-
x = self.fused_denseproj_in(x)
|
| 183 |
-
x = self.dense_swiggy(x)
|
| 184 |
-
x = self.denseproj_out(x)
|
| 185 |
-
return x
|
| 186 |
-
|
| 187 |
-
#parallel forward from kingoflolz/mesh-transformer-jax/! check it out!!
|
| 188 |
-
# "discovered by Wang et al + EleutherAI from GPT-J fame"
|
| 189 |
-
def forward(self, h_states):
|
| 190 |
-
# in trad dialect: b->batch, n,i,j,k,l,m,f,a,o -> sequentiality dims, h->heads, d->embedding dim
|
| 191 |
-
bat_len, seq_len, emb_dim = h_states.size()
|
| 192 |
-
# ^ detritus from modded-nanogpt transpose implementation. profile later ig.
|
| 193 |
-
|
| 194 |
-
# highly traditional pre layernorm
|
| 195 |
-
inner_states = self.layerwisenorm(h_states)
|
| 196 |
-
|
| 197 |
-
#crunchy parts
|
| 198 |
-
attn_out = self.self_attn(inner_states, bat_len, seq_len)
|
| 199 |
-
dense_out = self.feedfor(inner_states)
|
| 200 |
-
if self.weighted_skipnet==True:
|
| 201 |
-
skip_out = h_states*self.learnedlambda
|
| 202 |
-
else:
|
| 203 |
-
skip_out = h_states
|
| 204 |
-
#output w/ unabstracted resnet
|
| 205 |
-
return skip_out + dense_out + attn_out
|
| 206 |
-
|
| 207 |
-
def getnorm(type, shape=None):
|
| 208 |
-
if type == "layernorm":
|
| 209 |
-
return nn.LayerNorm(shape, elementwise_affine=True, bias=True)
|
| 210 |
-
elif type == "layernorm-nobias":
|
| 211 |
-
return nn.LayerNorm(shape, elementwise_affine=True, bias=False) #???
|
| 212 |
-
elif type == "rmsnorm":
|
| 213 |
-
return nn.RMSNorm(shape, elementwise_affine=False)
|
| 214 |
-
elif type == "dynamic_shape_rmsnorm":
|
| 215 |
-
return dynamic_shape_rmsnorm()
|
| 216 |
-
elif type == "dynamic_shape_layernorm":
|
| 217 |
-
return dynamic_shape_layernorm()
|
| 218 |
-
elif type == "l2norm":
|
| 219 |
-
return l2norm() #un function
|
| 220 |
-
elif type == "identitynorm":
|
| 221 |
-
return identitynorm(shape)
|
| 222 |
-
else:
|
| 223 |
-
raise Exception("Not implemented")
|
| 224 |
-
|
| 225 |
-
class l2norm(nn.Module): #haha
|
| 226 |
-
def forward(self, inputter, **kwargs):
|
| 227 |
-
inputter = nn.functional.normalize(inputter, p=2, dim=-1)
|
| 228 |
-
return inputter
|
| 229 |
-
|
| 230 |
-
def identitynorm(row):
|
| 231 |
-
return nn.Identity(row)
|
| 232 |
-
|
| 233 |
-
#from `questions/76067020/`` lol
|
| 234 |
-
class dynamic_shape_rmsnorm(nn.Module):
|
| 235 |
-
def forward(self, inputter, **kwargs):
|
| 236 |
-
inputter = inputter.transpose(1,2) #rotate!
|
| 237 |
-
#i am so sorry haha
|
| 238 |
-
#normalized_shape seems to require adjacencies, i tried a few other things first.
|
| 239 |
-
#wait the notation in the paper suggests... [3:].
|
| 240 |
-
inner_shape = inputter.size()[3:]
|
| 241 |
-
|
| 242 |
-
nn.functional.rms_norm(inputter, normalized_shape=inner_shape, **kwargs)
|
| 243 |
-
inputter = inputter.transpose(1,2) #reverse rotate!
|
| 244 |
-
return inputter
|
| 245 |
-
|
| 246 |
-
class dynamic_shape_layernorm(nn.Module):
|
| 247 |
-
def forward(self, inputter, **kwargs):
|
| 248 |
-
inputter = inputter.transpose(1,2) #rotate!
|
| 249 |
-
#i am so sorry haha
|
| 250 |
-
#normalized_shape seems to require adjacencies, i tried a few other things first.
|
| 251 |
-
#wait the notation in the paper suggests... [3:].
|
| 252 |
-
inner_shape = inputter.size()[3:]
|
| 253 |
-
|
| 254 |
-
nn.functional.layer_norm(inputter, normalized_shape=inner_shape, **kwargs)
|
| 255 |
-
inputter = inputter.transpose(1,2) #reverse rotate!
|
| 256 |
-
return inputter
|
| 257 |
-
|
| 258 |
-
#we too are hitting that mfing noam shazeer https://arxiv.org/pdf/2002.05202
|
| 259 |
-
#if there was a self-gated ELU id want to use it instead though
|
| 260 |
-
class swiglu(nn.Module):
|
| 261 |
-
def forward(self, x):
|
| 262 |
-
x, gate = x.chunk(2, dim=-1)
|
| 263 |
-
return nn.functional.silu(gate) * x
|
| 264 |
-
|
| 265 |
-
#rippin this one from modded-nanogpt
|
| 266 |
-
class rotarizer(nn.Module):
|
| 267 |
-
def __init__(self, dim, base=1000): #shhh don't tell anyone about the rotemb base
|
| 268 |
-
super().__init__()
|
| 269 |
-
self.inv_freq = (base ** (torch.arange(0,dim,2).float() / dim))**-1
|
| 270 |
-
self.seq_len_cached = None
|
| 271 |
-
self.cos_cached = None
|
| 272 |
-
self.sin_cached = None
|
| 273 |
-
|
| 274 |
-
def forward(self, x):
|
| 275 |
-
seq_len = x.size()[1] #perform the surgical LENGTH,YOINKEMS,{}, b {n} h d
|
| 276 |
-
#using torch tensor.size()[idx] notation bc i think it is more explicit than shape[]
|
| 277 |
-
if seq_len != self.seq_len_cached:
|
| 278 |
-
self.seq_len_cached = seq_len
|
| 279 |
-
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 280 |
-
reg_freqs = torch.outer(t, self.inv_freq).to(x.device)
|
| 281 |
-
self.cos_cached = reg_freqs.cos().bfloat16()
|
| 282 |
-
self.sin_cached = reg_freqs.sin().bfloat16()
|
| 283 |
-
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
|
| 284 |
-
#yeah slay em with the list comprehensions, cited author 😒
|
| 285 |
-
|
| 286 |
-
def apply_rotarizer_emb(x, cos, sin):
|
| 287 |
-
#assert x.ndim == 4 # b n h d
|
| 288 |
-
d = x.size()[3]//2 # perform the superb DIVIDE,2,LENGTH,YOINKEMS,{}, b n h {d}
|
| 289 |
-
x1 = x[..., :d] #some kind of slicing mystery code
|
| 290 |
-
x2 = x[..., d:]
|
| 291 |
-
y1 = x1 * cos + x2 * sin
|
| 292 |
-
y2 = x1 * (-sin) + x2 * cos
|
| 293 |
-
return torch.cat([y1,y2], 3).type_as(x)
|
| 294 |
-
|
| 295 |
-
#alternate attention to retrieve a shift matrix instead of scale matrix.
|
| 296 |
-
#this will either break the first time it runs or make perfect sense whomstdve doubted it all along
|
| 297 |
-
#REVISION 1:
|
| 298 |
-
#"""
|
| 299 |
-
def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
|
| 300 |
-
is_causal=False, scale=None, enable_gqa=False):
|
| 301 |
-
#make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
|
| 302 |
-
L, S = query.size(-2), key.size(-2)
|
| 303 |
-
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
| 304 |
-
#inversion of normal masking since we're not softmaxing
|
| 305 |
-
attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
|
| 306 |
-
|
| 307 |
-
if is_causal: #sounds caus-tly to change haha heeehehee
|
| 308 |
-
assert attn_mask is None
|
| 309 |
-
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
| 310 |
-
attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
|
| 311 |
-
attn_bias.to(query.dtype)
|
| 312 |
-
|
| 313 |
-
if attn_mask is not None: #more boilerplate ty pytorch
|
| 314 |
-
if attn_mask.dtype == torch.bool:
|
| 315 |
-
attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
|
| 316 |
-
else:
|
| 317 |
-
attn_bias *= attn_mask
|
| 318 |
-
|
| 319 |
-
if enable_gqa: #who can say what this does
|
| 320 |
-
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
| 321 |
-
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
| 322 |
-
|
| 323 |
-
attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
|
| 324 |
-
attn_magnitude *= attn_bias #* to combine instead of +
|
| 325 |
-
#attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
|
| 326 |
-
attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
|
| 327 |
-
return attn_magnitude @ value
|
| 328 |
-
#"""
|
| 329 |
-
#REVISION 2: this doesn't benefit from abstract syntactic similarity to torch sdpa. so we gut it!
|
| 330 |
-
#instead of creating a duds matrix of 1s to occupy the 'value' idx, we sum the shift-QK product directly
|
| 331 |
-
#uncompiled this maybe uses fewer ops; profile and find out.
|
| 332 |
-
"""
|
| 333 |
-
def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
|
| 334 |
-
is_causal=False, scale=None, enable_gqa=False):
|
| 335 |
-
#make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
|
| 336 |
-
L, S = query.size(-2), key.size(-2)
|
| 337 |
-
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
| 338 |
-
#inversion of normal masking since we're not softmaxing
|
| 339 |
-
attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
|
| 340 |
-
|
| 341 |
-
if is_causal: #sounds caus-tly to change haha heeehehee
|
| 342 |
-
assert attn_mask is None
|
| 343 |
-
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
| 344 |
-
attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
|
| 345 |
-
attn_bias.to(query.dtype)
|
| 346 |
-
|
| 347 |
-
if attn_mask is not None: #more boilerplate ty pytorch
|
| 348 |
-
if attn_mask.dtype == torch.bool:
|
| 349 |
-
attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
|
| 350 |
-
else:
|
| 351 |
-
attn_bias *= attn_mask
|
| 352 |
-
|
| 353 |
-
if enable_gqa: #who can say what this does
|
| 354 |
-
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
| 355 |
-
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
| 356 |
-
|
| 357 |
-
attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
|
| 358 |
-
attn_magnitude *= attn_bias #* to combine instead of +
|
| 359 |
-
#attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
|
| 360 |
-
attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
|
| 361 |
-
#... broadcasting... if A generalmatmul B, and A has shape (N x 1), B has shape (m x p),
|
| 362 |
-
# 1 is prepended to A in torch broadcasting. then A matmul B. then prepend removed.
|
| 363 |
-
# inplace prepend 1: a.unsqueeze_(0).
|
| 364 |
-
#
|
| 365 |
-
#attn_mag : b h n h_d ...
|
| 366 |
-
#no it *wasn't*! it's b, h, n, n!
|
| 367 |
-
#sdpa output (our v input) without transpose is
|
| 368 |
-
# b h n h_d
|
| 369 |
-
#so maybe we need to transpose sdpa_out by (-2, -1)
|
| 370 |
-
#such that sdpa_out : b h h_d n, allowing
|
| 371 |
-
#torch bmm of mat1:...{n X n} & mat2:...{h_d X n} --> bmmout: ...{h_d X n}
|
| 372 |
-
#print(attn_magnitude.size())
|
| 373 |
-
#print(value.size())
|
| 374 |
-
#attn_magnitude.unsqueeze_(0)
|
| 375 |
-
#...
|
| 376 |
-
#wow okay this is tricky. we were using a ones row reduce.
|
| 377 |
-
#basically last night we were assembling a (b h n_1 n_2) shape through biasq and biask matmul.
|
| 378 |
-
#then we multiplied it by a ones of (b h n h_d),
|
| 379 |
-
#which reduces b h n (n) to b h n (h_d)...
|
| 380 |
-
#where h_d rows are copies of sum(n_2) at each n index in (b h n h_d).
|
| 381 |
-
#meaning attn_II_rev1 was assigning a headwise bias along the entire sequence.
|
| 382 |
-
#which itself would be chosen by the optimizer state transformation evolution of biasq and biask.
|
| 383 |
-
return torch.add(attn_magnitude, value.transpose(-2,-1))
|
| 384 |
-
"""
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
### states take format batch, sequence, embedding
|
| 388 |
-
### therefore
|
| 389 |
-
### batch_size, sequence_length, embedding_dim = h_states.shape
|
| 390 |
-
def reshape_heads_dim(heads, tensor):
|
| 391 |
-
bat_len, seq_len, emb_dim = tensor.size()
|
| 392 |
-
head_len = heads
|
| 393 |
-
# i think equivalent to traditional
|
| 394 |
-
# "b n (h d) -> b h n d"
|
| 395 |
-
tensor = tensor.reshape(bat_len , seq_len, head_len, emb_dim // head_len)
|
| 396 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len*head_len, seq_len, emb_dim // head_len)
|
| 397 |
-
return tensor
|
| 398 |
-
|
| 399 |
-
def reshape_dim_heads(heads, tensor):
|
| 400 |
-
bat_len, seq_len, emb_dim = tensor.size()
|
| 401 |
-
head_len = heads
|
| 402 |
-
# i think equivalent to traditional
|
| 403 |
-
# "b h n d -> b n (h d)"
|
| 404 |
-
tensor = tensor.reshape(bat_len // head_len, head_len, seq_len, emb_dim)
|
| 405 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len // head_len, seq_len, emb_dim*head_len)
|
| 406 |
-
return tensor
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
###
|
| 410 |
-
### modelwise config:
|
| 411 |
-
### {"vocab_size":8000, "num_layers":4}
|
| 412 |
-
###
|
| 413 |
-
class PGPT_Lformer(nn.Module):
|
| 414 |
-
def __init__(self,config):
|
| 415 |
-
super().__init__()
|
| 416 |
-
self.config = config
|
| 417 |
-
|
| 418 |
-
self.lambdaformer = nn.ModuleDict(dict(
|
| 419 |
-
what_the_embedder_doin = nn.Embedding(config["vocab_size"], config["dim"]),
|
| 420 |
-
blocks = nn.ModuleList([vit22_tformer(config) for _ in range(config["num_layers"])])
|
| 421 |
-
))
|
| 422 |
-
self.tokenpicker_head = nn.Linear(in_features=config["dim"], out_features=config["vocab_size"], bias=False)
|
| 423 |
-
self.tokenpicker_head.weight.data.zero_() #re: @Grad62304977
|
| 424 |
-
|
| 425 |
-
def forward(self, index, targets=None, return_logits=True, return_zloss=False):
|
| 426 |
-
x = self.lambdaformer.what_the_embedder_doin(index) # get token embeddings
|
| 427 |
-
x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
|
| 428 |
-
for decoder in self.lambdaformer.blocks:
|
| 429 |
-
x = decoder(x)
|
| 430 |
-
x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
|
| 431 |
-
|
| 432 |
-
if targets is not None:
|
| 433 |
-
#grab some losses woooo
|
| 434 |
-
logits = self.tokenpicker_head(x)
|
| 435 |
-
if return_zloss: #tracking https://arxiv.org/abs/2309.14322
|
| 436 |
-
z = torch.sum(torch.exp(logits)) #reduce: e^logit[j]
|
| 437 |
-
z_loss = torch.log(z)**2 #log and square Z. make sure to set a coefficient in trainer!
|
| 438 |
-
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
|
| 439 |
-
logits = logits.float() # use tf32/fp32 for logits
|
| 440 |
-
loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
| 441 |
-
else:
|
| 442 |
-
#kellerjordan optimi
|
| 443 |
-
logits = self.tokenpicker_head(x[:, [-1], :]) # re: kj: note: using list [-1] to preserve the time dim
|
| 444 |
-
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
|
| 445 |
-
logits = logits.float() # use tf32/fp32 for logits
|
| 446 |
-
loss = None
|
| 447 |
-
|
| 448 |
-
#an appeal to performance is made:
|
| 449 |
-
if not return_logits:
|
| 450 |
-
logits = None
|
| 451 |
-
if not return_zloss:
|
| 452 |
-
z_loss = None
|
| 453 |
-
|
| 454 |
return logits, loss, z_loss
|
|
|
|
| 1 |
+
# Import necessary, revised, libraries
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
|
| 6 |
+
#dubious
|
| 7 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 8 |
+
|
| 9 |
+
#hehe
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
### note: modded_nanogpt.py is an more full container for a transformer block structure
|
| 13 |
+
### it should specify an encoder ("embedder") and decoder for autoregress on tinystories.
|
| 14 |
+
###
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
### 2302.05442's qk-layernorm is layernorm without centering and biases omitted.
|
| 19 |
+
### this is not equivalent to applying rmsnorm to the lexical scope of layernorm,
|
| 20 |
+
### as rmsnorm (1910.07467) doesn't use the mean statistic to yield variance.
|
| 21 |
+
### profiling and benchmarking a p%-pvarnorm would be great further work!
|
| 22 |
+
###
|
| 23 |
+
### to reach the complete spec of 2302.05442,
|
| 24 |
+
### noncentered nonbiased norms must be applied to projected q&k
|
| 25 |
+
###
|
| 26 |
+
### candidate default: cfig =
|
| 27 |
+
### {"dim":768,"dim_head":128,"headcount":6,"ffmult":4,
|
| 28 |
+
### "lambda":False,"layerwisenorm":"layernorm","qknorm":"identitynorm"}
|
| 29 |
+
### candidate tinystories:
|
| 30 |
+
### {"dim":256,"dim_head":32,"headcount":8,"ffmult":4,
|
| 31 |
+
### "lambda":True,"layerwisenorm":"rmsnorm","qknorm":"identitynorm"}
|
| 32 |
+
###
|
| 33 |
+
### 2401.14489 suggests GEneral Matrix Multiplication dim alignment.
|
| 34 |
+
### basically vocabsize%64:=0.
|
| 35 |
+
### and (emb_dim/heads)%2^k:=0 , for some integer k.
|
| 36 |
+
### and (batchsize*sequence_len)%2^k:=0 , for some integer k.
|
| 37 |
+
### this places the smallest possible seqlen at 64@bf16 and 128@fp8
|
| 38 |
+
###
|
| 39 |
+
### ...
|
| 40 |
+
### the swiglu returns to bite us. the presence of that doubled swiggy matrix does something!
|
| 41 |
+
### specifically. uh.
|
| 42 |
+
### actually because we cranked up the swiggy_dim by 2x, it follows all of our scaling rules
|
| 43 |
+
### lmao, lol, lol, lmao, etcetera.
|
| 44 |
+
class vit22_tformer(nn.Module):
|
| 45 |
+
def __init__(self, config):
|
| 46 |
+
super().__init__()
|
| 47 |
+
#query_dim = config["query_dim"] #don't even think about cross_attention
|
| 48 |
+
self.dim = config["dim"]
|
| 49 |
+
self.dim_head = config["dim_head"]
|
| 50 |
+
self.heads = config["headcount"]
|
| 51 |
+
self.weighted_skipnet = config["lambda"]
|
| 52 |
+
self.denseproj_mul = config["ff_mult"]
|
| 53 |
+
#self.naive_causal = config["is_causal_llm"]
|
| 54 |
+
#...
|
| 55 |
+
#self.qknormalized_shape = [config["dim_head"],config["training_seqlen"],config["headcount"],config["dim_head"],]
|
| 56 |
+
self.qknormalized_shape = [config["headcount"],config["dim_head"]]
|
| 57 |
+
self.layerwisenorm = getnorm(config["layerwisenorm"],shape=self.dim)
|
| 58 |
+
self.projnorm = getnorm(config["qknorm"],shape=self.qknormalized_shape)
|
| 59 |
+
|
| 60 |
+
attn_inner_dim = self.dim_head * self.heads
|
| 61 |
+
self.denseproj_inner_dim = self.dim * self.denseproj_mul
|
| 62 |
+
|
| 63 |
+
if "rotary_embedding_base" in config.keys():
|
| 64 |
+
self.rotbase = config["rotary_embedding_base"]
|
| 65 |
+
else:
|
| 66 |
+
self.rotbase = 1000 # hehe
|
| 67 |
+
|
| 68 |
+
self.attention_II = None
|
| 69 |
+
if "attention_deux" in config.keys():
|
| 70 |
+
self.attention_II = True
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
self.rotary = rotarizer(self.dim_head, base=self.rotbase)
|
| 74 |
+
self.learnedlambda = nn.Parameter(torch.tensor(1.0)) #my beloved
|
| 75 |
+
self.fused_swiglu_dim = self.denseproj_inner_dim*2 #this is necessary so the swiglu's two projections can be applied as a single operation.
|
| 76 |
+
self.scale = self.dim_head**-0.5 #this is the 's' in 's'dpa! #exposed for cosine attention reasons!
|
| 77 |
+
self.l2normscale = None
|
| 78 |
+
if config["qknorm"] == "l2norm": #bootleg cosine attention by overloading the scale term in sdpa
|
| 79 |
+
self.l2normscale = nn.Parameter(torch.log(torch.tensor(config["training_seqlen"]**2)-torch.tensor(config["training_seqlen"])))
|
| 80 |
+
|
| 81 |
+
#...
|
| 82 |
+
self.queryproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 83 |
+
self.keyproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 84 |
+
self.valueproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 85 |
+
self.attnoutproj = nn.Linear(in_features=attn_inner_dim, out_features=self.dim, bias=True)
|
| 86 |
+
|
| 87 |
+
if self.attention_II:
|
| 88 |
+
self.queryBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 89 |
+
self.keyBproj = nn.Linear(in_features=self.dim, out_features=self.dim, bias=False)
|
| 90 |
+
|
| 91 |
+
#dense ('mlp', 'feedforward', 'fully connected', ...) unit
|
| 92 |
+
self.fused_denseproj_in = nn.Linear(in_features=self.dim, out_features=self.fused_swiglu_dim, bias=True) #this is the vit22b part
|
| 93 |
+
self.dense_swiggy = swiglu() #this is kind of superfluous but this is pedagogical programming!
|
| 94 |
+
self.denseproj_out = nn.Linear(in_features=self.denseproj_inner_dim, out_features=self.dim, bias=True)
|
| 95 |
+
|
| 96 |
+
#[x]
|
| 97 |
+
def self_attn(self, x, bat_len, seq_len):
|
| 98 |
+
#norm -> {qkvproj -> qknorm{?}
|
| 99 |
+
#reshape_h_d -> attn -> reshape_d_h} -> attnoutproj
|
| 100 |
+
#project
|
| 101 |
+
query = self.queryproj(x)
|
| 102 |
+
key = self.keyproj(x)
|
| 103 |
+
value = self.valueproj(x)
|
| 104 |
+
|
| 105 |
+
if self.attention_II:
|
| 106 |
+
biasquery = self.queryBproj(x)
|
| 107 |
+
biaskey = self.keyBproj(x)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
#reshape to bundled up matmul formme
|
| 111 |
+
#query = reshape_heads_dim(self.heads, query)
|
| 112 |
+
#key = reshape_heads_dim(self.heads, key)
|
| 113 |
+
#value = reshape_heads_dim(self.heads, value)
|
| 114 |
+
#alternate reshape for compatibility with modded-nanogpt roformer
|
| 115 |
+
query = query.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 116 |
+
key = key.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 117 |
+
value = value.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 118 |
+
|
| 119 |
+
if self.attention_II:
|
| 120 |
+
biasquery = biasquery.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 121 |
+
biaskey = biaskey.view(bat_len, seq_len, self.heads, self.dim_head)
|
| 122 |
+
|
| 123 |
+
#pos_emb suggested before qknorm re: kellerjordan re: @Grad62304977
|
| 124 |
+
#but we get an error for the x.ndim assertion if we run this after reshaping. whoopsie!
|
| 125 |
+
cos, sin = self.rotary(query) #our rotary unit does the shape detection from states
|
| 126 |
+
|
| 127 |
+
#qk*norm
|
| 128 |
+
query = self.projnorm(query)
|
| 129 |
+
key = self.projnorm(key)
|
| 130 |
+
|
| 131 |
+
if self.attention_II:
|
| 132 |
+
biasquery = self.projnorm(biasquery)
|
| 133 |
+
biaskey = self.projnorm(biaskey)
|
| 134 |
+
|
| 135 |
+
#rotary embed after qknorm as suggested etc.
|
| 136 |
+
query = apply_rotarizer_emb(query, cos, sin)
|
| 137 |
+
key = apply_rotarizer_emb(key, cos, sin)
|
| 138 |
+
|
| 139 |
+
if self.attention_II:
|
| 140 |
+
biasquery = apply_rotarizer_emb(biasquery, cos, sin)
|
| 141 |
+
biaskey = apply_rotarizer_emb(biaskey, cos, sin)
|
| 142 |
+
|
| 143 |
+
#laser-attn goes here
|
| 144 |
+
#...
|
| 145 |
+
|
| 146 |
+
#if we were here to explain attention instead of projections and norms,
|
| 147 |
+
#we would have written this in jax or a language that compiles well!
|
| 148 |
+
#instead, to benefit from flash attention 2, we want to use torch SDPA!
|
| 149 |
+
if self.l2normscale is not None:
|
| 150 |
+
y = self.l2normscale*nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=1, is_causal=True)
|
| 151 |
+
else:
|
| 152 |
+
y = nn.functional.scaled_dot_product_attention(query.transpose(1,2), key.transpose(1,2), value.transpose(1,2), scale=self.scale, is_causal=True)
|
| 153 |
+
|
| 154 |
+
if self.attention_II:
|
| 155 |
+
#REV1
|
| 156 |
+
dud = torch.ones_like(value, dtype=query.dtype, device=query.device)
|
| 157 |
+
y = y + scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
|
| 158 |
+
biasquery.transpose(1,2) , biaskey.transpose(1,2) , dud.transpose(1,2),
|
| 159 |
+
scale=self.scale, is_causal=True
|
| 160 |
+
)
|
| 161 |
+
"""
|
| 162 |
+
#REV2
|
| 163 |
+
#attn_bias now sums the shift matrix within the attn_bias operation to our 'value' target.
|
| 164 |
+
y = scaled_dot_product_attn_bias( #~~attempt to reuse whatever efficient kernels we have already~~ nvm
|
| 165 |
+
biasquery.transpose(1,2), biaskey.transpose(1,2), y,
|
| 166 |
+
scale=self.scale, is_causal=True
|
| 167 |
+
)
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
#reshape scalars from folded position to unfolded position so the ribosome can read the messenger headrna
|
| 171 |
+
#y = self.reshape_dim_heads(self.heads, y)
|
| 172 |
+
#alternate reshape scalars
|
| 173 |
+
y = y.transpose(1,2).contiguous().view_as(x) #thanks a bunch modded-nanogpt
|
| 174 |
+
|
| 175 |
+
#laser-attn unscale goes here
|
| 176 |
+
#...
|
| 177 |
+
|
| 178 |
+
return self.attnoutproj(y)
|
| 179 |
+
|
| 180 |
+
#[x]
|
| 181 |
+
def feedfor(self,x):
|
| 182 |
+
x = self.fused_denseproj_in(x)
|
| 183 |
+
x = self.dense_swiggy(x)
|
| 184 |
+
x = self.denseproj_out(x)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
#parallel forward from kingoflolz/mesh-transformer-jax/! check it out!!
|
| 188 |
+
# "discovered by Wang et al + EleutherAI from GPT-J fame"
|
| 189 |
+
def forward(self, h_states):
|
| 190 |
+
# in trad dialect: b->batch, n,i,j,k,l,m,f,a,o -> sequentiality dims, h->heads, d->embedding dim
|
| 191 |
+
bat_len, seq_len, emb_dim = h_states.size()
|
| 192 |
+
# ^ detritus from modded-nanogpt transpose implementation. profile later ig.
|
| 193 |
+
|
| 194 |
+
# highly traditional pre layernorm
|
| 195 |
+
inner_states = self.layerwisenorm(h_states)
|
| 196 |
+
|
| 197 |
+
#crunchy parts
|
| 198 |
+
attn_out = self.self_attn(inner_states, bat_len, seq_len)
|
| 199 |
+
dense_out = self.feedfor(inner_states)
|
| 200 |
+
if self.weighted_skipnet==True:
|
| 201 |
+
skip_out = h_states*self.learnedlambda
|
| 202 |
+
else:
|
| 203 |
+
skip_out = h_states
|
| 204 |
+
#output w/ unabstracted resnet
|
| 205 |
+
return skip_out + dense_out + attn_out
|
| 206 |
+
|
| 207 |
+
def getnorm(type, shape=None):
|
| 208 |
+
if type == "layernorm":
|
| 209 |
+
return nn.LayerNorm(shape, elementwise_affine=True, bias=True)
|
| 210 |
+
elif type == "layernorm-nobias":
|
| 211 |
+
return nn.LayerNorm(shape, elementwise_affine=True, bias=False) #???
|
| 212 |
+
elif type == "rmsnorm":
|
| 213 |
+
return nn.RMSNorm(shape, elementwise_affine=False)
|
| 214 |
+
elif type == "dynamic_shape_rmsnorm":
|
| 215 |
+
return dynamic_shape_rmsnorm()
|
| 216 |
+
elif type == "dynamic_shape_layernorm":
|
| 217 |
+
return dynamic_shape_layernorm()
|
| 218 |
+
elif type == "l2norm":
|
| 219 |
+
return l2norm() #un function
|
| 220 |
+
elif type == "identitynorm":
|
| 221 |
+
return identitynorm(shape)
|
| 222 |
+
else:
|
| 223 |
+
raise Exception("Not implemented")
|
| 224 |
+
|
| 225 |
+
class l2norm(nn.Module): #haha
|
| 226 |
+
def forward(self, inputter, **kwargs):
|
| 227 |
+
inputter = nn.functional.normalize(inputter, p=2, dim=-1)
|
| 228 |
+
return inputter
|
| 229 |
+
|
| 230 |
+
def identitynorm(row):
|
| 231 |
+
return nn.Identity(row)
|
| 232 |
+
|
| 233 |
+
#from `questions/76067020/`` lol
|
| 234 |
+
class dynamic_shape_rmsnorm(nn.Module):
|
| 235 |
+
def forward(self, inputter, **kwargs):
|
| 236 |
+
inputter = inputter.transpose(1,2) #rotate!
|
| 237 |
+
#i am so sorry haha
|
| 238 |
+
#normalized_shape seems to require adjacencies, i tried a few other things first.
|
| 239 |
+
#wait the notation in the paper suggests... [3:].
|
| 240 |
+
inner_shape = inputter.size()[3:]
|
| 241 |
+
|
| 242 |
+
nn.functional.rms_norm(inputter, normalized_shape=inner_shape, **kwargs)
|
| 243 |
+
inputter = inputter.transpose(1,2) #reverse rotate!
|
| 244 |
+
return inputter
|
| 245 |
+
|
| 246 |
+
class dynamic_shape_layernorm(nn.Module):
|
| 247 |
+
def forward(self, inputter, **kwargs):
|
| 248 |
+
inputter = inputter.transpose(1,2) #rotate!
|
| 249 |
+
#i am so sorry haha
|
| 250 |
+
#normalized_shape seems to require adjacencies, i tried a few other things first.
|
| 251 |
+
#wait the notation in the paper suggests... [3:].
|
| 252 |
+
inner_shape = inputter.size()[3:]
|
| 253 |
+
|
| 254 |
+
nn.functional.layer_norm(inputter, normalized_shape=inner_shape, **kwargs)
|
| 255 |
+
inputter = inputter.transpose(1,2) #reverse rotate!
|
| 256 |
+
return inputter
|
| 257 |
+
|
| 258 |
+
#we too are hitting that mfing noam shazeer https://arxiv.org/pdf/2002.05202
|
| 259 |
+
#if there was a self-gated ELU id want to use it instead though
|
| 260 |
+
class swiglu(nn.Module):
|
| 261 |
+
def forward(self, x):
|
| 262 |
+
x, gate = x.chunk(2, dim=-1)
|
| 263 |
+
return nn.functional.silu(gate) * x
|
| 264 |
+
|
| 265 |
+
#rippin this one from modded-nanogpt
|
| 266 |
+
class rotarizer(nn.Module):
|
| 267 |
+
def __init__(self, dim, base=1000): #shhh don't tell anyone about the rotemb base
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.inv_freq = (base ** (torch.arange(0,dim,2).float() / dim))**-1
|
| 270 |
+
self.seq_len_cached = None
|
| 271 |
+
self.cos_cached = None
|
| 272 |
+
self.sin_cached = None
|
| 273 |
+
|
| 274 |
+
def forward(self, x):
|
| 275 |
+
seq_len = x.size()[1] #perform the surgical LENGTH,YOINKEMS,{}, b {n} h d
|
| 276 |
+
#using torch tensor.size()[idx] notation bc i think it is more explicit than shape[]
|
| 277 |
+
if seq_len != self.seq_len_cached:
|
| 278 |
+
self.seq_len_cached = seq_len
|
| 279 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 280 |
+
reg_freqs = torch.outer(t, self.inv_freq).to(x.device)
|
| 281 |
+
self.cos_cached = reg_freqs.cos().bfloat16()
|
| 282 |
+
self.sin_cached = reg_freqs.sin().bfloat16()
|
| 283 |
+
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
|
| 284 |
+
#yeah slay em with the list comprehensions, cited author 😒
|
| 285 |
+
|
| 286 |
+
def apply_rotarizer_emb(x, cos, sin):
|
| 287 |
+
#assert x.ndim == 4 # b n h d
|
| 288 |
+
d = x.size()[3]//2 # perform the superb DIVIDE,2,LENGTH,YOINKEMS,{}, b n h {d}
|
| 289 |
+
x1 = x[..., :d] #some kind of slicing mystery code
|
| 290 |
+
x2 = x[..., d:]
|
| 291 |
+
y1 = x1 * cos + x2 * sin
|
| 292 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 293 |
+
return torch.cat([y1,y2], 3).type_as(x)
|
| 294 |
+
|
| 295 |
+
#alternate attention to retrieve a shift matrix instead of scale matrix.
|
| 296 |
+
#this will either break the first time it runs or make perfect sense whomstdve doubted it all along
|
| 297 |
+
#REVISION 1:
|
| 298 |
+
#"""
|
| 299 |
+
def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
|
| 300 |
+
is_causal=False, scale=None, enable_gqa=False):
|
| 301 |
+
#make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
|
| 302 |
+
L, S = query.size(-2), key.size(-2)
|
| 303 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
| 304 |
+
#inversion of normal masking since we're not softmaxing
|
| 305 |
+
attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
|
| 306 |
+
|
| 307 |
+
if is_causal: #sounds caus-tly to change haha heeehehee
|
| 308 |
+
assert attn_mask is None
|
| 309 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
| 310 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
|
| 311 |
+
attn_bias.to(query.dtype)
|
| 312 |
+
|
| 313 |
+
if attn_mask is not None: #more boilerplate ty pytorch
|
| 314 |
+
if attn_mask.dtype == torch.bool:
|
| 315 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
|
| 316 |
+
else:
|
| 317 |
+
attn_bias *= attn_mask
|
| 318 |
+
|
| 319 |
+
if enable_gqa: #who can say what this does
|
| 320 |
+
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
| 321 |
+
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
| 322 |
+
|
| 323 |
+
attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
|
| 324 |
+
attn_magnitude *= attn_bias #* to combine instead of +
|
| 325 |
+
#attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
|
| 326 |
+
attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
|
| 327 |
+
return attn_magnitude @ value
|
| 328 |
+
#"""
|
| 329 |
+
#REVISION 2: this doesn't benefit from abstract syntactic similarity to torch sdpa. so we gut it!
|
| 330 |
+
#instead of creating a duds matrix of 1s to occupy the 'value' idx, we sum the shift-QK product directly
|
| 331 |
+
#uncompiled this maybe uses fewer ops; profile and find out.
|
| 332 |
+
"""
|
| 333 |
+
def scaled_dot_product_attn_bias(query, key, value, attn_mask=None, dropout_p=0.0,
|
| 334 |
+
is_causal=False, scale=None, enable_gqa=False):
|
| 335 |
+
#make sure you compile this or it will be slow! haha! it will be slow otherwise! haha!
|
| 336 |
+
L, S = query.size(-2), key.size(-2)
|
| 337 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
| 338 |
+
#inversion of normal masking since we're not softmaxing
|
| 339 |
+
attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
|
| 340 |
+
|
| 341 |
+
if is_causal: #sounds caus-tly to change haha heeehehee
|
| 342 |
+
assert attn_mask is None
|
| 343 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
| 344 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("0")) #0 not neginf
|
| 345 |
+
attn_bias.to(query.dtype)
|
| 346 |
+
|
| 347 |
+
if attn_mask is not None: #more boilerplate ty pytorch
|
| 348 |
+
if attn_mask.dtype == torch.bool:
|
| 349 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("0")) #0 not neginf
|
| 350 |
+
else:
|
| 351 |
+
attn_bias *= attn_mask
|
| 352 |
+
|
| 353 |
+
if enable_gqa: #who can say what this does
|
| 354 |
+
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
| 355 |
+
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
| 356 |
+
|
| 357 |
+
attn_magnitude = torch.matmul(query, key.transpose(-2, -1)) * scale
|
| 358 |
+
attn_magnitude *= attn_bias #* to combine instead of +
|
| 359 |
+
#attn_magnitude = torch.softmax(attn_weight, dim=-1) we dont want this lol
|
| 360 |
+
attn_magnitude = torch.dropout(attn_magnitude, dropout_p, train=True)
|
| 361 |
+
#... broadcasting... if A generalmatmul B, and A has shape (N x 1), B has shape (m x p),
|
| 362 |
+
# 1 is prepended to A in torch broadcasting. then A matmul B. then prepend removed.
|
| 363 |
+
# inplace prepend 1: a.unsqueeze_(0).
|
| 364 |
+
#
|
| 365 |
+
#attn_mag : b h n h_d ...
|
| 366 |
+
#no it *wasn't*! it's b, h, n, n!
|
| 367 |
+
#sdpa output (our v input) without transpose is
|
| 368 |
+
# b h n h_d
|
| 369 |
+
#so maybe we need to transpose sdpa_out by (-2, -1)
|
| 370 |
+
#such that sdpa_out : b h h_d n, allowing
|
| 371 |
+
#torch bmm of mat1:...{n X n} & mat2:...{h_d X n} --> bmmout: ...{h_d X n}
|
| 372 |
+
#print(attn_magnitude.size())
|
| 373 |
+
#print(value.size())
|
| 374 |
+
#attn_magnitude.unsqueeze_(0)
|
| 375 |
+
#...
|
| 376 |
+
#wow okay this is tricky. we were using a ones row reduce.
|
| 377 |
+
#basically last night we were assembling a (b h n_1 n_2) shape through biasq and biask matmul.
|
| 378 |
+
#then we multiplied it by a ones of (b h n h_d),
|
| 379 |
+
#which reduces b h n (n) to b h n (h_d)...
|
| 380 |
+
#where h_d rows are copies of sum(n_2) at each n index in (b h n h_d).
|
| 381 |
+
#meaning attn_II_rev1 was assigning a headwise bias along the entire sequence.
|
| 382 |
+
#which itself would be chosen by the optimizer state transformation evolution of biasq and biask.
|
| 383 |
+
return torch.add(attn_magnitude, value.transpose(-2,-1))
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
### states take format batch, sequence, embedding
|
| 388 |
+
### therefore
|
| 389 |
+
### batch_size, sequence_length, embedding_dim = h_states.shape
|
| 390 |
+
def reshape_heads_dim(heads, tensor):
|
| 391 |
+
bat_len, seq_len, emb_dim = tensor.size()
|
| 392 |
+
head_len = heads
|
| 393 |
+
# i think equivalent to traditional
|
| 394 |
+
# "b n (h d) -> b h n d"
|
| 395 |
+
tensor = tensor.reshape(bat_len , seq_len, head_len, emb_dim // head_len)
|
| 396 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len*head_len, seq_len, emb_dim // head_len)
|
| 397 |
+
return tensor
|
| 398 |
+
|
| 399 |
+
def reshape_dim_heads(heads, tensor):
|
| 400 |
+
bat_len, seq_len, emb_dim = tensor.size()
|
| 401 |
+
head_len = heads
|
| 402 |
+
# i think equivalent to traditional
|
| 403 |
+
# "b h n d -> b n (h d)"
|
| 404 |
+
tensor = tensor.reshape(bat_len // head_len, head_len, seq_len, emb_dim)
|
| 405 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(bat_len // head_len, seq_len, emb_dim*head_len)
|
| 406 |
+
return tensor
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
###
|
| 410 |
+
### modelwise config:
|
| 411 |
+
### {"vocab_size":8000, "num_layers":4}
|
| 412 |
+
###
|
| 413 |
+
class PGPT_Lformer(nn.Module):
|
| 414 |
+
def __init__(self,config):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.config = config
|
| 417 |
+
|
| 418 |
+
self.lambdaformer = nn.ModuleDict(dict(
|
| 419 |
+
what_the_embedder_doin = nn.Embedding(config["vocab_size"], config["dim"]),
|
| 420 |
+
blocks = nn.ModuleList([vit22_tformer(config) for _ in range(config["num_layers"])])
|
| 421 |
+
))
|
| 422 |
+
self.tokenpicker_head = nn.Linear(in_features=config["dim"], out_features=config["vocab_size"], bias=False)
|
| 423 |
+
self.tokenpicker_head.weight.data.zero_() #re: @Grad62304977
|
| 424 |
+
|
| 425 |
+
def forward(self, index, targets=None, return_logits=True, return_zloss=False):
|
| 426 |
+
x = self.lambdaformer.what_the_embedder_doin(index) # get token embeddings
|
| 427 |
+
x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
|
| 428 |
+
for decoder in self.lambdaformer.blocks:
|
| 429 |
+
x = decoder(x)
|
| 430 |
+
x = nn.functional.rms_norm(x, (x.size(-1),)) #re: @Grad62304977
|
| 431 |
+
|
| 432 |
+
if targets is not None:
|
| 433 |
+
#grab some losses woooo
|
| 434 |
+
logits = self.tokenpicker_head(x)
|
| 435 |
+
if return_zloss: #tracking https://arxiv.org/abs/2309.14322
|
| 436 |
+
z = torch.sum(torch.exp(logits)) #reduce: e^logit[j]
|
| 437 |
+
z_loss = torch.log(z)**2 #log and square Z. make sure to set a coefficient in trainer!
|
| 438 |
+
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
|
| 439 |
+
logits = logits.float() # use tf32/fp32 for logits
|
| 440 |
+
loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
| 441 |
+
else:
|
| 442 |
+
#kellerjordan optimi
|
| 443 |
+
logits = self.tokenpicker_head(x[:, [-1], :]) # re: kj: note: using list [-1] to preserve the time dim
|
| 444 |
+
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
|
| 445 |
+
logits = logits.float() # use tf32/fp32 for logits
|
| 446 |
+
loss = None
|
| 447 |
+
|
| 448 |
+
#an appeal to performance is made:
|
| 449 |
+
if not return_logits:
|
| 450 |
+
logits = None
|
| 451 |
+
if not return_zloss:
|
| 452 |
+
z_loss = None
|
| 453 |
+
|
| 454 |
return logits, loss, z_loss
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
torch
|
| 2 |
-
tiktoken
|
| 3 |
-
huggingface_hub
|
| 4 |
-
gradio
|
| 5 |
matplotlib
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
tiktoken
|
| 3 |
+
huggingface_hub
|
| 4 |
+
gradio
|
| 5 |
matplotlib
|