itayitay123 commited on
Commit
a142524
·
verified ·
1 Parent(s): 3832834

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -17
app.py CHANGED
@@ -14,6 +14,7 @@ from sklearn.metrics.pairwise import cosine_similarity
14
  import torch
15
  import numpy as np
16
  import gradio as gr
 
17
 
18
  # Device setup
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -28,7 +29,7 @@ dataset = load_dataset("lirus18/deepfashion", split="train")
28
  # Embed a subset of dataset images
29
  image_vectors = []
30
  image_indices = []
31
- N = 500 # You can increase this later if needed
32
 
33
  for i in range(N):
34
  img = dataset[i]['image'].convert("RGB")
@@ -51,13 +52,33 @@ def find_similar(user_image, top_k=3, exclude_index=None):
51
  sims[exclude_index] = -1
52
 
53
  top_idx = sims.argsort()[-top_k:][::-1]
54
- return [dataset[image_indices[i]]['image'] for i in top_idx]
55
 
56
- # Main function (no generation, just simulation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def recommend_from_upload(uploaded_image):
58
  uploaded_image = uploaded_image.convert("RGB")
59
 
60
- # Check if uploaded image is already in dataset
61
  closest_idx = None
62
  for i in range(len(image_indices)):
63
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
@@ -65,15 +86,27 @@ def recommend_from_upload(uploaded_image):
65
  closest_idx = i
66
  break
67
 
68
- # Get 3 similar images
69
- similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
 
 
 
70
 
71
- # Use a gray image as the "best generated outfit"
72
- placeholder_img = Image.new("RGB", (512, 512), color="gray")
 
 
 
 
 
 
 
 
 
73
 
74
- return [uploaded_image] + similar_imgs + [placeholder_img]
75
 
76
- # Example image paths (must be in root folder)
77
  example_paths = [
78
  ["example1.jpg"],
79
  ["example2.jpg"],
@@ -82,10 +115,10 @@ example_paths = [
82
  ["example5.jpg"]
83
  ]
84
 
85
- # Gradio interface with button
86
  with gr.Blocks() as demo:
87
  gr.Markdown("## 👗 Fashion Outfit Recommender")
88
- gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 simulated AI-generated outfit.")
89
 
90
  with gr.Row():
91
  image_input = gr.Image(type="pil", label="Upload a clothing item")
@@ -93,11 +126,11 @@ with gr.Blocks() as demo:
93
  generate_btn = gr.Button("Generate Recommendations")
94
 
95
  with gr.Row():
96
- output1 = gr.Image(label="Your Input")
97
- output2 = gr.Image(label="Similar Item 1")
98
- output3 = gr.Image(label="Similar Item 2")
99
- output4 = gr.Image(label="Similar Item 3")
100
- output5 = gr.Image(label="Simulated AI-Generated Outfit")
101
 
102
  examples = gr.Examples(
103
  examples=example_paths,
@@ -113,3 +146,4 @@ if __name__ == "__main__":
113
  demo.launch()
114
 
115
 
 
 
14
  import torch
15
  import numpy as np
16
  import gradio as gr
17
+ from diffusers import StableDiffusionImg2ImgPipeline
18
 
19
  # Device setup
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
29
  # Embed a subset of dataset images
30
  image_vectors = []
31
  image_indices = []
32
+ N = 500
33
 
34
  for i in range(N):
35
  img = dataset[i]['image'].convert("RGB")
 
52
  sims[exclude_index] = -1
53
 
54
  top_idx = sims.argsort()[-top_k:][::-1]
55
+ return [dataset[image_indices[i]]['image'] for i in top_idx], query_vec
56
 
57
+ # Load Stable Diffusion
58
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
59
+ "runwayml/stable-diffusion-v1-5",
60
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
+ low_cpu_mem_usage=True
62
+ ).to(device)
63
+ pipe.enable_attention_slicing()
64
+
65
+ # Generate 10 images
66
+ def generate_outfits(input_image, n=10):
67
+ prompt = "fashion outfit design inspired by the clothing item"
68
+ init_image = input_image.resize((512, 512))
69
+ generated_images = []
70
+
71
+ for _ in range(n):
72
+ result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
73
+ generated_images.append(result.images[0])
74
+
75
+ return generated_images
76
+
77
+ # Main function
78
  def recommend_from_upload(uploaded_image):
79
  uploaded_image = uploaded_image.convert("RGB")
80
 
81
+ # Check for duplicates
82
  closest_idx = None
83
  for i in range(len(image_indices)):
84
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
 
86
  closest_idx = i
87
  break
88
 
89
+ # Find similar items
90
+ similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
91
+
92
+ # Generate 10 new outfits
93
+ generated_imgs = generate_outfits(uploaded_image, n=10)
94
 
95
+ # Select best match
96
+ best_score = -1
97
+ best_img = None
98
+ for img in generated_imgs:
99
+ inputs = processor(images=img, return_tensors="pt").to(device)
100
+ with torch.no_grad():
101
+ emb = model.get_image_features(**inputs).cpu().numpy()
102
+ sim = cosine_similarity(query_vec, emb)[0][0]
103
+ if sim > best_score:
104
+ best_score = sim
105
+ best_img = img
106
 
107
+ return [uploaded_image] + similar_imgs + [best_img]
108
 
109
+ # Example paths
110
  example_paths = [
111
  ["example1.jpg"],
112
  ["example2.jpg"],
 
115
  ["example5.jpg"]
116
  ]
117
 
118
+ # Gradio UI
119
  with gr.Blocks() as demo:
120
  gr.Markdown("## 👗 Fashion Outfit Recommender")
121
+ gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 AI-generated outfit design.")
122
 
123
  with gr.Row():
124
  image_input = gr.Image(type="pil", label="Upload a clothing item")
 
126
  generate_btn = gr.Button("Generate Recommendations")
127
 
128
  with gr.Row():
129
+ output1 = gr.Image(label="Your Input", height=512, width=384)
130
+ output2 = gr.Image(label="Similar Item 1", height=512, width=384)
131
+ output3 = gr.Image(label="Similar Item 2", height=512, width=384)
132
+ output4 = gr.Image(label="Similar Item 3", height=512, width=384)
133
+ output5 = gr.Image(label="AI-Generated Outfit", height=512, width=384)
134
 
135
  examples = gr.Examples(
136
  examples=example_paths,
 
146
  demo.launch()
147
 
148
 
149
+