Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import io | |
| import os | |
| import requests | |
| import base64 | |
| import json | |
| from io import BytesIO | |
| from dotenv import load_dotenv | |
| import tempfile | |
| from streamlit_drawable_canvas import st_canvas | |
| import matplotlib.pyplot as plt | |
| def load_api_key(): | |
| # Try to get from environment variable | |
| api_key = os.getenv("IMAGE_GEN_API_KEY") | |
| # If not found and not in session state, try to get from secrets | |
| if not api_key and "api_key" not in st.session_state: | |
| try: | |
| api_key = st.secrets["sk-piXnT4vKFP1cSAHR3GxuNLirvgF97r0agK1vzve7KyE8ajSX"] | |
| except: | |
| api_key = None | |
| # If still not found, use session state value if it exists | |
| if not api_key and "api_key" in st.session_state: | |
| api_key = st.session_state.api_key | |
| return api_key | |
| def color_name(hex_color): | |
| """Convert hex color to a descriptive name""" | |
| # This is a simplified version - in a real app, you might use a color naming library | |
| color_map = { | |
| "#FFFFFF": "white", | |
| "#000000": "black", | |
| "#FF0000": "red", | |
| "#00FF00": "green", | |
| "#0000FF": "blue", | |
| "#FFFF00": "yellow", | |
| "#FF00FF": "pink", | |
| "#00FFFF": "cyan", | |
| "#FFA500": "orange", | |
| "#800080": "purple", | |
| "#A52A2A": "brown", | |
| "#808080": "gray" | |
| } | |
| # Try to find exact match | |
| if hex_color.upper() in color_map: | |
| return color_map[hex_color.upper()] | |
| # If no exact match, return a generic description | |
| return "colored" | |
| def generate_tshirt_image(drawing_image, user_text, user_text_color, tshirt_color, drawing_position, text_position, drawing_size, text_size): | |
| # Show loading message | |
| with st.spinner("Generating T-shirt design..."): | |
| try: | |
| # Get the API key | |
| api_key = load_api_key() | |
| if not api_key: | |
| st.error("API Key Missing. Please set it in the sidebar.") | |
| return None | |
| # Convert drawing to base64 | |
| buffered = BytesIO() | |
| drawing_image.save(buffered, format="PNG") | |
| drawing_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Determine the relative positions for the prompt | |
| drawing_pos = "centered" | |
| if drawing_position[1] < 0.4: | |
| drawing_pos = "upper" | |
| elif drawing_position[1] > 0.6: | |
| drawing_pos = "lower" | |
| text_pos = "bottom" | |
| if text_position[1] < 0.4: | |
| text_pos = "top" | |
| elif text_position[1] < 0.6: | |
| text_pos = "middle" | |
| # Determine size descriptions | |
| drawing_size_desc = "medium-sized" | |
| if drawing_size < 0.3: | |
| drawing_size_desc = "small" | |
| elif drawing_size > 0.5: | |
| drawing_size_desc = "large" | |
| text_size_desc = "medium" | |
| if text_size < 0.3: | |
| text_size_desc = "small" | |
| elif text_size > 0.6: | |
| text_size_desc = "large" | |
| # Create a prompt for the image generation API | |
| prompt = f""" | |
| Create a realistic photograph of a {color_name(tshirt_color)} t-shirt with a custom design. | |
| The t-shirt has the following elements: | |
| 1. {text_size_desc.capitalize()} text that says: "{user_text}" in {color_name(user_text_color)} color, positioned at the {text_pos} of the shirt. | |
| 2. A {drawing_size_desc} custom graphic design, positioned in the {drawing_pos} part of the shirt. | |
| Make it look like a professional product photo on a plain background. | |
| """ | |
| # API endpoint | |
| url = "https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image" | |
| # Request headers | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Accept": "application/json", | |
| "Authorization": f"Bearer {api_key}" | |
| } | |
| # Request payload | |
| payload = { | |
| "text_prompts": [{"text": prompt}], | |
| "cfg_scale": 7, | |
| "height": 1024, | |
| "width": 1024, | |
| "samples": 1, | |
| "steps": 30, | |
| } | |
| # Send request to image generation API | |
| response = requests.post(url, headers=headers, json=payload) | |
| if response.status_code == 200: | |
| # Save the generated image | |
| result = response.json() | |
| image_data = base64.b64decode(result["artifacts"][0]["base64"]) | |
| # Return the image | |
| return Image.open(BytesIO(image_data)) | |
| else: | |
| st.error(f"Failed to generate design: {response.text}") | |
| return None | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| return None | |
| def create_tshirt_preview(drawing_image, user_text, user_text_color, tshirt_color, drawing_position, text_position, drawing_size, text_size): | |
| # Create a base T-shirt template | |
| # Size is relative to the size of the display | |
| width, height = 400, 500 | |
| tshirt_template = Image.new("RGB", (width, height), color=tshirt_color) | |
| draw = ImageDraw.Draw(tshirt_template) | |
| # Draw T-shirt outline | |
| # This is a simple shape | |
| draw.polygon([(width*0.25, height*0.2), (width*0.75, height*0.2), | |
| (width*0.8, height*0.3), (width*0.8, height*0.8), | |
| (width*0.2, height*0.8), (width*0.2, height*0.3)], outline="black") | |
| # Draw sleeves | |
| draw.polygon([(width*0.75, height*0.2), (width*0.8, height*0.3), | |
| (width*0.9, height*0.25), (width*0.8, height*0.1)], outline="black") # Right sleeve | |
| draw.polygon([(width*0.25, height*0.2), (width*0.2, height*0.3), | |
| (width*0.1, height*0.25), (width*0.2, height*0.1)], outline="black") # Left sleeve | |
| # Place the user drawing at the specified position and size | |
| if drawing_image is not None: | |
| # Convert position from relative (0-1) to actual pixels | |
| draw_x = int(drawing_position[0] * width) | |
| draw_y = int(drawing_position[1] * height) | |
| # Calculate actual drawing size based on relative size | |
| actual_draw_size = int(drawing_size * min(width, height)) | |
| # Resize drawing to user-specified size | |
| resized_drawing = drawing_image.resize((actual_draw_size, actual_draw_size), Image.LANCZOS) | |
| # Calculate position to center the drawing at the specified point | |
| paste_x = max(0, draw_x - actual_draw_size // 2) | |
| paste_y = max(0, draw_y - actual_draw_size // 2) | |
| # Make sure the drawing doesn't go outside the template | |
| paste_x = min(paste_x, width - actual_draw_size) | |
| paste_y = min(paste_y, height - actual_draw_size) | |
| # Paste the drawing onto the template | |
| tshirt_template.paste(resized_drawing, (paste_x, paste_y), resized_drawing) | |
| # Add text if provided | |
| if user_text: | |
| try: | |
| # Convert position from relative (0-1) to actual pixels | |
| text_x = int(text_position[0] * width) | |
| text_y = int(text_position[1] * height) | |
| # Calculate actual text size based on relative size | |
| actual_text_size = int(text_size * 40) # Max font size is 40 | |
| # Try to use a TrueType font if available | |
| try: | |
| font = ImageFont.truetype("Arial", actual_text_size) | |
| except: | |
| font = ImageFont.load_default() | |
| # Draw text | |
| draw.text((text_x, text_y), user_text, fill=user_text_color, font=font, anchor="mm") | |
| except Exception as e: | |
| st.warning(f"Error adding text: {str(e)}") | |
| return tshirt_template | |
| def main(): | |
| st.set_page_config( | |
| page_title="T-Shirt Design App", | |
| page_icon="👕", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| st.title("T-Shirt Design Studio") | |
| st.subheader("Create your custom t-shirt design") | |
| # Initialize session state | |
| if "drawing_image" not in st.session_state: | |
| st.session_state.drawing_image = Image.new("RGBA", (500, 500), (255, 255, 255, 0)) | |
| if "user_text" not in st.session_state: | |
| st.session_state.user_text = "" | |
| if "user_text_color" not in st.session_state: | |
| st.session_state.user_text_color = "#000000" | |
| if "tshirt_color" not in st.session_state: | |
| st.session_state.tshirt_color = "#FFFFFF" | |
| if "drawing_position" not in st.session_state: | |
| st.session_state.drawing_position = (0.5, 0.4) # Relative position (0-1) | |
| if "text_position" not in st.session_state: | |
| st.session_state.text_position = (0.5, 0.7) # Relative position (0-1) | |
| if "drawing_size" not in st.session_state: | |
| st.session_state.drawing_size = 0.4 # Relative size (0-1) | |
| if "text_size" not in st.session_state: | |
| st.session_state.text_size = 0.5 # Relative size (0-1) | |
| if "generated_image" not in st.session_state: | |
| st.session_state.generated_image = None | |
| # Sidebar for API key | |
| with st.sidebar: | |
| st.header("API Settings") | |
| api_key = load_api_key() | |
| if not api_key: | |
| st.session_state.api_key = st.text_input("Stability AI API Key", type="password") | |
| if st.session_state.api_key: | |
| st.success("API Key set!") | |
| else: | |
| st.warning("Please enter your Stability AI API Key to generate designs") | |
| # Main app layout | |
| col1, col2 = st.columns([3, 2]) | |
| # Column 1: Design controls | |
| with col1: | |
| st.header("Design Your T-Shirt") | |
| # Create tabs for different design elements | |
| tab1, tab2, tab3 = st.tabs(["Drawing", "Text", "T-Shirt"]) | |
| # Tab 1: Drawing tools | |
| with tab1: | |
| st.subheader("Drawing Canvas") | |
| # Canvas settings | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 255, 255, 0)", | |
| stroke_width=st.slider("Brush size", 1, 30, 5), | |
| stroke_color=st.color_picker("Drawing color", "#000000"), | |
| background_color="rgba(255, 255, 255, 0)", | |
| height=400, | |
| width=400, | |
| drawing_mode="freedraw", | |
| key="canvas", | |
| ) | |
| # Convert canvas result to image if available | |
| if canvas_result.image_data is not None: | |
| # Convert the numpy array to PIL Image | |
| img_data = canvas_result.image_data | |
| if img_data.shape[2] == 4: # Check if there's an alpha channel | |
| # Create a PIL image from numpy array | |
| pil_image = Image.fromarray(img_data) | |
| st.session_state.drawing_image = pil_image | |
| if st.button("Clear Drawing"): | |
| # Reset the canvas by creating a new key | |
| st.session_state.drawing_image = Image.new("RGBA", (500, 500), (255, 255, 255, 0)) | |
| st.experimental_rerun() | |
| st.subheader("Drawing Position & Size") | |
| # Sliders for drawing position | |
| col_pos1, col_pos2 = st.columns(2) | |
| with col_pos1: | |
| draw_pos_x = st.slider("Horizontal Position", 0.1, 0.9, st.session_state.drawing_position[0], step=0.05) | |
| with col_pos2: | |
| draw_pos_y = st.slider("Vertical Position", 0.2, 0.8, st.session_state.drawing_position[1], step=0.05) | |
| st.session_state.drawing_position = (draw_pos_x, draw_pos_y) | |
| # Slider for drawing size | |
| drawing_size = st.slider("Drawing Size", 0.1, 0.8, st.session_state.drawing_size, step=0.05) | |
| st.session_state.drawing_size = drawing_size | |
| # Tab 2: Text controls | |
| with tab2: | |
| st.subheader("Text Settings") | |
| # Text input | |
| user_text = st.text_input("Text on T-shirt", value=st.session_state.user_text) | |
| st.session_state.user_text = user_text | |
| # Text color | |
| text_color = st.color_picker("Text Color", st.session_state.user_text_color) | |
| st.session_state.user_text_color = text_color | |
| st.subheader("Text Position & Size") | |
| # Sliders for text position | |
| col_txt1, col_txt2 = st.columns(2) | |
| with col_txt1: | |
| txt_pos_x = st.slider("Text Horizontal", 0.1, 0.9, st.session_state.text_position[0], step=0.05) | |
| with col_txt2: | |
| txt_pos_y = st.slider("Text Vertical", 0.2, 0.8, st.session_state.text_position[1], step=0.05) | |
| st.session_state.text_position = (txt_pos_x, txt_pos_y) | |
| # Slider for text size | |
| text_size = st.slider("Text Size", 0.1, 1.0, st.session_state.text_size, step=0.05) | |
| st.session_state.text_size = text_size | |
| # Tab 3: T-shirt color | |
| with tab3: | |
| st.subheader("T-Shirt Color") | |
| # T-shirt color picker | |
| tshirt_color = st.color_picker("Choose T-shirt Color", st.session_state.tshirt_color) | |
| st.session_state.tshirt_color = tshirt_color | |
| # Preset colors | |
| st.write("Quick colors:") | |
| # Define color presets | |
| color_presets = { | |
| "White": "#FFFFFF", | |
| "Black": "#000000", | |
| "Red": "#FF0000", | |
| "Blue": "#0000FF", | |
| "Green": "#00FF00", | |
| "Yellow": "#FFFF00", | |
| "Purple": "#800080", | |
| "Pink": "#FFC0CB", | |
| "Gray": "#808080", | |
| "Brown": "#A52A2A" | |
| } | |
| # Create color buttons in a grid | |
| color_cols = st.columns(5) | |
| for i, (name, color) in enumerate(color_presets.items()): | |
| with color_cols[i % 5]: | |
| if st.button(name, key=f"color_{name}"): | |
| st.session_state.tshirt_color = color | |
| st.experimental_rerun() | |
| # Column 2: Preview and generate | |
| with col2: | |
| st.header("T-Shirt Preview") | |
| # Create and display the t-shirt preview | |
| preview_image = create_tshirt_preview( | |
| st.session_state.drawing_image, | |
| st.session_state.user_text, | |
| st.session_state.user_text_color, | |
| st.session_state.tshirt_color, | |
| st.session_state.drawing_position, | |
| st.session_state.text_position, | |
| st.session_state.drawing_size, | |
| st.session_state.text_size | |
| ) | |
| st.image(preview_image, use_column_width=True) | |
| # Generate button | |
| if st.button("Generate T-Shirt Design", type="primary"): | |
| if not load_api_key(): | |
| st.error("Please enter your Stability AI API Key in the sidebar first") | |
| else: | |
| generated_img = generate_tshirt_image( | |
| st.session_state.drawing_image, | |
| st.session_state.user_text, | |
| st.session_state.user_text_color, | |
| st.session_state.tshirt_color, | |
| st.session_state.drawing_position, | |
| st.session_state.text_position, | |
| st.session_state.drawing_size, | |
| st.session_state.text_size | |
| ) | |
| if generated_img: | |
| st.session_state.generated_image = generated_img | |
| # Display generated image if available | |
| if st.session_state.generated_image: | |
| st.header("Generated Design") | |
| st.image(st.session_state.generated_image, use_column_width=True) | |
| # Download button for the generated image | |
| buf = BytesIO() | |
| st.session_state.generated_image.save(buf, format="PNG") | |
| byte_im = buf.getvalue() | |
| st.download_button( | |
| label="Download Design", | |
| data=byte_im, | |
| file_name="tshirt_design.png", | |
| mime="image/png", | |
| ) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("T-Shirt Design Studio - Create your custom designs") | |
| if __name__ == "__main__": | |
| main() |