# -*- coding: utf-8 -*- """FinalProject.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X """ from datasets import load_dataset from PIL import Image, ImageChops from transformers import CLIPProcessor, CLIPModel from sklearn.metrics.pairwise import cosine_similarity import torch import numpy as np import gradio as gr from diffusers import StableDiffusionImg2ImgPipeline # Device setup device = "cuda" if torch.cuda.is_available() else "cpu" # Load CLIP model model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Load dataset dataset = load_dataset("lirus18/deepfashion", split="train") # Embed a subset of dataset images image_vectors = [] image_indices = [] N = 500 for i in range(N): img = dataset[i]['image'].convert("RGB") inputs = processor(images=img, return_tensors="pt").to(device) with torch.no_grad(): emb = model.get_image_features(**inputs) image_vectors.append(emb.cpu().numpy().squeeze()) image_indices.append(i) image_vectors = np.array(image_vectors) # Similarity search def find_similar(user_image, top_k=3, exclude_index=None): inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device) with torch.no_grad(): query_vec = model.get_image_features(**inputs).cpu().numpy() sims = cosine_similarity(query_vec, image_vectors)[0] if exclude_index is not None: sims[exclude_index] = -1 top_idx = sims.argsort()[-top_k:][::-1] return [dataset[image_indices[i]]['image'] for i in top_idx], query_vec # Load Stable Diffusion pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True ).to(device) pipe.enable_attention_slicing() # Generate 10 images def generate_outfits(input_image, n=1): prompt = "fashion outfit design inspired by the clothing item" init_image = input_image.resize((512, 512)) generated_images = [] for _ in range(n): result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5) generated_images.append(result.images[0]) return generated_images # Main function def recommend_from_upload(uploaded_image): uploaded_image = uploaded_image.convert("RGB") # Check for duplicates closest_idx = None for i in range(len(image_indices)): dataset_image = dataset[image_indices[i]]['image'].convert("RGB") if ImageChops.difference(dataset_image, uploaded_image).getbbox() is None: closest_idx = i break # Find similar items similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx) # Generate 10 new outfits generated_imgs = generate_outfits(uploaded_image, n=1) # Select best match best_score = -1 best_img = None for img in generated_imgs: inputs = processor(images=img, return_tensors="pt").to(device) with torch.no_grad(): emb = model.get_image_features(**inputs).cpu().numpy() sim = cosine_similarity(query_vec, emb)[0][0] if sim > best_score: best_score = sim best_img = img return [uploaded_image] + similar_imgs + [best_img] # Example paths example_paths = [ ["example1.jpg"], ["example2.jpg"], ["example3.jpg"], ["example4.jpg"], ["example5.jpg"] ] # Gradio UI with gr.Blocks() as demo: gr.Markdown("## 👗 Fashion Outfit Recommender") gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 AI-generated outfit design.") with gr.Row(): image_input = gr.Image(type="pil", label="Upload a clothing item") generate_btn = gr.Button("Generate Recommendations") with gr.Row(): output1 = gr.Image(label="Your Input", height=512, width=384) output2 = gr.Image(label="Similar Item 1", height=512, width=384) output3 = gr.Image(label="Similar Item 2", height=512, width=384) output4 = gr.Image(label="Similar Item 3", height=512, width=384) output5 = gr.Image(label="AI-Generated Outfit", height=512, width=384) examples = gr.Examples( examples=example_paths, inputs=image_input, label="Try an Example" ) generate_btn.click(fn=recommend_from_upload, inputs=image_input, outputs=[output1, output2, output3, output4, output5]) if __name__ == "__main__": demo.launch()