zen-vton commited on
Commit
1fccc5c
Β·
verified Β·
1 Parent(s): ad9b761

Upload 11 files

Browse files
Files changed (11) hide show
  1. .gitignore +16 -0
  2. api_server.py +1377 -0
  3. check.py +365 -0
  4. fix.py +270 -0
  5. gradio_app.py +259 -0
  6. miss.py +421 -0
  7. path.py +141 -0
  8. requirements.txt +28 -0
  9. synonyms.py +853 -365
  10. train_products.py +421 -0
  11. validation_data.py +310 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .Python
4
+ venv/
5
+ env/
6
+ .vscode/
7
+ .idea/
8
+ .DS_Store
9
+ *.bin
10
+ *.safetensors
11
+ *.log
12
+ cache/*.faiss
13
+ cache/*.npy
14
+ !cache/metadata.pkl
15
+ !cache/model_info.json
16
+ !cache/cross_store_synonyms.pkl
api_server.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ο»Ώ# """
2
+ # 🎯 COMPLETE API SERVER - Matches Cross-Store Training System
3
+ # =============================================================
4
+ # βœ… Works with cross-store synonyms (washing machine = laundry machine)
5
+ # βœ… Uses auto-tags from training
6
+ # βœ… Single model (fast predictions)
7
+ # βœ… Guaranteed category_id match
8
+ # βœ… Real-time classification
9
+ # """
10
+
11
+ # from flask import Flask, request, jsonify, render_template_string
12
+ # from sentence_transformers import SentenceTransformer
13
+ # import faiss
14
+ # import pickle
15
+ # import numpy as np
16
+ # from pathlib import Path
17
+ # import time
18
+ # import re
19
+
20
+ # app = Flask(__name__)
21
+
22
+ # # ============================================================================
23
+ # # GLOBAL VARIABLES
24
+ # # ============================================================================
25
+
26
+ # CACHE_DIR = Path('cache')
27
+
28
+ # # Model
29
+ # encoder = None
30
+ # faiss_index = None
31
+ # metadata = []
32
+ # cross_store_synonyms = {}
33
+
34
+
35
+ # # ============================================================================
36
+ # # CROSS-STORE SYNONYM DATABASE (Same as training)
37
+ # # ============================================================================
38
+
39
+ # def build_cross_store_synonyms():
40
+ # """Build cross-store synonym database"""
41
+ # synonyms = {
42
+ # # Appliances
43
+ # 'washing machine': {'laundry machine', 'washer', 'clothes washer', 'washing appliance'},
44
+ # 'laundry machine': {'washing machine', 'washer', 'clothes washer'},
45
+ # 'dryer': {'drying machine', 'clothes dryer', 'tumble dryer'},
46
+ # 'refrigerator': {'fridge', 'cooler', 'ice box', 'cooling appliance'},
47
+ # 'dishwasher': {'dish washer', 'dish cleaning machine'},
48
+ # 'microwave': {'microwave oven', 'micro wave'},
49
+ # 'vacuum': {'vacuum cleaner', 'hoover', 'vac'},
50
+
51
+ # # Electronics
52
+ # 'tv': {'television', 'telly', 'smart tv', 'display'},
53
+ # 'laptop': {'notebook', 'portable computer', 'laptop computer'},
54
+ # 'mobile': {'phone', 'cell phone', 'smartphone', 'cellphone'},
55
+ # 'tablet': {'ipad', 'tab', 'tablet computer'},
56
+ # 'headphones': {'headset', 'earphones', 'earbuds', 'ear buds'},
57
+ # 'speaker': {'audio speaker', 'sound system', 'speakers'},
58
+
59
+ # # Furniture
60
+ # 'sofa': {'couch', 'settee', 'divan'},
61
+ # 'wardrobe': {'closet', 'armoire', 'cupboard'},
62
+ # 'drawer': {'chest of drawers', 'dresser'},
63
+
64
+ # # Clothing
65
+ # 'pants': {'trousers', 'slacks', 'bottoms'},
66
+ # 'sweater': {'jumper', 'pullover', 'sweatshirt'},
67
+ # 'sneakers': {'trainers', 'tennis shoes', 'running shoes'},
68
+ # 'jacket': {'coat', 'blazer', 'outerwear'},
69
+
70
+ # # Kitchen
71
+ # 'cooker': {'stove', 'range', 'cooking range'},
72
+ # 'blender': {'mixer', 'food processor', 'liquidizer'},
73
+ # 'kettle': {'electric kettle', 'water boiler'},
74
+
75
+ # # Baby/Kids
76
+ # 'stroller': {'pram', 'pushchair', 'buggy', 'baby carriage'},
77
+ # 'diaper': {'nappy', 'nappies'},
78
+ # 'pacifier': {'dummy', 'soother'},
79
+
80
+ # # Tools
81
+ # 'wrench': {'spanner', 'adjustable wrench'},
82
+ # 'flashlight': {'torch', 'flash light'},
83
+ # 'screwdriver': {'screw driver'},
84
+
85
+ # # Home
86
+ # 'tap': {'faucet', 'water tap'},
87
+ # 'bin': {'trash can', 'garbage can', 'waste bin'},
88
+ # 'curtain': {'drape', 'window covering'},
89
+
90
+ # # Crafts/Office
91
+ # 'guillotine': {'paper cutter', 'paper trimmer', 'blade cutter'},
92
+ # 'trimmer': {'cutter', 'cutting tool', 'edge cutter'},
93
+ # 'stapler': {'stapling machine', 'staple gun'},
94
+
95
+ # # Books/Media
96
+ # 'magazine': {'periodical', 'journal', 'publication'},
97
+ # 'comic': {'comic book', 'graphic novel', 'manga'},
98
+ # 'ebook': {'e-book', 'digital book', 'electronic book'},
99
+
100
+ # # General
101
+ # 'kids': {'children', 'child', 'childrens', 'youth', 'junior'},
102
+ # 'women': {'womens', 'ladies', 'female', 'lady'},
103
+ # 'men': {'mens', 'male', 'gentleman'},
104
+ # 'baby': {'infant', 'newborn', 'toddler'},
105
+ # }
106
+
107
+ # # Build bidirectional mapping
108
+ # expanded = {}
109
+ # for term, syns in synonyms.items():
110
+ # expanded[term] = syns.copy()
111
+ # for syn in syns:
112
+ # if syn not in expanded:
113
+ # expanded[syn] = set()
114
+ # expanded[syn].add(term)
115
+ # expanded[syn].update(syns - {syn})
116
+
117
+ # return expanded
118
+
119
+
120
+ # # ============================================================================
121
+ # # HELPER FUNCTIONS
122
+ # # ============================================================================
123
+
124
+ # def clean_text(text):
125
+ # """Clean and normalize text"""
126
+ # if not text:
127
+ # return ""
128
+ # text = str(text).lower()
129
+ # text = re.sub(r'[^\w\s-]', ' ', text)
130
+ # text = re.sub(r'\s+', ' ', text).strip()
131
+ # return text
132
+
133
+
134
+ # def extract_cross_store_terms(text):
135
+ # """Extract terms with cross-store variations"""
136
+ # cleaned = clean_text(text)
137
+ # words = cleaned.split()
138
+
139
+ # all_terms = set()
140
+ # all_terms.add(cleaned) # Full text
141
+
142
+ # # Single words
143
+ # for word in words:
144
+ # if len(word) > 2:
145
+ # all_terms.add(word)
146
+ # # Add cross-store synonyms
147
+ # if word in cross_store_synonyms:
148
+ # all_terms.update(cross_store_synonyms[word])
149
+
150
+ # # 2-word phrases
151
+ # for i in range(len(words) - 1):
152
+ # if len(words[i]) > 2 and len(words[i+1]) > 2:
153
+ # phrase = f"{words[i]} {words[i+1]}"
154
+ # all_terms.add(phrase)
155
+ # if phrase in cross_store_synonyms:
156
+ # all_terms.update(cross_store_synonyms[phrase])
157
+
158
+ # # 3-word phrases
159
+ # if len(words) >= 3:
160
+ # for i in range(len(words) - 2):
161
+ # if all(len(w) > 2 for w in words[i:i+3]):
162
+ # phrase = f"{words[i]} {words[i+1]} {words[i+2]}"
163
+ # all_terms.add(phrase)
164
+
165
+ # return list(all_terms)
166
+
167
+
168
+ # def build_enhanced_query(title, description=""):
169
+ # """Build enhanced query with cross-store intelligence"""
170
+ # # Extract terms with variations
171
+ # all_terms = extract_cross_store_terms(f"{title} {description}")
172
+
173
+ # # Clean product terms
174
+ # product_terms = [t for t in clean_text(f"{title} {description}").split() if len(t) > 2]
175
+
176
+ # # Build query
177
+ # # Emphasize original + all variations
178
+ # product_text = ' '.join(product_terms)
179
+ # variations_text = ' '.join(all_terms[:30]) # Top 30 variations
180
+
181
+ # # Repeat for emphasis
182
+ # emphasized = ' '.join([product_text] * 3)
183
+
184
+ # query = f"{emphasized} {variations_text} {title} {description}"
185
+
186
+ # return query, all_terms[:20]
187
+
188
+
189
+ # def encode_query(text):
190
+ # """Encode query using the trained model"""
191
+ # embedding = encoder.encode(
192
+ # text,
193
+ # convert_to_numpy=True,
194
+ # normalize_embeddings=True
195
+ # )
196
+
197
+ # if embedding.ndim == 1:
198
+ # embedding = embedding.reshape(1, -1)
199
+
200
+ # return embedding.astype('float32')
201
+
202
+
203
+ # def classify_product(title, description="", top_k=5):
204
+ # """
205
+ # Classify product using trained system
206
+ # Returns: category_id, category_path, confidence, and alternatives
207
+ # """
208
+ # start_time = time.time()
209
+
210
+ # # Step 1: Build enhanced query with cross-store synonyms
211
+ # query, matched_terms = build_enhanced_query(title, description)
212
+
213
+ # # Step 2: Encode query
214
+ # query_embedding = encode_query(query)
215
+
216
+ # # Step 3: Search FAISS index
217
+ # distances, indices = faiss_index.search(query_embedding, top_k)
218
+
219
+ # # Step 4: Get results
220
+ # results = []
221
+ # for i in range(len(indices[0])):
222
+ # idx = indices[0][i]
223
+ # if idx < len(metadata):
224
+ # meta = metadata[idx]
225
+ # confidence = float(distances[0][i]) * 100
226
+
227
+ # # Get final product name
228
+ # levels = meta.get('levels', [])
229
+ # final_product = levels[-1] if levels else meta['category_path'].split('/')[-1]
230
+
231
+ # results.append({
232
+ # 'rank': i + 1,
233
+ # 'category_id': meta['category_id'],
234
+ # 'category_path': meta['category_path'],
235
+ # 'final_product': final_product,
236
+ # 'confidence': round(confidence, 2),
237
+ # 'depth': meta.get('depth', 0)
238
+ # })
239
+
240
+ # # Best result
241
+ # best = results[0] if results else None
242
+
243
+ # if not best:
244
+ # return {
245
+ # 'error': 'No results found',
246
+ # 'product': title
247
+ # }
248
+
249
+ # # Confidence level
250
+ # conf_pct = best['confidence']
251
+ # if conf_pct >= 90:
252
+ # conf_level = "EXCELLENT"
253
+ # elif conf_pct >= 85:
254
+ # conf_level = "VERY HIGH"
255
+ # elif conf_pct >= 80:
256
+ # conf_level = "HIGH"
257
+ # elif conf_pct >= 75:
258
+ # conf_level = "GOOD"
259
+ # elif conf_pct >= 70:
260
+ # conf_level = "MEDIUM"
261
+ # else:
262
+ # conf_level = "LOW"
263
+
264
+ # processing_time = (time.time() - start_time) * 1000
265
+
266
+ # return {
267
+ # 'product': title,
268
+ # 'category_id': best['category_id'],
269
+ # 'category_path': best['category_path'],
270
+ # 'final_product': best['final_product'],
271
+ # 'confidence': f"{conf_level} ({conf_pct:.2f}%)",
272
+ # 'confidence_percent': conf_pct,
273
+ # 'depth': best['depth'],
274
+ # 'matched_terms': matched_terms,
275
+ # 'top_5_results': results,
276
+ # 'processing_time_ms': round(processing_time, 2)
277
+ # }
278
+
279
+
280
+ # # ============================================================================
281
+ # # SERVER INITIALIZATION
282
+ # # ============================================================================
283
+
284
+ # def load_server():
285
+ # """Load all trained data"""
286
+ # global encoder, faiss_index, metadata, cross_store_synonyms
287
+
288
+ # print("\n" + "="*80)
289
+ # print("πŸ”„ LOADING TRAINED MODEL")
290
+ # print("="*80 + "\n")
291
+
292
+ # # Load model
293
+ # print("πŸ“₯ Loading sentence transformer...")
294
+ # encoder = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
295
+ # print("βœ… Model loaded\n")
296
+
297
+ # # Load FAISS index
298
+ # print("πŸ“₯ Loading FAISS index...")
299
+ # index_path = CACHE_DIR / 'main_index.faiss'
300
+ # if not index_path.exists():
301
+ # raise FileNotFoundError(f"FAISS index not found: {index_path}\nPlease run training first!")
302
+ # faiss_index = faiss.read_index(str(index_path))
303
+ # print(f"βœ… Index loaded ({faiss_index.ntotal:,} vectors)\n")
304
+
305
+ # # Load metadata
306
+ # print("πŸ“₯ Loading metadata...")
307
+ # meta_path = CACHE_DIR / 'metadata.pkl'
308
+ # if not meta_path.exists():
309
+ # raise FileNotFoundError(f"Metadata not found: {meta_path}\nPlease run training first!")
310
+ # with open(meta_path, 'rb') as f:
311
+ # metadata = pickle.load(f)
312
+ # print(f"βœ… Metadata loaded ({len(metadata):,} categories)\n")
313
+
314
+ # # Load cross-store synonyms
315
+ # print("πŸ“₯ Loading cross-store synonyms...")
316
+ # syn_path = CACHE_DIR / 'cross_store_synonyms.pkl'
317
+ # if syn_path.exists():
318
+ # with open(syn_path, 'rb') as f:
319
+ # cross_store_synonyms = pickle.load(f)
320
+ # print(f"βœ… Cross-store synonyms loaded ({len(cross_store_synonyms)} terms)\n")
321
+ # else:
322
+ # print("⚠️ Cross-store synonyms not found, building default set...")
323
+ # cross_store_synonyms = build_cross_store_synonyms()
324
+ # print(f"βœ… Built {len(cross_store_synonyms)} synonym mappings\n")
325
+
326
+ # print("="*80)
327
+ # print("βœ… SERVER READY!")
328
+ # print("="*80 + "\n")
329
+
330
+
331
+ # # ============================================================================
332
+ # # HTML INTERFACE
333
+ # # ============================================================================
334
+
335
+ # HTML_TEMPLATE = """
336
+ # <!DOCTYPE html>
337
+ # <html>
338
+ # <head>
339
+ # <title>🎯 Product Category Classifier</title>
340
+ # <meta charset="UTF-8">
341
+ # <meta name="viewport" content="width=device-width, initial-scale=1.0">
342
+ # <style>
343
+ # * { margin: 0; padding: 0; box-sizing: border-box; }
344
+ # body {
345
+ # font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
346
+ # background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
347
+ # min-height: 100vh;
348
+ # padding: 20px;
349
+ # }
350
+ # .container { max-width: 1200px; margin: 0 auto; }
351
+ # .header {
352
+ # text-align: center;
353
+ # color: white;
354
+ # margin-bottom: 30px;
355
+ # }
356
+ # .header h1 { font-size: 2.5em; margin-bottom: 10px; }
357
+ # .badge {
358
+ # background: rgba(255,255,255,0.2);
359
+ # padding: 8px 20px;
360
+ # border-radius: 20px;
361
+ # display: inline-block;
362
+ # margin: 5px;
363
+ # font-size: 0.9em;
364
+ # }
365
+ # .card {
366
+ # background: white;
367
+ # border-radius: 20px;
368
+ # padding: 30px;
369
+ # box-shadow: 0 10px 40px rgba(0,0,0,0.2);
370
+ # }
371
+ # .success-box {
372
+ # background: #d4edda;
373
+ # padding: 15px;
374
+ # border-radius: 8px;
375
+ # margin-bottom: 20px;
376
+ # border-left: 4px solid #28a745;
377
+ # color: #155724;
378
+ # }
379
+ # .form-group { margin-bottom: 20px; }
380
+ # label {
381
+ # display: block;
382
+ # font-weight: 600;
383
+ # margin-bottom: 8px;
384
+ # color: #333;
385
+ # }
386
+ # input, textarea {
387
+ # width: 100%;
388
+ # padding: 12px;
389
+ # border: 2px solid #e0e0e0;
390
+ # border-radius: 8px;
391
+ # font-size: 1em;
392
+ # }
393
+ # input:focus, textarea:focus {
394
+ # outline: none;
395
+ # border-color: #667eea;
396
+ # }
397
+ # textarea { min-height: 80px; resize: vertical; }
398
+ # button {
399
+ # width: 100%;
400
+ # padding: 15px;
401
+ # background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
402
+ # color: white;
403
+ # border: none;
404
+ # border-radius: 10px;
405
+ # font-size: 1.1em;
406
+ # cursor: pointer;
407
+ # font-weight: 600;
408
+ # transition: transform 0.2s;
409
+ # }
410
+ # button:hover { transform: translateY(-2px); }
411
+ # .results { display: none; margin-top: 20px; }
412
+ # .results.show { display: block; animation: fadeIn 0.5s; }
413
+ # @keyframes fadeIn {
414
+ # from { opacity: 0; transform: translateY(10px); }
415
+ # to { opacity: 1; transform: translateY(0); }
416
+ # }
417
+ # .section {
418
+ # background: #f8f9fa;
419
+ # padding: 20px;
420
+ # border-radius: 12px;
421
+ # margin-bottom: 15px;
422
+ # border-left: 4px solid #667eea;
423
+ # }
424
+ # .section h3 { color: #667eea; margin-bottom: 12px; }
425
+ # .result-item {
426
+ # background: white;
427
+ # padding: 15px;
428
+ # border-radius: 8px;
429
+ # margin-bottom: 10px;
430
+ # border-left: 3px solid #667eea;
431
+ # }
432
+ # .tag {
433
+ # display: inline-block;
434
+ # background: #667eea;
435
+ # color: white;
436
+ # padding: 6px 12px;
437
+ # border-radius: 15px;
438
+ # margin: 3px;
439
+ # font-size: 0.9em;
440
+ # }
441
+ # .conf-excellent { background: #4caf50; }
442
+ # .conf-very { background: #8bc34a; }
443
+ # .conf-high { background: #cddc39; color: #333; }
444
+ # .conf-good { background: #ff9800; }
445
+ # .conf-medium { background: #ff5722; }
446
+ # .conf-low { background: #9e9e9e; }
447
+ # .loading { display: none; text-align: center; padding: 20px; }
448
+ # .loading.show { display: block; }
449
+ # .spinner {
450
+ # border: 4px solid #f3f3f3;
451
+ # border-top: 4px solid #667eea;
452
+ # border-radius: 50%;
453
+ # width: 40px;
454
+ # height: 40px;
455
+ # animation: spin 1s linear infinite;
456
+ # margin: 0 auto;
457
+ # }
458
+ # @keyframes spin {
459
+ # 0% { transform: rotate(0deg); }
460
+ # 100% { transform: rotate(360deg); }
461
+ # }
462
+ # </style>
463
+ # </head>
464
+ # <body>
465
+ # <div class="container">
466
+ # <div class="header">
467
+ # <h1>🎯 Product Category Classifier</h1>
468
+ # <div class="badge">Cross-Store Intelligence</div>
469
+ # <div class="badge">Auto-Tag Support</div>
470
+ # <div class="badge">Real-Time</div>
471
+ # </div>
472
+
473
+ # <div class="card">
474
+ # <div class="success-box">
475
+ # <strong>βœ… Cross-Store Synonyms Active!</strong><br>
476
+ # Understands: washing machine = laundry machine | tv = television | kids = children
477
+ # </div>
478
+
479
+ # <div class="form-group">
480
+ # <label>Product Title *</label>
481
+ # <input type="text" id="title" placeholder="e.g., Washing Machine or Laundry Machine" />
482
+ # </div>
483
+
484
+ # <div class="form-group">
485
+ # <label>Description (Optional)</label>
486
+ # <textarea id="desc" placeholder="Additional details..."></textarea>
487
+ # </div>
488
+
489
+ # <button onclick="classify()">🎯 Classify Product</button>
490
+
491
+ # <div class="loading" id="loading">
492
+ # <div class="spinner"></div>
493
+ # <p style="margin-top: 10px; color: #666;">Analyzing...</p>
494
+ # </div>
495
+
496
+ # <div class="results" id="results">
497
+ # <div class="section">
498
+ # <h3>βœ… Best Match</h3>
499
+ # <div class="result-item">
500
+ # <div style="margin-bottom: 10px;">
501
+ # <strong>Product:</strong> <span id="product"></span>
502
+ # </div>
503
+ # <div style="margin-bottom: 10px;">
504
+ # <strong>Category ID:</strong>
505
+ # <span id="catId" style="font-size: 1.2em; color: #28a745; font-weight: bold;"></span>
506
+ # </div>
507
+ # <div style="margin-bottom: 10px;">
508
+ # <strong>Final Product:</strong> <span id="finalProd" style="font-weight: 600;"></span>
509
+ # </div>
510
+ # <div style="margin-bottom: 10px;">
511
+ # <strong>Full Path:</strong><br>
512
+ # <span id="path" style="color: #666; font-size: 0.95em;"></span>
513
+ # </div>
514
+ # <div style="margin-bottom: 10px;">
515
+ # <strong>Confidence:</strong>
516
+ # <span id="confidence" class="tag"></span>
517
+ # </div>
518
+ # <div style="font-size: 0.9em; color: #666;">
519
+ # <strong>Depth:</strong> <span id="depth"></span> levels |
520
+ # <strong>Time:</strong> <span id="time"></span>ms
521
+ # </div>
522
+ # </div>
523
+ # </div>
524
+
525
+ # <div class="section">
526
+ # <h3>πŸ”— Matched Terms (Cross-Store Variations)</h3>
527
+ # <div id="matchedTerms"></div>
528
+ # </div>
529
+
530
+ # <div class="section">
531
+ # <h3>πŸ“‹ Top 5 Alternative Matches</h3>
532
+ # <div id="alternatives"></div>
533
+ # </div>
534
+ # </div>
535
+ # </div>
536
+ # </div>
537
+
538
+ # <script>
539
+ # async function classify() {
540
+ # const title = document.getElementById('title').value.trim();
541
+ # const desc = document.getElementById('desc').value.trim();
542
+
543
+ # if (!title) {
544
+ # alert('Please enter a product title');
545
+ # return;
546
+ # }
547
+
548
+ # document.getElementById('loading').classList.add('show');
549
+ # document.getElementById('results').classList.remove('show');
550
+
551
+ # try {
552
+ # const response = await fetch('/classify', {
553
+ # method: 'POST',
554
+ # headers: { 'Content-Type': 'application/json' },
555
+ # body: JSON.stringify({ title, description: desc })
556
+ # });
557
+
558
+ # if (!response.ok) throw new Error('Classification failed');
559
+
560
+ # const data = await response.json();
561
+ # displayResults(data);
562
+ # } catch (error) {
563
+ # alert('Error: ' + error.message);
564
+ # } finally {
565
+ # document.getElementById('loading').classList.remove('show');
566
+ # }
567
+ # }
568
+
569
+ # function displayResults(data) {
570
+ # document.getElementById('results').classList.add('show');
571
+
572
+ # document.getElementById('product').textContent = data.product;
573
+ # document.getElementById('catId').textContent = data.category_id;
574
+ # document.getElementById('finalProd').textContent = data.final_product;
575
+ # document.getElementById('path').textContent = data.category_path;
576
+ # document.getElementById('depth').textContent = data.depth;
577
+ # document.getElementById('time').textContent = data.processing_time_ms;
578
+
579
+ # const conf = document.getElementById('confidence');
580
+ # conf.textContent = data.confidence;
581
+ # const confClass = data.confidence.split(' ')[0].toLowerCase().replace('_', '-');
582
+ # conf.className = 'tag conf-' + confClass;
583
+
584
+ # const matchedHtml = data.matched_terms.map(t => `<span class="tag">${t}</span>`).join('');
585
+ # document.getElementById('matchedTerms').innerHTML = matchedHtml;
586
+
587
+ # let altHtml = '';
588
+ # data.top_5_results.forEach((item, i) => {
589
+ # const cls = i === 0 ? 'style="background: #e8f5e9;"' : '';
590
+ # altHtml += `
591
+ # <div class="result-item" ${cls}>
592
+ # <strong>${item.rank}.</strong> ${item.final_product}
593
+ # <span class="tag" style="background: #999;">${item.confidence}%</span>
594
+ # <div style="font-size: 0.85em; color: #666; margin-top: 5px;">
595
+ # ID: ${item.category_id}
596
+ # </div>
597
+ # </div>
598
+ # `;
599
+ # });
600
+ # document.getElementById('alternatives').innerHTML = altHtml;
601
+ # }
602
+
603
+ # document.getElementById('title').addEventListener('keypress', function(e) {
604
+ # if (e.key === 'Enter') classify();
605
+ # });
606
+ # </script>
607
+ # </body>
608
+ # </html>
609
+ # """
610
+
611
+
612
+ # # ============================================================================
613
+ # # FLASK ROUTES
614
+ # # ============================================================================
615
+
616
+ # @app.route('/')
617
+ # def index():
618
+ # """Serve the web interface"""
619
+ # return render_template_string(HTML_TEMPLATE)
620
+
621
+
622
+ # @app.route('/classify', methods=['POST'])
623
+ # def classify_route():
624
+ # """API endpoint for classification"""
625
+ # data = request.json
626
+ # title = data.get('title', '').strip()
627
+ # description = data.get('description', '').strip()
628
+
629
+ # if not title:
630
+ # return jsonify({'error': 'Title required'}), 400
631
+
632
+ # try:
633
+ # result = classify_product(title, description)
634
+ # return jsonify(result)
635
+ # except Exception as e:
636
+ # print(f"Error: {e}")
637
+ # return jsonify({'error': str(e)}), 500
638
+
639
+
640
+ # @app.route('/health')
641
+ # def health():
642
+ # """Health check endpoint"""
643
+ # return jsonify({
644
+ # 'status': 'healthy',
645
+ # 'categories': len(metadata),
646
+ # 'cross_store_synonyms': len(cross_store_synonyms),
647
+ # 'model': 'all-mpnet-base-v2'
648
+ # })
649
+
650
+
651
+ # # ============================================================================
652
+ # # MAIN
653
+ # # ============================================================================
654
+
655
+ # if __name__ == '__main__':
656
+ # try:
657
+ # load_server()
658
+
659
+ # print("\n🌐 Server starting...")
660
+ # print(" URL: http://localhost:5000")
661
+ # print(" Press CTRL+C to stop\n")
662
+
663
+ # app.run(host='0.0.0.0', port=5000, debug=False)
664
+
665
+ # except FileNotFoundError as e:
666
+ # print(f"\n❌ ERROR: {e}")
667
+ # print("\nπŸ’‘ Solution: Run training first:")
668
+ # print(" python train.py data/category_id_path_only.csv\n")
669
+ # except Exception as e:
670
+ # print(f"\n❌ UNEXPECTED ERROR: {e}\n")
671
+
672
+
673
+
674
+
675
+
676
+ #!/usr/bin/env python3
677
+ """
678
+ API Server for product category classification
679
+ Merged UI + classification logic
680
+ Model: intfloat/e5-base-v2 (must match training)
681
+
682
+ Usage:
683
+ python api_server.py
684
+
685
+ Requirements:
686
+ pip install flask sentence-transformers faiss-cpu numpy pickle5
687
+
688
+ Files expected in cache/:
689
+ - main_index.faiss
690
+ - metadata.pkl
691
+ - cross_store_synonyms.pkl (optional)
692
+
693
+ """
694
+
695
+ from flask import Flask, request, jsonify, render_template_string
696
+ from sentence_transformers import SentenceTransformer
697
+ import faiss
698
+ import pickle
699
+ import numpy as np
700
+ from pathlib import Path
701
+ import time
702
+ import re
703
+ import os
704
+ from typing import List
705
+
706
+ # ============================================================================
707
+ # CONFIG
708
+ # ============================================================================
709
+
710
+ CACHE_DIR = Path('cache')
711
+ MODEL_NAME = 'intfloat/e5-base-v2' # <-- MUST match the model used during training
712
+ FAISS_INDEX_PATH = CACHE_DIR / 'main_index.faiss'
713
+ METADATA_PATH = CACHE_DIR / 'metadata.pkl'
714
+ SYN_PATH = CACHE_DIR / 'cross_store_synonyms.pkl'
715
+
716
+ # Server globals
717
+ encoder = None
718
+ faiss_index = None
719
+ metadata = []
720
+ cross_store_synonyms = {}
721
+
722
+ # ============================================================================
723
+ # CROSS-STORE SYNONYM FALLBACK
724
+ # ============================================================================
725
+
726
+ def build_cross_store_synonyms():
727
+ """Default cross-store synonyms fallback (bidirectional mapping).
728
+ If you have a trained cross_store_synonyms.pkl produced by training, the
729
+ server will load that file instead. This function only used when no file
730
+ exists in the cache.
731
+ """
732
+ synonyms = {
733
+ 'washing machine': {'laundry machine', 'washer', 'clothes washer', 'washing appliance'},
734
+ 'laundry machine': {'washing machine', 'washer', 'clothes washer'},
735
+ 'dryer': {'drying machine', 'clothes dryer', 'tumble dryer'},
736
+ 'refrigerator': {'fridge', 'cooler', 'ice box', 'cooling appliance'},
737
+ 'dishwasher': {'dish washer', 'dish cleaning machine'},
738
+ 'microwave': {'microwave oven', 'micro wave'},
739
+ 'vacuum': {'vacuum cleaner', 'hoover', 'vac'},
740
+ 'tv': {'television', 'telly', 'smart tv', 'display'},
741
+ 'laptop': {'notebook', 'portable computer', 'laptop computer'},
742
+ 'mobile': {'phone', 'cell phone', 'smartphone', 'cellphone'},
743
+ 'tablet': {'ipad', 'tab', 'tablet computer'},
744
+ 'headphones': {'headset', 'earphones', 'earbuds', 'ear buds'},
745
+ 'speaker': {'audio speaker', 'sound system', 'speakers'},
746
+ 'sofa': {'couch', 'settee', 'divan'},
747
+ 'wardrobe': {'closet', 'armoire', 'cupboard'},
748
+ 'drawer': {'chest of drawers', 'dresser'},
749
+ 'pants': {'trousers', 'slacks', 'bottoms'},
750
+ 'sweater': {'jumper', 'pullover', 'sweatshirt'},
751
+ 'sneakers': {'trainers', 'tennis shoes', 'running shoes'},
752
+ 'jacket': {'coat', 'blazer', 'outerwear'},
753
+ 'cooker': {'stove', 'range', 'cooking range'},
754
+ 'blender': {'mixer', 'food processor', 'liquidizer'},
755
+ 'kettle': {'electric kettle', 'water boiler'},
756
+ 'stroller': {'pram', 'pushchair', 'buggy', 'baby carriage'},
757
+ 'diaper': {'nappy', 'nappies'},
758
+ 'pacifier': {'dummy', 'soother'},
759
+ 'wrench': {'spanner', 'adjustable wrench'},
760
+ 'flashlight': {'torch', 'flash light'},
761
+ 'screwdriver': {'screw driver'},
762
+ 'tap': {'faucet', 'water tap'},
763
+ 'bin': {'trash can', 'garbage can', 'waste bin'},
764
+ 'curtain': {'drape', 'window covering'},
765
+ 'guillotine': {'paper cutter', 'paper trimmer', 'blade cutter'},
766
+ 'trimmer': {'cutter', 'cutting tool', 'edge cutter'},
767
+ 'stapler': {'stapling machine', 'staple gun'},
768
+ 'magazine': {'periodical', 'journal', 'publication'},
769
+ 'comic': {'comic book', 'graphic novel', 'manga'},
770
+ 'ebook': {'e-book', 'digital book', 'electronic book'},
771
+ 'kids': {'children', 'child', 'childrens', 'youth', 'junior'},
772
+ 'women': {'womens', 'ladies', 'female', 'lady'},
773
+ 'men': {'mens', 'male', 'gentleman'},
774
+ 'baby': {'infant', 'newborn', 'toddler'},
775
+ }
776
+
777
+ expanded = {}
778
+ for term, syns in synonyms.items():
779
+ expanded[term] = set(syns)
780
+ for syn in syns:
781
+ if syn not in expanded:
782
+ expanded[syn] = set()
783
+ expanded[syn].add(term)
784
+ expanded[syn].update(syns - {syn})
785
+ return expanded
786
+
787
+ # ============================================================================
788
+ # TEXT CLEANING / QUERY BUILDING
789
+ # ============================================================================
790
+
791
+ def clean_text(text: str) -> str:
792
+ if not text:
793
+ return ""
794
+ text = str(text).lower()
795
+ # keep alphanumerics, dashes and spaces
796
+ text = re.sub(r"[^\w\s-]", " ", text)
797
+ text = re.sub(r"\s+", " ", text).strip()
798
+ return text
799
+
800
+
801
+ def extract_cross_store_terms(text: str) -> List[str]:
802
+ cleaned = clean_text(text)
803
+ words = cleaned.split()
804
+
805
+ all_terms = set()
806
+ all_terms.add(cleaned) # full cleaned text
807
+
808
+ # single words + synonyms
809
+ for word in words:
810
+ if len(word) > 2:
811
+ all_terms.add(word)
812
+ if word in cross_store_synonyms:
813
+ all_terms.update(cross_store_synonyms[word])
814
+
815
+ # 2-word phrases
816
+ for i in range(len(words) - 1):
817
+ if len(words[i]) > 2 and len(words[i + 1]) > 2:
818
+ phrase = f"{words[i]} {words[i+1]}"
819
+ all_terms.add(phrase)
820
+ if phrase in cross_store_synonyms:
821
+ all_terms.update(cross_store_synonyms[phrase])
822
+
823
+ # 3-word phrases
824
+ if len(words) >= 3:
825
+ for i in range(len(words) - 2):
826
+ if all(len(w) > 2 for w in words[i:i + 3]):
827
+ phrase = f"{words[i]} {words[i+1]} {words[i+2]}"
828
+ all_terms.add(phrase)
829
+
830
+ return list(all_terms)
831
+
832
+ def build_enhanced_query(title, description="", max_synonyms=10):
833
+ """Build query emphasizing original title and cross-store variations"""
834
+ title_clean = clean_text(title)
835
+ description_clean = clean_text(description)
836
+
837
+ # Extract cross-store variations
838
+ synonyms_list = extract_cross_store_terms(f"{title_clean} {description_clean}")
839
+
840
+ # Emphasize original title 3x, then include top synonyms
841
+ enhanced_query = ' '.join([title_clean] * 3 + synonyms_list[:max_synonyms])
842
+
843
+ return enhanced_query, synonyms_list[:20] # return top 20 for matched_terms display
844
+
845
+ # ============================================================================
846
+ # ENCODER / FAISS
847
+ # ============================================================================
848
+
849
+ def encode_query(text: str) -> np.ndarray:
850
+ embedding = encoder.encode(text, convert_to_numpy=True, normalize_embeddings=True)
851
+ if embedding.ndim == 1:
852
+ embedding = embedding.reshape(1, -1)
853
+ return embedding.astype('float32')
854
+
855
+ def classify_product(title, description="", top_k=5):
856
+ """Classify product using e5-base embeddings with cross-store optimization"""
857
+ start_time = time.time()
858
+
859
+ # Step 1: Build enhanced query
860
+ query_text, matched_terms = build_enhanced_query(title, description)
861
+
862
+ # Step 2: Encode query
863
+ query_embedding = encoder.encode(
864
+ query_text,
865
+ convert_to_numpy=True,
866
+ normalize_embeddings=True
867
+ ).astype('float32')
868
+
869
+ if query_embedding.ndim == 1:
870
+ query_embedding = query_embedding.reshape(1, -1)
871
+
872
+ # Step 3: FAISS search
873
+ distances, indices = faiss_index.search(query_embedding, top_k)
874
+
875
+ results = []
876
+ for i, idx in enumerate(indices[0]):
877
+ if idx >= len(metadata):
878
+ continue
879
+ meta = metadata[idx]
880
+ # Convert FAISS distance to cosine similarity
881
+ similarity = 1 - distances[0][i]
882
+ confidence_pct = float(similarity) * 100
883
+
884
+ final_product = meta.get('levels', [])[-1] if meta.get('levels') else meta['category_path'].split('/')[-1]
885
+
886
+ results.append({
887
+ 'rank': i + 1,
888
+ 'category_id': meta['category_id'],
889
+ 'category_path': meta['category_path'],
890
+ 'final_product': final_product,
891
+ 'confidence': round(confidence_pct, 2),
892
+ 'depth': meta.get('depth', 0)
893
+ })
894
+
895
+ if not results:
896
+ return {'error': 'No results found', 'product': title}
897
+
898
+ # Pick best match
899
+ best = results[0]
900
+ conf_pct = best['confidence']
901
+ if conf_pct >= 90:
902
+ conf_level = "EXCELLENT"
903
+ elif conf_pct >= 85:
904
+ conf_level = "VERY HIGH"
905
+ elif conf_pct >= 80:
906
+ conf_level = "HIGH"
907
+ elif conf_pct >= 75:
908
+ conf_level = "GOOD"
909
+ elif conf_pct >= 70:
910
+ conf_level = "MEDIUM"
911
+ else:
912
+ conf_level = "LOW"
913
+
914
+ processing_time = (time.time() - start_time) * 1000
915
+
916
+ return {
917
+ 'product': title,
918
+ 'category_id': best['category_id'],
919
+ 'category_path': best['category_path'],
920
+ 'final_product': best['final_product'],
921
+ 'confidence': f"{conf_level} ({conf_pct:.2f}%)",
922
+ 'confidence_percent': conf_pct,
923
+ 'depth': best['depth'],
924
+ 'matched_terms': matched_terms,
925
+ 'top_5_results': results,
926
+ 'processing_time_ms': round(processing_time, 2)
927
+ }
928
+ # FAISS returns squared L2 distances or inner product depending on index type.
929
+ # We'll treat lower distance as better. We convert to a 0-100-ish confidence by
930
+ # using a simple heuristic: score = 100 - normalized_distance*100 (clamped).
931
+
932
+ # Determine a normalization constant: use mean of top distance if available
933
+ flat_dist = distances[0]
934
+ max_d = float(np.max(flat_dist)) if flat_dist.size else 1.0
935
+ min_d = float(np.min(flat_dist)) if flat_dist.size else 0.0
936
+ range_d = max(1e-6, max_d - min_d)
937
+
938
+ for i, idx in enumerate(indices[0]):
939
+ if idx < 0 or idx >= len(metadata):
940
+ continue
941
+ meta = metadata[idx]
942
+ raw_d = float(distances[0][i])
943
+ # normalize and invert to make higher -> better
944
+ norm = (raw_d - min_d) / range_d
945
+ conf = max(0.0, min(100.0, 100.0 * (1.0 - norm)))
946
+
947
+ levels = meta.get('levels') or []
948
+ final_product = levels[-1] if levels else meta.get('category_path', '').split('/')[-1]
949
+
950
+ results.append({
951
+ 'rank': i + 1,
952
+ 'category_id': meta.get('category_id'),
953
+ 'category_path': meta.get('category_path'),
954
+ 'final_product': final_product,
955
+ 'confidence': round(conf, 2),
956
+ 'depth': meta.get('depth', 0)
957
+ })
958
+
959
+ if not results:
960
+ return {
961
+ 'error': 'No results found',
962
+ 'product': title
963
+ }
964
+
965
+ best = results[0]
966
+ conf_pct = best['confidence']
967
+ if conf_pct >= 90:
968
+ conf_level = "EXCELLENT"
969
+ elif conf_pct >= 85:
970
+ conf_level = "VERY HIGH"
971
+ elif conf_pct >= 80:
972
+ conf_level = "HIGH"
973
+ elif conf_pct >= 75:
974
+ conf_level = "GOOD"
975
+ elif conf_pct >= 70:
976
+ conf_level = "MEDIUM"
977
+ else:
978
+ conf_level = "LOW"
979
+
980
+ processing_time = (time.time() - start_time) * 1000.0
981
+
982
+ return {
983
+ 'product': title,
984
+ 'category_id': best['category_id'],
985
+ 'category_path': best['category_path'],
986
+ 'final_product': best['final_product'],
987
+ 'confidence': f"{conf_level} ({conf_pct:.2f}%)",
988
+ 'confidence_percent': conf_pct,
989
+ 'depth': best['depth'],
990
+ 'matched_terms': matched_terms,
991
+ 'top_5_results': results,
992
+ 'processing_time_ms': round(processing_time, 2)
993
+ }
994
+
995
+ # ============================================================================
996
+ # SERVER LOAD
997
+ # ============================================================================
998
+
999
+ def load_server():
1000
+ global encoder, faiss_index, metadata, cross_store_synonyms
1001
+
1002
+ print('\n' + '=' * 80)
1003
+ print('πŸ”„ LOADING TRAINED MODEL')
1004
+ print('=' * 80 + '\n')
1005
+
1006
+ # Load encoder
1007
+ print('πŸ“₯ Loading sentence transformer...')
1008
+ encoder = SentenceTransformer(MODEL_NAME)
1009
+ print('βœ… Model loaded\n')
1010
+
1011
+ # Load FAISS index
1012
+ print('πŸ“₯ Loading FAISS index...')
1013
+ if not FAISS_INDEX_PATH.exists():
1014
+ raise FileNotFoundError(f"FAISS index not found: {FAISS_INDEX_PATH}\nPlease run training first!")
1015
+ faiss_index = faiss.read_index(str(FAISS_INDEX_PATH))
1016
+ print(f"βœ… Index loaded ({faiss_index.ntotal:,} vectors)\n")
1017
+
1018
+ # Load metadata
1019
+ print('πŸ“₯ Loading metadata...')
1020
+ if not METADATA_PATH.exists():
1021
+ raise FileNotFoundError(f"Metadata not found: {METADATA_PATH}\nPlease run training first!")
1022
+ with open(METADATA_PATH, 'rb') as f:
1023
+ metadata = pickle.load(f)
1024
+ print(f"βœ… Metadata loaded ({len(metadata):,} categories)\n")
1025
+
1026
+ # Load or build cross-store synonyms
1027
+ print('πŸ“₯ Loading cross-store synonyms...')
1028
+ if SYN_PATH.exists():
1029
+ with open(SYN_PATH, 'rb') as f:
1030
+ cross_store_synonyms = pickle.load(f)
1031
+ print(f"βœ… Cross-store synonyms loaded ({len(cross_store_synonyms)} terms)\n")
1032
+ else:
1033
+ print('⚠️ Cross-store synonyms not found, building default set...')
1034
+ cross_store_synonyms = build_cross_store_synonyms()
1035
+ print(f"βœ… Built {len(cross_store_synonyms)} synonym mappings\n")
1036
+
1037
+ print('=' * 80)
1038
+ print('βœ… SERVER READY!')
1039
+ print('=' * 80 + '\n')
1040
+
1041
+ # ============================================================================
1042
+ # HTML TEMPLATE (same as provided)
1043
+ # ============================================================================
1044
+
1045
+ HTML_TEMPLATE = r"""
1046
+ <!DOCTYPE html>
1047
+ <html>
1048
+ <head>
1049
+ <title>🎯 Product Category Classifier</title>
1050
+ <meta charset="UTF-8">
1051
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
1052
+ <style>
1053
+ * { margin: 0; padding: 0; box-sizing: border-box; }
1054
+ body {
1055
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
1056
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1057
+ min-height: 100vh;
1058
+ padding: 20px;
1059
+ }
1060
+ .container { max-width: 1200px; margin: 0 auto; }
1061
+ .header {
1062
+ text-align: center;
1063
+ color: white;
1064
+ margin-bottom: 30px;
1065
+ }
1066
+ .header h1 { font-size: 2.5em; margin-bottom: 10px; }
1067
+ .badge {
1068
+ background: rgba(255,255,255,0.2);
1069
+ padding: 8px 20px;
1070
+ border-radius: 20px;
1071
+ display: inline-block;
1072
+ margin: 5px;
1073
+ font-size: 0.9em;
1074
+ }
1075
+ .card {
1076
+ background: white;
1077
+ border-radius: 20px;
1078
+ padding: 30px;
1079
+ box-shadow: 0 10px 40px rgba(0,0,0,0.2);
1080
+ }
1081
+ .success-box {
1082
+ background: #d4edda;
1083
+ padding: 15px;
1084
+ border-radius: 8px;
1085
+ margin-bottom: 20px;
1086
+ border-left: 4px solid #28a745;
1087
+ color: #155724;
1088
+ }
1089
+ .form-group { margin-bottom: 20px; }
1090
+ label {
1091
+ display: block;
1092
+ font-weight: 600;
1093
+ margin-bottom: 8px;
1094
+ color: #333;
1095
+ }
1096
+ input, textarea {
1097
+ width: 100%;
1098
+ padding: 12px;
1099
+ border: 2px solid #e0e0e0;
1100
+ border-radius: 8px;
1101
+ font-size: 1em;
1102
+ }
1103
+ input:focus, textarea:focus {
1104
+ outline: none;
1105
+ border-color: #667eea;
1106
+ }
1107
+ textarea { min-height: 80px; resize: vertical; }
1108
+ button {
1109
+ width: 100%;
1110
+ padding: 15px;
1111
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1112
+ color: white;
1113
+ border: none;
1114
+ border-radius: 10px;
1115
+ font-size: 1.1em;
1116
+ cursor: pointer;
1117
+ font-weight: 600;
1118
+ transition: transform 0.2s;
1119
+ }
1120
+ button:hover { transform: translateY(-2px); }
1121
+ .results { display: none; margin-top: 20px; }
1122
+ .results.show { display: block; animation: fadeIn 0.5s; }
1123
+ @keyframes fadeIn {
1124
+ from { opacity: 0; transform: translateY(10px); }
1125
+ to { opacity: 1; transform: translateY(0); }
1126
+ }
1127
+ .section {
1128
+ background: #f8f9fa;
1129
+ padding: 20px;
1130
+ border-radius: 12px;
1131
+ margin-bottom: 15px;
1132
+ border-left: 4px solid #667eea;
1133
+ }
1134
+ .section h3 { color: #667eea; margin-bottom: 12px; }
1135
+ .result-item {
1136
+ background: white;
1137
+ padding: 15px;
1138
+ border-radius: 8px;
1139
+ margin-bottom: 10px;
1140
+ border-left: 3px solid #667eea;
1141
+ }
1142
+ .tag {
1143
+ display: inline-block;
1144
+ background: #667eea;
1145
+ color: white;
1146
+ padding: 6px 12px;
1147
+ border-radius: 15px;
1148
+ margin: 3px;
1149
+ font-size: 0.9em;
1150
+ }
1151
+ .conf-excellent { background: #4caf50; }
1152
+ .conf-very { background: #8bc34a; }
1153
+ .conf-high { background: #cddc39; color: #333; }
1154
+ .conf-good { background: #ff9800; }
1155
+ .conf-medium { background: #ff5722; }
1156
+ .conf-low { background: #9e9e9e; }
1157
+ .loading { display: none; text-align: center; padding: 20px; }
1158
+ .loading.show { display: block; }
1159
+ .spinner {
1160
+ border: 4px solid #f3f3f3;
1161
+ border-top: 4px solid #667eea;
1162
+ border-radius: 50%;
1163
+ width: 40px;
1164
+ height: 40px;
1165
+ animation: spin 1s linear infinite;
1166
+ margin: 0 auto;
1167
+ }
1168
+ @keyframes spin {
1169
+ 0% { transform: rotate(0deg); }
1170
+ 100% { transform: rotate(360deg); }
1171
+ }
1172
+ </style>
1173
+ </head>
1174
+ <body>
1175
+ <div class="container">
1176
+ <div class="header">
1177
+ <h1>🎯 Product Category Classifier</h1>
1178
+ <div class="badge">Cross-Store Intelligence</div>
1179
+ <div class="badge">Auto-Tag Support</div>
1180
+ <div class="badge">Real-Time</div>
1181
+ </div>
1182
+
1183
+ <div class="card">
1184
+ <div class="success-box">
1185
+ <strong>βœ… Cross-Store Synonyms Active!</strong><br>
1186
+ Understands: washing machine = laundry machine | tv = television | kids = children
1187
+ </div>
1188
+
1189
+ <div class="form-group">
1190
+ <label>Product Title *</label>
1191
+ <input type="text" id="title" placeholder="e.g., Washing Machine or Laundry Machine" />
1192
+ </div>
1193
+
1194
+ <div class="form-group">
1195
+ <label>Description (Optional)</label>
1196
+ <textarea id="desc" placeholder="Additional details..."></textarea>
1197
+ </div>
1198
+
1199
+ <button onclick="classify()">🎯 Classify Product</button>
1200
+
1201
+ <div class="loading" id="loading">
1202
+ <div class="spinner"></div>
1203
+ <p style="margin-top: 10px; color: #666;">Analyzing...</p>
1204
+ </div>
1205
+
1206
+ <div class="results" id="results">
1207
+ <div class="section">
1208
+ <h3>βœ… Best Match</h3>
1209
+ <div class="result-item">
1210
+ <div style="margin-bottom: 10px;">
1211
+ <strong>Product:</strong> <span id="product"></span>
1212
+ </div>
1213
+ <div style="margin-bottom: 10px;">
1214
+ <strong>Category ID:</strong>
1215
+ <span id="catId" style="font-size: 1.2em; color: #28a745; font-weight: bold;"></span>
1216
+ </div>
1217
+ <div style="margin-bottom: 10px;">
1218
+ <strong>Final Product:</strong> <span id="finalProd" style="font-weight: 600;"></span>
1219
+ </div>
1220
+ <div style="margin-bottom: 10px;">
1221
+ <strong>Full Path:</strong><br>
1222
+ <span id="path" style="color: #666; font-size: 0.95em;"></span>
1223
+ </div>
1224
+ <div style="margin-bottom: 10px;">
1225
+ <strong>Confidence:</strong>
1226
+ <span id="confidence" class="tag"></span>
1227
+ </div>
1228
+ <div style="font-size: 0.9em; color: #666;">
1229
+ <strong>Depth:</strong> <span id="depth"></span> levels |
1230
+ <strong>Time:</strong> <span id="time"></span>ms
1231
+ </div>
1232
+ </div>
1233
+ </div>
1234
+
1235
+ <div class="section">
1236
+ <h3>πŸ”— Matched Terms (Cross-Store Variations)</h3>
1237
+ <div id="matchedTerms"></div>
1238
+ </div>
1239
+
1240
+ <div class="section">
1241
+ <h3>πŸ“‹ Top 5 Alternative Matches</h3>
1242
+ <div id="alternatives"></div>
1243
+ </div>
1244
+ </div>
1245
+ </div>
1246
+ </div>
1247
+
1248
+ <script>
1249
+ async function classify() {
1250
+ const title = document.getElementById('title').value.trim();
1251
+ const desc = document.getElementById('desc').value.trim();
1252
+
1253
+ if (!title) {
1254
+ alert('Please enter a product title');
1255
+ return;
1256
+ }
1257
+
1258
+ document.getElementById('loading').classList.add('show');
1259
+ document.getElementById('results').classList.remove('show');
1260
+
1261
+ try {
1262
+ const response = await fetch('/classify', {
1263
+ method: 'POST',
1264
+ headers: { 'Content-Type': 'application/json' },
1265
+ body: JSON.stringify({ title, description: desc })
1266
+ });
1267
+
1268
+ if (!response.ok) throw new Error('Classification failed');
1269
+
1270
+ const data = await response.json();
1271
+ displayResults(data);
1272
+ } catch (error) {
1273
+ alert('Error: ' + error.message);
1274
+ } finally {
1275
+ document.getElementById('loading').classList.remove('show');
1276
+ }
1277
+ }
1278
+
1279
+ function displayResults(data) {
1280
+ document.getElementById('results').classList.add('show');
1281
+
1282
+ document.getElementById('product').textContent = data.product;
1283
+ document.getElementById('catId').textContent = data.category_id;
1284
+ document.getElementById('finalProd').textContent = data.final_product;
1285
+ document.getElementById('path').textContent = data.category_path;
1286
+ document.getElementById('depth').textContent = data.depth;
1287
+ document.getElementById('time').textContent = data.processing_time_ms;
1288
+
1289
+ const conf = document.getElementById('confidence');
1290
+ conf.textContent = data.confidence;
1291
+ const confClass = data.confidence.split(' ')[0].toLowerCase().replace('_', '-');
1292
+ conf.className = 'tag conf-' + confClass;
1293
+
1294
+ const matchedHtml = data.matched_terms.map(t => `<span class="tag">${t}</span>`).join('');
1295
+ document.getElementById('matchedTerms').innerHTML = matchedHtml;
1296
+
1297
+ let altHtml = '';
1298
+ data.top_5_results.forEach((item, i) => {
1299
+ const cls = i === 0 ? 'style="background: #e8f5e9;"' : '';
1300
+ altHtml += `
1301
+ <div class="result-item" ${cls}>
1302
+ <strong>${item.rank}.</strong> ${item.final_product}
1303
+ <span class="tag" style="background: #999;">${item.confidence}%</span>
1304
+ <div style="font-size: 0.85em; color: #666; margin-top: 5px;">
1305
+ ID: ${item.category_id}
1306
+ </div>
1307
+ </div>
1308
+ `;
1309
+ });
1310
+ document.getElementById('alternatives').innerHTML = altHtml;
1311
+ }
1312
+
1313
+ document.getElementById('title').addEventListener('keypress', function(e) {
1314
+ if (e.key === 'Enter') classify();
1315
+ });
1316
+ </script>
1317
+ </body>
1318
+ </html>
1319
+ """
1320
+
1321
+ # ============================================================================
1322
+ # FLASK APP
1323
+ # ============================================================================
1324
+
1325
+ app = Flask(__name__)
1326
+
1327
+
1328
+ @app.route('/')
1329
+ def index():
1330
+ return render_template_string(HTML_TEMPLATE)
1331
+
1332
+
1333
+ @app.route('/classify', methods=['POST'])
1334
+ def classify_route():
1335
+ data = request.get_json(force=True)
1336
+ title = data.get('title', '').strip()
1337
+ description = data.get('description', '').strip()
1338
+
1339
+ if not title:
1340
+ return jsonify({'error': 'Title required'}), 400
1341
+
1342
+ try:
1343
+ result = classify_product(title, description)
1344
+ return jsonify(result)
1345
+ except Exception as e:
1346
+ app.logger.exception('Classification error')
1347
+ return jsonify({'error': str(e)}), 500
1348
+
1349
+
1350
+ @app.route('/health')
1351
+ def health():
1352
+ return jsonify({
1353
+ 'status': 'healthy',
1354
+ 'categories': len(metadata),
1355
+ 'cross_store_synonyms': len(cross_store_synonyms),
1356
+ 'model': MODEL_NAME
1357
+ })
1358
+
1359
+
1360
+ # ============================================================================
1361
+ # MAIN
1362
+ # ============================================================================
1363
+
1364
+ if __name__ == '__main__':
1365
+ try:
1366
+ load_server()
1367
+ print('\n🌐 Server starting...')
1368
+ print(' URL: http://localhost:5000')
1369
+ print(' Press CTRL+C to stop\n')
1370
+ # Recommended: run with a production server like gunicorn for production use
1371
+ app.run(host='0.0.0.0', port=5000, debug=False)
1372
+ except FileNotFoundError as e:
1373
+ print(f"\n❌ ERROR: {e}")
1374
+ print('\nπŸ’‘ Solution: Run training first to create FAISS index and metadata')
1375
+ except Exception as e:
1376
+ print(f"\n❌ UNEXPECTED ERROR: {e}\n")
1377
+
check.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ πŸ”§ DIAGNOSTIC AND FIX TOOL
3
+ ===========================
4
+ Analyzes your trained model and fixes common issues causing low confidence.
5
+
6
+ Issues it detects and fixes:
7
+ 1. Column name mismatches (Category_ID vs category_id)
8
+ 2. Missing or corrupted tags.json
9
+ 3. Wrong metadata format in cache
10
+ 4. FAISS index mismatch
11
+
12
+ Usage:
13
+ python diagnose_and_fix.py
14
+ """
15
+
16
+ import pickle
17
+ import json
18
+ import pandas as pd
19
+ import numpy as np
20
+ import faiss
21
+ from pathlib import Path
22
+ from sentence_transformers import SentenceTransformer
23
+ import sys
24
+
25
+ def check_cache_files():
26
+ """Check what files exist in cache"""
27
+ cache_dir = Path('cache')
28
+
29
+ print("\n" + "="*80)
30
+ print("πŸ” STEP 1: CHECKING CACHE FILES")
31
+ print("="*80 + "\n")
32
+
33
+ required_files = {
34
+ 'main_index.faiss': cache_dir / 'main_index.faiss',
35
+ 'metadata.pkl': cache_dir / 'metadata.pkl',
36
+ 'model_info.json': cache_dir / 'model_info.json',
37
+ }
38
+
39
+ optional_files = {
40
+ 'parent_embeddings.pkl': cache_dir / 'parent_embeddings.pkl',
41
+ 'calibrator.pkl': cache_dir / 'calibrator.pkl',
42
+ 'cross_store_synonyms.pkl': cache_dir / 'cross_store_synonyms.pkl',
43
+ }
44
+
45
+ issues = []
46
+
47
+ print("Required files:")
48
+ for name, path in required_files.items():
49
+ if path.exists():
50
+ size = path.stat().st_size / (1024 * 1024) # MB
51
+ print(f" βœ… {name} ({size:.2f} MB)")
52
+ else:
53
+ print(f" ❌ {name} - MISSING")
54
+ issues.append(f"Missing required file: {name}")
55
+
56
+ print("\nOptional files:")
57
+ for name, path in optional_files.items():
58
+ if path.exists():
59
+ size = path.stat().st_size / (1024 * 1024)
60
+ print(f" βœ… {name} ({size:.2f} MB)")
61
+ else:
62
+ print(f" ⚠️ {name} - not found")
63
+
64
+ return issues
65
+
66
+
67
+ def check_csv_format():
68
+ """Check CSV file format"""
69
+ print("\n" + "="*80)
70
+ print("πŸ” STEP 2: CHECKING CSV FORMAT")
71
+ print("="*80 + "\n")
72
+
73
+ csv_path = Path('data/category_only_path.csv')
74
+
75
+ if not csv_path.exists():
76
+ print("❌ CSV not found at: data/category_only_path.csv")
77
+ return ["CSV file not found"]
78
+
79
+ try:
80
+ df = pd.read_csv(csv_path, nrows=5)
81
+
82
+ print(f"Columns found: {list(df.columns)}")
83
+ print(f"Total rows: {len(pd.read_csv(csv_path)):,}")
84
+
85
+ print("\nFirst 3 rows:")
86
+ print(df.head(3).to_string())
87
+
88
+ # Check column names
89
+ if 'Category_ID' in df.columns and 'Category_path' in df.columns:
90
+ print("\nβœ… Column format: Uppercase (Category_ID, Category_path)")
91
+ return []
92
+ elif 'category_id' in df.columns and 'category_path' in df.columns:
93
+ print("\nβœ… Column format: Lowercase (category_id, category_path)")
94
+ return []
95
+ else:
96
+ print("\n❌ Unexpected column names!")
97
+ return ["CSV has wrong column names"]
98
+
99
+ except Exception as e:
100
+ print(f"\n❌ Error reading CSV: {e}")
101
+ return [f"CSV read error: {e}"]
102
+
103
+
104
+ def check_metadata():
105
+ """Check metadata format"""
106
+ print("\n" + "="*80)
107
+ print("πŸ” STEP 3: CHECKING METADATA FORMAT")
108
+ print("="*80 + "\n")
109
+
110
+ meta_path = Path('cache/metadata.pkl')
111
+
112
+ if not meta_path.exists():
113
+ print("❌ Metadata file not found")
114
+ return ["Metadata missing"]
115
+
116
+ try:
117
+ with open(meta_path, 'rb') as f:
118
+ metadata = pickle.load(f)
119
+
120
+ print(f"Metadata entries: {len(metadata):,}")
121
+
122
+ if metadata:
123
+ sample = metadata[0]
124
+ print(f"\nSample entry:")
125
+ print(f" Keys: {list(sample.keys())}")
126
+ print(f" category_id: {sample.get('category_id', 'MISSING')}")
127
+ print(f" category_path: {sample.get('category_path', 'MISSING')[:50]}...")
128
+
129
+ # Check if all entries have required fields
130
+ missing_fields = []
131
+ for i, entry in enumerate(metadata[:100]):
132
+ if 'category_id' not in entry:
133
+ missing_fields.append(f"Entry {i}: missing category_id")
134
+ if 'category_path' not in entry:
135
+ missing_fields.append(f"Entry {i}: missing category_path")
136
+
137
+ if missing_fields:
138
+ print(f"\n❌ Found {len(missing_fields)} entries with missing fields")
139
+ return missing_fields[:5] # Return first 5
140
+ else:
141
+ print("\nβœ… All entries have required fields")
142
+ return []
143
+ else:
144
+ print("❌ Metadata is empty!")
145
+ return ["Empty metadata"]
146
+
147
+ except Exception as e:
148
+ print(f"❌ Error reading metadata: {e}")
149
+ return [f"Metadata error: {e}"]
150
+
151
+
152
+ def check_faiss_index():
153
+ """Check FAISS index"""
154
+ print("\n" + "="*80)
155
+ print("πŸ” STEP 4: CHECKING FAISS INDEX")
156
+ print("="*80 + "\n")
157
+
158
+ index_path = Path('cache/main_index.faiss')
159
+ meta_path = Path('cache/metadata.pkl')
160
+
161
+ if not index_path.exists():
162
+ print("❌ FAISS index not found")
163
+ return ["FAISS index missing"]
164
+
165
+ try:
166
+ index = faiss.read_index(str(index_path))
167
+ print(f"FAISS index vectors: {index.ntotal:,}")
168
+ print(f"Dimension: {index.d}")
169
+
170
+ with open(meta_path, 'rb') as f:
171
+ metadata = pickle.load(f)
172
+
173
+ print(f"Metadata entries: {len(metadata):,}")
174
+
175
+ if index.ntotal != len(metadata):
176
+ print(f"\n❌ MISMATCH!")
177
+ print(f" FAISS has {index.ntotal:,} vectors")
178
+ print(f" Metadata has {len(metadata):,} entries")
179
+ return ["FAISS-metadata count mismatch"]
180
+ else:
181
+ print("\nβœ… FAISS and metadata counts match")
182
+ return []
183
+
184
+ except Exception as e:
185
+ print(f"❌ Error: {e}")
186
+ return [f"FAISS error: {e}"]
187
+
188
+
189
+ def check_tags_json():
190
+ """Check tags.json"""
191
+ print("\n" + "="*80)
192
+ print("πŸ” STEP 5: CHECKING TAGS.JSON")
193
+ print("="*80 + "\n")
194
+
195
+ tags_path = Path('data/tags.json')
196
+
197
+ if not tags_path.exists():
198
+ print("⚠️ tags.json not found - this will reduce accuracy!")
199
+ print(" Expected location: data/tags.json")
200
+ return ["tags.json missing"]
201
+
202
+ try:
203
+ with open(tags_path, 'r') as f:
204
+ tags = json.load(f)
205
+
206
+ print(f"Tags for {len(tags):,} categories")
207
+
208
+ if tags:
209
+ sample_key = list(tags.keys())[0]
210
+ sample_tags = tags[sample_key]
211
+
212
+ print(f"\nSample category: {sample_key}")
213
+ print(f"Tags ({len(sample_tags)}): {', '.join(sample_tags[:5])}...")
214
+
215
+ # Check average tags per category
216
+ tag_counts = [len(t) for t in tags.values() if isinstance(t, list)]
217
+ avg_tags = sum(tag_counts) / len(tag_counts) if tag_counts else 0
218
+
219
+ print(f"\nAverage tags per category: {avg_tags:.1f}")
220
+
221
+ if avg_tags < 10:
222
+ print("⚠️ Very few tags - this will reduce accuracy")
223
+ return ["Too few tags per category"]
224
+ else:
225
+ print("βœ… Tags look good")
226
+ return []
227
+ else:
228
+ print("❌ tags.json is empty!")
229
+ return ["Empty tags.json"]
230
+
231
+ except Exception as e:
232
+ print(f"❌ Error: {e}")
233
+ return [f"tags.json error: {e}"]
234
+
235
+
236
+ def test_prediction():
237
+ """Test a sample prediction"""
238
+ print("\n" + "="*80)
239
+ print("πŸ” STEP 6: TESTING PREDICTION")
240
+ print("="*80 + "\n")
241
+
242
+ try:
243
+ print("Loading model...")
244
+ encoder = SentenceTransformer('intfloat/e5-base-v2')
245
+
246
+ print("Loading FAISS index...")
247
+ index = faiss.read_index('cache/main_index.faiss')
248
+
249
+ print("Loading metadata...")
250
+ with open('cache/metadata.pkl', 'rb') as f:
251
+ metadata = pickle.load(f)
252
+
253
+ # Test query
254
+ test_query = "query: built in dishwasher"
255
+
256
+ print(f"\nTest query: \"{test_query}\"")
257
+ print("Encoding...")
258
+
259
+ query_emb = encoder.encode(test_query, convert_to_numpy=True, normalize_embeddings=True)
260
+ if query_emb.ndim == 1:
261
+ query_emb = query_emb.reshape(1, -1)
262
+
263
+ print("Searching...")
264
+ distances, indices = index.search(query_emb.astype('float32'), 5)
265
+
266
+ print("\nTop 5 results:")
267
+ for i in range(5):
268
+ idx = indices[0][i]
269
+ score = distances[0][i]
270
+ meta = metadata[idx]
271
+
272
+ print(f"\n{i+1}. Score: {score:.4f}")
273
+ print(f" ID: {meta.get('category_id', 'N/A')}")
274
+ print(f" Path: {meta.get('category_path', 'N/A')[:60]}...")
275
+
276
+ best_score = float(distances[0][0])
277
+
278
+ if best_score < 0.3:
279
+ print(f"\n❌ VERY LOW CONFIDENCE: {best_score:.4f}")
280
+ print(" This indicates a serious problem with training!")
281
+ return ["Very low prediction scores"]
282
+ elif best_score < 0.5:
283
+ print(f"\n⚠️ LOW CONFIDENCE: {best_score:.4f}")
284
+ print(" Model needs improvement")
285
+ return ["Low prediction scores"]
286
+ else:
287
+ print(f"\nβœ… GOOD CONFIDENCE: {best_score:.4f}")
288
+ return []
289
+
290
+ except Exception as e:
291
+ print(f"\n❌ Prediction test failed: {e}")
292
+ import traceback
293
+ traceback.print_exc()
294
+ return [f"Prediction error: {e}"]
295
+
296
+
297
+ def generate_fix_commands(all_issues):
298
+ """Generate commands to fix issues"""
299
+ print("\n" + "="*80)
300
+ print("πŸ”§ RECOMMENDED FIXES")
301
+ print("="*80 + "\n")
302
+
303
+ if not all_issues:
304
+ print("βœ… No critical issues found!")
305
+ print("\nIf you're still experiencing low confidence:")
306
+ print(" 1. Make sure you're using tags.json")
307
+ print(" 2. Check if validation.csv is being used for calibration")
308
+ print(" 3. Verify CSV has correct column names")
309
+ return
310
+
311
+ print("Issues found:")
312
+ for i, issue in enumerate(all_issues, 1):
313
+ print(f" {i}. {issue}")
314
+
315
+ print("\n" + "="*80)
316
+ print("FIX STEPS:")
317
+ print("="*80 + "\n")
318
+
319
+ if any('missing' in issue.lower() or 'mismatch' in issue.lower() or 'low' in issue.lower() for issue in all_issues):
320
+ print("πŸ”„ RE-TRAINING REQUIRED")
321
+ print("\nRun these commands in order:\n")
322
+
323
+ print("# Step 1: Generate tags (if missing)")
324
+ print("python generate_hybrid_tags.py data/category_only_path.csv data/tags.json")
325
+ print()
326
+
327
+ print("# Step 2: Generate validation data (for calibration)")
328
+ print("python create_validation_data.py auto data/category_only_path.csv 200")
329
+ print()
330
+
331
+ print("# Step 3: Train with ALL fixes")
332
+ print("python train_fixed_v2.py data/category_only_path.csv data/tags.json data/validation.csv")
333
+ print()
334
+ else:
335
+ print("βœ… No retraining needed - minor issues only")
336
+
337
+
338
+ def main():
339
+ print("\n" + "="*80)
340
+ print("πŸ”§ DIAGNOSTIC AND FIX TOOL")
341
+ print("="*80)
342
+ print("\nThis tool will analyze your model and identify issues\n")
343
+
344
+ all_issues = []
345
+
346
+ # Run all checks
347
+ all_issues.extend(check_cache_files())
348
+ all_issues.extend(check_csv_format())
349
+ all_issues.extend(check_metadata())
350
+ all_issues.extend(check_faiss_index())
351
+ all_issues.extend(check_tags_json())
352
+ all_issues.extend(test_prediction())
353
+
354
+ # Generate fixes
355
+ generate_fix_commands(all_issues)
356
+
357
+ print("\n" + "="*80)
358
+ print("πŸ“Š DIAGNOSIS COMPLETE")
359
+ print("="*80)
360
+ print(f"\nTotal issues found: {len(all_issues)}")
361
+ print("\n")
362
+
363
+
364
+ if __name__ == "__main__":
365
+ main()
fix.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ πŸ”§ AUTOMATIC EMBEDDING & INDEX FIXER
3
+ ====================================
4
+ Fixes common issues causing low confidence scores
5
+
6
+ Usage:
7
+ python fix_embeddings.py normalize # Fix normalization
8
+ python fix_embeddings.py rebuild-index # Rebuild FAISS
9
+ python fix_embeddings.py full-fix # Do everything
10
+ """
11
+
12
+ import numpy as np
13
+ import faiss
14
+ import pickle
15
+ import sys
16
+ from pathlib import Path
17
+ from tqdm import tqdm
18
+ import warnings
19
+ warnings.filterwarnings('ignore')
20
+
21
+ class EmbeddingFixer:
22
+ def __init__(self, cache_dir='cache'):
23
+ self.cache_dir = Path(cache_dir)
24
+
25
+ def banner(self, text):
26
+ print("\n" + "="*80)
27
+ print(f"πŸ”§ {text}")
28
+ print("="*80 + "\n")
29
+
30
+ def backup_files(self):
31
+ """Backup existing files"""
32
+ self.banner("CREATING BACKUPS")
33
+
34
+ backup_dir = self.cache_dir / 'backup'
35
+ backup_dir.mkdir(exist_ok=True)
36
+
37
+ files_to_backup = [
38
+ 'embeddings.npy',
39
+ 'main_index.faiss',
40
+ 'metadata.pkl'
41
+ ]
42
+
43
+ for filename in files_to_backup:
44
+ src = self.cache_dir / filename
45
+ if src.exists():
46
+ dst = backup_dir / filename
47
+ import shutil
48
+ shutil.copy2(src, dst)
49
+ print(f"βœ… Backed up: {filename}")
50
+
51
+ print(f"\nπŸ“ Backups saved to: {backup_dir}")
52
+
53
+ def normalize_embeddings(self):
54
+ """Normalize embeddings to unit length"""
55
+ self.banner("NORMALIZING EMBEDDINGS")
56
+
57
+ emb_path = self.cache_dir / 'embeddings.npy'
58
+
59
+ if not emb_path.exists():
60
+ print("❌ embeddings.npy not found!")
61
+ return False
62
+
63
+ print("Loading embeddings...")
64
+ embeddings = np.load(emb_path)
65
+
66
+ print(f"Original shape: {embeddings.shape}")
67
+
68
+ # Check current normalization
69
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
70
+ print(f"Mean norm before: {norms.mean():.6f}")
71
+ print(f"Std norm before: {norms.std():.6f}")
72
+
73
+ # Normalize
74
+ print("\nNormalizing...")
75
+ embeddings_normalized = embeddings / (norms + 1e-8)
76
+
77
+ # Verify
78
+ norms_after = np.linalg.norm(embeddings_normalized, axis=1)
79
+ print(f"Mean norm after: {norms_after.mean():.6f}")
80
+ print(f"Std norm after: {norms_after.std():.6f}")
81
+
82
+ # Save
83
+ output_path = self.cache_dir / 'embeddings.npy'
84
+ np.save(output_path, embeddings_normalized.astype('float32'))
85
+ print(f"\nβœ… Saved normalized embeddings: {output_path}")
86
+
87
+ return True
88
+
89
+ def rebuild_faiss_index(self):
90
+ """Rebuild FAISS index with correct metric"""
91
+ self.banner("REBUILDING FAISS INDEX")
92
+
93
+ emb_path = self.cache_dir / 'embeddings.npy'
94
+
95
+ if not emb_path.exists():
96
+ print("❌ embeddings.npy not found!")
97
+ return False
98
+
99
+ print("Loading embeddings...")
100
+ embeddings = np.load(emb_path).astype('float32')
101
+
102
+ print(f"Shape: {embeddings.shape}")
103
+
104
+ # Ensure normalized
105
+ norms = np.linalg.norm(embeddings, axis=1)
106
+ if abs(norms.mean() - 1.0) > 0.01:
107
+ print("⚠️ Embeddings not normalized, normalizing now...")
108
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
109
+ np.save(emb_path, embeddings)
110
+
111
+ dimension = embeddings.shape[1]
112
+
113
+ print(f"\nBuilding FAISS index...")
114
+ print(f" Dimension: {dimension}")
115
+ print(f" Vectors: {len(embeddings):,}")
116
+ print(f" Metric: INNER_PRODUCT")
117
+
118
+ # Create index with INNER_PRODUCT metric
119
+ index = faiss.IndexFlatIP(dimension)
120
+
121
+ # Add vectors
122
+ print("\nAdding vectors...")
123
+ index.add(embeddings)
124
+
125
+ # Save
126
+ index_path = self.cache_dir / 'main_index.faiss'
127
+ faiss.write_index(index, str(index_path))
128
+
129
+ print(f"\nβœ… Saved FAISS index: {index_path}")
130
+ print(f" Total vectors: {index.ntotal:,}")
131
+
132
+ return True
133
+
134
+ def verify_fixes(self):
135
+ """Verify that fixes worked"""
136
+ self.banner("VERIFYING FIXES")
137
+
138
+ try:
139
+ # Check embeddings
140
+ embeddings = np.load(self.cache_dir / 'embeddings.npy')
141
+ norms = np.linalg.norm(embeddings, axis=1)
142
+
143
+ print("πŸ“Š Embeddings:")
144
+ print(f" Mean norm: {norms.mean():.6f}")
145
+ print(f" Std norm: {norms.std():.6f}")
146
+
147
+ if abs(norms.mean() - 1.0) < 0.01 and norms.std() < 0.01:
148
+ print(" βœ… Properly normalized")
149
+ else:
150
+ print(" ❌ Still not normalized properly")
151
+ return False
152
+
153
+ # Check FAISS
154
+ index = faiss.read_index(str(self.cache_dir / 'main_index.faiss'))
155
+
156
+ print(f"\nπŸ“Š FAISS Index:")
157
+ print(f" Vectors: {index.ntotal:,}")
158
+ print(f" Dimension: {index.d}")
159
+
160
+ metric = index.metric_type
161
+ if metric == faiss.METRIC_INNER_PRODUCT:
162
+ print(" βœ… Using INNER_PRODUCT")
163
+ else:
164
+ print(f" ❌ Wrong metric: {metric}")
165
+ return False
166
+
167
+ # Test search
168
+ print("\nπŸ” Testing search...")
169
+ query = embeddings[0:1]
170
+ distances, indices = index.search(query, 5)
171
+
172
+ print(f" Top result index: {indices[0][0]}")
173
+ print(f" Top result score: {distances[0][0]:.6f}")
174
+
175
+ if distances[0][0] > 0.95: # Should match itself almost perfectly
176
+ print(" βœ… Search working correctly")
177
+ else:
178
+ print(" ⚠️ Unexpected similarity score")
179
+
180
+ print("\nβœ… ALL CHECKS PASSED!")
181
+ return True
182
+
183
+ except Exception as e:
184
+ print(f"\n❌ Verification failed: {e}")
185
+ return False
186
+
187
+ def full_fix(self):
188
+ """Run all fixes"""
189
+ self.banner("RUNNING FULL FIX")
190
+
191
+ print("This will:")
192
+ print("1. Backup existing files")
193
+ print("2. Normalize embeddings")
194
+ print("3. Rebuild FAISS index")
195
+ print("4. Verify fixes")
196
+
197
+ print("\nStarting in 3 seconds...")
198
+ import time
199
+ time.sleep(3)
200
+
201
+ # Backup
202
+ self.backup_files()
203
+
204
+ # Fix embeddings
205
+ if not self.normalize_embeddings():
206
+ print("\n❌ Failed to normalize embeddings")
207
+ return False
208
+
209
+ # Rebuild index
210
+ if not self.rebuild_faiss_index():
211
+ print("\n❌ Failed to rebuild index")
212
+ return False
213
+
214
+ # Verify
215
+ if not self.verify_fixes():
216
+ print("\n❌ Fixes did not work properly")
217
+ return False
218
+
219
+ print("\n" + "="*80)
220
+ print("βœ… ALL FIXES COMPLETED SUCCESSFULLY!")
221
+ print("="*80)
222
+ print("\nNext steps:")
223
+ print("1. Restart your API server: python api_server.py")
224
+ print("2. Test classification with a known category")
225
+ print("3. Check confidence scores")
226
+ print("\nIf issues persist, run diagnostics:")
227
+ print(" python diagnose_and_fix.py")
228
+ print("="*80 + "\n")
229
+
230
+ return True
231
+
232
+
233
+ def main():
234
+ if len(sys.argv) < 2:
235
+ print("\n" + "="*80)
236
+ print("πŸ”§ EMBEDDING & INDEX FIXER")
237
+ print("="*80)
238
+ print("\nUsage:")
239
+ print(" python fix_embeddings.py normalize # Fix normalization only")
240
+ print(" python fix_embeddings.py rebuild-index # Rebuild FAISS index")
241
+ print(" python fix_embeddings.py full-fix # Do everything (recommended)")
242
+ print("\nExample:")
243
+ print(" python fix_embeddings.py full-fix")
244
+ print("="*80 + "\n")
245
+ sys.exit(1)
246
+
247
+ command = sys.argv[1].lower()
248
+ fixer = EmbeddingFixer()
249
+
250
+ if command == 'normalize':
251
+ fixer.backup_files()
252
+ fixer.normalize_embeddings()
253
+ fixer.verify_fixes()
254
+
255
+ elif command == 'rebuild-index':
256
+ fixer.backup_files()
257
+ fixer.rebuild_faiss_index()
258
+ fixer.verify_fixes()
259
+
260
+ elif command == 'full-fix':
261
+ fixer.full_fix()
262
+
263
+ else:
264
+ print(f"❌ Unknown command: {command}")
265
+ print("Use: normalize, rebuild-index, or full-fix")
266
+ sys.exit(1)
267
+
268
+
269
+ if __name__ == "__main__":
270
+ main()
gradio_app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio App for Product Category Classification
4
+ Model: intfloat/e5-base-v2 (must match training)
5
+ Requires: pip install gradio sentence-transformers faiss-cpu numpy pickle5
6
+ """
7
+
8
+ import gradio as gr
9
+ from sentence_transformers import SentenceTransformer
10
+ import faiss
11
+ import pickle
12
+ import numpy as np
13
+ import re
14
+ from pathlib import Path
15
+ import time
16
+
17
+ # ====================================================================
18
+ # CONFIG
19
+ # ====================================================================
20
+ CACHE_DIR = Path("cache")
21
+ MODEL_NAME = "intfloat/e5-base-v2"
22
+ FAISS_INDEX_PATH = CACHE_DIR / "main_index.faiss"
23
+ METADATA_PATH = CACHE_DIR / "metadata.pkl"
24
+ SYN_PATH = CACHE_DIR / "cross_store_synonyms.pkl"
25
+
26
+ encoder = None
27
+ faiss_index = None
28
+ metadata = []
29
+ cross_store_synonyms = {}
30
+
31
+ # ====================================================================
32
+ # UTILITIES
33
+ # ====================================================================
34
+ def clean_text(text: str) -> str:
35
+ if not text:
36
+ return ""
37
+ text = str(text).lower()
38
+ text = re.sub(r"[^\w\s-]", " ", text)
39
+ text = re.sub(r"\s+", " ", text).strip()
40
+ return text
41
+
42
+ def build_cross_store_synonyms():
43
+ synonyms = {
44
+ 'washing machine': {'laundry machine', 'washer', 'clothes washer', 'washing appliance'},
45
+ 'laundry machine': {'washing machine', 'washer', 'clothes washer'},
46
+ 'dryer': {'drying machine', 'clothes dryer', 'tumble dryer'},
47
+ 'refrigerator': {'fridge', 'cooler', 'ice box', 'cooling appliance'},
48
+ 'dishwasher': {'dish washer', 'dish cleaning machine'},
49
+ 'microwave': {'microwave oven', 'micro wave'},
50
+ 'vacuum': {'vacuum cleaner', 'hoover', 'vac'},
51
+ 'tv': {'television', 'telly', 'smart tv', 'display'},
52
+ 'laptop': {'notebook', 'portable computer', 'laptop computer'},
53
+ 'mobile': {'phone', 'cell phone', 'smartphone', 'cellphone'},
54
+ 'tablet': {'ipad', 'tab', 'tablet computer'},
55
+ 'headphones': {'headset', 'earphones', 'earbuds', 'ear buds'},
56
+ 'speaker': {'audio speaker', 'sound system', 'speakers'},
57
+ 'sofa': {'couch', 'settee', 'divan'},
58
+ 'wardrobe': {'closet', 'armoire', 'cupboard'},
59
+ 'drawer': {'chest of drawers', 'dresser'},
60
+ 'pants': {'trousers', 'slacks', 'bottoms'},
61
+ 'sweater': {'jumper', 'pullover', 'sweatshirt'},
62
+ 'sneakers': {'trainers', 'tennis shoes', 'running shoes'},
63
+ 'jacket': {'coat', 'blazer', 'outerwear'},
64
+ 'cooker': {'stove', 'range', 'cooking range'},
65
+ 'blender': {'mixer', 'food processor', 'liquidizer'},
66
+ 'kettle': {'electric kettle', 'water boiler'},
67
+ 'stroller': {'pram', 'pushchair', 'buggy', 'baby carriage'},
68
+ 'diaper': {'nappy', 'nappies'},
69
+ 'pacifier': {'dummy', 'soother'},
70
+ 'wrench': {'spanner', 'adjustable wrench'},
71
+ 'flashlight': {'torch', 'flash light'},
72
+ 'screwdriver': {'screw driver'},
73
+ 'tap': {'faucet', 'water tap'},
74
+ 'bin': {'trash can', 'garbage can', 'waste bin'},
75
+ 'curtain': {'drape', 'window covering'},
76
+ 'guillotine': {'paper cutter', 'paper trimmer', 'blade cutter'},
77
+ 'trimmer': {'cutter', 'cutting tool', 'edge cutter'},
78
+ 'stapler': {'stapling machine', 'staple gun'},
79
+ 'magazine': {'periodical', 'journal', 'publication'},
80
+ 'comic': {'comic book', 'graphic novel', 'manga'},
81
+ 'ebook': {'e-book', 'digital book', 'electronic book'},
82
+ 'kids': {'children', 'child', 'childrens', 'youth', 'junior'},
83
+ 'women': {'womens', 'ladies', 'female', 'lady'},
84
+ 'men': {'mens', 'male', 'gentleman'},
85
+ 'baby': {'infant', 'newborn', 'toddler'},
86
+ }
87
+
88
+ expanded = {}
89
+ for term, syns in synonyms.items():
90
+ expanded[term] = set(syns)
91
+ for syn in syns:
92
+ if syn not in expanded:
93
+ expanded[syn] = set()
94
+ expanded[syn].add(term)
95
+ expanded[syn].update(syns - {syn})
96
+ return expanded
97
+
98
+ def extract_cross_store_terms(text: str):
99
+ cleaned = clean_text(text)
100
+ words = cleaned.split()
101
+ all_terms = set()
102
+ all_terms.add(cleaned)
103
+ for word in words:
104
+ if len(word) > 2:
105
+ all_terms.add(word)
106
+ if word in cross_store_synonyms:
107
+ all_terms.update(cross_store_synonyms[word])
108
+ for i in range(len(words) - 1):
109
+ phrase = f"{words[i]} {words[i+1]}"
110
+ all_terms.add(phrase)
111
+ if phrase in cross_store_synonyms:
112
+ all_terms.update(cross_store_synonyms[phrase])
113
+ for i in range(len(words) - 2):
114
+ phrase = f"{words[i]} {words[i+1]} {words[i+2]}"
115
+ all_terms.add(phrase)
116
+ return list(all_terms)
117
+
118
+ def build_enhanced_query(title, description="", max_synonyms=10):
119
+ title_clean = clean_text(title)
120
+ description_clean = clean_text(description)
121
+ synonyms_list = extract_cross_store_terms(f"{title_clean} {description_clean}")
122
+ enhanced_query = ' '.join([title_clean]*3 + synonyms_list[:max_synonyms])
123
+ return enhanced_query, synonyms_list[:20]
124
+
125
+ def encode_query(text: str):
126
+ emb = encoder.encode(text, convert_to_numpy=True, normalize_embeddings=True)
127
+ if emb.ndim == 1:
128
+ emb = emb.reshape(1, -1)
129
+ return emb.astype('float32')
130
+
131
+ # ====================================================================
132
+ # CLASSIFICATION
133
+ # ====================================================================
134
+ def classify_product(title, description="", top_k=5):
135
+ start_time = time.time()
136
+ query_text, matched_terms = build_enhanced_query(title, description)
137
+ query_embedding = encode_query(query_text)
138
+ distances, indices = faiss_index.search(query_embedding, top_k)
139
+
140
+ results = []
141
+ for i, idx in enumerate(indices[0]):
142
+ if idx >= len(metadata):
143
+ continue
144
+ meta = metadata[idx]
145
+ similarity = 1 - distances[0][i]
146
+ confidence_pct = float(similarity) * 100
147
+ final_product = meta.get('levels', [])[-1] if meta.get('levels') else meta['category_path'].split('/')[-1]
148
+ results.append({
149
+ 'rank': i+1,
150
+ 'category_id': str(meta['category_id']),
151
+ 'category_path': meta['category_path'],
152
+ 'final_product': final_product,
153
+ 'confidence': round(confidence_pct, 2),
154
+ 'depth': meta.get('depth', 0)
155
+ })
156
+
157
+ if not results:
158
+ return {
159
+ 'error': 'No results found',
160
+ 'product': title
161
+ }
162
+
163
+ best = results[0]
164
+ conf_pct = best['confidence']
165
+ if conf_pct >= 90:
166
+ conf_level = "EXCELLENT"
167
+ elif conf_pct >= 85:
168
+ conf_level = "VERY HIGH"
169
+ elif conf_pct >= 80:
170
+ conf_level = "HIGH"
171
+ elif conf_pct >= 75:
172
+ conf_level = "GOOD"
173
+ elif conf_pct >= 70:
174
+ conf_level = "MEDIUM"
175
+ else:
176
+ conf_level = "LOW"
177
+
178
+ processing_time = (time.time() - start_time) * 1000
179
+
180
+ return {
181
+ 'product': title,
182
+ 'category_id': best['category_id'],
183
+ 'category_path': best['category_path'],
184
+ 'final_product': best['final_product'],
185
+ 'confidence': f"{conf_level} ({conf_pct:.2f}%)",
186
+ 'confidence_percent': conf_pct,
187
+ 'depth': best['depth'],
188
+ 'matched_terms': matched_terms,
189
+ 'top_5_results': results,
190
+ 'processing_time_ms': round(processing_time, 2)
191
+ }
192
+
193
+ # ====================================================================
194
+ # LOAD MODEL & INDEX
195
+ # ====================================================================
196
+ def load_model():
197
+ global encoder, faiss_index, metadata, cross_store_synonyms
198
+ print("Loading sentence-transformer model...")
199
+ encoder = SentenceTransformer(MODEL_NAME)
200
+ print("Model loaded.")
201
+
202
+ print("Loading FAISS index...")
203
+ faiss_index = faiss.read_index(str(FAISS_INDEX_PATH))
204
+ print(f"FAISS index loaded: {faiss_index.ntotal} vectors.")
205
+
206
+ print("Loading metadata...")
207
+ with open(METADATA_PATH, 'rb') as f:
208
+ metadata = pickle.load(f)
209
+ print(f"Metadata loaded: {len(metadata)} categories.")
210
+
211
+ print("Loading cross-store synonyms...")
212
+ if SYN_PATH.exists():
213
+ with open(SYN_PATH, 'rb') as f:
214
+ cross_store_synonyms = pickle.load(f)
215
+ print(f"Loaded {len(cross_store_synonyms)} synonyms from file.")
216
+ else:
217
+ cross_store_synonyms = build_cross_store_synonyms()
218
+ print(f"Built {len(cross_store_synonyms)} default synonyms.")
219
+
220
+ # ====================================================================
221
+ # GRADIO FUNCTION
222
+ # ====================================================================
223
+ def classify_gradio(title, description=""):
224
+ result = classify_product(title, description)
225
+ top_match = str(result.get('final_product', ''))
226
+ category_path = str(result.get('category_path', ''))
227
+ confidence = str(result.get('confidence', ''))
228
+ matched_terms = ', '.join(result.get('matched_terms', [])) if result.get('matched_terms') else ''
229
+ top5_html = ""
230
+ for item in result.get('top_5_results', []):
231
+ top5_html += f"{item['rank']}. {item['final_product']} (ID: {item['category_id']}, Confidence: {item['confidence']}%)\n"
232
+ return top_match, category_path, confidence, matched_terms, top5_html
233
+
234
+ # ====================================================================
235
+ # MAIN GRADIO APP
236
+ # ====================================================================
237
+ def main():
238
+ load_model()
239
+ iface = gr.Interface(
240
+ fn=classify_gradio,
241
+ inputs=[
242
+ gr.Textbox(label="Product Title"),
243
+ gr.Textbox(label="Description")
244
+ ],
245
+ outputs=[
246
+ gr.Textbox(label="Predicted Product"),
247
+ gr.Textbox(label="Category Path"),
248
+ gr.Textbox(label="Confidence"),
249
+ gr.Textbox(label="Matched Terms"),
250
+ gr.Textbox(label="Top 5 Alternatives")
251
+ ],
252
+ title="🎯 Product Category Classifier",
253
+ description="Classify products with full cross-store synonyms and embeddings"
254
+ )
255
+ # Launch with a public shareable link
256
+ iface.launch(share=True)
257
+
258
+ if __name__ == "__main__":
259
+ main()
miss.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ πŸ”¬ ADVANCED MODEL DIAGNOSTICS & AUTOMATIC FIXES
3
+ ===============================================
4
+ Diagnoses and fixes common issues causing low confidence/accuracy
5
+
6
+ Usage:
7
+ python diagnose_and_fix.py
8
+ """
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import pickle
13
+ import json
14
+ import faiss
15
+ from pathlib import Path
16
+ from sentence_transformers import SentenceTransformer
17
+ from collections import defaultdict, Counter
18
+ from tqdm import tqdm
19
+ import warnings
20
+ warnings.filterwarnings('ignore')
21
+
22
+ class ModelDiagnostics:
23
+ def __init__(self, cache_dir='cache', data_dir='data'):
24
+ self.cache_dir = Path(cache_dir)
25
+ self.data_dir = Path(data_dir)
26
+ self.issues = []
27
+ self.fixes_applied = []
28
+
29
+ def banner(self, text):
30
+ print("\n" + "="*80)
31
+ print(f"πŸ” {text}")
32
+ print("="*80 + "\n")
33
+
34
+ def check_embedding_normalization(self):
35
+ """Check if embeddings are properly normalized"""
36
+ self.banner("CHECKING EMBEDDING NORMALIZATION")
37
+
38
+ try:
39
+ embeddings = np.load(self.cache_dir / 'embeddings.npy')
40
+
41
+ # Check norms
42
+ norms = np.linalg.norm(embeddings, axis=1)
43
+
44
+ print(f"πŸ“Š Embedding Statistics:")
45
+ print(f" Shape: {embeddings.shape}")
46
+ print(f" Mean norm: {norms.mean():.6f}")
47
+ print(f" Std norm: {norms.std():.6f}")
48
+ print(f" Min norm: {norms.min():.6f}")
49
+ print(f" Max norm: {norms.max():.6f}")
50
+
51
+ # Should be ~1.0 if normalized
52
+ if abs(norms.mean() - 1.0) > 0.01 or norms.std() > 0.01:
53
+ self.issues.append({
54
+ 'type': 'CRITICAL',
55
+ 'issue': 'Embeddings not normalized',
56
+ 'details': f'Mean norm: {norms.mean():.6f} (should be ~1.0)',
57
+ 'fix': 'Re-normalize embeddings'
58
+ })
59
+ print(" ❌ ISSUE: Embeddings are NOT normalized!")
60
+ print(" This causes incorrect similarity scores")
61
+ return False
62
+ else:
63
+ print(" βœ… Embeddings properly normalized")
64
+ return True
65
+
66
+ except Exception as e:
67
+ print(f" ❌ Error: {e}")
68
+ return False
69
+
70
+ def check_faiss_metric(self):
71
+ """Check FAISS index metric type"""
72
+ self.banner("CHECKING FAISS INDEX METRIC")
73
+
74
+ try:
75
+ index = faiss.read_index(str(self.cache_dir / 'main_index.faiss'))
76
+
77
+ metric = index.metric_type
78
+
79
+ print(f"πŸ“Š FAISS Index:")
80
+ print(f" Vectors: {index.ntotal:,}")
81
+ print(f" Dimension: {index.d}")
82
+ print(f" Metric type: {metric}")
83
+
84
+ if metric == faiss.METRIC_INNER_PRODUCT:
85
+ print(" βœ… Using INNER_PRODUCT (correct for normalized vectors)")
86
+ return True
87
+ elif metric == faiss.METRIC_L2:
88
+ self.issues.append({
89
+ 'type': 'CRITICAL',
90
+ 'issue': 'Wrong FAISS metric',
91
+ 'details': 'Using L2 distance instead of inner product',
92
+ 'fix': 'Rebuild index with METRIC_INNER_PRODUCT'
93
+ })
94
+ print(" ❌ ISSUE: Using L2 distance!")
95
+ print(" Should use INNER_PRODUCT for normalized vectors")
96
+ return False
97
+ else:
98
+ print(f" ⚠️ Unknown metric: {metric}")
99
+ return False
100
+
101
+ except Exception as e:
102
+ print(f" ❌ Error: {e}")
103
+ return False
104
+
105
+ def check_text_weighting(self):
106
+ """Check if text is properly weighted"""
107
+ self.banner("CHECKING TEXT CONSTRUCTION")
108
+
109
+ try:
110
+ with open(self.cache_dir / 'metadata.pkl', 'rb') as f:
111
+ metadata = pickle.load(f)
112
+
113
+ # Analyze a sample
114
+ sample = metadata[0]
115
+
116
+ print(f"πŸ“Š Sample Category:")
117
+ print(f" ID: {sample.get('category_id')}")
118
+ print(f" Path: {sample.get('category_path')}")
119
+ print(f" Depth: {sample.get('depth')}")
120
+ print(f" Levels: {sample.get('levels')}")
121
+
122
+ # Check if we have tags
123
+ if 'auto_tags' in sample and sample['auto_tags']:
124
+ print(f" Tags: {len(sample['auto_tags'])} tags")
125
+ print(f" Sample tags: {sample['auto_tags'][:5]}")
126
+ print(" βœ… Auto-tags present")
127
+ else:
128
+ self.issues.append({
129
+ 'type': 'WARNING',
130
+ 'issue': 'Missing auto-tags',
131
+ 'details': 'Categories lack auto-generated tags',
132
+ 'fix': 'Generate tags from category paths'
133
+ })
134
+ print(" ⚠️ No auto-tags found")
135
+
136
+ return True
137
+
138
+ except Exception as e:
139
+ print(f" ❌ Error: {e}")
140
+ return False
141
+
142
+ def test_predictions(self, num_samples=100):
143
+ """Test prediction accuracy on random samples"""
144
+ self.banner("TESTING PREDICTION ACCURACY")
145
+
146
+ try:
147
+ # Load model
148
+ print("Loading model and index...")
149
+ encoder = SentenceTransformer('intfloat/e5-base-v2')
150
+ index = faiss.read_index(str(self.cache_dir / 'main_index.faiss'))
151
+
152
+ with open(self.cache_dir / 'metadata.pkl', 'rb') as f:
153
+ metadata = pickle.load(f)
154
+
155
+ # Load CSV
156
+ csv_files = list(self.data_dir.glob('*.csv'))
157
+ if not csv_files:
158
+ print(" ❌ No CSV files found in data/")
159
+ return False
160
+
161
+ df = pd.read_csv(csv_files[0])
162
+
163
+ # Sample categories
164
+ samples = df.sample(min(num_samples, len(df)))
165
+
166
+ correct = 0
167
+ confidence_scores = []
168
+ rank_positions = []
169
+
170
+ print(f"Testing {len(samples)} random categories...\n")
171
+
172
+ for idx, row in tqdm(samples.iterrows(), total=len(samples)):
173
+ cat_id = str(row.iloc[0]) # First column
174
+ cat_path = str(row.iloc[1]) # Second column
175
+
176
+ # Get leaf category (final product)
177
+ leaf = cat_path.split('/')[-1].strip()
178
+
179
+ # Build query
180
+ query = f"query: {leaf}"
181
+
182
+ # Encode
183
+ query_emb = encoder.encode(query, normalize_embeddings=True)
184
+ query_emb = query_emb.reshape(1, -1).astype('float32')
185
+
186
+ # Search
187
+ distances, indices = index.search(query_emb, 10)
188
+
189
+ # Check if correct category is in top results
190
+ found_rank = None
191
+ for rank, idx in enumerate(indices[0]):
192
+ pred_id = str(metadata[idx]['category_id'])
193
+ if pred_id == cat_id:
194
+ found_rank = rank + 1
195
+ correct += 1
196
+ confidence_scores.append(float(distances[0][rank]))
197
+ break
198
+
199
+ if found_rank:
200
+ rank_positions.append(found_rank)
201
+ else:
202
+ rank_positions.append(11) # Not in top 10
203
+
204
+ # Calculate metrics
205
+ accuracy = (correct / len(samples)) * 100
206
+ avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
207
+
208
+ print(f"\nπŸ“Š Results:")
209
+ print(f" Accuracy (Top-1): {accuracy:.2f}%")
210
+ print(f" Correct predictions: {correct}/{len(samples)}")
211
+ print(f" Average confidence: {avg_confidence:.4f}")
212
+
213
+ if confidence_scores:
214
+ print(f" Min confidence: {min(confidence_scores):.4f}")
215
+ print(f" Max confidence: {max(confidence_scores):.4f}")
216
+
217
+ # Rank distribution
218
+ rank_counts = Counter(rank_positions)
219
+ print(f"\n Rank Distribution:")
220
+ for rank in sorted(rank_counts.keys())[:5]:
221
+ count = rank_counts[rank]
222
+ pct = (count / len(samples)) * 100
223
+ print(f" Rank {rank}: {count} ({pct:.1f}%)")
224
+
225
+ if accuracy < 70:
226
+ self.issues.append({
227
+ 'type': 'CRITICAL',
228
+ 'issue': 'Low prediction accuracy',
229
+ 'details': f'Only {accuracy:.1f}% accuracy',
230
+ 'fix': 'Retrain with better text weighting'
231
+ })
232
+ print(f"\n ❌ ISSUE: Low accuracy ({accuracy:.1f}%)")
233
+ return False
234
+ elif accuracy < 85:
235
+ self.issues.append({
236
+ 'type': 'WARNING',
237
+ 'issue': 'Moderate accuracy',
238
+ 'details': f'Accuracy: {accuracy:.1f}%',
239
+ 'fix': 'Consider retraining with optimizations'
240
+ })
241
+ print(f"\n ⚠️ Moderate accuracy ({accuracy:.1f}%)")
242
+ return True
243
+ else:
244
+ print(f"\n βœ… Good accuracy ({accuracy:.1f}%)")
245
+ return True
246
+
247
+ except Exception as e:
248
+ print(f" ❌ Error: {e}")
249
+ import traceback
250
+ traceback.print_exc()
251
+ return False
252
+
253
+ def analyze_category_distribution(self):
254
+ """Analyze category depth and structure"""
255
+ self.banner("ANALYZING CATEGORY STRUCTURE")
256
+
257
+ try:
258
+ with open(self.cache_dir / 'metadata.pkl', 'rb') as f:
259
+ metadata = pickle.load(f)
260
+
261
+ depths = [m.get('depth', 0) for m in metadata]
262
+
263
+ print(f"πŸ“Š Category Structure:")
264
+ print(f" Total categories: {len(metadata):,}")
265
+ print(f" Average depth: {np.mean(depths):.2f}")
266
+ print(f" Min depth: {min(depths)}")
267
+ print(f" Max depth: {max(depths)}")
268
+
269
+ # Depth distribution
270
+ depth_counts = Counter(depths)
271
+ print(f"\n Depth Distribution:")
272
+ for depth in sorted(depth_counts.keys())[:8]:
273
+ count = depth_counts[depth]
274
+ pct = (count / len(metadata)) * 100
275
+ print(f" Depth {depth}: {count:,} ({pct:.1f}%)")
276
+
277
+ # Check for imbalance
278
+ if max(depths) - min(depths) > 5:
279
+ self.issues.append({
280
+ 'type': 'WARNING',
281
+ 'issue': 'Large depth variation',
282
+ 'details': f'Depth ranges from {min(depths)} to {max(depths)}',
283
+ 'fix': 'Consider depth-based weighting'
284
+ })
285
+ print(f"\n ⚠️ Large depth variation detected")
286
+
287
+ return True
288
+
289
+ except Exception as e:
290
+ print(f" ❌ Error: {e}")
291
+ return False
292
+
293
+ def check_duplicate_embeddings(self):
294
+ """Check for duplicate or near-duplicate embeddings"""
295
+ self.banner("CHECKING FOR DUPLICATE EMBEDDINGS")
296
+
297
+ try:
298
+ embeddings = np.load(self.cache_dir / 'embeddings.npy')
299
+
300
+ # Sample check (checking all would be too slow)
301
+ sample_size = min(1000, len(embeddings))
302
+ sample_indices = np.random.choice(len(embeddings), sample_size, replace=False)
303
+ sample_embs = embeddings[sample_indices]
304
+
305
+ # Compute pairwise similarities
306
+ similarities = np.dot(sample_embs, sample_embs.T)
307
+
308
+ # Count very high similarities (excluding diagonal)
309
+ np.fill_diagonal(similarities, 0)
310
+ high_sim = (similarities > 0.99).sum() // 2 # Divide by 2 for symmetry
311
+
312
+ print(f"πŸ“Š Duplicate Check (sample of {sample_size}):")
313
+ print(f" Very similar pairs (>0.99): {high_sim}")
314
+
315
+ if high_sim > sample_size * 0.05: # >5% duplicates
316
+ self.issues.append({
317
+ 'type': 'WARNING',
318
+ 'issue': 'Many duplicate embeddings',
319
+ 'details': f'{high_sim} pairs with >0.99 similarity',
320
+ 'fix': 'Check for duplicate categories or improve text diversity'
321
+ })
322
+ print(f" ⚠️ Many near-duplicates detected")
323
+ return False
324
+ else:
325
+ print(f" βœ… Low duplicate rate")
326
+ return True
327
+
328
+ except Exception as e:
329
+ print(f" ❌ Error: {e}")
330
+ return False
331
+
332
+ def generate_report(self):
333
+ """Generate diagnostic report"""
334
+ self.banner("DIAGNOSTIC REPORT")
335
+
336
+ if not self.issues:
337
+ print("βœ… NO ISSUES FOUND!")
338
+ print("\nYour model appears to be properly configured.")
339
+ return
340
+
341
+ # Group by severity
342
+ critical = [i for i in self.issues if i['type'] == 'CRITICAL']
343
+ warnings = [i for i in self.issues if i['type'] == 'WARNING']
344
+
345
+ if critical:
346
+ print("πŸ”΄ CRITICAL ISSUES:")
347
+ for i, issue in enumerate(critical, 1):
348
+ print(f"\n{i}. {issue['issue']}")
349
+ print(f" Details: {issue['details']}")
350
+ print(f" Fix: {issue['fix']}")
351
+
352
+ if warnings:
353
+ print("\n🟑 WARNINGS:")
354
+ for i, issue in enumerate(warnings, 1):
355
+ print(f"\n{i}. {issue['issue']}")
356
+ print(f" Details: {issue['details']}")
357
+ print(f" Fix: {issue['fix']}")
358
+
359
+ print(f"\nπŸ“Š Summary:")
360
+ print(f" Critical issues: {len(critical)}")
361
+ print(f" Warnings: {len(warnings)}")
362
+
363
+ def suggest_fixes(self):
364
+ """Suggest fixes based on issues found"""
365
+ self.banner("RECOMMENDED FIXES")
366
+
367
+ if not self.issues:
368
+ print("βœ… No fixes needed!")
369
+ return
370
+
371
+ print("Run these commands to fix issues:\n")
372
+
373
+ # Check for critical issues
374
+ critical = [i for i in self.issues if i['type'] == 'CRITICAL']
375
+
376
+ if any('normalization' in i['issue'].lower() for i in critical):
377
+ print("1️⃣ Fix embedding normalization:")
378
+ print(" python fix_embeddings.py normalize")
379
+ print()
380
+
381
+ if any('faiss' in i['issue'].lower() for i in critical):
382
+ print("2️⃣ Rebuild FAISS index with correct metric:")
383
+ print(" python fix_embeddings.py rebuild-index")
384
+ print()
385
+
386
+ if any('accuracy' in i['issue'].lower() for i in critical):
387
+ print("3️⃣ Retrain with improved settings:")
388
+ print(" python train_fixed_v2.py data/categories.csv data/tags.json")
389
+ print()
390
+
391
+ if any('tags' in i['issue'].lower() for i in self.issues):
392
+ print("4️⃣ Generate missing tags:")
393
+ print(" python generate_tags.py data/categories.csv")
394
+ print()
395
+
396
+ def run_full_diagnostics(self):
397
+ """Run all diagnostic checks"""
398
+ print("\n" + "="*80)
399
+ print("πŸ”¬ COMPREHENSIVE MODEL DIAGNOSTICS")
400
+ print("="*80)
401
+
402
+ # Run all checks
403
+ self.check_embedding_normalization()
404
+ self.check_faiss_metric()
405
+ self.check_text_weighting()
406
+ self.analyze_category_distribution()
407
+ self.check_duplicate_embeddings()
408
+ self.test_predictions(num_samples=50)
409
+
410
+ # Generate report
411
+ self.generate_report()
412
+ self.suggest_fixes()
413
+
414
+ print("\n" + "="*80)
415
+ print("🎯 DIAGNOSTICS COMPLETE")
416
+ print("="*80 + "\n")
417
+
418
+
419
+ if __name__ == "__main__":
420
+ diagnostics = ModelDiagnostics()
421
+ diagnostics.run_full_diagnostics()
path.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import pandas as pd
4
+ import json
5
+ import re
6
+ from tqdm import tqdm
7
+
8
+
9
+ class HybridTagsGenerator:
10
+
11
+ def __init__(self):
12
+ # Search intent patterns (E5 likes real text)
13
+ self.search_intents = [
14
+ "buy {item}",
15
+ "best {item}",
16
+ "{item} reviews",
17
+ ]
18
+
19
+ def clean(self, text):
20
+ text = str(text).lower()
21
+ text = re.sub(r"[^\w\s-]", " ", text)
22
+ text = re.sub(r"\s+", " ", text).strip()
23
+ return text
24
+
25
+ # -------------------------------------------------------
26
+ # 1. Hierarchical tag boosting
27
+ # -------------------------------------------------------
28
+ def make_hierarchy_tags(self, path):
29
+ levels = [l.strip() for l in path.split("/") if l.strip()]
30
+ tags = []
31
+
32
+ # Strong full-path signal
33
+ full = " ".join(self.clean(l) for l in levels)
34
+ tags.extend([full] * 8) # <-- Strong boost
35
+
36
+ # Progressive hierarchy
37
+ for i in range(1, len(levels) + 1):
38
+ seg = " ".join(self.clean(l) for l in levels[:i])
39
+ tags.append(seg)
40
+
41
+ # Parent-child reinforcement
42
+ if len(levels) >= 2:
43
+ parent = self.clean(levels[-2])
44
+ child = self.clean(levels[-1])
45
+
46
+ tags.extend([
47
+ f"{parent} {child}",
48
+ f"{child} {parent}",
49
+ f"{child} in {parent}",
50
+ f"{child} category {parent}"
51
+ ])
52
+
53
+ return tags
54
+
55
+ # -------------------------------------------------------
56
+ # 2. Extract key terms and word combos
57
+ # -------------------------------------------------------
58
+ def extract_terms(self, path):
59
+ levels = [l.strip() for l in path.split("/") if l.strip()]
60
+ terms = []
61
+
62
+ for level in levels:
63
+ cleaned = self.clean(level)
64
+ if cleaned not in terms:
65
+ terms.append(cleaned)
66
+
67
+ words = [w for w in cleaned.split() if len(w) > 3]
68
+ terms.extend(words)
69
+
70
+ # bigrams for leaf and parent
71
+ if level in levels[-2:]:
72
+ for i in range(len(words) - 1):
73
+ terms.append(f"{words[i]} {words[i+1]}")
74
+
75
+ # Remove duplicates, keep order
76
+ return list(dict.fromkeys(terms))
77
+
78
+ # -------------------------------------------------------
79
+ # 3. Build final tag list for ONE category
80
+ # -------------------------------------------------------
81
+ def build_tags(self, category_id, category_path):
82
+ tags = []
83
+
84
+ # Hierarchy tags
85
+ tags.extend(self.make_hierarchy_tags(category_path))
86
+
87
+ # Key terms
88
+ terms = self.extract_terms(category_path)
89
+ tags.extend(terms[:15])
90
+
91
+ # Search intent (for leaf level)
92
+ leaf = self.clean(category_path.split("/")[-1])
93
+ for pattern in self.search_intents[:2]:
94
+ tags.append(pattern.format(item=leaf))
95
+
96
+ # Clean + dedupe + limit
97
+ seen = set()
98
+ final = []
99
+
100
+ for t in tags:
101
+ c = self.clean(t)
102
+ if c and c not in seen and len(c.split()) <= 6:
103
+ seen.add(c)
104
+ final.append(c)
105
+
106
+ return final[:50]
107
+
108
+ # -------------------------------------------------------
109
+ # 4. Generate tags.json for entire CSV
110
+ # -------------------------------------------------------
111
+ def generate_tags_json(self, csv_path, output="tags.json"):
112
+ df = pd.read_csv(csv_path, dtype=str)
113
+
114
+ if "Category_ID" not in df.columns or "Category_path" not in df.columns:
115
+ raise ValueError("CSV must contain Category_ID, Category_path columns")
116
+
117
+ df = df.dropna(subset=["Category_path"])
118
+
119
+ tags_dict = {}
120
+
121
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Building tags"):
122
+ cid = str(row["Category_ID"])
123
+ cpath = str(row["Category_path"])
124
+ tags_dict[cid] = self.build_tags(cid, cpath)
125
+
126
+ with open(output, "w", encoding="utf-8") as f:
127
+ json.dump(tags_dict, f, indent=2)
128
+
129
+ print(f"βœ… DONE: {output} saved.")
130
+ return tags_dict
131
+
132
+
133
+ if __name__ == "__main__":
134
+ import sys
135
+ if len(sys.argv) < 2:
136
+ print("Usage: python build_tags_json.py <categories.csv>")
137
+ sys.exit()
138
+
139
+ csv_file = sys.argv[1]
140
+ gen = HybridTagsGenerator()
141
+ gen.generate_tags_json(csv_file, "tags.json")
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sentence-transformers==3.3.1
2
+ # torch==2.5.1
3
+ # transformers==4.46.3
4
+ # faiss-gpu==1.9.0.post1
5
+ # pandas==2.2.3
6
+ # numpy==2.0.2
7
+ # fastapi==0.115.6
8
+ # uvicorn==0.32.1
9
+ # gunicorn==23.0.0
10
+ # pydantic==2.10.3
11
+ # joblib==1.4.2
12
+ # psutil==6.1.0
13
+
14
+
15
+ sentence-transformers==3.3.1
16
+ torch==2.5.1
17
+ transformers==4.46.3
18
+ faiss-cpu==1.9.0
19
+ pandas==2.2.3
20
+ numpy==2.0.2
21
+ fastapi==0.115.6
22
+ uvicorn==0.32.1
23
+ gunicorn==23.0.0
24
+ pydantic==2.10.3
25
+ joblib==1.4.2
26
+ psutil==6.1.0
27
+ nltk>=3.8.1
28
+ # Note: faiss-gpu is commented out to avoid compatibility issues on systems without a compatible GPU.
synonyms.py CHANGED
@@ -1,366 +1,854 @@
1
-
2
- """
3
- πŸ€– AI-POWERED SYNONYM MANAGER (Fixed for Windows + GPU)
4
- ========================================================
5
- βœ… Uses e5-base-v2 (768D, memory-efficient)
6
- βœ… Windows + NVIDIA GPU optimized
7
- βœ… Generates cross-store synonyms automatically
8
-
9
- Usage:
10
- python synonym_manager_fixed.py autobuild data/category_id_path_only.csv
11
- python synonym_manager_fixed.py autobuild data/category_id_path_only.csv --fast
12
- """
13
-
14
- import pickle
15
- from pathlib import Path
16
- import json
17
- from collections import defaultdict
18
- from tqdm import tqdm
19
- import warnings
20
- import sys
21
- import os
22
-
23
- warnings.filterwarnings('ignore')
24
- os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
25
-
26
- try:
27
- from nltk.corpus import wordnet
28
- from nltk import download as nltk_download
29
- WORDNET_AVAILABLE = True
30
- except ImportError:
31
- WORDNET_AVAILABLE = False
32
-
33
- try:
34
- from sentence_transformers import SentenceTransformer, util
35
- import torch
36
- TRANSFORMERS_AVAILABLE = True
37
- except ImportError:
38
- TRANSFORMERS_AVAILABLE = False
39
-
40
-
41
- class SynonymManager:
42
- """AI-powered synonym manager"""
43
-
44
- def __init__(self, cache_dir='cache', fast_mode=False):
45
- self.cache_dir = Path(cache_dir)
46
- self.synonyms_file = self.cache_dir / 'cross_store_synonyms.pkl'
47
- self.synonyms = {}
48
- self.model = None
49
- self.device = "cpu"
50
- self.fast_mode = fast_mode
51
-
52
- self.cache_dir.mkdir(parents=True, exist_ok=True)
53
-
54
- if self.synonyms_file.exists():
55
- self.load_synonyms()
56
-
57
- def load_synonyms(self):
58
- """Load existing synonyms"""
59
- try:
60
- with open(self.synonyms_file, 'rb') as f:
61
- loaded = pickle.load(f)
62
-
63
- if loaded and list(loaded.values()):
64
- first_val = next(iter(loaded.values()))
65
-
66
- if isinstance(first_val, list) and first_val:
67
- if isinstance(first_val[0], tuple):
68
- self.synonyms = loaded
69
- else:
70
- self.synonyms = {k: [(v, 0.8, 'legacy') for v in vals] for k, vals in loaded.items()}
71
- elif isinstance(first_val, set):
72
- self.synonyms = {k: [(v, 0.8, 'legacy') for v in vals] for k, vals in loaded.items()}
73
-
74
- print(f"βœ… Loaded {len(self.synonyms):,} synonym entries")
75
- except Exception as e:
76
- print(f"❌ Error loading synonyms: {e}")
77
- self.synonyms = {}
78
-
79
- def save_synonyms(self):
80
- """Save synonyms"""
81
- try:
82
- with open(self.synonyms_file, 'wb') as f:
83
- pickle.dump(self.synonyms, f)
84
-
85
- json_file = self.cache_dir / 'synonyms_readable.json'
86
- readable = {
87
- term: [
88
- {'synonym': syn, 'confidence': conf, 'source': src}
89
- for syn, conf, src in syns
90
- ]
91
- for term, syns in self.synonyms.items()
92
- }
93
- with open(json_file, 'w', encoding='utf-8') as f:
94
- json.dump(readable, f, indent=2, ensure_ascii=False)
95
-
96
- print(f"βœ… Saved {len(self.synonyms):,} synonym entries")
97
- return True
98
- except Exception as e:
99
- print(f"❌ Error saving synonyms: {e}")
100
- return False
101
-
102
- def load_transformer_model(self):
103
- """Load e5-base-v2 model"""
104
- if not TRANSFORMERS_AVAILABLE:
105
- print("❌ SentenceTransformers not installed!")
106
- return False
107
-
108
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
109
-
110
- if self.device == "cuda":
111
- print(f"πŸ”₯ NVIDIA GPU detected!")
112
-
113
- model_name = "intfloat/e5-base-v2"
114
- print(f"\nπŸ€– Loading {model_name}...")
115
-
116
- try:
117
- self.model = SentenceTransformer(model_name, device=self.device)
118
-
119
- if self.device == "cuda":
120
- self.model = self.model.half()
121
- print("⚑ Enabled FP16 precision")
122
-
123
- print("βœ… Model loaded\n")
124
- return True
125
- except Exception as e:
126
- print(f"❌ Failed to load model: {e}")
127
- return False
128
-
129
- def get_wordnet_synonyms(self, word, limit=10):
130
- """Get WordNet synonyms"""
131
- if self.fast_mode or not WORDNET_AVAILABLE:
132
- return []
133
-
134
- try:
135
- try:
136
- wordnet.synsets('test')
137
- except:
138
- nltk_download('wordnet', quiet=True)
139
- nltk_download('omw-1.4', quiet=True)
140
-
141
- synonyms = []
142
- word_clean = word.lower().replace(' ', '_')
143
-
144
- for syn in wordnet.synsets(word_clean):
145
- for lemma in syn.lemmas():
146
- synonym = lemma.name().replace('_', ' ').lower()
147
- if synonym != word.lower() and len(synonym) > 2:
148
- confidence = 0.75
149
- synonyms.append((synonym, confidence, 'wordnet'))
150
- if len(synonyms) >= limit:
151
- break
152
- if len(synonyms) >= limit:
153
- break
154
-
155
- return synonyms[:limit]
156
- except Exception:
157
- return []
158
-
159
- def get_semantic_synonyms(self, term, candidate_pool, threshold=0.70, limit=15):
160
- """Get semantic synonyms using E5"""
161
- if not self.model or not candidate_pool:
162
- return []
163
-
164
- try:
165
- query = f"query: {term}"
166
- candidates_prefixed = [f"passage: {c}" for c in candidate_pool]
167
-
168
- term_emb = self.model.encode(query, convert_to_tensor=True, show_progress_bar=False)
169
-
170
- batch_size = 32 if self.device == "cuda" else 8
171
- all_embeddings = []
172
-
173
- for i in range(0, len(candidates_prefixed), batch_size):
174
- batch = candidates_prefixed[i:i + batch_size]
175
- emb = self.model.encode(batch, convert_to_tensor=True, show_progress_bar=False)
176
- all_embeddings.append(emb)
177
-
178
- candidate_embs = torch.cat(all_embeddings, dim=0)
179
- scores = util.cos_sim(term_emb, candidate_embs)[0]
180
-
181
- synonyms = []
182
- for candidate, score in zip(candidate_pool, scores):
183
- score_val = float(score)
184
- if score_val > threshold and candidate.lower() != term.lower():
185
- confidence = 0.60 + (score_val - threshold) * 0.35 / (1 - threshold)
186
- synonyms.append((candidate, confidence, 'semantic'))
187
-
188
- synonyms.sort(key=lambda x: x[1], reverse=True)
189
- return synonyms[:limit]
190
-
191
- except Exception as e:
192
- print(f"⚠️ Semantic error: {e}")
193
- return []
194
-
195
- def auto_generate_synonyms(self, term, candidate_pool=None, semantic_threshold=0.70, silent=False):
196
- """Generate synonyms from multiple sources"""
197
- all_synonyms = []
198
-
199
- if not silent:
200
- print(f"\nπŸ” Finding synonyms for: '{term}'")
201
-
202
- if WORDNET_AVAILABLE and not self.fast_mode:
203
- wn_syns = self.get_wordnet_synonyms(term, limit=10)
204
- all_synonyms.extend(wn_syns)
205
-
206
- if candidate_pool and self.model:
207
- sem_syns = self.get_semantic_synonyms(
208
- term, candidate_pool,
209
- threshold=semantic_threshold,
210
- limit=15
211
- )
212
- all_synonyms.extend(sem_syns)
213
-
214
- synonym_map = {}
215
- for syn, conf, source in all_synonyms:
216
- syn_lower = syn.lower()
217
- if syn_lower not in synonym_map or conf > synonym_map[syn_lower][1]:
218
- synonym_map[syn_lower] = (syn, conf, source)
219
-
220
- final_synonyms = sorted(synonym_map.values(), key=lambda x: x[1], reverse=True)
221
- return final_synonyms
222
-
223
- def add_synonym_group(self, term, synonyms_with_confidence):
224
- """Add synonym group"""
225
- term_lower = term.lower()
226
- if term_lower not in self.synonyms:
227
- self.synonyms[term_lower] = []
228
-
229
- for syn, conf, src in synonyms_with_confidence:
230
- if not any(s[0].lower() == syn.lower() for s in self.synonyms[term_lower]):
231
- self.synonyms[term_lower].append((syn, conf, src))
232
-
233
- def extract_terms_from_categories(self, csv_path, min_frequency=2):
234
- """Extract terms from category CSV"""
235
- print(f"\nπŸ“‚ Extracting terms from: {csv_path}")
236
-
237
- try:
238
- import pandas as pd
239
-
240
- df = pd.read_csv(csv_path)
241
- path_col = df.columns[1] if len(df.columns) > 1 else df.columns[0]
242
- paths = df[path_col].dropna().astype(str)
243
-
244
- print(f" Processing {len(paths):,} category paths...")
245
-
246
- term_freq = defaultdict(int)
247
-
248
- for path in tqdm(paths, desc="Analyzing paths"):
249
- levels = path.split('/')
250
-
251
- for level in levels:
252
- words = level.lower().split()
253
-
254
- for word in words:
255
- if len(word) > 2 and word.isalpha():
256
- term_freq[word] += 1
257
-
258
- for i in range(len(words) - 1):
259
- if len(words[i]) > 2 and len(words[i+1]) > 2:
260
- phrase = f"{words[i]} {words[i+1]}"
261
- if phrase.replace(' ', '').isalpha():
262
- term_freq[phrase] += 1
263
-
264
- candidates = [
265
- term for term, freq in term_freq.items()
266
- if freq >= min_frequency
267
- ]
268
-
269
- print(f"βœ… Extracted {len(candidates):,} terms (min frequency: {min_frequency})")
270
- return candidates, term_freq
271
-
272
- except Exception as e:
273
- print(f"❌ Error extracting terms: {e}")
274
- import traceback
275
- traceback.print_exc()
276
- return [], {}
277
-
278
- def auto_build_from_categories(self, csv_path, top_terms=1000, semantic_threshold=0.70):
279
- """Auto-build synonym database"""
280
- print("\n" + "="*80)
281
- print("πŸš€ AUTO-BUILD SYNONYM DATABASE")
282
- print("="*80)
283
-
284
- if not self.load_transformer_model():
285
- print("\n⚠️ Continuing with WordNet only")
286
-
287
- all_terms, term_freq = self.extract_terms_from_categories(csv_path)
288
- if not all_terms:
289
- print("❌ No terms extracted")
290
- return False
291
-
292
- print(f"\n🎯 Selecting top {top_terms} terms...")
293
- top_frequent = sorted(term_freq.items(), key=lambda x: x[1], reverse=True)[:top_terms]
294
- terms_to_process = [term for term, _ in top_frequent]
295
-
296
- print(f"βœ… Selected {len(terms_to_process)} terms")
297
- print(f"πŸ“Š Top 10: {', '.join(terms_to_process[:10])}")
298
- print(f"\nπŸ”„ Generating synonyms (threshold={semantic_threshold})...\n")
299
-
300
- stats = {'processed': 0, 'synonyms': 0, 'high_conf': 0}
301
-
302
- for term in tqdm(terms_to_process, desc="Processing"):
303
- if term in self.synonyms and len(self.synonyms[term]) >= 10:
304
- continue
305
-
306
- syns = self.auto_generate_synonyms(
307
- term,
308
- candidate_pool=all_terms,
309
- semantic_threshold=semantic_threshold,
310
- silent=True
311
- )
312
-
313
- if syns:
314
- self.add_synonym_group(term, syns)
315
- stats['processed'] += 1
316
- stats['synonyms'] += len(syns)
317
- stats['high_conf'] += sum(1 for _, c, _ in syns if c >= 0.8)
318
-
319
- print(f"\nβœ… Processed: {stats['processed']:,} terms")
320
- print(f"βœ… Total synonyms: {stats['synonyms']:,}")
321
- print(f"βœ… High confidence (β‰₯0.8): {stats['high_conf']:,}")
322
-
323
- self.save_synonyms()
324
-
325
- print("\nπŸŽ‰ AUTO-BUILD COMPLETE!\n")
326
- return True
327
-
328
-
329
- def main():
330
- """Main entry point"""
331
- print("\n" + "="*80)
332
- print("πŸ€– AI-POWERED SYNONYM MANAGER")
333
- print("="*80 + "\n")
334
-
335
- fast_mode = '--fast' in sys.argv
336
-
337
- if len(sys.argv) < 2:
338
- print("Usage:")
339
- print(" python synonym_manager_fixed.py autobuild <csv_file>")
340
- print(" python synonym_manager_fixed.py autobuild <csv_file> --fast")
341
- print("\nExample:")
342
- print(" python synonym_manager_fixed.py autobuild data/category_id_path_only.csv")
343
- return
344
-
345
- command = sys.argv[1].lower()
346
-
347
- if command == 'autobuild':
348
- if len(sys.argv) < 3:
349
- print("❌ CSV file path required")
350
- return
351
-
352
- csv_path = sys.argv[2]
353
-
354
- if not Path(csv_path).exists():
355
- print(f"❌ File not found: {csv_path}")
356
- return
357
-
358
- manager = SynonymManager(fast_mode=fast_mode)
359
- manager.auto_build_from_categories(csv_path, top_terms=1000)
360
-
361
- else:
362
- print(f"❌ Unknown command: {command}")
363
-
364
-
365
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  main()
 
1
+ # """
2
+ # πŸ€– FIXED AI-POWERED SYNONYM MANAGER
3
+ # ====================================
4
+ # βœ… Windows + NVIDIA GPU optimized
5
+ # βœ… Uses e5-base-v2 (lower memory)
6
+ # βœ… Proper error handling
7
+ # βœ… Progress tracking
8
+
9
+ # Usage:
10
+ # python synonym_manager_fixed.py autobuild data/category_id_path_only.csv
11
+ # python synonym_manager_fixed.py autobuild data/category_id_path_only.csv --fast
12
+ # """
13
+
14
+ # import pickle
15
+ # from pathlib import Path
16
+ # import json
17
+ # from collections import defaultdict
18
+ # from tqdm import tqdm
19
+ # import warnings
20
+ # import sys
21
+ # import os
22
+
23
+ # warnings.filterwarnings('ignore')
24
+
25
+ # # Fix CUDA issues on Windows
26
+ # os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
27
+
28
+ # try:
29
+ # from nltk.corpus import wordnet
30
+ # from nltk import download as nltk_download
31
+ # WORDNET_AVAILABLE = True
32
+ # except ImportError:
33
+ # WORDNET_AVAILABLE = False
34
+ # print("⚠️ NLTK not available. Install with: pip install nltk")
35
+
36
+ # try:
37
+ # from sentence_transformers import SentenceTransformer, util
38
+ # import torch
39
+ # TRANSFORMERS_AVAILABLE = True
40
+ # except ImportError:
41
+ # TRANSFORMERS_AVAILABLE = False
42
+ # print("⚠️ SentenceTransformers not available.")
43
+ # print(" Install with: pip install sentence-transformers torch")
44
+
45
+
46
+ # class FixedAISynonymManager:
47
+ # """Fixed AI-powered synonym manager for Windows + NVIDIA GPU"""
48
+
49
+ # def __init__(self, cache_dir='cache', tags_file='data/tags.json', fast_mode=False):
50
+ # self.cache_dir = Path(cache_dir)
51
+ # self.synonyms_file = self.cache_dir / 'cross_store_synonyms.pkl'
52
+ # self.tags_file = Path(tags_file)
53
+ # self.synonyms = {}
54
+ # self.tags_data = {}
55
+ # self.model = None
56
+ # self.device = "cpu"
57
+ # self.fast_mode = fast_mode
58
+
59
+ # # Create cache directory
60
+ # self.cache_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ # # Load existing data
63
+ # self.load_tags()
64
+ # if self.synonyms_file.exists():
65
+ # self.load_synonyms()
66
+ # else:
67
+ # print("πŸ“ No existing synonyms file. Will create new one.")
68
+
69
+ # def load_tags(self):
70
+ # """Load domain-specific tags (optional)"""
71
+ # if self.tags_file.exists():
72
+ # try:
73
+ # with open(self.tags_file, 'r', encoding='utf-8') as f:
74
+ # self.tags_data = json.load(f)
75
+ # print(f"βœ… Loaded {len(self.tags_data)} tag entries")
76
+ # return True
77
+ # except Exception as e:
78
+ # print(f"⚠️ Could not load tags.json: {e}")
79
+ # else:
80
+ # print(f"ℹ️ tags.json not found (optional)")
81
+ # return False
82
+
83
+ # def load_synonyms(self):
84
+ # """Load existing synonyms with format conversion"""
85
+ # try:
86
+ # with open(self.synonyms_file, 'rb') as f:
87
+ # loaded = pickle.load(f)
88
+
89
+ # # Handle different formats
90
+ # if not loaded:
91
+ # self.synonyms = {}
92
+ # return
93
+
94
+ # # Check format
95
+ # first_val = next(iter(loaded.values()))
96
+
97
+ # if isinstance(first_val, list):
98
+ # if first_val and isinstance(first_val[0], tuple):
99
+ # # New format: [(syn, conf, src), ...]
100
+ # self.synonyms = loaded
101
+ # print(f"βœ… Loaded {len(self.synonyms)} synonym entries (new format)")
102
+ # elif first_val and isinstance(first_val[0], str):
103
+ # # Legacy format: [syn1, syn2, ...]
104
+ # self.synonyms = {
105
+ # k: [(v, 0.8, 'legacy') for v in vals]
106
+ # for k, vals in loaded.items()
107
+ # }
108
+ # print(f"βœ… Converted {len(self.synonyms)} legacy synonym entries")
109
+ # elif isinstance(first_val, set):
110
+ # # Set format
111
+ # self.synonyms = {
112
+ # k: [(v, 0.8, 'legacy') for v in vals]
113
+ # for k, vals in loaded.items()
114
+ # }
115
+ # print(f"βœ… Converted {len(self.synonyms)} set-based entries")
116
+ # else:
117
+ # self.synonyms = {}
118
+ # print(f"⚠️ Unknown synonym format")
119
+
120
+ # except Exception as e:
121
+ # print(f"❌ Error loading synonyms: {e}")
122
+ # self.synonyms = {}
123
+
124
+ # def save_synonyms(self):
125
+ # """Save synonyms in both formats"""
126
+ # try:
127
+ # # Save binary format
128
+ # with open(self.synonyms_file, 'wb') as f:
129
+ # pickle.dump(self.synonyms, f)
130
+
131
+ # # Save readable JSON
132
+ # json_file = self.cache_dir / 'synonyms_readable.json'
133
+ # readable = {}
134
+ # for term, syns in self.synonyms.items():
135
+ # readable[term] = [
136
+ # {'synonym': syn, 'confidence': float(conf), 'source': src}
137
+ # for syn, conf, src in syns
138
+ # ]
139
+
140
+ # with open(json_file, 'w', encoding='utf-8') as f:
141
+ # json.dump(readable, f, indent=2, ensure_ascii=False)
142
+
143
+ # print(f"\nβœ… Saved {len(self.synonyms)} synonym entries")
144
+ # print(f" πŸ“ Binary: {self.synonyms_file}")
145
+ # print(f" πŸ“ JSON: {json_file}")
146
+ # return True
147
+ # except Exception as e:
148
+ # print(f"❌ Error saving synonyms: {e}")
149
+ # return False
150
+
151
+ # def load_transformer_model(self):
152
+ # """Load e5-base-v2 model with GPU support"""
153
+ # if not TRANSFORMERS_AVAILABLE:
154
+ # print("❌ SentenceTransformers not installed!")
155
+ # return False
156
+
157
+ # # Check for CUDA
158
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
159
+
160
+ # if self.device == "cuda":
161
+ # print(f"πŸ”₯ NVIDIA GPU detected!")
162
+ # try:
163
+ # gpu_name = torch.cuda.get_device_name(0)
164
+ # vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
165
+ # print(f" GPU: {gpu_name}")
166
+ # print(f" VRAM: {vram_gb:.1f} GB")
167
+ # except:
168
+ # pass
169
+ # else:
170
+ # print("πŸ’» Using CPU (slower)")
171
+
172
+ # # Use e5-base-v2 for better memory efficiency
173
+ # model_name = "intfloat/e5-base-v2"
174
+ # print(f"\nπŸ€– Loading model: {model_name}")
175
+
176
+ # try:
177
+ # self.model = SentenceTransformer(model_name, device=self.device)
178
+ # self.model.max_seq_length = 256
179
+
180
+ # # Use FP16 on GPU for speed
181
+ # if self.device == "cuda":
182
+ # self.model = self.model.half()
183
+ # print("⚑ Enabled FP16 precision")
184
+
185
+ # print("βœ… Model loaded successfully\n")
186
+ # return True
187
+ # except Exception as e:
188
+ # print(f"❌ Failed to load model: {e}")
189
+ # return False
190
+
191
+ # def get_wordnet_synonyms(self, word, limit=10):
192
+ # """Get WordNet synonyms"""
193
+ # if self.fast_mode or not WORDNET_AVAILABLE:
194
+ # return []
195
+
196
+ # try:
197
+ # # Ensure WordNet is downloaded
198
+ # try:
199
+ # wordnet.synsets('test')
200
+ # except:
201
+ # print("πŸ“₯ Downloading WordNet data...")
202
+ # nltk_download('wordnet', quiet=True)
203
+ # nltk_download('omw-1.4', quiet=True)
204
+
205
+ # synonyms = []
206
+ # word_clean = word.lower().replace(' ', '_')
207
+
208
+ # for syn in wordnet.synsets(word_clean):
209
+ # for lemma in syn.lemmas():
210
+ # synonym = lemma.name().replace('_', ' ').lower()
211
+ # if synonym != word.lower() and len(synonym) > 2:
212
+ # confidence = 0.75 # Fixed confidence for WordNet
213
+ # synonyms.append((synonym, confidence, 'wordnet'))
214
+ # if len(synonyms) >= limit:
215
+ # break
216
+ # if len(synonyms) >= limit:
217
+ # break
218
+
219
+ # return synonyms[:limit]
220
+ # except Exception:
221
+ # return []
222
+
223
+ # def get_semantic_synonyms(self, term, candidate_pool, threshold=0.70, limit=15):
224
+ # """Get semantic synonyms using embeddings"""
225
+ # if not self.model or not candidate_pool:
226
+ # return []
227
+
228
+ # try:
229
+ # # E5 model requires query/passage prefixes
230
+ # query = f"query: {term}"
231
+ # candidates_prefixed = [f"passage: {c}" for c in candidate_pool]
232
+
233
+ # # Encode query
234
+ # term_emb = self.model.encode(
235
+ # query,
236
+ # convert_to_tensor=True,
237
+ # show_progress_bar=False
238
+ # )
239
+
240
+ # # Encode candidates in batches
241
+ # batch_size = 32 if self.device == "cuda" else 8
242
+ # all_embeddings = []
243
+
244
+ # for i in range(0, len(candidates_prefixed), batch_size):
245
+ # batch = candidates_prefixed[i:i + batch_size]
246
+ # emb = self.model.encode(
247
+ # batch,
248
+ # convert_to_tensor=True,
249
+ # show_progress_bar=False
250
+ # )
251
+ # all_embeddings.append(emb)
252
+
253
+ # # Concatenate all embeddings
254
+ # candidate_embs = torch.cat(all_embeddings, dim=0)
255
+
256
+ # # Calculate cosine similarity
257
+ # scores = util.cos_sim(term_emb, candidate_embs)[0]
258
+
259
+ # # Filter by threshold
260
+ # synonyms = []
261
+ # for candidate, score in zip(candidate_pool, scores):
262
+ # score_val = float(score)
263
+ # if score_val > threshold and candidate.lower() != term.lower():
264
+ # # Scale confidence between 0.6 and 0.95
265
+ # confidence = 0.60 + (score_val - threshold) * 0.35 / (1 - threshold)
266
+ # synonyms.append((candidate, confidence, 'semantic'))
267
+
268
+ # # Sort by confidence
269
+ # synonyms.sort(key=lambda x: x[1], reverse=True)
270
+ # return synonyms[:limit]
271
+
272
+ # except Exception as e:
273
+ # print(f"⚠️ Semantic error: {e}")
274
+ # return []
275
+
276
+ # def auto_generate_synonyms(self, term, candidate_pool=None,
277
+ # semantic_threshold=0.70, silent=False):
278
+ # """Generate synonyms from multiple sources"""
279
+ # all_synonyms = []
280
+
281
+ # if not silent:
282
+ # print(f"\nπŸ” Finding synonyms for: '{term}'")
283
+
284
+ # # Source 1: WordNet
285
+ # if WORDNET_AVAILABLE and not self.fast_mode:
286
+ # wn_syns = self.get_wordnet_synonyms(term, limit=10)
287
+ # all_synonyms.extend(wn_syns)
288
+
289
+ # # Source 2: Semantic similarity
290
+ # if candidate_pool and self.model:
291
+ # sem_syns = self.get_semantic_synonyms(
292
+ # term, candidate_pool,
293
+ # threshold=semantic_threshold,
294
+ # limit=15
295
+ # )
296
+ # all_synonyms.extend(sem_syns)
297
+
298
+ # # Deduplicate (keep highest confidence)
299
+ # synonym_map = {}
300
+ # for syn, conf, source in all_synonyms:
301
+ # syn_lower = syn.lower()
302
+ # if syn_lower not in synonym_map or conf > synonym_map[syn_lower][1]:
303
+ # synonym_map[syn_lower] = (syn, conf, source)
304
+
305
+ # final_synonyms = sorted(
306
+ # synonym_map.values(),
307
+ # key=lambda x: x[1],
308
+ # reverse=True
309
+ # )
310
+
311
+ # return final_synonyms
312
+
313
+ # def add_synonym_group(self, term, synonyms_with_confidence):
314
+ # """Add synonym group"""
315
+ # term_lower = term.lower()
316
+ # if term_lower not in self.synonyms:
317
+ # self.synonyms[term_lower] = []
318
+
319
+ # for syn, conf, src in synonyms_with_confidence:
320
+ # # Check if already exists
321
+ # if not any(s[0].lower() == syn.lower() for s in self.synonyms[term_lower]):
322
+ # self.synonyms[term_lower].append((syn, conf, src))
323
+
324
+ # def extract_terms_from_categories(self, csv_path, min_frequency=2):
325
+ # """Extract terms from category CSV"""
326
+ # print(f"\nπŸ“‚ Extracting terms from: {csv_path}")
327
+
328
+ # try:
329
+ # import pandas as pd
330
+
331
+ # # Read CSV
332
+ # df = pd.read_csv(csv_path)
333
+
334
+ # # Find path column (usually second column)
335
+ # path_col = df.columns[1] if len(df.columns) > 1 else df.columns[0]
336
+ # paths = df[path_col].dropna().astype(str)
337
+
338
+ # print(f" Processing {len(paths):,} category paths...")
339
+
340
+ # term_freq = defaultdict(int)
341
+
342
+ # for path in tqdm(paths, desc="Analyzing paths"):
343
+ # levels = path.split('/')
344
+
345
+ # for level in levels:
346
+ # words = level.lower().split()
347
+
348
+ # # Single words
349
+ # for word in words:
350
+ # if len(word) > 2 and word.isalpha():
351
+ # term_freq[word] += 1
352
+
353
+ # # Two-word phrases
354
+ # for i in range(len(words) - 1):
355
+ # if len(words[i]) > 2 and len(words[i+1]) > 2:
356
+ # phrase = f"{words[i]} {words[i+1]}"
357
+ # if phrase.replace(' ', '').isalpha():
358
+ # term_freq[phrase] += 1
359
+
360
+ # # Filter by frequency
361
+ # candidates = [
362
+ # term for term, freq in term_freq.items()
363
+ # if freq >= min_frequency
364
+ # ]
365
+
366
+ # print(f"βœ… Extracted {len(candidates):,} terms (min frequency: {min_frequency})")
367
+ # return candidates, term_freq
368
+
369
+ # except Exception as e:
370
+ # print(f"❌ Error extracting terms: {e}")
371
+ # import traceback
372
+ # traceback.print_exc()
373
+ # return [], {}
374
+
375
+ # def auto_build_from_categories(self, csv_path, top_terms=1000,
376
+ # semantic_threshold=0.70):
377
+ # """Auto-build synonym database from categories"""
378
+ # print("\n" + "="*80)
379
+ # print("πŸš€ AUTO-BUILD SYNONYM DATABASE")
380
+ # print("="*80)
381
+
382
+ # # Load model
383
+ # if not self.load_transformer_model():
384
+ # print("\n⚠️ Continuing with WordNet only (limited coverage)")
385
+
386
+ # # Extract terms
387
+ # all_terms, term_freq = self.extract_terms_from_categories(csv_path)
388
+ # if not all_terms:
389
+ # print("❌ No terms extracted")
390
+ # return False
391
+
392
+ # # Select top terms
393
+ # print(f"\n🎯 Selecting top {top_terms} terms...")
394
+ # top_frequent = sorted(
395
+ # term_freq.items(),
396
+ # key=lambda x: x[1],
397
+ # reverse=True
398
+ # )[:top_terms]
399
+ # terms_to_process = [term for term, _ in top_frequent]
400
+
401
+ # print(f"βœ… Selected {len(terms_to_process)} terms")
402
+ # print(f"πŸ“Š Top 10: {', '.join(terms_to_process[:10])}")
403
+ # print(f"\nπŸ”„ Generating synonyms (threshold={semantic_threshold})...\n")
404
+
405
+ # # Process terms
406
+ # stats = {
407
+ # 'processed': 0,
408
+ # 'synonyms': 0,
409
+ # 'high_conf': 0
410
+ # }
411
+
412
+ # for term in tqdm(terms_to_process, desc="Processing"):
413
+ # # Skip if already has enough synonyms
414
+ # if term in self.synonyms and len(self.synonyms[term]) >= 10:
415
+ # continue
416
+
417
+ # # Generate synonyms
418
+ # syns = self.auto_generate_synonyms(
419
+ # term,
420
+ # candidate_pool=all_terms,
421
+ # semantic_threshold=semantic_threshold,
422
+ # silent=True
423
+ # )
424
+
425
+ # if syns:
426
+ # self.add_synonym_group(term, syns)
427
+ # stats['processed'] += 1
428
+ # stats['synonyms'] += len(syns)
429
+ # stats['high_conf'] += sum(1 for _, c, _ in syns if c >= 0.8)
430
+
431
+ # # Print stats
432
+ # print(f"\nβœ… Processed: {stats['processed']:,} terms")
433
+ # print(f"βœ… Total synonyms: {stats['synonyms']:,}")
434
+ # print(f"βœ… High confidence (β‰₯0.8): {stats['high_conf']:,}")
435
+
436
+ # # Save
437
+ # self.save_synonyms()
438
+
439
+ # print("\nπŸŽ‰ AUTO-BUILD COMPLETE!\n")
440
+ # return True
441
+
442
+
443
+ # def main():
444
+ # """Main entry point"""
445
+ # print("\n" + "="*80)
446
+ # print("πŸ€– AI-POWERED SYNONYM MANAGER (Windows + NVIDIA GPU)")
447
+ # print("="*80 + "\n")
448
+
449
+ # # Parse arguments
450
+ # fast_mode = '--fast' in sys.argv
451
+
452
+ # if len(sys.argv) < 2:
453
+ # print("Usage:")
454
+ # print(" python synonym_manager_fixed.py autobuild <csv_file>")
455
+ # print(" python synonym_manager_fixed.py autobuild <csv_file> --fast")
456
+ # print("\nExample:")
457
+ # print(" python synonym_manager_fixed.py autobuild data/category_id_path_only.csv")
458
+ # return
459
+
460
+ # command = sys.argv[1].lower()
461
+
462
+ # if command == 'autobuild':
463
+ # if len(sys.argv) < 3:
464
+ # print("❌ CSV file path required")
465
+ # return
466
+
467
+ # csv_path = sys.argv[2]
468
+
469
+ # if not Path(csv_path).exists():
470
+ # print(f"❌ File not found: {csv_path}")
471
+ # return
472
+
473
+ # # Initialize manager
474
+ # manager = FixedAISynonymManager(fast_mode=fast_mode)
475
+
476
+ # # Run auto-build
477
+ # manager.auto_build_from_categories(csv_path, top_terms=1000)
478
+
479
+ # else:
480
+ # print(f"❌ Unknown command: {command}")
481
+
482
+
483
+ # if __name__ == "__main__":
484
+ # main()
485
+
486
+
487
+ #for cache2
488
+
489
+
490
+ """
491
+ πŸ€– AI-POWERED SYNONYM MANAGER (Fixed for Windows + GPU)
492
+ ========================================================
493
+ βœ… Uses e5-base-v2 (768D, memory-efficient)
494
+ βœ… Windows + NVIDIA GPU optimized
495
+ βœ… Generates cross-store synonyms automatically
496
+
497
+ Usage:
498
+ python synonym_manager_fixed.py autobuild data/category_id_path_only.csv
499
+ python synonym_manager_fixed.py autobuild data/category_id_path_only.csv --fast
500
+ """
501
+
502
+ import pickle
503
+ from pathlib import Path
504
+ import json
505
+ from collections import defaultdict
506
+ from tqdm import tqdm
507
+ import warnings
508
+ import sys
509
+ import os
510
+
511
+ warnings.filterwarnings('ignore')
512
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
513
+
514
+ try:
515
+ from nltk.corpus import wordnet
516
+ from nltk import download as nltk_download
517
+ WORDNET_AVAILABLE = True
518
+ except ImportError:
519
+ WORDNET_AVAILABLE = False
520
+
521
+ try:
522
+ from sentence_transformers import SentenceTransformer, util
523
+ import torch
524
+ TRANSFORMERS_AVAILABLE = True
525
+ except ImportError:
526
+ TRANSFORMERS_AVAILABLE = False
527
+
528
+
529
+ class SynonymManager:
530
+ """AI-powered synonym manager"""
531
+
532
+ def __init__(self, cache_dir='cache', fast_mode=False):
533
+ self.cache_dir = Path(cache_dir)
534
+ self.synonyms_file = self.cache_dir / 'cross_store_synonyms.pkl'
535
+ self.synonyms = {}
536
+ self.model = None
537
+ self.device = "cpu"
538
+ self.fast_mode = fast_mode
539
+
540
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
541
+
542
+ if self.synonyms_file.exists():
543
+ self.load_synonyms()
544
+
545
+ def load_synonyms(self):
546
+ """Load existing synonyms"""
547
+ try:
548
+ with open(self.synonyms_file, 'rb') as f:
549
+ loaded = pickle.load(f)
550
+
551
+ if loaded and list(loaded.values()):
552
+ first_val = next(iter(loaded.values()))
553
+
554
+ if isinstance(first_val, list) and first_val:
555
+ if isinstance(first_val[0], tuple):
556
+ self.synonyms = loaded
557
+ else:
558
+ self.synonyms = {k: [(v, 0.8, 'legacy') for v in vals] for k, vals in loaded.items()}
559
+ elif isinstance(first_val, set):
560
+ self.synonyms = {k: [(v, 0.8, 'legacy') for v in vals] for k, vals in loaded.items()}
561
+
562
+ print(f"βœ… Loaded {len(self.synonyms):,} synonym entries")
563
+ except Exception as e:
564
+ print(f"❌ Error loading synonyms: {e}")
565
+ self.synonyms = {}
566
+
567
+ def save_synonyms(self):
568
+ """Save synonyms"""
569
+ try:
570
+ with open(self.synonyms_file, 'wb') as f:
571
+ pickle.dump(self.synonyms, f)
572
+
573
+ json_file = self.cache_dir / 'synonyms_readable.json'
574
+ readable = {
575
+ term: [
576
+ {'synonym': syn, 'confidence': conf, 'source': src}
577
+ for syn, conf, src in syns
578
+ ]
579
+ for term, syns in self.synonyms.items()
580
+ }
581
+ with open(json_file, 'w', encoding='utf-8') as f:
582
+ json.dump(readable, f, indent=2, ensure_ascii=False)
583
+
584
+ print(f"βœ… Saved {len(self.synonyms):,} synonym entries")
585
+ return True
586
+ except Exception as e:
587
+ print(f"❌ Error saving synonyms: {e}")
588
+ return False
589
+
590
+ def load_transformer_model(self):
591
+ """Load e5-base-v2 model"""
592
+ if not TRANSFORMERS_AVAILABLE:
593
+ print("❌ SentenceTransformers not installed!")
594
+ return False
595
+
596
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
597
+
598
+ if self.device == "cuda":
599
+ print(f"πŸ”₯ NVIDIA GPU detected!")
600
+
601
+ model_name = "intfloat/e5-base-v2"
602
+ print(f"\nπŸ€– Loading {model_name}...")
603
+
604
+ try:
605
+ self.model = SentenceTransformer(model_name, device=self.device)
606
+
607
+ if self.device == "cuda":
608
+ self.model = self.model.half()
609
+ print("⚑ Enabled FP16 precision")
610
+
611
+ print("βœ… Model loaded\n")
612
+ return True
613
+ except Exception as e:
614
+ print(f"❌ Failed to load model: {e}")
615
+ return False
616
+
617
+ def get_wordnet_synonyms(self, word, limit=10):
618
+ """Get WordNet synonyms"""
619
+ if self.fast_mode or not WORDNET_AVAILABLE:
620
+ return []
621
+
622
+ try:
623
+ try:
624
+ wordnet.synsets('test')
625
+ except:
626
+ nltk_download('wordnet', quiet=True)
627
+ nltk_download('omw-1.4', quiet=True)
628
+
629
+ synonyms = []
630
+ word_clean = word.lower().replace(' ', '_')
631
+
632
+ for syn in wordnet.synsets(word_clean):
633
+ for lemma in syn.lemmas():
634
+ synonym = lemma.name().replace('_', ' ').lower()
635
+ if synonym != word.lower() and len(synonym) > 2:
636
+ confidence = 0.75
637
+ synonyms.append((synonym, confidence, 'wordnet'))
638
+ if len(synonyms) >= limit:
639
+ break
640
+ if len(synonyms) >= limit:
641
+ break
642
+
643
+ return synonyms[:limit]
644
+ except Exception:
645
+ return []
646
+
647
+ def get_semantic_synonyms(self, term, candidate_pool, threshold=0.70, limit=15):
648
+ """Get semantic synonyms using E5"""
649
+ if not self.model or not candidate_pool:
650
+ return []
651
+
652
+ try:
653
+ query = f"query: {term}"
654
+ candidates_prefixed = [f"passage: {c}" for c in candidate_pool]
655
+
656
+ term_emb = self.model.encode(query, convert_to_tensor=True, show_progress_bar=False)
657
+
658
+ batch_size = 32 if self.device == "cuda" else 8
659
+ all_embeddings = []
660
+
661
+ for i in range(0, len(candidates_prefixed), batch_size):
662
+ batch = candidates_prefixed[i:i + batch_size]
663
+ emb = self.model.encode(batch, convert_to_tensor=True, show_progress_bar=False)
664
+ all_embeddings.append(emb)
665
+
666
+ candidate_embs = torch.cat(all_embeddings, dim=0)
667
+ scores = util.cos_sim(term_emb, candidate_embs)[0]
668
+
669
+ synonyms = []
670
+ for candidate, score in zip(candidate_pool, scores):
671
+ score_val = float(score)
672
+ if score_val > threshold and candidate.lower() != term.lower():
673
+ confidence = 0.60 + (score_val - threshold) * 0.35 / (1 - threshold)
674
+ synonyms.append((candidate, confidence, 'semantic'))
675
+
676
+ synonyms.sort(key=lambda x: x[1], reverse=True)
677
+ return synonyms[:limit]
678
+
679
+ except Exception as e:
680
+ print(f"⚠️ Semantic error: {e}")
681
+ return []
682
+
683
+ def auto_generate_synonyms(self, term, candidate_pool=None, semantic_threshold=0.70, silent=False):
684
+ """Generate synonyms from multiple sources"""
685
+ all_synonyms = []
686
+
687
+ if not silent:
688
+ print(f"\nπŸ” Finding synonyms for: '{term}'")
689
+
690
+ if WORDNET_AVAILABLE and not self.fast_mode:
691
+ wn_syns = self.get_wordnet_synonyms(term, limit=10)
692
+ all_synonyms.extend(wn_syns)
693
+
694
+ if candidate_pool and self.model:
695
+ sem_syns = self.get_semantic_synonyms(
696
+ term, candidate_pool,
697
+ threshold=semantic_threshold,
698
+ limit=15
699
+ )
700
+ all_synonyms.extend(sem_syns)
701
+
702
+ synonym_map = {}
703
+ for syn, conf, source in all_synonyms:
704
+ syn_lower = syn.lower()
705
+ if syn_lower not in synonym_map or conf > synonym_map[syn_lower][1]:
706
+ synonym_map[syn_lower] = (syn, conf, source)
707
+
708
+ final_synonyms = sorted(synonym_map.values(), key=lambda x: x[1], reverse=True)
709
+ return final_synonyms
710
+
711
+ def add_synonym_group(self, term, synonyms_with_confidence):
712
+ """Add synonym group"""
713
+ term_lower = term.lower()
714
+ if term_lower not in self.synonyms:
715
+ self.synonyms[term_lower] = []
716
+
717
+ for syn, conf, src in synonyms_with_confidence:
718
+ if not any(s[0].lower() == syn.lower() for s in self.synonyms[term_lower]):
719
+ self.synonyms[term_lower].append((syn, conf, src))
720
+
721
+ def extract_terms_from_categories(self, csv_path, min_frequency=2):
722
+ """Extract terms from category CSV"""
723
+ print(f"\nπŸ“‚ Extracting terms from: {csv_path}")
724
+
725
+ try:
726
+ import pandas as pd
727
+
728
+ df = pd.read_csv(csv_path)
729
+ path_col = df.columns[1] if len(df.columns) > 1 else df.columns[0]
730
+ paths = df[path_col].dropna().astype(str)
731
+
732
+ print(f" Processing {len(paths):,} category paths...")
733
+
734
+ term_freq = defaultdict(int)
735
+
736
+ for path in tqdm(paths, desc="Analyzing paths"):
737
+ levels = path.split('/')
738
+
739
+ for level in levels:
740
+ words = level.lower().split()
741
+
742
+ for word in words:
743
+ if len(word) > 2 and word.isalpha():
744
+ term_freq[word] += 1
745
+
746
+ for i in range(len(words) - 1):
747
+ if len(words[i]) > 2 and len(words[i+1]) > 2:
748
+ phrase = f"{words[i]} {words[i+1]}"
749
+ if phrase.replace(' ', '').isalpha():
750
+ term_freq[phrase] += 1
751
+
752
+ candidates = [
753
+ term for term, freq in term_freq.items()
754
+ if freq >= min_frequency
755
+ ]
756
+
757
+ print(f"βœ… Extracted {len(candidates):,} terms (min frequency: {min_frequency})")
758
+ return candidates, term_freq
759
+
760
+ except Exception as e:
761
+ print(f"❌ Error extracting terms: {e}")
762
+ import traceback
763
+ traceback.print_exc()
764
+ return [], {}
765
+
766
+ def auto_build_from_categories(self, csv_path, top_terms=1000, semantic_threshold=0.70):
767
+ """Auto-build synonym database"""
768
+ print("\n" + "="*80)
769
+ print("πŸš€ AUTO-BUILD SYNONYM DATABASE")
770
+ print("="*80)
771
+
772
+ if not self.load_transformer_model():
773
+ print("\n⚠️ Continuing with WordNet only")
774
+
775
+ all_terms, term_freq = self.extract_terms_from_categories(csv_path)
776
+ if not all_terms:
777
+ print("❌ No terms extracted")
778
+ return False
779
+
780
+ print(f"\n🎯 Selecting top {top_terms} terms...")
781
+ top_frequent = sorted(term_freq.items(), key=lambda x: x[1], reverse=True)[:top_terms]
782
+ terms_to_process = [term for term, _ in top_frequent]
783
+
784
+ print(f"βœ… Selected {len(terms_to_process)} terms")
785
+ print(f"πŸ“Š Top 10: {', '.join(terms_to_process[:10])}")
786
+ print(f"\nπŸ”„ Generating synonyms (threshold={semantic_threshold})...\n")
787
+
788
+ stats = {'processed': 0, 'synonyms': 0, 'high_conf': 0}
789
+
790
+ for term in tqdm(terms_to_process, desc="Processing"):
791
+ if term in self.synonyms and len(self.synonyms[term]) >= 10:
792
+ continue
793
+
794
+ syns = self.auto_generate_synonyms(
795
+ term,
796
+ candidate_pool=all_terms,
797
+ semantic_threshold=semantic_threshold,
798
+ silent=True
799
+ )
800
+
801
+ if syns:
802
+ self.add_synonym_group(term, syns)
803
+ stats['processed'] += 1
804
+ stats['synonyms'] += len(syns)
805
+ stats['high_conf'] += sum(1 for _, c, _ in syns if c >= 0.8)
806
+
807
+ print(f"\nβœ… Processed: {stats['processed']:,} terms")
808
+ print(f"βœ… Total synonyms: {stats['synonyms']:,}")
809
+ print(f"βœ… High confidence (β‰₯0.8): {stats['high_conf']:,}")
810
+
811
+ self.save_synonyms()
812
+
813
+ print("\nπŸŽ‰ AUTO-BUILD COMPLETE!\n")
814
+ return True
815
+
816
+
817
+ def main():
818
+ """Main entry point"""
819
+ print("\n" + "="*80)
820
+ print("πŸ€– AI-POWERED SYNONYM MANAGER")
821
+ print("="*80 + "\n")
822
+
823
+ fast_mode = '--fast' in sys.argv
824
+
825
+ if len(sys.argv) < 2:
826
+ print("Usage:")
827
+ print(" python synonym_manager_fixed.py autobuild <csv_file>")
828
+ print(" python synonym_manager_fixed.py autobuild <csv_file> --fast")
829
+ print("\nExample:")
830
+ print(" python synonym_manager_fixed.py autobuild data/category_id_path_only.csv")
831
+ return
832
+
833
+ command = sys.argv[1].lower()
834
+
835
+ if command == 'autobuild':
836
+ if len(sys.argv) < 3:
837
+ print("❌ CSV file path required")
838
+ return
839
+
840
+ csv_path = sys.argv[2]
841
+
842
+ if not Path(csv_path).exists():
843
+ print(f"❌ File not found: {csv_path}")
844
+ return
845
+
846
+ manager = SynonymManager(fast_mode=fast_mode)
847
+ manager.auto_build_from_categories(csv_path, top_terms=1000)
848
+
849
+ else:
850
+ print(f"❌ Unknown command: {command}")
851
+
852
+
853
+ if __name__ == "__main__":
854
  main()
train_products.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ train.py
4
+ Build normalized embeddings + FAISS index for category catalog,
5
+ build parent embeddings, save synonyms from tags.json and optionally
6
+ train a LightGBM classifier and a simple confidence calibrator.
7
+
8
+ Assumptions / Files:
9
+ - categories CSV: category_only_path.csv (Category_ID,Category_path,Final_Category)
10
+ - optional: data/tags.json (map category_id -> list of phrases)
11
+ - optional: validation.csv (columns: product_title,category_id) used for calibrator / classifier
12
+
13
+ Outputs to ./cache:
14
+ - main_index.faiss
15
+ - metadata.pkl
16
+ - parent_embeddings.pkl
17
+ - cross_store_synonyms.pkl
18
+ - model_info.json
19
+ - calibrator.pkl (if validation exists)
20
+ - classifier.pkl (if --train-classifier used)
21
+ """
22
+
23
+ import argparse
24
+ import json
25
+ import os
26
+ import pickle
27
+ from pathlib import Path
28
+ from typing import List, Dict
29
+
30
+ import numpy as np
31
+ import pandas as pd
32
+ from tqdm import tqdm
33
+
34
+ # sentence-transformers + faiss
35
+ from sentence_transformers import SentenceTransformer
36
+ import faiss
37
+
38
+ # sklearn for calibrator and simple preprocessing
39
+ from sklearn.linear_model import LogisticRegression
40
+ from sklearn.preprocessing import StandardScaler
41
+ from sklearn.model_selection import train_test_split
42
+
43
+ # optional LightGBM (install if you plan to train classifier)
44
+ try:
45
+ import importlib
46
+ lgb = importlib.import_module("lightgbm")
47
+ LGB_AVAILABLE = True
48
+ except Exception:
49
+ lgb = None
50
+ LGB_AVAILABLE = False
51
+
52
+ CACHE_DIR = Path("cache")
53
+ CACHE_DIR.mkdir(exist_ok=True, parents=True)
54
+
55
+ DEFAULT_BATCH_SIZE_CPU = 256
56
+ DEFAULT_BATCH_SIZE_GPU = 16
57
+
58
+
59
+ def normalize_path_sep(path: str) -> str:
60
+ if not isinstance(path, str):
61
+ return ""
62
+ s = path.strip()
63
+ s = s.replace("/", " > ")
64
+ s = " > ".join([p.strip() for p in s.split(">") if p.strip()])
65
+ return s
66
+
67
+
68
+ def path_to_levels(path: str) -> List[str]:
69
+ n = normalize_path_sep(path)
70
+ return [p.strip() for p in n.split(" > ") if p.strip()]
71
+
72
+
73
+ def safe_pickle_save(obj, p: Path):
74
+ with open(p, "wb") as f:
75
+ pickle.dump(obj, f)
76
+
77
+
78
+ def build_encoder(model_name: str, use_cuda: bool):
79
+ device = "cuda" if use_cuda else "cpu"
80
+ print(f"Loading encoder: {model_name} on {device}")
81
+ model = SentenceTransformer(model_name, device=device)
82
+ if use_cuda:
83
+ try:
84
+ import torch
85
+ model = model.half()
86
+ print("Using FP16 on GPU to conserve VRAM.")
87
+ except Exception:
88
+ pass
89
+ return model
90
+
91
+
92
+ def encode_texts(model: SentenceTransformer, texts: List[str], use_cuda: bool) -> np.ndarray:
93
+ batch_size = DEFAULT_BATCH_SIZE_GPU if use_cuda else DEFAULT_BATCH_SIZE_CPU
94
+ print(f"Encoding {len(texts):,} texts in batches of {batch_size} ...")
95
+ all_emb = []
96
+ for i in tqdm(range(0, len(texts), batch_size)):
97
+ batch = texts[i:i + batch_size]
98
+ emb = model.encode(batch, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
99
+ if emb.ndim == 1:
100
+ emb = emb.reshape(1, -1)
101
+ all_emb.append(emb.astype("float32"))
102
+ embeddings = np.vstack(all_emb)
103
+ print("Final embeddings shape:", embeddings.shape)
104
+ return embeddings
105
+
106
+
107
+ def build_faiss_index(np_emb: np.ndarray, use_gpu: bool = False):
108
+ d = np_emb.shape[1]
109
+ print(f"Building IndexFlatIP (d={d}) on {'GPU' if use_gpu else 'CPU'}")
110
+ index = faiss.IndexFlatIP(d)
111
+ if use_gpu:
112
+ try:
113
+ res = faiss.StandardGpuResources()
114
+ index = faiss.index_cpu_to_gpu(res, 0, index)
115
+ print("Converted FAISS index to GPU")
116
+ except Exception as e:
117
+ print("GPU conversion failed; using CPU index:", e)
118
+ index.add(np_emb)
119
+ print("Index ntotal:", index.ntotal)
120
+ return index
121
+
122
+
123
+ def make_parent_embeddings(metadata: List[Dict], embeddings: np.ndarray) -> Dict[str, np.ndarray]:
124
+ """
125
+ For each possible parent path (every prefix), average embeddings of its children.
126
+ This helps hierarchical boosting during inference.
127
+ """
128
+ parent_map = {}
129
+ count_map = {}
130
+ for i, meta in enumerate(metadata):
131
+ levels = meta.get("levels", [])
132
+ for depth in range(1, len(levels)):
133
+ parent = " > ".join(levels[:depth])
134
+ if not parent:
135
+ continue
136
+ parent_map.setdefault(parent, np.zeros(embeddings.shape[1], dtype="float32"))
137
+ count_map.setdefault(parent, 0)
138
+ parent_map[parent] += embeddings[i]
139
+ count_map[parent] += 1
140
+
141
+ # average + normalize
142
+ from numpy.linalg import norm
143
+ final = {}
144
+ for k, vec in parent_map.items():
145
+ cnt = count_map.get(k, 1)
146
+ avg = vec / float(cnt)
147
+ nrm = np.linalg.norm(avg) + 1e-12
148
+ final[k] = (avg / nrm).astype("float32")
149
+ return final
150
+
151
+
152
+ def load_tags_json(path: Path) -> Dict[str, List[str]]:
153
+ if not path.exists():
154
+ return {}
155
+ try:
156
+ with open(path, "r", encoding="utf-8") as f:
157
+ data = json.load(f)
158
+ # ensure keys are strings
159
+ return {str(k): [str(x) for x in v] for k, v in data.items()}
160
+ except Exception as e:
161
+ print("Failed to load tags.json:", e)
162
+ return {}
163
+
164
+
165
+ def train_calibrator(encoder, metadata, faiss_index, val_path: Path, model_name: str, use_cuda: bool):
166
+ """
167
+ Build a simple calibrator mapping raw cosine similarity of (product -> true category emb)
168
+ to a probability. Uses sklearn LogisticRegression on one feature (raw_score).
169
+ Expects validation.csv with columns product_title,category_id
170
+ """
171
+ print("Training calibrator using:", val_path)
172
+ df = pd.read_csv(val_path, dtype=str, keep_default_na=False)
173
+ if "product_title" not in df.columns or "category_id" not in df.columns:
174
+ print("validation.csv must have 'product_title' and 'category_id' columns. Skipping calibrator.")
175
+ return None
176
+
177
+ examples = []
178
+ labels = []
179
+ # Build a mapping category_id -> embedding (from metadata)
180
+ id_to_idx = {m["category_id"]: i for i, m in enumerate(metadata)}
181
+
182
+ # prepare product embeddings in batches
183
+ titles = df["product_title"].astype(str).tolist()
184
+ prod_embs = encode_texts(encoder, [f"query: {t}" for t in titles], use_cuda=use_cuda)
185
+
186
+ for i, row in df.iterrows():
187
+ cid = str(row["category_id"]).strip()
188
+ if cid not in id_to_idx:
189
+ # not in catalog, skip sample
190
+ continue
191
+ cat_idx = id_to_idx[cid]
192
+ cat_emb = metadata[cat_idx].get("_embedding") # we will attach embeddings later temporarily
193
+ if cat_emb is None:
194
+ continue
195
+ q_emb = prod_embs[i].reshape(1, -1).astype("float32")
196
+ raw = float(np.dot(q_emb, cat_emb.reshape(-1, 1))[0][0]) # cosine because normalized
197
+ # positive
198
+ examples.append([raw])
199
+ labels.append(1)
200
+
201
+ # generate few negatives by sampling other categories
202
+ # sample up to 2 random negatives
203
+ negs = 2
204
+ for _ in range(negs):
205
+ import random
206
+ rand_idx = random.randrange(len(metadata))
207
+ if rand_idx == cat_idx:
208
+ continue
209
+ neg_emb = metadata[rand_idx].get("_embedding")
210
+ if neg_emb is None:
211
+ continue
212
+ raw_neg = float(np.dot(q_emb, neg_emb.reshape(-1, 1))[0][0])
213
+ examples.append([raw_neg])
214
+ labels.append(0)
215
+
216
+ if not examples:
217
+ print("No examples for calibrator (maybe category ids mismatch). Skipping.")
218
+ return None
219
+
220
+ X = np.array(examples, dtype="float32")
221
+ y = np.array(labels, dtype="int8")
222
+ scaler = StandardScaler()
223
+ Xs = scaler.fit_transform(X)
224
+ clf = LogisticRegression(max_iter=200)
225
+ clf.fit(Xs, y)
226
+ print("Calibrator trained (logistic regression on raw cosine).")
227
+ return {"calibrator": clf, "scaler": scaler}
228
+
229
+
230
+ def attach_embeddings_to_metadata(metadata: List[Dict], embeddings: np.ndarray):
231
+ for i, m in enumerate(metadata):
232
+ m["_embedding"] = embeddings[i]
233
+
234
+
235
+ def detach_embeddings_from_metadata(metadata: List[Dict]):
236
+ for m in metadata:
237
+ if "_embedding" in m:
238
+ del m["_embedding"]
239
+
240
+
241
+ def main():
242
+ parser = argparse.ArgumentParser()
243
+ parser.add_argument("--csv", required=True, help="categories CSV (Category_ID,Category_path,Final_Category)")
244
+ parser.add_argument("--model", default="intfloat/e5-base-v2", help="embedding model")
245
+ parser.add_argument("--gpu", action="store_true", help="use GPU for encoding if available (careful with 4GB)")
246
+ parser.add_argument("--clean-cache", action="store_true", help="delete other cache files after build")
247
+ parser.add_argument("--train-classifier", action="store_true", help="train LightGBM classifier on validation.csv (optional)")
248
+ parser.add_argument("--validation", default="data/validation.csv", help="validation CSV used for calibrator / classifier")
249
+ parser.add_argument("--tags", default="data/tags.json", help="tags.json path (optional)")
250
+ args = parser.parse_args()
251
+
252
+ csv_path = Path(args.csv)
253
+ if not csv_path.exists():
254
+ raise SystemExit("CSV not found: " + str(csv_path))
255
+
256
+ print("Reading CSV:", csv_path)
257
+ df = pd.read_csv(csv_path, dtype=str, keep_default_na=False)
258
+ if df.shape[1] < 2:
259
+ raise SystemExit("CSV must have at least 2 columns: Category_ID, Category_path")
260
+
261
+ # columns
262
+ cols = list(df.columns)
263
+ cid_col, path_col = cols[0], cols[1]
264
+ print("Using columns:", cid_col, path_col)
265
+
266
+ metadata = []
267
+ texts_for_encoding = []
268
+ for idx, row in df.iterrows():
269
+ cid = str(row[cid_col]).strip()
270
+ raw_path = str(row[path_col]).strip()
271
+ norm_path = normalize_path_sep(raw_path)
272
+ levels = path_to_levels(norm_path)
273
+ final = levels[-1] if levels else norm_path or cid
274
+ # include both path and final in canonical text to encode
275
+ text = f"category: {norm_path}. leaf: {final}."
276
+ metadata.append({
277
+ "category_id": cid,
278
+ "category_path": norm_path,
279
+ "final": final,
280
+ "levels": levels,
281
+ "depth": len(levels)
282
+ })
283
+ texts_for_encoding.append(text)
284
+
285
+ print(f"Prepared {len(metadata):,} metadata entries")
286
+
287
+ # encoder
288
+ use_cuda = args.gpu
289
+ encoder = build_encoder(args.model, use_cuda=use_cuda)
290
+
291
+ # encode categories
292
+ cat_embeddings = encode_texts(encoder, texts_for_encoding, use_cuda=use_cuda)
293
+
294
+ # Attach embeddings temporarily for calibrator builder
295
+ attach_embeddings_to_metadata(metadata, cat_embeddings)
296
+
297
+ # parent embeddings
298
+ parent_emb = make_parent_embeddings(metadata, cat_embeddings)
299
+ print(f"Built {len(parent_emb):,} parent embeddings")
300
+
301
+ # Build CPU FAISS index (IP on normalized vectors -> cosine)
302
+ index = build_faiss_index(cat_embeddings, use_gpu=False)
303
+
304
+ # save index (FAISS CPU index)
305
+ faiss_path = CACHE_DIR / "main_index.faiss"
306
+ faiss.write_index(index, str(faiss_path))
307
+ print("Saved FAISS index:", faiss_path)
308
+
309
+ # save metadata (we will strip embeddings before saving to reduce pickle size)
310
+ detach_embeddings_from_metadata(metadata)
311
+ meta_path = CACHE_DIR / "metadata.pkl"
312
+ safe_pickle_save(metadata, meta_path)
313
+ print("Saved metadata:", meta_path)
314
+
315
+ # save parent embeddings
316
+ parent_path = CACHE_DIR / "parent_embeddings.pkl"
317
+ safe_pickle_save(parent_emb, parent_path)
318
+ print("Saved parent embeddings:", parent_path)
319
+
320
+ # save model_info
321
+ info = {
322
+ "model_name": args.model,
323
+ "num_categories": len(metadata),
324
+ "embedding_dim": cat_embeddings.shape[1]
325
+ }
326
+ with open(CACHE_DIR / "model_info.json", "w", encoding="utf-8") as f:
327
+ json.dump(info, f, indent=2)
328
+ print("Saved model_info.json")
329
+
330
+ # store tags.json -> cross_store_synonyms (just preserve structure)
331
+ tags = load_tags_json(Path(args.tags))
332
+ if tags:
333
+ syn_p = CACHE_DIR / "cross_store_synonyms.pkl"
334
+ safe_pickle_save(tags, syn_p)
335
+ print("Saved cross_store_synonyms.pkl from tags.json (size: %d)" % len(tags))
336
+
337
+ # calibrator: use validation.csv if exists
338
+ val_path = Path(args.validation)
339
+ calibrator_obj = None
340
+ if val_path.exists():
341
+ # we need embeddings attached again for calibrator training
342
+ attach_embeddings_to_metadata(metadata, cat_embeddings)
343
+ calibrator_obj = train_calibrator(encoder, metadata, index, val_path, args.model, use_cuda=use_cuda)
344
+ detach_embeddings_from_metadata(metadata)
345
+ if calibrator_obj:
346
+ safe_pickle_save(calibrator_obj, CACHE_DIR / "calibrator.pkl")
347
+ print("Saved calibrator.pkl")
348
+
349
+ # optional LightGBM classifier
350
+ if args.train_classifier:
351
+ if not LGB_AVAILABLE:
352
+ print("LightGBM not available. Install lightgbm to train classifier.")
353
+ else:
354
+ val_path2 = Path(args.validation)
355
+ if not val_path2.exists():
356
+ print("validation.csv required to train classifier. Skipping classifier training.")
357
+ else:
358
+ # create training set from validation.csv
359
+ dfv = pd.read_csv(val_path2, dtype=str, keep_default_na=False)
360
+ if "product_title" not in dfv.columns or "category_id" not in dfv.columns:
361
+ print("validation.csv must contain product_title and category_id. Skipping classifier.")
362
+ else:
363
+ # encode product titles
364
+ prod_texts = [f"query: {t}" for t in dfv["product_title"].astype(str).tolist()]
365
+ prod_embs = encode_texts(encoder, prod_texts, use_cuda=use_cuda)
366
+ # map category ids to numeric labels
367
+ cat_to_label = {m["category_id"]: i for i, m in enumerate(metadata)}
368
+ labels = []
369
+ rows = []
370
+ for i, row in dfv.iterrows():
371
+ cid = row["category_id"]
372
+ if cid not in cat_to_label:
373
+ continue
374
+ labels.append(cat_to_label[cid])
375
+ rows.append(prod_embs[i])
376
+ if len(rows) < 50:
377
+ print("Not enough training rows for classifier. Need >=50. Skipping.")
378
+ else:
379
+ X = np.vstack(rows)
380
+ y = np.array(labels, dtype=np.int32)
381
+ X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.15, random_state=42, stratify=y)
382
+ lgb_train = lgb.Dataset(X_train, label=y_train)
383
+ lgb_eval = lgb.Dataset(X_val, label=y_val, reference=lgb_train)
384
+ params = {
385
+ "objective": "multiclass",
386
+ "num_class": int(max(y) + 1),
387
+ "metric": "multi_logloss",
388
+ "verbosity": -1,
389
+ "num_threads": 4,
390
+ "learning_rate": 0.1,
391
+ "num_leaves": 31
392
+ }
393
+ print("Training LightGBM classifier (may take time)...")
394
+ gbm = lgb.train(params, lgb_train, valid_sets=[lgb_train, lgb_eval], early_stopping_rounds=30, num_boost_round=500)
395
+ # save classifier and mapping
396
+ clf_path = CACHE_DIR / "classifier.pkl"
397
+ safe_pickle_save({"model": gbm, "cat_to_label": cat_to_label, "label_to_cat": {v: k for k, v in cat_to_label.items()}}, clf_path)
398
+ print("Saved classifier.pkl")
399
+
400
+ # cleanup if asked
401
+ if args.clean_cache:
402
+ keep = {"main_index.faiss", "metadata.pkl", "model_info.json", "parent_embeddings.pkl", "cross_store_synonyms.pkl"}
403
+ if calibrator_obj:
404
+ keep.add("calibrator.pkl")
405
+ # remove everything else in cache
406
+ removed = []
407
+ for p in CACHE_DIR.iterdir():
408
+ if p.name in keep:
409
+ continue
410
+ try:
411
+ p.unlink()
412
+ removed.append(p.name)
413
+ except Exception:
414
+ pass
415
+ if removed:
416
+ print("Removed cache files:", removed)
417
+
418
+ print("DONE. Index + data saved to cache/")
419
+
420
+ if __name__ == "__main__":
421
+ main()
validation_data.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ πŸ“Š VALIDATION DATA CREATOR
3
+ ===========================
4
+ Helper script to create validation CSV for confidence calibration.
5
+
6
+ Two modes:
7
+ 1. Sample from existing categories (automated)
8
+ 2. Manual entry (interactive)
9
+
10
+ Output format:
11
+ product_title,true_category_id
12
+ "Oxygen Sensor Tool",12345
13
+ "Hydraulic Oil Additive",67890
14
+
15
+ Usage:
16
+ # Automated sampling:
17
+ python create_validation_data.py auto data/category_id_path_only.csv
18
+
19
+ # Manual entry:
20
+ python create_validation_data.py manual
21
+ """
22
+
23
+ import pandas as pd
24
+ import sys
25
+ from pathlib import Path
26
+ import random
27
+
28
+
29
+ def sample_from_categories(csv_path, num_samples=100, output_file='data/validation.csv'):
30
+ """
31
+ Automatically create validation data by sampling from categories
32
+ and generating product titles based on category paths.
33
+ """
34
+ print("\n" + "="*80)
35
+ print("πŸ“Š AUTO-GENERATING VALIDATION DATA")
36
+ print("="*80 + "\n")
37
+
38
+ # Load categories
39
+ print(f"Loading: {csv_path}")
40
+ df = pd.read_csv(csv_path)
41
+
42
+ if len(df.columns) < 2:
43
+ print("❌ CSV must have at least 2 columns (category_id, category_path)")
44
+ return False
45
+
46
+ df.columns = ['category_id', 'category_path'] + list(df.columns[2:])
47
+ df = df.dropna(subset=['category_path'])
48
+
49
+ print(f"βœ… Loaded {len(df):,} categories\n")
50
+
51
+ # Sample categories
52
+ sample_size = min(num_samples, len(df))
53
+ sampled = df.sample(n=sample_size, random_state=42)
54
+
55
+ print(f"πŸ“ Generating {sample_size} validation entries...\n")
56
+
57
+ validation_data = []
58
+
59
+ for idx, row in sampled.iterrows():
60
+ cat_id = str(row['category_id'])
61
+ cat_path = str(row['category_path'])
62
+
63
+ # Generate product title from category path
64
+ levels = cat_path.split('/')
65
+
66
+ # Use last 2-3 levels as product title
67
+ if len(levels) >= 3:
68
+ title_parts = levels[-3:]
69
+ elif len(levels) >= 2:
70
+ title_parts = levels[-2:]
71
+ else:
72
+ title_parts = levels
73
+
74
+ # Clean and combine
75
+ title = ' '.join(title_parts).strip()
76
+
77
+ # Add some variation
78
+ variations = [
79
+ title,
80
+ f"{title} kit",
81
+ f"{title} tool",
82
+ f"{title} set",
83
+ f"professional {title}",
84
+ f"{title} replacement",
85
+ ]
86
+
87
+ product_title = random.choice(variations)
88
+
89
+ validation_data.append({
90
+ 'product_title': product_title,
91
+ 'true_category_id': cat_id
92
+ })
93
+
94
+ # Create DataFrame
95
+ val_df = pd.DataFrame(validation_data)
96
+
97
+ # Save
98
+ output_path = Path(output_file)
99
+ output_path.parent.mkdir(parents=True, exist_ok=True)
100
+
101
+ val_df.to_csv(output_path, index=False)
102
+
103
+ print(f"βœ… Created validation file: {output_path}")
104
+ print(f" Entries: {len(val_df):,}")
105
+
106
+ # Show samples
107
+ print("\nπŸ“ Sample entries:")
108
+ for i, row in val_df.head(5).iterrows():
109
+ print(f" {i+1}. \"{row['product_title']}\" β†’ {row['true_category_id']}")
110
+
111
+ print("\n" + "="*80)
112
+ print("βœ… VALIDATION DATA CREATED!")
113
+ print("="*80)
114
+ print(f"\nNext step: Train with calibration")
115
+ print(f" python train_fixed_v2.py data/category_id_path_only.csv data/tags.json {output_path}")
116
+ print("="*80 + "\n")
117
+
118
+ return True
119
+
120
+
121
+ def manual_entry(output_file='data/validation_manual.csv'):
122
+ """
123
+ Interactive mode to manually create validation data.
124
+ """
125
+ print("\n" + "="*80)
126
+ print("πŸ“ MANUAL VALIDATION DATA ENTRY")
127
+ print("="*80)
128
+ print("\nEnter product titles and their correct category IDs.")
129
+ print("Press CTRL+C when done.\n")
130
+
131
+ validation_data = []
132
+
133
+ try:
134
+ while True:
135
+ print(f"\n--- Entry #{len(validation_data) + 1} ---")
136
+
137
+ title = input("Product title: ").strip()
138
+ if not title:
139
+ print("⚠️ Title cannot be empty")
140
+ continue
141
+
142
+ cat_id = input("Category ID: ").strip()
143
+ if not cat_id:
144
+ print("⚠️ Category ID cannot be empty")
145
+ continue
146
+
147
+ validation_data.append({
148
+ 'product_title': title,
149
+ 'true_category_id': cat_id
150
+ })
151
+
152
+ print(f"βœ… Added: \"{title}\" β†’ {cat_id}")
153
+
154
+ except KeyboardInterrupt:
155
+ print("\n\nπŸ“Š Entry complete!")
156
+
157
+ if not validation_data:
158
+ print("❌ No entries created")
159
+ return False
160
+
161
+ # Create DataFrame
162
+ val_df = pd.DataFrame(validation_data)
163
+
164
+ # Save
165
+ output_path = Path(output_file)
166
+ output_path.parent.mkdir(parents=True, exist_ok=True)
167
+
168
+ val_df.to_csv(output_path, index=False)
169
+
170
+ print(f"\nβœ… Created validation file: {output_path}")
171
+ print(f" Entries: {len(val_df):,}")
172
+
173
+ print("\n" + "="*80)
174
+ print("βœ… VALIDATION DATA CREATED!")
175
+ print("="*80)
176
+ print(f"\nNext step: Train with calibration")
177
+ print(f" python train_fixed_v2.py data/category_id_path_only.csv data/tags.json {output_path}")
178
+ print("="*80 + "\n")
179
+
180
+ return True
181
+
182
+
183
+ def verify_validation_file(validation_csv, categories_csv):
184
+ """
185
+ Verify that validation data references valid category IDs.
186
+ """
187
+ print("\n" + "="*80)
188
+ print("πŸ” VERIFYING VALIDATION DATA")
189
+ print("="*80 + "\n")
190
+
191
+ # Load validation data
192
+ print(f"Loading validation: {validation_csv}")
193
+ val_df = pd.read_csv(validation_csv)
194
+
195
+ if 'product_title' not in val_df.columns or 'true_category_id' not in val_df.columns:
196
+ print("❌ Validation CSV must have: product_title, true_category_id")
197
+ return False
198
+
199
+ print(f"βœ… Loaded {len(val_df):,} validation entries\n")
200
+
201
+ # Load categories
202
+ print(f"Loading categories: {categories_csv}")
203
+ cat_df = pd.read_csv(categories_csv)
204
+ cat_df.columns = ['category_id', 'category_path'] + list(cat_df.columns[2:])
205
+
206
+ valid_ids = set(cat_df['category_id'].astype(str))
207
+ print(f"βœ… Loaded {len(valid_ids):,} valid category IDs\n")
208
+
209
+ # Verify
210
+ print("Checking validation entries...")
211
+ invalid_count = 0
212
+
213
+ for idx, row in val_df.iterrows():
214
+ cat_id = str(row['true_category_id'])
215
+ title = row['product_title']
216
+
217
+ if cat_id not in valid_ids:
218
+ print(f"❌ Invalid ID: {cat_id} for \"{title}\"")
219
+ invalid_count += 1
220
+
221
+ if invalid_count == 0:
222
+ print("βœ… All validation entries are valid!")
223
+ else:
224
+ print(f"\n⚠️ Found {invalid_count} invalid entries")
225
+
226
+ # Summary
227
+ print("\n" + "="*80)
228
+ print("πŸ“Š VALIDATION DATA SUMMARY")
229
+ print("="*80)
230
+ print(f"Total entries: {len(val_df):,}")
231
+ print(f"Valid entries: {len(val_df) - invalid_count:,}")
232
+ print(f"Invalid entries: {invalid_count}")
233
+ print("="*80 + "\n")
234
+
235
+ return invalid_count == 0
236
+
237
+
238
+ def main():
239
+ print("\n" + "="*80)
240
+ print("πŸ“Š VALIDATION DATA CREATOR")
241
+ print("="*80 + "\n")
242
+
243
+ if len(sys.argv) < 2:
244
+ print("Usage:")
245
+ print(" python create_validation_data.py auto <csv_path> [num_samples] [output_file]")
246
+ print(" python create_validation_data.py manual [output_file]")
247
+ print(" python create_validation_data.py verify <validation_csv> <categories_csv>")
248
+ print("\nExamples:")
249
+ print(" # Auto-generate 100 samples:")
250
+ print(" python create_validation_data.py auto data/category_id_path_only.csv")
251
+ print()
252
+ print(" # Auto-generate 200 samples:")
253
+ print(" python create_validation_data.py auto data/category_id_path_only.csv 200")
254
+ print()
255
+ print(" # Manual entry:")
256
+ print(" python create_validation_data.py manual")
257
+ print()
258
+ print(" # Verify validation file:")
259
+ print(" python create_validation_data.py verify data/validation.csv data/category_id_path_only.csv")
260
+ print()
261
+ return
262
+
263
+ mode = sys.argv[1].lower()
264
+
265
+ if mode == 'auto':
266
+ if len(sys.argv) < 3:
267
+ print("❌ CSV path required for auto mode")
268
+ print(" python create_validation_data.py auto data/category_id_path_only.csv")
269
+ return
270
+
271
+ csv_path = sys.argv[2]
272
+ num_samples = int(sys.argv[3]) if len(sys.argv) > 3 else 100
273
+ output_file = sys.argv[4] if len(sys.argv) > 4 else 'data/validation.csv'
274
+
275
+ if not Path(csv_path).exists():
276
+ print(f"❌ File not found: {csv_path}")
277
+ return
278
+
279
+ sample_from_categories(csv_path, num_samples, output_file)
280
+
281
+ elif mode == 'manual':
282
+ output_file = sys.argv[2] if len(sys.argv) > 2 else 'data/validation_manual.csv'
283
+ manual_entry(output_file)
284
+
285
+ elif mode == 'verify':
286
+ if len(sys.argv) < 4:
287
+ print("❌ Both validation CSV and categories CSV required")
288
+ print(" python create_validation_data.py verify data/validation.csv data/category_id_path_only.csv")
289
+ return
290
+
291
+ validation_csv = sys.argv[2]
292
+ categories_csv = sys.argv[3]
293
+
294
+ if not Path(validation_csv).exists():
295
+ print(f"❌ File not found: {validation_csv}")
296
+ return
297
+
298
+ if not Path(categories_csv).exists():
299
+ print(f"❌ File not found: {categories_csv}")
300
+ return
301
+
302
+ verify_validation_file(validation_csv, categories_csv)
303
+
304
+ else:
305
+ print(f"❌ Unknown mode: {mode}")
306
+ print(" Use: auto, manual, or verify")
307
+
308
+
309
+ if __name__ == "__main__":
310
+ main()