Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Model configurations | |
| MODELS = { | |
| "facebook/opt-350m": { | |
| "name": "Facebook OPT-350M", | |
| "description": "350M parameter foundation model" | |
| }, | |
| "EleutherAI/pythia-410m": { | |
| "name": "EleutherAI Pythia-410M", | |
| "description": "410M parameter foundation model" | |
| }, | |
| "openai-community/gpt2-medium": { | |
| "name": "GPT-2 Medium", | |
| "description": "355M parameter original GPT-2" | |
| } | |
| } | |
| # Global storage for models (will be loaded on GPU) | |
| current_model = None | |
| current_tokenizer = None | |
| current_model_name = None | |
| def generate_text(model_name, prompt, max_tokens, temperature, top_p): | |
| """Generate text using selected model with ZeroGPU""" | |
| global current_model, current_tokenizer, current_model_name | |
| try: | |
| # Load model if not loaded or different model selected | |
| if current_model is None or current_model_name != model_name: | |
| current_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| current_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| current_model_name = model_name | |
| inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device) | |
| with torch.no_grad(): | |
| outputs = current_model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=current_tokenizer.eos_token_id | |
| ) | |
| response = current_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Gradio interface | |
| with gr.Blocks(title="GPT-2 Class Models Comparison", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# GPT-2 Class Models Comparison (350M-410M)") | |
| gr.Markdown("Test and compare small foundation models trained on general text") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="facebook/opt-350m", | |
| label="Select Model", | |
| info="Choose which 350M-class model to test" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_info = gr.Markdown("") | |
| def update_model_info(model_name): | |
| info = MODELS[model_name] | |
| return f"**{info['name']}**\n\n{info['description']}" | |
| model_dropdown.change( | |
| fn=update_model_info, | |
| inputs=[model_dropdown], | |
| outputs=[model_info] | |
| ) | |
| # Initialize with default model info | |
| model_info.value = update_model_info("facebook/opt-350m") | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Type your prompt here...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=500, | |
| value=200, | |
| step=10, | |
| label="Max New Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P" | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| output_box = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| interactive=False | |
| ) | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[model_dropdown, prompt_box, max_tokens, temperature, top_p], | |
| outputs=[output_box] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Example Prompts:") | |
| gr.Markdown(""" | |
| - General: "Once upon a time in a distant galaxy" | |
| - Technical: "The main advantages of Python programming are" | |
| - Medical: "Question: What are the indications for aspirin in cardiac patients?\nAnswer:" | |
| """) | |
| demo.launch() | |