Sachin21112004 commited on
Commit
ceb3792
·
verified ·
1 Parent(s): 6454aff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import GPTNeoForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ try:
6
+ # Load the GPT-Neo model and tokenizer
7
+ model_name = "EleutherAI/gpt-neo-1.3B"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = GPTNeoForCausalLM.from_pretrained(model_name)
10
+
11
+ # Set device
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
14
+
15
+ except Exception as e:
16
+ print(f"Error loading model: {e}")
17
+ raise
18
+
19
+ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9):
20
+ """
21
+ Generate text using GPT-Neo model with error handling
22
+ """
23
+ try:
24
+ if not prompt or len(prompt.strip()) == 0:
25
+ return "Error: Please enter a prompt."
26
+
27
+ # Tokenize input
28
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
29
+
30
+ # Generate text
31
+ with torch.no_grad():
32
+ output = model.generate(
33
+ input_ids,
34
+ max_length=max_length,
35
+ temperature=temperature,
36
+ top_p=top_p,
37
+ do_sample=True,
38
+ pad_token_id=tokenizer.eos_token_id
39
+ )
40
+
41
+ # Decode output
42
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
43
+ return generated_text
44
+
45
+ except RuntimeError as e:
46
+ return f"Memory Error: {str(e)}. Try reducing max_length."
47
+ except Exception as e:
48
+ return f"Error generating text: {str(e)}"
49
+
50
+ # Create Gradio interface
51
+ with gr.Blocks(title="GPT-Neo Text Generation") as demo:
52
+ gr.Markdown("# GPT-Neo 1.3B Text Generation")
53
+ gr.Markdown("Generate creative text using the EleutherAI GPT-Neo 1.3B model")
54
+
55
+ with gr.Row():
56
+ with gr.Column():
57
+ prompt_input = gr.Textbox(
58
+ label="Enter your prompt",
59
+ placeholder="Start typing your prompt...",
60
+ lines=3
61
+ )
62
+
63
+ with gr.Row():
64
+ max_length_slider = gr.Slider(
65
+ minimum=10,
66
+ maximum=200,
67
+ value=100,
68
+ step=10,
69
+ label="Max Length"
70
+ )
71
+
72
+ with gr.Row():
73
+ temperature_slider = gr.Slider(
74
+ minimum=0.1,
75
+ maximum=2.0,
76
+ value=0.7,
77
+ step=0.1,
78
+ label="Temperature"
79
+ )
80
+
81
+ top_p_slider = gr.Slider(
82
+ minimum=0.1,
83
+ maximum=1.0,
84
+ value=0.9,
85
+ step=0.05,
86
+ label="Top P"
87
+ )
88
+
89
+ generate_button = gr.Button("Generate Text", variant="primary")
90
+
91
+ with gr.Column():
92
+ output_text = gr.Textbox(
93
+ label="Generated Text",
94
+ lines=10,
95
+ interactive=False
96
+ )
97
+
98
+ # Connect button click to generation function
99
+ generate_button.click(
100
+ fn=generate_text,
101
+ inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider],
102
+ outputs=output_text
103
+ )
104
+
105
+ # Add examples
106
+ gr.Examples(
107
+ examples=[
108
+ ["Once upon a time"],
109
+ ["The future of AI is"],
110
+ ["In a galaxy far away"],
111
+ ["Machine learning is"],
112
+ ],
113
+ inputs=prompt_input,
114
+ label="Example Prompts"
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ demo.launch()