Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import time | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import tiktoken | |
| import pgptlformer # Your model definition file | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from contextlib import nullcontext | |
| # --- Configuration --- | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' | |
| PTDTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[DTYPE] | |
| CTX = nullcontext() if DEVICE == 'cpu' else torch.amp.autocast(device_type=DEVICE, dtype=PTDTYPE) | |
| TORCH_COMPILE = False # Gradio instances can be slow, so compilation might timeout. Set to False for stability. | |
| # --- Model Loading --- | |
| def load_model(repo_id, filename, config_override=None): | |
| """Loads a model from the Hugging Face Hub.""" | |
| print(f"Loading model: {repo_id}/{filename}...") | |
| try: | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| checkpoint = torch.load(ckpt_path, map_location=DEVICE) | |
| tformer_cfg = checkpoint['model_args'] | |
| if config_override: | |
| tformer_cfg.update(config_override) | |
| model = pgptlformer.PGPT_Lformer(tformer_cfg) | |
| state_dict = checkpoint['model'] | |
| # Clean up state dict if needed | |
| unwanted_prefix = '_orig_mod.' | |
| for k, v in list(state_dict.items()): | |
| if k.startswith(unwanted_prefix): | |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
| model.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility | |
| model.eval() | |
| model.to(DEVICE) | |
| if TORCH_COMPILE: | |
| model = torch.compile(model) | |
| print(f"Model {filename} loaded successfully.") | |
| return model, tformer_cfg | |
| except Exception as e: | |
| print(f"Error loading model {filename}: {e}") | |
| raise | |
| # Load both models once at the start | |
| try: | |
| # FIX #1: Add the correct subdirectory for the baseline model | |
| BASELINE_MODEL, BASELINE_CFG = load_model( | |
| repo_id="SQCU/pgptlformer-tinystories", | |
| filename="re-pqt-rmsXrmsx2-70b91221-a39c-4824-a69c-48a034963529/state_step040500.pt" | |
| ) | |
| # FIX #2: The shift-attn model already had the directory, but ensure it's correct | |
| SHIFT_ATTN_MODEL, SHIFT_ATTN_CFG = load_model( | |
| repo_id="SQCU/pgptlformer-tinystories", | |
| filename="re-pqt-rmsXrmsx2x2-ATTNII-791967c5-5c59-4a5f-a2c5-07772bcf65ab/state_step040500.pt", | |
| config_override={"attention_deux": True} | |
| ) | |
| except Exception as e: | |
| # If loading fails, show an error in the Gradio app instead of crashing | |
| BASELINE_MODEL, SHIFT_ATTN_MODEL = None, None | |
| ERROR_MESSAGE = f"Failed to load models. Please check logs. Error: {e}" | |
| # --- Inference and Metrics --- | |
| ENC = tiktoken.get_encoding("gpt2") | |
| ENCODE = lambda s: ENC.encode(s, allowed_special={"<|endoftext|>"}) | |
| DECODE = lambda l: ENC.decode(l) | |
| def generate_and_measure(model, prompt_ids, max_new_tokens=50): | |
| """Runs inference and calculates metrics.""" | |
| # Reset stats for this run | |
| if DEVICE == 'cuda': | |
| torch.cuda.reset_peak_memory_stats(DEVICE) | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| # --- Generation Loop --- | |
| model_logits = [] | |
| generated_ids = prompt_ids | |
| for _ in range(max_new_tokens): | |
| idx_cond = generated_ids if generated_ids.size(1) <= 1024 else generated_ids[:, -1024:] | |
| logits, _, _ = model(idx_cond, return_logits=True) | |
| final_logits = logits[:, -1, :] | |
| model_logits.append(final_logits) # Store logits for perplexity/sharpening calc | |
| probs = torch.nn.functional.softmax(final_logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| generated_ids = torch.cat((generated_ids, idx_next), dim=1) | |
| if DEVICE == 'cuda': | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| # --- Metrics Calculation --- | |
| # 1. Inference Speed (Tokens/Second) | |
| tokens_per_sec = max_new_tokens / (end_time - start_time) | |
| # 2. VRAM Usage (MB) | |
| vram_usage = torch.cuda.max_memory_allocated(DEVICE) / (1024**2) if DEVICE == 'cuda' else 0 | |
| # 3. Pseudo-Perplexity | |
| all_logits = torch.cat(model_logits, dim=0) | |
| target_ids = generated_ids[0, -max_new_tokens:] | |
| cross_entropy = torch.nn.functional.cross_entropy(all_logits, target_ids) | |
| pseudo_perplexity = torch.exp(cross_entropy).item() | |
| # 4. Logit Sharpening (Average of max probability) | |
| avg_max_prob = torch.nn.functional.softmax(all_logits, dim=-1).max(dim=-1).values.mean().item() | |
| # --- Decode and Return --- | |
| output_text = DECODE(generated_ids[0].tolist()) | |
| metrics = { | |
| 'Tokens/Sec': tokens_per_sec, | |
| 'VRAM (MB)': vram_usage, | |
| 'Perplexity': pseudo_perplexity, | |
| 'Logit Sharpening': avg_max_prob, | |
| } | |
| return output_text, metrics | |
| # --- Visualization --- | |
| def plot_radar_chart(baseline_metrics, shift_attn_metrics): | |
| """Creates a radar chart comparing the two models.""" | |
| labels = list(baseline_metrics.keys()) | |
| baseline_stats = list(baseline_metrics.values()) | |
| shift_attn_stats = list(shift_attn_metrics.values()) | |
| # Normalize stats for plotting. Higher is better for all metrics on the chart. | |
| # We will take the inverse of Perplexity and VRAM for a "higher is better" visualization. | |
| baseline_plot_stats = [ | |
| baseline_stats[0], # Tokens/Sec (Higher is better) | |
| 1 / (baseline_stats[1] + 1e-6), # VRAM (Inverse) | |
| 1 / (baseline_stats[2] + 1e-6), # Perplexity (Inverse) | |
| baseline_stats[3] # Sharpening (Higher is better) | |
| ] | |
| shift_attn_plot_stats = [ | |
| shift_attn_stats[0], | |
| 1 / (shift_attn_stats[1] + 1e-6), | |
| 1 / (shift_attn_stats[2] + 1e-6), | |
| shift_attn_stats[3] | |
| ] | |
| angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist() | |
| # Make the plot circular | |
| baseline_plot_stats += baseline_plot_stats[:1] | |
| shift_attn_plot_stats += shift_attn_plot_stats[:1] | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) | |
| # Helper function to find nice plot limits | |
| def get_max_val(*args): | |
| return max(max(lst) for lst in args if lst) * 1.2 | |
| ax.set_ylim(0, get_max_val(baseline_plot_stats, shift_attn_plot_stats)) | |
| # Plot labels | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(["Tokens/Sec\n(Higher is Better)", "1 / VRAM\n(Higher is Better)", "1 / Perplexity\n(Higher is Better)", "Logit Sharpening\n(Higher is Better)"]) | |
| # Plot data | |
| ax.plot(angles, baseline_plot_stats, 'o-', linewidth=2, label="Baseline") | |
| ax.fill(angles, baseline_plot_stats, alpha=0.25) | |
| ax.plot(angles, shift_attn_plot_stats, 'o-', linewidth=2, label="Shift-Attn") | |
| ax.fill(angles, shift_attn_plot_stats, alpha=0.25) | |
| ax.set_title("Model Performance Comparison", size=20, color='gray', y=1.1) | |
| ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1)) | |
| plt.tight_layout() | |
| return fig | |
| # --- Gradio Interface --- | |
| def run_comparison(prompt, max_new_tokens): | |
| if not BASELINE_MODEL or not SHIFT_ATTN_MODEL: | |
| raise gr.Error(ERROR_MESSAGE) | |
| input_ids = ENCODE(prompt) | |
| x = (torch.tensor(input_ids, dtype=torch.long, device=DEVICE)[None, ...]) | |
| # Run both models | |
| baseline_text, baseline_metrics = generate_and_measure(BASELINE_MODEL, x, max_new_tokens) | |
| shift_attn_text, shift_attn_metrics = generate_and_measure(SHIFT_ATTN_MODEL, x, max_new_tokens) | |
| # Create plot | |
| chart = plot_radar_chart(baseline_metrics, shift_attn_metrics) | |
| return baseline_text, shift_attn_text, chart | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| gr.Markdown("# `shift-attn`: A Live Demonstration") | |
| gr.Markdown( | |
| "This demo compares a baseline `pgptlformer` model against an identical model enhanced with the `shift-attn` mechanism (`attention_deux`). " | |
| "The radar chart visualizes key performance and efficiency metrics, where a larger area indicates a better overall model." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="Enter your prompt:", value="The quick brown fox") | |
| token_slider = gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max New Tokens") | |
| submit_btn = gr.Button("Compare Models", variant="primary") | |
| with gr.Column(scale=2): | |
| plot_output = gr.Plot(label="Performance Radar Chart") | |
| with gr.Row(): | |
| baseline_output = gr.Textbox(label="Baseline Model Output", lines=8) | |
| shift_attn_output = gr.Textbox(label="Shift-Attn Model Output", lines=8) | |
| submit_btn.click( | |
| fn=run_comparison, | |
| inputs=[prompt_input, token_slider], | |
| outputs=[baseline_output, shift_attn_output, plot_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |