Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Load models and tokenizers | |
| models = { | |
| "GPT From Scratch (benchaffe/shakespeare-gpt-mini)": "benchaffe/shakespeare-gpt-mini", | |
| "Fine-tuned DistilGPT2 (benchaffe/shakespeare-distilgpt2)": "benchaffe/shakespeare-distilgpt2" | |
| } | |
| model_objects = {} | |
| tokenizer_objects = {} | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| for name, path in models.items(): | |
| tokenizer = AutoTokenizer.from_pretrained(path) | |
| model = AutoModelForCausalLM.from_pretrained(path).to(device) | |
| model.eval() | |
| model_objects[name] = model | |
| tokenizer_objects[name] = tokenizer | |
| # Generation function | |
| def generate_comparison(prompt, max_length, temperature, top_p, top_k, do_sample): | |
| results = {} | |
| for name in models: | |
| tokenizer = tokenizer_objects[name] | |
| model = model_objects[name] | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| output = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| results[name] = generated_text | |
| return [results[name] for name in models] | |
| # Examples | |
| example_prompts = [ | |
| "What light through yonder window breaks", | |
| "To be or not to be, that is the question", | |
| "Hark! Who goes there?", | |
| "My love is deep; the more I give to thee", | |
| "Thou art more lovely and more temperate" | |
| ] | |
| with gr.Blocks(title="Shakespeare Model Comparison") as demo: | |
| gr.Markdown("# 🧠 Shakespeare GPT Model Comparison") | |
| gr.Markdown("Compare outputs from two models: one trained from scratch and one fine-tuned from DistilGPT2. Adjust generation parameters below.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Input Prompt", placeholder="Enter a Shakespearean prompt here...") | |
| example = gr.Dropdown(example_prompts, label="Select an Example Prompt") | |
| example.change(fn=lambda e: gr.update(value=e), inputs=example, outputs=prompt) | |
| max_length = gr.Slider(32, 256, value=80, step=8, label="Max Length") | |
| temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
| top_k = gr.Slider(0, 100, value=50, step=5, label="Top-k") | |
| do_sample = gr.Checkbox(value=True, label="Use Sampling (if unchecked = greedy decoding)") | |
| submit_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| outputs = [gr.Textbox(label=name) for name in models] | |
| submit_btn.click(fn=generate_comparison, inputs=[prompt, max_length, temperature, top_p, top_k, do_sample], outputs=outputs) | |
| demo.launch() |