Spaces:
No application file
No application file
Upload 11 files
Browse files- .gitignore +16 -0
- api_server.py +1377 -0
- check.py +365 -0
- fix.py +270 -0
- gradio_app.py +259 -0
- miss.py +421 -0
- path.py +141 -0
- requirements.txt +28 -0
- synonyms.py +853 -365
- train_products.py +421 -0
- validation_data.py +310 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.Python
|
| 4 |
+
venv/
|
| 5 |
+
env/
|
| 6 |
+
.vscode/
|
| 7 |
+
.idea/
|
| 8 |
+
.DS_Store
|
| 9 |
+
*.bin
|
| 10 |
+
*.safetensors
|
| 11 |
+
*.log
|
| 12 |
+
cache/*.faiss
|
| 13 |
+
cache/*.npy
|
| 14 |
+
!cache/metadata.pkl
|
| 15 |
+
!cache/model_info.json
|
| 16 |
+
!cache/cross_store_synonyms.pkl
|
api_server.py
ADDED
|
@@ -0,0 +1,1377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ο»Ώ# """
|
| 2 |
+
# π― COMPLETE API SERVER - Matches Cross-Store Training System
|
| 3 |
+
# =============================================================
|
| 4 |
+
# β
Works with cross-store synonyms (washing machine = laundry machine)
|
| 5 |
+
# β
Uses auto-tags from training
|
| 6 |
+
# β
Single model (fast predictions)
|
| 7 |
+
# β
Guaranteed category_id match
|
| 8 |
+
# β
Real-time classification
|
| 9 |
+
# """
|
| 10 |
+
|
| 11 |
+
# from flask import Flask, request, jsonify, render_template_string
|
| 12 |
+
# from sentence_transformers import SentenceTransformer
|
| 13 |
+
# import faiss
|
| 14 |
+
# import pickle
|
| 15 |
+
# import numpy as np
|
| 16 |
+
# from pathlib import Path
|
| 17 |
+
# import time
|
| 18 |
+
# import re
|
| 19 |
+
|
| 20 |
+
# app = Flask(__name__)
|
| 21 |
+
|
| 22 |
+
# # ============================================================================
|
| 23 |
+
# # GLOBAL VARIABLES
|
| 24 |
+
# # ============================================================================
|
| 25 |
+
|
| 26 |
+
# CACHE_DIR = Path('cache')
|
| 27 |
+
|
| 28 |
+
# # Model
|
| 29 |
+
# encoder = None
|
| 30 |
+
# faiss_index = None
|
| 31 |
+
# metadata = []
|
| 32 |
+
# cross_store_synonyms = {}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# # ============================================================================
|
| 36 |
+
# # CROSS-STORE SYNONYM DATABASE (Same as training)
|
| 37 |
+
# # ============================================================================
|
| 38 |
+
|
| 39 |
+
# def build_cross_store_synonyms():
|
| 40 |
+
# """Build cross-store synonym database"""
|
| 41 |
+
# synonyms = {
|
| 42 |
+
# # Appliances
|
| 43 |
+
# 'washing machine': {'laundry machine', 'washer', 'clothes washer', 'washing appliance'},
|
| 44 |
+
# 'laundry machine': {'washing machine', 'washer', 'clothes washer'},
|
| 45 |
+
# 'dryer': {'drying machine', 'clothes dryer', 'tumble dryer'},
|
| 46 |
+
# 'refrigerator': {'fridge', 'cooler', 'ice box', 'cooling appliance'},
|
| 47 |
+
# 'dishwasher': {'dish washer', 'dish cleaning machine'},
|
| 48 |
+
# 'microwave': {'microwave oven', 'micro wave'},
|
| 49 |
+
# 'vacuum': {'vacuum cleaner', 'hoover', 'vac'},
|
| 50 |
+
|
| 51 |
+
# # Electronics
|
| 52 |
+
# 'tv': {'television', 'telly', 'smart tv', 'display'},
|
| 53 |
+
# 'laptop': {'notebook', 'portable computer', 'laptop computer'},
|
| 54 |
+
# 'mobile': {'phone', 'cell phone', 'smartphone', 'cellphone'},
|
| 55 |
+
# 'tablet': {'ipad', 'tab', 'tablet computer'},
|
| 56 |
+
# 'headphones': {'headset', 'earphones', 'earbuds', 'ear buds'},
|
| 57 |
+
# 'speaker': {'audio speaker', 'sound system', 'speakers'},
|
| 58 |
+
|
| 59 |
+
# # Furniture
|
| 60 |
+
# 'sofa': {'couch', 'settee', 'divan'},
|
| 61 |
+
# 'wardrobe': {'closet', 'armoire', 'cupboard'},
|
| 62 |
+
# 'drawer': {'chest of drawers', 'dresser'},
|
| 63 |
+
|
| 64 |
+
# # Clothing
|
| 65 |
+
# 'pants': {'trousers', 'slacks', 'bottoms'},
|
| 66 |
+
# 'sweater': {'jumper', 'pullover', 'sweatshirt'},
|
| 67 |
+
# 'sneakers': {'trainers', 'tennis shoes', 'running shoes'},
|
| 68 |
+
# 'jacket': {'coat', 'blazer', 'outerwear'},
|
| 69 |
+
|
| 70 |
+
# # Kitchen
|
| 71 |
+
# 'cooker': {'stove', 'range', 'cooking range'},
|
| 72 |
+
# 'blender': {'mixer', 'food processor', 'liquidizer'},
|
| 73 |
+
# 'kettle': {'electric kettle', 'water boiler'},
|
| 74 |
+
|
| 75 |
+
# # Baby/Kids
|
| 76 |
+
# 'stroller': {'pram', 'pushchair', 'buggy', 'baby carriage'},
|
| 77 |
+
# 'diaper': {'nappy', 'nappies'},
|
| 78 |
+
# 'pacifier': {'dummy', 'soother'},
|
| 79 |
+
|
| 80 |
+
# # Tools
|
| 81 |
+
# 'wrench': {'spanner', 'adjustable wrench'},
|
| 82 |
+
# 'flashlight': {'torch', 'flash light'},
|
| 83 |
+
# 'screwdriver': {'screw driver'},
|
| 84 |
+
|
| 85 |
+
# # Home
|
| 86 |
+
# 'tap': {'faucet', 'water tap'},
|
| 87 |
+
# 'bin': {'trash can', 'garbage can', 'waste bin'},
|
| 88 |
+
# 'curtain': {'drape', 'window covering'},
|
| 89 |
+
|
| 90 |
+
# # Crafts/Office
|
| 91 |
+
# 'guillotine': {'paper cutter', 'paper trimmer', 'blade cutter'},
|
| 92 |
+
# 'trimmer': {'cutter', 'cutting tool', 'edge cutter'},
|
| 93 |
+
# 'stapler': {'stapling machine', 'staple gun'},
|
| 94 |
+
|
| 95 |
+
# # Books/Media
|
| 96 |
+
# 'magazine': {'periodical', 'journal', 'publication'},
|
| 97 |
+
# 'comic': {'comic book', 'graphic novel', 'manga'},
|
| 98 |
+
# 'ebook': {'e-book', 'digital book', 'electronic book'},
|
| 99 |
+
|
| 100 |
+
# # General
|
| 101 |
+
# 'kids': {'children', 'child', 'childrens', 'youth', 'junior'},
|
| 102 |
+
# 'women': {'womens', 'ladies', 'female', 'lady'},
|
| 103 |
+
# 'men': {'mens', 'male', 'gentleman'},
|
| 104 |
+
# 'baby': {'infant', 'newborn', 'toddler'},
|
| 105 |
+
# }
|
| 106 |
+
|
| 107 |
+
# # Build bidirectional mapping
|
| 108 |
+
# expanded = {}
|
| 109 |
+
# for term, syns in synonyms.items():
|
| 110 |
+
# expanded[term] = syns.copy()
|
| 111 |
+
# for syn in syns:
|
| 112 |
+
# if syn not in expanded:
|
| 113 |
+
# expanded[syn] = set()
|
| 114 |
+
# expanded[syn].add(term)
|
| 115 |
+
# expanded[syn].update(syns - {syn})
|
| 116 |
+
|
| 117 |
+
# return expanded
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# # ============================================================================
|
| 121 |
+
# # HELPER FUNCTIONS
|
| 122 |
+
# # ============================================================================
|
| 123 |
+
|
| 124 |
+
# def clean_text(text):
|
| 125 |
+
# """Clean and normalize text"""
|
| 126 |
+
# if not text:
|
| 127 |
+
# return ""
|
| 128 |
+
# text = str(text).lower()
|
| 129 |
+
# text = re.sub(r'[^\w\s-]', ' ', text)
|
| 130 |
+
# text = re.sub(r'\s+', ' ', text).strip()
|
| 131 |
+
# return text
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# def extract_cross_store_terms(text):
|
| 135 |
+
# """Extract terms with cross-store variations"""
|
| 136 |
+
# cleaned = clean_text(text)
|
| 137 |
+
# words = cleaned.split()
|
| 138 |
+
|
| 139 |
+
# all_terms = set()
|
| 140 |
+
# all_terms.add(cleaned) # Full text
|
| 141 |
+
|
| 142 |
+
# # Single words
|
| 143 |
+
# for word in words:
|
| 144 |
+
# if len(word) > 2:
|
| 145 |
+
# all_terms.add(word)
|
| 146 |
+
# # Add cross-store synonyms
|
| 147 |
+
# if word in cross_store_synonyms:
|
| 148 |
+
# all_terms.update(cross_store_synonyms[word])
|
| 149 |
+
|
| 150 |
+
# # 2-word phrases
|
| 151 |
+
# for i in range(len(words) - 1):
|
| 152 |
+
# if len(words[i]) > 2 and len(words[i+1]) > 2:
|
| 153 |
+
# phrase = f"{words[i]} {words[i+1]}"
|
| 154 |
+
# all_terms.add(phrase)
|
| 155 |
+
# if phrase in cross_store_synonyms:
|
| 156 |
+
# all_terms.update(cross_store_synonyms[phrase])
|
| 157 |
+
|
| 158 |
+
# # 3-word phrases
|
| 159 |
+
# if len(words) >= 3:
|
| 160 |
+
# for i in range(len(words) - 2):
|
| 161 |
+
# if all(len(w) > 2 for w in words[i:i+3]):
|
| 162 |
+
# phrase = f"{words[i]} {words[i+1]} {words[i+2]}"
|
| 163 |
+
# all_terms.add(phrase)
|
| 164 |
+
|
| 165 |
+
# return list(all_terms)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# def build_enhanced_query(title, description=""):
|
| 169 |
+
# """Build enhanced query with cross-store intelligence"""
|
| 170 |
+
# # Extract terms with variations
|
| 171 |
+
# all_terms = extract_cross_store_terms(f"{title} {description}")
|
| 172 |
+
|
| 173 |
+
# # Clean product terms
|
| 174 |
+
# product_terms = [t for t in clean_text(f"{title} {description}").split() if len(t) > 2]
|
| 175 |
+
|
| 176 |
+
# # Build query
|
| 177 |
+
# # Emphasize original + all variations
|
| 178 |
+
# product_text = ' '.join(product_terms)
|
| 179 |
+
# variations_text = ' '.join(all_terms[:30]) # Top 30 variations
|
| 180 |
+
|
| 181 |
+
# # Repeat for emphasis
|
| 182 |
+
# emphasized = ' '.join([product_text] * 3)
|
| 183 |
+
|
| 184 |
+
# query = f"{emphasized} {variations_text} {title} {description}"
|
| 185 |
+
|
| 186 |
+
# return query, all_terms[:20]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# def encode_query(text):
|
| 190 |
+
# """Encode query using the trained model"""
|
| 191 |
+
# embedding = encoder.encode(
|
| 192 |
+
# text,
|
| 193 |
+
# convert_to_numpy=True,
|
| 194 |
+
# normalize_embeddings=True
|
| 195 |
+
# )
|
| 196 |
+
|
| 197 |
+
# if embedding.ndim == 1:
|
| 198 |
+
# embedding = embedding.reshape(1, -1)
|
| 199 |
+
|
| 200 |
+
# return embedding.astype('float32')
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# def classify_product(title, description="", top_k=5):
|
| 204 |
+
# """
|
| 205 |
+
# Classify product using trained system
|
| 206 |
+
# Returns: category_id, category_path, confidence, and alternatives
|
| 207 |
+
# """
|
| 208 |
+
# start_time = time.time()
|
| 209 |
+
|
| 210 |
+
# # Step 1: Build enhanced query with cross-store synonyms
|
| 211 |
+
# query, matched_terms = build_enhanced_query(title, description)
|
| 212 |
+
|
| 213 |
+
# # Step 2: Encode query
|
| 214 |
+
# query_embedding = encode_query(query)
|
| 215 |
+
|
| 216 |
+
# # Step 3: Search FAISS index
|
| 217 |
+
# distances, indices = faiss_index.search(query_embedding, top_k)
|
| 218 |
+
|
| 219 |
+
# # Step 4: Get results
|
| 220 |
+
# results = []
|
| 221 |
+
# for i in range(len(indices[0])):
|
| 222 |
+
# idx = indices[0][i]
|
| 223 |
+
# if idx < len(metadata):
|
| 224 |
+
# meta = metadata[idx]
|
| 225 |
+
# confidence = float(distances[0][i]) * 100
|
| 226 |
+
|
| 227 |
+
# # Get final product name
|
| 228 |
+
# levels = meta.get('levels', [])
|
| 229 |
+
# final_product = levels[-1] if levels else meta['category_path'].split('/')[-1]
|
| 230 |
+
|
| 231 |
+
# results.append({
|
| 232 |
+
# 'rank': i + 1,
|
| 233 |
+
# 'category_id': meta['category_id'],
|
| 234 |
+
# 'category_path': meta['category_path'],
|
| 235 |
+
# 'final_product': final_product,
|
| 236 |
+
# 'confidence': round(confidence, 2),
|
| 237 |
+
# 'depth': meta.get('depth', 0)
|
| 238 |
+
# })
|
| 239 |
+
|
| 240 |
+
# # Best result
|
| 241 |
+
# best = results[0] if results else None
|
| 242 |
+
|
| 243 |
+
# if not best:
|
| 244 |
+
# return {
|
| 245 |
+
# 'error': 'No results found',
|
| 246 |
+
# 'product': title
|
| 247 |
+
# }
|
| 248 |
+
|
| 249 |
+
# # Confidence level
|
| 250 |
+
# conf_pct = best['confidence']
|
| 251 |
+
# if conf_pct >= 90:
|
| 252 |
+
# conf_level = "EXCELLENT"
|
| 253 |
+
# elif conf_pct >= 85:
|
| 254 |
+
# conf_level = "VERY HIGH"
|
| 255 |
+
# elif conf_pct >= 80:
|
| 256 |
+
# conf_level = "HIGH"
|
| 257 |
+
# elif conf_pct >= 75:
|
| 258 |
+
# conf_level = "GOOD"
|
| 259 |
+
# elif conf_pct >= 70:
|
| 260 |
+
# conf_level = "MEDIUM"
|
| 261 |
+
# else:
|
| 262 |
+
# conf_level = "LOW"
|
| 263 |
+
|
| 264 |
+
# processing_time = (time.time() - start_time) * 1000
|
| 265 |
+
|
| 266 |
+
# return {
|
| 267 |
+
# 'product': title,
|
| 268 |
+
# 'category_id': best['category_id'],
|
| 269 |
+
# 'category_path': best['category_path'],
|
| 270 |
+
# 'final_product': best['final_product'],
|
| 271 |
+
# 'confidence': f"{conf_level} ({conf_pct:.2f}%)",
|
| 272 |
+
# 'confidence_percent': conf_pct,
|
| 273 |
+
# 'depth': best['depth'],
|
| 274 |
+
# 'matched_terms': matched_terms,
|
| 275 |
+
# 'top_5_results': results,
|
| 276 |
+
# 'processing_time_ms': round(processing_time, 2)
|
| 277 |
+
# }
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# # ============================================================================
|
| 281 |
+
# # SERVER INITIALIZATION
|
| 282 |
+
# # ============================================================================
|
| 283 |
+
|
| 284 |
+
# def load_server():
|
| 285 |
+
# """Load all trained data"""
|
| 286 |
+
# global encoder, faiss_index, metadata, cross_store_synonyms
|
| 287 |
+
|
| 288 |
+
# print("\n" + "="*80)
|
| 289 |
+
# print("π LOADING TRAINED MODEL")
|
| 290 |
+
# print("="*80 + "\n")
|
| 291 |
+
|
| 292 |
+
# # Load model
|
| 293 |
+
# print("π₯ Loading sentence transformer...")
|
| 294 |
+
# encoder = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
| 295 |
+
# print("β
Model loaded\n")
|
| 296 |
+
|
| 297 |
+
# # Load FAISS index
|
| 298 |
+
# print("π₯ Loading FAISS index...")
|
| 299 |
+
# index_path = CACHE_DIR / 'main_index.faiss'
|
| 300 |
+
# if not index_path.exists():
|
| 301 |
+
# raise FileNotFoundError(f"FAISS index not found: {index_path}\nPlease run training first!")
|
| 302 |
+
# faiss_index = faiss.read_index(str(index_path))
|
| 303 |
+
# print(f"β
Index loaded ({faiss_index.ntotal:,} vectors)\n")
|
| 304 |
+
|
| 305 |
+
# # Load metadata
|
| 306 |
+
# print("π₯ Loading metadata...")
|
| 307 |
+
# meta_path = CACHE_DIR / 'metadata.pkl'
|
| 308 |
+
# if not meta_path.exists():
|
| 309 |
+
# raise FileNotFoundError(f"Metadata not found: {meta_path}\nPlease run training first!")
|
| 310 |
+
# with open(meta_path, 'rb') as f:
|
| 311 |
+
# metadata = pickle.load(f)
|
| 312 |
+
# print(f"β
Metadata loaded ({len(metadata):,} categories)\n")
|
| 313 |
+
|
| 314 |
+
# # Load cross-store synonyms
|
| 315 |
+
# print("π₯ Loading cross-store synonyms...")
|
| 316 |
+
# syn_path = CACHE_DIR / 'cross_store_synonyms.pkl'
|
| 317 |
+
# if syn_path.exists():
|
| 318 |
+
# with open(syn_path, 'rb') as f:
|
| 319 |
+
# cross_store_synonyms = pickle.load(f)
|
| 320 |
+
# print(f"β
Cross-store synonyms loaded ({len(cross_store_synonyms)} terms)\n")
|
| 321 |
+
# else:
|
| 322 |
+
# print("β οΈ Cross-store synonyms not found, building default set...")
|
| 323 |
+
# cross_store_synonyms = build_cross_store_synonyms()
|
| 324 |
+
# print(f"β
Built {len(cross_store_synonyms)} synonym mappings\n")
|
| 325 |
+
|
| 326 |
+
# print("="*80)
|
| 327 |
+
# print("β
SERVER READY!")
|
| 328 |
+
# print("="*80 + "\n")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# # ============================================================================
|
| 332 |
+
# # HTML INTERFACE
|
| 333 |
+
# # ============================================================================
|
| 334 |
+
|
| 335 |
+
# HTML_TEMPLATE = """
|
| 336 |
+
# <!DOCTYPE html>
|
| 337 |
+
# <html>
|
| 338 |
+
# <head>
|
| 339 |
+
# <title>π― Product Category Classifier</title>
|
| 340 |
+
# <meta charset="UTF-8">
|
| 341 |
+
# <meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 342 |
+
# <style>
|
| 343 |
+
# * { margin: 0; padding: 0; box-sizing: border-box; }
|
| 344 |
+
# body {
|
| 345 |
+
# font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 346 |
+
# background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 347 |
+
# min-height: 100vh;
|
| 348 |
+
# padding: 20px;
|
| 349 |
+
# }
|
| 350 |
+
# .container { max-width: 1200px; margin: 0 auto; }
|
| 351 |
+
# .header {
|
| 352 |
+
# text-align: center;
|
| 353 |
+
# color: white;
|
| 354 |
+
# margin-bottom: 30px;
|
| 355 |
+
# }
|
| 356 |
+
# .header h1 { font-size: 2.5em; margin-bottom: 10px; }
|
| 357 |
+
# .badge {
|
| 358 |
+
# background: rgba(255,255,255,0.2);
|
| 359 |
+
# padding: 8px 20px;
|
| 360 |
+
# border-radius: 20px;
|
| 361 |
+
# display: inline-block;
|
| 362 |
+
# margin: 5px;
|
| 363 |
+
# font-size: 0.9em;
|
| 364 |
+
# }
|
| 365 |
+
# .card {
|
| 366 |
+
# background: white;
|
| 367 |
+
# border-radius: 20px;
|
| 368 |
+
# padding: 30px;
|
| 369 |
+
# box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
| 370 |
+
# }
|
| 371 |
+
# .success-box {
|
| 372 |
+
# background: #d4edda;
|
| 373 |
+
# padding: 15px;
|
| 374 |
+
# border-radius: 8px;
|
| 375 |
+
# margin-bottom: 20px;
|
| 376 |
+
# border-left: 4px solid #28a745;
|
| 377 |
+
# color: #155724;
|
| 378 |
+
# }
|
| 379 |
+
# .form-group { margin-bottom: 20px; }
|
| 380 |
+
# label {
|
| 381 |
+
# display: block;
|
| 382 |
+
# font-weight: 600;
|
| 383 |
+
# margin-bottom: 8px;
|
| 384 |
+
# color: #333;
|
| 385 |
+
# }
|
| 386 |
+
# input, textarea {
|
| 387 |
+
# width: 100%;
|
| 388 |
+
# padding: 12px;
|
| 389 |
+
# border: 2px solid #e0e0e0;
|
| 390 |
+
# border-radius: 8px;
|
| 391 |
+
# font-size: 1em;
|
| 392 |
+
# }
|
| 393 |
+
# input:focus, textarea:focus {
|
| 394 |
+
# outline: none;
|
| 395 |
+
# border-color: #667eea;
|
| 396 |
+
# }
|
| 397 |
+
# textarea { min-height: 80px; resize: vertical; }
|
| 398 |
+
# button {
|
| 399 |
+
# width: 100%;
|
| 400 |
+
# padding: 15px;
|
| 401 |
+
# background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 402 |
+
# color: white;
|
| 403 |
+
# border: none;
|
| 404 |
+
# border-radius: 10px;
|
| 405 |
+
# font-size: 1.1em;
|
| 406 |
+
# cursor: pointer;
|
| 407 |
+
# font-weight: 600;
|
| 408 |
+
# transition: transform 0.2s;
|
| 409 |
+
# }
|
| 410 |
+
# button:hover { transform: translateY(-2px); }
|
| 411 |
+
# .results { display: none; margin-top: 20px; }
|
| 412 |
+
# .results.show { display: block; animation: fadeIn 0.5s; }
|
| 413 |
+
# @keyframes fadeIn {
|
| 414 |
+
# from { opacity: 0; transform: translateY(10px); }
|
| 415 |
+
# to { opacity: 1; transform: translateY(0); }
|
| 416 |
+
# }
|
| 417 |
+
# .section {
|
| 418 |
+
# background: #f8f9fa;
|
| 419 |
+
# padding: 20px;
|
| 420 |
+
# border-radius: 12px;
|
| 421 |
+
# margin-bottom: 15px;
|
| 422 |
+
# border-left: 4px solid #667eea;
|
| 423 |
+
# }
|
| 424 |
+
# .section h3 { color: #667eea; margin-bottom: 12px; }
|
| 425 |
+
# .result-item {
|
| 426 |
+
# background: white;
|
| 427 |
+
# padding: 15px;
|
| 428 |
+
# border-radius: 8px;
|
| 429 |
+
# margin-bottom: 10px;
|
| 430 |
+
# border-left: 3px solid #667eea;
|
| 431 |
+
# }
|
| 432 |
+
# .tag {
|
| 433 |
+
# display: inline-block;
|
| 434 |
+
# background: #667eea;
|
| 435 |
+
# color: white;
|
| 436 |
+
# padding: 6px 12px;
|
| 437 |
+
# border-radius: 15px;
|
| 438 |
+
# margin: 3px;
|
| 439 |
+
# font-size: 0.9em;
|
| 440 |
+
# }
|
| 441 |
+
# .conf-excellent { background: #4caf50; }
|
| 442 |
+
# .conf-very { background: #8bc34a; }
|
| 443 |
+
# .conf-high { background: #cddc39; color: #333; }
|
| 444 |
+
# .conf-good { background: #ff9800; }
|
| 445 |
+
# .conf-medium { background: #ff5722; }
|
| 446 |
+
# .conf-low { background: #9e9e9e; }
|
| 447 |
+
# .loading { display: none; text-align: center; padding: 20px; }
|
| 448 |
+
# .loading.show { display: block; }
|
| 449 |
+
# .spinner {
|
| 450 |
+
# border: 4px solid #f3f3f3;
|
| 451 |
+
# border-top: 4px solid #667eea;
|
| 452 |
+
# border-radius: 50%;
|
| 453 |
+
# width: 40px;
|
| 454 |
+
# height: 40px;
|
| 455 |
+
# animation: spin 1s linear infinite;
|
| 456 |
+
# margin: 0 auto;
|
| 457 |
+
# }
|
| 458 |
+
# @keyframes spin {
|
| 459 |
+
# 0% { transform: rotate(0deg); }
|
| 460 |
+
# 100% { transform: rotate(360deg); }
|
| 461 |
+
# }
|
| 462 |
+
# </style>
|
| 463 |
+
# </head>
|
| 464 |
+
# <body>
|
| 465 |
+
# <div class="container">
|
| 466 |
+
# <div class="header">
|
| 467 |
+
# <h1>π― Product Category Classifier</h1>
|
| 468 |
+
# <div class="badge">Cross-Store Intelligence</div>
|
| 469 |
+
# <div class="badge">Auto-Tag Support</div>
|
| 470 |
+
# <div class="badge">Real-Time</div>
|
| 471 |
+
# </div>
|
| 472 |
+
|
| 473 |
+
# <div class="card">
|
| 474 |
+
# <div class="success-box">
|
| 475 |
+
# <strong>β
Cross-Store Synonyms Active!</strong><br>
|
| 476 |
+
# Understands: washing machine = laundry machine | tv = television | kids = children
|
| 477 |
+
# </div>
|
| 478 |
+
|
| 479 |
+
# <div class="form-group">
|
| 480 |
+
# <label>Product Title *</label>
|
| 481 |
+
# <input type="text" id="title" placeholder="e.g., Washing Machine or Laundry Machine" />
|
| 482 |
+
# </div>
|
| 483 |
+
|
| 484 |
+
# <div class="form-group">
|
| 485 |
+
# <label>Description (Optional)</label>
|
| 486 |
+
# <textarea id="desc" placeholder="Additional details..."></textarea>
|
| 487 |
+
# </div>
|
| 488 |
+
|
| 489 |
+
# <button onclick="classify()">π― Classify Product</button>
|
| 490 |
+
|
| 491 |
+
# <div class="loading" id="loading">
|
| 492 |
+
# <div class="spinner"></div>
|
| 493 |
+
# <p style="margin-top: 10px; color: #666;">Analyzing...</p>
|
| 494 |
+
# </div>
|
| 495 |
+
|
| 496 |
+
# <div class="results" id="results">
|
| 497 |
+
# <div class="section">
|
| 498 |
+
# <h3>β
Best Match</h3>
|
| 499 |
+
# <div class="result-item">
|
| 500 |
+
# <div style="margin-bottom: 10px;">
|
| 501 |
+
# <strong>Product:</strong> <span id="product"></span>
|
| 502 |
+
# </div>
|
| 503 |
+
# <div style="margin-bottom: 10px;">
|
| 504 |
+
# <strong>Category ID:</strong>
|
| 505 |
+
# <span id="catId" style="font-size: 1.2em; color: #28a745; font-weight: bold;"></span>
|
| 506 |
+
# </div>
|
| 507 |
+
# <div style="margin-bottom: 10px;">
|
| 508 |
+
# <strong>Final Product:</strong> <span id="finalProd" style="font-weight: 600;"></span>
|
| 509 |
+
# </div>
|
| 510 |
+
# <div style="margin-bottom: 10px;">
|
| 511 |
+
# <strong>Full Path:</strong><br>
|
| 512 |
+
# <span id="path" style="color: #666; font-size: 0.95em;"></span>
|
| 513 |
+
# </div>
|
| 514 |
+
# <div style="margin-bottom: 10px;">
|
| 515 |
+
# <strong>Confidence:</strong>
|
| 516 |
+
# <span id="confidence" class="tag"></span>
|
| 517 |
+
# </div>
|
| 518 |
+
# <div style="font-size: 0.9em; color: #666;">
|
| 519 |
+
# <strong>Depth:</strong> <span id="depth"></span> levels |
|
| 520 |
+
# <strong>Time:</strong> <span id="time"></span>ms
|
| 521 |
+
# </div>
|
| 522 |
+
# </div>
|
| 523 |
+
# </div>
|
| 524 |
+
|
| 525 |
+
# <div class="section">
|
| 526 |
+
# <h3>π Matched Terms (Cross-Store Variations)</h3>
|
| 527 |
+
# <div id="matchedTerms"></div>
|
| 528 |
+
# </div>
|
| 529 |
+
|
| 530 |
+
# <div class="section">
|
| 531 |
+
# <h3>π Top 5 Alternative Matches</h3>
|
| 532 |
+
# <div id="alternatives"></div>
|
| 533 |
+
# </div>
|
| 534 |
+
# </div>
|
| 535 |
+
# </div>
|
| 536 |
+
# </div>
|
| 537 |
+
|
| 538 |
+
# <script>
|
| 539 |
+
# async function classify() {
|
| 540 |
+
# const title = document.getElementById('title').value.trim();
|
| 541 |
+
# const desc = document.getElementById('desc').value.trim();
|
| 542 |
+
|
| 543 |
+
# if (!title) {
|
| 544 |
+
# alert('Please enter a product title');
|
| 545 |
+
# return;
|
| 546 |
+
# }
|
| 547 |
+
|
| 548 |
+
# document.getElementById('loading').classList.add('show');
|
| 549 |
+
# document.getElementById('results').classList.remove('show');
|
| 550 |
+
|
| 551 |
+
# try {
|
| 552 |
+
# const response = await fetch('/classify', {
|
| 553 |
+
# method: 'POST',
|
| 554 |
+
# headers: { 'Content-Type': 'application/json' },
|
| 555 |
+
# body: JSON.stringify({ title, description: desc })
|
| 556 |
+
# });
|
| 557 |
+
|
| 558 |
+
# if (!response.ok) throw new Error('Classification failed');
|
| 559 |
+
|
| 560 |
+
# const data = await response.json();
|
| 561 |
+
# displayResults(data);
|
| 562 |
+
# } catch (error) {
|
| 563 |
+
# alert('Error: ' + error.message);
|
| 564 |
+
# } finally {
|
| 565 |
+
# document.getElementById('loading').classList.remove('show');
|
| 566 |
+
# }
|
| 567 |
+
# }
|
| 568 |
+
|
| 569 |
+
# function displayResults(data) {
|
| 570 |
+
# document.getElementById('results').classList.add('show');
|
| 571 |
+
|
| 572 |
+
# document.getElementById('product').textContent = data.product;
|
| 573 |
+
# document.getElementById('catId').textContent = data.category_id;
|
| 574 |
+
# document.getElementById('finalProd').textContent = data.final_product;
|
| 575 |
+
# document.getElementById('path').textContent = data.category_path;
|
| 576 |
+
# document.getElementById('depth').textContent = data.depth;
|
| 577 |
+
# document.getElementById('time').textContent = data.processing_time_ms;
|
| 578 |
+
|
| 579 |
+
# const conf = document.getElementById('confidence');
|
| 580 |
+
# conf.textContent = data.confidence;
|
| 581 |
+
# const confClass = data.confidence.split(' ')[0].toLowerCase().replace('_', '-');
|
| 582 |
+
# conf.className = 'tag conf-' + confClass;
|
| 583 |
+
|
| 584 |
+
# const matchedHtml = data.matched_terms.map(t => `<span class="tag">${t}</span>`).join('');
|
| 585 |
+
# document.getElementById('matchedTerms').innerHTML = matchedHtml;
|
| 586 |
+
|
| 587 |
+
# let altHtml = '';
|
| 588 |
+
# data.top_5_results.forEach((item, i) => {
|
| 589 |
+
# const cls = i === 0 ? 'style="background: #e8f5e9;"' : '';
|
| 590 |
+
# altHtml += `
|
| 591 |
+
# <div class="result-item" ${cls}>
|
| 592 |
+
# <strong>${item.rank}.</strong> ${item.final_product}
|
| 593 |
+
# <span class="tag" style="background: #999;">${item.confidence}%</span>
|
| 594 |
+
# <div style="font-size: 0.85em; color: #666; margin-top: 5px;">
|
| 595 |
+
# ID: ${item.category_id}
|
| 596 |
+
# </div>
|
| 597 |
+
# </div>
|
| 598 |
+
# `;
|
| 599 |
+
# });
|
| 600 |
+
# document.getElementById('alternatives').innerHTML = altHtml;
|
| 601 |
+
# }
|
| 602 |
+
|
| 603 |
+
# document.getElementById('title').addEventListener('keypress', function(e) {
|
| 604 |
+
# if (e.key === 'Enter') classify();
|
| 605 |
+
# });
|
| 606 |
+
# </script>
|
| 607 |
+
# </body>
|
| 608 |
+
# </html>
|
| 609 |
+
# """
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
# # ============================================================================
|
| 613 |
+
# # FLASK ROUTES
|
| 614 |
+
# # ============================================================================
|
| 615 |
+
|
| 616 |
+
# @app.route('/')
|
| 617 |
+
# def index():
|
| 618 |
+
# """Serve the web interface"""
|
| 619 |
+
# return render_template_string(HTML_TEMPLATE)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
# @app.route('/classify', methods=['POST'])
|
| 623 |
+
# def classify_route():
|
| 624 |
+
# """API endpoint for classification"""
|
| 625 |
+
# data = request.json
|
| 626 |
+
# title = data.get('title', '').strip()
|
| 627 |
+
# description = data.get('description', '').strip()
|
| 628 |
+
|
| 629 |
+
# if not title:
|
| 630 |
+
# return jsonify({'error': 'Title required'}), 400
|
| 631 |
+
|
| 632 |
+
# try:
|
| 633 |
+
# result = classify_product(title, description)
|
| 634 |
+
# return jsonify(result)
|
| 635 |
+
# except Exception as e:
|
| 636 |
+
# print(f"Error: {e}")
|
| 637 |
+
# return jsonify({'error': str(e)}), 500
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# @app.route('/health')
|
| 641 |
+
# def health():
|
| 642 |
+
# """Health check endpoint"""
|
| 643 |
+
# return jsonify({
|
| 644 |
+
# 'status': 'healthy',
|
| 645 |
+
# 'categories': len(metadata),
|
| 646 |
+
# 'cross_store_synonyms': len(cross_store_synonyms),
|
| 647 |
+
# 'model': 'all-mpnet-base-v2'
|
| 648 |
+
# })
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
# # ============================================================================
|
| 652 |
+
# # MAIN
|
| 653 |
+
# # ============================================================================
|
| 654 |
+
|
| 655 |
+
# if __name__ == '__main__':
|
| 656 |
+
# try:
|
| 657 |
+
# load_server()
|
| 658 |
+
|
| 659 |
+
# print("\nπ Server starting...")
|
| 660 |
+
# print(" URL: http://localhost:5000")
|
| 661 |
+
# print(" Press CTRL+C to stop\n")
|
| 662 |
+
|
| 663 |
+
# app.run(host='0.0.0.0', port=5000, debug=False)
|
| 664 |
+
|
| 665 |
+
# except FileNotFoundError as e:
|
| 666 |
+
# print(f"\nβ ERROR: {e}")
|
| 667 |
+
# print("\nπ‘ Solution: Run training first:")
|
| 668 |
+
# print(" python train.py data/category_id_path_only.csv\n")
|
| 669 |
+
# except Exception as e:
|
| 670 |
+
# print(f"\nβ UNEXPECTED ERROR: {e}\n")
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
#!/usr/bin/env python3
|
| 677 |
+
"""
|
| 678 |
+
API Server for product category classification
|
| 679 |
+
Merged UI + classification logic
|
| 680 |
+
Model: intfloat/e5-base-v2 (must match training)
|
| 681 |
+
|
| 682 |
+
Usage:
|
| 683 |
+
python api_server.py
|
| 684 |
+
|
| 685 |
+
Requirements:
|
| 686 |
+
pip install flask sentence-transformers faiss-cpu numpy pickle5
|
| 687 |
+
|
| 688 |
+
Files expected in cache/:
|
| 689 |
+
- main_index.faiss
|
| 690 |
+
- metadata.pkl
|
| 691 |
+
- cross_store_synonyms.pkl (optional)
|
| 692 |
+
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
from flask import Flask, request, jsonify, render_template_string
|
| 696 |
+
from sentence_transformers import SentenceTransformer
|
| 697 |
+
import faiss
|
| 698 |
+
import pickle
|
| 699 |
+
import numpy as np
|
| 700 |
+
from pathlib import Path
|
| 701 |
+
import time
|
| 702 |
+
import re
|
| 703 |
+
import os
|
| 704 |
+
from typing import List
|
| 705 |
+
|
| 706 |
+
# ============================================================================
|
| 707 |
+
# CONFIG
|
| 708 |
+
# ============================================================================
|
| 709 |
+
|
| 710 |
+
CACHE_DIR = Path('cache')
|
| 711 |
+
MODEL_NAME = 'intfloat/e5-base-v2' # <-- MUST match the model used during training
|
| 712 |
+
FAISS_INDEX_PATH = CACHE_DIR / 'main_index.faiss'
|
| 713 |
+
METADATA_PATH = CACHE_DIR / 'metadata.pkl'
|
| 714 |
+
SYN_PATH = CACHE_DIR / 'cross_store_synonyms.pkl'
|
| 715 |
+
|
| 716 |
+
# Server globals
|
| 717 |
+
encoder = None
|
| 718 |
+
faiss_index = None
|
| 719 |
+
metadata = []
|
| 720 |
+
cross_store_synonyms = {}
|
| 721 |
+
|
| 722 |
+
# ============================================================================
|
| 723 |
+
# CROSS-STORE SYNONYM FALLBACK
|
| 724 |
+
# ============================================================================
|
| 725 |
+
|
| 726 |
+
def build_cross_store_synonyms():
|
| 727 |
+
"""Default cross-store synonyms fallback (bidirectional mapping).
|
| 728 |
+
If you have a trained cross_store_synonyms.pkl produced by training, the
|
| 729 |
+
server will load that file instead. This function only used when no file
|
| 730 |
+
exists in the cache.
|
| 731 |
+
"""
|
| 732 |
+
synonyms = {
|
| 733 |
+
'washing machine': {'laundry machine', 'washer', 'clothes washer', 'washing appliance'},
|
| 734 |
+
'laundry machine': {'washing machine', 'washer', 'clothes washer'},
|
| 735 |
+
'dryer': {'drying machine', 'clothes dryer', 'tumble dryer'},
|
| 736 |
+
'refrigerator': {'fridge', 'cooler', 'ice box', 'cooling appliance'},
|
| 737 |
+
'dishwasher': {'dish washer', 'dish cleaning machine'},
|
| 738 |
+
'microwave': {'microwave oven', 'micro wave'},
|
| 739 |
+
'vacuum': {'vacuum cleaner', 'hoover', 'vac'},
|
| 740 |
+
'tv': {'television', 'telly', 'smart tv', 'display'},
|
| 741 |
+
'laptop': {'notebook', 'portable computer', 'laptop computer'},
|
| 742 |
+
'mobile': {'phone', 'cell phone', 'smartphone', 'cellphone'},
|
| 743 |
+
'tablet': {'ipad', 'tab', 'tablet computer'},
|
| 744 |
+
'headphones': {'headset', 'earphones', 'earbuds', 'ear buds'},
|
| 745 |
+
'speaker': {'audio speaker', 'sound system', 'speakers'},
|
| 746 |
+
'sofa': {'couch', 'settee', 'divan'},
|
| 747 |
+
'wardrobe': {'closet', 'armoire', 'cupboard'},
|
| 748 |
+
'drawer': {'chest of drawers', 'dresser'},
|
| 749 |
+
'pants': {'trousers', 'slacks', 'bottoms'},
|
| 750 |
+
'sweater': {'jumper', 'pullover', 'sweatshirt'},
|
| 751 |
+
'sneakers': {'trainers', 'tennis shoes', 'running shoes'},
|
| 752 |
+
'jacket': {'coat', 'blazer', 'outerwear'},
|
| 753 |
+
'cooker': {'stove', 'range', 'cooking range'},
|
| 754 |
+
'blender': {'mixer', 'food processor', 'liquidizer'},
|
| 755 |
+
'kettle': {'electric kettle', 'water boiler'},
|
| 756 |
+
'stroller': {'pram', 'pushchair', 'buggy', 'baby carriage'},
|
| 757 |
+
'diaper': {'nappy', 'nappies'},
|
| 758 |
+
'pacifier': {'dummy', 'soother'},
|
| 759 |
+
'wrench': {'spanner', 'adjustable wrench'},
|
| 760 |
+
'flashlight': {'torch', 'flash light'},
|
| 761 |
+
'screwdriver': {'screw driver'},
|
| 762 |
+
'tap': {'faucet', 'water tap'},
|
| 763 |
+
'bin': {'trash can', 'garbage can', 'waste bin'},
|
| 764 |
+
'curtain': {'drape', 'window covering'},
|
| 765 |
+
'guillotine': {'paper cutter', 'paper trimmer', 'blade cutter'},
|
| 766 |
+
'trimmer': {'cutter', 'cutting tool', 'edge cutter'},
|
| 767 |
+
'stapler': {'stapling machine', 'staple gun'},
|
| 768 |
+
'magazine': {'periodical', 'journal', 'publication'},
|
| 769 |
+
'comic': {'comic book', 'graphic novel', 'manga'},
|
| 770 |
+
'ebook': {'e-book', 'digital book', 'electronic book'},
|
| 771 |
+
'kids': {'children', 'child', 'childrens', 'youth', 'junior'},
|
| 772 |
+
'women': {'womens', 'ladies', 'female', 'lady'},
|
| 773 |
+
'men': {'mens', 'male', 'gentleman'},
|
| 774 |
+
'baby': {'infant', 'newborn', 'toddler'},
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
expanded = {}
|
| 778 |
+
for term, syns in synonyms.items():
|
| 779 |
+
expanded[term] = set(syns)
|
| 780 |
+
for syn in syns:
|
| 781 |
+
if syn not in expanded:
|
| 782 |
+
expanded[syn] = set()
|
| 783 |
+
expanded[syn].add(term)
|
| 784 |
+
expanded[syn].update(syns - {syn})
|
| 785 |
+
return expanded
|
| 786 |
+
|
| 787 |
+
# ============================================================================
|
| 788 |
+
# TEXT CLEANING / QUERY BUILDING
|
| 789 |
+
# ============================================================================
|
| 790 |
+
|
| 791 |
+
def clean_text(text: str) -> str:
|
| 792 |
+
if not text:
|
| 793 |
+
return ""
|
| 794 |
+
text = str(text).lower()
|
| 795 |
+
# keep alphanumerics, dashes and spaces
|
| 796 |
+
text = re.sub(r"[^\w\s-]", " ", text)
|
| 797 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 798 |
+
return text
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def extract_cross_store_terms(text: str) -> List[str]:
|
| 802 |
+
cleaned = clean_text(text)
|
| 803 |
+
words = cleaned.split()
|
| 804 |
+
|
| 805 |
+
all_terms = set()
|
| 806 |
+
all_terms.add(cleaned) # full cleaned text
|
| 807 |
+
|
| 808 |
+
# single words + synonyms
|
| 809 |
+
for word in words:
|
| 810 |
+
if len(word) > 2:
|
| 811 |
+
all_terms.add(word)
|
| 812 |
+
if word in cross_store_synonyms:
|
| 813 |
+
all_terms.update(cross_store_synonyms[word])
|
| 814 |
+
|
| 815 |
+
# 2-word phrases
|
| 816 |
+
for i in range(len(words) - 1):
|
| 817 |
+
if len(words[i]) > 2 and len(words[i + 1]) > 2:
|
| 818 |
+
phrase = f"{words[i]} {words[i+1]}"
|
| 819 |
+
all_terms.add(phrase)
|
| 820 |
+
if phrase in cross_store_synonyms:
|
| 821 |
+
all_terms.update(cross_store_synonyms[phrase])
|
| 822 |
+
|
| 823 |
+
# 3-word phrases
|
| 824 |
+
if len(words) >= 3:
|
| 825 |
+
for i in range(len(words) - 2):
|
| 826 |
+
if all(len(w) > 2 for w in words[i:i + 3]):
|
| 827 |
+
phrase = f"{words[i]} {words[i+1]} {words[i+2]}"
|
| 828 |
+
all_terms.add(phrase)
|
| 829 |
+
|
| 830 |
+
return list(all_terms)
|
| 831 |
+
|
| 832 |
+
def build_enhanced_query(title, description="", max_synonyms=10):
|
| 833 |
+
"""Build query emphasizing original title and cross-store variations"""
|
| 834 |
+
title_clean = clean_text(title)
|
| 835 |
+
description_clean = clean_text(description)
|
| 836 |
+
|
| 837 |
+
# Extract cross-store variations
|
| 838 |
+
synonyms_list = extract_cross_store_terms(f"{title_clean} {description_clean}")
|
| 839 |
+
|
| 840 |
+
# Emphasize original title 3x, then include top synonyms
|
| 841 |
+
enhanced_query = ' '.join([title_clean] * 3 + synonyms_list[:max_synonyms])
|
| 842 |
+
|
| 843 |
+
return enhanced_query, synonyms_list[:20] # return top 20 for matched_terms display
|
| 844 |
+
|
| 845 |
+
# ============================================================================
|
| 846 |
+
# ENCODER / FAISS
|
| 847 |
+
# ============================================================================
|
| 848 |
+
|
| 849 |
+
def encode_query(text: str) -> np.ndarray:
|
| 850 |
+
embedding = encoder.encode(text, convert_to_numpy=True, normalize_embeddings=True)
|
| 851 |
+
if embedding.ndim == 1:
|
| 852 |
+
embedding = embedding.reshape(1, -1)
|
| 853 |
+
return embedding.astype('float32')
|
| 854 |
+
|
| 855 |
+
def classify_product(title, description="", top_k=5):
|
| 856 |
+
"""Classify product using e5-base embeddings with cross-store optimization"""
|
| 857 |
+
start_time = time.time()
|
| 858 |
+
|
| 859 |
+
# Step 1: Build enhanced query
|
| 860 |
+
query_text, matched_terms = build_enhanced_query(title, description)
|
| 861 |
+
|
| 862 |
+
# Step 2: Encode query
|
| 863 |
+
query_embedding = encoder.encode(
|
| 864 |
+
query_text,
|
| 865 |
+
convert_to_numpy=True,
|
| 866 |
+
normalize_embeddings=True
|
| 867 |
+
).astype('float32')
|
| 868 |
+
|
| 869 |
+
if query_embedding.ndim == 1:
|
| 870 |
+
query_embedding = query_embedding.reshape(1, -1)
|
| 871 |
+
|
| 872 |
+
# Step 3: FAISS search
|
| 873 |
+
distances, indices = faiss_index.search(query_embedding, top_k)
|
| 874 |
+
|
| 875 |
+
results = []
|
| 876 |
+
for i, idx in enumerate(indices[0]):
|
| 877 |
+
if idx >= len(metadata):
|
| 878 |
+
continue
|
| 879 |
+
meta = metadata[idx]
|
| 880 |
+
# Convert FAISS distance to cosine similarity
|
| 881 |
+
similarity = 1 - distances[0][i]
|
| 882 |
+
confidence_pct = float(similarity) * 100
|
| 883 |
+
|
| 884 |
+
final_product = meta.get('levels', [])[-1] if meta.get('levels') else meta['category_path'].split('/')[-1]
|
| 885 |
+
|
| 886 |
+
results.append({
|
| 887 |
+
'rank': i + 1,
|
| 888 |
+
'category_id': meta['category_id'],
|
| 889 |
+
'category_path': meta['category_path'],
|
| 890 |
+
'final_product': final_product,
|
| 891 |
+
'confidence': round(confidence_pct, 2),
|
| 892 |
+
'depth': meta.get('depth', 0)
|
| 893 |
+
})
|
| 894 |
+
|
| 895 |
+
if not results:
|
| 896 |
+
return {'error': 'No results found', 'product': title}
|
| 897 |
+
|
| 898 |
+
# Pick best match
|
| 899 |
+
best = results[0]
|
| 900 |
+
conf_pct = best['confidence']
|
| 901 |
+
if conf_pct >= 90:
|
| 902 |
+
conf_level = "EXCELLENT"
|
| 903 |
+
elif conf_pct >= 85:
|
| 904 |
+
conf_level = "VERY HIGH"
|
| 905 |
+
elif conf_pct >= 80:
|
| 906 |
+
conf_level = "HIGH"
|
| 907 |
+
elif conf_pct >= 75:
|
| 908 |
+
conf_level = "GOOD"
|
| 909 |
+
elif conf_pct >= 70:
|
| 910 |
+
conf_level = "MEDIUM"
|
| 911 |
+
else:
|
| 912 |
+
conf_level = "LOW"
|
| 913 |
+
|
| 914 |
+
processing_time = (time.time() - start_time) * 1000
|
| 915 |
+
|
| 916 |
+
return {
|
| 917 |
+
'product': title,
|
| 918 |
+
'category_id': best['category_id'],
|
| 919 |
+
'category_path': best['category_path'],
|
| 920 |
+
'final_product': best['final_product'],
|
| 921 |
+
'confidence': f"{conf_level} ({conf_pct:.2f}%)",
|
| 922 |
+
'confidence_percent': conf_pct,
|
| 923 |
+
'depth': best['depth'],
|
| 924 |
+
'matched_terms': matched_terms,
|
| 925 |
+
'top_5_results': results,
|
| 926 |
+
'processing_time_ms': round(processing_time, 2)
|
| 927 |
+
}
|
| 928 |
+
# FAISS returns squared L2 distances or inner product depending on index type.
|
| 929 |
+
# We'll treat lower distance as better. We convert to a 0-100-ish confidence by
|
| 930 |
+
# using a simple heuristic: score = 100 - normalized_distance*100 (clamped).
|
| 931 |
+
|
| 932 |
+
# Determine a normalization constant: use mean of top distance if available
|
| 933 |
+
flat_dist = distances[0]
|
| 934 |
+
max_d = float(np.max(flat_dist)) if flat_dist.size else 1.0
|
| 935 |
+
min_d = float(np.min(flat_dist)) if flat_dist.size else 0.0
|
| 936 |
+
range_d = max(1e-6, max_d - min_d)
|
| 937 |
+
|
| 938 |
+
for i, idx in enumerate(indices[0]):
|
| 939 |
+
if idx < 0 or idx >= len(metadata):
|
| 940 |
+
continue
|
| 941 |
+
meta = metadata[idx]
|
| 942 |
+
raw_d = float(distances[0][i])
|
| 943 |
+
# normalize and invert to make higher -> better
|
| 944 |
+
norm = (raw_d - min_d) / range_d
|
| 945 |
+
conf = max(0.0, min(100.0, 100.0 * (1.0 - norm)))
|
| 946 |
+
|
| 947 |
+
levels = meta.get('levels') or []
|
| 948 |
+
final_product = levels[-1] if levels else meta.get('category_path', '').split('/')[-1]
|
| 949 |
+
|
| 950 |
+
results.append({
|
| 951 |
+
'rank': i + 1,
|
| 952 |
+
'category_id': meta.get('category_id'),
|
| 953 |
+
'category_path': meta.get('category_path'),
|
| 954 |
+
'final_product': final_product,
|
| 955 |
+
'confidence': round(conf, 2),
|
| 956 |
+
'depth': meta.get('depth', 0)
|
| 957 |
+
})
|
| 958 |
+
|
| 959 |
+
if not results:
|
| 960 |
+
return {
|
| 961 |
+
'error': 'No results found',
|
| 962 |
+
'product': title
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
best = results[0]
|
| 966 |
+
conf_pct = best['confidence']
|
| 967 |
+
if conf_pct >= 90:
|
| 968 |
+
conf_level = "EXCELLENT"
|
| 969 |
+
elif conf_pct >= 85:
|
| 970 |
+
conf_level = "VERY HIGH"
|
| 971 |
+
elif conf_pct >= 80:
|
| 972 |
+
conf_level = "HIGH"
|
| 973 |
+
elif conf_pct >= 75:
|
| 974 |
+
conf_level = "GOOD"
|
| 975 |
+
elif conf_pct >= 70:
|
| 976 |
+
conf_level = "MEDIUM"
|
| 977 |
+
else:
|
| 978 |
+
conf_level = "LOW"
|
| 979 |
+
|
| 980 |
+
processing_time = (time.time() - start_time) * 1000.0
|
| 981 |
+
|
| 982 |
+
return {
|
| 983 |
+
'product': title,
|
| 984 |
+
'category_id': best['category_id'],
|
| 985 |
+
'category_path': best['category_path'],
|
| 986 |
+
'final_product': best['final_product'],
|
| 987 |
+
'confidence': f"{conf_level} ({conf_pct:.2f}%)",
|
| 988 |
+
'confidence_percent': conf_pct,
|
| 989 |
+
'depth': best['depth'],
|
| 990 |
+
'matched_terms': matched_terms,
|
| 991 |
+
'top_5_results': results,
|
| 992 |
+
'processing_time_ms': round(processing_time, 2)
|
| 993 |
+
}
|
| 994 |
+
|
| 995 |
+
# ============================================================================
|
| 996 |
+
# SERVER LOAD
|
| 997 |
+
# ============================================================================
|
| 998 |
+
|
| 999 |
+
def load_server():
|
| 1000 |
+
global encoder, faiss_index, metadata, cross_store_synonyms
|
| 1001 |
+
|
| 1002 |
+
print('\n' + '=' * 80)
|
| 1003 |
+
print('π LOADING TRAINED MODEL')
|
| 1004 |
+
print('=' * 80 + '\n')
|
| 1005 |
+
|
| 1006 |
+
# Load encoder
|
| 1007 |
+
print('π₯ Loading sentence transformer...')
|
| 1008 |
+
encoder = SentenceTransformer(MODEL_NAME)
|
| 1009 |
+
print('β
Model loaded\n')
|
| 1010 |
+
|
| 1011 |
+
# Load FAISS index
|
| 1012 |
+
print('π₯ Loading FAISS index...')
|
| 1013 |
+
if not FAISS_INDEX_PATH.exists():
|
| 1014 |
+
raise FileNotFoundError(f"FAISS index not found: {FAISS_INDEX_PATH}\nPlease run training first!")
|
| 1015 |
+
faiss_index = faiss.read_index(str(FAISS_INDEX_PATH))
|
| 1016 |
+
print(f"β
Index loaded ({faiss_index.ntotal:,} vectors)\n")
|
| 1017 |
+
|
| 1018 |
+
# Load metadata
|
| 1019 |
+
print('π₯ Loading metadata...')
|
| 1020 |
+
if not METADATA_PATH.exists():
|
| 1021 |
+
raise FileNotFoundError(f"Metadata not found: {METADATA_PATH}\nPlease run training first!")
|
| 1022 |
+
with open(METADATA_PATH, 'rb') as f:
|
| 1023 |
+
metadata = pickle.load(f)
|
| 1024 |
+
print(f"β
Metadata loaded ({len(metadata):,} categories)\n")
|
| 1025 |
+
|
| 1026 |
+
# Load or build cross-store synonyms
|
| 1027 |
+
print('π₯ Loading cross-store synonyms...')
|
| 1028 |
+
if SYN_PATH.exists():
|
| 1029 |
+
with open(SYN_PATH, 'rb') as f:
|
| 1030 |
+
cross_store_synonyms = pickle.load(f)
|
| 1031 |
+
print(f"β
Cross-store synonyms loaded ({len(cross_store_synonyms)} terms)\n")
|
| 1032 |
+
else:
|
| 1033 |
+
print('β οΈ Cross-store synonyms not found, building default set...')
|
| 1034 |
+
cross_store_synonyms = build_cross_store_synonyms()
|
| 1035 |
+
print(f"β
Built {len(cross_store_synonyms)} synonym mappings\n")
|
| 1036 |
+
|
| 1037 |
+
print('=' * 80)
|
| 1038 |
+
print('β
SERVER READY!')
|
| 1039 |
+
print('=' * 80 + '\n')
|
| 1040 |
+
|
| 1041 |
+
# ============================================================================
|
| 1042 |
+
# HTML TEMPLATE (same as provided)
|
| 1043 |
+
# ============================================================================
|
| 1044 |
+
|
| 1045 |
+
HTML_TEMPLATE = r"""
|
| 1046 |
+
<!DOCTYPE html>
|
| 1047 |
+
<html>
|
| 1048 |
+
<head>
|
| 1049 |
+
<title>π― Product Category Classifier</title>
|
| 1050 |
+
<meta charset="UTF-8">
|
| 1051 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 1052 |
+
<style>
|
| 1053 |
+
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 1054 |
+
body {
|
| 1055 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 1056 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 1057 |
+
min-height: 100vh;
|
| 1058 |
+
padding: 20px;
|
| 1059 |
+
}
|
| 1060 |
+
.container { max-width: 1200px; margin: 0 auto; }
|
| 1061 |
+
.header {
|
| 1062 |
+
text-align: center;
|
| 1063 |
+
color: white;
|
| 1064 |
+
margin-bottom: 30px;
|
| 1065 |
+
}
|
| 1066 |
+
.header h1 { font-size: 2.5em; margin-bottom: 10px; }
|
| 1067 |
+
.badge {
|
| 1068 |
+
background: rgba(255,255,255,0.2);
|
| 1069 |
+
padding: 8px 20px;
|
| 1070 |
+
border-radius: 20px;
|
| 1071 |
+
display: inline-block;
|
| 1072 |
+
margin: 5px;
|
| 1073 |
+
font-size: 0.9em;
|
| 1074 |
+
}
|
| 1075 |
+
.card {
|
| 1076 |
+
background: white;
|
| 1077 |
+
border-radius: 20px;
|
| 1078 |
+
padding: 30px;
|
| 1079 |
+
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
| 1080 |
+
}
|
| 1081 |
+
.success-box {
|
| 1082 |
+
background: #d4edda;
|
| 1083 |
+
padding: 15px;
|
| 1084 |
+
border-radius: 8px;
|
| 1085 |
+
margin-bottom: 20px;
|
| 1086 |
+
border-left: 4px solid #28a745;
|
| 1087 |
+
color: #155724;
|
| 1088 |
+
}
|
| 1089 |
+
.form-group { margin-bottom: 20px; }
|
| 1090 |
+
label {
|
| 1091 |
+
display: block;
|
| 1092 |
+
font-weight: 600;
|
| 1093 |
+
margin-bottom: 8px;
|
| 1094 |
+
color: #333;
|
| 1095 |
+
}
|
| 1096 |
+
input, textarea {
|
| 1097 |
+
width: 100%;
|
| 1098 |
+
padding: 12px;
|
| 1099 |
+
border: 2px solid #e0e0e0;
|
| 1100 |
+
border-radius: 8px;
|
| 1101 |
+
font-size: 1em;
|
| 1102 |
+
}
|
| 1103 |
+
input:focus, textarea:focus {
|
| 1104 |
+
outline: none;
|
| 1105 |
+
border-color: #667eea;
|
| 1106 |
+
}
|
| 1107 |
+
textarea { min-height: 80px; resize: vertical; }
|
| 1108 |
+
button {
|
| 1109 |
+
width: 100%;
|
| 1110 |
+
padding: 15px;
|
| 1111 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 1112 |
+
color: white;
|
| 1113 |
+
border: none;
|
| 1114 |
+
border-radius: 10px;
|
| 1115 |
+
font-size: 1.1em;
|
| 1116 |
+
cursor: pointer;
|
| 1117 |
+
font-weight: 600;
|
| 1118 |
+
transition: transform 0.2s;
|
| 1119 |
+
}
|
| 1120 |
+
button:hover { transform: translateY(-2px); }
|
| 1121 |
+
.results { display: none; margin-top: 20px; }
|
| 1122 |
+
.results.show { display: block; animation: fadeIn 0.5s; }
|
| 1123 |
+
@keyframes fadeIn {
|
| 1124 |
+
from { opacity: 0; transform: translateY(10px); }
|
| 1125 |
+
to { opacity: 1; transform: translateY(0); }
|
| 1126 |
+
}
|
| 1127 |
+
.section {
|
| 1128 |
+
background: #f8f9fa;
|
| 1129 |
+
padding: 20px;
|
| 1130 |
+
border-radius: 12px;
|
| 1131 |
+
margin-bottom: 15px;
|
| 1132 |
+
border-left: 4px solid #667eea;
|
| 1133 |
+
}
|
| 1134 |
+
.section h3 { color: #667eea; margin-bottom: 12px; }
|
| 1135 |
+
.result-item {
|
| 1136 |
+
background: white;
|
| 1137 |
+
padding: 15px;
|
| 1138 |
+
border-radius: 8px;
|
| 1139 |
+
margin-bottom: 10px;
|
| 1140 |
+
border-left: 3px solid #667eea;
|
| 1141 |
+
}
|
| 1142 |
+
.tag {
|
| 1143 |
+
display: inline-block;
|
| 1144 |
+
background: #667eea;
|
| 1145 |
+
color: white;
|
| 1146 |
+
padding: 6px 12px;
|
| 1147 |
+
border-radius: 15px;
|
| 1148 |
+
margin: 3px;
|
| 1149 |
+
font-size: 0.9em;
|
| 1150 |
+
}
|
| 1151 |
+
.conf-excellent { background: #4caf50; }
|
| 1152 |
+
.conf-very { background: #8bc34a; }
|
| 1153 |
+
.conf-high { background: #cddc39; color: #333; }
|
| 1154 |
+
.conf-good { background: #ff9800; }
|
| 1155 |
+
.conf-medium { background: #ff5722; }
|
| 1156 |
+
.conf-low { background: #9e9e9e; }
|
| 1157 |
+
.loading { display: none; text-align: center; padding: 20px; }
|
| 1158 |
+
.loading.show { display: block; }
|
| 1159 |
+
.spinner {
|
| 1160 |
+
border: 4px solid #f3f3f3;
|
| 1161 |
+
border-top: 4px solid #667eea;
|
| 1162 |
+
border-radius: 50%;
|
| 1163 |
+
width: 40px;
|
| 1164 |
+
height: 40px;
|
| 1165 |
+
animation: spin 1s linear infinite;
|
| 1166 |
+
margin: 0 auto;
|
| 1167 |
+
}
|
| 1168 |
+
@keyframes spin {
|
| 1169 |
+
0% { transform: rotate(0deg); }
|
| 1170 |
+
100% { transform: rotate(360deg); }
|
| 1171 |
+
}
|
| 1172 |
+
</style>
|
| 1173 |
+
</head>
|
| 1174 |
+
<body>
|
| 1175 |
+
<div class="container">
|
| 1176 |
+
<div class="header">
|
| 1177 |
+
<h1>π― Product Category Classifier</h1>
|
| 1178 |
+
<div class="badge">Cross-Store Intelligence</div>
|
| 1179 |
+
<div class="badge">Auto-Tag Support</div>
|
| 1180 |
+
<div class="badge">Real-Time</div>
|
| 1181 |
+
</div>
|
| 1182 |
+
|
| 1183 |
+
<div class="card">
|
| 1184 |
+
<div class="success-box">
|
| 1185 |
+
<strong>β
Cross-Store Synonyms Active!</strong><br>
|
| 1186 |
+
Understands: washing machine = laundry machine | tv = television | kids = children
|
| 1187 |
+
</div>
|
| 1188 |
+
|
| 1189 |
+
<div class="form-group">
|
| 1190 |
+
<label>Product Title *</label>
|
| 1191 |
+
<input type="text" id="title" placeholder="e.g., Washing Machine or Laundry Machine" />
|
| 1192 |
+
</div>
|
| 1193 |
+
|
| 1194 |
+
<div class="form-group">
|
| 1195 |
+
<label>Description (Optional)</label>
|
| 1196 |
+
<textarea id="desc" placeholder="Additional details..."></textarea>
|
| 1197 |
+
</div>
|
| 1198 |
+
|
| 1199 |
+
<button onclick="classify()">π― Classify Product</button>
|
| 1200 |
+
|
| 1201 |
+
<div class="loading" id="loading">
|
| 1202 |
+
<div class="spinner"></div>
|
| 1203 |
+
<p style="margin-top: 10px; color: #666;">Analyzing...</p>
|
| 1204 |
+
</div>
|
| 1205 |
+
|
| 1206 |
+
<div class="results" id="results">
|
| 1207 |
+
<div class="section">
|
| 1208 |
+
<h3>β
Best Match</h3>
|
| 1209 |
+
<div class="result-item">
|
| 1210 |
+
<div style="margin-bottom: 10px;">
|
| 1211 |
+
<strong>Product:</strong> <span id="product"></span>
|
| 1212 |
+
</div>
|
| 1213 |
+
<div style="margin-bottom: 10px;">
|
| 1214 |
+
<strong>Category ID:</strong>
|
| 1215 |
+
<span id="catId" style="font-size: 1.2em; color: #28a745; font-weight: bold;"></span>
|
| 1216 |
+
</div>
|
| 1217 |
+
<div style="margin-bottom: 10px;">
|
| 1218 |
+
<strong>Final Product:</strong> <span id="finalProd" style="font-weight: 600;"></span>
|
| 1219 |
+
</div>
|
| 1220 |
+
<div style="margin-bottom: 10px;">
|
| 1221 |
+
<strong>Full Path:</strong><br>
|
| 1222 |
+
<span id="path" style="color: #666; font-size: 0.95em;"></span>
|
| 1223 |
+
</div>
|
| 1224 |
+
<div style="margin-bottom: 10px;">
|
| 1225 |
+
<strong>Confidence:</strong>
|
| 1226 |
+
<span id="confidence" class="tag"></span>
|
| 1227 |
+
</div>
|
| 1228 |
+
<div style="font-size: 0.9em; color: #666;">
|
| 1229 |
+
<strong>Depth:</strong> <span id="depth"></span> levels |
|
| 1230 |
+
<strong>Time:</strong> <span id="time"></span>ms
|
| 1231 |
+
</div>
|
| 1232 |
+
</div>
|
| 1233 |
+
</div>
|
| 1234 |
+
|
| 1235 |
+
<div class="section">
|
| 1236 |
+
<h3>π Matched Terms (Cross-Store Variations)</h3>
|
| 1237 |
+
<div id="matchedTerms"></div>
|
| 1238 |
+
</div>
|
| 1239 |
+
|
| 1240 |
+
<div class="section">
|
| 1241 |
+
<h3>π Top 5 Alternative Matches</h3>
|
| 1242 |
+
<div id="alternatives"></div>
|
| 1243 |
+
</div>
|
| 1244 |
+
</div>
|
| 1245 |
+
</div>
|
| 1246 |
+
</div>
|
| 1247 |
+
|
| 1248 |
+
<script>
|
| 1249 |
+
async function classify() {
|
| 1250 |
+
const title = document.getElementById('title').value.trim();
|
| 1251 |
+
const desc = document.getElementById('desc').value.trim();
|
| 1252 |
+
|
| 1253 |
+
if (!title) {
|
| 1254 |
+
alert('Please enter a product title');
|
| 1255 |
+
return;
|
| 1256 |
+
}
|
| 1257 |
+
|
| 1258 |
+
document.getElementById('loading').classList.add('show');
|
| 1259 |
+
document.getElementById('results').classList.remove('show');
|
| 1260 |
+
|
| 1261 |
+
try {
|
| 1262 |
+
const response = await fetch('/classify', {
|
| 1263 |
+
method: 'POST',
|
| 1264 |
+
headers: { 'Content-Type': 'application/json' },
|
| 1265 |
+
body: JSON.stringify({ title, description: desc })
|
| 1266 |
+
});
|
| 1267 |
+
|
| 1268 |
+
if (!response.ok) throw new Error('Classification failed');
|
| 1269 |
+
|
| 1270 |
+
const data = await response.json();
|
| 1271 |
+
displayResults(data);
|
| 1272 |
+
} catch (error) {
|
| 1273 |
+
alert('Error: ' + error.message);
|
| 1274 |
+
} finally {
|
| 1275 |
+
document.getElementById('loading').classList.remove('show');
|
| 1276 |
+
}
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
function displayResults(data) {
|
| 1280 |
+
document.getElementById('results').classList.add('show');
|
| 1281 |
+
|
| 1282 |
+
document.getElementById('product').textContent = data.product;
|
| 1283 |
+
document.getElementById('catId').textContent = data.category_id;
|
| 1284 |
+
document.getElementById('finalProd').textContent = data.final_product;
|
| 1285 |
+
document.getElementById('path').textContent = data.category_path;
|
| 1286 |
+
document.getElementById('depth').textContent = data.depth;
|
| 1287 |
+
document.getElementById('time').textContent = data.processing_time_ms;
|
| 1288 |
+
|
| 1289 |
+
const conf = document.getElementById('confidence');
|
| 1290 |
+
conf.textContent = data.confidence;
|
| 1291 |
+
const confClass = data.confidence.split(' ')[0].toLowerCase().replace('_', '-');
|
| 1292 |
+
conf.className = 'tag conf-' + confClass;
|
| 1293 |
+
|
| 1294 |
+
const matchedHtml = data.matched_terms.map(t => `<span class="tag">${t}</span>`).join('');
|
| 1295 |
+
document.getElementById('matchedTerms').innerHTML = matchedHtml;
|
| 1296 |
+
|
| 1297 |
+
let altHtml = '';
|
| 1298 |
+
data.top_5_results.forEach((item, i) => {
|
| 1299 |
+
const cls = i === 0 ? 'style="background: #e8f5e9;"' : '';
|
| 1300 |
+
altHtml += `
|
| 1301 |
+
<div class="result-item" ${cls}>
|
| 1302 |
+
<strong>${item.rank}.</strong> ${item.final_product}
|
| 1303 |
+
<span class="tag" style="background: #999;">${item.confidence}%</span>
|
| 1304 |
+
<div style="font-size: 0.85em; color: #666; margin-top: 5px;">
|
| 1305 |
+
ID: ${item.category_id}
|
| 1306 |
+
</div>
|
| 1307 |
+
</div>
|
| 1308 |
+
`;
|
| 1309 |
+
});
|
| 1310 |
+
document.getElementById('alternatives').innerHTML = altHtml;
|
| 1311 |
+
}
|
| 1312 |
+
|
| 1313 |
+
document.getElementById('title').addEventListener('keypress', function(e) {
|
| 1314 |
+
if (e.key === 'Enter') classify();
|
| 1315 |
+
});
|
| 1316 |
+
</script>
|
| 1317 |
+
</body>
|
| 1318 |
+
</html>
|
| 1319 |
+
"""
|
| 1320 |
+
|
| 1321 |
+
# ============================================================================
|
| 1322 |
+
# FLASK APP
|
| 1323 |
+
# ============================================================================
|
| 1324 |
+
|
| 1325 |
+
app = Flask(__name__)
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
@app.route('/')
|
| 1329 |
+
def index():
|
| 1330 |
+
return render_template_string(HTML_TEMPLATE)
|
| 1331 |
+
|
| 1332 |
+
|
| 1333 |
+
@app.route('/classify', methods=['POST'])
|
| 1334 |
+
def classify_route():
|
| 1335 |
+
data = request.get_json(force=True)
|
| 1336 |
+
title = data.get('title', '').strip()
|
| 1337 |
+
description = data.get('description', '').strip()
|
| 1338 |
+
|
| 1339 |
+
if not title:
|
| 1340 |
+
return jsonify({'error': 'Title required'}), 400
|
| 1341 |
+
|
| 1342 |
+
try:
|
| 1343 |
+
result = classify_product(title, description)
|
| 1344 |
+
return jsonify(result)
|
| 1345 |
+
except Exception as e:
|
| 1346 |
+
app.logger.exception('Classification error')
|
| 1347 |
+
return jsonify({'error': str(e)}), 500
|
| 1348 |
+
|
| 1349 |
+
|
| 1350 |
+
@app.route('/health')
|
| 1351 |
+
def health():
|
| 1352 |
+
return jsonify({
|
| 1353 |
+
'status': 'healthy',
|
| 1354 |
+
'categories': len(metadata),
|
| 1355 |
+
'cross_store_synonyms': len(cross_store_synonyms),
|
| 1356 |
+
'model': MODEL_NAME
|
| 1357 |
+
})
|
| 1358 |
+
|
| 1359 |
+
|
| 1360 |
+
# ============================================================================
|
| 1361 |
+
# MAIN
|
| 1362 |
+
# ============================================================================
|
| 1363 |
+
|
| 1364 |
+
if __name__ == '__main__':
|
| 1365 |
+
try:
|
| 1366 |
+
load_server()
|
| 1367 |
+
print('\nπ Server starting...')
|
| 1368 |
+
print(' URL: http://localhost:5000')
|
| 1369 |
+
print(' Press CTRL+C to stop\n')
|
| 1370 |
+
# Recommended: run with a production server like gunicorn for production use
|
| 1371 |
+
app.run(host='0.0.0.0', port=5000, debug=False)
|
| 1372 |
+
except FileNotFoundError as e:
|
| 1373 |
+
print(f"\nβ ERROR: {e}")
|
| 1374 |
+
print('\nπ‘ Solution: Run training first to create FAISS index and metadata')
|
| 1375 |
+
except Exception as e:
|
| 1376 |
+
print(f"\nβ UNEXPECTED ERROR: {e}\n")
|
| 1377 |
+
|
check.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
π§ DIAGNOSTIC AND FIX TOOL
|
| 3 |
+
===========================
|
| 4 |
+
Analyzes your trained model and fixes common issues causing low confidence.
|
| 5 |
+
|
| 6 |
+
Issues it detects and fixes:
|
| 7 |
+
1. Column name mismatches (Category_ID vs category_id)
|
| 8 |
+
2. Missing or corrupted tags.json
|
| 9 |
+
3. Wrong metadata format in cache
|
| 10 |
+
4. FAISS index mismatch
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python diagnose_and_fix.py
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import pickle
|
| 17 |
+
import json
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import numpy as np
|
| 20 |
+
import faiss
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from sentence_transformers import SentenceTransformer
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
def check_cache_files():
|
| 26 |
+
"""Check what files exist in cache"""
|
| 27 |
+
cache_dir = Path('cache')
|
| 28 |
+
|
| 29 |
+
print("\n" + "="*80)
|
| 30 |
+
print("π STEP 1: CHECKING CACHE FILES")
|
| 31 |
+
print("="*80 + "\n")
|
| 32 |
+
|
| 33 |
+
required_files = {
|
| 34 |
+
'main_index.faiss': cache_dir / 'main_index.faiss',
|
| 35 |
+
'metadata.pkl': cache_dir / 'metadata.pkl',
|
| 36 |
+
'model_info.json': cache_dir / 'model_info.json',
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
optional_files = {
|
| 40 |
+
'parent_embeddings.pkl': cache_dir / 'parent_embeddings.pkl',
|
| 41 |
+
'calibrator.pkl': cache_dir / 'calibrator.pkl',
|
| 42 |
+
'cross_store_synonyms.pkl': cache_dir / 'cross_store_synonyms.pkl',
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
issues = []
|
| 46 |
+
|
| 47 |
+
print("Required files:")
|
| 48 |
+
for name, path in required_files.items():
|
| 49 |
+
if path.exists():
|
| 50 |
+
size = path.stat().st_size / (1024 * 1024) # MB
|
| 51 |
+
print(f" β
{name} ({size:.2f} MB)")
|
| 52 |
+
else:
|
| 53 |
+
print(f" β {name} - MISSING")
|
| 54 |
+
issues.append(f"Missing required file: {name}")
|
| 55 |
+
|
| 56 |
+
print("\nOptional files:")
|
| 57 |
+
for name, path in optional_files.items():
|
| 58 |
+
if path.exists():
|
| 59 |
+
size = path.stat().st_size / (1024 * 1024)
|
| 60 |
+
print(f" β
{name} ({size:.2f} MB)")
|
| 61 |
+
else:
|
| 62 |
+
print(f" β οΈ {name} - not found")
|
| 63 |
+
|
| 64 |
+
return issues
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def check_csv_format():
|
| 68 |
+
"""Check CSV file format"""
|
| 69 |
+
print("\n" + "="*80)
|
| 70 |
+
print("π STEP 2: CHECKING CSV FORMAT")
|
| 71 |
+
print("="*80 + "\n")
|
| 72 |
+
|
| 73 |
+
csv_path = Path('data/category_only_path.csv')
|
| 74 |
+
|
| 75 |
+
if not csv_path.exists():
|
| 76 |
+
print("β CSV not found at: data/category_only_path.csv")
|
| 77 |
+
return ["CSV file not found"]
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
df = pd.read_csv(csv_path, nrows=5)
|
| 81 |
+
|
| 82 |
+
print(f"Columns found: {list(df.columns)}")
|
| 83 |
+
print(f"Total rows: {len(pd.read_csv(csv_path)):,}")
|
| 84 |
+
|
| 85 |
+
print("\nFirst 3 rows:")
|
| 86 |
+
print(df.head(3).to_string())
|
| 87 |
+
|
| 88 |
+
# Check column names
|
| 89 |
+
if 'Category_ID' in df.columns and 'Category_path' in df.columns:
|
| 90 |
+
print("\nβ
Column format: Uppercase (Category_ID, Category_path)")
|
| 91 |
+
return []
|
| 92 |
+
elif 'category_id' in df.columns and 'category_path' in df.columns:
|
| 93 |
+
print("\nβ
Column format: Lowercase (category_id, category_path)")
|
| 94 |
+
return []
|
| 95 |
+
else:
|
| 96 |
+
print("\nβ Unexpected column names!")
|
| 97 |
+
return ["CSV has wrong column names"]
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"\nβ Error reading CSV: {e}")
|
| 101 |
+
return [f"CSV read error: {e}"]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def check_metadata():
|
| 105 |
+
"""Check metadata format"""
|
| 106 |
+
print("\n" + "="*80)
|
| 107 |
+
print("π STEP 3: CHECKING METADATA FORMAT")
|
| 108 |
+
print("="*80 + "\n")
|
| 109 |
+
|
| 110 |
+
meta_path = Path('cache/metadata.pkl')
|
| 111 |
+
|
| 112 |
+
if not meta_path.exists():
|
| 113 |
+
print("β Metadata file not found")
|
| 114 |
+
return ["Metadata missing"]
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
with open(meta_path, 'rb') as f:
|
| 118 |
+
metadata = pickle.load(f)
|
| 119 |
+
|
| 120 |
+
print(f"Metadata entries: {len(metadata):,}")
|
| 121 |
+
|
| 122 |
+
if metadata:
|
| 123 |
+
sample = metadata[0]
|
| 124 |
+
print(f"\nSample entry:")
|
| 125 |
+
print(f" Keys: {list(sample.keys())}")
|
| 126 |
+
print(f" category_id: {sample.get('category_id', 'MISSING')}")
|
| 127 |
+
print(f" category_path: {sample.get('category_path', 'MISSING')[:50]}...")
|
| 128 |
+
|
| 129 |
+
# Check if all entries have required fields
|
| 130 |
+
missing_fields = []
|
| 131 |
+
for i, entry in enumerate(metadata[:100]):
|
| 132 |
+
if 'category_id' not in entry:
|
| 133 |
+
missing_fields.append(f"Entry {i}: missing category_id")
|
| 134 |
+
if 'category_path' not in entry:
|
| 135 |
+
missing_fields.append(f"Entry {i}: missing category_path")
|
| 136 |
+
|
| 137 |
+
if missing_fields:
|
| 138 |
+
print(f"\nβ Found {len(missing_fields)} entries with missing fields")
|
| 139 |
+
return missing_fields[:5] # Return first 5
|
| 140 |
+
else:
|
| 141 |
+
print("\nβ
All entries have required fields")
|
| 142 |
+
return []
|
| 143 |
+
else:
|
| 144 |
+
print("β Metadata is empty!")
|
| 145 |
+
return ["Empty metadata"]
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"β Error reading metadata: {e}")
|
| 149 |
+
return [f"Metadata error: {e}"]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def check_faiss_index():
|
| 153 |
+
"""Check FAISS index"""
|
| 154 |
+
print("\n" + "="*80)
|
| 155 |
+
print("π STEP 4: CHECKING FAISS INDEX")
|
| 156 |
+
print("="*80 + "\n")
|
| 157 |
+
|
| 158 |
+
index_path = Path('cache/main_index.faiss')
|
| 159 |
+
meta_path = Path('cache/metadata.pkl')
|
| 160 |
+
|
| 161 |
+
if not index_path.exists():
|
| 162 |
+
print("β FAISS index not found")
|
| 163 |
+
return ["FAISS index missing"]
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
index = faiss.read_index(str(index_path))
|
| 167 |
+
print(f"FAISS index vectors: {index.ntotal:,}")
|
| 168 |
+
print(f"Dimension: {index.d}")
|
| 169 |
+
|
| 170 |
+
with open(meta_path, 'rb') as f:
|
| 171 |
+
metadata = pickle.load(f)
|
| 172 |
+
|
| 173 |
+
print(f"Metadata entries: {len(metadata):,}")
|
| 174 |
+
|
| 175 |
+
if index.ntotal != len(metadata):
|
| 176 |
+
print(f"\nβ MISMATCH!")
|
| 177 |
+
print(f" FAISS has {index.ntotal:,} vectors")
|
| 178 |
+
print(f" Metadata has {len(metadata):,} entries")
|
| 179 |
+
return ["FAISS-metadata count mismatch"]
|
| 180 |
+
else:
|
| 181 |
+
print("\nβ
FAISS and metadata counts match")
|
| 182 |
+
return []
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f"β Error: {e}")
|
| 186 |
+
return [f"FAISS error: {e}"]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def check_tags_json():
|
| 190 |
+
"""Check tags.json"""
|
| 191 |
+
print("\n" + "="*80)
|
| 192 |
+
print("π STEP 5: CHECKING TAGS.JSON")
|
| 193 |
+
print("="*80 + "\n")
|
| 194 |
+
|
| 195 |
+
tags_path = Path('data/tags.json')
|
| 196 |
+
|
| 197 |
+
if not tags_path.exists():
|
| 198 |
+
print("β οΈ tags.json not found - this will reduce accuracy!")
|
| 199 |
+
print(" Expected location: data/tags.json")
|
| 200 |
+
return ["tags.json missing"]
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
with open(tags_path, 'r') as f:
|
| 204 |
+
tags = json.load(f)
|
| 205 |
+
|
| 206 |
+
print(f"Tags for {len(tags):,} categories")
|
| 207 |
+
|
| 208 |
+
if tags:
|
| 209 |
+
sample_key = list(tags.keys())[0]
|
| 210 |
+
sample_tags = tags[sample_key]
|
| 211 |
+
|
| 212 |
+
print(f"\nSample category: {sample_key}")
|
| 213 |
+
print(f"Tags ({len(sample_tags)}): {', '.join(sample_tags[:5])}...")
|
| 214 |
+
|
| 215 |
+
# Check average tags per category
|
| 216 |
+
tag_counts = [len(t) for t in tags.values() if isinstance(t, list)]
|
| 217 |
+
avg_tags = sum(tag_counts) / len(tag_counts) if tag_counts else 0
|
| 218 |
+
|
| 219 |
+
print(f"\nAverage tags per category: {avg_tags:.1f}")
|
| 220 |
+
|
| 221 |
+
if avg_tags < 10:
|
| 222 |
+
print("β οΈ Very few tags - this will reduce accuracy")
|
| 223 |
+
return ["Too few tags per category"]
|
| 224 |
+
else:
|
| 225 |
+
print("β
Tags look good")
|
| 226 |
+
return []
|
| 227 |
+
else:
|
| 228 |
+
print("β tags.json is empty!")
|
| 229 |
+
return ["Empty tags.json"]
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(f"β Error: {e}")
|
| 233 |
+
return [f"tags.json error: {e}"]
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def test_prediction():
|
| 237 |
+
"""Test a sample prediction"""
|
| 238 |
+
print("\n" + "="*80)
|
| 239 |
+
print("π STEP 6: TESTING PREDICTION")
|
| 240 |
+
print("="*80 + "\n")
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
print("Loading model...")
|
| 244 |
+
encoder = SentenceTransformer('intfloat/e5-base-v2')
|
| 245 |
+
|
| 246 |
+
print("Loading FAISS index...")
|
| 247 |
+
index = faiss.read_index('cache/main_index.faiss')
|
| 248 |
+
|
| 249 |
+
print("Loading metadata...")
|
| 250 |
+
with open('cache/metadata.pkl', 'rb') as f:
|
| 251 |
+
metadata = pickle.load(f)
|
| 252 |
+
|
| 253 |
+
# Test query
|
| 254 |
+
test_query = "query: built in dishwasher"
|
| 255 |
+
|
| 256 |
+
print(f"\nTest query: \"{test_query}\"")
|
| 257 |
+
print("Encoding...")
|
| 258 |
+
|
| 259 |
+
query_emb = encoder.encode(test_query, convert_to_numpy=True, normalize_embeddings=True)
|
| 260 |
+
if query_emb.ndim == 1:
|
| 261 |
+
query_emb = query_emb.reshape(1, -1)
|
| 262 |
+
|
| 263 |
+
print("Searching...")
|
| 264 |
+
distances, indices = index.search(query_emb.astype('float32'), 5)
|
| 265 |
+
|
| 266 |
+
print("\nTop 5 results:")
|
| 267 |
+
for i in range(5):
|
| 268 |
+
idx = indices[0][i]
|
| 269 |
+
score = distances[0][i]
|
| 270 |
+
meta = metadata[idx]
|
| 271 |
+
|
| 272 |
+
print(f"\n{i+1}. Score: {score:.4f}")
|
| 273 |
+
print(f" ID: {meta.get('category_id', 'N/A')}")
|
| 274 |
+
print(f" Path: {meta.get('category_path', 'N/A')[:60]}...")
|
| 275 |
+
|
| 276 |
+
best_score = float(distances[0][0])
|
| 277 |
+
|
| 278 |
+
if best_score < 0.3:
|
| 279 |
+
print(f"\nβ VERY LOW CONFIDENCE: {best_score:.4f}")
|
| 280 |
+
print(" This indicates a serious problem with training!")
|
| 281 |
+
return ["Very low prediction scores"]
|
| 282 |
+
elif best_score < 0.5:
|
| 283 |
+
print(f"\nβ οΈ LOW CONFIDENCE: {best_score:.4f}")
|
| 284 |
+
print(" Model needs improvement")
|
| 285 |
+
return ["Low prediction scores"]
|
| 286 |
+
else:
|
| 287 |
+
print(f"\nβ
GOOD CONFIDENCE: {best_score:.4f}")
|
| 288 |
+
return []
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
print(f"\nβ Prediction test failed: {e}")
|
| 292 |
+
import traceback
|
| 293 |
+
traceback.print_exc()
|
| 294 |
+
return [f"Prediction error: {e}"]
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def generate_fix_commands(all_issues):
|
| 298 |
+
"""Generate commands to fix issues"""
|
| 299 |
+
print("\n" + "="*80)
|
| 300 |
+
print("π§ RECOMMENDED FIXES")
|
| 301 |
+
print("="*80 + "\n")
|
| 302 |
+
|
| 303 |
+
if not all_issues:
|
| 304 |
+
print("β
No critical issues found!")
|
| 305 |
+
print("\nIf you're still experiencing low confidence:")
|
| 306 |
+
print(" 1. Make sure you're using tags.json")
|
| 307 |
+
print(" 2. Check if validation.csv is being used for calibration")
|
| 308 |
+
print(" 3. Verify CSV has correct column names")
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
print("Issues found:")
|
| 312 |
+
for i, issue in enumerate(all_issues, 1):
|
| 313 |
+
print(f" {i}. {issue}")
|
| 314 |
+
|
| 315 |
+
print("\n" + "="*80)
|
| 316 |
+
print("FIX STEPS:")
|
| 317 |
+
print("="*80 + "\n")
|
| 318 |
+
|
| 319 |
+
if any('missing' in issue.lower() or 'mismatch' in issue.lower() or 'low' in issue.lower() for issue in all_issues):
|
| 320 |
+
print("π RE-TRAINING REQUIRED")
|
| 321 |
+
print("\nRun these commands in order:\n")
|
| 322 |
+
|
| 323 |
+
print("# Step 1: Generate tags (if missing)")
|
| 324 |
+
print("python generate_hybrid_tags.py data/category_only_path.csv data/tags.json")
|
| 325 |
+
print()
|
| 326 |
+
|
| 327 |
+
print("# Step 2: Generate validation data (for calibration)")
|
| 328 |
+
print("python create_validation_data.py auto data/category_only_path.csv 200")
|
| 329 |
+
print()
|
| 330 |
+
|
| 331 |
+
print("# Step 3: Train with ALL fixes")
|
| 332 |
+
print("python train_fixed_v2.py data/category_only_path.csv data/tags.json data/validation.csv")
|
| 333 |
+
print()
|
| 334 |
+
else:
|
| 335 |
+
print("β
No retraining needed - minor issues only")
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def main():
|
| 339 |
+
print("\n" + "="*80)
|
| 340 |
+
print("π§ DIAGNOSTIC AND FIX TOOL")
|
| 341 |
+
print("="*80)
|
| 342 |
+
print("\nThis tool will analyze your model and identify issues\n")
|
| 343 |
+
|
| 344 |
+
all_issues = []
|
| 345 |
+
|
| 346 |
+
# Run all checks
|
| 347 |
+
all_issues.extend(check_cache_files())
|
| 348 |
+
all_issues.extend(check_csv_format())
|
| 349 |
+
all_issues.extend(check_metadata())
|
| 350 |
+
all_issues.extend(check_faiss_index())
|
| 351 |
+
all_issues.extend(check_tags_json())
|
| 352 |
+
all_issues.extend(test_prediction())
|
| 353 |
+
|
| 354 |
+
# Generate fixes
|
| 355 |
+
generate_fix_commands(all_issues)
|
| 356 |
+
|
| 357 |
+
print("\n" + "="*80)
|
| 358 |
+
print("π DIAGNOSIS COMPLETE")
|
| 359 |
+
print("="*80)
|
| 360 |
+
print(f"\nTotal issues found: {len(all_issues)}")
|
| 361 |
+
print("\n")
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
main()
|
fix.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
π§ AUTOMATIC EMBEDDING & INDEX FIXER
|
| 3 |
+
====================================
|
| 4 |
+
Fixes common issues causing low confidence scores
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python fix_embeddings.py normalize # Fix normalization
|
| 8 |
+
python fix_embeddings.py rebuild-index # Rebuild FAISS
|
| 9 |
+
python fix_embeddings.py full-fix # Do everything
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import faiss
|
| 14 |
+
import pickle
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
import warnings
|
| 19 |
+
warnings.filterwarnings('ignore')
|
| 20 |
+
|
| 21 |
+
class EmbeddingFixer:
|
| 22 |
+
def __init__(self, cache_dir='cache'):
|
| 23 |
+
self.cache_dir = Path(cache_dir)
|
| 24 |
+
|
| 25 |
+
def banner(self, text):
|
| 26 |
+
print("\n" + "="*80)
|
| 27 |
+
print(f"π§ {text}")
|
| 28 |
+
print("="*80 + "\n")
|
| 29 |
+
|
| 30 |
+
def backup_files(self):
|
| 31 |
+
"""Backup existing files"""
|
| 32 |
+
self.banner("CREATING BACKUPS")
|
| 33 |
+
|
| 34 |
+
backup_dir = self.cache_dir / 'backup'
|
| 35 |
+
backup_dir.mkdir(exist_ok=True)
|
| 36 |
+
|
| 37 |
+
files_to_backup = [
|
| 38 |
+
'embeddings.npy',
|
| 39 |
+
'main_index.faiss',
|
| 40 |
+
'metadata.pkl'
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
for filename in files_to_backup:
|
| 44 |
+
src = self.cache_dir / filename
|
| 45 |
+
if src.exists():
|
| 46 |
+
dst = backup_dir / filename
|
| 47 |
+
import shutil
|
| 48 |
+
shutil.copy2(src, dst)
|
| 49 |
+
print(f"β
Backed up: {filename}")
|
| 50 |
+
|
| 51 |
+
print(f"\nπ Backups saved to: {backup_dir}")
|
| 52 |
+
|
| 53 |
+
def normalize_embeddings(self):
|
| 54 |
+
"""Normalize embeddings to unit length"""
|
| 55 |
+
self.banner("NORMALIZING EMBEDDINGS")
|
| 56 |
+
|
| 57 |
+
emb_path = self.cache_dir / 'embeddings.npy'
|
| 58 |
+
|
| 59 |
+
if not emb_path.exists():
|
| 60 |
+
print("β embeddings.npy not found!")
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
print("Loading embeddings...")
|
| 64 |
+
embeddings = np.load(emb_path)
|
| 65 |
+
|
| 66 |
+
print(f"Original shape: {embeddings.shape}")
|
| 67 |
+
|
| 68 |
+
# Check current normalization
|
| 69 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 70 |
+
print(f"Mean norm before: {norms.mean():.6f}")
|
| 71 |
+
print(f"Std norm before: {norms.std():.6f}")
|
| 72 |
+
|
| 73 |
+
# Normalize
|
| 74 |
+
print("\nNormalizing...")
|
| 75 |
+
embeddings_normalized = embeddings / (norms + 1e-8)
|
| 76 |
+
|
| 77 |
+
# Verify
|
| 78 |
+
norms_after = np.linalg.norm(embeddings_normalized, axis=1)
|
| 79 |
+
print(f"Mean norm after: {norms_after.mean():.6f}")
|
| 80 |
+
print(f"Std norm after: {norms_after.std():.6f}")
|
| 81 |
+
|
| 82 |
+
# Save
|
| 83 |
+
output_path = self.cache_dir / 'embeddings.npy'
|
| 84 |
+
np.save(output_path, embeddings_normalized.astype('float32'))
|
| 85 |
+
print(f"\nβ
Saved normalized embeddings: {output_path}")
|
| 86 |
+
|
| 87 |
+
return True
|
| 88 |
+
|
| 89 |
+
def rebuild_faiss_index(self):
|
| 90 |
+
"""Rebuild FAISS index with correct metric"""
|
| 91 |
+
self.banner("REBUILDING FAISS INDEX")
|
| 92 |
+
|
| 93 |
+
emb_path = self.cache_dir / 'embeddings.npy'
|
| 94 |
+
|
| 95 |
+
if not emb_path.exists():
|
| 96 |
+
print("β embeddings.npy not found!")
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
print("Loading embeddings...")
|
| 100 |
+
embeddings = np.load(emb_path).astype('float32')
|
| 101 |
+
|
| 102 |
+
print(f"Shape: {embeddings.shape}")
|
| 103 |
+
|
| 104 |
+
# Ensure normalized
|
| 105 |
+
norms = np.linalg.norm(embeddings, axis=1)
|
| 106 |
+
if abs(norms.mean() - 1.0) > 0.01:
|
| 107 |
+
print("β οΈ Embeddings not normalized, normalizing now...")
|
| 108 |
+
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 109 |
+
np.save(emb_path, embeddings)
|
| 110 |
+
|
| 111 |
+
dimension = embeddings.shape[1]
|
| 112 |
+
|
| 113 |
+
print(f"\nBuilding FAISS index...")
|
| 114 |
+
print(f" Dimension: {dimension}")
|
| 115 |
+
print(f" Vectors: {len(embeddings):,}")
|
| 116 |
+
print(f" Metric: INNER_PRODUCT")
|
| 117 |
+
|
| 118 |
+
# Create index with INNER_PRODUCT metric
|
| 119 |
+
index = faiss.IndexFlatIP(dimension)
|
| 120 |
+
|
| 121 |
+
# Add vectors
|
| 122 |
+
print("\nAdding vectors...")
|
| 123 |
+
index.add(embeddings)
|
| 124 |
+
|
| 125 |
+
# Save
|
| 126 |
+
index_path = self.cache_dir / 'main_index.faiss'
|
| 127 |
+
faiss.write_index(index, str(index_path))
|
| 128 |
+
|
| 129 |
+
print(f"\nβ
Saved FAISS index: {index_path}")
|
| 130 |
+
print(f" Total vectors: {index.ntotal:,}")
|
| 131 |
+
|
| 132 |
+
return True
|
| 133 |
+
|
| 134 |
+
def verify_fixes(self):
|
| 135 |
+
"""Verify that fixes worked"""
|
| 136 |
+
self.banner("VERIFYING FIXES")
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
# Check embeddings
|
| 140 |
+
embeddings = np.load(self.cache_dir / 'embeddings.npy')
|
| 141 |
+
norms = np.linalg.norm(embeddings, axis=1)
|
| 142 |
+
|
| 143 |
+
print("π Embeddings:")
|
| 144 |
+
print(f" Mean norm: {norms.mean():.6f}")
|
| 145 |
+
print(f" Std norm: {norms.std():.6f}")
|
| 146 |
+
|
| 147 |
+
if abs(norms.mean() - 1.0) < 0.01 and norms.std() < 0.01:
|
| 148 |
+
print(" β
Properly normalized")
|
| 149 |
+
else:
|
| 150 |
+
print(" β Still not normalized properly")
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
# Check FAISS
|
| 154 |
+
index = faiss.read_index(str(self.cache_dir / 'main_index.faiss'))
|
| 155 |
+
|
| 156 |
+
print(f"\nπ FAISS Index:")
|
| 157 |
+
print(f" Vectors: {index.ntotal:,}")
|
| 158 |
+
print(f" Dimension: {index.d}")
|
| 159 |
+
|
| 160 |
+
metric = index.metric_type
|
| 161 |
+
if metric == faiss.METRIC_INNER_PRODUCT:
|
| 162 |
+
print(" β
Using INNER_PRODUCT")
|
| 163 |
+
else:
|
| 164 |
+
print(f" β Wrong metric: {metric}")
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
# Test search
|
| 168 |
+
print("\nπ Testing search...")
|
| 169 |
+
query = embeddings[0:1]
|
| 170 |
+
distances, indices = index.search(query, 5)
|
| 171 |
+
|
| 172 |
+
print(f" Top result index: {indices[0][0]}")
|
| 173 |
+
print(f" Top result score: {distances[0][0]:.6f}")
|
| 174 |
+
|
| 175 |
+
if distances[0][0] > 0.95: # Should match itself almost perfectly
|
| 176 |
+
print(" β
Search working correctly")
|
| 177 |
+
else:
|
| 178 |
+
print(" β οΈ Unexpected similarity score")
|
| 179 |
+
|
| 180 |
+
print("\nβ
ALL CHECKS PASSED!")
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f"\nβ Verification failed: {e}")
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def full_fix(self):
|
| 188 |
+
"""Run all fixes"""
|
| 189 |
+
self.banner("RUNNING FULL FIX")
|
| 190 |
+
|
| 191 |
+
print("This will:")
|
| 192 |
+
print("1. Backup existing files")
|
| 193 |
+
print("2. Normalize embeddings")
|
| 194 |
+
print("3. Rebuild FAISS index")
|
| 195 |
+
print("4. Verify fixes")
|
| 196 |
+
|
| 197 |
+
print("\nStarting in 3 seconds...")
|
| 198 |
+
import time
|
| 199 |
+
time.sleep(3)
|
| 200 |
+
|
| 201 |
+
# Backup
|
| 202 |
+
self.backup_files()
|
| 203 |
+
|
| 204 |
+
# Fix embeddings
|
| 205 |
+
if not self.normalize_embeddings():
|
| 206 |
+
print("\nβ Failed to normalize embeddings")
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
# Rebuild index
|
| 210 |
+
if not self.rebuild_faiss_index():
|
| 211 |
+
print("\nβ Failed to rebuild index")
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
# Verify
|
| 215 |
+
if not self.verify_fixes():
|
| 216 |
+
print("\nβ Fixes did not work properly")
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
print("\n" + "="*80)
|
| 220 |
+
print("β
ALL FIXES COMPLETED SUCCESSFULLY!")
|
| 221 |
+
print("="*80)
|
| 222 |
+
print("\nNext steps:")
|
| 223 |
+
print("1. Restart your API server: python api_server.py")
|
| 224 |
+
print("2. Test classification with a known category")
|
| 225 |
+
print("3. Check confidence scores")
|
| 226 |
+
print("\nIf issues persist, run diagnostics:")
|
| 227 |
+
print(" python diagnose_and_fix.py")
|
| 228 |
+
print("="*80 + "\n")
|
| 229 |
+
|
| 230 |
+
return True
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def main():
|
| 234 |
+
if len(sys.argv) < 2:
|
| 235 |
+
print("\n" + "="*80)
|
| 236 |
+
print("π§ EMBEDDING & INDEX FIXER")
|
| 237 |
+
print("="*80)
|
| 238 |
+
print("\nUsage:")
|
| 239 |
+
print(" python fix_embeddings.py normalize # Fix normalization only")
|
| 240 |
+
print(" python fix_embeddings.py rebuild-index # Rebuild FAISS index")
|
| 241 |
+
print(" python fix_embeddings.py full-fix # Do everything (recommended)")
|
| 242 |
+
print("\nExample:")
|
| 243 |
+
print(" python fix_embeddings.py full-fix")
|
| 244 |
+
print("="*80 + "\n")
|
| 245 |
+
sys.exit(1)
|
| 246 |
+
|
| 247 |
+
command = sys.argv[1].lower()
|
| 248 |
+
fixer = EmbeddingFixer()
|
| 249 |
+
|
| 250 |
+
if command == 'normalize':
|
| 251 |
+
fixer.backup_files()
|
| 252 |
+
fixer.normalize_embeddings()
|
| 253 |
+
fixer.verify_fixes()
|
| 254 |
+
|
| 255 |
+
elif command == 'rebuild-index':
|
| 256 |
+
fixer.backup_files()
|
| 257 |
+
fixer.rebuild_faiss_index()
|
| 258 |
+
fixer.verify_fixes()
|
| 259 |
+
|
| 260 |
+
elif command == 'full-fix':
|
| 261 |
+
fixer.full_fix()
|
| 262 |
+
|
| 263 |
+
else:
|
| 264 |
+
print(f"β Unknown command: {command}")
|
| 265 |
+
print("Use: normalize, rebuild-index, or full-fix")
|
| 266 |
+
sys.exit(1)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
main()
|
gradio_app.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio App for Product Category Classification
|
| 4 |
+
Model: intfloat/e5-base-v2 (must match training)
|
| 5 |
+
Requires: pip install gradio sentence-transformers faiss-cpu numpy pickle5
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
import faiss
|
| 11 |
+
import pickle
|
| 12 |
+
import numpy as np
|
| 13 |
+
import re
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
# ====================================================================
|
| 18 |
+
# CONFIG
|
| 19 |
+
# ====================================================================
|
| 20 |
+
CACHE_DIR = Path("cache")
|
| 21 |
+
MODEL_NAME = "intfloat/e5-base-v2"
|
| 22 |
+
FAISS_INDEX_PATH = CACHE_DIR / "main_index.faiss"
|
| 23 |
+
METADATA_PATH = CACHE_DIR / "metadata.pkl"
|
| 24 |
+
SYN_PATH = CACHE_DIR / "cross_store_synonyms.pkl"
|
| 25 |
+
|
| 26 |
+
encoder = None
|
| 27 |
+
faiss_index = None
|
| 28 |
+
metadata = []
|
| 29 |
+
cross_store_synonyms = {}
|
| 30 |
+
|
| 31 |
+
# ====================================================================
|
| 32 |
+
# UTILITIES
|
| 33 |
+
# ====================================================================
|
| 34 |
+
def clean_text(text: str) -> str:
|
| 35 |
+
if not text:
|
| 36 |
+
return ""
|
| 37 |
+
text = str(text).lower()
|
| 38 |
+
text = re.sub(r"[^\w\s-]", " ", text)
|
| 39 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 40 |
+
return text
|
| 41 |
+
|
| 42 |
+
def build_cross_store_synonyms():
|
| 43 |
+
synonyms = {
|
| 44 |
+
'washing machine': {'laundry machine', 'washer', 'clothes washer', 'washing appliance'},
|
| 45 |
+
'laundry machine': {'washing machine', 'washer', 'clothes washer'},
|
| 46 |
+
'dryer': {'drying machine', 'clothes dryer', 'tumble dryer'},
|
| 47 |
+
'refrigerator': {'fridge', 'cooler', 'ice box', 'cooling appliance'},
|
| 48 |
+
'dishwasher': {'dish washer', 'dish cleaning machine'},
|
| 49 |
+
'microwave': {'microwave oven', 'micro wave'},
|
| 50 |
+
'vacuum': {'vacuum cleaner', 'hoover', 'vac'},
|
| 51 |
+
'tv': {'television', 'telly', 'smart tv', 'display'},
|
| 52 |
+
'laptop': {'notebook', 'portable computer', 'laptop computer'},
|
| 53 |
+
'mobile': {'phone', 'cell phone', 'smartphone', 'cellphone'},
|
| 54 |
+
'tablet': {'ipad', 'tab', 'tablet computer'},
|
| 55 |
+
'headphones': {'headset', 'earphones', 'earbuds', 'ear buds'},
|
| 56 |
+
'speaker': {'audio speaker', 'sound system', 'speakers'},
|
| 57 |
+
'sofa': {'couch', 'settee', 'divan'},
|
| 58 |
+
'wardrobe': {'closet', 'armoire', 'cupboard'},
|
| 59 |
+
'drawer': {'chest of drawers', 'dresser'},
|
| 60 |
+
'pants': {'trousers', 'slacks', 'bottoms'},
|
| 61 |
+
'sweater': {'jumper', 'pullover', 'sweatshirt'},
|
| 62 |
+
'sneakers': {'trainers', 'tennis shoes', 'running shoes'},
|
| 63 |
+
'jacket': {'coat', 'blazer', 'outerwear'},
|
| 64 |
+
'cooker': {'stove', 'range', 'cooking range'},
|
| 65 |
+
'blender': {'mixer', 'food processor', 'liquidizer'},
|
| 66 |
+
'kettle': {'electric kettle', 'water boiler'},
|
| 67 |
+
'stroller': {'pram', 'pushchair', 'buggy', 'baby carriage'},
|
| 68 |
+
'diaper': {'nappy', 'nappies'},
|
| 69 |
+
'pacifier': {'dummy', 'soother'},
|
| 70 |
+
'wrench': {'spanner', 'adjustable wrench'},
|
| 71 |
+
'flashlight': {'torch', 'flash light'},
|
| 72 |
+
'screwdriver': {'screw driver'},
|
| 73 |
+
'tap': {'faucet', 'water tap'},
|
| 74 |
+
'bin': {'trash can', 'garbage can', 'waste bin'},
|
| 75 |
+
'curtain': {'drape', 'window covering'},
|
| 76 |
+
'guillotine': {'paper cutter', 'paper trimmer', 'blade cutter'},
|
| 77 |
+
'trimmer': {'cutter', 'cutting tool', 'edge cutter'},
|
| 78 |
+
'stapler': {'stapling machine', 'staple gun'},
|
| 79 |
+
'magazine': {'periodical', 'journal', 'publication'},
|
| 80 |
+
'comic': {'comic book', 'graphic novel', 'manga'},
|
| 81 |
+
'ebook': {'e-book', 'digital book', 'electronic book'},
|
| 82 |
+
'kids': {'children', 'child', 'childrens', 'youth', 'junior'},
|
| 83 |
+
'women': {'womens', 'ladies', 'female', 'lady'},
|
| 84 |
+
'men': {'mens', 'male', 'gentleman'},
|
| 85 |
+
'baby': {'infant', 'newborn', 'toddler'},
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
expanded = {}
|
| 89 |
+
for term, syns in synonyms.items():
|
| 90 |
+
expanded[term] = set(syns)
|
| 91 |
+
for syn in syns:
|
| 92 |
+
if syn not in expanded:
|
| 93 |
+
expanded[syn] = set()
|
| 94 |
+
expanded[syn].add(term)
|
| 95 |
+
expanded[syn].update(syns - {syn})
|
| 96 |
+
return expanded
|
| 97 |
+
|
| 98 |
+
def extract_cross_store_terms(text: str):
|
| 99 |
+
cleaned = clean_text(text)
|
| 100 |
+
words = cleaned.split()
|
| 101 |
+
all_terms = set()
|
| 102 |
+
all_terms.add(cleaned)
|
| 103 |
+
for word in words:
|
| 104 |
+
if len(word) > 2:
|
| 105 |
+
all_terms.add(word)
|
| 106 |
+
if word in cross_store_synonyms:
|
| 107 |
+
all_terms.update(cross_store_synonyms[word])
|
| 108 |
+
for i in range(len(words) - 1):
|
| 109 |
+
phrase = f"{words[i]} {words[i+1]}"
|
| 110 |
+
all_terms.add(phrase)
|
| 111 |
+
if phrase in cross_store_synonyms:
|
| 112 |
+
all_terms.update(cross_store_synonyms[phrase])
|
| 113 |
+
for i in range(len(words) - 2):
|
| 114 |
+
phrase = f"{words[i]} {words[i+1]} {words[i+2]}"
|
| 115 |
+
all_terms.add(phrase)
|
| 116 |
+
return list(all_terms)
|
| 117 |
+
|
| 118 |
+
def build_enhanced_query(title, description="", max_synonyms=10):
|
| 119 |
+
title_clean = clean_text(title)
|
| 120 |
+
description_clean = clean_text(description)
|
| 121 |
+
synonyms_list = extract_cross_store_terms(f"{title_clean} {description_clean}")
|
| 122 |
+
enhanced_query = ' '.join([title_clean]*3 + synonyms_list[:max_synonyms])
|
| 123 |
+
return enhanced_query, synonyms_list[:20]
|
| 124 |
+
|
| 125 |
+
def encode_query(text: str):
|
| 126 |
+
emb = encoder.encode(text, convert_to_numpy=True, normalize_embeddings=True)
|
| 127 |
+
if emb.ndim == 1:
|
| 128 |
+
emb = emb.reshape(1, -1)
|
| 129 |
+
return emb.astype('float32')
|
| 130 |
+
|
| 131 |
+
# ====================================================================
|
| 132 |
+
# CLASSIFICATION
|
| 133 |
+
# ====================================================================
|
| 134 |
+
def classify_product(title, description="", top_k=5):
|
| 135 |
+
start_time = time.time()
|
| 136 |
+
query_text, matched_terms = build_enhanced_query(title, description)
|
| 137 |
+
query_embedding = encode_query(query_text)
|
| 138 |
+
distances, indices = faiss_index.search(query_embedding, top_k)
|
| 139 |
+
|
| 140 |
+
results = []
|
| 141 |
+
for i, idx in enumerate(indices[0]):
|
| 142 |
+
if idx >= len(metadata):
|
| 143 |
+
continue
|
| 144 |
+
meta = metadata[idx]
|
| 145 |
+
similarity = 1 - distances[0][i]
|
| 146 |
+
confidence_pct = float(similarity) * 100
|
| 147 |
+
final_product = meta.get('levels', [])[-1] if meta.get('levels') else meta['category_path'].split('/')[-1]
|
| 148 |
+
results.append({
|
| 149 |
+
'rank': i+1,
|
| 150 |
+
'category_id': str(meta['category_id']),
|
| 151 |
+
'category_path': meta['category_path'],
|
| 152 |
+
'final_product': final_product,
|
| 153 |
+
'confidence': round(confidence_pct, 2),
|
| 154 |
+
'depth': meta.get('depth', 0)
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
if not results:
|
| 158 |
+
return {
|
| 159 |
+
'error': 'No results found',
|
| 160 |
+
'product': title
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
best = results[0]
|
| 164 |
+
conf_pct = best['confidence']
|
| 165 |
+
if conf_pct >= 90:
|
| 166 |
+
conf_level = "EXCELLENT"
|
| 167 |
+
elif conf_pct >= 85:
|
| 168 |
+
conf_level = "VERY HIGH"
|
| 169 |
+
elif conf_pct >= 80:
|
| 170 |
+
conf_level = "HIGH"
|
| 171 |
+
elif conf_pct >= 75:
|
| 172 |
+
conf_level = "GOOD"
|
| 173 |
+
elif conf_pct >= 70:
|
| 174 |
+
conf_level = "MEDIUM"
|
| 175 |
+
else:
|
| 176 |
+
conf_level = "LOW"
|
| 177 |
+
|
| 178 |
+
processing_time = (time.time() - start_time) * 1000
|
| 179 |
+
|
| 180 |
+
return {
|
| 181 |
+
'product': title,
|
| 182 |
+
'category_id': best['category_id'],
|
| 183 |
+
'category_path': best['category_path'],
|
| 184 |
+
'final_product': best['final_product'],
|
| 185 |
+
'confidence': f"{conf_level} ({conf_pct:.2f}%)",
|
| 186 |
+
'confidence_percent': conf_pct,
|
| 187 |
+
'depth': best['depth'],
|
| 188 |
+
'matched_terms': matched_terms,
|
| 189 |
+
'top_5_results': results,
|
| 190 |
+
'processing_time_ms': round(processing_time, 2)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
# ====================================================================
|
| 194 |
+
# LOAD MODEL & INDEX
|
| 195 |
+
# ====================================================================
|
| 196 |
+
def load_model():
|
| 197 |
+
global encoder, faiss_index, metadata, cross_store_synonyms
|
| 198 |
+
print("Loading sentence-transformer model...")
|
| 199 |
+
encoder = SentenceTransformer(MODEL_NAME)
|
| 200 |
+
print("Model loaded.")
|
| 201 |
+
|
| 202 |
+
print("Loading FAISS index...")
|
| 203 |
+
faiss_index = faiss.read_index(str(FAISS_INDEX_PATH))
|
| 204 |
+
print(f"FAISS index loaded: {faiss_index.ntotal} vectors.")
|
| 205 |
+
|
| 206 |
+
print("Loading metadata...")
|
| 207 |
+
with open(METADATA_PATH, 'rb') as f:
|
| 208 |
+
metadata = pickle.load(f)
|
| 209 |
+
print(f"Metadata loaded: {len(metadata)} categories.")
|
| 210 |
+
|
| 211 |
+
print("Loading cross-store synonyms...")
|
| 212 |
+
if SYN_PATH.exists():
|
| 213 |
+
with open(SYN_PATH, 'rb') as f:
|
| 214 |
+
cross_store_synonyms = pickle.load(f)
|
| 215 |
+
print(f"Loaded {len(cross_store_synonyms)} synonyms from file.")
|
| 216 |
+
else:
|
| 217 |
+
cross_store_synonyms = build_cross_store_synonyms()
|
| 218 |
+
print(f"Built {len(cross_store_synonyms)} default synonyms.")
|
| 219 |
+
|
| 220 |
+
# ====================================================================
|
| 221 |
+
# GRADIO FUNCTION
|
| 222 |
+
# ====================================================================
|
| 223 |
+
def classify_gradio(title, description=""):
|
| 224 |
+
result = classify_product(title, description)
|
| 225 |
+
top_match = str(result.get('final_product', ''))
|
| 226 |
+
category_path = str(result.get('category_path', ''))
|
| 227 |
+
confidence = str(result.get('confidence', ''))
|
| 228 |
+
matched_terms = ', '.join(result.get('matched_terms', [])) if result.get('matched_terms') else ''
|
| 229 |
+
top5_html = ""
|
| 230 |
+
for item in result.get('top_5_results', []):
|
| 231 |
+
top5_html += f"{item['rank']}. {item['final_product']} (ID: {item['category_id']}, Confidence: {item['confidence']}%)\n"
|
| 232 |
+
return top_match, category_path, confidence, matched_terms, top5_html
|
| 233 |
+
|
| 234 |
+
# ====================================================================
|
| 235 |
+
# MAIN GRADIO APP
|
| 236 |
+
# ====================================================================
|
| 237 |
+
def main():
|
| 238 |
+
load_model()
|
| 239 |
+
iface = gr.Interface(
|
| 240 |
+
fn=classify_gradio,
|
| 241 |
+
inputs=[
|
| 242 |
+
gr.Textbox(label="Product Title"),
|
| 243 |
+
gr.Textbox(label="Description")
|
| 244 |
+
],
|
| 245 |
+
outputs=[
|
| 246 |
+
gr.Textbox(label="Predicted Product"),
|
| 247 |
+
gr.Textbox(label="Category Path"),
|
| 248 |
+
gr.Textbox(label="Confidence"),
|
| 249 |
+
gr.Textbox(label="Matched Terms"),
|
| 250 |
+
gr.Textbox(label="Top 5 Alternatives")
|
| 251 |
+
],
|
| 252 |
+
title="π― Product Category Classifier",
|
| 253 |
+
description="Classify products with full cross-store synonyms and embeddings"
|
| 254 |
+
)
|
| 255 |
+
# Launch with a public shareable link
|
| 256 |
+
iface.launch(share=True)
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
main()
|
miss.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
π¬ ADVANCED MODEL DIAGNOSTICS & AUTOMATIC FIXES
|
| 3 |
+
===============================================
|
| 4 |
+
Diagnoses and fixes common issues causing low confidence/accuracy
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python diagnose_and_fix.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import pickle
|
| 13 |
+
import json
|
| 14 |
+
import faiss
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from sentence_transformers import SentenceTransformer
|
| 17 |
+
from collections import defaultdict, Counter
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import warnings
|
| 20 |
+
warnings.filterwarnings('ignore')
|
| 21 |
+
|
| 22 |
+
class ModelDiagnostics:
|
| 23 |
+
def __init__(self, cache_dir='cache', data_dir='data'):
|
| 24 |
+
self.cache_dir = Path(cache_dir)
|
| 25 |
+
self.data_dir = Path(data_dir)
|
| 26 |
+
self.issues = []
|
| 27 |
+
self.fixes_applied = []
|
| 28 |
+
|
| 29 |
+
def banner(self, text):
|
| 30 |
+
print("\n" + "="*80)
|
| 31 |
+
print(f"π {text}")
|
| 32 |
+
print("="*80 + "\n")
|
| 33 |
+
|
| 34 |
+
def check_embedding_normalization(self):
|
| 35 |
+
"""Check if embeddings are properly normalized"""
|
| 36 |
+
self.banner("CHECKING EMBEDDING NORMALIZATION")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
embeddings = np.load(self.cache_dir / 'embeddings.npy')
|
| 40 |
+
|
| 41 |
+
# Check norms
|
| 42 |
+
norms = np.linalg.norm(embeddings, axis=1)
|
| 43 |
+
|
| 44 |
+
print(f"π Embedding Statistics:")
|
| 45 |
+
print(f" Shape: {embeddings.shape}")
|
| 46 |
+
print(f" Mean norm: {norms.mean():.6f}")
|
| 47 |
+
print(f" Std norm: {norms.std():.6f}")
|
| 48 |
+
print(f" Min norm: {norms.min():.6f}")
|
| 49 |
+
print(f" Max norm: {norms.max():.6f}")
|
| 50 |
+
|
| 51 |
+
# Should be ~1.0 if normalized
|
| 52 |
+
if abs(norms.mean() - 1.0) > 0.01 or norms.std() > 0.01:
|
| 53 |
+
self.issues.append({
|
| 54 |
+
'type': 'CRITICAL',
|
| 55 |
+
'issue': 'Embeddings not normalized',
|
| 56 |
+
'details': f'Mean norm: {norms.mean():.6f} (should be ~1.0)',
|
| 57 |
+
'fix': 'Re-normalize embeddings'
|
| 58 |
+
})
|
| 59 |
+
print(" β ISSUE: Embeddings are NOT normalized!")
|
| 60 |
+
print(" This causes incorrect similarity scores")
|
| 61 |
+
return False
|
| 62 |
+
else:
|
| 63 |
+
print(" β
Embeddings properly normalized")
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f" β Error: {e}")
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
def check_faiss_metric(self):
|
| 71 |
+
"""Check FAISS index metric type"""
|
| 72 |
+
self.banner("CHECKING FAISS INDEX METRIC")
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
index = faiss.read_index(str(self.cache_dir / 'main_index.faiss'))
|
| 76 |
+
|
| 77 |
+
metric = index.metric_type
|
| 78 |
+
|
| 79 |
+
print(f"π FAISS Index:")
|
| 80 |
+
print(f" Vectors: {index.ntotal:,}")
|
| 81 |
+
print(f" Dimension: {index.d}")
|
| 82 |
+
print(f" Metric type: {metric}")
|
| 83 |
+
|
| 84 |
+
if metric == faiss.METRIC_INNER_PRODUCT:
|
| 85 |
+
print(" β
Using INNER_PRODUCT (correct for normalized vectors)")
|
| 86 |
+
return True
|
| 87 |
+
elif metric == faiss.METRIC_L2:
|
| 88 |
+
self.issues.append({
|
| 89 |
+
'type': 'CRITICAL',
|
| 90 |
+
'issue': 'Wrong FAISS metric',
|
| 91 |
+
'details': 'Using L2 distance instead of inner product',
|
| 92 |
+
'fix': 'Rebuild index with METRIC_INNER_PRODUCT'
|
| 93 |
+
})
|
| 94 |
+
print(" β ISSUE: Using L2 distance!")
|
| 95 |
+
print(" Should use INNER_PRODUCT for normalized vectors")
|
| 96 |
+
return False
|
| 97 |
+
else:
|
| 98 |
+
print(f" β οΈ Unknown metric: {metric}")
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f" β Error: {e}")
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def check_text_weighting(self):
|
| 106 |
+
"""Check if text is properly weighted"""
|
| 107 |
+
self.banner("CHECKING TEXT CONSTRUCTION")
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
with open(self.cache_dir / 'metadata.pkl', 'rb') as f:
|
| 111 |
+
metadata = pickle.load(f)
|
| 112 |
+
|
| 113 |
+
# Analyze a sample
|
| 114 |
+
sample = metadata[0]
|
| 115 |
+
|
| 116 |
+
print(f"π Sample Category:")
|
| 117 |
+
print(f" ID: {sample.get('category_id')}")
|
| 118 |
+
print(f" Path: {sample.get('category_path')}")
|
| 119 |
+
print(f" Depth: {sample.get('depth')}")
|
| 120 |
+
print(f" Levels: {sample.get('levels')}")
|
| 121 |
+
|
| 122 |
+
# Check if we have tags
|
| 123 |
+
if 'auto_tags' in sample and sample['auto_tags']:
|
| 124 |
+
print(f" Tags: {len(sample['auto_tags'])} tags")
|
| 125 |
+
print(f" Sample tags: {sample['auto_tags'][:5]}")
|
| 126 |
+
print(" β
Auto-tags present")
|
| 127 |
+
else:
|
| 128 |
+
self.issues.append({
|
| 129 |
+
'type': 'WARNING',
|
| 130 |
+
'issue': 'Missing auto-tags',
|
| 131 |
+
'details': 'Categories lack auto-generated tags',
|
| 132 |
+
'fix': 'Generate tags from category paths'
|
| 133 |
+
})
|
| 134 |
+
print(" β οΈ No auto-tags found")
|
| 135 |
+
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f" β Error: {e}")
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
def test_predictions(self, num_samples=100):
|
| 143 |
+
"""Test prediction accuracy on random samples"""
|
| 144 |
+
self.banner("TESTING PREDICTION ACCURACY")
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
# Load model
|
| 148 |
+
print("Loading model and index...")
|
| 149 |
+
encoder = SentenceTransformer('intfloat/e5-base-v2')
|
| 150 |
+
index = faiss.read_index(str(self.cache_dir / 'main_index.faiss'))
|
| 151 |
+
|
| 152 |
+
with open(self.cache_dir / 'metadata.pkl', 'rb') as f:
|
| 153 |
+
metadata = pickle.load(f)
|
| 154 |
+
|
| 155 |
+
# Load CSV
|
| 156 |
+
csv_files = list(self.data_dir.glob('*.csv'))
|
| 157 |
+
if not csv_files:
|
| 158 |
+
print(" β No CSV files found in data/")
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
df = pd.read_csv(csv_files[0])
|
| 162 |
+
|
| 163 |
+
# Sample categories
|
| 164 |
+
samples = df.sample(min(num_samples, len(df)))
|
| 165 |
+
|
| 166 |
+
correct = 0
|
| 167 |
+
confidence_scores = []
|
| 168 |
+
rank_positions = []
|
| 169 |
+
|
| 170 |
+
print(f"Testing {len(samples)} random categories...\n")
|
| 171 |
+
|
| 172 |
+
for idx, row in tqdm(samples.iterrows(), total=len(samples)):
|
| 173 |
+
cat_id = str(row.iloc[0]) # First column
|
| 174 |
+
cat_path = str(row.iloc[1]) # Second column
|
| 175 |
+
|
| 176 |
+
# Get leaf category (final product)
|
| 177 |
+
leaf = cat_path.split('/')[-1].strip()
|
| 178 |
+
|
| 179 |
+
# Build query
|
| 180 |
+
query = f"query: {leaf}"
|
| 181 |
+
|
| 182 |
+
# Encode
|
| 183 |
+
query_emb = encoder.encode(query, normalize_embeddings=True)
|
| 184 |
+
query_emb = query_emb.reshape(1, -1).astype('float32')
|
| 185 |
+
|
| 186 |
+
# Search
|
| 187 |
+
distances, indices = index.search(query_emb, 10)
|
| 188 |
+
|
| 189 |
+
# Check if correct category is in top results
|
| 190 |
+
found_rank = None
|
| 191 |
+
for rank, idx in enumerate(indices[0]):
|
| 192 |
+
pred_id = str(metadata[idx]['category_id'])
|
| 193 |
+
if pred_id == cat_id:
|
| 194 |
+
found_rank = rank + 1
|
| 195 |
+
correct += 1
|
| 196 |
+
confidence_scores.append(float(distances[0][rank]))
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
if found_rank:
|
| 200 |
+
rank_positions.append(found_rank)
|
| 201 |
+
else:
|
| 202 |
+
rank_positions.append(11) # Not in top 10
|
| 203 |
+
|
| 204 |
+
# Calculate metrics
|
| 205 |
+
accuracy = (correct / len(samples)) * 100
|
| 206 |
+
avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
|
| 207 |
+
|
| 208 |
+
print(f"\nπ Results:")
|
| 209 |
+
print(f" Accuracy (Top-1): {accuracy:.2f}%")
|
| 210 |
+
print(f" Correct predictions: {correct}/{len(samples)}")
|
| 211 |
+
print(f" Average confidence: {avg_confidence:.4f}")
|
| 212 |
+
|
| 213 |
+
if confidence_scores:
|
| 214 |
+
print(f" Min confidence: {min(confidence_scores):.4f}")
|
| 215 |
+
print(f" Max confidence: {max(confidence_scores):.4f}")
|
| 216 |
+
|
| 217 |
+
# Rank distribution
|
| 218 |
+
rank_counts = Counter(rank_positions)
|
| 219 |
+
print(f"\n Rank Distribution:")
|
| 220 |
+
for rank in sorted(rank_counts.keys())[:5]:
|
| 221 |
+
count = rank_counts[rank]
|
| 222 |
+
pct = (count / len(samples)) * 100
|
| 223 |
+
print(f" Rank {rank}: {count} ({pct:.1f}%)")
|
| 224 |
+
|
| 225 |
+
if accuracy < 70:
|
| 226 |
+
self.issues.append({
|
| 227 |
+
'type': 'CRITICAL',
|
| 228 |
+
'issue': 'Low prediction accuracy',
|
| 229 |
+
'details': f'Only {accuracy:.1f}% accuracy',
|
| 230 |
+
'fix': 'Retrain with better text weighting'
|
| 231 |
+
})
|
| 232 |
+
print(f"\n β ISSUE: Low accuracy ({accuracy:.1f}%)")
|
| 233 |
+
return False
|
| 234 |
+
elif accuracy < 85:
|
| 235 |
+
self.issues.append({
|
| 236 |
+
'type': 'WARNING',
|
| 237 |
+
'issue': 'Moderate accuracy',
|
| 238 |
+
'details': f'Accuracy: {accuracy:.1f}%',
|
| 239 |
+
'fix': 'Consider retraining with optimizations'
|
| 240 |
+
})
|
| 241 |
+
print(f"\n β οΈ Moderate accuracy ({accuracy:.1f}%)")
|
| 242 |
+
return True
|
| 243 |
+
else:
|
| 244 |
+
print(f"\n β
Good accuracy ({accuracy:.1f}%)")
|
| 245 |
+
return True
|
| 246 |
+
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f" β Error: {e}")
|
| 249 |
+
import traceback
|
| 250 |
+
traceback.print_exc()
|
| 251 |
+
return False
|
| 252 |
+
|
| 253 |
+
def analyze_category_distribution(self):
|
| 254 |
+
"""Analyze category depth and structure"""
|
| 255 |
+
self.banner("ANALYZING CATEGORY STRUCTURE")
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
with open(self.cache_dir / 'metadata.pkl', 'rb') as f:
|
| 259 |
+
metadata = pickle.load(f)
|
| 260 |
+
|
| 261 |
+
depths = [m.get('depth', 0) for m in metadata]
|
| 262 |
+
|
| 263 |
+
print(f"π Category Structure:")
|
| 264 |
+
print(f" Total categories: {len(metadata):,}")
|
| 265 |
+
print(f" Average depth: {np.mean(depths):.2f}")
|
| 266 |
+
print(f" Min depth: {min(depths)}")
|
| 267 |
+
print(f" Max depth: {max(depths)}")
|
| 268 |
+
|
| 269 |
+
# Depth distribution
|
| 270 |
+
depth_counts = Counter(depths)
|
| 271 |
+
print(f"\n Depth Distribution:")
|
| 272 |
+
for depth in sorted(depth_counts.keys())[:8]:
|
| 273 |
+
count = depth_counts[depth]
|
| 274 |
+
pct = (count / len(metadata)) * 100
|
| 275 |
+
print(f" Depth {depth}: {count:,} ({pct:.1f}%)")
|
| 276 |
+
|
| 277 |
+
# Check for imbalance
|
| 278 |
+
if max(depths) - min(depths) > 5:
|
| 279 |
+
self.issues.append({
|
| 280 |
+
'type': 'WARNING',
|
| 281 |
+
'issue': 'Large depth variation',
|
| 282 |
+
'details': f'Depth ranges from {min(depths)} to {max(depths)}',
|
| 283 |
+
'fix': 'Consider depth-based weighting'
|
| 284 |
+
})
|
| 285 |
+
print(f"\n β οΈ Large depth variation detected")
|
| 286 |
+
|
| 287 |
+
return True
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f" β Error: {e}")
|
| 291 |
+
return False
|
| 292 |
+
|
| 293 |
+
def check_duplicate_embeddings(self):
|
| 294 |
+
"""Check for duplicate or near-duplicate embeddings"""
|
| 295 |
+
self.banner("CHECKING FOR DUPLICATE EMBEDDINGS")
|
| 296 |
+
|
| 297 |
+
try:
|
| 298 |
+
embeddings = np.load(self.cache_dir / 'embeddings.npy')
|
| 299 |
+
|
| 300 |
+
# Sample check (checking all would be too slow)
|
| 301 |
+
sample_size = min(1000, len(embeddings))
|
| 302 |
+
sample_indices = np.random.choice(len(embeddings), sample_size, replace=False)
|
| 303 |
+
sample_embs = embeddings[sample_indices]
|
| 304 |
+
|
| 305 |
+
# Compute pairwise similarities
|
| 306 |
+
similarities = np.dot(sample_embs, sample_embs.T)
|
| 307 |
+
|
| 308 |
+
# Count very high similarities (excluding diagonal)
|
| 309 |
+
np.fill_diagonal(similarities, 0)
|
| 310 |
+
high_sim = (similarities > 0.99).sum() // 2 # Divide by 2 for symmetry
|
| 311 |
+
|
| 312 |
+
print(f"π Duplicate Check (sample of {sample_size}):")
|
| 313 |
+
print(f" Very similar pairs (>0.99): {high_sim}")
|
| 314 |
+
|
| 315 |
+
if high_sim > sample_size * 0.05: # >5% duplicates
|
| 316 |
+
self.issues.append({
|
| 317 |
+
'type': 'WARNING',
|
| 318 |
+
'issue': 'Many duplicate embeddings',
|
| 319 |
+
'details': f'{high_sim} pairs with >0.99 similarity',
|
| 320 |
+
'fix': 'Check for duplicate categories or improve text diversity'
|
| 321 |
+
})
|
| 322 |
+
print(f" β οΈ Many near-duplicates detected")
|
| 323 |
+
return False
|
| 324 |
+
else:
|
| 325 |
+
print(f" β
Low duplicate rate")
|
| 326 |
+
return True
|
| 327 |
+
|
| 328 |
+
except Exception as e:
|
| 329 |
+
print(f" β Error: {e}")
|
| 330 |
+
return False
|
| 331 |
+
|
| 332 |
+
def generate_report(self):
|
| 333 |
+
"""Generate diagnostic report"""
|
| 334 |
+
self.banner("DIAGNOSTIC REPORT")
|
| 335 |
+
|
| 336 |
+
if not self.issues:
|
| 337 |
+
print("β
NO ISSUES FOUND!")
|
| 338 |
+
print("\nYour model appears to be properly configured.")
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
# Group by severity
|
| 342 |
+
critical = [i for i in self.issues if i['type'] == 'CRITICAL']
|
| 343 |
+
warnings = [i for i in self.issues if i['type'] == 'WARNING']
|
| 344 |
+
|
| 345 |
+
if critical:
|
| 346 |
+
print("π΄ CRITICAL ISSUES:")
|
| 347 |
+
for i, issue in enumerate(critical, 1):
|
| 348 |
+
print(f"\n{i}. {issue['issue']}")
|
| 349 |
+
print(f" Details: {issue['details']}")
|
| 350 |
+
print(f" Fix: {issue['fix']}")
|
| 351 |
+
|
| 352 |
+
if warnings:
|
| 353 |
+
print("\nπ‘ WARNINGS:")
|
| 354 |
+
for i, issue in enumerate(warnings, 1):
|
| 355 |
+
print(f"\n{i}. {issue['issue']}")
|
| 356 |
+
print(f" Details: {issue['details']}")
|
| 357 |
+
print(f" Fix: {issue['fix']}")
|
| 358 |
+
|
| 359 |
+
print(f"\nπ Summary:")
|
| 360 |
+
print(f" Critical issues: {len(critical)}")
|
| 361 |
+
print(f" Warnings: {len(warnings)}")
|
| 362 |
+
|
| 363 |
+
def suggest_fixes(self):
|
| 364 |
+
"""Suggest fixes based on issues found"""
|
| 365 |
+
self.banner("RECOMMENDED FIXES")
|
| 366 |
+
|
| 367 |
+
if not self.issues:
|
| 368 |
+
print("β
No fixes needed!")
|
| 369 |
+
return
|
| 370 |
+
|
| 371 |
+
print("Run these commands to fix issues:\n")
|
| 372 |
+
|
| 373 |
+
# Check for critical issues
|
| 374 |
+
critical = [i for i in self.issues if i['type'] == 'CRITICAL']
|
| 375 |
+
|
| 376 |
+
if any('normalization' in i['issue'].lower() for i in critical):
|
| 377 |
+
print("1οΈβ£ Fix embedding normalization:")
|
| 378 |
+
print(" python fix_embeddings.py normalize")
|
| 379 |
+
print()
|
| 380 |
+
|
| 381 |
+
if any('faiss' in i['issue'].lower() for i in critical):
|
| 382 |
+
print("2οΈβ£ Rebuild FAISS index with correct metric:")
|
| 383 |
+
print(" python fix_embeddings.py rebuild-index")
|
| 384 |
+
print()
|
| 385 |
+
|
| 386 |
+
if any('accuracy' in i['issue'].lower() for i in critical):
|
| 387 |
+
print("3οΈβ£ Retrain with improved settings:")
|
| 388 |
+
print(" python train_fixed_v2.py data/categories.csv data/tags.json")
|
| 389 |
+
print()
|
| 390 |
+
|
| 391 |
+
if any('tags' in i['issue'].lower() for i in self.issues):
|
| 392 |
+
print("4οΈβ£ Generate missing tags:")
|
| 393 |
+
print(" python generate_tags.py data/categories.csv")
|
| 394 |
+
print()
|
| 395 |
+
|
| 396 |
+
def run_full_diagnostics(self):
|
| 397 |
+
"""Run all diagnostic checks"""
|
| 398 |
+
print("\n" + "="*80)
|
| 399 |
+
print("π¬ COMPREHENSIVE MODEL DIAGNOSTICS")
|
| 400 |
+
print("="*80)
|
| 401 |
+
|
| 402 |
+
# Run all checks
|
| 403 |
+
self.check_embedding_normalization()
|
| 404 |
+
self.check_faiss_metric()
|
| 405 |
+
self.check_text_weighting()
|
| 406 |
+
self.analyze_category_distribution()
|
| 407 |
+
self.check_duplicate_embeddings()
|
| 408 |
+
self.test_predictions(num_samples=50)
|
| 409 |
+
|
| 410 |
+
# Generate report
|
| 411 |
+
self.generate_report()
|
| 412 |
+
self.suggest_fixes()
|
| 413 |
+
|
| 414 |
+
print("\n" + "="*80)
|
| 415 |
+
print("π― DIAGNOSTICS COMPLETE")
|
| 416 |
+
print("="*80 + "\n")
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
diagnostics = ModelDiagnostics()
|
| 421 |
+
diagnostics.run_full_diagnostics()
|
path.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class HybridTagsGenerator:
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
# Search intent patterns (E5 likes real text)
|
| 13 |
+
self.search_intents = [
|
| 14 |
+
"buy {item}",
|
| 15 |
+
"best {item}",
|
| 16 |
+
"{item} reviews",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
def clean(self, text):
|
| 20 |
+
text = str(text).lower()
|
| 21 |
+
text = re.sub(r"[^\w\s-]", " ", text)
|
| 22 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
# -------------------------------------------------------
|
| 26 |
+
# 1. Hierarchical tag boosting
|
| 27 |
+
# -------------------------------------------------------
|
| 28 |
+
def make_hierarchy_tags(self, path):
|
| 29 |
+
levels = [l.strip() for l in path.split("/") if l.strip()]
|
| 30 |
+
tags = []
|
| 31 |
+
|
| 32 |
+
# Strong full-path signal
|
| 33 |
+
full = " ".join(self.clean(l) for l in levels)
|
| 34 |
+
tags.extend([full] * 8) # <-- Strong boost
|
| 35 |
+
|
| 36 |
+
# Progressive hierarchy
|
| 37 |
+
for i in range(1, len(levels) + 1):
|
| 38 |
+
seg = " ".join(self.clean(l) for l in levels[:i])
|
| 39 |
+
tags.append(seg)
|
| 40 |
+
|
| 41 |
+
# Parent-child reinforcement
|
| 42 |
+
if len(levels) >= 2:
|
| 43 |
+
parent = self.clean(levels[-2])
|
| 44 |
+
child = self.clean(levels[-1])
|
| 45 |
+
|
| 46 |
+
tags.extend([
|
| 47 |
+
f"{parent} {child}",
|
| 48 |
+
f"{child} {parent}",
|
| 49 |
+
f"{child} in {parent}",
|
| 50 |
+
f"{child} category {parent}"
|
| 51 |
+
])
|
| 52 |
+
|
| 53 |
+
return tags
|
| 54 |
+
|
| 55 |
+
# -------------------------------------------------------
|
| 56 |
+
# 2. Extract key terms and word combos
|
| 57 |
+
# -------------------------------------------------------
|
| 58 |
+
def extract_terms(self, path):
|
| 59 |
+
levels = [l.strip() for l in path.split("/") if l.strip()]
|
| 60 |
+
terms = []
|
| 61 |
+
|
| 62 |
+
for level in levels:
|
| 63 |
+
cleaned = self.clean(level)
|
| 64 |
+
if cleaned not in terms:
|
| 65 |
+
terms.append(cleaned)
|
| 66 |
+
|
| 67 |
+
words = [w for w in cleaned.split() if len(w) > 3]
|
| 68 |
+
terms.extend(words)
|
| 69 |
+
|
| 70 |
+
# bigrams for leaf and parent
|
| 71 |
+
if level in levels[-2:]:
|
| 72 |
+
for i in range(len(words) - 1):
|
| 73 |
+
terms.append(f"{words[i]} {words[i+1]}")
|
| 74 |
+
|
| 75 |
+
# Remove duplicates, keep order
|
| 76 |
+
return list(dict.fromkeys(terms))
|
| 77 |
+
|
| 78 |
+
# -------------------------------------------------------
|
| 79 |
+
# 3. Build final tag list for ONE category
|
| 80 |
+
# -------------------------------------------------------
|
| 81 |
+
def build_tags(self, category_id, category_path):
|
| 82 |
+
tags = []
|
| 83 |
+
|
| 84 |
+
# Hierarchy tags
|
| 85 |
+
tags.extend(self.make_hierarchy_tags(category_path))
|
| 86 |
+
|
| 87 |
+
# Key terms
|
| 88 |
+
terms = self.extract_terms(category_path)
|
| 89 |
+
tags.extend(terms[:15])
|
| 90 |
+
|
| 91 |
+
# Search intent (for leaf level)
|
| 92 |
+
leaf = self.clean(category_path.split("/")[-1])
|
| 93 |
+
for pattern in self.search_intents[:2]:
|
| 94 |
+
tags.append(pattern.format(item=leaf))
|
| 95 |
+
|
| 96 |
+
# Clean + dedupe + limit
|
| 97 |
+
seen = set()
|
| 98 |
+
final = []
|
| 99 |
+
|
| 100 |
+
for t in tags:
|
| 101 |
+
c = self.clean(t)
|
| 102 |
+
if c and c not in seen and len(c.split()) <= 6:
|
| 103 |
+
seen.add(c)
|
| 104 |
+
final.append(c)
|
| 105 |
+
|
| 106 |
+
return final[:50]
|
| 107 |
+
|
| 108 |
+
# -------------------------------------------------------
|
| 109 |
+
# 4. Generate tags.json for entire CSV
|
| 110 |
+
# -------------------------------------------------------
|
| 111 |
+
def generate_tags_json(self, csv_path, output="tags.json"):
|
| 112 |
+
df = pd.read_csv(csv_path, dtype=str)
|
| 113 |
+
|
| 114 |
+
if "Category_ID" not in df.columns or "Category_path" not in df.columns:
|
| 115 |
+
raise ValueError("CSV must contain Category_ID, Category_path columns")
|
| 116 |
+
|
| 117 |
+
df = df.dropna(subset=["Category_path"])
|
| 118 |
+
|
| 119 |
+
tags_dict = {}
|
| 120 |
+
|
| 121 |
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Building tags"):
|
| 122 |
+
cid = str(row["Category_ID"])
|
| 123 |
+
cpath = str(row["Category_path"])
|
| 124 |
+
tags_dict[cid] = self.build_tags(cid, cpath)
|
| 125 |
+
|
| 126 |
+
with open(output, "w", encoding="utf-8") as f:
|
| 127 |
+
json.dump(tags_dict, f, indent=2)
|
| 128 |
+
|
| 129 |
+
print(f"β
DONE: {output} saved.")
|
| 130 |
+
return tags_dict
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
import sys
|
| 135 |
+
if len(sys.argv) < 2:
|
| 136 |
+
print("Usage: python build_tags_json.py <categories.csv>")
|
| 137 |
+
sys.exit()
|
| 138 |
+
|
| 139 |
+
csv_file = sys.argv[1]
|
| 140 |
+
gen = HybridTagsGenerator()
|
| 141 |
+
gen.generate_tags_json(csv_file, "tags.json")
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sentence-transformers==3.3.1
|
| 2 |
+
# torch==2.5.1
|
| 3 |
+
# transformers==4.46.3
|
| 4 |
+
# faiss-gpu==1.9.0.post1
|
| 5 |
+
# pandas==2.2.3
|
| 6 |
+
# numpy==2.0.2
|
| 7 |
+
# fastapi==0.115.6
|
| 8 |
+
# uvicorn==0.32.1
|
| 9 |
+
# gunicorn==23.0.0
|
| 10 |
+
# pydantic==2.10.3
|
| 11 |
+
# joblib==1.4.2
|
| 12 |
+
# psutil==6.1.0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
sentence-transformers==3.3.1
|
| 16 |
+
torch==2.5.1
|
| 17 |
+
transformers==4.46.3
|
| 18 |
+
faiss-cpu==1.9.0
|
| 19 |
+
pandas==2.2.3
|
| 20 |
+
numpy==2.0.2
|
| 21 |
+
fastapi==0.115.6
|
| 22 |
+
uvicorn==0.32.1
|
| 23 |
+
gunicorn==23.0.0
|
| 24 |
+
pydantic==2.10.3
|
| 25 |
+
joblib==1.4.2
|
| 26 |
+
psutil==6.1.0
|
| 27 |
+
nltk>=3.8.1
|
| 28 |
+
# Note: faiss-gpu is commented out to avoid compatibility issues on systems without a compatible GPU.
|
synonyms.py
CHANGED
|
@@ -1,366 +1,854 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
β
Uses e5-base-v2 (
|
| 6 |
-
β
|
| 7 |
-
β
|
| 8 |
-
|
| 9 |
-
Usage:
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
import pickle
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
import json
|
| 17 |
-
from collections import defaultdict
|
| 18 |
-
from tqdm import tqdm
|
| 19 |
-
import warnings
|
| 20 |
-
import sys
|
| 21 |
-
import os
|
| 22 |
-
|
| 23 |
-
warnings.filterwarnings('ignore')
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
main()
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# π€ FIXED AI-POWERED SYNONYM MANAGER
|
| 3 |
+
# ====================================
|
| 4 |
+
# β
Windows + NVIDIA GPU optimized
|
| 5 |
+
# β
Uses e5-base-v2 (lower memory)
|
| 6 |
+
# β
Proper error handling
|
| 7 |
+
# β
Progress tracking
|
| 8 |
+
|
| 9 |
+
# Usage:
|
| 10 |
+
# python synonym_manager_fixed.py autobuild data/category_id_path_only.csv
|
| 11 |
+
# python synonym_manager_fixed.py autobuild data/category_id_path_only.csv --fast
|
| 12 |
+
# """
|
| 13 |
+
|
| 14 |
+
# import pickle
|
| 15 |
+
# from pathlib import Path
|
| 16 |
+
# import json
|
| 17 |
+
# from collections import defaultdict
|
| 18 |
+
# from tqdm import tqdm
|
| 19 |
+
# import warnings
|
| 20 |
+
# import sys
|
| 21 |
+
# import os
|
| 22 |
+
|
| 23 |
+
# warnings.filterwarnings('ignore')
|
| 24 |
+
|
| 25 |
+
# # Fix CUDA issues on Windows
|
| 26 |
+
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
| 27 |
+
|
| 28 |
+
# try:
|
| 29 |
+
# from nltk.corpus import wordnet
|
| 30 |
+
# from nltk import download as nltk_download
|
| 31 |
+
# WORDNET_AVAILABLE = True
|
| 32 |
+
# except ImportError:
|
| 33 |
+
# WORDNET_AVAILABLE = False
|
| 34 |
+
# print("β οΈ NLTK not available. Install with: pip install nltk")
|
| 35 |
+
|
| 36 |
+
# try:
|
| 37 |
+
# from sentence_transformers import SentenceTransformer, util
|
| 38 |
+
# import torch
|
| 39 |
+
# TRANSFORMERS_AVAILABLE = True
|
| 40 |
+
# except ImportError:
|
| 41 |
+
# TRANSFORMERS_AVAILABLE = False
|
| 42 |
+
# print("β οΈ SentenceTransformers not available.")
|
| 43 |
+
# print(" Install with: pip install sentence-transformers torch")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# class FixedAISynonymManager:
|
| 47 |
+
# """Fixed AI-powered synonym manager for Windows + NVIDIA GPU"""
|
| 48 |
+
|
| 49 |
+
# def __init__(self, cache_dir='cache', tags_file='data/tags.json', fast_mode=False):
|
| 50 |
+
# self.cache_dir = Path(cache_dir)
|
| 51 |
+
# self.synonyms_file = self.cache_dir / 'cross_store_synonyms.pkl'
|
| 52 |
+
# self.tags_file = Path(tags_file)
|
| 53 |
+
# self.synonyms = {}
|
| 54 |
+
# self.tags_data = {}
|
| 55 |
+
# self.model = None
|
| 56 |
+
# self.device = "cpu"
|
| 57 |
+
# self.fast_mode = fast_mode
|
| 58 |
+
|
| 59 |
+
# # Create cache directory
|
| 60 |
+
# self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# # Load existing data
|
| 63 |
+
# self.load_tags()
|
| 64 |
+
# if self.synonyms_file.exists():
|
| 65 |
+
# self.load_synonyms()
|
| 66 |
+
# else:
|
| 67 |
+
# print("π No existing synonyms file. Will create new one.")
|
| 68 |
+
|
| 69 |
+
# def load_tags(self):
|
| 70 |
+
# """Load domain-specific tags (optional)"""
|
| 71 |
+
# if self.tags_file.exists():
|
| 72 |
+
# try:
|
| 73 |
+
# with open(self.tags_file, 'r', encoding='utf-8') as f:
|
| 74 |
+
# self.tags_data = json.load(f)
|
| 75 |
+
# print(f"β
Loaded {len(self.tags_data)} tag entries")
|
| 76 |
+
# return True
|
| 77 |
+
# except Exception as e:
|
| 78 |
+
# print(f"β οΈ Could not load tags.json: {e}")
|
| 79 |
+
# else:
|
| 80 |
+
# print(f"βΉοΈ tags.json not found (optional)")
|
| 81 |
+
# return False
|
| 82 |
+
|
| 83 |
+
# def load_synonyms(self):
|
| 84 |
+
# """Load existing synonyms with format conversion"""
|
| 85 |
+
# try:
|
| 86 |
+
# with open(self.synonyms_file, 'rb') as f:
|
| 87 |
+
# loaded = pickle.load(f)
|
| 88 |
+
|
| 89 |
+
# # Handle different formats
|
| 90 |
+
# if not loaded:
|
| 91 |
+
# self.synonyms = {}
|
| 92 |
+
# return
|
| 93 |
+
|
| 94 |
+
# # Check format
|
| 95 |
+
# first_val = next(iter(loaded.values()))
|
| 96 |
+
|
| 97 |
+
# if isinstance(first_val, list):
|
| 98 |
+
# if first_val and isinstance(first_val[0], tuple):
|
| 99 |
+
# # New format: [(syn, conf, src), ...]
|
| 100 |
+
# self.synonyms = loaded
|
| 101 |
+
# print(f"β
Loaded {len(self.synonyms)} synonym entries (new format)")
|
| 102 |
+
# elif first_val and isinstance(first_val[0], str):
|
| 103 |
+
# # Legacy format: [syn1, syn2, ...]
|
| 104 |
+
# self.synonyms = {
|
| 105 |
+
# k: [(v, 0.8, 'legacy') for v in vals]
|
| 106 |
+
# for k, vals in loaded.items()
|
| 107 |
+
# }
|
| 108 |
+
# print(f"β
Converted {len(self.synonyms)} legacy synonym entries")
|
| 109 |
+
# elif isinstance(first_val, set):
|
| 110 |
+
# # Set format
|
| 111 |
+
# self.synonyms = {
|
| 112 |
+
# k: [(v, 0.8, 'legacy') for v in vals]
|
| 113 |
+
# for k, vals in loaded.items()
|
| 114 |
+
# }
|
| 115 |
+
# print(f"β
Converted {len(self.synonyms)} set-based entries")
|
| 116 |
+
# else:
|
| 117 |
+
# self.synonyms = {}
|
| 118 |
+
# print(f"β οΈ Unknown synonym format")
|
| 119 |
+
|
| 120 |
+
# except Exception as e:
|
| 121 |
+
# print(f"β Error loading synonyms: {e}")
|
| 122 |
+
# self.synonyms = {}
|
| 123 |
+
|
| 124 |
+
# def save_synonyms(self):
|
| 125 |
+
# """Save synonyms in both formats"""
|
| 126 |
+
# try:
|
| 127 |
+
# # Save binary format
|
| 128 |
+
# with open(self.synonyms_file, 'wb') as f:
|
| 129 |
+
# pickle.dump(self.synonyms, f)
|
| 130 |
+
|
| 131 |
+
# # Save readable JSON
|
| 132 |
+
# json_file = self.cache_dir / 'synonyms_readable.json'
|
| 133 |
+
# readable = {}
|
| 134 |
+
# for term, syns in self.synonyms.items():
|
| 135 |
+
# readable[term] = [
|
| 136 |
+
# {'synonym': syn, 'confidence': float(conf), 'source': src}
|
| 137 |
+
# for syn, conf, src in syns
|
| 138 |
+
# ]
|
| 139 |
+
|
| 140 |
+
# with open(json_file, 'w', encoding='utf-8') as f:
|
| 141 |
+
# json.dump(readable, f, indent=2, ensure_ascii=False)
|
| 142 |
+
|
| 143 |
+
# print(f"\nβ
Saved {len(self.synonyms)} synonym entries")
|
| 144 |
+
# print(f" π Binary: {self.synonyms_file}")
|
| 145 |
+
# print(f" π JSON: {json_file}")
|
| 146 |
+
# return True
|
| 147 |
+
# except Exception as e:
|
| 148 |
+
# print(f"β Error saving synonyms: {e}")
|
| 149 |
+
# return False
|
| 150 |
+
|
| 151 |
+
# def load_transformer_model(self):
|
| 152 |
+
# """Load e5-base-v2 model with GPU support"""
|
| 153 |
+
# if not TRANSFORMERS_AVAILABLE:
|
| 154 |
+
# print("β SentenceTransformers not installed!")
|
| 155 |
+
# return False
|
| 156 |
+
|
| 157 |
+
# # Check for CUDA
|
| 158 |
+
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 159 |
+
|
| 160 |
+
# if self.device == "cuda":
|
| 161 |
+
# print(f"π₯ NVIDIA GPU detected!")
|
| 162 |
+
# try:
|
| 163 |
+
# gpu_name = torch.cuda.get_device_name(0)
|
| 164 |
+
# vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 165 |
+
# print(f" GPU: {gpu_name}")
|
| 166 |
+
# print(f" VRAM: {vram_gb:.1f} GB")
|
| 167 |
+
# except:
|
| 168 |
+
# pass
|
| 169 |
+
# else:
|
| 170 |
+
# print("π» Using CPU (slower)")
|
| 171 |
+
|
| 172 |
+
# # Use e5-base-v2 for better memory efficiency
|
| 173 |
+
# model_name = "intfloat/e5-base-v2"
|
| 174 |
+
# print(f"\nπ€ Loading model: {model_name}")
|
| 175 |
+
|
| 176 |
+
# try:
|
| 177 |
+
# self.model = SentenceTransformer(model_name, device=self.device)
|
| 178 |
+
# self.model.max_seq_length = 256
|
| 179 |
+
|
| 180 |
+
# # Use FP16 on GPU for speed
|
| 181 |
+
# if self.device == "cuda":
|
| 182 |
+
# self.model = self.model.half()
|
| 183 |
+
# print("β‘ Enabled FP16 precision")
|
| 184 |
+
|
| 185 |
+
# print("β
Model loaded successfully\n")
|
| 186 |
+
# return True
|
| 187 |
+
# except Exception as e:
|
| 188 |
+
# print(f"β Failed to load model: {e}")
|
| 189 |
+
# return False
|
| 190 |
+
|
| 191 |
+
# def get_wordnet_synonyms(self, word, limit=10):
|
| 192 |
+
# """Get WordNet synonyms"""
|
| 193 |
+
# if self.fast_mode or not WORDNET_AVAILABLE:
|
| 194 |
+
# return []
|
| 195 |
+
|
| 196 |
+
# try:
|
| 197 |
+
# # Ensure WordNet is downloaded
|
| 198 |
+
# try:
|
| 199 |
+
# wordnet.synsets('test')
|
| 200 |
+
# except:
|
| 201 |
+
# print("π₯ Downloading WordNet data...")
|
| 202 |
+
# nltk_download('wordnet', quiet=True)
|
| 203 |
+
# nltk_download('omw-1.4', quiet=True)
|
| 204 |
+
|
| 205 |
+
# synonyms = []
|
| 206 |
+
# word_clean = word.lower().replace(' ', '_')
|
| 207 |
+
|
| 208 |
+
# for syn in wordnet.synsets(word_clean):
|
| 209 |
+
# for lemma in syn.lemmas():
|
| 210 |
+
# synonym = lemma.name().replace('_', ' ').lower()
|
| 211 |
+
# if synonym != word.lower() and len(synonym) > 2:
|
| 212 |
+
# confidence = 0.75 # Fixed confidence for WordNet
|
| 213 |
+
# synonyms.append((synonym, confidence, 'wordnet'))
|
| 214 |
+
# if len(synonyms) >= limit:
|
| 215 |
+
# break
|
| 216 |
+
# if len(synonyms) >= limit:
|
| 217 |
+
# break
|
| 218 |
+
|
| 219 |
+
# return synonyms[:limit]
|
| 220 |
+
# except Exception:
|
| 221 |
+
# return []
|
| 222 |
+
|
| 223 |
+
# def get_semantic_synonyms(self, term, candidate_pool, threshold=0.70, limit=15):
|
| 224 |
+
# """Get semantic synonyms using embeddings"""
|
| 225 |
+
# if not self.model or not candidate_pool:
|
| 226 |
+
# return []
|
| 227 |
+
|
| 228 |
+
# try:
|
| 229 |
+
# # E5 model requires query/passage prefixes
|
| 230 |
+
# query = f"query: {term}"
|
| 231 |
+
# candidates_prefixed = [f"passage: {c}" for c in candidate_pool]
|
| 232 |
+
|
| 233 |
+
# # Encode query
|
| 234 |
+
# term_emb = self.model.encode(
|
| 235 |
+
# query,
|
| 236 |
+
# convert_to_tensor=True,
|
| 237 |
+
# show_progress_bar=False
|
| 238 |
+
# )
|
| 239 |
+
|
| 240 |
+
# # Encode candidates in batches
|
| 241 |
+
# batch_size = 32 if self.device == "cuda" else 8
|
| 242 |
+
# all_embeddings = []
|
| 243 |
+
|
| 244 |
+
# for i in range(0, len(candidates_prefixed), batch_size):
|
| 245 |
+
# batch = candidates_prefixed[i:i + batch_size]
|
| 246 |
+
# emb = self.model.encode(
|
| 247 |
+
# batch,
|
| 248 |
+
# convert_to_tensor=True,
|
| 249 |
+
# show_progress_bar=False
|
| 250 |
+
# )
|
| 251 |
+
# all_embeddings.append(emb)
|
| 252 |
+
|
| 253 |
+
# # Concatenate all embeddings
|
| 254 |
+
# candidate_embs = torch.cat(all_embeddings, dim=0)
|
| 255 |
+
|
| 256 |
+
# # Calculate cosine similarity
|
| 257 |
+
# scores = util.cos_sim(term_emb, candidate_embs)[0]
|
| 258 |
+
|
| 259 |
+
# # Filter by threshold
|
| 260 |
+
# synonyms = []
|
| 261 |
+
# for candidate, score in zip(candidate_pool, scores):
|
| 262 |
+
# score_val = float(score)
|
| 263 |
+
# if score_val > threshold and candidate.lower() != term.lower():
|
| 264 |
+
# # Scale confidence between 0.6 and 0.95
|
| 265 |
+
# confidence = 0.60 + (score_val - threshold) * 0.35 / (1 - threshold)
|
| 266 |
+
# synonyms.append((candidate, confidence, 'semantic'))
|
| 267 |
+
|
| 268 |
+
# # Sort by confidence
|
| 269 |
+
# synonyms.sort(key=lambda x: x[1], reverse=True)
|
| 270 |
+
# return synonyms[:limit]
|
| 271 |
+
|
| 272 |
+
# except Exception as e:
|
| 273 |
+
# print(f"β οΈ Semantic error: {e}")
|
| 274 |
+
# return []
|
| 275 |
+
|
| 276 |
+
# def auto_generate_synonyms(self, term, candidate_pool=None,
|
| 277 |
+
# semantic_threshold=0.70, silent=False):
|
| 278 |
+
# """Generate synonyms from multiple sources"""
|
| 279 |
+
# all_synonyms = []
|
| 280 |
+
|
| 281 |
+
# if not silent:
|
| 282 |
+
# print(f"\nπ Finding synonyms for: '{term}'")
|
| 283 |
+
|
| 284 |
+
# # Source 1: WordNet
|
| 285 |
+
# if WORDNET_AVAILABLE and not self.fast_mode:
|
| 286 |
+
# wn_syns = self.get_wordnet_synonyms(term, limit=10)
|
| 287 |
+
# all_synonyms.extend(wn_syns)
|
| 288 |
+
|
| 289 |
+
# # Source 2: Semantic similarity
|
| 290 |
+
# if candidate_pool and self.model:
|
| 291 |
+
# sem_syns = self.get_semantic_synonyms(
|
| 292 |
+
# term, candidate_pool,
|
| 293 |
+
# threshold=semantic_threshold,
|
| 294 |
+
# limit=15
|
| 295 |
+
# )
|
| 296 |
+
# all_synonyms.extend(sem_syns)
|
| 297 |
+
|
| 298 |
+
# # Deduplicate (keep highest confidence)
|
| 299 |
+
# synonym_map = {}
|
| 300 |
+
# for syn, conf, source in all_synonyms:
|
| 301 |
+
# syn_lower = syn.lower()
|
| 302 |
+
# if syn_lower not in synonym_map or conf > synonym_map[syn_lower][1]:
|
| 303 |
+
# synonym_map[syn_lower] = (syn, conf, source)
|
| 304 |
+
|
| 305 |
+
# final_synonyms = sorted(
|
| 306 |
+
# synonym_map.values(),
|
| 307 |
+
# key=lambda x: x[1],
|
| 308 |
+
# reverse=True
|
| 309 |
+
# )
|
| 310 |
+
|
| 311 |
+
# return final_synonyms
|
| 312 |
+
|
| 313 |
+
# def add_synonym_group(self, term, synonyms_with_confidence):
|
| 314 |
+
# """Add synonym group"""
|
| 315 |
+
# term_lower = term.lower()
|
| 316 |
+
# if term_lower not in self.synonyms:
|
| 317 |
+
# self.synonyms[term_lower] = []
|
| 318 |
+
|
| 319 |
+
# for syn, conf, src in synonyms_with_confidence:
|
| 320 |
+
# # Check if already exists
|
| 321 |
+
# if not any(s[0].lower() == syn.lower() for s in self.synonyms[term_lower]):
|
| 322 |
+
# self.synonyms[term_lower].append((syn, conf, src))
|
| 323 |
+
|
| 324 |
+
# def extract_terms_from_categories(self, csv_path, min_frequency=2):
|
| 325 |
+
# """Extract terms from category CSV"""
|
| 326 |
+
# print(f"\nπ Extracting terms from: {csv_path}")
|
| 327 |
+
|
| 328 |
+
# try:
|
| 329 |
+
# import pandas as pd
|
| 330 |
+
|
| 331 |
+
# # Read CSV
|
| 332 |
+
# df = pd.read_csv(csv_path)
|
| 333 |
+
|
| 334 |
+
# # Find path column (usually second column)
|
| 335 |
+
# path_col = df.columns[1] if len(df.columns) > 1 else df.columns[0]
|
| 336 |
+
# paths = df[path_col].dropna().astype(str)
|
| 337 |
+
|
| 338 |
+
# print(f" Processing {len(paths):,} category paths...")
|
| 339 |
+
|
| 340 |
+
# term_freq = defaultdict(int)
|
| 341 |
+
|
| 342 |
+
# for path in tqdm(paths, desc="Analyzing paths"):
|
| 343 |
+
# levels = path.split('/')
|
| 344 |
+
|
| 345 |
+
# for level in levels:
|
| 346 |
+
# words = level.lower().split()
|
| 347 |
+
|
| 348 |
+
# # Single words
|
| 349 |
+
# for word in words:
|
| 350 |
+
# if len(word) > 2 and word.isalpha():
|
| 351 |
+
# term_freq[word] += 1
|
| 352 |
+
|
| 353 |
+
# # Two-word phrases
|
| 354 |
+
# for i in range(len(words) - 1):
|
| 355 |
+
# if len(words[i]) > 2 and len(words[i+1]) > 2:
|
| 356 |
+
# phrase = f"{words[i]} {words[i+1]}"
|
| 357 |
+
# if phrase.replace(' ', '').isalpha():
|
| 358 |
+
# term_freq[phrase] += 1
|
| 359 |
+
|
| 360 |
+
# # Filter by frequency
|
| 361 |
+
# candidates = [
|
| 362 |
+
# term for term, freq in term_freq.items()
|
| 363 |
+
# if freq >= min_frequency
|
| 364 |
+
# ]
|
| 365 |
+
|
| 366 |
+
# print(f"β
Extracted {len(candidates):,} terms (min frequency: {min_frequency})")
|
| 367 |
+
# return candidates, term_freq
|
| 368 |
+
|
| 369 |
+
# except Exception as e:
|
| 370 |
+
# print(f"β Error extracting terms: {e}")
|
| 371 |
+
# import traceback
|
| 372 |
+
# traceback.print_exc()
|
| 373 |
+
# return [], {}
|
| 374 |
+
|
| 375 |
+
# def auto_build_from_categories(self, csv_path, top_terms=1000,
|
| 376 |
+
# semantic_threshold=0.70):
|
| 377 |
+
# """Auto-build synonym database from categories"""
|
| 378 |
+
# print("\n" + "="*80)
|
| 379 |
+
# print("π AUTO-BUILD SYNONYM DATABASE")
|
| 380 |
+
# print("="*80)
|
| 381 |
+
|
| 382 |
+
# # Load model
|
| 383 |
+
# if not self.load_transformer_model():
|
| 384 |
+
# print("\nβ οΈ Continuing with WordNet only (limited coverage)")
|
| 385 |
+
|
| 386 |
+
# # Extract terms
|
| 387 |
+
# all_terms, term_freq = self.extract_terms_from_categories(csv_path)
|
| 388 |
+
# if not all_terms:
|
| 389 |
+
# print("β No terms extracted")
|
| 390 |
+
# return False
|
| 391 |
+
|
| 392 |
+
# # Select top terms
|
| 393 |
+
# print(f"\nπ― Selecting top {top_terms} terms...")
|
| 394 |
+
# top_frequent = sorted(
|
| 395 |
+
# term_freq.items(),
|
| 396 |
+
# key=lambda x: x[1],
|
| 397 |
+
# reverse=True
|
| 398 |
+
# )[:top_terms]
|
| 399 |
+
# terms_to_process = [term for term, _ in top_frequent]
|
| 400 |
+
|
| 401 |
+
# print(f"β
Selected {len(terms_to_process)} terms")
|
| 402 |
+
# print(f"π Top 10: {', '.join(terms_to_process[:10])}")
|
| 403 |
+
# print(f"\nπ Generating synonyms (threshold={semantic_threshold})...\n")
|
| 404 |
+
|
| 405 |
+
# # Process terms
|
| 406 |
+
# stats = {
|
| 407 |
+
# 'processed': 0,
|
| 408 |
+
# 'synonyms': 0,
|
| 409 |
+
# 'high_conf': 0
|
| 410 |
+
# }
|
| 411 |
+
|
| 412 |
+
# for term in tqdm(terms_to_process, desc="Processing"):
|
| 413 |
+
# # Skip if already has enough synonyms
|
| 414 |
+
# if term in self.synonyms and len(self.synonyms[term]) >= 10:
|
| 415 |
+
# continue
|
| 416 |
+
|
| 417 |
+
# # Generate synonyms
|
| 418 |
+
# syns = self.auto_generate_synonyms(
|
| 419 |
+
# term,
|
| 420 |
+
# candidate_pool=all_terms,
|
| 421 |
+
# semantic_threshold=semantic_threshold,
|
| 422 |
+
# silent=True
|
| 423 |
+
# )
|
| 424 |
+
|
| 425 |
+
# if syns:
|
| 426 |
+
# self.add_synonym_group(term, syns)
|
| 427 |
+
# stats['processed'] += 1
|
| 428 |
+
# stats['synonyms'] += len(syns)
|
| 429 |
+
# stats['high_conf'] += sum(1 for _, c, _ in syns if c >= 0.8)
|
| 430 |
+
|
| 431 |
+
# # Print stats
|
| 432 |
+
# print(f"\nβ
Processed: {stats['processed']:,} terms")
|
| 433 |
+
# print(f"β
Total synonyms: {stats['synonyms']:,}")
|
| 434 |
+
# print(f"β
High confidence (β₯0.8): {stats['high_conf']:,}")
|
| 435 |
+
|
| 436 |
+
# # Save
|
| 437 |
+
# self.save_synonyms()
|
| 438 |
+
|
| 439 |
+
# print("\nπ AUTO-BUILD COMPLETE!\n")
|
| 440 |
+
# return True
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
# def main():
|
| 444 |
+
# """Main entry point"""
|
| 445 |
+
# print("\n" + "="*80)
|
| 446 |
+
# print("π€ AI-POWERED SYNONYM MANAGER (Windows + NVIDIA GPU)")
|
| 447 |
+
# print("="*80 + "\n")
|
| 448 |
+
|
| 449 |
+
# # Parse arguments
|
| 450 |
+
# fast_mode = '--fast' in sys.argv
|
| 451 |
+
|
| 452 |
+
# if len(sys.argv) < 2:
|
| 453 |
+
# print("Usage:")
|
| 454 |
+
# print(" python synonym_manager_fixed.py autobuild <csv_file>")
|
| 455 |
+
# print(" python synonym_manager_fixed.py autobuild <csv_file> --fast")
|
| 456 |
+
# print("\nExample:")
|
| 457 |
+
# print(" python synonym_manager_fixed.py autobuild data/category_id_path_only.csv")
|
| 458 |
+
# return
|
| 459 |
+
|
| 460 |
+
# command = sys.argv[1].lower()
|
| 461 |
+
|
| 462 |
+
# if command == 'autobuild':
|
| 463 |
+
# if len(sys.argv) < 3:
|
| 464 |
+
# print("β CSV file path required")
|
| 465 |
+
# return
|
| 466 |
+
|
| 467 |
+
# csv_path = sys.argv[2]
|
| 468 |
+
|
| 469 |
+
# if not Path(csv_path).exists():
|
| 470 |
+
# print(f"β File not found: {csv_path}")
|
| 471 |
+
# return
|
| 472 |
+
|
| 473 |
+
# # Initialize manager
|
| 474 |
+
# manager = FixedAISynonymManager(fast_mode=fast_mode)
|
| 475 |
+
|
| 476 |
+
# # Run auto-build
|
| 477 |
+
# manager.auto_build_from_categories(csv_path, top_terms=1000)
|
| 478 |
+
|
| 479 |
+
# else:
|
| 480 |
+
# print(f"β Unknown command: {command}")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# if __name__ == "__main__":
|
| 484 |
+
# main()
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
#for cache2
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
"""
|
| 491 |
+
π€ AI-POWERED SYNONYM MANAGER (Fixed for Windows + GPU)
|
| 492 |
+
========================================================
|
| 493 |
+
β
Uses e5-base-v2 (768D, memory-efficient)
|
| 494 |
+
β
Windows + NVIDIA GPU optimized
|
| 495 |
+
β
Generates cross-store synonyms automatically
|
| 496 |
+
|
| 497 |
+
Usage:
|
| 498 |
+
python synonym_manager_fixed.py autobuild data/category_id_path_only.csv
|
| 499 |
+
python synonym_manager_fixed.py autobuild data/category_id_path_only.csv --fast
|
| 500 |
+
"""
|
| 501 |
+
|
| 502 |
+
import pickle
|
| 503 |
+
from pathlib import Path
|
| 504 |
+
import json
|
| 505 |
+
from collections import defaultdict
|
| 506 |
+
from tqdm import tqdm
|
| 507 |
+
import warnings
|
| 508 |
+
import sys
|
| 509 |
+
import os
|
| 510 |
+
|
| 511 |
+
warnings.filterwarnings('ignore')
|
| 512 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
| 513 |
+
|
| 514 |
+
try:
|
| 515 |
+
from nltk.corpus import wordnet
|
| 516 |
+
from nltk import download as nltk_download
|
| 517 |
+
WORDNET_AVAILABLE = True
|
| 518 |
+
except ImportError:
|
| 519 |
+
WORDNET_AVAILABLE = False
|
| 520 |
+
|
| 521 |
+
try:
|
| 522 |
+
from sentence_transformers import SentenceTransformer, util
|
| 523 |
+
import torch
|
| 524 |
+
TRANSFORMERS_AVAILABLE = True
|
| 525 |
+
except ImportError:
|
| 526 |
+
TRANSFORMERS_AVAILABLE = False
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class SynonymManager:
|
| 530 |
+
"""AI-powered synonym manager"""
|
| 531 |
+
|
| 532 |
+
def __init__(self, cache_dir='cache', fast_mode=False):
|
| 533 |
+
self.cache_dir = Path(cache_dir)
|
| 534 |
+
self.synonyms_file = self.cache_dir / 'cross_store_synonyms.pkl'
|
| 535 |
+
self.synonyms = {}
|
| 536 |
+
self.model = None
|
| 537 |
+
self.device = "cpu"
|
| 538 |
+
self.fast_mode = fast_mode
|
| 539 |
+
|
| 540 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 541 |
+
|
| 542 |
+
if self.synonyms_file.exists():
|
| 543 |
+
self.load_synonyms()
|
| 544 |
+
|
| 545 |
+
def load_synonyms(self):
|
| 546 |
+
"""Load existing synonyms"""
|
| 547 |
+
try:
|
| 548 |
+
with open(self.synonyms_file, 'rb') as f:
|
| 549 |
+
loaded = pickle.load(f)
|
| 550 |
+
|
| 551 |
+
if loaded and list(loaded.values()):
|
| 552 |
+
first_val = next(iter(loaded.values()))
|
| 553 |
+
|
| 554 |
+
if isinstance(first_val, list) and first_val:
|
| 555 |
+
if isinstance(first_val[0], tuple):
|
| 556 |
+
self.synonyms = loaded
|
| 557 |
+
else:
|
| 558 |
+
self.synonyms = {k: [(v, 0.8, 'legacy') for v in vals] for k, vals in loaded.items()}
|
| 559 |
+
elif isinstance(first_val, set):
|
| 560 |
+
self.synonyms = {k: [(v, 0.8, 'legacy') for v in vals] for k, vals in loaded.items()}
|
| 561 |
+
|
| 562 |
+
print(f"β
Loaded {len(self.synonyms):,} synonym entries")
|
| 563 |
+
except Exception as e:
|
| 564 |
+
print(f"β Error loading synonyms: {e}")
|
| 565 |
+
self.synonyms = {}
|
| 566 |
+
|
| 567 |
+
def save_synonyms(self):
|
| 568 |
+
"""Save synonyms"""
|
| 569 |
+
try:
|
| 570 |
+
with open(self.synonyms_file, 'wb') as f:
|
| 571 |
+
pickle.dump(self.synonyms, f)
|
| 572 |
+
|
| 573 |
+
json_file = self.cache_dir / 'synonyms_readable.json'
|
| 574 |
+
readable = {
|
| 575 |
+
term: [
|
| 576 |
+
{'synonym': syn, 'confidence': conf, 'source': src}
|
| 577 |
+
for syn, conf, src in syns
|
| 578 |
+
]
|
| 579 |
+
for term, syns in self.synonyms.items()
|
| 580 |
+
}
|
| 581 |
+
with open(json_file, 'w', encoding='utf-8') as f:
|
| 582 |
+
json.dump(readable, f, indent=2, ensure_ascii=False)
|
| 583 |
+
|
| 584 |
+
print(f"β
Saved {len(self.synonyms):,} synonym entries")
|
| 585 |
+
return True
|
| 586 |
+
except Exception as e:
|
| 587 |
+
print(f"β Error saving synonyms: {e}")
|
| 588 |
+
return False
|
| 589 |
+
|
| 590 |
+
def load_transformer_model(self):
|
| 591 |
+
"""Load e5-base-v2 model"""
|
| 592 |
+
if not TRANSFORMERS_AVAILABLE:
|
| 593 |
+
print("β SentenceTransformers not installed!")
|
| 594 |
+
return False
|
| 595 |
+
|
| 596 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 597 |
+
|
| 598 |
+
if self.device == "cuda":
|
| 599 |
+
print(f"π₯ NVIDIA GPU detected!")
|
| 600 |
+
|
| 601 |
+
model_name = "intfloat/e5-base-v2"
|
| 602 |
+
print(f"\nπ€ Loading {model_name}...")
|
| 603 |
+
|
| 604 |
+
try:
|
| 605 |
+
self.model = SentenceTransformer(model_name, device=self.device)
|
| 606 |
+
|
| 607 |
+
if self.device == "cuda":
|
| 608 |
+
self.model = self.model.half()
|
| 609 |
+
print("β‘ Enabled FP16 precision")
|
| 610 |
+
|
| 611 |
+
print("β
Model loaded\n")
|
| 612 |
+
return True
|
| 613 |
+
except Exception as e:
|
| 614 |
+
print(f"β Failed to load model: {e}")
|
| 615 |
+
return False
|
| 616 |
+
|
| 617 |
+
def get_wordnet_synonyms(self, word, limit=10):
|
| 618 |
+
"""Get WordNet synonyms"""
|
| 619 |
+
if self.fast_mode or not WORDNET_AVAILABLE:
|
| 620 |
+
return []
|
| 621 |
+
|
| 622 |
+
try:
|
| 623 |
+
try:
|
| 624 |
+
wordnet.synsets('test')
|
| 625 |
+
except:
|
| 626 |
+
nltk_download('wordnet', quiet=True)
|
| 627 |
+
nltk_download('omw-1.4', quiet=True)
|
| 628 |
+
|
| 629 |
+
synonyms = []
|
| 630 |
+
word_clean = word.lower().replace(' ', '_')
|
| 631 |
+
|
| 632 |
+
for syn in wordnet.synsets(word_clean):
|
| 633 |
+
for lemma in syn.lemmas():
|
| 634 |
+
synonym = lemma.name().replace('_', ' ').lower()
|
| 635 |
+
if synonym != word.lower() and len(synonym) > 2:
|
| 636 |
+
confidence = 0.75
|
| 637 |
+
synonyms.append((synonym, confidence, 'wordnet'))
|
| 638 |
+
if len(synonyms) >= limit:
|
| 639 |
+
break
|
| 640 |
+
if len(synonyms) >= limit:
|
| 641 |
+
break
|
| 642 |
+
|
| 643 |
+
return synonyms[:limit]
|
| 644 |
+
except Exception:
|
| 645 |
+
return []
|
| 646 |
+
|
| 647 |
+
def get_semantic_synonyms(self, term, candidate_pool, threshold=0.70, limit=15):
|
| 648 |
+
"""Get semantic synonyms using E5"""
|
| 649 |
+
if not self.model or not candidate_pool:
|
| 650 |
+
return []
|
| 651 |
+
|
| 652 |
+
try:
|
| 653 |
+
query = f"query: {term}"
|
| 654 |
+
candidates_prefixed = [f"passage: {c}" for c in candidate_pool]
|
| 655 |
+
|
| 656 |
+
term_emb = self.model.encode(query, convert_to_tensor=True, show_progress_bar=False)
|
| 657 |
+
|
| 658 |
+
batch_size = 32 if self.device == "cuda" else 8
|
| 659 |
+
all_embeddings = []
|
| 660 |
+
|
| 661 |
+
for i in range(0, len(candidates_prefixed), batch_size):
|
| 662 |
+
batch = candidates_prefixed[i:i + batch_size]
|
| 663 |
+
emb = self.model.encode(batch, convert_to_tensor=True, show_progress_bar=False)
|
| 664 |
+
all_embeddings.append(emb)
|
| 665 |
+
|
| 666 |
+
candidate_embs = torch.cat(all_embeddings, dim=0)
|
| 667 |
+
scores = util.cos_sim(term_emb, candidate_embs)[0]
|
| 668 |
+
|
| 669 |
+
synonyms = []
|
| 670 |
+
for candidate, score in zip(candidate_pool, scores):
|
| 671 |
+
score_val = float(score)
|
| 672 |
+
if score_val > threshold and candidate.lower() != term.lower():
|
| 673 |
+
confidence = 0.60 + (score_val - threshold) * 0.35 / (1 - threshold)
|
| 674 |
+
synonyms.append((candidate, confidence, 'semantic'))
|
| 675 |
+
|
| 676 |
+
synonyms.sort(key=lambda x: x[1], reverse=True)
|
| 677 |
+
return synonyms[:limit]
|
| 678 |
+
|
| 679 |
+
except Exception as e:
|
| 680 |
+
print(f"β οΈ Semantic error: {e}")
|
| 681 |
+
return []
|
| 682 |
+
|
| 683 |
+
def auto_generate_synonyms(self, term, candidate_pool=None, semantic_threshold=0.70, silent=False):
|
| 684 |
+
"""Generate synonyms from multiple sources"""
|
| 685 |
+
all_synonyms = []
|
| 686 |
+
|
| 687 |
+
if not silent:
|
| 688 |
+
print(f"\nπ Finding synonyms for: '{term}'")
|
| 689 |
+
|
| 690 |
+
if WORDNET_AVAILABLE and not self.fast_mode:
|
| 691 |
+
wn_syns = self.get_wordnet_synonyms(term, limit=10)
|
| 692 |
+
all_synonyms.extend(wn_syns)
|
| 693 |
+
|
| 694 |
+
if candidate_pool and self.model:
|
| 695 |
+
sem_syns = self.get_semantic_synonyms(
|
| 696 |
+
term, candidate_pool,
|
| 697 |
+
threshold=semantic_threshold,
|
| 698 |
+
limit=15
|
| 699 |
+
)
|
| 700 |
+
all_synonyms.extend(sem_syns)
|
| 701 |
+
|
| 702 |
+
synonym_map = {}
|
| 703 |
+
for syn, conf, source in all_synonyms:
|
| 704 |
+
syn_lower = syn.lower()
|
| 705 |
+
if syn_lower not in synonym_map or conf > synonym_map[syn_lower][1]:
|
| 706 |
+
synonym_map[syn_lower] = (syn, conf, source)
|
| 707 |
+
|
| 708 |
+
final_synonyms = sorted(synonym_map.values(), key=lambda x: x[1], reverse=True)
|
| 709 |
+
return final_synonyms
|
| 710 |
+
|
| 711 |
+
def add_synonym_group(self, term, synonyms_with_confidence):
|
| 712 |
+
"""Add synonym group"""
|
| 713 |
+
term_lower = term.lower()
|
| 714 |
+
if term_lower not in self.synonyms:
|
| 715 |
+
self.synonyms[term_lower] = []
|
| 716 |
+
|
| 717 |
+
for syn, conf, src in synonyms_with_confidence:
|
| 718 |
+
if not any(s[0].lower() == syn.lower() for s in self.synonyms[term_lower]):
|
| 719 |
+
self.synonyms[term_lower].append((syn, conf, src))
|
| 720 |
+
|
| 721 |
+
def extract_terms_from_categories(self, csv_path, min_frequency=2):
|
| 722 |
+
"""Extract terms from category CSV"""
|
| 723 |
+
print(f"\nπ Extracting terms from: {csv_path}")
|
| 724 |
+
|
| 725 |
+
try:
|
| 726 |
+
import pandas as pd
|
| 727 |
+
|
| 728 |
+
df = pd.read_csv(csv_path)
|
| 729 |
+
path_col = df.columns[1] if len(df.columns) > 1 else df.columns[0]
|
| 730 |
+
paths = df[path_col].dropna().astype(str)
|
| 731 |
+
|
| 732 |
+
print(f" Processing {len(paths):,} category paths...")
|
| 733 |
+
|
| 734 |
+
term_freq = defaultdict(int)
|
| 735 |
+
|
| 736 |
+
for path in tqdm(paths, desc="Analyzing paths"):
|
| 737 |
+
levels = path.split('/')
|
| 738 |
+
|
| 739 |
+
for level in levels:
|
| 740 |
+
words = level.lower().split()
|
| 741 |
+
|
| 742 |
+
for word in words:
|
| 743 |
+
if len(word) > 2 and word.isalpha():
|
| 744 |
+
term_freq[word] += 1
|
| 745 |
+
|
| 746 |
+
for i in range(len(words) - 1):
|
| 747 |
+
if len(words[i]) > 2 and len(words[i+1]) > 2:
|
| 748 |
+
phrase = f"{words[i]} {words[i+1]}"
|
| 749 |
+
if phrase.replace(' ', '').isalpha():
|
| 750 |
+
term_freq[phrase] += 1
|
| 751 |
+
|
| 752 |
+
candidates = [
|
| 753 |
+
term for term, freq in term_freq.items()
|
| 754 |
+
if freq >= min_frequency
|
| 755 |
+
]
|
| 756 |
+
|
| 757 |
+
print(f"β
Extracted {len(candidates):,} terms (min frequency: {min_frequency})")
|
| 758 |
+
return candidates, term_freq
|
| 759 |
+
|
| 760 |
+
except Exception as e:
|
| 761 |
+
print(f"β Error extracting terms: {e}")
|
| 762 |
+
import traceback
|
| 763 |
+
traceback.print_exc()
|
| 764 |
+
return [], {}
|
| 765 |
+
|
| 766 |
+
def auto_build_from_categories(self, csv_path, top_terms=1000, semantic_threshold=0.70):
|
| 767 |
+
"""Auto-build synonym database"""
|
| 768 |
+
print("\n" + "="*80)
|
| 769 |
+
print("π AUTO-BUILD SYNONYM DATABASE")
|
| 770 |
+
print("="*80)
|
| 771 |
+
|
| 772 |
+
if not self.load_transformer_model():
|
| 773 |
+
print("\nβ οΈ Continuing with WordNet only")
|
| 774 |
+
|
| 775 |
+
all_terms, term_freq = self.extract_terms_from_categories(csv_path)
|
| 776 |
+
if not all_terms:
|
| 777 |
+
print("β No terms extracted")
|
| 778 |
+
return False
|
| 779 |
+
|
| 780 |
+
print(f"\nπ― Selecting top {top_terms} terms...")
|
| 781 |
+
top_frequent = sorted(term_freq.items(), key=lambda x: x[1], reverse=True)[:top_terms]
|
| 782 |
+
terms_to_process = [term for term, _ in top_frequent]
|
| 783 |
+
|
| 784 |
+
print(f"β
Selected {len(terms_to_process)} terms")
|
| 785 |
+
print(f"π Top 10: {', '.join(terms_to_process[:10])}")
|
| 786 |
+
print(f"\nπ Generating synonyms (threshold={semantic_threshold})...\n")
|
| 787 |
+
|
| 788 |
+
stats = {'processed': 0, 'synonyms': 0, 'high_conf': 0}
|
| 789 |
+
|
| 790 |
+
for term in tqdm(terms_to_process, desc="Processing"):
|
| 791 |
+
if term in self.synonyms and len(self.synonyms[term]) >= 10:
|
| 792 |
+
continue
|
| 793 |
+
|
| 794 |
+
syns = self.auto_generate_synonyms(
|
| 795 |
+
term,
|
| 796 |
+
candidate_pool=all_terms,
|
| 797 |
+
semantic_threshold=semantic_threshold,
|
| 798 |
+
silent=True
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
if syns:
|
| 802 |
+
self.add_synonym_group(term, syns)
|
| 803 |
+
stats['processed'] += 1
|
| 804 |
+
stats['synonyms'] += len(syns)
|
| 805 |
+
stats['high_conf'] += sum(1 for _, c, _ in syns if c >= 0.8)
|
| 806 |
+
|
| 807 |
+
print(f"\nβ
Processed: {stats['processed']:,} terms")
|
| 808 |
+
print(f"β
Total synonyms: {stats['synonyms']:,}")
|
| 809 |
+
print(f"β
High confidence (β₯0.8): {stats['high_conf']:,}")
|
| 810 |
+
|
| 811 |
+
self.save_synonyms()
|
| 812 |
+
|
| 813 |
+
print("\nπ AUTO-BUILD COMPLETE!\n")
|
| 814 |
+
return True
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
def main():
|
| 818 |
+
"""Main entry point"""
|
| 819 |
+
print("\n" + "="*80)
|
| 820 |
+
print("π€ AI-POWERED SYNONYM MANAGER")
|
| 821 |
+
print("="*80 + "\n")
|
| 822 |
+
|
| 823 |
+
fast_mode = '--fast' in sys.argv
|
| 824 |
+
|
| 825 |
+
if len(sys.argv) < 2:
|
| 826 |
+
print("Usage:")
|
| 827 |
+
print(" python synonym_manager_fixed.py autobuild <csv_file>")
|
| 828 |
+
print(" python synonym_manager_fixed.py autobuild <csv_file> --fast")
|
| 829 |
+
print("\nExample:")
|
| 830 |
+
print(" python synonym_manager_fixed.py autobuild data/category_id_path_only.csv")
|
| 831 |
+
return
|
| 832 |
+
|
| 833 |
+
command = sys.argv[1].lower()
|
| 834 |
+
|
| 835 |
+
if command == 'autobuild':
|
| 836 |
+
if len(sys.argv) < 3:
|
| 837 |
+
print("β CSV file path required")
|
| 838 |
+
return
|
| 839 |
+
|
| 840 |
+
csv_path = sys.argv[2]
|
| 841 |
+
|
| 842 |
+
if not Path(csv_path).exists():
|
| 843 |
+
print(f"β File not found: {csv_path}")
|
| 844 |
+
return
|
| 845 |
+
|
| 846 |
+
manager = SynonymManager(fast_mode=fast_mode)
|
| 847 |
+
manager.auto_build_from_categories(csv_path, top_terms=1000)
|
| 848 |
+
|
| 849 |
+
else:
|
| 850 |
+
print(f"β Unknown command: {command}")
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
if __name__ == "__main__":
|
| 854 |
main()
|
train_products.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
train.py
|
| 4 |
+
Build normalized embeddings + FAISS index for category catalog,
|
| 5 |
+
build parent embeddings, save synonyms from tags.json and optionally
|
| 6 |
+
train a LightGBM classifier and a simple confidence calibrator.
|
| 7 |
+
|
| 8 |
+
Assumptions / Files:
|
| 9 |
+
- categories CSV: category_only_path.csv (Category_ID,Category_path,Final_Category)
|
| 10 |
+
- optional: data/tags.json (map category_id -> list of phrases)
|
| 11 |
+
- optional: validation.csv (columns: product_title,category_id) used for calibrator / classifier
|
| 12 |
+
|
| 13 |
+
Outputs to ./cache:
|
| 14 |
+
- main_index.faiss
|
| 15 |
+
- metadata.pkl
|
| 16 |
+
- parent_embeddings.pkl
|
| 17 |
+
- cross_store_synonyms.pkl
|
| 18 |
+
- model_info.json
|
| 19 |
+
- calibrator.pkl (if validation exists)
|
| 20 |
+
- classifier.pkl (if --train-classifier used)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import pickle
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import List, Dict
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import pandas as pd
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
# sentence-transformers + faiss
|
| 35 |
+
from sentence_transformers import SentenceTransformer
|
| 36 |
+
import faiss
|
| 37 |
+
|
| 38 |
+
# sklearn for calibrator and simple preprocessing
|
| 39 |
+
from sklearn.linear_model import LogisticRegression
|
| 40 |
+
from sklearn.preprocessing import StandardScaler
|
| 41 |
+
from sklearn.model_selection import train_test_split
|
| 42 |
+
|
| 43 |
+
# optional LightGBM (install if you plan to train classifier)
|
| 44 |
+
try:
|
| 45 |
+
import importlib
|
| 46 |
+
lgb = importlib.import_module("lightgbm")
|
| 47 |
+
LGB_AVAILABLE = True
|
| 48 |
+
except Exception:
|
| 49 |
+
lgb = None
|
| 50 |
+
LGB_AVAILABLE = False
|
| 51 |
+
|
| 52 |
+
CACHE_DIR = Path("cache")
|
| 53 |
+
CACHE_DIR.mkdir(exist_ok=True, parents=True)
|
| 54 |
+
|
| 55 |
+
DEFAULT_BATCH_SIZE_CPU = 256
|
| 56 |
+
DEFAULT_BATCH_SIZE_GPU = 16
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def normalize_path_sep(path: str) -> str:
|
| 60 |
+
if not isinstance(path, str):
|
| 61 |
+
return ""
|
| 62 |
+
s = path.strip()
|
| 63 |
+
s = s.replace("/", " > ")
|
| 64 |
+
s = " > ".join([p.strip() for p in s.split(">") if p.strip()])
|
| 65 |
+
return s
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def path_to_levels(path: str) -> List[str]:
|
| 69 |
+
n = normalize_path_sep(path)
|
| 70 |
+
return [p.strip() for p in n.split(" > ") if p.strip()]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def safe_pickle_save(obj, p: Path):
|
| 74 |
+
with open(p, "wb") as f:
|
| 75 |
+
pickle.dump(obj, f)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_encoder(model_name: str, use_cuda: bool):
|
| 79 |
+
device = "cuda" if use_cuda else "cpu"
|
| 80 |
+
print(f"Loading encoder: {model_name} on {device}")
|
| 81 |
+
model = SentenceTransformer(model_name, device=device)
|
| 82 |
+
if use_cuda:
|
| 83 |
+
try:
|
| 84 |
+
import torch
|
| 85 |
+
model = model.half()
|
| 86 |
+
print("Using FP16 on GPU to conserve VRAM.")
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
| 89 |
+
return model
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def encode_texts(model: SentenceTransformer, texts: List[str], use_cuda: bool) -> np.ndarray:
|
| 93 |
+
batch_size = DEFAULT_BATCH_SIZE_GPU if use_cuda else DEFAULT_BATCH_SIZE_CPU
|
| 94 |
+
print(f"Encoding {len(texts):,} texts in batches of {batch_size} ...")
|
| 95 |
+
all_emb = []
|
| 96 |
+
for i in tqdm(range(0, len(texts), batch_size)):
|
| 97 |
+
batch = texts[i:i + batch_size]
|
| 98 |
+
emb = model.encode(batch, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
|
| 99 |
+
if emb.ndim == 1:
|
| 100 |
+
emb = emb.reshape(1, -1)
|
| 101 |
+
all_emb.append(emb.astype("float32"))
|
| 102 |
+
embeddings = np.vstack(all_emb)
|
| 103 |
+
print("Final embeddings shape:", embeddings.shape)
|
| 104 |
+
return embeddings
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_faiss_index(np_emb: np.ndarray, use_gpu: bool = False):
|
| 108 |
+
d = np_emb.shape[1]
|
| 109 |
+
print(f"Building IndexFlatIP (d={d}) on {'GPU' if use_gpu else 'CPU'}")
|
| 110 |
+
index = faiss.IndexFlatIP(d)
|
| 111 |
+
if use_gpu:
|
| 112 |
+
try:
|
| 113 |
+
res = faiss.StandardGpuResources()
|
| 114 |
+
index = faiss.index_cpu_to_gpu(res, 0, index)
|
| 115 |
+
print("Converted FAISS index to GPU")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print("GPU conversion failed; using CPU index:", e)
|
| 118 |
+
index.add(np_emb)
|
| 119 |
+
print("Index ntotal:", index.ntotal)
|
| 120 |
+
return index
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def make_parent_embeddings(metadata: List[Dict], embeddings: np.ndarray) -> Dict[str, np.ndarray]:
|
| 124 |
+
"""
|
| 125 |
+
For each possible parent path (every prefix), average embeddings of its children.
|
| 126 |
+
This helps hierarchical boosting during inference.
|
| 127 |
+
"""
|
| 128 |
+
parent_map = {}
|
| 129 |
+
count_map = {}
|
| 130 |
+
for i, meta in enumerate(metadata):
|
| 131 |
+
levels = meta.get("levels", [])
|
| 132 |
+
for depth in range(1, len(levels)):
|
| 133 |
+
parent = " > ".join(levels[:depth])
|
| 134 |
+
if not parent:
|
| 135 |
+
continue
|
| 136 |
+
parent_map.setdefault(parent, np.zeros(embeddings.shape[1], dtype="float32"))
|
| 137 |
+
count_map.setdefault(parent, 0)
|
| 138 |
+
parent_map[parent] += embeddings[i]
|
| 139 |
+
count_map[parent] += 1
|
| 140 |
+
|
| 141 |
+
# average + normalize
|
| 142 |
+
from numpy.linalg import norm
|
| 143 |
+
final = {}
|
| 144 |
+
for k, vec in parent_map.items():
|
| 145 |
+
cnt = count_map.get(k, 1)
|
| 146 |
+
avg = vec / float(cnt)
|
| 147 |
+
nrm = np.linalg.norm(avg) + 1e-12
|
| 148 |
+
final[k] = (avg / nrm).astype("float32")
|
| 149 |
+
return final
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_tags_json(path: Path) -> Dict[str, List[str]]:
|
| 153 |
+
if not path.exists():
|
| 154 |
+
return {}
|
| 155 |
+
try:
|
| 156 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 157 |
+
data = json.load(f)
|
| 158 |
+
# ensure keys are strings
|
| 159 |
+
return {str(k): [str(x) for x in v] for k, v in data.items()}
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print("Failed to load tags.json:", e)
|
| 162 |
+
return {}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def train_calibrator(encoder, metadata, faiss_index, val_path: Path, model_name: str, use_cuda: bool):
|
| 166 |
+
"""
|
| 167 |
+
Build a simple calibrator mapping raw cosine similarity of (product -> true category emb)
|
| 168 |
+
to a probability. Uses sklearn LogisticRegression on one feature (raw_score).
|
| 169 |
+
Expects validation.csv with columns product_title,category_id
|
| 170 |
+
"""
|
| 171 |
+
print("Training calibrator using:", val_path)
|
| 172 |
+
df = pd.read_csv(val_path, dtype=str, keep_default_na=False)
|
| 173 |
+
if "product_title" not in df.columns or "category_id" not in df.columns:
|
| 174 |
+
print("validation.csv must have 'product_title' and 'category_id' columns. Skipping calibrator.")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
examples = []
|
| 178 |
+
labels = []
|
| 179 |
+
# Build a mapping category_id -> embedding (from metadata)
|
| 180 |
+
id_to_idx = {m["category_id"]: i for i, m in enumerate(metadata)}
|
| 181 |
+
|
| 182 |
+
# prepare product embeddings in batches
|
| 183 |
+
titles = df["product_title"].astype(str).tolist()
|
| 184 |
+
prod_embs = encode_texts(encoder, [f"query: {t}" for t in titles], use_cuda=use_cuda)
|
| 185 |
+
|
| 186 |
+
for i, row in df.iterrows():
|
| 187 |
+
cid = str(row["category_id"]).strip()
|
| 188 |
+
if cid not in id_to_idx:
|
| 189 |
+
# not in catalog, skip sample
|
| 190 |
+
continue
|
| 191 |
+
cat_idx = id_to_idx[cid]
|
| 192 |
+
cat_emb = metadata[cat_idx].get("_embedding") # we will attach embeddings later temporarily
|
| 193 |
+
if cat_emb is None:
|
| 194 |
+
continue
|
| 195 |
+
q_emb = prod_embs[i].reshape(1, -1).astype("float32")
|
| 196 |
+
raw = float(np.dot(q_emb, cat_emb.reshape(-1, 1))[0][0]) # cosine because normalized
|
| 197 |
+
# positive
|
| 198 |
+
examples.append([raw])
|
| 199 |
+
labels.append(1)
|
| 200 |
+
|
| 201 |
+
# generate few negatives by sampling other categories
|
| 202 |
+
# sample up to 2 random negatives
|
| 203 |
+
negs = 2
|
| 204 |
+
for _ in range(negs):
|
| 205 |
+
import random
|
| 206 |
+
rand_idx = random.randrange(len(metadata))
|
| 207 |
+
if rand_idx == cat_idx:
|
| 208 |
+
continue
|
| 209 |
+
neg_emb = metadata[rand_idx].get("_embedding")
|
| 210 |
+
if neg_emb is None:
|
| 211 |
+
continue
|
| 212 |
+
raw_neg = float(np.dot(q_emb, neg_emb.reshape(-1, 1))[0][0])
|
| 213 |
+
examples.append([raw_neg])
|
| 214 |
+
labels.append(0)
|
| 215 |
+
|
| 216 |
+
if not examples:
|
| 217 |
+
print("No examples for calibrator (maybe category ids mismatch). Skipping.")
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
X = np.array(examples, dtype="float32")
|
| 221 |
+
y = np.array(labels, dtype="int8")
|
| 222 |
+
scaler = StandardScaler()
|
| 223 |
+
Xs = scaler.fit_transform(X)
|
| 224 |
+
clf = LogisticRegression(max_iter=200)
|
| 225 |
+
clf.fit(Xs, y)
|
| 226 |
+
print("Calibrator trained (logistic regression on raw cosine).")
|
| 227 |
+
return {"calibrator": clf, "scaler": scaler}
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def attach_embeddings_to_metadata(metadata: List[Dict], embeddings: np.ndarray):
|
| 231 |
+
for i, m in enumerate(metadata):
|
| 232 |
+
m["_embedding"] = embeddings[i]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def detach_embeddings_from_metadata(metadata: List[Dict]):
|
| 236 |
+
for m in metadata:
|
| 237 |
+
if "_embedding" in m:
|
| 238 |
+
del m["_embedding"]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def main():
|
| 242 |
+
parser = argparse.ArgumentParser()
|
| 243 |
+
parser.add_argument("--csv", required=True, help="categories CSV (Category_ID,Category_path,Final_Category)")
|
| 244 |
+
parser.add_argument("--model", default="intfloat/e5-base-v2", help="embedding model")
|
| 245 |
+
parser.add_argument("--gpu", action="store_true", help="use GPU for encoding if available (careful with 4GB)")
|
| 246 |
+
parser.add_argument("--clean-cache", action="store_true", help="delete other cache files after build")
|
| 247 |
+
parser.add_argument("--train-classifier", action="store_true", help="train LightGBM classifier on validation.csv (optional)")
|
| 248 |
+
parser.add_argument("--validation", default="data/validation.csv", help="validation CSV used for calibrator / classifier")
|
| 249 |
+
parser.add_argument("--tags", default="data/tags.json", help="tags.json path (optional)")
|
| 250 |
+
args = parser.parse_args()
|
| 251 |
+
|
| 252 |
+
csv_path = Path(args.csv)
|
| 253 |
+
if not csv_path.exists():
|
| 254 |
+
raise SystemExit("CSV not found: " + str(csv_path))
|
| 255 |
+
|
| 256 |
+
print("Reading CSV:", csv_path)
|
| 257 |
+
df = pd.read_csv(csv_path, dtype=str, keep_default_na=False)
|
| 258 |
+
if df.shape[1] < 2:
|
| 259 |
+
raise SystemExit("CSV must have at least 2 columns: Category_ID, Category_path")
|
| 260 |
+
|
| 261 |
+
# columns
|
| 262 |
+
cols = list(df.columns)
|
| 263 |
+
cid_col, path_col = cols[0], cols[1]
|
| 264 |
+
print("Using columns:", cid_col, path_col)
|
| 265 |
+
|
| 266 |
+
metadata = []
|
| 267 |
+
texts_for_encoding = []
|
| 268 |
+
for idx, row in df.iterrows():
|
| 269 |
+
cid = str(row[cid_col]).strip()
|
| 270 |
+
raw_path = str(row[path_col]).strip()
|
| 271 |
+
norm_path = normalize_path_sep(raw_path)
|
| 272 |
+
levels = path_to_levels(norm_path)
|
| 273 |
+
final = levels[-1] if levels else norm_path or cid
|
| 274 |
+
# include both path and final in canonical text to encode
|
| 275 |
+
text = f"category: {norm_path}. leaf: {final}."
|
| 276 |
+
metadata.append({
|
| 277 |
+
"category_id": cid,
|
| 278 |
+
"category_path": norm_path,
|
| 279 |
+
"final": final,
|
| 280 |
+
"levels": levels,
|
| 281 |
+
"depth": len(levels)
|
| 282 |
+
})
|
| 283 |
+
texts_for_encoding.append(text)
|
| 284 |
+
|
| 285 |
+
print(f"Prepared {len(metadata):,} metadata entries")
|
| 286 |
+
|
| 287 |
+
# encoder
|
| 288 |
+
use_cuda = args.gpu
|
| 289 |
+
encoder = build_encoder(args.model, use_cuda=use_cuda)
|
| 290 |
+
|
| 291 |
+
# encode categories
|
| 292 |
+
cat_embeddings = encode_texts(encoder, texts_for_encoding, use_cuda=use_cuda)
|
| 293 |
+
|
| 294 |
+
# Attach embeddings temporarily for calibrator builder
|
| 295 |
+
attach_embeddings_to_metadata(metadata, cat_embeddings)
|
| 296 |
+
|
| 297 |
+
# parent embeddings
|
| 298 |
+
parent_emb = make_parent_embeddings(metadata, cat_embeddings)
|
| 299 |
+
print(f"Built {len(parent_emb):,} parent embeddings")
|
| 300 |
+
|
| 301 |
+
# Build CPU FAISS index (IP on normalized vectors -> cosine)
|
| 302 |
+
index = build_faiss_index(cat_embeddings, use_gpu=False)
|
| 303 |
+
|
| 304 |
+
# save index (FAISS CPU index)
|
| 305 |
+
faiss_path = CACHE_DIR / "main_index.faiss"
|
| 306 |
+
faiss.write_index(index, str(faiss_path))
|
| 307 |
+
print("Saved FAISS index:", faiss_path)
|
| 308 |
+
|
| 309 |
+
# save metadata (we will strip embeddings before saving to reduce pickle size)
|
| 310 |
+
detach_embeddings_from_metadata(metadata)
|
| 311 |
+
meta_path = CACHE_DIR / "metadata.pkl"
|
| 312 |
+
safe_pickle_save(metadata, meta_path)
|
| 313 |
+
print("Saved metadata:", meta_path)
|
| 314 |
+
|
| 315 |
+
# save parent embeddings
|
| 316 |
+
parent_path = CACHE_DIR / "parent_embeddings.pkl"
|
| 317 |
+
safe_pickle_save(parent_emb, parent_path)
|
| 318 |
+
print("Saved parent embeddings:", parent_path)
|
| 319 |
+
|
| 320 |
+
# save model_info
|
| 321 |
+
info = {
|
| 322 |
+
"model_name": args.model,
|
| 323 |
+
"num_categories": len(metadata),
|
| 324 |
+
"embedding_dim": cat_embeddings.shape[1]
|
| 325 |
+
}
|
| 326 |
+
with open(CACHE_DIR / "model_info.json", "w", encoding="utf-8") as f:
|
| 327 |
+
json.dump(info, f, indent=2)
|
| 328 |
+
print("Saved model_info.json")
|
| 329 |
+
|
| 330 |
+
# store tags.json -> cross_store_synonyms (just preserve structure)
|
| 331 |
+
tags = load_tags_json(Path(args.tags))
|
| 332 |
+
if tags:
|
| 333 |
+
syn_p = CACHE_DIR / "cross_store_synonyms.pkl"
|
| 334 |
+
safe_pickle_save(tags, syn_p)
|
| 335 |
+
print("Saved cross_store_synonyms.pkl from tags.json (size: %d)" % len(tags))
|
| 336 |
+
|
| 337 |
+
# calibrator: use validation.csv if exists
|
| 338 |
+
val_path = Path(args.validation)
|
| 339 |
+
calibrator_obj = None
|
| 340 |
+
if val_path.exists():
|
| 341 |
+
# we need embeddings attached again for calibrator training
|
| 342 |
+
attach_embeddings_to_metadata(metadata, cat_embeddings)
|
| 343 |
+
calibrator_obj = train_calibrator(encoder, metadata, index, val_path, args.model, use_cuda=use_cuda)
|
| 344 |
+
detach_embeddings_from_metadata(metadata)
|
| 345 |
+
if calibrator_obj:
|
| 346 |
+
safe_pickle_save(calibrator_obj, CACHE_DIR / "calibrator.pkl")
|
| 347 |
+
print("Saved calibrator.pkl")
|
| 348 |
+
|
| 349 |
+
# optional LightGBM classifier
|
| 350 |
+
if args.train_classifier:
|
| 351 |
+
if not LGB_AVAILABLE:
|
| 352 |
+
print("LightGBM not available. Install lightgbm to train classifier.")
|
| 353 |
+
else:
|
| 354 |
+
val_path2 = Path(args.validation)
|
| 355 |
+
if not val_path2.exists():
|
| 356 |
+
print("validation.csv required to train classifier. Skipping classifier training.")
|
| 357 |
+
else:
|
| 358 |
+
# create training set from validation.csv
|
| 359 |
+
dfv = pd.read_csv(val_path2, dtype=str, keep_default_na=False)
|
| 360 |
+
if "product_title" not in dfv.columns or "category_id" not in dfv.columns:
|
| 361 |
+
print("validation.csv must contain product_title and category_id. Skipping classifier.")
|
| 362 |
+
else:
|
| 363 |
+
# encode product titles
|
| 364 |
+
prod_texts = [f"query: {t}" for t in dfv["product_title"].astype(str).tolist()]
|
| 365 |
+
prod_embs = encode_texts(encoder, prod_texts, use_cuda=use_cuda)
|
| 366 |
+
# map category ids to numeric labels
|
| 367 |
+
cat_to_label = {m["category_id"]: i for i, m in enumerate(metadata)}
|
| 368 |
+
labels = []
|
| 369 |
+
rows = []
|
| 370 |
+
for i, row in dfv.iterrows():
|
| 371 |
+
cid = row["category_id"]
|
| 372 |
+
if cid not in cat_to_label:
|
| 373 |
+
continue
|
| 374 |
+
labels.append(cat_to_label[cid])
|
| 375 |
+
rows.append(prod_embs[i])
|
| 376 |
+
if len(rows) < 50:
|
| 377 |
+
print("Not enough training rows for classifier. Need >=50. Skipping.")
|
| 378 |
+
else:
|
| 379 |
+
X = np.vstack(rows)
|
| 380 |
+
y = np.array(labels, dtype=np.int32)
|
| 381 |
+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.15, random_state=42, stratify=y)
|
| 382 |
+
lgb_train = lgb.Dataset(X_train, label=y_train)
|
| 383 |
+
lgb_eval = lgb.Dataset(X_val, label=y_val, reference=lgb_train)
|
| 384 |
+
params = {
|
| 385 |
+
"objective": "multiclass",
|
| 386 |
+
"num_class": int(max(y) + 1),
|
| 387 |
+
"metric": "multi_logloss",
|
| 388 |
+
"verbosity": -1,
|
| 389 |
+
"num_threads": 4,
|
| 390 |
+
"learning_rate": 0.1,
|
| 391 |
+
"num_leaves": 31
|
| 392 |
+
}
|
| 393 |
+
print("Training LightGBM classifier (may take time)...")
|
| 394 |
+
gbm = lgb.train(params, lgb_train, valid_sets=[lgb_train, lgb_eval], early_stopping_rounds=30, num_boost_round=500)
|
| 395 |
+
# save classifier and mapping
|
| 396 |
+
clf_path = CACHE_DIR / "classifier.pkl"
|
| 397 |
+
safe_pickle_save({"model": gbm, "cat_to_label": cat_to_label, "label_to_cat": {v: k for k, v in cat_to_label.items()}}, clf_path)
|
| 398 |
+
print("Saved classifier.pkl")
|
| 399 |
+
|
| 400 |
+
# cleanup if asked
|
| 401 |
+
if args.clean_cache:
|
| 402 |
+
keep = {"main_index.faiss", "metadata.pkl", "model_info.json", "parent_embeddings.pkl", "cross_store_synonyms.pkl"}
|
| 403 |
+
if calibrator_obj:
|
| 404 |
+
keep.add("calibrator.pkl")
|
| 405 |
+
# remove everything else in cache
|
| 406 |
+
removed = []
|
| 407 |
+
for p in CACHE_DIR.iterdir():
|
| 408 |
+
if p.name in keep:
|
| 409 |
+
continue
|
| 410 |
+
try:
|
| 411 |
+
p.unlink()
|
| 412 |
+
removed.append(p.name)
|
| 413 |
+
except Exception:
|
| 414 |
+
pass
|
| 415 |
+
if removed:
|
| 416 |
+
print("Removed cache files:", removed)
|
| 417 |
+
|
| 418 |
+
print("DONE. Index + data saved to cache/")
|
| 419 |
+
|
| 420 |
+
if __name__ == "__main__":
|
| 421 |
+
main()
|
validation_data.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
π VALIDATION DATA CREATOR
|
| 3 |
+
===========================
|
| 4 |
+
Helper script to create validation CSV for confidence calibration.
|
| 5 |
+
|
| 6 |
+
Two modes:
|
| 7 |
+
1. Sample from existing categories (automated)
|
| 8 |
+
2. Manual entry (interactive)
|
| 9 |
+
|
| 10 |
+
Output format:
|
| 11 |
+
product_title,true_category_id
|
| 12 |
+
"Oxygen Sensor Tool",12345
|
| 13 |
+
"Hydraulic Oil Additive",67890
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
# Automated sampling:
|
| 17 |
+
python create_validation_data.py auto data/category_id_path_only.csv
|
| 18 |
+
|
| 19 |
+
# Manual entry:
|
| 20 |
+
python create_validation_data.py manual
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import sys
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
import random
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def sample_from_categories(csv_path, num_samples=100, output_file='data/validation.csv'):
|
| 30 |
+
"""
|
| 31 |
+
Automatically create validation data by sampling from categories
|
| 32 |
+
and generating product titles based on category paths.
|
| 33 |
+
"""
|
| 34 |
+
print("\n" + "="*80)
|
| 35 |
+
print("π AUTO-GENERATING VALIDATION DATA")
|
| 36 |
+
print("="*80 + "\n")
|
| 37 |
+
|
| 38 |
+
# Load categories
|
| 39 |
+
print(f"Loading: {csv_path}")
|
| 40 |
+
df = pd.read_csv(csv_path)
|
| 41 |
+
|
| 42 |
+
if len(df.columns) < 2:
|
| 43 |
+
print("β CSV must have at least 2 columns (category_id, category_path)")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
df.columns = ['category_id', 'category_path'] + list(df.columns[2:])
|
| 47 |
+
df = df.dropna(subset=['category_path'])
|
| 48 |
+
|
| 49 |
+
print(f"β
Loaded {len(df):,} categories\n")
|
| 50 |
+
|
| 51 |
+
# Sample categories
|
| 52 |
+
sample_size = min(num_samples, len(df))
|
| 53 |
+
sampled = df.sample(n=sample_size, random_state=42)
|
| 54 |
+
|
| 55 |
+
print(f"π Generating {sample_size} validation entries...\n")
|
| 56 |
+
|
| 57 |
+
validation_data = []
|
| 58 |
+
|
| 59 |
+
for idx, row in sampled.iterrows():
|
| 60 |
+
cat_id = str(row['category_id'])
|
| 61 |
+
cat_path = str(row['category_path'])
|
| 62 |
+
|
| 63 |
+
# Generate product title from category path
|
| 64 |
+
levels = cat_path.split('/')
|
| 65 |
+
|
| 66 |
+
# Use last 2-3 levels as product title
|
| 67 |
+
if len(levels) >= 3:
|
| 68 |
+
title_parts = levels[-3:]
|
| 69 |
+
elif len(levels) >= 2:
|
| 70 |
+
title_parts = levels[-2:]
|
| 71 |
+
else:
|
| 72 |
+
title_parts = levels
|
| 73 |
+
|
| 74 |
+
# Clean and combine
|
| 75 |
+
title = ' '.join(title_parts).strip()
|
| 76 |
+
|
| 77 |
+
# Add some variation
|
| 78 |
+
variations = [
|
| 79 |
+
title,
|
| 80 |
+
f"{title} kit",
|
| 81 |
+
f"{title} tool",
|
| 82 |
+
f"{title} set",
|
| 83 |
+
f"professional {title}",
|
| 84 |
+
f"{title} replacement",
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
product_title = random.choice(variations)
|
| 88 |
+
|
| 89 |
+
validation_data.append({
|
| 90 |
+
'product_title': product_title,
|
| 91 |
+
'true_category_id': cat_id
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
# Create DataFrame
|
| 95 |
+
val_df = pd.DataFrame(validation_data)
|
| 96 |
+
|
| 97 |
+
# Save
|
| 98 |
+
output_path = Path(output_file)
|
| 99 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
val_df.to_csv(output_path, index=False)
|
| 102 |
+
|
| 103 |
+
print(f"β
Created validation file: {output_path}")
|
| 104 |
+
print(f" Entries: {len(val_df):,}")
|
| 105 |
+
|
| 106 |
+
# Show samples
|
| 107 |
+
print("\nπ Sample entries:")
|
| 108 |
+
for i, row in val_df.head(5).iterrows():
|
| 109 |
+
print(f" {i+1}. \"{row['product_title']}\" β {row['true_category_id']}")
|
| 110 |
+
|
| 111 |
+
print("\n" + "="*80)
|
| 112 |
+
print("β
VALIDATION DATA CREATED!")
|
| 113 |
+
print("="*80)
|
| 114 |
+
print(f"\nNext step: Train with calibration")
|
| 115 |
+
print(f" python train_fixed_v2.py data/category_id_path_only.csv data/tags.json {output_path}")
|
| 116 |
+
print("="*80 + "\n")
|
| 117 |
+
|
| 118 |
+
return True
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def manual_entry(output_file='data/validation_manual.csv'):
|
| 122 |
+
"""
|
| 123 |
+
Interactive mode to manually create validation data.
|
| 124 |
+
"""
|
| 125 |
+
print("\n" + "="*80)
|
| 126 |
+
print("π MANUAL VALIDATION DATA ENTRY")
|
| 127 |
+
print("="*80)
|
| 128 |
+
print("\nEnter product titles and their correct category IDs.")
|
| 129 |
+
print("Press CTRL+C when done.\n")
|
| 130 |
+
|
| 131 |
+
validation_data = []
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
while True:
|
| 135 |
+
print(f"\n--- Entry #{len(validation_data) + 1} ---")
|
| 136 |
+
|
| 137 |
+
title = input("Product title: ").strip()
|
| 138 |
+
if not title:
|
| 139 |
+
print("β οΈ Title cannot be empty")
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
cat_id = input("Category ID: ").strip()
|
| 143 |
+
if not cat_id:
|
| 144 |
+
print("β οΈ Category ID cannot be empty")
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
validation_data.append({
|
| 148 |
+
'product_title': title,
|
| 149 |
+
'true_category_id': cat_id
|
| 150 |
+
})
|
| 151 |
+
|
| 152 |
+
print(f"β
Added: \"{title}\" β {cat_id}")
|
| 153 |
+
|
| 154 |
+
except KeyboardInterrupt:
|
| 155 |
+
print("\n\nπ Entry complete!")
|
| 156 |
+
|
| 157 |
+
if not validation_data:
|
| 158 |
+
print("β No entries created")
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
# Create DataFrame
|
| 162 |
+
val_df = pd.DataFrame(validation_data)
|
| 163 |
+
|
| 164 |
+
# Save
|
| 165 |
+
output_path = Path(output_file)
|
| 166 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 167 |
+
|
| 168 |
+
val_df.to_csv(output_path, index=False)
|
| 169 |
+
|
| 170 |
+
print(f"\nβ
Created validation file: {output_path}")
|
| 171 |
+
print(f" Entries: {len(val_df):,}")
|
| 172 |
+
|
| 173 |
+
print("\n" + "="*80)
|
| 174 |
+
print("β
VALIDATION DATA CREATED!")
|
| 175 |
+
print("="*80)
|
| 176 |
+
print(f"\nNext step: Train with calibration")
|
| 177 |
+
print(f" python train_fixed_v2.py data/category_id_path_only.csv data/tags.json {output_path}")
|
| 178 |
+
print("="*80 + "\n")
|
| 179 |
+
|
| 180 |
+
return True
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def verify_validation_file(validation_csv, categories_csv):
|
| 184 |
+
"""
|
| 185 |
+
Verify that validation data references valid category IDs.
|
| 186 |
+
"""
|
| 187 |
+
print("\n" + "="*80)
|
| 188 |
+
print("π VERIFYING VALIDATION DATA")
|
| 189 |
+
print("="*80 + "\n")
|
| 190 |
+
|
| 191 |
+
# Load validation data
|
| 192 |
+
print(f"Loading validation: {validation_csv}")
|
| 193 |
+
val_df = pd.read_csv(validation_csv)
|
| 194 |
+
|
| 195 |
+
if 'product_title' not in val_df.columns or 'true_category_id' not in val_df.columns:
|
| 196 |
+
print("β Validation CSV must have: product_title, true_category_id")
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
print(f"β
Loaded {len(val_df):,} validation entries\n")
|
| 200 |
+
|
| 201 |
+
# Load categories
|
| 202 |
+
print(f"Loading categories: {categories_csv}")
|
| 203 |
+
cat_df = pd.read_csv(categories_csv)
|
| 204 |
+
cat_df.columns = ['category_id', 'category_path'] + list(cat_df.columns[2:])
|
| 205 |
+
|
| 206 |
+
valid_ids = set(cat_df['category_id'].astype(str))
|
| 207 |
+
print(f"β
Loaded {len(valid_ids):,} valid category IDs\n")
|
| 208 |
+
|
| 209 |
+
# Verify
|
| 210 |
+
print("Checking validation entries...")
|
| 211 |
+
invalid_count = 0
|
| 212 |
+
|
| 213 |
+
for idx, row in val_df.iterrows():
|
| 214 |
+
cat_id = str(row['true_category_id'])
|
| 215 |
+
title = row['product_title']
|
| 216 |
+
|
| 217 |
+
if cat_id not in valid_ids:
|
| 218 |
+
print(f"β Invalid ID: {cat_id} for \"{title}\"")
|
| 219 |
+
invalid_count += 1
|
| 220 |
+
|
| 221 |
+
if invalid_count == 0:
|
| 222 |
+
print("β
All validation entries are valid!")
|
| 223 |
+
else:
|
| 224 |
+
print(f"\nβ οΈ Found {invalid_count} invalid entries")
|
| 225 |
+
|
| 226 |
+
# Summary
|
| 227 |
+
print("\n" + "="*80)
|
| 228 |
+
print("π VALIDATION DATA SUMMARY")
|
| 229 |
+
print("="*80)
|
| 230 |
+
print(f"Total entries: {len(val_df):,}")
|
| 231 |
+
print(f"Valid entries: {len(val_df) - invalid_count:,}")
|
| 232 |
+
print(f"Invalid entries: {invalid_count}")
|
| 233 |
+
print("="*80 + "\n")
|
| 234 |
+
|
| 235 |
+
return invalid_count == 0
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main():
|
| 239 |
+
print("\n" + "="*80)
|
| 240 |
+
print("π VALIDATION DATA CREATOR")
|
| 241 |
+
print("="*80 + "\n")
|
| 242 |
+
|
| 243 |
+
if len(sys.argv) < 2:
|
| 244 |
+
print("Usage:")
|
| 245 |
+
print(" python create_validation_data.py auto <csv_path> [num_samples] [output_file]")
|
| 246 |
+
print(" python create_validation_data.py manual [output_file]")
|
| 247 |
+
print(" python create_validation_data.py verify <validation_csv> <categories_csv>")
|
| 248 |
+
print("\nExamples:")
|
| 249 |
+
print(" # Auto-generate 100 samples:")
|
| 250 |
+
print(" python create_validation_data.py auto data/category_id_path_only.csv")
|
| 251 |
+
print()
|
| 252 |
+
print(" # Auto-generate 200 samples:")
|
| 253 |
+
print(" python create_validation_data.py auto data/category_id_path_only.csv 200")
|
| 254 |
+
print()
|
| 255 |
+
print(" # Manual entry:")
|
| 256 |
+
print(" python create_validation_data.py manual")
|
| 257 |
+
print()
|
| 258 |
+
print(" # Verify validation file:")
|
| 259 |
+
print(" python create_validation_data.py verify data/validation.csv data/category_id_path_only.csv")
|
| 260 |
+
print()
|
| 261 |
+
return
|
| 262 |
+
|
| 263 |
+
mode = sys.argv[1].lower()
|
| 264 |
+
|
| 265 |
+
if mode == 'auto':
|
| 266 |
+
if len(sys.argv) < 3:
|
| 267 |
+
print("β CSV path required for auto mode")
|
| 268 |
+
print(" python create_validation_data.py auto data/category_id_path_only.csv")
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
csv_path = sys.argv[2]
|
| 272 |
+
num_samples = int(sys.argv[3]) if len(sys.argv) > 3 else 100
|
| 273 |
+
output_file = sys.argv[4] if len(sys.argv) > 4 else 'data/validation.csv'
|
| 274 |
+
|
| 275 |
+
if not Path(csv_path).exists():
|
| 276 |
+
print(f"β File not found: {csv_path}")
|
| 277 |
+
return
|
| 278 |
+
|
| 279 |
+
sample_from_categories(csv_path, num_samples, output_file)
|
| 280 |
+
|
| 281 |
+
elif mode == 'manual':
|
| 282 |
+
output_file = sys.argv[2] if len(sys.argv) > 2 else 'data/validation_manual.csv'
|
| 283 |
+
manual_entry(output_file)
|
| 284 |
+
|
| 285 |
+
elif mode == 'verify':
|
| 286 |
+
if len(sys.argv) < 4:
|
| 287 |
+
print("β Both validation CSV and categories CSV required")
|
| 288 |
+
print(" python create_validation_data.py verify data/validation.csv data/category_id_path_only.csv")
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
validation_csv = sys.argv[2]
|
| 292 |
+
categories_csv = sys.argv[3]
|
| 293 |
+
|
| 294 |
+
if not Path(validation_csv).exists():
|
| 295 |
+
print(f"β File not found: {validation_csv}")
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
if not Path(categories_csv).exists():
|
| 299 |
+
print(f"β File not found: {categories_csv}")
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
verify_validation_file(validation_csv, categories_csv)
|
| 303 |
+
|
| 304 |
+
else:
|
| 305 |
+
print(f"β Unknown mode: {mode}")
|
| 306 |
+
print(" Use: auto, manual, or verify")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
main()
|