GPT-Neo-TextGen / app.py
Sachin21112004's picture
Create app.py
ceb3792 verified
import gradio as gr
from transformers import GPTNeoForCausalLM, AutoTokenizer
import torch
try:
# Load the GPT-Neo model and tokenizer
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPTNeoForCausalLM.from_pretrained(model_name)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
print(f"Error loading model: {e}")
raise
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9):
"""
Generate text using GPT-Neo model with error handling
"""
try:
if not prompt or len(prompt.strip()) == 0:
return "Error: Please enter a prompt."
# Tokenize input
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate text
with torch.no_grad():
output = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode output
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
except RuntimeError as e:
return f"Memory Error: {str(e)}. Try reducing max_length."
except Exception as e:
return f"Error generating text: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="GPT-Neo Text Generation") as demo:
gr.Markdown("# GPT-Neo 1.3B Text Generation")
gr.Markdown("Generate creative text using the EleutherAI GPT-Neo 1.3B model")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Enter your prompt",
placeholder="Start typing your prompt...",
lines=3
)
with gr.Row():
max_length_slider = gr.Slider(
minimum=10,
maximum=200,
value=100,
step=10,
label="Max Length"
)
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P"
)
generate_button = gr.Button("Generate Text", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False
)
# Connect button click to generation function
generate_button.click(
fn=generate_text,
inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider],
outputs=output_text
)
# Add examples
gr.Examples(
examples=[
["Once upon a time"],
["The future of AI is"],
["In a galaxy far away"],
["Machine learning is"],
],
inputs=prompt_input,
label="Example Prompts"
)
if __name__ == "__main__":
demo.launch()