File size: 3,573 Bytes
ceb3792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from transformers import GPTNeoForCausalLM, AutoTokenizer
import torch

try:
    # Load the GPT-Neo model and tokenizer
    model_name = "EleutherAI/gpt-neo-1.3B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = GPTNeoForCausalLM.from_pretrained(model_name)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
except Exception as e:
    print(f"Error loading model: {e}")
    raise

def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9):
    """
    Generate text using GPT-Neo model with error handling
    """
    try:
        if not prompt or len(prompt.strip()) == 0:
            return "Error: Please enter a prompt."
        
        # Tokenize input
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        
        # Generate text
        with torch.no_grad():
            output = model.generate(
                input_ids,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode output
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        return generated_text
    
    except RuntimeError as e:
        return f"Memory Error: {str(e)}. Try reducing max_length."
    except Exception as e:
        return f"Error generating text: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="GPT-Neo Text Generation") as demo:
    gr.Markdown("# GPT-Neo 1.3B Text Generation")
    gr.Markdown("Generate creative text using the EleutherAI GPT-Neo 1.3B model")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Enter your prompt",
                placeholder="Start typing your prompt...",
                lines=3
            )
            
            with gr.Row():
                max_length_slider = gr.Slider(
                    minimum=10,
                    maximum=200,
                    value=100,
                    step=10,
                    label="Max Length"
                )
            
            with gr.Row():
                temperature_slider = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )
                
                top_p_slider = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05,
                    label="Top P"
                )
            
            generate_button = gr.Button("Generate Text", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="Generated Text",
                lines=10,
                interactive=False
            )
    
    # Connect button click to generation function
    generate_button.click(
        fn=generate_text,
        inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider],
        outputs=output_text
    )
    
    # Add examples
    gr.Examples(
        examples=[
            ["Once upon a time"],
            ["The future of AI is"],
            ["In a galaxy far away"],
            ["Machine learning is"],
        ],
        inputs=prompt_input,
        label="Example Prompts"
    )

if __name__ == "__main__":
    demo.launch()