# import gradio as gr # import subprocess # import os # import random # from PIL import Image # import shutil # import requests # # === Setup Paths === # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # GEN_SCRIPT = os.path.join(BASE_DIR, "stylegan3", "gen_images.py") # OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") # MODEL_PATH = os.path.join(BASE_DIR, "top_model.pkl") # SAVE_DIR = os.path.join(BASE_DIR, "saved_images") # os.makedirs(OUTPUT_DIR, exist_ok=True) # os.makedirs(SAVE_DIR, exist_ok=True) # # === Image Generation Function === # def generate_images(): # command = [ # "python", # GEN_SCRIPT, # f"--outdir={OUTPUT_DIR}", # "--trunc=1", # "--seeds=3-5,7,9,12-14,16-26,29,31,32,34,40,41", # f"--network={MODEL_PATH}" # ] # try: # subprocess.run(command, check=True, capture_output=True, text=True) # except subprocess.CalledProcessError as e: # return f"Error generating images:\n{e.stderr}" # # === Select Random Images from Output Folder === # def get_random_images(): # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")] # if len(image_files) < 10: # generate_images() # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")] # random_images = random.sample(image_files, min(10, len(image_files))) # image_paths = [os.path.join(OUTPUT_DIR, img) for img in random_images] # return image_paths # # === Send Image to Backend === # def send_to_backend(img_path, user_id): # if not user_id: # return "❌ user_id not found in URL." # if not img_path or not os.path.exists(img_path): # return "⚠️ No image selected or image not found." # try: # with open(img_path, 'rb') as f: # files = {'file': ('generated_image.png', f, 'image/png')} # # Your backend endpoint here # url = f" https://7da2-2409-4042-6e81-1806-de6-b8e5-836c-6b95.ngrok-free.app/images/upload/{user_id}" # response = requests.post(url, files=files) # if response.status_code == 201: # return "✅ Image uploaded and saved to database!" # else: # return f"❌ Upload failed: {response.status_code} - {response.text}" # except Exception as e: # return f"⚠️ Error: {str(e)}" # # === Gradio Interface === # with gr.Blocks() as demo: # gr.Markdown("# 🎨 AI-Generated Clothing Designs - Tops") # generate_button = gr.Button("Generate New Designs") # user_id_state = gr.State() # @demo.load(inputs=None, outputs=[user_id_state]) # def get_user_id(request: gr.Request): # return request.query_params.get("user_id", "") # image_components = [] # file_paths = [] # save_buttons = [] # outputs = [] # # Use 3 columns layout # for row_idx in range(4): # 4 rows (to cover 10 images) # with gr.Row(): # for col_idx in range(3): # 3 columns # i = row_idx * 3 + col_idx # if i >= 10: # break # with gr.Column(): # img = gr.Image(width=180, height=180, label=f"Design {i+1}") # image_components.append(img) # file_path = gr.Textbox(visible=False) # file_paths.append(file_path) # save_btn = gr.Button("💾 Save to DB") # save_buttons.append(save_btn) # output = gr.Textbox(label="Status", interactive=False) # outputs.append(output) # save_btn.click( # fn=send_to_backend, # inputs=[file_path, user_id_state], # outputs=output # ) # # Generate button logic # def generate_and_display_images(): # image_paths = get_random_images() # return image_paths + image_paths # One for display, one for hidden path tracking # generate_button.click( # fn=generate_and_display_images, # outputs=image_components + file_paths # ) # if __name__ == "__main__": # demo.launch() import torch from transformers import CLIPModel, CLIPProcessor from PIL import Image import numpy as np import pickle import gradio as gr import tempfile # Force CPU usage for optimization device = torch.device("cpu") # Load your GAN model with open("top_model.pkl", "rb") as f: G = pickle.load(f)['G_ema'].eval().cpu() # Ensure model is in eval mode and on CPU # Load CLIP model and processor clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval().cpu() clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # def send_to_backend(img_path, user_id): # if not user_id: # return "❌ user_id not found in URL." # if not img_path or not os.path.exists(img_path): # return "⚠️ No image selected or image not found." # try: # with open(img_path, 'rb') as f: # files = {'file': ('generated_image.png', f, 'image/png')} # # Your backend endpoint here # url = f"https://e335-103-40-74-83.ngrok-free.app/images/upload/{user_id}" # response = requests.post(url, files=files) # if response.status_code == 201: # return "✅ Image uploaded and saved to database!" # else: # console.log({response.text}) # return f"❌ Upload failed: {response.status_code} - {response.text}" # except Exception as e: # return f"⚠️ Error: {str(e)}" import os import requests # Make sure you import this! def send_to_backend(img_path, user_id): print(f"💡 [DEBUG] Sending image to backend | img_path={img_path}, user_id={user_id}") if not user_id: print("❌ [DEBUG] Missing user_id in URL.") return "❌ user_id not found in URL." if not img_path or not os.path.exists(img_path): print("⚠️ [DEBUG] Image path invalid or does not exist.") return "⚠️ No image selected or image not found." try: with open(img_path, 'rb') as f: files = {'file': ('generated_image.png', f, 'image/png')} url = f" https://68be601de1e4.ngrok-free.app/images/upload/{user_id}" print(f"🔁 [DEBUG] Sending POST to {url}") response = requests.post(url, files=files) print(f"📩 [DEBUG] Response: {response.status_code} - {response.text}") if response.status_code == 201 or response.status_code == 200: return "✅ Image uploaded and saved to database!" else: return f"❌ Upload failed: {response.status_code} - {response.text}" except Exception as e: print(f"⚠️ [ERROR] Exception during upload: {str(e)}") return f"⚠️ Error: {str(e)}" # Generate images def generate_images(G, num_images=10): # Reduce for CPU performance z = torch.randn(num_images, G.z_dim) c = None with torch.no_grad(): images = G(z, c) images = (images.clamp(-1, 1) + 1) * (255 / 2) images = images.permute(0, 2, 3, 1).numpy().astype(np.uint8) return z, images # Rank images using CLIP def rank_by_clip(images, prompt, top_k=3): # Reduce top_k for speed images_pil = [Image.fromarray(img) for img in images] inputs = clip_processor(text=[prompt], images=images_pil, return_tensors="pt", padding=True) with torch.no_grad(): image_features = clip_model.get_image_features(pixel_values=inputs["pixel_values"]) text_features = clip_model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) similarity = (image_features @ text_features.T).squeeze() top_indices = similarity.argsort(descending=True)[:top_k] best_images = [images_pil[i] for i in top_indices] return best_images # Gradio interface function def generate_top_dresses(prompt): _, images = generate_images(G, num_images=20) top_images = rank_by_clip(images, prompt, top_k=2) file_paths = [] for i, img in enumerate(top_images): temp_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name img.save(temp_path) file_paths.append(temp_path) return top_images, file_paths # Launch Gradio with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: gr.Markdown(""" # 👗 AI Top Generator _Type in your dream outfit, and let the AI bring your fashion vision to life!_ Just describe and see how AI transforms your words into fashion. """) with gr.Row(): input_box = gr.Textbox( label="Describe your Design", placeholder="e.g., 'Black sleeveless crop top'", lines=2 ) with gr.Row(): submit_button = gr.Button("Generate Designs") user_id_state = gr.State() @demo.load(inputs=None, outputs=[user_id_state]) def get_user_id(request: gr.Request): return request.query_params.get("user_id", "") image_components = [] file_paths = [] save_buttons = [] outputs = [] with gr.Row(): for i in range(2): # Only 2 images with gr.Column(): img = gr.Image(width=180, height=180, label=f"Design {i+1}") image_components.append(img) file_path = gr.Textbox(visible=False) file_paths.append(file_path) save_btn = gr.Button("💾 Save to DB") save_buttons.append(save_btn) output = gr.Textbox(label="Status", interactive=False) outputs.append(output) save_btn.click( fn=send_to_backend, inputs=[file_path, user_id_state], outputs=output ) examples = gr.Examples( examples = [ ["Simple red V-neck top"], ["Simple blue round-neck top with short sleeves"] ], inputs=[input_box] ) # Generate button logicg def generate_and_display_images(prompt): images, paths = generate_top_dresses(prompt) return images + paths submit_button.click( fn=generate_and_display_images, inputs=[input_box], outputs=image_components + file_paths ) demo.launch()