Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """FinalProject.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X | |
| """ | |
| from datasets import load_dataset | |
| from PIL import Image, ImageChops | |
| from transformers import CLIPProcessor, CLIPModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from diffusers import StableDiffusionImg2ImgPipeline | |
| # Device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load CLIP model | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Load dataset | |
| dataset = load_dataset("lirus18/deepfashion", split="train") | |
| # Embed a subset of images | |
| image_vectors = [] | |
| image_indices = [] | |
| N = 500 | |
| for i in range(N): | |
| img = dataset[i]['image'].convert("RGB") | |
| inputs = processor(images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| emb = model.get_image_features(**inputs) | |
| image_vectors.append(emb.cpu().numpy().squeeze()) | |
| image_indices.append(i) | |
| image_vectors = np.array(image_vectors) | |
| # Find similar images | |
| def find_similar(user_image, top_k=3, exclude_index=None): | |
| inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| query_vec = model.get_image_features(**inputs).cpu().numpy() | |
| sims = cosine_similarity(query_vec, image_vectors)[0] | |
| if exclude_index is not None: | |
| sims[exclude_index] = -1 | |
| top_idx = sims.argsort()[-top_k:][::-1] | |
| return [dataset[image_indices[i]]['image'] for i in top_idx] | |
| # Load Stable Diffusion | |
| pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| pipe.enable_attention_slicing() | |
| def generate_outfits(input_image, n=10): | |
| prompt = "fashion outfit design inspired by the clothing item" | |
| init_image = input_image.resize((512, 512)) | |
| generated_images = [] | |
| for _ in range(n): | |
| result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5) | |
| generated_images.append(result.images[0]) | |
| return generated_images | |
| def recommend_from_upload(uploaded_image): | |
| uploaded_image = uploaded_image.convert("RGB") | |
| closest_idx = None | |
| for i in range(len(image_indices)): | |
| dataset_image = dataset[image_indices[i]]['image'].convert("RGB") | |
| if ImageChops.difference(dataset_image, uploaded_image).getbbox() is None: | |
| closest_idx = i | |
| break | |
| similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx) | |
| generated_imgs = generate_outfits(uploaded_image, n=10) | |
| return [uploaded_image] + similar_imgs + generated_imgs | |
| # 5 clickable example images (must be uploaded to the repo) | |
| example_paths = [ | |
| ["fashion_examples/example1.jpg"], | |
| ["fashion_examples/example2.jpg"], | |
| ["fashion_examples/example3.jpg"], | |
| ["fashion_examples/example4.jpg"], | |
| ["fashion_examples/example5.jpg"] | |
| ] | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=recommend_from_upload, | |
| inputs=gr.Image(type="pil", label="Upload a clothing item"), | |
| outputs=[ | |
| gr.Image(label="Your Input"), | |
| gr.Image(label="Similar Item 1"), | |
| gr.Image(label="Similar Item 2"), | |
| gr.Image(label="Similar Item 3"), | |
| gr.Gallery(label="AI-Generated Outfits (x10)").style(grid=(5, 2), height="auto"), | |
| ], | |
| title="👗 Fashion Outfit Recommender", | |
| description="Upload a clothing image to get 3 similar items from the dataset and 10 AI-generated outfit designs.", | |
| examples=example_paths | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |