attn_shift_demo / app.py
zyzzyva
yeah we vibecoding
062b730
import gradio as gr
import torch
import time
import os
from huggingface_hub import hf_hub_download
import tiktoken
import pgptlformer # Your model definition file
import matplotlib.pyplot as plt
import numpy as np
from contextlib import nullcontext
# --- Configuration ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
PTDTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[DTYPE]
CTX = nullcontext() if DEVICE == 'cpu' else torch.amp.autocast(device_type=DEVICE, dtype=PTDTYPE)
TORCH_COMPILE = False # Gradio instances can be slow, so compilation might timeout. Set to False for stability.
# --- Model Loading ---
@torch.no_grad()
def load_model(repo_id, filename, config_override=None):
"""Loads a model from the Hugging Face Hub."""
print(f"Loading model: {repo_id}/{filename}...")
try:
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(ckpt_path, map_location=DEVICE)
tformer_cfg = checkpoint['model_args']
if config_override:
tformer_cfg.update(config_override)
model = pgptlformer.PGPT_Lformer(tformer_cfg)
state_dict = checkpoint['model']
# Clean up state dict if needed
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility
model.eval()
model.to(DEVICE)
if TORCH_COMPILE:
model = torch.compile(model)
print(f"Model {filename} loaded successfully.")
return model, tformer_cfg
except Exception as e:
print(f"Error loading model {filename}: {e}")
raise
# Load both models once at the start
try:
# FIX #1: Add the correct subdirectory for the baseline model
BASELINE_MODEL, BASELINE_CFG = load_model(
repo_id="SQCU/pgptlformer-tinystories",
filename="re-pqt-rmsXrmsx2-70b91221-a39c-4824-a69c-48a034963529/state_step040500.pt"
)
# FIX #2: The shift-attn model already had the directory, but ensure it's correct
SHIFT_ATTN_MODEL, SHIFT_ATTN_CFG = load_model(
repo_id="SQCU/pgptlformer-tinystories",
filename="re-pqt-rmsXrmsx2x2-ATTNII-791967c5-5c59-4a5f-a2c5-07772bcf65ab/state_step040500.pt",
config_override={"attention_deux": True}
)
except Exception as e:
# If loading fails, show an error in the Gradio app instead of crashing
BASELINE_MODEL, SHIFT_ATTN_MODEL = None, None
ERROR_MESSAGE = f"Failed to load models. Please check logs. Error: {e}"
# --- Inference and Metrics ---
ENC = tiktoken.get_encoding("gpt2")
ENCODE = lambda s: ENC.encode(s, allowed_special={"<|endoftext|>"})
DECODE = lambda l: ENC.decode(l)
@torch.no_grad()
def generate_and_measure(model, prompt_ids, max_new_tokens=50):
"""Runs inference and calculates metrics."""
# Reset stats for this run
if DEVICE == 'cuda':
torch.cuda.reset_peak_memory_stats(DEVICE)
torch.cuda.synchronize()
start_time = time.time()
# --- Generation Loop ---
model_logits = []
generated_ids = prompt_ids
for _ in range(max_new_tokens):
idx_cond = generated_ids if generated_ids.size(1) <= 1024 else generated_ids[:, -1024:]
logits, _, _ = model(idx_cond, return_logits=True)
final_logits = logits[:, -1, :]
model_logits.append(final_logits) # Store logits for perplexity/sharpening calc
probs = torch.nn.functional.softmax(final_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
generated_ids = torch.cat((generated_ids, idx_next), dim=1)
if DEVICE == 'cuda':
torch.cuda.synchronize()
end_time = time.time()
# --- Metrics Calculation ---
# 1. Inference Speed (Tokens/Second)
tokens_per_sec = max_new_tokens / (end_time - start_time)
# 2. VRAM Usage (MB)
vram_usage = torch.cuda.max_memory_allocated(DEVICE) / (1024**2) if DEVICE == 'cuda' else 0
# 3. Pseudo-Perplexity
all_logits = torch.cat(model_logits, dim=0)
target_ids = generated_ids[0, -max_new_tokens:]
cross_entropy = torch.nn.functional.cross_entropy(all_logits, target_ids)
pseudo_perplexity = torch.exp(cross_entropy).item()
# 4. Logit Sharpening (Average of max probability)
avg_max_prob = torch.nn.functional.softmax(all_logits, dim=-1).max(dim=-1).values.mean().item()
# --- Decode and Return ---
output_text = DECODE(generated_ids[0].tolist())
metrics = {
'Tokens/Sec': tokens_per_sec,
'VRAM (MB)': vram_usage,
'Perplexity': pseudo_perplexity,
'Logit Sharpening': avg_max_prob,
}
return output_text, metrics
# --- Visualization ---
def plot_radar_chart(baseline_metrics, shift_attn_metrics):
"""Creates a radar chart comparing the two models."""
labels = list(baseline_metrics.keys())
baseline_stats = list(baseline_metrics.values())
shift_attn_stats = list(shift_attn_metrics.values())
# Normalize stats for plotting. Higher is better for all metrics on the chart.
# We will take the inverse of Perplexity and VRAM for a "higher is better" visualization.
baseline_plot_stats = [
baseline_stats[0], # Tokens/Sec (Higher is better)
1 / (baseline_stats[1] + 1e-6), # VRAM (Inverse)
1 / (baseline_stats[2] + 1e-6), # Perplexity (Inverse)
baseline_stats[3] # Sharpening (Higher is better)
]
shift_attn_plot_stats = [
shift_attn_stats[0],
1 / (shift_attn_stats[1] + 1e-6),
1 / (shift_attn_stats[2] + 1e-6),
shift_attn_stats[3]
]
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
# Make the plot circular
baseline_plot_stats += baseline_plot_stats[:1]
shift_attn_plot_stats += shift_attn_plot_stats[:1]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
# Helper function to find nice plot limits
def get_max_val(*args):
return max(max(lst) for lst in args if lst) * 1.2
ax.set_ylim(0, get_max_val(baseline_plot_stats, shift_attn_plot_stats))
# Plot labels
ax.set_xticks(angles[:-1])
ax.set_xticklabels(["Tokens/Sec\n(Higher is Better)", "1 / VRAM\n(Higher is Better)", "1 / Perplexity\n(Higher is Better)", "Logit Sharpening\n(Higher is Better)"])
# Plot data
ax.plot(angles, baseline_plot_stats, 'o-', linewidth=2, label="Baseline")
ax.fill(angles, baseline_plot_stats, alpha=0.25)
ax.plot(angles, shift_attn_plot_stats, 'o-', linewidth=2, label="Shift-Attn")
ax.fill(angles, shift_attn_plot_stats, alpha=0.25)
ax.set_title("Model Performance Comparison", size=20, color='gray', y=1.1)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
plt.tight_layout()
return fig
# --- Gradio Interface ---
def run_comparison(prompt, max_new_tokens):
if not BASELINE_MODEL or not SHIFT_ATTN_MODEL:
raise gr.Error(ERROR_MESSAGE)
input_ids = ENCODE(prompt)
x = (torch.tensor(input_ids, dtype=torch.long, device=DEVICE)[None, ...])
# Run both models
baseline_text, baseline_metrics = generate_and_measure(BASELINE_MODEL, x, max_new_tokens)
shift_attn_text, shift_attn_metrics = generate_and_measure(SHIFT_ATTN_MODEL, x, max_new_tokens)
# Create plot
chart = plot_radar_chart(baseline_metrics, shift_attn_metrics)
return baseline_text, shift_attn_text, chart
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("# `shift-attn`: A Live Demonstration")
gr.Markdown(
"This demo compares a baseline `pgptlformer` model against an identical model enhanced with the `shift-attn` mechanism (`attention_deux`). "
"The radar chart visualizes key performance and efficiency metrics, where a larger area indicates a better overall model."
)
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Enter your prompt:", value="The quick brown fox")
token_slider = gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max New Tokens")
submit_btn = gr.Button("Compare Models", variant="primary")
with gr.Column(scale=2):
plot_output = gr.Plot(label="Performance Radar Chart")
with gr.Row():
baseline_output = gr.Textbox(label="Baseline Model Output", lines=8)
shift_attn_output = gr.Textbox(label="Shift-Attn Model Output", lines=8)
submit_btn.click(
fn=run_comparison,
inputs=[prompt_input, token_slider],
outputs=[baseline_output, shift_attn_output, plot_output]
)
if __name__ == "__main__":
demo.launch()