import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Load models and tokenizers models = { "GPT From Scratch (benchaffe/shakespeare-gpt-mini)": "benchaffe/shakespeare-gpt-mini", "Fine-tuned DistilGPT2 (benchaffe/shakespeare-distilgpt2)": "benchaffe/shakespeare-distilgpt2" } model_objects = {} tokenizer_objects = {} device = "cuda" if torch.cuda.is_available() else "cpu" for name, path in models.items(): tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path).to(device) model.eval() model_objects[name] = model tokenizer_objects[name] = tokenizer # Generation function def generate_comparison(prompt, max_length, temperature, top_p, top_k, do_sample): results = {} for name in models: tokenizer = tokenizer_objects[name] model = model_objects[name] inputs = tokenizer(prompt, return_tensors="pt").to(device) output = model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) results[name] = generated_text return [results[name] for name in models] # Examples example_prompts = [ "What light through yonder window breaks", "To be or not to be, that is the question", "Hark! Who goes there?", "My love is deep; the more I give to thee", "Thou art more lovely and more temperate" ] with gr.Blocks(title="Shakespeare Model Comparison") as demo: gr.Markdown("# 🧠 Shakespeare GPT Model Comparison") gr.Markdown("Compare outputs from two models: one trained from scratch and one fine-tuned from DistilGPT2. Adjust generation parameters below.") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Input Prompt", placeholder="Enter a Shakespearean prompt here...") example = gr.Dropdown(example_prompts, label="Select an Example Prompt") example.change(fn=lambda e: gr.update(value=e), inputs=example, outputs=prompt) max_length = gr.Slider(32, 256, value=80, step=8, label="Max Length") temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") top_k = gr.Slider(0, 100, value=50, step=5, label="Top-k") do_sample = gr.Checkbox(value=True, label="Use Sampling (if unchecked = greedy decoding)") submit_btn = gr.Button("Generate") with gr.Column(): outputs = [gr.Textbox(label=name) for name in models] submit_btn.click(fn=generate_comparison, inputs=[prompt, max_length, temperature, top_p, top_k, do_sample], outputs=outputs) demo.launch()