itayitay123 commited on
Commit
cbd13a7
·
verified ·
1 Parent(s): eabc0b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -46
app.py CHANGED
@@ -14,7 +14,6 @@ from sklearn.metrics.pairwise import cosine_similarity
14
  import torch
15
  import numpy as np
16
  import gradio as gr
17
- from diffusers import StableDiffusionImg2ImgPipeline
18
 
19
  # Device setup
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -29,7 +28,7 @@ dataset = load_dataset("lirus18/deepfashion", split="train")
29
  # Embed a subset of dataset images
30
  image_vectors = []
31
  image_indices = []
32
- N = 500
33
 
34
  for i in range(N):
35
  img = dataset[i]['image'].convert("RGB")
@@ -52,33 +51,13 @@ def find_similar(user_image, top_k=3, exclude_index=None):
52
  sims[exclude_index] = -1
53
 
54
  top_idx = sims.argsort()[-top_k:][::-1]
55
- return [dataset[image_indices[i]]['image'] for i in top_idx], query_vec
56
 
57
- # Load Stable Diffusion
58
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
59
- "runwayml/stable-diffusion-v1-5",
60
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
- low_cpu_mem_usage=True
62
- ).to(device)
63
- pipe.enable_attention_slicing()
64
-
65
- # Generate outfits (2 only)
66
- def generate_outfits(input_image, n=2):
67
- prompt = "fashion outfit design inspired by the clothing item"
68
- init_image = input_image.resize((512, 512))
69
- generated_images = []
70
-
71
- for _ in range(n):
72
- result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
73
- generated_images.append(result.images[0])
74
-
75
- return generated_images
76
-
77
- # Main recommendation function
78
  def recommend_from_upload(uploaded_image):
79
  uploaded_image = uploaded_image.convert("RGB")
80
 
81
- # Check if the uploaded image exists in dataset
82
  closest_idx = None
83
  for i in range(len(image_indices)):
84
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
@@ -86,27 +65,15 @@ def recommend_from_upload(uploaded_image):
86
  closest_idx = i
87
  break
88
 
89
- # Get 3 similar items + embedding
90
- similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
91
-
92
- # Generate 2 synthetic outfits
93
- generated_imgs = generate_outfits(uploaded_image, n=2)
94
 
95
- # Select most similar generated image
96
- best_score = -1
97
- best_generated_img = None
98
- for img in generated_imgs:
99
- inputs = processor(images=img, return_tensors="pt").to(device)
100
- with torch.no_grad():
101
- emb = model.get_image_features(**inputs).cpu().numpy()
102
- sim = cosine_similarity(query_vec, emb)[0][0]
103
- if sim > best_score:
104
- best_score = sim
105
- best_generated_img = img
106
 
107
- return [uploaded_image] + similar_imgs + [best_generated_img]
108
 
109
- # Example image paths (must exist in root folder)
110
  example_paths = [
111
  ["example1.jpg"],
112
  ["example2.jpg"],
@@ -115,10 +82,10 @@ example_paths = [
115
  ["example5.jpg"]
116
  ]
117
 
118
- # Gradio Interface with button
119
  with gr.Blocks() as demo:
120
  gr.Markdown("## 👗 Fashion Outfit Recommender")
121
- gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 AI-generated outfit design.")
122
 
123
  with gr.Row():
124
  image_input = gr.Image(type="pil", label="Upload a clothing item")
@@ -130,7 +97,7 @@ with gr.Blocks() as demo:
130
  output2 = gr.Image(label="Similar Item 1")
131
  output3 = gr.Image(label="Similar Item 2")
132
  output4 = gr.Image(label="Similar Item 3")
133
- output5 = gr.Image(label="Best AI-Generated Outfit")
134
 
135
  examples = gr.Examples(
136
  examples=example_paths,
@@ -145,3 +112,4 @@ with gr.Blocks() as demo:
145
  if __name__ == "__main__":
146
  demo.launch()
147
 
 
 
14
  import torch
15
  import numpy as np
16
  import gradio as gr
 
17
 
18
  # Device setup
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
28
  # Embed a subset of dataset images
29
  image_vectors = []
30
  image_indices = []
31
+ N = 500 # You can increase this later if needed
32
 
33
  for i in range(N):
34
  img = dataset[i]['image'].convert("RGB")
 
51
  sims[exclude_index] = -1
52
 
53
  top_idx = sims.argsort()[-top_k:][::-1]
54
+ return [dataset[image_indices[i]]['image'] for i in top_idx]
55
 
56
+ # Main function (no generation, just simulation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def recommend_from_upload(uploaded_image):
58
  uploaded_image = uploaded_image.convert("RGB")
59
 
60
+ # Check if uploaded image is already in dataset
61
  closest_idx = None
62
  for i in range(len(image_indices)):
63
  dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
 
65
  closest_idx = i
66
  break
67
 
68
+ # Get 3 similar images
69
+ similar_imgs = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
 
 
 
70
 
71
+ # Use a gray image as the "best generated outfit"
72
+ placeholder_img = Image.new("RGB", (512, 512), color="gray")
 
 
 
 
 
 
 
 
 
73
 
74
+ return [uploaded_image] + similar_imgs + [placeholder_img]
75
 
76
+ # Example image paths (must be in root folder)
77
  example_paths = [
78
  ["example1.jpg"],
79
  ["example2.jpg"],
 
82
  ["example5.jpg"]
83
  ]
84
 
85
+ # Gradio interface with button
86
  with gr.Blocks() as demo:
87
  gr.Markdown("## 👗 Fashion Outfit Recommender")
88
+ gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 simulated AI-generated outfit.")
89
 
90
  with gr.Row():
91
  image_input = gr.Image(type="pil", label="Upload a clothing item")
 
97
  output2 = gr.Image(label="Similar Item 1")
98
  output3 = gr.Image(label="Similar Item 2")
99
  output4 = gr.Image(label="Similar Item 3")
100
+ output5 = gr.Image(label="Simulated AI-Generated Outfit")
101
 
102
  examples = gr.Examples(
103
  examples=example_paths,
 
112
  if __name__ == "__main__":
113
  demo.launch()
114
 
115
+