Roei Zavida commited on
Commit
76b76c9
·
unverified ·
1 Parent(s): d50c3a2

Initial Commit

Browse files
Files changed (2) hide show
  1. app.py +196 -125
  2. requirements.txt +4 -6
app.py CHANGED
@@ -1,61 +1,75 @@
1
- import gradio as gr
2
- import numpy as np
3
  import random
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
60
  css = """
61
  #col-container {
@@ -64,90 +78,147 @@ css = """
64
  }
65
  """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
  with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
  inputs=[
141
  prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
  ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
 
1
+ import base64
2
+ import io
3
  import random
4
+ import time
5
 
6
+ import gradio as gr
7
+ import openai
8
+ import requests
9
+ from PIL import Image
10
+
11
+
12
+ def generate_image(
13
+ prompt: str,
14
+ api_key: str,
15
+ base_url: str,
16
+ model: str,
17
+ size: str,
18
+ quality: str,
19
+ style: str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  progress=gr.Progress(track_tqdm=True),
21
  ):
22
+ # Validate required parameters
23
+ if not prompt:
24
+ raise gr.Error("Please provide a prompt for the image generation")
25
+ if not api_key:
26
+ raise gr.Error("API key is required")
27
+ if not base_url:
28
+ raise gr.Error("Base URL is required")
29
+ if not model:
30
+ raise gr.Error("Model name is required")
31
+
32
+ try:
33
+ # Initialize client (fast operation)
34
+ progress(0.1, desc="Initializing client...")
35
+ client = openai.OpenAI(api_key=api_key, base_url=base_url)
36
+
37
+ # Generate image (slowest operation)
38
+ progress(0.2, desc="Sending request to API...")
39
+ response = client.images.generate(model=model, prompt=prompt, size=size, quality=quality, style=style, n=1)
40
+ progress(0.6, desc="Generating image... This may take 10-30 seconds")
41
+
42
+ if hasattr(response, "data") and hasattr(response.data[0], "url") and response.data[0].url:
43
+ image_url = response.data[0].url
44
+ progress(0.8, desc="Downloading generated image...")
45
+ image_response = requests.get(image_url)
46
+ if image_response.status_code != 200:
47
+ raise gr.Error("Failed to download the generated image")
48
+
49
+ progress(0.9, desc="Processing image...")
50
+ img = Image.open(io.BytesIO(image_response.content))
51
+ progress(1.0, desc="Complete!")
52
+ return img
53
+
54
+ elif hasattr(response, "data") and hasattr(response.data[0], "b64_json") and response.data[0].b64_json:
55
+ progress(0.8, desc="Decoding base64 image...")
56
+ b64_data = response.data[0].b64_json
57
+ img_data = base64.b64decode(b64_data)
58
+
59
+ progress(0.9, desc="Processing image...")
60
+ img = Image.open(io.BytesIO(img_data))
61
+ progress(1.0, desc="Complete!")
62
+ return img
63
+ else:
64
+ raise gr.Error("No image data received from the API")
65
+
66
+ except openai.APIError as e:
67
+ raise gr.Error(f"OpenAI API error: {str(e)}")
68
+ except requests.RequestException as e:
69
+ raise gr.Error(f"Network error: {str(e)}")
70
+ except Exception as e:
71
+ raise gr.Error(f"An unexpected error occurred: {str(e)}")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  css = """
75
  #col-container {
 
78
  }
