import gradio as gr import torch import time import os from huggingface_hub import hf_hub_download import tiktoken import pgptlformer # Your model definition file import matplotlib.pyplot as plt import numpy as np from contextlib import nullcontext # --- Configuration --- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' PTDTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[DTYPE] CTX = nullcontext() if DEVICE == 'cpu' else torch.amp.autocast(device_type=DEVICE, dtype=PTDTYPE) TORCH_COMPILE = False # Gradio instances can be slow, so compilation might timeout. Set to False for stability. # --- Model Loading --- @torch.no_grad() def load_model(repo_id, filename, config_override=None): """Loads a model from the Hugging Face Hub.""" print(f"Loading model: {repo_id}/{filename}...") try: ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = torch.load(ckpt_path, map_location=DEVICE) tformer_cfg = checkpoint['model_args'] if config_override: tformer_cfg.update(config_override) model = pgptlformer.PGPT_Lformer(tformer_cfg) state_dict = checkpoint['model'] # Clean up state dict if needed unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility model.eval() model.to(DEVICE) if TORCH_COMPILE: model = torch.compile(model) print(f"Model {filename} loaded successfully.") return model, tformer_cfg except Exception as e: print(f"Error loading model {filename}: {e}") raise # Load both models once at the start try: # FIX #1: Add the correct subdirectory for the baseline model BASELINE_MODEL, BASELINE_CFG = load_model( repo_id="SQCU/pgptlformer-tinystories", filename="re-pqt-rmsXrmsx2-70b91221-a39c-4824-a69c-48a034963529/state_step040500.pt" ) # FIX #2: The shift-attn model already had the directory, but ensure it's correct SHIFT_ATTN_MODEL, SHIFT_ATTN_CFG = load_model( repo_id="SQCU/pgptlformer-tinystories", filename="re-pqt-rmsXrmsx2x2-ATTNII-791967c5-5c59-4a5f-a2c5-07772bcf65ab/state_step040500.pt", config_override={"attention_deux": True} ) except Exception as e: # If loading fails, show an error in the Gradio app instead of crashing BASELINE_MODEL, SHIFT_ATTN_MODEL = None, None ERROR_MESSAGE = f"Failed to load models. Please check logs. Error: {e}" # --- Inference and Metrics --- ENC = tiktoken.get_encoding("gpt2") ENCODE = lambda s: ENC.encode(s, allowed_special={"<|endoftext|>"}) DECODE = lambda l: ENC.decode(l) @torch.no_grad() def generate_and_measure(model, prompt_ids, max_new_tokens=50): """Runs inference and calculates metrics.""" # Reset stats for this run if DEVICE == 'cuda': torch.cuda.reset_peak_memory_stats(DEVICE) torch.cuda.synchronize() start_time = time.time() # --- Generation Loop --- model_logits = [] generated_ids = prompt_ids for _ in range(max_new_tokens): idx_cond = generated_ids if generated_ids.size(1) <= 1024 else generated_ids[:, -1024:] logits, _, _ = model(idx_cond, return_logits=True) final_logits = logits[:, -1, :] model_logits.append(final_logits) # Store logits for perplexity/sharpening calc probs = torch.nn.functional.softmax(final_logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) generated_ids = torch.cat((generated_ids, idx_next), dim=1) if DEVICE == 'cuda': torch.cuda.synchronize() end_time = time.time() # --- Metrics Calculation --- # 1. Inference Speed (Tokens/Second) tokens_per_sec = max_new_tokens / (end_time - start_time) # 2. VRAM Usage (MB) vram_usage = torch.cuda.max_memory_allocated(DEVICE) / (1024**2) if DEVICE == 'cuda' else 0 # 3. Pseudo-Perplexity all_logits = torch.cat(model_logits, dim=0) target_ids = generated_ids[0, -max_new_tokens:] cross_entropy = torch.nn.functional.cross_entropy(all_logits, target_ids) pseudo_perplexity = torch.exp(cross_entropy).item() # 4. Logit Sharpening (Average of max probability) avg_max_prob = torch.nn.functional.softmax(all_logits, dim=-1).max(dim=-1).values.mean().item() # --- Decode and Return --- output_text = DECODE(generated_ids[0].tolist()) metrics = { 'Tokens/Sec': tokens_per_sec, 'VRAM (MB)': vram_usage, 'Perplexity': pseudo_perplexity, 'Logit Sharpening': avg_max_prob, } return output_text, metrics # --- Visualization --- def plot_radar_chart(baseline_metrics, shift_attn_metrics): """Creates a radar chart comparing the two models.""" labels = list(baseline_metrics.keys()) baseline_stats = list(baseline_metrics.values()) shift_attn_stats = list(shift_attn_metrics.values()) # Normalize stats for plotting. Higher is better for all metrics on the chart. # We will take the inverse of Perplexity and VRAM for a "higher is better" visualization. baseline_plot_stats = [ baseline_stats[0], # Tokens/Sec (Higher is better) 1 / (baseline_stats[1] + 1e-6), # VRAM (Inverse) 1 / (baseline_stats[2] + 1e-6), # Perplexity (Inverse) baseline_stats[3] # Sharpening (Higher is better) ] shift_attn_plot_stats = [ shift_attn_stats[0], 1 / (shift_attn_stats[1] + 1e-6), 1 / (shift_attn_stats[2] + 1e-6), shift_attn_stats[3] ] angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist() # Make the plot circular baseline_plot_stats += baseline_plot_stats[:1] shift_attn_plot_stats += shift_attn_plot_stats[:1] angles += angles[:1] fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) # Helper function to find nice plot limits def get_max_val(*args): return max(max(lst) for lst in args if lst) * 1.2 ax.set_ylim(0, get_max_val(baseline_plot_stats, shift_attn_plot_stats)) # Plot labels ax.set_xticks(angles[:-1]) ax.set_xticklabels(["Tokens/Sec\n(Higher is Better)", "1 / VRAM\n(Higher is Better)", "1 / Perplexity\n(Higher is Better)", "Logit Sharpening\n(Higher is Better)"]) # Plot data ax.plot(angles, baseline_plot_stats, 'o-', linewidth=2, label="Baseline") ax.fill(angles, baseline_plot_stats, alpha=0.25) ax.plot(angles, shift_attn_plot_stats, 'o-', linewidth=2, label="Shift-Attn") ax.fill(angles, shift_attn_plot_stats, alpha=0.25) ax.set_title("Model Performance Comparison", size=20, color='gray', y=1.1) ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1)) plt.tight_layout() return fig # --- Gradio Interface --- def run_comparison(prompt, max_new_tokens): if not BASELINE_MODEL or not SHIFT_ATTN_MODEL: raise gr.Error(ERROR_MESSAGE) input_ids = ENCODE(prompt) x = (torch.tensor(input_ids, dtype=torch.long, device=DEVICE)[None, ...]) # Run both models baseline_text, baseline_metrics = generate_and_measure(BASELINE_MODEL, x, max_new_tokens) shift_attn_text, shift_attn_metrics = generate_and_measure(SHIFT_ATTN_MODEL, x, max_new_tokens) # Create plot chart = plot_radar_chart(baseline_metrics, shift_attn_metrics) return baseline_text, shift_attn_text, chart with gr.Blocks(theme=gr.themes.Base()) as demo: gr.Markdown("# `shift-attn`: A Live Demonstration") gr.Markdown( "This demo compares a baseline `pgptlformer` model against an identical model enhanced with the `shift-attn` mechanism (`attention_deux`). " "The radar chart visualizes key performance and efficiency metrics, where a larger area indicates a better overall model." ) with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox(label="Enter your prompt:", value="The quick brown fox") token_slider = gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max New Tokens") submit_btn = gr.Button("Compare Models", variant="primary") with gr.Column(scale=2): plot_output = gr.Plot(label="Performance Radar Chart") with gr.Row(): baseline_output = gr.Textbox(label="Baseline Model Output", lines=8) shift_attn_output = gr.Textbox(label="Shift-Attn Model Output", lines=8) submit_btn.click( fn=run_comparison, inputs=[prompt_input, token_slider], outputs=[baseline_output, shift_attn_output, plot_output] ) if __name__ == "__main__": demo.launch()