FaizanShaikh1 commited on
Commit
4558ea9
·
verified ·
1 Parent(s): ec6edac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -0
app.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import numpy as np
4
+ import io
5
+ import os
6
+ import requests
7
+ import base64
8
+ import json
9
+ from io import BytesIO
10
+ from dotenv import load_dotenv
11
+ import tempfile
12
+ from streamlit_drawable_canvas import st_canvas
13
+ import matplotlib.pyplot as plt
14
+
15
+ def load_api_key():
16
+ # Try to get from environment variable
17
+ api_key = os.getenv("IMAGE_GEN_API_KEY")
18
+
19
+ # If not found and not in session state, try to get from secrets
20
+ if not api_key and "api_key" not in st.session_state:
21
+ try:
22
+ api_key = st.secrets["sk-piXnT4vKFP1cSAHR3GxuNLirvgF97r0agK1vzve7KyE8ajSX"]
23
+ except:
24
+ api_key = None
25
+
26
+ # If still not found, use session state value if it exists
27
+ if not api_key and "api_key" in st.session_state:
28
+ api_key = st.session_state.api_key
29
+
30
+ return api_key
31
+
32
+ def color_name(hex_color):
33
+ """Convert hex color to a descriptive name"""
34
+ # This is a simplified version - in a real app, you might use a color naming library
35
+ color_map = {
36
+ "#FFFFFF": "white",
37
+ "#000000": "black",
38
+ "#FF0000": "red",
39
+ "#00FF00": "green",
40
+ "#0000FF": "blue",
41
+ "#FFFF00": "yellow",
42
+ "#FF00FF": "pink",
43
+ "#00FFFF": "cyan",
44
+ "#FFA500": "orange",
45
+ "#800080": "purple",
46
+ "#A52A2A": "brown",
47
+ "#808080": "gray"
48
+ }
49
+
50
+ # Try to find exact match
51
+ if hex_color.upper() in color_map:
52
+ return color_map[hex_color.upper()]
53
+
54
+ # If no exact match, return a generic description
55
+ return "colored"
56
+
57
+ def generate_tshirt_image(drawing_image, user_text, user_text_color, tshirt_color, drawing_position, text_position, drawing_size, text_size):
58
+ # Show loading message
59
+ with st.spinner("Generating T-shirt design..."):
60
+ try:
61
+ # Get the API key
62
+ api_key = load_api_key()
63
+ if not api_key:
64
+ st.error("API Key Missing. Please set it in the sidebar.")
65
+ return None
66
+
67
+ # Convert drawing to base64
68
+ buffered = BytesIO()
69
+ drawing_image.save(buffered, format="PNG")
70
+ drawing_base64 = base64.b64encode(buffered.getvalue()).decode()
71
+
72
+ # Determine the relative positions for the prompt
73
+ drawing_pos = "centered"
74
+ if drawing_position[1] < 0.4:
75
+ drawing_pos = "upper"
76
+ elif drawing_position[1] > 0.6:
77
+ drawing_pos = "lower"
78
+
79
+ text_pos = "bottom"
80
+ if text_position[1] < 0.4:
81
+ text_pos = "top"
82
+ elif text_position[1] < 0.6:
83
+ text_pos = "middle"
84
+
85
+ # Determine size descriptions
86
+ drawing_size_desc = "medium-sized"
87
+ if drawing_size < 0.3:
88
+ drawing_size_desc = "small"
89
+ elif drawing_size > 0.5:
90
+ drawing_size_desc = "large"
91
+
92
+ text_size_desc = "medium"
93
+ if text_size < 0.3:
94
+ text_size_desc = "small"
95
+ elif text_size > 0.6:
96
+ text_size_desc = "large"
97
+
98
+ # Create a prompt for the image generation API
99
+ prompt = f"""
100
+ Create a realistic photograph of a {color_name(tshirt_color)} t-shirt with a custom design.
101
+ The t-shirt has the following elements:
102
+ 1. {text_size_desc.capitalize()} text that says: "{user_text}" in {color_name(user_text_color)} color, positioned at the {text_pos} of the shirt.
103
+ 2. A {drawing_size_desc} custom graphic design, positioned in the {drawing_pos} part of the shirt.
104
+ Make it look like a professional product photo on a plain background.
105
+ """
106
+
107
+ # API endpoint
108
+ url = "https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image"
109
+
110
+ # Request headers
111
+ headers = {
112
+ "Content-Type": "application/json",
113
+ "Accept": "application/json",
114
+ "Authorization": f"Bearer {api_key}"
115
+ }
116
+
117
+ # Request payload
118
+ payload = {
119
+ "text_prompts": [{"text": prompt}],
120
+ "cfg_scale": 7,
121
+ "height": 1024,
122
+ "width": 1024,
123
+ "samples": 1,
124
+ "steps": 30,
125
+ }
126
+
127
+ # Send request to image generation API
128
+ response = requests.post(url, headers=headers, json=payload)
129
+
130
+ if response.status_code == 200:
131
+ # Save the generated image
132
+ result = response.json()
133
+ image_data = base64.b64decode(result["artifacts"][0]["base64"])
134
+
135
+ # Return the image
136
+ return Image.open(BytesIO(image_data))
137
+ else:
138
+ st.error(f"Failed to generate design: {response.text}")
139
+ return None
140
+
141
+ except Exception as e:
142
+ st.error(f"An error occurred: {str(e)}")
143
+ return None
144
+
145
+ def create_tshirt_preview(drawing_image, user_text, user_text_color, tshirt_color, drawing_position, text_position, drawing_size, text_size):
146
+ # Create a base T-shirt template
147
+ # Size is relative to the size of the display
148
+ width, height = 400, 500
149
+ tshirt_template = Image.new("RGB", (width, height), color=tshirt_color)
150
+ draw = ImageDraw.Draw(tshirt_template)
151
+
152
+ # Draw T-shirt outline
153
+ # This is a simple shape
154
+ draw.polygon([(width*0.25, height*0.2), (width*0.75, height*0.2),
155
+ (width*0.8, height*0.3), (width*0.8, height*0.8),
156
+ (width*0.2, height*0.8), (width*0.2, height*0.3)], outline="black")
157
+
158
+ # Draw sleeves
159
+ draw.polygon([(width*0.75, height*0.2), (width*0.8, height*0.3),
160
+ (width*0.9, height*0.25), (width*0.8, height*0.1)], outline="black") # Right sleeve
161
+ draw.polygon([(width*0.25, height*0.2), (width*0.2, height*0.3),
162
+ (width*0.1, height*0.25), (width*0.2, height*0.1)], outline="black") # Left sleeve
163
+
164
+ # Place the user drawing at the specified position and size
165
+ if drawing_image is not None:
166
+ # Convert position from relative (0-1) to actual pixels
167
+ draw_x = int(drawing_position[0] * width)
168
+ draw_y = int(drawing_position[1] * height)
169
+
170
+ # Calculate actual drawing size based on relative size
171
+ actual_draw_size = int(drawing_size * min(width, height))
172
+
173
+ # Resize drawing to user-specified size
174
+ resized_drawing = drawing_image.resize((actual_draw_size, actual_draw_size), Image.LANCZOS)
175
+
176
+ # Calculate position to center the drawing at the specified point
177
+ paste_x = max(0, draw_x - actual_draw_size // 2)
178
+ paste_y = max(0, draw_y - actual_draw_size // 2)
179
+
180
+ # Make sure the drawing doesn't go outside the template
181
+ paste_x = min(paste_x, width - actual_draw_size)
182
+ paste_y = min(paste_y, height - actual_draw_size)
183
+
184
+ # Paste the drawing onto the template
185
+ tshirt_template.paste(resized_drawing, (paste_x, paste_y), resized_drawing)
186
+
187
+ # Add text if provided
188
+ if user_text:
189
+ try:
190
+ # Convert position from relative (0-1) to actual pixels
191
+ text_x = int(text_position[0] * width)
192
+ text_y = int(text_position[1] * height)
193
+
194
+ # Calculate actual text size based on relative size
195
+ actual_text_size = int(text_size * 40) # Max font size is 40
196
+
197
+ # Try to use a TrueType font if available
198
+ try:
199
+ font = ImageFont.truetype("Arial", actual_text_size)
200
+ except:
201
+ font = ImageFont.load_default()
202
+
203
+ # Draw text
204
+ draw.text((text_x, text_y), user_text, fill=user_text_color, font=font, anchor="mm")
205
+ except Exception as e:
206
+ st.warning(f"Error adding text: {str(e)}")
207
+
208
+ return tshirt_template
209
+
210
+ def main():
211
+ st.set_page_config(
212
+ page_title="T-Shirt Design App",
213
+ page_icon="👕",
214
+ layout="wide",
215
+ initial_sidebar_state="expanded",
216
+ )
217
+
218
+ st.title("T-Shirt Design Studio")
219
+ st.subheader("Create your custom t-shirt design")
220
+
221
+ # Initialize session state
222
+ if "drawing_image" not in st.session_state:
223
+ st.session_state.drawing_image = Image.new("RGBA", (500, 500), (255, 255, 255, 0))
224
+
225
+ if "user_text" not in st.session_state:
226
+ st.session_state.user_text = ""
227
+
228
+ if "user_text_color" not in st.session_state:
229
+ st.session_state.user_text_color = "#000000"
230
+
231
+ if "tshirt_color" not in st.session_state:
232
+ st.session_state.tshirt_color = "#FFFFFF"
233
+
234
+ if "drawing_position" not in st.session_state:
235
+ st.session_state.drawing_position = (0.5, 0.4) # Relative position (0-1)
236
+
237
+ if "text_position" not in st.session_state:
238
+ st.session_state.text_position = (0.5, 0.7) # Relative position (0-1)
239
+
240
+ if "drawing_size" not in st.session_state:
241
+ st.session_state.drawing_size = 0.4 # Relative size (0-1)
242
+
243
+ if "text_size" not in st.session_state:
244
+ st.session_state.text_size = 0.5 # Relative size (0-1)
245
+
246
+ if "generated_image" not in st.session_state:
247
+ st.session_state.generated_image = None
248
+
249
+ # Sidebar for API key
250
+ with st.sidebar:
251
+ st.header("API Settings")
252
+
253
+ api_key = load_api_key()
254
+ if not api_key:
255
+ st.session_state.api_key = st.text_input("Stability AI API Key", type="password")
256
+ if st.session_state.api_key:
257
+ st.success("API Key set!")
258
+ else:
259
+ st.warning("Please enter your Stability AI API Key to generate designs")
260
+
261
+ # Main app layout
262
+ col1, col2 = st.columns([3, 2])
263
+
264
+ # Column 1: Design controls
265
+ with col1:
266
+ st.header("Design Your T-Shirt")
267
+
268
+ # Create tabs for different design elements
269
+ tab1, tab2, tab3 = st.tabs(["Drawing", "Text", "T-Shirt"])
270
+
271
+ # Tab 1: Drawing tools
272
+ with tab1:
273
+ st.subheader("Drawing Canvas")
274
+
275
+ # Canvas settings
276
+ canvas_result = st_canvas(
277
+ fill_color="rgba(255, 255, 255, 0)",
278
+ stroke_width=st.slider("Brush size", 1, 30, 5),
279
+ stroke_color=st.color_picker("Drawing color", "#000000"),
280
+ background_color="rgba(255, 255, 255, 0)",
281
+ height=400,
282
+ width=400,
283
+ drawing_mode="freedraw",
284
+ key="canvas",
285
+ )
286
+
287
+ # Convert canvas result to image if available
288
+ if canvas_result.image_data is not None:
289
+ # Convert the numpy array to PIL Image
290
+ img_data = canvas_result.image_data
291
+ if img_data.shape[2] == 4: # Check if there's an alpha channel
292
+ # Create a PIL image from numpy array
293
+ pil_image = Image.fromarray(img_data)
294
+ st.session_state.drawing_image = pil_image
295
+
296
+ if st.button("Clear Drawing"):
297
+ # Reset the canvas by creating a new key
298
+ st.session_state.drawing_image = Image.new("RGBA", (500, 500), (255, 255, 255, 0))
299
+ st.experimental_rerun()
300
+
301
+ st.subheader("Drawing Position & Size")
302
+
303
+ # Sliders for drawing position
304
+ col_pos1, col_pos2 = st.columns(2)
305
+ with col_pos1:
306
+ draw_pos_x = st.slider("Horizontal Position", 0.1, 0.9, st.session_state.drawing_position[0], step=0.05)
307
+ with col_pos2:
308
+ draw_pos_y = st.slider("Vertical Position", 0.2, 0.8, st.session_state.drawing_position[1], step=0.05)
309
+
310
+ st.session_state.drawing_position = (draw_pos_x, draw_pos_y)
311
+
312
+ # Slider for drawing size
313
+ drawing_size = st.slider("Drawing Size", 0.1, 0.8, st.session_state.drawing_size, step=0.05)
314
+ st.session_state.drawing_size = drawing_size
315
+
316
+ # Tab 2: Text controls
317
+ with tab2:
318
+ st.subheader("Text Settings")
319
+
320
+ # Text input
321
+ user_text = st.text_input("Text on T-shirt", value=st.session_state.user_text)
322
+ st.session_state.user_text = user_text
323
+
324
+ # Text color
325
+ text_color = st.color_picker("Text Color", st.session_state.user_text_color)
326
+ st.session_state.user_text_color = text_color
327
+
328
+ st.subheader("Text Position & Size")
329
+
330
+ # Sliders for text position
331
+ col_txt1, col_txt2 = st.columns(2)
332
+ with col_txt1:
333
+ txt_pos_x = st.slider("Text Horizontal", 0.1, 0.9, st.session_state.text_position[0], step=0.05)
334
+ with col_txt2:
335
+ txt_pos_y = st.slider("Text Vertical", 0.2, 0.8, st.session_state.text_position[1], step=0.05)
336
+
337
+ st.session_state.text_position = (txt_pos_x, txt_pos_y)
338
+
339
+ # Slider for text size
340
+ text_size = st.slider("Text Size", 0.1, 1.0, st.session_state.text_size, step=0.05)
341
+ st.session_state.text_size = text_size
342
+
343
+ # Tab 3: T-shirt color
344
+ with tab3:
345
+ st.subheader("T-Shirt Color")
346
+
347
+ # T-shirt color picker
348
+ tshirt_color = st.color_picker("Choose T-shirt Color", st.session_state.tshirt_color)
349
+ st.session_state.tshirt_color = tshirt_color
350
+
351
+ # Preset colors
352
+ st.write("Quick colors:")
353
+
354
+ # Define color presets
355
+ color_presets = {
356
+ "White": "#FFFFFF",
357
+ "Black": "#000000",
358
+ "Red": "#FF0000",
359
+ "Blue": "#0000FF",
360
+ "Green": "#00FF00",
361
+ "Yellow": "#FFFF00",
362
+ "Purple": "#800080",
363
+ "Pink": "#FFC0CB",
364
+ "Gray": "#808080",
365
+ "Brown": "#A52A2A"
366
+ }
367
+
368
+ # Create color buttons in a grid
369
+ color_cols = st.columns(5)
370
+ for i, (name, color) in enumerate(color_presets.items()):
371
+ with color_cols[i % 5]:
372
+ if st.button(name, key=f"color_{name}"):
373
+ st.session_state.tshirt_color = color
374
+ st.experimental_rerun()
375
+
376
+ # Column 2: Preview and generate
377
+ with col2:
378
+ st.header("T-Shirt Preview")
379
+
380
+ # Create and display the t-shirt preview
381
+ preview_image = create_tshirt_preview(
382
+ st.session_state.drawing_image,
383
+ st.session_state.user_text,
384
+ st.session_state.user_text_color,
385
+ st.session_state.tshirt_color,
386
+ st.session_state.drawing_position,
387
+ st.session_state.text_position,
388
+ st.session_state.drawing_size,
389
+ st.session_state.text_size
390
+ )
391
+
392
+ st.image(preview_image, use_column_width=True)
393
+
394
+ # Generate button
395
+ if st.button("Generate T-Shirt Design", type="primary"):
396
+ if not load_api_key():
397
+ st.error("Please enter your Stability AI API Key in the sidebar first")
398
+ else:
399
+ generated_img = generate_tshirt_image(
400
+ st.session_state.drawing_image,
401
+ st.session_state.user_text,
402
+ st.session_state.user_text_color,
403
+ st.session_state.tshirt_color,
404
+ st.session_state.drawing_position,
405
+ st.session_state.text_position,
406
+ st.session_state.drawing_size,
407
+ st.session_state.text_size
408
+ )
409
+
410
+ if generated_img:
411
+ st.session_state.generated_image = generated_img
412
+
413
+ # Display generated image if available
414
+ if st.session_state.generated_image:
415
+ st.header("Generated Design")
416
+ st.image(st.session_state.generated_image, use_column_width=True)
417
+
418
+ # Download button for the generated image
419
+ buf = BytesIO()
420
+ st.session_state.generated_image.save(buf, format="PNG")
421
+ byte_im = buf.getvalue()
422
+
423
+ st.download_button(
424
+ label="Download Design",
425
+ data=byte_im,
426
+ file_name="tshirt_design.png",
427
+ mime="image/png",
428
+ )
429
+
430
+ # Footer
431
+ st.markdown("---")
432
+ st.markdown("T-Shirt Design Studio - Create your custom designs")
433
+
434
+ if __name__ == "__main__":
435
+ main()