| # import gradio as gr | |
| # import subprocess | |
| # import os | |
| # import random | |
| # from PIL import Image | |
| # import shutil | |
| # import requests | |
| # # === Setup Paths === | |
| # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # GEN_SCRIPT = os.path.join(BASE_DIR, "stylegan3", "gen_images.py") | |
| # OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") | |
| # MODEL_PATH = os.path.join(BASE_DIR, "top_model.pkl") | |
| # SAVE_DIR = os.path.join(BASE_DIR, "saved_images") | |
| # os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # os.makedirs(SAVE_DIR, exist_ok=True) | |
| # # === Image Generation Function === | |
| # def generate_images(): | |
| # command = [ | |
| # "python", | |
| # GEN_SCRIPT, | |
| # f"--outdir={OUTPUT_DIR}", | |
| # "--trunc=1", | |
| # "--seeds=3-5,7,9,12-14,16-26,29,31,32,34,40,41", | |
| # f"--network={MODEL_PATH}" | |
| # ] | |
| # try: | |
| # subprocess.run(command, check=True, capture_output=True, text=True) | |
| # except subprocess.CalledProcessError as e: | |
| # return f"Error generating images:\n{e.stderr}" | |
| # # === Select Random Images from Output Folder === | |
| # def get_random_images(): | |
| # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")] | |
| # if len(image_files) < 10: | |
| # generate_images() | |
| # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")] | |
| # random_images = random.sample(image_files, min(10, len(image_files))) | |
| # image_paths = [os.path.join(OUTPUT_DIR, img) for img in random_images] | |
| # return image_paths | |
| # # === Send Image to Backend === | |
| # def send_to_backend(img_path, user_id): | |
| # if not user_id: | |
| # return "❌ user_id not found in URL." | |
| # if not img_path or not os.path.exists(img_path): | |
| # return "⚠️ No image selected or image not found." | |
| # try: | |
| # with open(img_path, 'rb') as f: | |
| # files = {'file': ('generated_image.png', f, 'image/png')} | |
| # # Your backend endpoint here | |
| # url = f" https://7da2-2409-4042-6e81-1806-de6-b8e5-836c-6b95.ngrok-free.app/images/upload/{user_id}" | |
| # response = requests.post(url, files=files) | |
| # if response.status_code == 201: | |
| # return "✅ Image uploaded and saved to database!" | |
| # else: | |
| # return f"❌ Upload failed: {response.status_code} - {response.text}" | |
| # except Exception as e: | |
| # return f"⚠️ Error: {str(e)}" | |
| # # === Gradio Interface === | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("# 🎨 AI-Generated Clothing Designs - Tops") | |
| # generate_button = gr.Button("Generate New Designs") | |
| # user_id_state = gr.State() | |
| # @demo.load(inputs=None, outputs=[user_id_state]) | |
| # def get_user_id(request: gr.Request): | |
| # return request.query_params.get("user_id", "") | |
| # image_components = [] | |
| # file_paths = [] | |
| # save_buttons = [] | |
| # outputs = [] | |
| # # Use 3 columns layout | |
| # for row_idx in range(4): # 4 rows (to cover 10 images) | |
| # with gr.Row(): | |
| # for col_idx in range(3): # 3 columns | |
| # i = row_idx * 3 + col_idx | |
| # if i >= 10: | |
| # break | |
| # with gr.Column(): | |
| # img = gr.Image(width=180, height=180, label=f"Design {i+1}") | |
| # image_components.append(img) | |
| # file_path = gr.Textbox(visible=False) | |
| # file_paths.append(file_path) | |
| # save_btn = gr.Button("💾 Save to DB") | |
| # save_buttons.append(save_btn) | |
| # output = gr.Textbox(label="Status", interactive=False) | |
| # outputs.append(output) | |
| # save_btn.click( | |
| # fn=send_to_backend, | |
| # inputs=[file_path, user_id_state], | |
| # outputs=output | |
| # ) | |
| # # Generate button logic | |
| # def generate_and_display_images(): | |
| # image_paths = get_random_images() | |
| # return image_paths + image_paths # One for display, one for hidden path tracking | |
| # generate_button.click( | |
| # fn=generate_and_display_images, | |
| # outputs=image_components + file_paths | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| import torch | |
| from transformers import CLIPModel, CLIPProcessor | |
| from PIL import Image | |
| import numpy as np | |
| import pickle | |
| import gradio as gr | |
| import tempfile | |
| # Force CPU usage for optimization | |
| device = torch.device("cpu") | |
| # Load your GAN model | |
| with open("top_model.pkl", "rb") as f: | |
| G = pickle.load(f)['G_ema'].eval().cpu() # Ensure model is in eval mode and on CPU | |
| # Load CLIP model and processor | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval().cpu() | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # def send_to_backend(img_path, user_id): | |
| # if not user_id: | |
| # return "❌ user_id not found in URL." | |
| # if not img_path or not os.path.exists(img_path): | |
| # return "⚠️ No image selected or image not found." | |
| # try: | |
| # with open(img_path, 'rb') as f: | |
| # files = {'file': ('generated_image.png', f, 'image/png')} | |
| # # Your backend endpoint here | |
| # url = f"https://e335-103-40-74-83.ngrok-free.app/images/upload/{user_id}" | |
| # response = requests.post(url, files=files) | |
| # if response.status_code == 201: | |
| # return "✅ Image uploaded and saved to database!" | |
| # else: | |
| # console.log({response.text}) | |
| # return f"❌ Upload failed: {response.status_code} - {response.text}" | |
| # except Exception as e: | |
| # return f"⚠️ Error: {str(e)}" | |
| import os | |
| import requests # Make sure you import this! | |
| def send_to_backend(img_path, user_id): | |
| print(f"💡 [DEBUG] Sending image to backend | img_path={img_path}, user_id={user_id}") | |
| if not user_id: | |
| print("❌ [DEBUG] Missing user_id in URL.") | |
| return "❌ user_id not found in URL." | |
| if not img_path or not os.path.exists(img_path): | |
| print("⚠️ [DEBUG] Image path invalid or does not exist.") | |
| return "⚠️ No image selected or image not found." | |
| try: | |
| with open(img_path, 'rb') as f: | |
| files = {'file': ('generated_image.png', f, 'image/png')} | |
| url = f" https://68be601de1e4.ngrok-free.app/images/upload/{user_id}" | |
| print(f"🔁 [DEBUG] Sending POST to {url}") | |
| response = requests.post(url, files=files) | |
| print(f"📩 [DEBUG] Response: {response.status_code} - {response.text}") | |
| if response.status_code == 201 or response.status_code == 200: | |
| return "✅ Image uploaded and saved to database!" | |
| else: | |
| return f"❌ Upload failed: {response.status_code} - {response.text}" | |
| except Exception as e: | |
| print(f"⚠️ [ERROR] Exception during upload: {str(e)}") | |
| return f"⚠️ Error: {str(e)}" | |
| # Generate images | |
| def generate_images(G, num_images=10): # Reduce for CPU performance | |
| z = torch.randn(num_images, G.z_dim) | |
| c = None | |
| with torch.no_grad(): | |
| images = G(z, c) | |
| images = (images.clamp(-1, 1) + 1) * (255 / 2) | |
| images = images.permute(0, 2, 3, 1).numpy().astype(np.uint8) | |
| return z, images | |
| # Rank images using CLIP | |
| def rank_by_clip(images, prompt, top_k=3): # Reduce top_k for speed | |
| images_pil = [Image.fromarray(img) for img in images] | |
| inputs = clip_processor(text=[prompt], images=images_pil, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| image_features = clip_model.get_image_features(pixel_values=inputs["pixel_values"]) | |
| text_features = clip_model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| similarity = (image_features @ text_features.T).squeeze() | |
| top_indices = similarity.argsort(descending=True)[:top_k] | |
| best_images = [images_pil[i] for i in top_indices] | |
| return best_images | |
| # Gradio interface function | |
| def generate_top_dresses(prompt): | |
| _, images = generate_images(G, num_images=20) | |
| top_images = rank_by_clip(images, prompt, top_k=2) | |
| file_paths = [] | |
| for i, img in enumerate(top_images): | |
| temp_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name | |
| img.save(temp_path) | |
| file_paths.append(temp_path) | |
| return top_images, file_paths | |
| # Launch Gradio | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown(""" | |
| # 👗 AI Top Generator | |
| _Type in your dream outfit, and let the AI bring your fashion vision to life!_ | |
| Just describe and see how AI transforms your words into fashion. | |
| """) | |
| with gr.Row(): | |
| input_box = gr.Textbox( | |
| label="Describe your Design", | |
| placeholder="e.g., 'Black sleeveless crop top'", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_button = gr.Button("Generate Designs") | |
| user_id_state = gr.State() | |
| def get_user_id(request: gr.Request): | |
| return request.query_params.get("user_id", "") | |
| image_components = [] | |
| file_paths = [] | |
| save_buttons = [] | |
| outputs = [] | |
| with gr.Row(): | |
| for i in range(2): # Only 2 images | |
| with gr.Column(): | |
| img = gr.Image(width=180, height=180, label=f"Design {i+1}") | |
| image_components.append(img) | |
| file_path = gr.Textbox(visible=False) | |
| file_paths.append(file_path) | |
| save_btn = gr.Button("💾 Save to DB") | |
| save_buttons.append(save_btn) | |
| output = gr.Textbox(label="Status", interactive=False) | |
| outputs.append(output) | |
| save_btn.click( | |
| fn=send_to_backend, | |
| inputs=[file_path, user_id_state], | |
| outputs=output | |
| ) | |
| examples = gr.Examples( | |
| examples = [ | |
| ["Simple red V-neck top"], | |
| ["Simple blue round-neck top with short sleeves"] | |
| ], | |
| inputs=[input_box] | |
| ) | |
| # Generate button logicg | |
| def generate_and_display_images(prompt): | |
| images, paths = generate_top_dresses(prompt) | |
| return images + paths | |
| submit_button.click( | |
| fn=generate_and_display_images, | |
| inputs=[input_box], | |
| outputs=image_components + file_paths | |
| ) | |
| demo.launch() | |