bobs24 commited on
Commit
68f893f
·
0 Parent(s):

track product with LFS

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ data/product_data.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ env/
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.DS_Store
7
+ data/embeddings.npy
8
+ data/image_urls.pkl
Dockerfile ADDED
File without changes
README.md ADDED
File without changes
data/product_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fd4ad76f4783518e02888845949d8eb4ecffe69860266d3fc576d2d27cadca4
3
+ size 30552758
main.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ import numpy as np
3
+ import pandas as pd
4
+ import pickle
5
+ from fastapi import FastAPI, File, UploadFile
6
+ from fastapi.responses import JSONResponse
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ from model.feature_extractor import FeatureExtractor
10
+ from utils.faiss_index import FaissIndex
11
+
12
+ import os
13
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
14
+
15
+ app = FastAPI()
16
+
17
+ # Load model and data
18
+ embeddings = np.load("data/embeddings.npy")
19
+ with open("data/image_urls.pkl", "rb") as f:
20
+ image_urls = pickle.load(f)
21
+ product_data = pd.read_csv("data/product_data.csv")
22
+
23
+ fe = FeatureExtractor()
24
+ index = FaissIndex(dim=embeddings.shape[1])
25
+ index.build(embeddings, image_urls)
26
+
27
+ @app.post("/recommend")
28
+ async def recommend(file: UploadFile = File(...), threshold: float = 0.8, k: int = 100):
29
+ try:
30
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
31
+ user_emb = fe.extract(image)
32
+ results = index.search(user_emb, threshold=threshold, k=k)
33
+
34
+ if not results:
35
+ return JSONResponse({"message": "No similar products found"}, status_code=404)
36
+
37
+ input_url = results[0][0]
38
+ input_row = product_data[product_data['IMAGE'] == input_url]
39
+
40
+ input_group_id = input_row['GROUP_ID'].values[0] if not input_row.empty else None
41
+ input_product_name = input_row['PRODUCT_NAME'].values[0] if not input_row.empty else None
42
+
43
+ # Filtering logic
44
+ filtered = []
45
+ for url, sim in results:
46
+ row = product_data[product_data['IMAGE'] == url]
47
+ group_id = row['GROUP_ID'].values[0] if not row.empty else None
48
+ product_name = row['PRODUCT_NAME'].values[0] if not row.empty else None
49
+
50
+ if (input_group_id is None or input_group_id == 0):
51
+ if product_name != input_product_name:
52
+ filtered.append((url, sim))
53
+ else:
54
+ if group_id != input_group_id:
55
+ filtered.append((url, sim))
56
+
57
+ # De-duplicate by product name
58
+ seen = set()
59
+ final = []
60
+ for url, sim in filtered:
61
+ row = product_data[product_data['IMAGE'] == url]
62
+ product_name = row['PRODUCT_NAME'].values[0] if not row.empty else None
63
+ if product_name and product_name not in seen:
64
+ seen.add(product_name)
65
+ brand_name = row['BRAND_NAME'].values[0] if 'BRAND_NAME' in row else "Unknown"
66
+ final.append({
67
+ "brand_name": brand_name,
68
+ "product_name": product_name,
69
+ "image_url": url,
70
+ "similarity_score": float(f"{sim:.4f}")
71
+ })
72
+
73
+ return {"recommendations": final[:15]}
74
+
75
+ except Exception as e:
76
+ return JSONResponse({"error": str(e)}, status_code=500)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
model/feature_extractor.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.models as models
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ class FeatureExtractor:
8
+ def __init__(self):
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ # Load pretrained ResNet50 without the final classification layer
11
+ resnet = models.resnet50(pretrained=True)
12
+ # Remove the final fully connected layer (fc)
13
+ self.model = torch.nn.Sequential(*list(resnet.children())[:-1])
14
+ self.model.eval().to(self.device)
15
+
16
+ # Standard ImageNet preprocessing
17
+ self.transform = transforms.Compose([
18
+ transforms.Resize(256),
19
+ transforms.CenterCrop(224),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(
22
+ mean=[0.485, 0.456, 0.406],
23
+ std=[0.229, 0.224, 0.225]
24
+ ),
25
+ ])
26
+
27
+ def extract(self, image: Image.Image):
28
+ image = self.transform(image).unsqueeze(0).to(self.device)
29
+ with torch.no_grad():
30
+ features = self.model(image)
31
+ features = features.squeeze().cpu().numpy()
32
+ features = features.reshape(-1) # flatten (2048,)
33
+
34
+ # Normalize to unit vector (important for cosine similarity)
35
+ norm = np.linalg.norm(features)
36
+ if norm > 0:
37
+ features = features / norm
38
+ return features
precompute_embeddings.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import pickle
4
+ from tqdm import tqdm
5
+ from model.feature_extractor import FeatureExtractor
6
+ from utils.image_utils import load_image_from_url
7
+
8
+ def main():
9
+ df = pd.read_csv("data/product_data.csv")
10
+ fe = FeatureExtractor()
11
+ embeddings = []
12
+ valid_urls = []
13
+
14
+ # tqdm wraps the iterable and shows progress bar with description
15
+ for url in tqdm(df['IMAGE_URL'], desc="Extracting embeddings"):
16
+ img = load_image_from_url(url)
17
+ if img is not None:
18
+ emb = fe.extract(img)
19
+ embeddings.append(emb)
20
+ valid_urls.append(url)
21
+
22
+ embeddings = np.array(embeddings)
23
+ np.save("data/embeddings.npy", embeddings)
24
+
25
+ with open("data/image_urls.pkl", "wb") as f:
26
+ pickle.dump(valid_urls, f)
27
+
28
+ print(f"Saved {len(valid_urls)} embeddings and URLs.")
29
+
30
+ if __name__ == "__main__":
31
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ timm
5
+ faiss-cpu
6
+ pandas
7
+ Pillow
8
+ requests
9
+ tqdm
10
+ numpy
11
+ fastapi
12
+ uvicorn
streamlit_app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
3
+
4
+ import streamlit as st
5
+ from model.feature_extractor import FeatureExtractor
6
+ from utils.faiss_index import FaissIndex
7
+ from PIL import Image
8
+ import pandas as pd
9
+ import numpy as np
10
+ import pickle
11
+ import streamlit.components.v1 as components
12
+
13
+ st.set_page_config(page_title="🛍️ Product Recommender", layout="wide")
14
+
15
+ @st.cache_resource
16
+ def load_resources():
17
+ embeddings = np.load("data/embeddings.npy")
18
+ with open("data/image_urls.pkl", "rb") as f:
19
+ image_urls = pickle.load(f)
20
+ product_data = pd.read_csv("data/product_data.csv")
21
+ fe = FeatureExtractor()
22
+ index = FaissIndex(dim=embeddings.shape[1])
23
+ index.build(embeddings, image_urls)
24
+ return fe, index, image_urls, product_data
25
+
26
+ fe, index, image_urls, product_data = load_resources()
27
+
28
+ st.title("🛍️ Product Image Recommender")
29
+
30
+ uploaded_file = st.file_uploader("Upload a product image", type=["jpg", "jpeg", "png"])
31
+
32
+ if uploaded_file:
33
+ user_img = Image.open(uploaded_file).convert("RGB")
34
+ st.image(user_img, caption="Uploaded Image", width=250)
35
+
36
+ user_emb = fe.extract(user_img)
37
+ results = index.search(user_emb, threshold=0.8, k=100)
38
+
39
+ if len(results) > 0:
40
+ input_image_url = results[0][0]
41
+
42
+ # Get GROUP_ID of uploaded image
43
+ input_group_id_series = product_data.loc[product_data['IMAGE'] == input_image_url, 'GROUP_ID']
44
+ input_group_id = input_group_id_series.values[0] if not input_group_id_series.empty else None
45
+
46
+ # Get PRODUCT_NAME of uploaded image
47
+ input_product_name_series = product_data.loc[product_data['IMAGE'] == input_image_url, 'PRODUCT_NAME']
48
+ input_product_name = input_product_name_series.values[0] if not input_product_name_series.empty else None
49
+
50
+ # st.markdown(f"**GROUP_ID of uploaded image:** `{input_group_id}`")
51
+
52
+ filtered_results = []
53
+ for url, sim in results:
54
+ group_id_series = product_data.loc[product_data['IMAGE'] == url, 'GROUP_ID']
55
+ group_id = group_id_series.values[0] if not group_id_series.empty else None
56
+
57
+ product_series = product_data.loc[product_data['IMAGE'] == url, 'PRODUCT_NAME']
58
+ product_name = product_series.values[0] if not product_series.empty else None
59
+
60
+ # Rule: if GROUP_ID is None or 0, exclude same product name
61
+ if (input_group_id is None or input_group_id == 0):
62
+ if product_name != input_product_name:
63
+ filtered_results.append((url, sim))
64
+ else:
65
+ if group_id != input_group_id:
66
+ filtered_results.append((url, sim))
67
+
68
+ seen_products = set()
69
+ deduped_results = []
70
+ for url, sim in filtered_results:
71
+ product_series = product_data.loc[product_data['IMAGE'] == url, 'PRODUCT_NAME']
72
+ product_name = product_series.values[0] if not product_series.empty else None
73
+ if product_name and product_name not in seen_products:
74
+ seen_products.add(product_name)
75
+ deduped_results.append((url, sim))
76
+
77
+ top_results = deduped_results[:15]
78
+
79
+ cards_html = ""
80
+ for url, sim in top_results:
81
+ brand = product_data.loc[product_data['IMAGE'] == url, 'BRAND_NAME'].values
82
+ product = product_data.loc[product_data['IMAGE'] == url, 'PRODUCT_NAME'].values
83
+ brand_name = brand[0] if len(brand) > 0 else "Unknown Brand"
84
+ product_name = product[0] if len(product) > 0 else "Unknown Product"
85
+ cards_html += f"""
86
+ <div class="card">
87
+ <img src="{url}" alt="Product Image"/>
88
+ <div class="info">
89
+ <h4>{brand_name}</h4>
90
+ <p>{product_name}</p>
91
+ <span>Similarity: {sim:.2f}</span>
92
+ </div>
93
+ </div>
94
+ """
95
+
96
+ full_html = f"""
97
+ <style>
98
+ .carousel-wrapper {{
99
+ overflow-x: auto;
100
+ overflow-y: visible; /* allow vertical overflow if any */
101
+ white-space: nowrap;
102
+ padding: 20px 16px 40px 16px;
103
+ height: auto;
104
+ scroll-behavior: smooth;
105
+ }}
106
+ .carousel {{
107
+ display: flex;
108
+ gap: 10px;
109
+ align-items: stretch; /* all cards same height */
110
+ }}
111
+ .card {{
112
+ flex: 0 0 auto;
113
+ width: 280px; /* 1.5x wider */
114
+ /* no fixed height */
115
+ border: 1px solid #ddd;
116
+ border-radius: 14px;
117
+ padding: 14px;
118
+ background: #fff;
119
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
120
+ text-align: center;
121
+ box-sizing: border-box;
122
+ transition: transform 0.2s ease-in-out;
123
+ font-family: "Segoe UI", sans-serif;
124
+ }}
125
+ .card:hover {{
126
+ transform: scale(1.04);
127
+ box-shadow: 0 6px 16px rgba(0,0,0,0.12);
128
+ }}
129
+ .card img {{
130
+ width: 100%;
131
+ height: 300px; /* 1.5x taller */
132
+ object-fit: cover;
133
+ border-radius: 8px;
134
+ }}
135
+ .info h4 {{
136
+ font-size: 20px;
137
+ margin: 12px 0 6px;
138
+ color: #222;
139
+ white-space: normal;
140
+ }}
141
+ .info p {{
142
+ font-size: 16px;
143
+ margin: 0 0 8px;
144
+ color: #555;
145
+ white-space: normal;
146
+ }}
147
+ .info span {{
148
+ font-size: 13px;
149
+ color: #888;
150
+ }}
151
+ </style>
152
+
153
+ <div class="carousel-wrapper">
154
+ <div class="carousel">
155
+ {cards_html}
156
+ </div>
157
+ </div>
158
+ """
159
+
160
+ st.subheader("🔍 Recommended Products")
161
+ components.html(full_html, height=600, scrolling=False)
162
+
163
+ else:
164
+ st.info("✨ No visually similar items found — this might be a one-of-a-kind product!")
utils/faiss_index.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+
4
+ class FaissIndex:
5
+ def __init__(self, dim):
6
+ self.index = faiss.IndexFlatIP(dim)
7
+ self.image_map = []
8
+
9
+ def build(self, embeddings, image_ids):
10
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
11
+ normalized_embeddings = embeddings / (norms + 1e-10)
12
+ self.index.add(normalized_embeddings.astype('float32'))
13
+ self.image_map = image_ids
14
+
15
+ def search(self, query_vector, threshold=0.8, k=50):
16
+ query_norm = np.linalg.norm(query_vector)
17
+ if query_norm > 0:
18
+ query_vector = query_vector / query_norm
19
+ query = np.array([query_vector]).astype('float32')
20
+ similarities, indices = self.index.search(query, k)
21
+
22
+ results = []
23
+ for i, sim in zip(indices[0], similarities[0]):
24
+ if sim >= threshold:
25
+ results.append((self.image_map[i], sim))
26
+ return sorted(results, key=lambda x: -x[1])
utils/image_utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from io import BytesIO
3
+ from PIL import Image
4
+
5
+ def load_image_from_url(url):
6
+ try:
7
+ response = requests.get(url, timeout=10)
8
+ response.raise_for_status()
9
+ image = Image.open(BytesIO(response.content)).convert('RGB')
10
+ return image
11
+ except Exception as e:
12
+ print(f"Failed to load image from {url}: {e}")
13
+ return None