import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Model configurations MODELS = { "facebook/opt-350m": { "name": "Facebook OPT-350M", "description": "350M parameter foundation model" }, "EleutherAI/pythia-410m": { "name": "EleutherAI Pythia-410M", "description": "410M parameter foundation model" }, "openai-community/gpt2-medium": { "name": "GPT-2 Medium", "description": "355M parameter original GPT-2" } } # Global storage for models (will be loaded on GPU) current_model = None current_tokenizer = None current_model_name = None @spaces.GPU def generate_text(model_name, prompt, max_tokens, temperature, top_p): """Generate text using selected model with ZeroGPU""" global current_model, current_tokenizer, current_model_name try: # Load model if not loaded or different model selected if current_model is None or current_model_name != model_name: current_tokenizer = AutoTokenizer.from_pretrained(model_name) current_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) current_model_name = model_name inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device) with torch.no_grad(): outputs = current_model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=current_tokenizer.eos_token_id ) response = current_tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: return f"Error: {str(e)}" # Gradio interface with gr.Blocks(title="GPT-2 Class Models Comparison", theme=gr.themes.Soft()) as demo: gr.Markdown("# GPT-2 Class Models Comparison (350M-410M)") gr.Markdown("Test and compare small foundation models trained on general text") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value="facebook/opt-350m", label="Select Model", info="Choose which 350M-class model to test" ) with gr.Row(): with gr.Column(): model_info = gr.Markdown("") def update_model_info(model_name): info = MODELS[model_name] return f"**{info['name']}**\n\n{info['description']}" model_dropdown.change( fn=update_model_info, inputs=[model_dropdown], outputs=[model_info] ) # Initialize with default model info model_info.value = update_model_info("facebook/opt-350m") prompt_box = gr.Textbox( label="Prompt", placeholder="Type your prompt here...", lines=5 ) with gr.Row(): max_tokens = gr.Slider( minimum=50, maximum=500, value=200, step=10, label="Max New Tokens" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P" ) generate_btn = gr.Button("Generate", variant="primary", size="lg") output_box = gr.Textbox( label="Generated Text", lines=10, interactive=False ) generate_btn.click( fn=generate_text, inputs=[model_dropdown, prompt_box, max_tokens, temperature, top_p], outputs=[output_box] ) gr.Markdown("---") gr.Markdown("### Example Prompts:") gr.Markdown(""" - General: "Once upon a time in a distant galaxy" - Technical: "The main advantages of Python programming are" - Medical: "Question: What are the indications for aspirin in cardiac patients?\nAnswer:" """) demo.launch()