File size: 3,833 Bytes
b53df71
 
 
 
 
 
 
 
69c19c8
b53df71
1d60bf1
b53df71
44aca10
b53df71
44aca10
 
1d60bf1
b53df71
1d60bf1
 
 
69c19c8
1d60bf1
b53df71
 
1d60bf1
 
b53df71
69c19c8
b53df71
 
69c19c8
b53df71
 
1d60bf1
 
b53df71
1d60bf1
 
44aca10
b53df71
 
 
1d60bf1
b53df71
 
 
 
 
 
 
69c19c8
b53df71
 
 
 
69c19c8
b53df71
 
 
69c19c8
b53df71
69c19c8
b53df71
69c19c8
b53df71
 
69c19c8
 
 
 
 
 
 
b53df71
 
 
 
 
 
 
 
 
 
 
69c19c8
 
 
 
 
 
 
 
 
 
 
 
b53df71
44aca10
b53df71
 
 
 
 
 
 
 
69c19c8
b53df71
 
69c19c8
 
b53df71
 
1d60bf1
 
69c19c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: utf-8 -*-
"""FinalProject.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X
"""

from datasets import load_dataset
from PIL import Image, ImageChops
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import gradio as gr
from diffusers import StableDiffusionImg2ImgPipeline

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load dataset
dataset = load_dataset("lirus18/deepfashion", split="train")

# Embed a subset of images
image_vectors = []
image_indices = []
N = 500

for i in range(N):
    img = dataset[i]['image'].convert("RGB")
    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        emb = model.get_image_features(**inputs)
    image_vectors.append(emb.cpu().numpy().squeeze())
    image_indices.append(i)

image_vectors = np.array(image_vectors)

# Find similar images
def find_similar(user_image, top_k=3, exclude_index=None):
    inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device)
    with torch.no_grad():
        query_vec = model.get_image_features(**inputs).cpu().numpy()

    sims = cosine_similarity(query_vec, image_vectors)[0]
    if exclude_index is not None:
        sims[exclude_index] = -1

    top_idx = sims.argsort()[-top_k:][::-1]
    return [dataset[image_indices[i]]['image'] for i in top_idx]

# Load Stable Diffusion
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True
).to(device)
pipe.enable_attention_slicing()

def generate_outfits(input_image, n=10):
    prompt = "fashion outfit design inspired by the clothing item"
    init_image = input_image.resize((512, 512))
    generated_images = []

    for _ in range(n):
        result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
        generated_images.append(result.images[0])

    return generated_images

def recommend_from_upload(uploaded_image):
    uploaded_image = uploaded_image.convert("RGB")
    closest_idx = None
    for i in range(len(image_indices)):
        dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
        if ImageChops.difference(dataset_image, uploaded_image).getbbox() is None:
            closest_idx = i
            break

    similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
    generated_imgs = generate_outfits(uploaded_image, n=10)

    return [uploaded_image] + similar_imgs + generated_imgs

# 5 clickable example images (must be uploaded to the repo)
example_paths = [
    ["fashion_examples/example1.jpg"],
    ["fashion_examples/example2.jpg"],
    ["fashion_examples/example3.jpg"],
    ["fashion_examples/example4.jpg"],
    ["fashion_examples/example5.jpg"]
]

# Gradio Interface
demo = gr.Interface(
    fn=recommend_from_upload,
    inputs=gr.Image(type="pil", label="Upload a clothing item"),
    outputs=[
        gr.Image(label="Your Input"),
        gr.Image(label="Similar Item 1"),
        gr.Image(label="Similar Item 2"),
        gr.Image(label="Similar Item 3"),
        gr.Gallery(label="AI-Generated Outfits (x10)").style(grid=(5, 2), height="auto"),
    ],
    title="👗 Fashion Outfit Recommender",
    description="Upload a clothing image to get 3 similar items from the dataset and 10 AI-generated outfit designs.",
    examples=example_paths
)

if __name__ == "__main__":
    demo.launch()