Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import torch | |
| # Load model and tokenizer (using smaller GPT-2 for free tier) | |
| model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| # Set pad token | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50): | |
| """Generate text using GPT-2""" | |
| try: | |
| # Encode input | |
| inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs, | |
| max_length=min(max_length + len(inputs[0]), 512), # Limit total length | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| num_return_sequences=1 | |
| ) | |
| # Decode output | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Return only the new generated part | |
| return generated_text[len(prompt):].strip() | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="GPT-2 Text Generator") as demo: | |
| gr.Markdown("# GPT-2 Text Generation Server") | |
| gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your text prompt here...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| max_length = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=100, | |
| step=10, | |
| label="Max Length" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top-p" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top-k" | |
| ) | |
| generate_btn = gr.Button("Generate Text", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| placeholder="Generated text will appear here..." | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Once upon a time in a distant galaxy,"], | |
| ["The future of artificial intelligence is"], | |
| ["In the heart of the ancient forest,"], | |
| ["The detective walked into the room and noticed"], | |
| ], | |
| inputs=prompt_input | |
| ) | |
| # Connect the function with explicit API endpoint name | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt_input, max_length, temperature, top_p, top_k], | |
| outputs=output_text, | |
| api_name="/predict" # Explicit API endpoint for external calls | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |