hermanda commited on
Commit
84f4ae2
·
verified ·
1 Parent(s): aab94a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from typing import Optional
5
+ from io import BytesIO
6
+
7
+ import requests
8
+ import gradio as gr
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ from pydantic import BaseModel
11
+ from openai import OpenAI
12
+
13
+
14
+ class MemeData(BaseModel):
15
+ """
16
+ Pydantic model for meme data returned by the language model.
17
+
18
+ Attributes:
19
+ image_prompt (str): A textual prompt to generate an image.
20
+ top_text (str): The top text to place on the generated meme.
21
+ bottom_text (Optional[str]): The bottom text to place on the meme if desired.
22
+ """
23
+
24
+ image_prompt: str
25
+ top_text: str
26
+ bottom_text: Optional[str] = None
27
+
28
+
29
+ def get_meme(meme_seed: str, client: OpenAI) -> MemeData:
30
+ """
31
+ Get meme data (image prompt, top text, and optional bottom text)
32
+ based on the provided meme seed.
33
+
34
+ Args:
35
+ meme_seed (str): A general idea or backstory for meme generation.
36
+ client (OpenAI): An instance of the OpenAI client (with an API key).
37
+
38
+ Returns:
39
+ MemeData: Parsed meme information containing the image prompt,
40
+ top text, and optional bottom text.
41
+ """
42
+
43
+ response = client.beta.chat.completions.parse(
44
+ model="gpt-4o",
45
+ messages=[
46
+ {
47
+ "role": "system",
48
+ "content": (
49
+ "You are the best meme lord in the world. You are also highly "
50
+ "creative, funny and clever."
51
+ ),
52
+ },
53
+ {
54
+ "role": "user",
55
+ "content": (
56
+ "Generate me a meme prompt for text2image generation "
57
+ "(image_prompt) and the top text (top_text) and bottom text "
58
+ "if you want to also include the bottom text (bottom_text). "
59
+ f"The meme will be based on the following: {meme_seed}"
60
+ ),
61
+ },
62
+ ],
63
+ response_format=MemeData,
64
+ )
65
+ meme = response.choices[0].message.parsed
66
+ return meme
67
+
68
+
69
+ def generate_image(image_prompt: str, client: OpenAI) -> str:
70
+ """
71
+ Generate an image URL based on the provided text prompt using DALL-E 3.
72
+
73
+ Args:
74
+ image_prompt (str): Text prompt describing the desired image.
75
+ client (OpenAI): An instance of the OpenAI client (with an API key).
76
+
77
+ Returns:
78
+ str: A URL pointing to the generated image.
79
+ """
80
+
81
+ response = client.images.generate(
82
+ model="dall-e-3",
83
+ prompt=image_prompt,
84
+ size="1024x1024",
85
+ quality="standard",
86
+ n=1,
87
+ )
88
+ return response.data[0].url
89
+
90
+
91
+ def generate_meme(
92
+ image_url: str, top_text: str, bottom_text: Optional[str] = None
93
+ ) -> Image.Image:
94
+ """
95
+ Generate a meme by placing the provided top and bottom text onto the image
96
+ from the given URL.
97
+
98
+ Args:
99
+ image_url (str): A URL to the generated image.
100
+ top_text (str): Text to be placed at the top of the image.
101
+ bottom_text (Optional[str]): Text to be placed at the bottom of the image.
102
+
103
+ Returns:
104
+ Image.Image: A PIL Image object with the meme text drawn.
105
+ """
106
+
107
+ if bottom_text is None:
108
+ bottom_text = ""
109
+
110
+ try:
111
+ response = requests.get(image_url)
112
+ response.raise_for_status()
113
+ except requests.HTTPError as http_err:
114
+ print(f"HTTP error occurred: {http_err}")
115
+ raise
116
+ except Exception as err:
117
+ print(f"Other error occurred: {err}")
118
+ raise
119
+
120
+ image = Image.open(BytesIO(response.content))
121
+ draw = ImageDraw.Draw(image)
122
+ width, height = image.size
123
+
124
+ def fit_text(
125
+ text: str,
126
+ max_width: int,
127
+ draw_obj: ImageDraw.Draw,
128
+ font_obj: ImageFont.ImageFont,
129
+ ) -> list:
130
+ """
131
+ Split text into multiple lines ensuring that each line
132
+ fits within the specified max_width.
133
+
134
+ Args:
135
+ text (str): The text to split.
136
+ max_width (int): Maximum allowed width for the text.
137
+ draw_obj (ImageDraw.Draw): PIL drawing object.
138
+ font_obj (ImageFont.ImageFont): Font object for text measurement.
139
+
140
+ Returns:
141
+ list: A list of lines that fit within the max_width.
142
+ """
143
+ lines = []
144
+ line = ""
145
+ for word in text.split():
146
+ test_line = f"{line} {word}".strip()
147
+ test_width = (
148
+ draw_obj.textlength(test_line, font=font_obj)
149
+ if hasattr(draw_obj, "textlength")
150
+ else font_obj.getsize(test_line)[0]
151
+ )
152
+ if test_width <= max_width:
153
+ line = test_line
154
+ else:
155
+ lines.append(line)
156
+ line = word
157
+ if line:
158
+ lines.append(line)
159
+ return lines
160
+
161
+ max_height = height // 5
162
+ font_size = int(max_height / 2)
163
+
164
+ while True:
165
+ font = ImageFont.load_default(font_size)
166
+ top_lines = fit_text(top_text, width - 20, draw, font)
167
+ bottom_lines = fit_text(bottom_text, width - 20, draw, font)
168
+ top_text_height = len(top_lines) * font_size
169
+ bottom_text_height = len(bottom_lines) * font_size
170
+ if top_text_height <= max_height and bottom_text_height <= max_height:
171
+ break
172
+ font_size -= 1
173
+
174
+ top_y_position = 20
175
+ bottom_y_position = height - bottom_text_height - 20
176
+
177
+ for i, line in enumerate(top_lines):
178
+ draw.text(
179
+ (10, top_y_position + i * font_size),
180
+ line,
181
+ font=font,
182
+ fill="white",
183
+ stroke_width=2,
184
+ stroke_fill="black",
185
+ )
186
+
187
+ for i, line in enumerate(bottom_lines):
188
+ draw.text(
189
+ (10, bottom_y_position + i * font_size),
190
+ line,
191
+ font=font,
192
+ fill="white",
193
+ stroke_width=2,
194
+ stroke_fill="black",
195
+ )
196
+
197
+ return image
198
+
199
+ def process_single_meme(meme_seed: str, client: OpenAI):
200
+ """
201
+ Process a single meme: retrieve meme data, generate the image prompt,
202
+ generate the image from the prompt, and overlay top/bottom text.
203
+
204
+ Args:
205
+ meme_seed (str): A general idea or backstory for meme generation.
206
+ client (OpenAI): An instance of the OpenAI client (with an API key).
207
+
208
+ Returns:
209
+ tuple: (image_prompt, top_text, bottom_text, PIL.Image object).
210
+ """
211
+
212
+ meme_data = get_meme(meme_seed, client)
213
+ image_prompt = meme_data.image_prompt
214
+ top_text = meme_data.top_text
215
+ bottom_text = meme_data.bottom_text
216
+
217
+ image_url = generate_image(image_prompt, client)
218
+ image = generate_meme(image_url, top_text, bottom_text)
219
+
220
+ # Record details to stdout
221
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S%f")
222
+ print(
223
+ f"timestamp: {timestamp}, Meme Seed: {meme_seed}, "
224
+ f"Image Prompt: {image_prompt}, Top Text: {top_text}, "
225
+ f"Bottom Text: {bottom_text}"
226
+ )
227
+
228
+ return image_prompt, top_text, bottom_text, image
229
+
230
+
231
+ def process_multiple_memes(meme_seed: str, api_key: str):
232
+ """
233
+ Process multiple memes in parallel using ThreadPoolExecutor.
234
+
235
+ Args:
236
+ meme_seed (str): A general idea or backstory for meme generation.
237
+ api_key (str): The OpenAI API key to authenticate requests.
238
+
239
+ Returns:
240
+ list: A list of tuples, each containing
241
+ (image_prompt, top_text, bottom_text, PIL.Image object).
242
+ """
243
+
244
+ client = OpenAI(api_key=api_key)
245
+ with ThreadPoolExecutor(max_workers=4) as executor:
246
+ futures = [
247
+ executor.submit(process_single_meme, meme_seed, client) for _ in range(4)
248
+ ]
249
+ results = [future.result() for future in futures]
250
+ return results
251
+
252
+
253
+ def gradio_app() -> gr.Blocks:
254
+ """
255
+ Build and return the Gradio Blocks application for generating memes.
256
+
257
+ Returns:
258
+ gr.Blocks: A Gradio Blocks app object that can be launched.
259
+ """
260
+
261
+ with gr.Blocks() as app:
262
+ gr.Markdown("## Meme Generator")
263
+
264
+ openai_key_input = gr.Textbox(
265
+ label="OpenAI API Key",
266
+ placeholder="Enter your API key",
267
+ type="password",
268
+ )
269
+
270
+ with gr.Row():
271
+ with gr.Column(scale=1):
272
+ meme_seed_input = gr.Textbox(
273
+ label="Meme prompt",
274
+ value="Wait AI is generating memes now?",
275
+ placeholder="Enter a meme prompt",
276
+ )
277
+ generate_button = gr.Button("Generate")
278
+
279
+ with gr.Column(scale=3):
280
+ with gr.Row():
281
+ img1 = gr.Image(label="Image 1")
282
+ img2 = gr.Image(label="Image 2")
283
+ with gr.Row():
284
+ img3 = gr.Image(label="Image 3")
285
+ img4 = gr.Image(label="Image 4")
286
+
287
+ def display_memes(meme_seed: str, api_key: str):
288
+ """
289
+ Generate four memes based on the same meme seed and API key,
290
+ and return them for display.
291
+
292
+ Args:
293
+ meme_seed (str): A general idea or backstory for meme generation.
294
+ api_key (str): The OpenAI API key for the client.
295
+
296
+ Returns:
297
+ list: A list of four PIL.Image objects generated by the process.
298
+ """
299
+ results = process_multiple_memes(meme_seed, api_key)
300
+ memes_data = []
301
+ for _, _, _, image in results:
302
+ memes_data.append(image)
303
+ return memes_data
304
+
305
+ generate_button.click(
306
+ fn=display_memes,
307
+ inputs=[meme_seed_input, openai_key_input],
308
+ outputs=[img1, img2, img3, img4],
309
+ )
310
+
311
+ return app
312
+
313
+
314
+ if __name__ == "__main__":
315
+ app = gradio_app()
316
+ app.launch()