79
  """
80
 
81
+ examples = [
82
+ "A serene lake surrounded by mountains at sunset",
83
+ "A futuristic cityscape at night",
84
+ "A watercolor painting of a blooming cherry tree",
85
+ "A majestic lion resting on a rocky outcrop in the African savanna",
86
+ "A cozy cottage nestled in a snowy forest during winter",
87
+ "An astronaut floating in space with Earth in the background",
88
+ "A bustling marketplace in a Moroccan city",
89
+ "A vibrant coral reef teeming with marine life",
90
+ "A steampunk-inspired robot tending a garden",
91
+ "A minimalist abstract painting with bold colors",
92
+ "A hyperrealistic close-up of a dew-covered spiderweb",
93
+ "A fantasy landscape with floating islands and waterfalls",
94
+ "A vintage photograph of a jazz band in a smoky club",
95
+ "A serene beach with crystal-clear water and palm trees",
96
+ "A vibrant street market in a Southeast Asian city",
97
+ "A futuristic laboratory with advanced technology",
98
+ "A dense jungle with exotic plants and animals",
99
+ "A medieval castle on a hilltop overlooking a village",
100
+ "A bustling coffee shop in a rainy city",
101
+ "A peaceful Zen garden with carefully raked gravel",
102
+ "A majestic dragon soaring through a stormy sky",
103
+ "A bioluminescent forest at twilight",
104
+ "A portrait of a wise old wizard with a long beard",
105
+ "A group of penguins waddling across the Antarctic ice",
106
+ "A stack of pancakes with syrup and berries",
107
+ "A close-up of a blooming sunflower in a field",
108
+ "A cityscape reflected in a rain puddle",
109
+ "A cup of coffee with latte art",
110
+ "A snowy mountain range under a starry sky",
111
+ "A field of lavender in Provence, France",
112
+ "A plate of sushi with various types of fish",
113
+ "A hot air balloon floating over a valley",
114
+ "A lighthouse on a rocky cliff overlooking the ocean",
115
+ "A bowl of ramen with chopsticks",
116
+ "A tropical beach with a hammock and palm trees",
117
+ "A grand library with towering bookshelves",
118
+ "A cobblestone street in a European village",
119
+ "A waterfall cascading into a clear pool",
120
+ "A field of tulips in the Netherlands",
121
+ "A campfire under a starry night sky",
122
+ "A slice of pizza with pepperoni and cheese",
123
+ "A dense bamboo forest in Japan",
124
+ "A plate of pasta with tomato sauce and basil",
125
+ "A serene Japanese garden with a koi pond",
126
+ "A vibrant carnival with colorful lights and rides",
127
+ "A cozy fireplace in a log cabin",
128
+ "A field of wildflowers in the spring",
129
+ "A bustling train station in a major city",
130
+ "A quiet countryside road with rolling hills",
131
+ "A modern art museum with abstract sculptures",
132
+ ]
133
 
134
+ with gr.Blocks(css=css, title="OpenAI Image Generator") as demo:
135
+ gr.Markdown("# OpenAI Compatible Image Generator")
136
+ gr.Markdown("Generate images using OpenAI's DALL-E or compatible APIs")
137
+
138
+ # Initialize browser state with default values in a list
139
+ settings_state = gr.BrowserState(
140
+ [
141
+ "", # api_key
142
+ "https://api.openai.com/v1", # base_url
143
+ "dall-e-3", # model
144
+ "1024x1024", # size
145
+ "standard", # quality
146
+ "vivid", # style
147
+ ]
148
+ )
149
+ saved_message = gr.Markdown("✅ Settings saved", visible=False)
150
+
151
+ with gr.Row():
152
+ # Left column for settings
153
+ with gr.Column(scale=1):
154
+ gr.Markdown("### Settings")
155
+ api_key = gr.Textbox(label="API Key", placeholder="Your OpenAI API key", type="password", value="")
156
+ base_url = gr.Textbox(label="Base URL", placeholder="API base URL", value="https://api.openai.com/v1")
157
+ model = gr.Textbox(label="Model", placeholder="dall-e-3", value="dall-e-3")
158
+ size = gr.Dropdown(
159
+ label="Size", choices=["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], value="1024x1024"
160
  )
161
+ quality = gr.Dropdown(label="Quality", choices=["standard", "hd"], value="standard")
162
+ style = gr.Dropdown(label="Style", choices=["vivid", "natural"], value="vivid")
163
 
164
+ # Right column for prompt and image
165
+ with gr.Column(scale=2):
166
  with gr.Row():
167
+ prompt = gr.Text(
168
+ label="Prompt",
169
+ show_label=False,
170
+ max_lines=1,
171
+ placeholder=random.choice(examples),
172
+ value=random.choice(examples),
173
+ container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  )
175
+ run_button = gr.Button("Run", scale=0, variant="primary")
176
+
177
+ result = gr.Image(label="Result", show_label=False)
178
+ gr.Examples(examples=examples, inputs=[prompt])
179
+
180
+ # Load settings from browser storage
181
+ @demo.load(inputs=[settings_state], outputs=[api_key, base_url, model, size, quality, style])
182
+ def load_from_local_storage(saved_values):
183
+ print("Loading settings from local storage:", saved_values)
184
+ return (
185
+ saved_values[0], # api_key
186
+ saved_values[1], # base_url
187
+ saved_values[2], # model
188
+ saved_values[3], # size
189
+ saved_values[4], # quality
190
+ saved_values[5], # style
191
+ )
192
+
193
+ # Save settings to browser storage
194
+ @gr.on(
195
+ inputs=[api_key, base_url, model, size, quality, style],
196
+ outputs=[settings_state],
197
+ triggers=[api_key.change, base_url.change, model.change, size.change, quality.change, style.change],
198
+ )
199
+ def save_to_local_storage(api_key, base_url, model, size, quality, style):
200
+ return [api_key, base_url, model, size, quality, style]
201
 
202
+ # Show saved message when settings change
203
+ @gr.on(settings_state.change, outputs=[saved_message])
204
+ def show_saved_message():
205
+ timestamp = time.strftime("%I:%M:%S %p")
206
+ return gr.Markdown(f"✅ Settings saved at {timestamp}", visible=True)
 
 
207
 
208
+ # Main generation event
209
  gr.on(
210
  triggers=[run_button.click, prompt.submit],
211
+ fn=generate_image,
212
  inputs=[
213
  prompt,
214
+ api_key,
215
+ base_url,
216
+ model,
217
+ size,
218
+ quality,
219
+ style,
 
220
  ],
221
+ outputs=result,
222
  )
223
 
224
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
1
+ gradio
2
+ openai
3
+ requests
4
+ Pillow