Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import GPTNeoForCausalLM, AutoTokenizer | |
| import torch | |
| try: | |
| # Load the GPT-Neo model and tokenizer | |
| model_name = "EleutherAI/gpt-neo-1.3B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = GPTNeoForCausalLM.from_pretrained(model_name) | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9): | |
| """ | |
| Generate text using GPT-Neo model with error handling | |
| """ | |
| try: | |
| if not prompt or len(prompt.strip()) == 0: | |
| return "Error: Please enter a prompt." | |
| # Tokenize input | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| # Generate text | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode output | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return generated_text | |
| except RuntimeError as e: | |
| return f"Memory Error: {str(e)}. Try reducing max_length." | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="GPT-Neo Text Generation") as demo: | |
| gr.Markdown("# GPT-Neo 1.3B Text Generation") | |
| gr.Markdown("Generate creative text using the EleutherAI GPT-Neo 1.3B model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Start typing your prompt...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| max_length_slider = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=100, | |
| step=10, | |
| label="Max Length" | |
| ) | |
| with gr.Row(): | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P" | |
| ) | |
| generate_button = gr.Button("Generate Text", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Connect button click to generation function | |
| generate_button.click( | |
| fn=generate_text, | |
| inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider], | |
| outputs=output_text | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| ["Once upon a time"], | |
| ["The future of AI is"], | |
| ["In a galaxy far away"], | |
| ["Machine learning is"], | |
| ], | |
| inputs=prompt_input, | |
| label="Example Prompts" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |