gpt / app.py
Your Name
Add @spaces.GPU decorator for ZeroGPU support
1d9f921
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Model configurations
MODELS = {
"facebook/opt-350m": {
"name": "Facebook OPT-350M",
"description": "350M parameter foundation model"
},
"EleutherAI/pythia-410m": {
"name": "EleutherAI Pythia-410M",
"description": "410M parameter foundation model"
},
"openai-community/gpt2-medium": {
"name": "GPT-2 Medium",
"description": "355M parameter original GPT-2"
}
}
# Global storage for models (will be loaded on GPU)
current_model = None
current_tokenizer = None
current_model_name = None
@spaces.GPU
def generate_text(model_name, prompt, max_tokens, temperature, top_p):
"""Generate text using selected model with ZeroGPU"""
global current_model, current_tokenizer, current_model_name
try:
# Load model if not loaded or different model selected
if current_model is None or current_model_name != model_name:
current_tokenizer = AutoTokenizer.from_pretrained(model_name)
current_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
current_model_name = model_name
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
with torch.no_grad():
outputs = current_model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=current_tokenizer.eos_token_id
)
response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Error: {str(e)}"
# Gradio interface
with gr.Blocks(title="GPT-2 Class Models Comparison", theme=gr.themes.Soft()) as demo:
gr.Markdown("# GPT-2 Class Models Comparison (350M-410M)")
gr.Markdown("Test and compare small foundation models trained on general text")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="facebook/opt-350m",
label="Select Model",
info="Choose which 350M-class model to test"
)
with gr.Row():
with gr.Column():
model_info = gr.Markdown("")
def update_model_info(model_name):
info = MODELS[model_name]
return f"**{info['name']}**\n\n{info['description']}"
model_dropdown.change(
fn=update_model_info,
inputs=[model_dropdown],
outputs=[model_info]
)
# Initialize with default model info
model_info.value = update_model_info("facebook/opt-350m")
prompt_box = gr.Textbox(
label="Prompt",
placeholder="Type your prompt here...",
lines=5
)
with gr.Row():
max_tokens = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=10,
label="Max New Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P"
)
generate_btn = gr.Button("Generate", variant="primary", size="lg")
output_box = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False
)
generate_btn.click(
fn=generate_text,
inputs=[model_dropdown, prompt_box, max_tokens, temperature, top_p],
outputs=[output_box]
)
gr.Markdown("---")
gr.Markdown("### Example Prompts:")
gr.Markdown("""
- General: "Once upon a time in a distant galaxy"
- Technical: "The main advantages of Python programming are"
- Medical: "Question: What are the indications for aspirin in cardiac patients?\nAnswer:"
""")
demo.launch()