Spaces:
Sleeping
Sleeping
Clean codebase and add demographic enhancements
Browse files- Enhanced demographic handling for zero interactions
- Improved content-based filtering with aggregated history
- Added comprehensive UI support for new user scenarios
- Cleaned up analysis files and redundant components
- Updated documentation and project structure
- DEEP_ARCHITECTURE.md +549 -0
- README.md +25 -27
- analyze_recommendations.py +0 -543
- api/main.py +60 -88
- api_2phase.py +0 -521
- api_joint.py +0 -522
- datasets/interactions.csv +0 -0
- datasets/items.csv +0 -0
- datasets/users.csv +0 -0
- frontend/src/App.css +62 -0
- frontend/src/App.js +207 -33
- src/data_generation/generate_demographics.py +292 -0
- src/inference/enhanced_recommendation_engine.py +0 -303
- src/inference/enhanced_recommendation_engine_128d.py +0 -499
- src/inference/recommendation_engine.py +562 -17
- src/models/enhanced_two_tower.py +0 -574
- src/models/improved_two_tower.py +0 -545
- src/models/user_tower.py +33 -0
- src/preprocessing/optimized_dataset_creator.py +0 -111
- src/preprocessing/user_data_preparation.py +72 -0
- src/training/curriculum_trainer.py +0 -341
- src/training/fast_joint_training.py +0 -268
- src/training/improved_joint_training.py +0 -462
- src/training/optimized_joint_training.py +0 -439
- src/utils/real_user_selector.py +51 -11
- train_improved_model.py +0 -111
DEEP_ARCHITECTURE.md
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Architecture Documentation - RecSys-HP
|
| 2 |
+
|
| 3 |
+
## ποΈ Complete System Architecture Overview
|
| 4 |
+
|
| 5 |
+
```mermaid
|
| 6 |
+
graph TB
|
| 7 |
+
subgraph "Data Layer"
|
| 8 |
+
D1[items.csv<br/>15K+ products]
|
| 9 |
+
D2[users.csv<br/>Enhanced demographics]
|
| 10 |
+
D3[interactions.csv<br/>User-item interactions]
|
| 11 |
+
D4[Artifacts<br/>Trained models & indices]
|
| 12 |
+
end
|
| 13 |
+
|
| 14 |
+
subgraph "ML Pipeline"
|
| 15 |
+
P1[Data Preprocessing]
|
| 16 |
+
P2[Item Tower Pre-training]
|
| 17 |
+
P3[FAISS Index Creation]
|
| 18 |
+
P4[Joint Training]
|
| 19 |
+
P5[User Tower Training]
|
| 20 |
+
end
|
| 21 |
+
|
| 22 |
+
subgraph "Inference Layer"
|
| 23 |
+
I1[Recommendation Engine<br/>Category-Boosted Algorithm]
|
| 24 |
+
I2[FAISS Similarity Search]
|
| 25 |
+
I3[Real User Selection]
|
| 26 |
+
I4[Hybrid Scoring]
|
| 27 |
+
end
|
| 28 |
+
|
| 29 |
+
subgraph "API Layer"
|
| 30 |
+
A1[FastAPI Server<br/>Port 8000]
|
| 31 |
+
A2[Recommendation Endpoints]
|
| 32 |
+
A3[User Management]
|
| 33 |
+
A4[Item Retrieval]
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
subgraph "Frontend Layer"
|
| 37 |
+
F1[React.js Application]
|
| 38 |
+
F2[Interactive UI Components]
|
| 39 |
+
F3[Real-time Analytics]
|
| 40 |
+
F4[User Profile Management]
|
| 41 |
+
end
|
| 42 |
+
|
| 43 |
+
D1 --> P1
|
| 44 |
+
D2 --> P1
|
| 45 |
+
D3 --> P1
|
| 46 |
+
P1 --> P2
|
| 47 |
+
P2 --> P3
|
| 48 |
+
P3 --> P4
|
| 49 |
+
P4 --> P5
|
| 50 |
+
P5 --> D4
|
| 51 |
+
D4 --> I1
|
| 52 |
+
I1 --> I2
|
| 53 |
+
I2 --> I3
|
| 54 |
+
I3 --> I4
|
| 55 |
+
I4 --> A1
|
| 56 |
+
A1 --> A2
|
| 57 |
+
A2 --> A3
|
| 58 |
+
A3 --> A4
|
| 59 |
+
A4 --> F1
|
| 60 |
+
F1 --> F2
|
| 61 |
+
F2 --> F3
|
| 62 |
+
F3 --> F4
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## π Project Structure
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
RecSys-HP/
|
| 71 |
+
βββ ποΈ Data Layer
|
| 72 |
+
β βββ datasets/
|
| 73 |
+
β β βββ items.csv # 15K+ product catalog
|
| 74 |
+
β β βββ users.csv # Enhanced user demographics (7 features)
|
| 75 |
+
β β βββ interactions.csv # User-item interaction history
|
| 76 |
+
β β βββ users_enhanced.csv # Backup with enhanced features
|
| 77 |
+
β βββ src/artifacts/ # Trained models and indices
|
| 78 |
+
β βββ item_embeddings.npy # Pre-trained item vectors (128D)
|
| 79 |
+
β βββ faiss_item_index.bin # FAISS similarity index
|
| 80 |
+
β βββ faiss_metadata.pkl # Item metadata mapping
|
| 81 |
+
β βββ vocabularies.pkl # Categorical encoders
|
| 82 |
+
β βββ *.weights.* files # TensorFlow model weights
|
| 83 |
+
β
|
| 84 |
+
βββ π§ ML Pipeline
|
| 85 |
+
β βββ src/preprocessing/
|
| 86 |
+
β β βββ data_loader.py # Data loading and preprocessing
|
| 87 |
+
β β βββ user_data_preparation.py # User feature engineering
|
| 88 |
+
β βββ src/training/
|
| 89 |
+
β β βββ item_pretraining.py # Item tower pre-training
|
| 90 |
+
β β βββ joint_training.py # Two-tower joint training
|
| 91 |
+
β βββ src/models/
|
| 92 |
+
β β βββ item_tower.py # Item embedding model (TensorFlow)
|
| 93 |
+
β β βββ user_tower.py # User embedding model (TensorFlow)
|
| 94 |
+
β βββ Training Scripts
|
| 95 |
+
β βββ run_training_pipeline.py # Complete pipeline executor
|
| 96 |
+
β βββ run_2phase_training.py # 2-phase training approach
|
| 97 |
+
β βββ run_joint_training.py # Joint training approach
|
| 98 |
+
β
|
| 99 |
+
βββ π Inference Layer
|
| 100 |
+
β βββ src/inference/
|
| 101 |
+
β β βββ recommendation_engine.py # Core recommendation algorithms
|
| 102 |
+
β β βββ faiss_index.py # FAISS index management
|
| 103 |
+
β βββ src/utils/
|
| 104 |
+
β β βββ real_user_selector.py # Real user data selection
|
| 105 |
+
β βββ src/data_generation/
|
| 106 |
+
β βββ generate_demographics.py # Synthetic user generation
|
| 107 |
+
β
|
| 108 |
+
βββ π API Layer
|
| 109 |
+
β βββ api/
|
| 110 |
+
β βββ main.py # FastAPI server with all endpoints
|
| 111 |
+
β
|
| 112 |
+
βββ π¨ Frontend Layer
|
| 113 |
+
β βββ frontend/
|
| 114 |
+
β βββ src/
|
| 115 |
+
β β βββ App.js # Main React application
|
| 116 |
+
β β βββ index.js # Entry point
|
| 117 |
+
β βββ public/ # Static assets
|
| 118 |
+
β βββ build/ # Production build
|
| 119 |
+
β βββ package.json # Dependencies
|
| 120 |
+
β
|
| 121 |
+
βββ π§ͺ Testing & Analysis
|
| 122 |
+
βββ test_category_boosted.py # Basic algorithm testing
|
| 123 |
+
βββ test_enhanced_category_boosted.py # Advanced subcategory testing
|
| 124 |
+
βββ deep_analyze_category_boosted.py # Comprehensive analysis tool
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## π Data Flow Architecture
|
| 130 |
+
|
| 131 |
+
### 1. Training Pipeline Flow
|
| 132 |
+
```mermaid
|
| 133 |
+
sequenceDiagram
|
| 134 |
+
participant D as Data Files
|
| 135 |
+
participant P as Preprocessing
|
| 136 |
+
participant IT as Item Tower
|
| 137 |
+
participant F as FAISS Index
|
| 138 |
+
participant JT as Joint Training
|
| 139 |
+
participant UT as User Tower
|
| 140 |
+
participant A as Artifacts
|
| 141 |
+
|
| 142 |
+
D->>P: Load datasets (items, users, interactions)
|
| 143 |
+
P->>IT: Preprocessed item features
|
| 144 |
+
IT->>IT: Pre-train item embeddings (128D)
|
| 145 |
+
IT->>F: Generate item vectors
|
| 146 |
+
F->>F: Build FAISS similarity index
|
| 147 |
+
IT->>JT: Pre-trained item tower
|
| 148 |
+
P->>JT: User features (7 demographics)
|
| 149 |
+
JT->>UT: Train user tower
|
| 150 |
+
JT->>A: Save trained models
|
| 151 |
+
F->>A: Save FAISS index
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### 2. Inference Pipeline Flow
|
| 155 |
+
```mermaid
|
| 156 |
+
sequenceDiagram
|
| 157 |
+
participant U as User Request
|
| 158 |
+
participant API as FastAPI
|
| 159 |
+
participant RE as Recommendation Engine
|
| 160 |
+
participant F as FAISS Search
|
| 161 |
+
participant CB as Category Boosted
|
| 162 |
+
participant R as Response
|
| 163 |
+
|
| 164 |
+
U->>API: POST /recommendations
|
| 165 |
+
API->>RE: User profile + preferences
|
| 166 |
+
RE->>F: Query item embeddings
|
| 167 |
+
F->>RE: Similar items (k*10 wide search)
|
| 168 |
+
RE->>CB: Apply category-boosted algorithm
|
| 169 |
+
CB->>CB: 50% from user categories + proportional distribution
|
| 170 |
+
CB->>RE: Balanced recommendations
|
| 171 |
+
RE->>API: Scored & ranked items
|
| 172 |
+
API->>R: JSON response with recommendations
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
---
|
| 176 |
+
|
| 177 |
+
## π§ Machine Learning Architecture
|
| 178 |
+
|
| 179 |
+
### Two-Tower Architecture
|
| 180 |
+
```
|
| 181 |
+
βββββββββββββββββββββββββββ βββββββββββββββββββββββββββ
|
| 182 |
+
β ITEM TOWER β β USER TOWER β
|
| 183 |
+
β β β β
|
| 184 |
+
β βββββββββββββββββββββββ β β βββββββββββββββββββββββ β
|
| 185 |
+
β β Item Features β β β β Demographic Featuresβ β
|
| 186 |
+
β β β β β β β β
|
| 187 |
+
β β β’ product_id β β β β β’ age (normalized) β β
|
| 188 |
+
β β β’ category_code β β β β β’ gender (encoded) β β
|
| 189 |
+
β β β’ brand β β β β β’ income (binned) β β
|
| 190 |
+
β β β’ price (log) β β β β β’ profession β β
|
| 191 |
+
β β β β β β β’ location β β
|
| 192 |
+
β βββββββββββββββββββββββ β β β β’ education_level β β
|
| 193 |
+
β β β β β β’ marital_status β β
|
| 194 |
+
β βΌ β β βββββββββββββββββββββββ β
|
| 195 |
+
β βββββββββββββββββββββββ β β β β
|
| 196 |
+
β β Dense Layers β β β βΌ β
|
| 197 |
+
β β β β β βββββββββββββββββββββββ β
|
| 198 |
+
β β β’ Dense(256, ReLU) β β β β Dense Layers β β
|
| 199 |
+
β β β’ Dropout(0.3) β β β β β β
|
| 200 |
+
β β β’ Dense(128, ReLU) β β β β β’ Dense(128, ReLU) β β
|
| 201 |
+
β β β’ L2 Regularization β β β β β’ Dropout(0.2) β β
|
| 202 |
+
β β β β β β β’ Dense(64, ReLU) β β
|
| 203 |
+
β βββββββββββββββββββββββ β β β β’ L2 Regularization β β
|
| 204 |
+
β β β β βββββββββββββββββββββββ β
|
| 205 |
+
β βΌ β β β β
|
| 206 |
+
β βββββββββββββββββββββββ β β βΌ β
|
| 207 |
+
β β Item Embedding β β β βββββββββββββββββββββββ β
|
| 208 |
+
β β (128D) β β β β User Embedding β β
|
| 209 |
+
β β β β β β (64D) β β
|
| 210 |
+
β βββββββββββββββββββββββ β β βββββββββββββββββββββββ β
|
| 211 |
+
βββββββββββββββββββββββββββ βββββββββββββββββββββββββββ
|
| 212 |
+
β β
|
| 213 |
+
ββββββββββββββββ¬ββββββββββββββββ
|
| 214 |
+
βΌ
|
| 215 |
+
βββββββββββββββββββββββββββ
|
| 216 |
+
β Dot Product β
|
| 217 |
+
β Similarity Score β
|
| 218 |
+
β β
|
| 219 |
+
β similarity = user_emb β
|
| 220 |
+
β Β· item_emb β
|
| 221 |
+
βββββββββββββββββββββββββββ
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### Category-Boosted Algorithm
|
| 225 |
+
```
|
| 226 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 227 |
+
β CATEGORY-BOOSTED RECOMMENDATION FLOW β
|
| 228 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
|
| 229 |
+
β β
|
| 230 |
+
β 1. USER INTERACTION ANALYSIS β
|
| 231 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 232 |
+
β β interaction_history: [1001, 2003, β β
|
| 233 |
+
β β 3045, 1099] β β
|
| 234 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 235 |
+
β β β
|
| 236 |
+
β βΌ β
|
| 237 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 238 |
+
β β Extract 2-level subcategories: β β
|
| 239 |
+
β β β’ computers.components: 40% β β
|
| 240 |
+
β β β’ electronics.audio: 35% β β
|
| 241 |
+
β β β’ computers.peripherals: 25% β β
|
| 242 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 243 |
+
β β
|
| 244 |
+
β 2. WIDE SIMILARITY SEARCH β
|
| 245 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 246 |
+
β β FAISS.search(user_embedding, β β
|
| 247 |
+
β β k = requested * 10) β β
|
| 248 |
+
β β β β
|
| 249 |
+
β β Returns: ~1000 similar items β β
|
| 250 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 251 |
+
β β β
|
| 252 |
+
β βΌ β
|
| 253 |
+
β 3. CATEGORY ORGANIZATION β
|
| 254 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 255 |
+
β β Group by subcategories: β β
|
| 256 |
+
β β β β
|
| 257 |
+
β β computers.components: [1001, 1099, β β
|
| 258 |
+
β β 1203, ...] β β
|
| 259 |
+
β β electronics.audio: [2003, 2156, β β
|
| 260 |
+
β β 2089, ...] β β
|
| 261 |
+
β β computers.peripherals: [3045, 3201, β β
|
| 262 |
+
β β 3078, ...] β β
|
| 263 |
+
β β other_categories: [4001, 5002, ...] β β
|
| 264 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 265 |
+
β β β
|
| 266 |
+
β βΌ β
|
| 267 |
+
β 4. PROPORTIONAL ALLOCATION β
|
| 268 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 269 |
+
β β Target: 50% from user categories β β
|
| 270 |
+
β β (50 items for 100 recommendations) β β
|
| 271 |
+
β β β β
|
| 272 |
+
β β computers.components: 40% β 20 itemsβ β
|
| 273 |
+
β β electronics.audio: 35% β 18 items β β
|
| 274 |
+
β β computers.peripherals: 25% β 12 itemsβ β
|
| 275 |
+
β β β β
|
| 276 |
+
β β Remaining 50 items: diverse mix β β
|
| 277 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 278 |
+
β β β
|
| 279 |
+
β βΌ β
|
| 280 |
+
β 5. FINAL RECOMMENDATION SET β
|
| 281 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 282 |
+
β β β’ 50 items from user's categories β β
|
| 283 |
+
β β (proportionally distributed) β β
|
| 284 |
+
β β β’ 50 items for exploration β β
|
| 285 |
+
β β β’ All ranked by similarity score β β
|
| 286 |
+
β β β’ Ensures category diversity β β
|
| 287 |
+
β βββββββββββββββββββββββββββββββββββββββ β
|
| 288 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 289 |
+
```
|
| 290 |
+
|
| 291 |
+
---
|
| 292 |
+
|
| 293 |
+
## π API Architecture
|
| 294 |
+
|
| 295 |
+
### FastAPI Server Endpoints
|
| 296 |
+
```python
|
| 297 |
+
# Core Recommendation Endpoints
|
| 298 |
+
POST /recommendations # Main recommendation engine
|
| 299 |
+
GET /real-users # Fetch real user profiles
|
| 300 |
+
GET /items/{item_id} # Get item details
|
| 301 |
+
GET /dataset-summary # Dataset statistics
|
| 302 |
+
|
| 303 |
+
# Algorithm-Specific Endpoints
|
| 304 |
+
POST /recommendations/hybrid # Hybrid collaborative + content
|
| 305 |
+
POST /recommendations/collaborative # Pure collaborative filtering
|
| 306 |
+
POST /recommendations/content # Aggregated history content-based recommendations
|
| 307 |
+
POST /recommendations/category_boosted # Category-boosted algorithm
|
| 308 |
+
|
| 309 |
+
# Utility Endpoints
|
| 310 |
+
GET / # Health check
|
| 311 |
+
GET /sample-items # Random item samples
|
| 312 |
+
POST /generate-interactions # Synthetic interaction generation
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
### Request/Response Flow
|
| 316 |
+
```mermaid
|
| 317 |
+
graph LR
|
| 318 |
+
subgraph "Request Processing"
|
| 319 |
+
A[User Request] --> B[Validation]
|
| 320 |
+
B --> C[Feature Engineering]
|
| 321 |
+
C --> D[Model Inference]
|
| 322 |
+
end
|
| 323 |
+
|
| 324 |
+
subgraph "Recommendation Engine"
|
| 325 |
+
D --> E[FAISS Search]
|
| 326 |
+
E --> F[Category Analysis]
|
| 327 |
+
F --> G[Score Calculation]
|
| 328 |
+
G --> H[Ranking & Filtering]
|
| 329 |
+
end
|
| 330 |
+
|
| 331 |
+
subgraph "Response Generation"
|
| 332 |
+
H --> I[Item Enrichment]
|
| 333 |
+
I --> J[Metadata Addition]
|
| 334 |
+
J --> K[JSON Response]
|
| 335 |
+
end
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
---
|
| 339 |
+
|
| 340 |
+
## π¨ Frontend Architecture
|
| 341 |
+
|
| 342 |
+
### React.js Component Structure
|
| 343 |
+
```
|
| 344 |
+
App.js (Main Container)
|
| 345 |
+
βββ User Profile Management
|
| 346 |
+
β βββ Demographics Form
|
| 347 |
+
β βββ Real User Selection
|
| 348 |
+
β βββ Interaction History Display
|
| 349 |
+
βββ Recommendation Controls
|
| 350 |
+
β βββ Algorithm Selection
|
| 351 |
+
β βββ Count Configuration
|
| 352 |
+
β βββ Weight Adjustment
|
| 353 |
+
βββ Results Display
|
| 354 |
+
β βββ Recommendation Cards
|
| 355 |
+
β βββ Category Analytics
|
| 356 |
+
β βββ Pagination Controls
|
| 357 |
+
β βββ Similar Items View
|
| 358 |
+
βββ Analysis Components
|
| 359 |
+
βββ Category Interest Graphs
|
| 360 |
+
βββ Interaction Patterns
|
| 361 |
+
βββ Performance Metrics
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
### State Management
|
| 365 |
+
```javascript
|
| 366 |
+
const [userProfile, setUserProfile] = useState({
|
| 367 |
+
age: 30,
|
| 368 |
+
gender: 'male',
|
| 369 |
+
income: 50000,
|
| 370 |
+
profession: 'Technology',
|
| 371 |
+
location: 'Urban',
|
| 372 |
+
education_level: "Bachelor's",
|
| 373 |
+
marital_status: 'Single',
|
| 374 |
+
interaction_history: []
|
| 375 |
+
});
|
| 376 |
+
|
| 377 |
+
const [recommendationType, setRecommendationType] = useState('category_boosted');
|
| 378 |
+
const [recommendations, setRecommendations] = useState([]);
|
| 379 |
+
const [realUsers, setRealUsers] = useState([]);
|
| 380 |
+
const [datasetSummary, setDatasetSummary] = useState(null);
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
---
|
| 384 |
+
|
| 385 |
+
## π Algorithm Deep Dive
|
| 386 |
+
|
| 387 |
+
### 1. Hybrid Recommendation
|
| 388 |
+
- **Collaborative Filtering**: User-item interaction patterns
|
| 389 |
+
- **Aggregated Content-Based**: User's complete interaction history aggregated into single embedding
|
| 390 |
+
- **Weight Balance**: Configurable collaborative weight (default: 0.7)
|
| 391 |
+
|
| 392 |
+
### 1.5. Aggregated History Content-Based Filtering
|
| 393 |
+
- **Revolutionary Approach**: Aggregates user's entire interaction history instead of single-item similarity
|
| 394 |
+
- **Aggregation Methods**:
|
| 395 |
+
- **Weighted Mean**: `weights = exp(linspace(-1, 0, len(history)))` (recent interactions weighted higher)
|
| 396 |
+
- **Simple Mean**: Equal weighting of all interaction embeddings
|
| 397 |
+
- **Max Pooling**: Element-wise maximum across all embeddings
|
| 398 |
+
- **Process Flow**:
|
| 399 |
+
1. **Embedding Extraction**: Get 128D vectors for each item in user's history
|
| 400 |
+
2. **Aggregation**: Apply selected aggregation method (weighted_mean by default)
|
| 401 |
+
3. **Normalization**: L2-normalize the aggregated embedding
|
| 402 |
+
4. **ANN Search**: Direct FAISS similarity search using aggregated user profile
|
| 403 |
+
5. **Filtering**: Remove already-interacted items from results
|
| 404 |
+
- **Benefits**: Captures complete user preference profile, more robust than single-item seed
|
| 405 |
+
|
| 406 |
+
### 2. Category-Boosted Algorithm
|
| 407 |
+
- **Step 1**: Analyze user's subcategory preferences (2-level depth)
|
| 408 |
+
- **Step 2**: Wide FAISS search (k Γ 10 multiplier)
|
| 409 |
+
- **Step 3**: Category organization and candidate grouping
|
| 410 |
+
- **Step 4**: Proportional allocation (50% from user categories)
|
| 411 |
+
- **Step 5**: Exploration items filling (remaining 50%)
|
| 412 |
+
|
| 413 |
+
### 3. FAISS Integration
|
| 414 |
+
- **Index Type**: Flat L2 similarity search
|
| 415 |
+
- **Vector Dimension**: 128D item embeddings
|
| 416 |
+
- **Search Strategy**: Wide retrieval + post-processing
|
| 417 |
+
- **Metadata**: Item-to-index mapping via pickle files
|
| 418 |
+
|
| 419 |
+
---
|
| 420 |
+
|
| 421 |
+
## π Performance Characteristics
|
| 422 |
+
|
| 423 |
+
### Scalability Metrics
|
| 424 |
+
- **Items**: 15K+ products supported
|
| 425 |
+
- **Users**: Unlimited (stateless design)
|
| 426 |
+
- **Recommendations**: 1-1000 per request
|
| 427 |
+
- **Response Time**: <2s for 100 recommendations
|
| 428 |
+
- **Memory Usage**: ~500MB for full model + index
|
| 429 |
+
|
| 430 |
+
### Algorithm Performance
|
| 431 |
+
- **Category Matching**: β₯50% from user's categories
|
| 432 |
+
- **Diversity Score**: Balanced exploration vs exploitation
|
| 433 |
+
- **Cold Start**: Handles new users via demographic features
|
| 434 |
+
- **Subcategory Precision**: 2-level category matching
|
| 435 |
+
|
| 436 |
+
---
|
| 437 |
+
|
| 438 |
+
## π Deployment Architecture
|
| 439 |
+
|
| 440 |
+
### Development Environment
|
| 441 |
+
```bash
|
| 442 |
+
# Backend (FastAPI)
|
| 443 |
+
cd /api && python main.py
|
| 444 |
+
|
| 445 |
+
# Frontend (React)
|
| 446 |
+
cd frontend && npm start
|
| 447 |
+
|
| 448 |
+
# Training Pipeline
|
| 449 |
+
python run_training_pipeline.py
|
| 450 |
+
```
|
| 451 |
+
|
| 452 |
+
### Production Considerations
|
| 453 |
+
- **Containerization**: Docker support for API + Frontend
|
| 454 |
+
- **Database**: PostgreSQL for production user/item storage
|
| 455 |
+
- **Caching**: Redis for recommendation caching
|
| 456 |
+
- **Load Balancing**: Nginx for multiple API instances
|
| 457 |
+
- **Monitoring**: Prometheus + Grafana for metrics
|
| 458 |
+
|
| 459 |
+
---
|
| 460 |
+
|
| 461 |
+
## π§ Configuration & Customization
|
| 462 |
+
|
| 463 |
+
### Model Configuration
|
| 464 |
+
```python
|
| 465 |
+
# Item Tower
|
| 466 |
+
ITEM_EMBEDDING_DIM = 128
|
| 467 |
+
ITEM_HIDDEN_LAYERS = [256, 128]
|
| 468 |
+
ITEM_DROPOUT_RATE = 0.3
|
| 469 |
+
|
| 470 |
+
# User Tower
|
| 471 |
+
USER_EMBEDDING_DIM = 64
|
| 472 |
+
USER_HIDDEN_LAYERS = [128, 64]
|
| 473 |
+
USER_DROPOUT_RATE = 0.2
|
| 474 |
+
|
| 475 |
+
# Training
|
| 476 |
+
BATCH_SIZE = 512
|
| 477 |
+
LEARNING_RATE = 0.001
|
| 478 |
+
EPOCHS = 100
|
| 479 |
+
VALIDATION_SPLIT = 0.2
|
| 480 |
+
```
|
| 481 |
+
|
| 482 |
+
### Algorithm Parameters
|
| 483 |
+
```python
|
| 484 |
+
# Category-Boosted
|
| 485 |
+
WIDE_SEARCH_MULTIPLIER = 10
|
| 486 |
+
USER_CATEGORY_PERCENTAGE = 0.5
|
| 487 |
+
SUBCATEGORY_LEVELS = 2
|
| 488 |
+
MIN_INTERACTION_THRESHOLD = 5
|
| 489 |
+
|
| 490 |
+
# FAISS
|
| 491 |
+
INDEX_TYPE = "Flat"
|
| 492 |
+
SIMILARITY_METRIC = "L2"
|
| 493 |
+
SEARCH_PARAMS = {"nprobe": 10}
|
| 494 |
+
|
| 495 |
+
# Aggregated Content-Based
|
| 496 |
+
AGGREGATION_METHOD = "weighted_mean" # "mean", "weighted_mean", "max"
|
| 497 |
+
TEMPORAL_DECAY_ALPHA = 1.0 # Controls recency weighting strength
|
| 498 |
+
HISTORY_LIMIT = 50 # Max items to consider for aggregation
|
| 499 |
+
```
|
| 500 |
+
|
| 501 |
+
---
|
| 502 |
+
|
| 503 |
+
## π§ͺ Testing Framework
|
| 504 |
+
|
| 505 |
+
### Test Coverage
|
| 506 |
+
- **Unit Tests**: Individual algorithm components
|
| 507 |
+
- **Integration Tests**: End-to-end recommendation flow
|
| 508 |
+
- **Performance Tests**: Latency and throughput benchmarks
|
| 509 |
+
- **Accuracy Tests**: Category matching validation
|
| 510 |
+
|
| 511 |
+
### Analysis Tools
|
| 512 |
+
- `test_category_boosted.py`: Basic algorithm validation
|
| 513 |
+
- `test_enhanced_category_boosted.py`: Advanced subcategory testing
|
| 514 |
+
- `deep_analyze_category_boosted.py`: Comprehensive performance analysis
|
| 515 |
+
- **`analyze_recommendation_alignment.py`**: **NEW** - Multi-algorithm alignment analysis
|
| 516 |
+
- Tests all 4 algorithms (collaborative, content, hybrid, category_boosted)
|
| 517 |
+
- Category alignment scoring and coverage analysis
|
| 518 |
+
- Diversity vs relevance trade-off analysis
|
| 519 |
+
- User-specific algorithm performance comparison
|
| 520 |
+
- Generates comprehensive visualizations and reports
|
| 521 |
+
|
| 522 |
+
### Algorithm Comparison Metrics
|
| 523 |
+
- **Top-Level Alignment**: % of recommendations matching user's preferred categories
|
| 524 |
+
- **Subcategory Precision**: 2-level category matching accuracy
|
| 525 |
+
- **Coverage Score**: % of user's categories represented in recommendations
|
| 526 |
+
- **Diversity Score**: Shannon entropy of recommendation categories
|
| 527 |
+
- **Performance by Scale**: Algorithm behavior across 10-100+ recommendations
|
| 528 |
+
|
| 529 |
+
---
|
| 530 |
+
|
| 531 |
+
## π Future Enhancements
|
| 532 |
+
|
| 533 |
+
### Planned Features
|
| 534 |
+
1. **Real-time Learning**: Online model updates
|
| 535 |
+
2. **A/B Testing**: Algorithm comparison framework
|
| 536 |
+
3. **Explainability**: Recommendation reasoning
|
| 537 |
+
4. **Multi-objective**: Balancing relevance, diversity, novelty
|
| 538 |
+
5. **Graph Neural Networks**: Advanced relationship modeling
|
| 539 |
+
|
| 540 |
+
### Technical Debt
|
| 541 |
+
- [ ] Add comprehensive error handling
|
| 542 |
+
- [ ] Implement request caching
|
| 543 |
+
- [ ] Add model versioning
|
| 544 |
+
- [ ] Create automated testing pipeline
|
| 545 |
+
- [ ] Add performance monitoring
|
| 546 |
+
|
| 547 |
+
---
|
| 548 |
+
|
| 549 |
+
This deep architecture documentation provides a comprehensive view of the RecSys-HP recommendation system, covering all layers from data storage to user interface, with detailed technical specifications and implementation details.
|
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Advanced Two-Tower Recommendation System
|
| 2 |
|
| 3 |
-
A production-ready recommendation system implementation using TensorFlow Recommenders with an enhanced two-tower architecture. This system provides personalized item recommendations through collaborative filtering, content-based filtering, and hybrid approaches, featuring categorical demographics, curriculum learning, and advanced training strategies.
|
| 4 |
|
| 5 |
## π― Project Overview
|
| 6 |
|
|
@@ -11,7 +11,7 @@ This recommendation system addresses the challenge of providing personalized ite
|
|
| 11 |
- **π§ Enhanced Two-Tower Architecture**: 128D embeddings with temperature scaling and contrastive learning
|
| 12 |
- **π Curriculum Learning**: Progressive training strategy for improved convergence
|
| 13 |
- **β‘ Real-time Inference**: Sub-100ms recommendation serving with FAISS indexing
|
| 14 |
-
- **π Multi-strategy Recommendations**: Collaborative, content-based, and hybrid approaches
|
| 15 |
- **πͺ Category-Aware Boosting**: Enhanced personalization through user preference alignment
|
| 16 |
- **π Interactive Similar Items**: Click-to-explore with 60/40 category-balanced discovery
|
| 17 |
- **π Comprehensive Analysis**: Quality metrics and performance evaluation tools
|
|
@@ -70,11 +70,22 @@ The system implements a sophisticated two-tower neural network architecture opti
|
|
| 70 |
- **Stage 3**: Complex cases (long history) - 67th+ percentile
|
| 71 |
- **Adaptive Learning Rates**: Decrease as stages progress for stability
|
| 72 |
|
| 73 |
-
### 5.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
- **Enhanced Hybrid Recommendations**: Category boosting based on user preferences
|
| 75 |
- **Category Alignment Analysis**: Measures personalization effectiveness
|
| 76 |
- **Diversity Controls**: Balanced category representation in recommendations
|
| 77 |
-
- **
|
|
|
|
| 78 |
|
| 79 |
## π Project Structure
|
| 80 |
|
|
@@ -134,10 +145,8 @@ RecSys-HP/
|
|
| 134 |
βββ π Training Scripts # Multiple training approaches
|
| 135 |
β βββ run_training_pipeline.py # Main training orchestration
|
| 136 |
β βββ run_2phase_training.py # 2-phase training approach
|
| 137 |
-
β
|
| 138 |
-
β βββ train_improved_model.py # Enhanced model training
|
| 139 |
β
|
| 140 |
-
βββ π analyze_recommendations.py # Recommendation quality analysis
|
| 141 |
βββ π requirements.txt # Python dependencies
|
| 142 |
```
|
| 143 |
|
|
@@ -478,24 +487,11 @@ python api_2phase.py
|
|
| 478 |
python api_joint.py
|
| 479 |
```
|
| 480 |
|
| 481 |
-
###
|
| 482 |
-
|
| 483 |
```bash
|
| 484 |
-
#
|
| 485 |
-
python analyze_recommendations.py
|
| 486 |
-
# β Generates recommendation_analysis_report.md + plots
|
| 487 |
-
|
| 488 |
-
# Test individual engines
|
| 489 |
-
python -m src.inference.enhanced_recommendation_engine_128d # 128D enhanced
|
| 490 |
-
python -m src.inference.enhanced_recommendation_engine # Standard enhanced
|
| 491 |
-
python -m src.inference.recommendation_engine # Basic engine
|
| 492 |
-
|
| 493 |
-
# Real user data utilities
|
| 494 |
python -m src.utils.real_user_selector # Demo real user extraction
|
| 495 |
-
|
| 496 |
-
# Data processing utilities
|
| 497 |
-
python -m src.preprocessing.data_loader
|
| 498 |
-
python -m src.preprocessing.optimized_dataset_creator
|
| 499 |
```
|
| 500 |
|
| 501 |
### π§ͺ Frontend Development
|
|
@@ -667,7 +663,8 @@ RecSys-HP/
|
|
| 667 |
β
**Enhanced Architecture**: 128D embeddings, temperature scaling, contrastive learning
|
| 668 |
β
**Curriculum Learning**: Progressive training for better convergence
|
| 669 |
β
**Category-Aware Recommendations**: Intelligent personalization with diversity
|
| 670 |
-
β
**
|
|
|
|
| 671 |
β
**Production Ready**: Scalable API with enhanced frontend features
|
| 672 |
|
| 673 |
**π Ready to deliver next-generation personalized recommendations!**
|
|
@@ -687,11 +684,12 @@ This project provides multiple training strategies:
|
|
| 687 |
- **2-Phase API** (`api_2phase.py`) - Specialized for 2-phase training
|
| 688 |
- **Joint API** (`api_joint.py`) - Optimized for joint training approach
|
| 689 |
|
| 690 |
-
##
|
| 691 |
|
| 692 |
-
- **
|
|
|
|
| 693 |
|
| 694 |
-
##
|
| 695 |
|
| 696 |
### Frontend Development
|
| 697 |
```bash
|
|
|
|
| 1 |
# Advanced Two-Tower Recommendation System
|
| 2 |
|
| 3 |
+
A production-ready recommendation system implementation using TensorFlow Recommenders with an enhanced two-tower architecture. This system provides personalized item recommendations through collaborative filtering, **aggregated history content-based filtering**, and hybrid approaches, featuring categorical demographics, curriculum learning, and advanced training strategies.
|
| 4 |
|
| 5 |
## π― Project Overview
|
| 6 |
|
|
|
|
| 11 |
- **π§ Enhanced Two-Tower Architecture**: 128D embeddings with temperature scaling and contrastive learning
|
| 12 |
- **π Curriculum Learning**: Progressive training strategy for improved convergence
|
| 13 |
- **β‘ Real-time Inference**: Sub-100ms recommendation serving with FAISS indexing
|
| 14 |
+
- **π Multi-strategy Recommendations**: Collaborative, **aggregated history content-based**, and hybrid approaches
|
| 15 |
- **πͺ Category-Aware Boosting**: Enhanced personalization through user preference alignment
|
| 16 |
- **π Interactive Similar Items**: Click-to-explore with 60/40 category-balanced discovery
|
| 17 |
- **π Comprehensive Analysis**: Quality metrics and performance evaluation tools
|
|
|
|
| 70 |
- **Stage 3**: Complex cases (long history) - 67th+ percentile
|
| 71 |
- **Adaptive Learning Rates**: Decrease as stages progress for stability
|
| 72 |
|
| 73 |
+
### 5. Aggregated History Content-Based Filtering π
|
| 74 |
+
- **Revolutionary Approach**: Uses aggregated user interaction history instead of single-item similarity
|
| 75 |
+
- **Multiple Aggregation Methods**:
|
| 76 |
+
- **Weighted Mean**: Recent interactions weighted higher (exponential decay)
|
| 77 |
+
- **Simple Mean**: Equal weighting of all interactions
|
| 78 |
+
- **Max Pooling**: Element-wise maximum of embeddings
|
| 79 |
+
- **ANN Search**: Direct similarity search using FAISS with aggregated user profile
|
| 80 |
+
- **Enhanced Personalization**: Captures complete user preference profile, not just recent item
|
| 81 |
+
- **Category-Aware**: Analyzes user's full category distribution for balanced recommendations
|
| 82 |
+
|
| 83 |
+
### 6. Category-Aware Recommendation Engine πͺ
|
| 84 |
- **Enhanced Hybrid Recommendations**: Category boosting based on user preferences
|
| 85 |
- **Category Alignment Analysis**: Measures personalization effectiveness
|
| 86 |
- **Diversity Controls**: Balanced category representation in recommendations
|
| 87 |
+
- **Subcategory Precision**: 2-level category matching (e.g., "computers.components")
|
| 88 |
+
- **Comprehensive Analysis Tools**: Multi-algorithm comparison and alignment scoring
|
| 89 |
|
| 90 |
## π Project Structure
|
| 91 |
|
|
|
|
| 145 |
βββ π Training Scripts # Multiple training approaches
|
| 146 |
β βββ run_training_pipeline.py # Main training orchestration
|
| 147 |
β βββ run_2phase_training.py # 2-phase training approach
|
| 148 |
+
β βββ run_joint_training.py # Joint training approach
|
|
|
|
| 149 |
β
|
|
|
|
| 150 |
βββ π requirements.txt # Python dependencies
|
| 151 |
```
|
| 152 |
|
|
|
|
| 487 |
python api_joint.py
|
| 488 |
```
|
| 489 |
|
| 490 |
+
### π§ System Testing
|
|
|
|
| 491 |
```bash
|
| 492 |
+
# Test core system components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
python -m src.utils.real_user_selector # Demo real user extraction
|
| 494 |
+
python -m src.preprocessing.data_loader # Verify data loading
|
|
|
|
|
|
|
|
|
|
| 495 |
```
|
| 496 |
|
| 497 |
### π§ͺ Frontend Development
|
|
|
|
| 663 |
β
**Enhanced Architecture**: 128D embeddings, temperature scaling, contrastive learning
|
| 664 |
β
**Curriculum Learning**: Progressive training for better convergence
|
| 665 |
β
**Category-Aware Recommendations**: Intelligent personalization with diversity
|
| 666 |
+
β
**Aggregated Content-Based Filtering**: Revolutionary user history aggregation approach
|
| 667 |
+
β
**Enhanced Demographic Support**: Improved cold-start user handling
|
| 668 |
β
**Production Ready**: Scalable API with enhanced frontend features
|
| 669 |
|
| 670 |
**π Ready to deliver next-generation personalized recommendations!**
|
|
|
|
| 684 |
- **2-Phase API** (`api_2phase.py`) - Specialized for 2-phase training
|
| 685 |
- **Joint API** (`api_joint.py`) - Optimized for joint training approach
|
| 686 |
|
| 687 |
+
## π§ Development Tools
|
| 688 |
|
| 689 |
+
- **Real User Selection** (`src.utils.real_user_selector`) - Extract real user profiles for testing
|
| 690 |
+
- **Data Loading Utilities** (`src.preprocessing.data_loader`) - Dataset loading and validation
|
| 691 |
|
| 692 |
+
## π§ͺ Development & Testing
|
| 693 |
|
| 694 |
### Frontend Development
|
| 695 |
```bash
|
analyze_recommendations.py
DELETED
|
@@ -1,543 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Recommendation Analysis Script
|
| 4 |
-
|
| 5 |
-
This script compares recommendations from both training approaches:
|
| 6 |
-
1. 2-phase training (pre-trained item tower + joint fine-tuning)
|
| 7 |
-
2. Single joint training (end-to-end optimization)
|
| 8 |
-
|
| 9 |
-
It analyzes:
|
| 10 |
-
- Category alignment between user interactions and recommendations
|
| 11 |
-
- Diversity of recommended categories
|
| 12 |
-
- Overlap between the two approaches
|
| 13 |
-
- Performance on real users
|
| 14 |
-
|
| 15 |
-
Usage:
|
| 16 |
-
python analyze_recommendations.py
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
import os
|
| 20 |
-
import sys
|
| 21 |
-
import numpy as np
|
| 22 |
-
import pandas as pd
|
| 23 |
-
from collections import defaultdict, Counter
|
| 24 |
-
from typing import Dict, List, Tuple
|
| 25 |
-
import matplotlib.pyplot as plt
|
| 26 |
-
import seaborn as sns
|
| 27 |
-
|
| 28 |
-
# Add src to path
|
| 29 |
-
sys.path.append('src')
|
| 30 |
-
|
| 31 |
-
from src.inference.recommendation_engine import RecommendationEngine
|
| 32 |
-
from src.utils.real_user_selector import RealUserSelector
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class RecommendationAnalyzer:
|
| 36 |
-
"""Analyzer for comparing different recommendation approaches."""
|
| 37 |
-
|
| 38 |
-
def __init__(self):
|
| 39 |
-
self.recommendation_engine = None
|
| 40 |
-
self.real_user_selector = None
|
| 41 |
-
self.items_df = None
|
| 42 |
-
self.setup_engines()
|
| 43 |
-
|
| 44 |
-
def setup_engines(self):
|
| 45 |
-
"""Setup recommendation engines and data."""
|
| 46 |
-
print("Loading recommendation engines...")
|
| 47 |
-
|
| 48 |
-
try:
|
| 49 |
-
# Load recommendation engine (assumes trained model artifacts exist)
|
| 50 |
-
self.recommendation_engine = RecommendationEngine()
|
| 51 |
-
print("β
Recommendation engine loaded")
|
| 52 |
-
except Exception as e:
|
| 53 |
-
print(f"β Error loading recommendation engine: {e}")
|
| 54 |
-
return
|
| 55 |
-
|
| 56 |
-
try:
|
| 57 |
-
# Load real user selector
|
| 58 |
-
self.real_user_selector = RealUserSelector()
|
| 59 |
-
print("β
Real user selector loaded")
|
| 60 |
-
except Exception as e:
|
| 61 |
-
print(f"β Error loading real user selector: {e}")
|
| 62 |
-
|
| 63 |
-
# Load items data for category analysis
|
| 64 |
-
self.items_df = pd.read_csv("datasets/items.csv")
|
| 65 |
-
print(f"β
Loaded {len(self.items_df)} items")
|
| 66 |
-
|
| 67 |
-
def get_item_categories(self, item_ids: List[int]) -> List[str]:
|
| 68 |
-
"""Get category codes for given item IDs."""
|
| 69 |
-
categories = []
|
| 70 |
-
for item_id in item_ids:
|
| 71 |
-
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 72 |
-
if len(item_row) > 0:
|
| 73 |
-
categories.append(item_row.iloc[0]['category_code'])
|
| 74 |
-
else:
|
| 75 |
-
categories.append('unknown')
|
| 76 |
-
return categories
|
| 77 |
-
|
| 78 |
-
def analyze_user_recommendations(self,
|
| 79 |
-
user_profile: Dict,
|
| 80 |
-
recommendation_types: List[str] = None) -> Dict:
|
| 81 |
-
"""Analyze recommendations for a single user across different approaches."""
|
| 82 |
-
|
| 83 |
-
if recommendation_types is None:
|
| 84 |
-
recommendation_types = ['collaborative', 'hybrid', 'content']
|
| 85 |
-
|
| 86 |
-
results = {
|
| 87 |
-
'user_profile': user_profile,
|
| 88 |
-
'interaction_categories': [],
|
| 89 |
-
'recommendations': {},
|
| 90 |
-
'category_analysis': {}
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
# Get categories from user's interaction history
|
| 94 |
-
if user_profile['interaction_history']:
|
| 95 |
-
results['interaction_categories'] = self.get_item_categories(
|
| 96 |
-
user_profile['interaction_history']
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
# Get recommendations for each type
|
| 100 |
-
for rec_type in recommendation_types:
|
| 101 |
-
try:
|
| 102 |
-
if rec_type == 'collaborative':
|
| 103 |
-
recs = self.recommendation_engine.recommend_items_collaborative(
|
| 104 |
-
age=user_profile['age'],
|
| 105 |
-
gender=user_profile['gender'],
|
| 106 |
-
income=user_profile['income'],
|
| 107 |
-
interaction_history=user_profile['interaction_history'],
|
| 108 |
-
k=10
|
| 109 |
-
)
|
| 110 |
-
elif rec_type == 'hybrid':
|
| 111 |
-
recs = self.recommendation_engine.recommend_items_hybrid(
|
| 112 |
-
age=user_profile['age'],
|
| 113 |
-
gender=user_profile['gender'],
|
| 114 |
-
income=user_profile['income'],
|
| 115 |
-
interaction_history=user_profile['interaction_history'],
|
| 116 |
-
k=10
|
| 117 |
-
)
|
| 118 |
-
elif rec_type == 'content' and user_profile['interaction_history']:
|
| 119 |
-
recs = self.recommendation_engine.recommend_items_content_based(
|
| 120 |
-
seed_item_id=user_profile['interaction_history'][-1],
|
| 121 |
-
k=10
|
| 122 |
-
)
|
| 123 |
-
else:
|
| 124 |
-
continue
|
| 125 |
-
|
| 126 |
-
# Extract item IDs and categories
|
| 127 |
-
item_ids = [item_id for item_id, score, info in recs]
|
| 128 |
-
rec_categories = self.get_item_categories(item_ids)
|
| 129 |
-
|
| 130 |
-
results['recommendations'][rec_type] = {
|
| 131 |
-
'items': recs,
|
| 132 |
-
'item_ids': item_ids,
|
| 133 |
-
'categories': rec_categories,
|
| 134 |
-
'scores': [score for item_id, score, info in recs]
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
# Analyze category alignment
|
| 138 |
-
results['category_analysis'][rec_type] = self.analyze_category_alignment(
|
| 139 |
-
results['interaction_categories'],
|
| 140 |
-
rec_categories
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
except Exception as e:
|
| 144 |
-
print(f"Error generating {rec_type} recommendations: {e}")
|
| 145 |
-
|
| 146 |
-
return results
|
| 147 |
-
|
| 148 |
-
def analyze_category_alignment(self,
|
| 149 |
-
interaction_categories: List[str],
|
| 150 |
-
recommendation_categories: List[str]) -> Dict:
|
| 151 |
-
"""Analyze alignment between interaction and recommendation categories."""
|
| 152 |
-
|
| 153 |
-
if not interaction_categories:
|
| 154 |
-
return {
|
| 155 |
-
'overlap_ratio': 0.0,
|
| 156 |
-
'unique_interaction_categories': 0,
|
| 157 |
-
'unique_recommendation_categories': len(set(recommendation_categories)),
|
| 158 |
-
'common_categories': [],
|
| 159 |
-
'category_distribution': Counter(recommendation_categories)
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
interaction_set = set(interaction_categories)
|
| 163 |
-
recommendation_set = set(recommendation_categories)
|
| 164 |
-
|
| 165 |
-
common_categories = interaction_set.intersection(recommendation_set)
|
| 166 |
-
overlap_ratio = len(common_categories) / len(interaction_set) if interaction_set else 0.0
|
| 167 |
-
|
| 168 |
-
return {
|
| 169 |
-
'overlap_ratio': overlap_ratio,
|
| 170 |
-
'unique_interaction_categories': len(interaction_set),
|
| 171 |
-
'unique_recommendation_categories': len(recommendation_set),
|
| 172 |
-
'common_categories': list(common_categories),
|
| 173 |
-
'category_distribution': Counter(recommendation_categories),
|
| 174 |
-
'interaction_category_distribution': Counter(interaction_categories)
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
def compare_recommendation_approaches(self,
|
| 178 |
-
users_sample: List[Dict],
|
| 179 |
-
approaches: List[str] = None) -> Dict:
|
| 180 |
-
"""Compare different recommendation approaches across multiple users."""
|
| 181 |
-
|
| 182 |
-
if approaches is None:
|
| 183 |
-
approaches = ['collaborative', 'hybrid', 'content']
|
| 184 |
-
|
| 185 |
-
comparison_results = {
|
| 186 |
-
'approach_stats': defaultdict(list),
|
| 187 |
-
'cross_approach_analysis': {},
|
| 188 |
-
'user_results': []
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
print(f"Analyzing {len(users_sample)} users across {len(approaches)} approaches...")
|
| 192 |
-
|
| 193 |
-
for i, user in enumerate(users_sample):
|
| 194 |
-
print(f"Analyzing user {i+1}/{len(users_sample)}...")
|
| 195 |
-
|
| 196 |
-
user_results = self.analyze_user_recommendations(user, approaches)
|
| 197 |
-
comparison_results['user_results'].append(user_results)
|
| 198 |
-
|
| 199 |
-
# Aggregate stats by approach
|
| 200 |
-
for approach in approaches:
|
| 201 |
-
if approach in user_results['category_analysis']:
|
| 202 |
-
analysis = user_results['category_analysis'][approach]
|
| 203 |
-
comparison_results['approach_stats'][approach].append({
|
| 204 |
-
'overlap_ratio': analysis['overlap_ratio'],
|
| 205 |
-
'unique_rec_categories': analysis['unique_recommendation_categories'],
|
| 206 |
-
'common_categories_count': len(analysis['common_categories'])
|
| 207 |
-
})
|
| 208 |
-
|
| 209 |
-
# Calculate aggregate statistics
|
| 210 |
-
for approach in approaches:
|
| 211 |
-
stats = comparison_results['approach_stats'][approach]
|
| 212 |
-
if stats:
|
| 213 |
-
comparison_results['approach_stats'][approach] = {
|
| 214 |
-
'avg_overlap_ratio': np.mean([s['overlap_ratio'] for s in stats]),
|
| 215 |
-
'std_overlap_ratio': np.std([s['overlap_ratio'] for s in stats]),
|
| 216 |
-
'avg_unique_categories': np.mean([s['unique_rec_categories'] for s in stats]),
|
| 217 |
-
'avg_common_categories': np.mean([s['common_categories_count'] for s in stats]),
|
| 218 |
-
'total_users': len(stats)
|
| 219 |
-
}
|
| 220 |
-
|
| 221 |
-
# Cross-approach analysis
|
| 222 |
-
comparison_results['cross_approach_analysis'] = self.cross_approach_analysis(
|
| 223 |
-
comparison_results['user_results'], approaches
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
return comparison_results
|
| 227 |
-
|
| 228 |
-
def cross_approach_analysis(self, user_results: List[Dict], approaches: List[str]) -> Dict:
|
| 229 |
-
"""Analyze similarities and differences between approaches."""
|
| 230 |
-
|
| 231 |
-
cross_analysis = {
|
| 232 |
-
'item_overlap': defaultdict(dict),
|
| 233 |
-
'category_overlap': defaultdict(dict),
|
| 234 |
-
'score_correlation': defaultdict(dict)
|
| 235 |
-
}
|
| 236 |
-
|
| 237 |
-
for user_result in user_results:
|
| 238 |
-
recommendations = user_result['recommendations']
|
| 239 |
-
|
| 240 |
-
# Compare each pair of approaches
|
| 241 |
-
for i, approach1 in enumerate(approaches):
|
| 242 |
-
for approach2 in approaches[i+1:]:
|
| 243 |
-
if approach1 in recommendations and approach2 in recommendations:
|
| 244 |
-
|
| 245 |
-
# Item overlap
|
| 246 |
-
items1 = set(recommendations[approach1]['item_ids'])
|
| 247 |
-
items2 = set(recommendations[approach2]['item_ids'])
|
| 248 |
-
item_overlap_ratio = len(items1.intersection(items2)) / len(items1.union(items2))
|
| 249 |
-
|
| 250 |
-
# Category overlap
|
| 251 |
-
cats1 = set(recommendations[approach1]['categories'])
|
| 252 |
-
cats2 = set(recommendations[approach2]['categories'])
|
| 253 |
-
cat_overlap_ratio = len(cats1.intersection(cats2)) / len(cats1.union(cats2)) if cats1.union(cats2) else 0
|
| 254 |
-
|
| 255 |
-
# Store results
|
| 256 |
-
pair_key = f"{approach1}_vs_{approach2}"
|
| 257 |
-
if pair_key not in cross_analysis['item_overlap']:
|
| 258 |
-
cross_analysis['item_overlap'][pair_key] = []
|
| 259 |
-
cross_analysis['category_overlap'][pair_key] = []
|
| 260 |
-
|
| 261 |
-
cross_analysis['item_overlap'][pair_key].append(item_overlap_ratio)
|
| 262 |
-
cross_analysis['category_overlap'][pair_key].append(cat_overlap_ratio)
|
| 263 |
-
|
| 264 |
-
# Calculate averages
|
| 265 |
-
for pair_key in cross_analysis['item_overlap']:
|
| 266 |
-
cross_analysis['item_overlap'][pair_key] = {
|
| 267 |
-
'avg': np.mean(cross_analysis['item_overlap'][pair_key]),
|
| 268 |
-
'std': np.std(cross_analysis['item_overlap'][pair_key])
|
| 269 |
-
}
|
| 270 |
-
cross_analysis['category_overlap'][pair_key] = {
|
| 271 |
-
'avg': np.mean(cross_analysis['category_overlap'][pair_key]),
|
| 272 |
-
'std': np.std(cross_analysis['category_overlap'][pair_key])
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
return cross_analysis
|
| 276 |
-
|
| 277 |
-
def generate_report(self, comparison_results: Dict, output_file: str = "recommendation_analysis_report.md"):
|
| 278 |
-
"""Generate a comprehensive analysis report."""
|
| 279 |
-
|
| 280 |
-
report = []
|
| 281 |
-
report.append("# Recommendation System Analysis Report")
|
| 282 |
-
report.append(f"Generated: {pd.Timestamp.now()}")
|
| 283 |
-
report.append("")
|
| 284 |
-
|
| 285 |
-
# Overall Statistics
|
| 286 |
-
report.append("## Overall Statistics")
|
| 287 |
-
report.append("")
|
| 288 |
-
|
| 289 |
-
for approach, stats in comparison_results['approach_stats'].items():
|
| 290 |
-
if isinstance(stats, dict):
|
| 291 |
-
report.append(f"### {approach.title()} Recommendations")
|
| 292 |
-
report.append(f"- **Average Category Overlap**: {stats['avg_overlap_ratio']:.3f} Β± {stats['std_overlap_ratio']:.3f}")
|
| 293 |
-
report.append(f"- **Average Unique Categories per User**: {stats['avg_unique_categories']:.1f}")
|
| 294 |
-
report.append(f"- **Average Common Categories**: {stats['avg_common_categories']:.1f}")
|
| 295 |
-
report.append(f"- **Users Analyzed**: {stats['total_users']}")
|
| 296 |
-
report.append("")
|
| 297 |
-
|
| 298 |
-
# Cross-Approach Analysis
|
| 299 |
-
report.append("## Cross-Approach Comparison")
|
| 300 |
-
report.append("")
|
| 301 |
-
|
| 302 |
-
cross_analysis = comparison_results['cross_approach_analysis']
|
| 303 |
-
|
| 304 |
-
report.append("### Item Overlap Between Approaches")
|
| 305 |
-
for pair, overlap_stats in cross_analysis['item_overlap'].items():
|
| 306 |
-
report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} Β± {overlap_stats['std']:.3f}")
|
| 307 |
-
report.append("")
|
| 308 |
-
|
| 309 |
-
report.append("### Category Overlap Between Approaches")
|
| 310 |
-
for pair, overlap_stats in cross_analysis['category_overlap'].items():
|
| 311 |
-
report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} Β± {overlap_stats['std']:.3f}")
|
| 312 |
-
report.append("")
|
| 313 |
-
|
| 314 |
-
# Category Alignment Analysis
|
| 315 |
-
report.append("## Category Alignment Analysis")
|
| 316 |
-
report.append("")
|
| 317 |
-
report.append("Category alignment measures how well recommendations match the categories")
|
| 318 |
-
report.append("of items users have previously interacted with.")
|
| 319 |
-
report.append("")
|
| 320 |
-
|
| 321 |
-
# Find best performing approach
|
| 322 |
-
best_approach = max(
|
| 323 |
-
comparison_results['approach_stats'].keys(),
|
| 324 |
-
key=lambda k: comparison_results['approach_stats'][k]['avg_overlap_ratio']
|
| 325 |
-
if isinstance(comparison_results['approach_stats'][k], dict) else 0
|
| 326 |
-
)
|
| 327 |
-
|
| 328 |
-
report.append(f"**Best Category Alignment**: {best_approach.title()} approach")
|
| 329 |
-
report.append("")
|
| 330 |
-
|
| 331 |
-
# Recommendations
|
| 332 |
-
report.append("## Key Findings & Recommendations")
|
| 333 |
-
report.append("")
|
| 334 |
-
|
| 335 |
-
# Analyze overlap ratios to provide insights
|
| 336 |
-
overlap_ratios = {
|
| 337 |
-
k: v['avg_overlap_ratio'] for k, v in comparison_results['approach_stats'].items()
|
| 338 |
-
if isinstance(v, dict)
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
if overlap_ratios:
|
| 342 |
-
avg_overlap = np.mean(list(overlap_ratios.values()))
|
| 343 |
-
if avg_overlap > 0.5:
|
| 344 |
-
report.append("β
**Strong Category Alignment**: Recommendations show good alignment with user interaction patterns.")
|
| 345 |
-
elif avg_overlap > 0.3:
|
| 346 |
-
report.append("β οΈ **Moderate Category Alignment**: Some alignment present but room for improvement.")
|
| 347 |
-
else:
|
| 348 |
-
report.append("β **Weak Category Alignment**: Recommendations may be too diverse or not well-aligned with user preferences.")
|
| 349 |
-
|
| 350 |
-
report.append("")
|
| 351 |
-
|
| 352 |
-
# Compare approaches
|
| 353 |
-
if len(overlap_ratios) > 1:
|
| 354 |
-
sorted_approaches = sorted(overlap_ratios.items(), key=lambda x: x[1], reverse=True)
|
| 355 |
-
report.append("### Approach Rankings (by category alignment):")
|
| 356 |
-
for i, (approach, ratio) in enumerate(sorted_approaches, 1):
|
| 357 |
-
report.append(f"{i}. **{approach.title()}**: {ratio:.3f}")
|
| 358 |
-
report.append("")
|
| 359 |
-
|
| 360 |
-
# Write report
|
| 361 |
-
with open(output_file, 'w') as f:
|
| 362 |
-
f.write('\n'.join(report))
|
| 363 |
-
|
| 364 |
-
print(f"β
Analysis report saved to: {output_file}")
|
| 365 |
-
return '\n'.join(report)
|
| 366 |
-
|
| 367 |
-
def visualize_results(self, comparison_results: Dict, save_plots: bool = True):
|
| 368 |
-
"""Create visualizations for the analysis results."""
|
| 369 |
-
|
| 370 |
-
# Set up plotting style
|
| 371 |
-
plt.style.use('default')
|
| 372 |
-
sns.set_palette("husl")
|
| 373 |
-
|
| 374 |
-
# Create figure with subplots
|
| 375 |
-
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
| 376 |
-
fig.suptitle('Recommendation System Analysis', fontsize=16, fontweight='bold')
|
| 377 |
-
|
| 378 |
-
# 1. Category Overlap by Approach
|
| 379 |
-
ax1 = axes[0, 0]
|
| 380 |
-
approaches = []
|
| 381 |
-
overlap_means = []
|
| 382 |
-
overlap_stds = []
|
| 383 |
-
|
| 384 |
-
for approach, stats in comparison_results['approach_stats'].items():
|
| 385 |
-
if isinstance(stats, dict):
|
| 386 |
-
approaches.append(approach.title())
|
| 387 |
-
overlap_means.append(stats['avg_overlap_ratio'])
|
| 388 |
-
overlap_stds.append(stats['std_overlap_ratio'])
|
| 389 |
-
|
| 390 |
-
bars1 = ax1.bar(approaches, overlap_means, yerr=overlap_stds, capsize=5, alpha=0.7)
|
| 391 |
-
ax1.set_title('Average Category Overlap by Approach')
|
| 392 |
-
ax1.set_ylabel('Category Overlap Ratio')
|
| 393 |
-
ax1.set_ylim(0, 1)
|
| 394 |
-
|
| 395 |
-
# Add value labels on bars
|
| 396 |
-
for bar, mean in zip(bars1, overlap_means):
|
| 397 |
-
height = bar.get_height()
|
| 398 |
-
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 399 |
-
f'{mean:.3f}', ha='center', va='bottom')
|
| 400 |
-
|
| 401 |
-
# 2. Cross-Approach Item Overlap
|
| 402 |
-
ax2 = axes[0, 1]
|
| 403 |
-
cross_analysis = comparison_results['cross_approach_analysis']
|
| 404 |
-
|
| 405 |
-
pair_names = []
|
| 406 |
-
item_overlaps = []
|
| 407 |
-
|
| 408 |
-
for pair, overlap_stats in cross_analysis['item_overlap'].items():
|
| 409 |
-
pair_names.append(pair.replace('_vs_', ' vs ').title())
|
| 410 |
-
item_overlaps.append(overlap_stats['avg'])
|
| 411 |
-
|
| 412 |
-
if pair_names:
|
| 413 |
-
bars2 = ax2.bar(pair_names, item_overlaps, alpha=0.7, color='coral')
|
| 414 |
-
ax2.set_title('Item Overlap Between Approaches')
|
| 415 |
-
ax2.set_ylabel('Item Overlap Ratio')
|
| 416 |
-
ax2.set_ylim(0, 1)
|
| 417 |
-
plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
|
| 418 |
-
|
| 419 |
-
# Add value labels
|
| 420 |
-
for bar, overlap in zip(bars2, item_overlaps):
|
| 421 |
-
height = bar.get_height()
|
| 422 |
-
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 423 |
-
f'{overlap:.3f}', ha='center', va='bottom')
|
| 424 |
-
|
| 425 |
-
# 3. Category Diversity
|
| 426 |
-
ax3 = axes[1, 0]
|
| 427 |
-
unique_categories = []
|
| 428 |
-
for approach, stats in comparison_results['approach_stats'].items():
|
| 429 |
-
if isinstance(stats, dict):
|
| 430 |
-
unique_categories.append(stats['avg_unique_categories'])
|
| 431 |
-
|
| 432 |
-
bars3 = ax3.bar(approaches, unique_categories, alpha=0.7, color='lightgreen')
|
| 433 |
-
ax3.set_title('Average Unique Categories per Recommendation')
|
| 434 |
-
ax3.set_ylabel('Number of Unique Categories')
|
| 435 |
-
|
| 436 |
-
for bar, cats in zip(bars3, unique_categories):
|
| 437 |
-
height = bar.get_height()
|
| 438 |
-
ax3.text(bar.get_x() + bar.get_width()/2., height + 0.1,
|
| 439 |
-
f'{cats:.1f}', ha='center', va='bottom')
|
| 440 |
-
|
| 441 |
-
# 4. Category vs Item Overlap Comparison
|
| 442 |
-
ax4 = axes[1, 1]
|
| 443 |
-
|
| 444 |
-
if cross_analysis['item_overlap'] and cross_analysis['category_overlap']:
|
| 445 |
-
pairs = list(cross_analysis['item_overlap'].keys())
|
| 446 |
-
item_overlaps = [cross_analysis['item_overlap'][p]['avg'] for p in pairs]
|
| 447 |
-
cat_overlaps = [cross_analysis['category_overlap'][p]['avg'] for p in pairs]
|
| 448 |
-
|
| 449 |
-
x = np.arange(len(pairs))
|
| 450 |
-
width = 0.35
|
| 451 |
-
|
| 452 |
-
bars4a = ax4.bar(x - width/2, item_overlaps, width, label='Item Overlap', alpha=0.7)
|
| 453 |
-
bars4b = ax4.bar(x + width/2, cat_overlaps, width, label='Category Overlap', alpha=0.7)
|
| 454 |
-
|
| 455 |
-
ax4.set_title('Item vs Category Overlap Between Approaches')
|
| 456 |
-
ax4.set_ylabel('Overlap Ratio')
|
| 457 |
-
ax4.set_xticks(x)
|
| 458 |
-
ax4.set_xticklabels([p.replace('_vs_', ' vs ') for p in pairs], rotation=45, ha='right')
|
| 459 |
-
ax4.legend()
|
| 460 |
-
ax4.set_ylim(0, 1)
|
| 461 |
-
|
| 462 |
-
plt.tight_layout()
|
| 463 |
-
|
| 464 |
-
if save_plots:
|
| 465 |
-
plt.savefig('recommendation_analysis_plots.png', dpi=300, bbox_inches='tight')
|
| 466 |
-
print("β
Plots saved to: recommendation_analysis_plots.png")
|
| 467 |
-
|
| 468 |
-
plt.show()
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
def main():
|
| 472 |
-
"""Main function to run the recommendation analysis."""
|
| 473 |
-
|
| 474 |
-
print("π Starting Recommendation Analysis...")
|
| 475 |
-
print("=" * 50)
|
| 476 |
-
|
| 477 |
-
# Initialize analyzer
|
| 478 |
-
analyzer = RecommendationAnalyzer()
|
| 479 |
-
|
| 480 |
-
if analyzer.recommendation_engine is None:
|
| 481 |
-
print("β Cannot proceed without recommendation engine. Please ensure model is trained.")
|
| 482 |
-
return
|
| 483 |
-
|
| 484 |
-
# Get sample of real users for analysis
|
| 485 |
-
print("Getting real user sample...")
|
| 486 |
-
try:
|
| 487 |
-
real_users = analyzer.real_user_selector.get_real_users(n=20, min_interactions=3)
|
| 488 |
-
print(f"β
Loaded {len(real_users)} real users for analysis")
|
| 489 |
-
except Exception as e:
|
| 490 |
-
print(f"β Error loading real users: {e}")
|
| 491 |
-
# Fallback to synthetic users
|
| 492 |
-
real_users = [
|
| 493 |
-
{
|
| 494 |
-
'age': 32, 'gender': 'male', 'income': 75000,
|
| 495 |
-
'interaction_history': [1000978, 1001588, 1001618, 1002039]
|
| 496 |
-
},
|
| 497 |
-
{
|
| 498 |
-
'age': 28, 'gender': 'female', 'income': 45000,
|
| 499 |
-
'interaction_history': [1003456, 1004567, 1005678]
|
| 500 |
-
},
|
| 501 |
-
{
|
| 502 |
-
'age': 45, 'gender': 'male', 'income': 85000,
|
| 503 |
-
'interaction_history': [1006789, 1007890, 1008901, 1009012, 1010123]
|
| 504 |
-
}
|
| 505 |
-
]
|
| 506 |
-
print(f"Using {len(real_users)} synthetic users for analysis")
|
| 507 |
-
|
| 508 |
-
# Run comprehensive analysis
|
| 509 |
-
print("Running recommendation analysis...")
|
| 510 |
-
approaches = ['collaborative', 'hybrid', 'content']
|
| 511 |
-
|
| 512 |
-
comparison_results = analyzer.compare_recommendation_approaches(
|
| 513 |
-
users_sample=real_users,
|
| 514 |
-
approaches=approaches
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
# Generate report
|
| 518 |
-
print("Generating analysis report...")
|
| 519 |
-
report = analyzer.generate_report(comparison_results)
|
| 520 |
-
|
| 521 |
-
# Create visualizations
|
| 522 |
-
print("Creating visualizations...")
|
| 523 |
-
try:
|
| 524 |
-
analyzer.visualize_results(comparison_results, save_plots=True)
|
| 525 |
-
except Exception as e:
|
| 526 |
-
print(f"Warning: Could not create visualizations: {e}")
|
| 527 |
-
|
| 528 |
-
# Print summary
|
| 529 |
-
print("\n" + "=" * 50)
|
| 530 |
-
print("π ANALYSIS SUMMARY")
|
| 531 |
-
print("=" * 50)
|
| 532 |
-
|
| 533 |
-
for approach, stats in comparison_results['approach_stats'].items():
|
| 534 |
-
if isinstance(stats, dict):
|
| 535 |
-
print(f"{approach.title()}: {stats['avg_overlap_ratio']:.3f} avg category overlap")
|
| 536 |
-
|
| 537 |
-
print(f"\nβ
Analysis complete! Check:")
|
| 538 |
-
print(" π recommendation_analysis_report.md")
|
| 539 |
-
print(" π recommendation_analysis_plots.png")
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
if __name__ == "__main__":
|
| 543 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/main.py
CHANGED
|
@@ -33,7 +33,6 @@ app.add_middleware(
|
|
| 33 |
|
| 34 |
# Global instances
|
| 35 |
recommendation_engine = None
|
| 36 |
-
enhanced_recommendation_engine = None
|
| 37 |
real_user_selector = None
|
| 38 |
|
| 39 |
|
|
@@ -75,13 +74,17 @@ class UserProfile(BaseModel):
|
|
| 75 |
age: int
|
| 76 |
gender: str # "male" or "female"
|
| 77 |
income: float
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
interaction_history: Optional[List[int]] = []
|
| 79 |
|
| 80 |
|
| 81 |
class RecommendationRequest(BaseModel):
|
| 82 |
user_profile: UserProfile
|
| 83 |
num_recommendations: int = 10
|
| 84 |
-
recommendation_type: str = "hybrid" # "collaborative", "content"
|
| 85 |
collaborative_weight: Optional[float] = 0.7
|
| 86 |
category_boost: Optional[float] = 1.5 # For enhanced recommendations
|
| 87 |
enable_category_boost: Optional[bool] = True
|
|
@@ -132,6 +135,10 @@ class RealUserProfile(BaseModel):
|
|
| 132 |
age: int
|
| 133 |
gender: str
|
| 134 |
income: int
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
interaction_history: List[int]
|
| 136 |
interaction_stats: Dict[str, int]
|
| 137 |
interaction_pattern: str
|
|
@@ -169,39 +176,24 @@ class EnrichedBehavioralPatternsResponse(BaseModel):
|
|
| 169 |
|
| 170 |
@app.on_event("startup")
|
| 171 |
async def startup_event():
|
| 172 |
-
"""Initialize the recommendation
|
| 173 |
-
global recommendation_engine,
|
| 174 |
|
| 175 |
try:
|
| 176 |
-
print("Loading recommendation engine...")
|
| 177 |
recommendation_engine = RecommendationEngine()
|
| 178 |
-
print("Recommendation engine loaded successfully!")
|
|
|
|
| 179 |
except Exception as e:
|
| 180 |
-
print(f"Error loading recommendation engine: {e}")
|
| 181 |
recommendation_engine = None
|
| 182 |
|
| 183 |
-
try:
|
| 184 |
-
print("Loading enhanced recommendation engine...")
|
| 185 |
-
# Try enhanced 128D engine first, fallback to regular enhanced
|
| 186 |
-
try:
|
| 187 |
-
from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
|
| 188 |
-
enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
|
| 189 |
-
print("β
Using Enhanced 128D Recommendation Engine")
|
| 190 |
-
except:
|
| 191 |
-
from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
|
| 192 |
-
enhanced_recommendation_engine = EnhancedRecommendationEngine()
|
| 193 |
-
print("β οΈ Using fallback Enhanced Recommendation Engine")
|
| 194 |
-
print("Enhanced recommendation engine loaded successfully!")
|
| 195 |
-
except Exception as e:
|
| 196 |
-
print(f"Error loading enhanced recommendation engine: {e}")
|
| 197 |
-
enhanced_recommendation_engine = None
|
| 198 |
-
|
| 199 |
try:
|
| 200 |
print("Loading real user selector...")
|
| 201 |
real_user_selector = RealUserSelector()
|
| 202 |
-
print("Real user selector loaded successfully!")
|
| 203 |
except Exception as e:
|
| 204 |
-
print(f"Error loading real user selector: {e}")
|
| 205 |
real_user_selector = None
|
| 206 |
|
| 207 |
|
|
@@ -211,7 +203,12 @@ async def root():
|
|
| 211 |
return {
|
| 212 |
"message": "Two-Tower Recommendation API",
|
| 213 |
"version": "1.0.0",
|
| 214 |
-
"status": "active" if recommendation_engine is not None else "initialization_failed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
}
|
| 216 |
|
| 217 |
|
|
@@ -220,7 +217,13 @@ async def health_check():
|
|
| 220 |
"""Health check endpoint."""
|
| 221 |
return {
|
| 222 |
"status": "healthy" if recommendation_engine is not None else "unhealthy",
|
| 223 |
-
"engine_loaded": recommendation_engine is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
}
|
| 225 |
|
| 226 |
|
|
@@ -378,6 +381,10 @@ async def get_recommendations(request: RecommendationRequest):
|
|
| 378 |
age=user_profile.age,
|
| 379 |
gender=user_profile.gender,
|
| 380 |
income=user_profile.income,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
interaction_history=filtered_interaction_history,
|
| 382 |
k=request.num_recommendations * 2 # Get more to allow for filtering
|
| 383 |
)
|
|
@@ -390,11 +397,11 @@ async def get_recommendations(request: RecommendationRequest):
|
|
| 390 |
(f" in category '{request.selected_category}'" if request.selected_category else "")
|
| 391 |
)
|
| 392 |
|
| 393 |
-
# Use
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
)
|
| 399 |
|
| 400 |
elif request.recommendation_type == "hybrid":
|
|
@@ -402,79 +409,40 @@ async def get_recommendations(request: RecommendationRequest):
|
|
| 402 |
age=user_profile.age,
|
| 403 |
gender=user_profile.gender,
|
| 404 |
income=user_profile.income,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
interaction_history=filtered_interaction_history,
|
| 406 |
k=request.num_recommendations * 2, # Get more to allow for filtering
|
| 407 |
collaborative_weight=request.collaborative_weight
|
| 408 |
)
|
| 409 |
|
| 410 |
-
elif request.recommendation_type == "
|
| 411 |
-
|
| 412 |
-
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 413 |
-
|
| 414 |
-
# Check if it's the 128D engine or fallback
|
| 415 |
-
if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 416 |
-
# 128D Enhanced engine
|
| 417 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 418 |
-
age=user_profile.age,
|
| 419 |
-
gender=user_profile.gender,
|
| 420 |
-
income=user_profile.income,
|
| 421 |
-
interaction_history=filtered_interaction_history,
|
| 422 |
-
k=request.num_recommendations * 2, # Get more to allow for filtering
|
| 423 |
-
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 424 |
-
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 425 |
-
)
|
| 426 |
-
else:
|
| 427 |
-
# Fallback enhanced engine
|
| 428 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
|
| 429 |
-
age=user_profile.age,
|
| 430 |
-
gender=user_profile.gender,
|
| 431 |
-
income=user_profile.income,
|
| 432 |
-
interaction_history=filtered_interaction_history,
|
| 433 |
-
k=request.num_recommendations * 2, # Get more to allow for filtering
|
| 434 |
-
collaborative_weight=request.collaborative_weight,
|
| 435 |
-
category_boost=request.category_boost,
|
| 436 |
-
enable_category_boost=request.enable_category_boost,
|
| 437 |
-
enable_diversity=request.enable_diversity
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
elif request.recommendation_type == "enhanced_128d":
|
| 441 |
-
if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 442 |
-
raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
|
| 443 |
-
|
| 444 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 445 |
age=user_profile.age,
|
| 446 |
gender=user_profile.gender,
|
| 447 |
income=user_profile.income,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
interaction_history=filtered_interaction_history,
|
| 449 |
-
k=request.num_recommendations * 2
|
| 450 |
-
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 451 |
-
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 452 |
-
)
|
| 453 |
-
|
| 454 |
-
elif request.recommendation_type == "category_focused":
|
| 455 |
-
if enhanced_recommendation_engine is None:
|
| 456 |
-
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 457 |
-
|
| 458 |
-
recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
|
| 459 |
-
age=user_profile.age,
|
| 460 |
-
gender=user_profile.gender,
|
| 461 |
-
income=user_profile.income,
|
| 462 |
-
interaction_history=filtered_interaction_history,
|
| 463 |
-
k=request.num_recommendations * 2, # Get more to allow for filtering
|
| 464 |
-
focus_percentage=0.8
|
| 465 |
)
|
| 466 |
|
| 467 |
else:
|
| 468 |
raise HTTPException(
|
| 469 |
status_code=400,
|
| 470 |
-
detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid',
|
| 471 |
)
|
| 472 |
|
| 473 |
# Apply category filtering to final recommendations if needed
|
| 474 |
if request.selected_category:
|
| 475 |
recommendations = filter_recommendations_by_category(recommendations, request.selected_category)
|
| 476 |
-
|
| 477 |
-
|
|
|
|
| 478 |
|
| 479 |
# Format response
|
| 480 |
formatted_recommendations = []
|
|
@@ -543,8 +511,12 @@ async def predict_user_item_rating(request: RatingPredictionRequest):
|
|
| 543 |
age=user_profile.age,
|
| 544 |
gender=user_profile.gender,
|
| 545 |
income=user_profile.income,
|
| 546 |
-
|
| 547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
)
|
| 549 |
|
| 550 |
item_info = recommendation_engine._get_item_info(request.item_id)
|
|
|
|
| 33 |
|
| 34 |
# Global instances
|
| 35 |
recommendation_engine = None
|
|
|
|
| 36 |
real_user_selector = None
|
| 37 |
|
| 38 |
|
|
|
|
| 74 |
age: int
|
| 75 |
gender: str # "male" or "female"
|
| 76 |
income: float
|
| 77 |
+
profession: Optional[str] = "Other"
|
| 78 |
+
location: Optional[str] = "Urban"
|
| 79 |
+
education_level: Optional[str] = "High School"
|
| 80 |
+
marital_status: Optional[str] = "Single"
|
| 81 |
interaction_history: Optional[List[int]] = []
|
| 82 |
|
| 83 |
|
| 84 |
class RecommendationRequest(BaseModel):
|
| 85 |
user_profile: UserProfile
|
| 86 |
num_recommendations: int = 10
|
| 87 |
+
recommendation_type: str = "hybrid" # "collaborative", "content" (aggregated history), "hybrid"
|
| 88 |
collaborative_weight: Optional[float] = 0.7
|
| 89 |
category_boost: Optional[float] = 1.5 # For enhanced recommendations
|
| 90 |
enable_category_boost: Optional[bool] = True
|
|
|
|
| 135 |
age: int
|
| 136 |
gender: str
|
| 137 |
income: int
|
| 138 |
+
profession: Optional[str] = None
|
| 139 |
+
location: Optional[str] = None
|
| 140 |
+
education_level: Optional[str] = None
|
| 141 |
+
marital_status: Optional[str] = None
|
| 142 |
interaction_history: List[int]
|
| 143 |
interaction_stats: Dict[str, int]
|
| 144 |
interaction_pattern: str
|
|
|
|
| 176 |
|
| 177 |
@app.on_event("startup")
|
| 178 |
async def startup_event():
|
| 179 |
+
"""Initialize the recommendation engine and real user selector on startup."""
|
| 180 |
+
global recommendation_engine, real_user_selector
|
| 181 |
|
| 182 |
try:
|
| 183 |
+
print("Loading recommendation engine with enhanced demographics...")
|
| 184 |
recommendation_engine = RecommendationEngine()
|
| 185 |
+
print("β
Recommendation engine loaded successfully!")
|
| 186 |
+
print(" Supports 7 demographic features: age, gender, income, profession, location, education, marital_status")
|
| 187 |
except Exception as e:
|
| 188 |
+
print(f"β Error loading recommendation engine: {e}")
|
| 189 |
recommendation_engine = None
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
try:
|
| 192 |
print("Loading real user selector...")
|
| 193 |
real_user_selector = RealUserSelector()
|
| 194 |
+
print("β
Real user selector loaded successfully!")
|
| 195 |
except Exception as e:
|
| 196 |
+
print(f"β Error loading real user selector: {e}")
|
| 197 |
real_user_selector = None
|
| 198 |
|
| 199 |
|
|
|
|
| 203 |
return {
|
| 204 |
"message": "Two-Tower Recommendation API",
|
| 205 |
"version": "1.0.0",
|
| 206 |
+
"status": "active" if recommendation_engine is not None else "initialization_failed",
|
| 207 |
+
"enhanced_demographics": True,
|
| 208 |
+
"supported_demographics": [
|
| 209 |
+
"age", "gender", "income", "profession",
|
| 210 |
+
"location", "education_level", "marital_status"
|
| 211 |
+
]
|
| 212 |
}
|
| 213 |
|
| 214 |
|
|
|
|
| 217 |
"""Health check endpoint."""
|
| 218 |
return {
|
| 219 |
"status": "healthy" if recommendation_engine is not None else "unhealthy",
|
| 220 |
+
"engine_loaded": recommendation_engine is not None,
|
| 221 |
+
"enhanced_demographics": True,
|
| 222 |
+
"demographic_features": 7,
|
| 223 |
+
"supported_demographics": [
|
| 224 |
+
"age", "gender", "income", "profession",
|
| 225 |
+
"location", "education_level", "marital_status"
|
| 226 |
+
]
|
| 227 |
}
|
| 228 |
|
| 229 |
|
|
|
|
| 381 |
age=user_profile.age,
|
| 382 |
gender=user_profile.gender,
|
| 383 |
income=user_profile.income,
|
| 384 |
+
profession=user_profile.profession or "Other",
|
| 385 |
+
location=user_profile.location or "Urban",
|
| 386 |
+
education_level=user_profile.education_level or "High School",
|
| 387 |
+
marital_status=user_profile.marital_status or "Single",
|
| 388 |
interaction_history=filtered_interaction_history,
|
| 389 |
k=request.num_recommendations * 2 # Get more to allow for filtering
|
| 390 |
)
|
|
|
|
| 397 |
(f" in category '{request.selected_category}'" if request.selected_category else "")
|
| 398 |
)
|
| 399 |
|
| 400 |
+
# Use aggregated interaction history for content-based recommendations
|
| 401 |
+
recommendations = recommendation_engine.recommend_items_content_based_from_history(
|
| 402 |
+
interaction_history=filtered_interaction_history,
|
| 403 |
+
k=request.num_recommendations * 2, # Get more to allow for filtering
|
| 404 |
+
aggregation_method="weighted_mean"
|
| 405 |
)
|
| 406 |
|
| 407 |
elif request.recommendation_type == "hybrid":
|
|
|
|
| 409 |
age=user_profile.age,
|
| 410 |
gender=user_profile.gender,
|
| 411 |
income=user_profile.income,
|
| 412 |
+
profession=user_profile.profession or "Other",
|
| 413 |
+
location=user_profile.location or "Urban",
|
| 414 |
+
education_level=user_profile.education_level or "High School",
|
| 415 |
+
marital_status=user_profile.marital_status or "Single",
|
| 416 |
interaction_history=filtered_interaction_history,
|
| 417 |
k=request.num_recommendations * 2, # Get more to allow for filtering
|
| 418 |
collaborative_weight=request.collaborative_weight
|
| 419 |
)
|
| 420 |
|
| 421 |
+
elif request.recommendation_type == "category_boosted":
|
| 422 |
+
recommendations = recommendation_engine.recommend_items_category_boosted(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
age=user_profile.age,
|
| 424 |
gender=user_profile.gender,
|
| 425 |
income=user_profile.income,
|
| 426 |
+
profession=user_profile.profession or "Other",
|
| 427 |
+
location=user_profile.location or "Urban",
|
| 428 |
+
education_level=user_profile.education_level or "High School",
|
| 429 |
+
marital_status=user_profile.marital_status or "Single",
|
| 430 |
interaction_history=filtered_interaction_history,
|
| 431 |
+
k=request.num_recommendations * 2 # Get more to allow for filtering
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
)
|
| 433 |
|
| 434 |
else:
|
| 435 |
raise HTTPException(
|
| 436 |
status_code=400,
|
| 437 |
+
detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', or 'category_boosted'"
|
| 438 |
)
|
| 439 |
|
| 440 |
# Apply category filtering to final recommendations if needed
|
| 441 |
if request.selected_category:
|
| 442 |
recommendations = filter_recommendations_by_category(recommendations, request.selected_category)
|
| 443 |
+
|
| 444 |
+
# Always limit to requested number of recommendations
|
| 445 |
+
recommendations = recommendations[:request.num_recommendations]
|
| 446 |
|
| 447 |
# Format response
|
| 448 |
formatted_recommendations = []
|
|
|
|
| 511 |
age=user_profile.age,
|
| 512 |
gender=user_profile.gender,
|
| 513 |
income=user_profile.income,
|
| 514 |
+
item_id=request.item_id,
|
| 515 |
+
profession=user_profile.profession or "Other",
|
| 516 |
+
location=user_profile.location or "Urban",
|
| 517 |
+
education_level=user_profile.education_level or "High School",
|
| 518 |
+
marital_status=user_profile.marital_status or "Single",
|
| 519 |
+
interaction_history=user_profile.interaction_history
|
| 520 |
)
|
| 521 |
|
| 522 |
item_info = recommendation_engine._get_item_info(request.item_id)
|
api_2phase.py
DELETED
|
@@ -1,521 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
API for 2-Phase Trained Recommendation System
|
| 4 |
-
|
| 5 |
-
This API serves recommendations from a model trained using the 2-phase approach:
|
| 6 |
-
1. Pre-trained item tower
|
| 7 |
-
2. Joint training with fine-tuned item tower
|
| 8 |
-
|
| 9 |
-
Usage:
|
| 10 |
-
python api_2phase.py
|
| 11 |
-
|
| 12 |
-
Then access: http://localhost:8000
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
from fastapi import FastAPI, HTTPException
|
| 16 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
-
from pydantic import BaseModel
|
| 18 |
-
from typing import List, Optional, Dict, Any
|
| 19 |
-
import uvicorn
|
| 20 |
-
import os
|
| 21 |
-
import sys
|
| 22 |
-
import pandas as pd
|
| 23 |
-
|
| 24 |
-
# Add src to path for imports and set working directory
|
| 25 |
-
parent_dir = os.path.dirname(__file__)
|
| 26 |
-
sys.path.append(parent_dir)
|
| 27 |
-
os.chdir(parent_dir) # Change to project root directory
|
| 28 |
-
|
| 29 |
-
from src.inference.recommendation_engine import RecommendationEngine
|
| 30 |
-
from src.utils.real_user_selector import RealUserSelector
|
| 31 |
-
|
| 32 |
-
# Initialize FastAPI app
|
| 33 |
-
app = FastAPI(
|
| 34 |
-
title="Two-Tower Recommendation API (2-Phase Training)",
|
| 35 |
-
description="API for serving recommendations using a two-tower architecture trained with 2-phase approach",
|
| 36 |
-
version="1.0.0-2phase"
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
# Add CORS middleware
|
| 40 |
-
app.add_middleware(
|
| 41 |
-
CORSMiddleware,
|
| 42 |
-
allow_origins=["*"], # Configure appropriately for production
|
| 43 |
-
allow_credentials=True,
|
| 44 |
-
allow_methods=["*"],
|
| 45 |
-
allow_headers=["*"],
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
# Global instances
|
| 49 |
-
recommendation_engine = None
|
| 50 |
-
enhanced_recommendation_engine = None
|
| 51 |
-
real_user_selector = None
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# Pydantic models for request/response
|
| 55 |
-
class UserProfile(BaseModel):
|
| 56 |
-
age: int
|
| 57 |
-
gender: str # "male" or "female"
|
| 58 |
-
income: float
|
| 59 |
-
interaction_history: Optional[List[int]] = []
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class RecommendationRequest(BaseModel):
|
| 63 |
-
user_profile: UserProfile
|
| 64 |
-
num_recommendations: int = 10
|
| 65 |
-
recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
|
| 66 |
-
collaborative_weight: Optional[float] = 0.7
|
| 67 |
-
category_boost: Optional[float] = 1.5 # For enhanced recommendations
|
| 68 |
-
enable_category_boost: Optional[bool] = True
|
| 69 |
-
enable_diversity: Optional[bool] = True
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class ItemSimilarityRequest(BaseModel):
|
| 73 |
-
item_id: int
|
| 74 |
-
num_recommendations: int = 10
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class RatingPredictionRequest(BaseModel):
|
| 78 |
-
user_profile: UserProfile
|
| 79 |
-
item_id: int
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
class ItemInfo(BaseModel):
|
| 83 |
-
product_id: int
|
| 84 |
-
category_id: int
|
| 85 |
-
category_code: str
|
| 86 |
-
brand: str
|
| 87 |
-
price: float
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class RecommendationResponse(BaseModel):
|
| 91 |
-
item_id: int
|
| 92 |
-
score: float
|
| 93 |
-
item_info: ItemInfo
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class RecommendationsResponse(BaseModel):
|
| 97 |
-
recommendations: List[RecommendationResponse]
|
| 98 |
-
user_profile: UserProfile
|
| 99 |
-
recommendation_type: str
|
| 100 |
-
total_count: int
|
| 101 |
-
training_approach: str = "2-phase"
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class RatingPredictionResponse(BaseModel):
|
| 105 |
-
user_profile: UserProfile
|
| 106 |
-
item_id: int
|
| 107 |
-
predicted_rating: float
|
| 108 |
-
item_info: ItemInfo
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
class RealUserProfile(BaseModel):
|
| 112 |
-
user_id: int
|
| 113 |
-
age: int
|
| 114 |
-
gender: str
|
| 115 |
-
income: int
|
| 116 |
-
interaction_history: List[int]
|
| 117 |
-
interaction_stats: Dict[str, int]
|
| 118 |
-
interaction_pattern: str
|
| 119 |
-
summary: str
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
class RealUsersResponse(BaseModel):
|
| 123 |
-
users: List[RealUserProfile]
|
| 124 |
-
total_count: int
|
| 125 |
-
dataset_summary: Dict[str, Any]
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
@app.on_event("startup")
|
| 129 |
-
async def startup_event():
|
| 130 |
-
"""Initialize the recommendation engines and real user selector on startup."""
|
| 131 |
-
global recommendation_engine, enhanced_recommendation_engine, real_user_selector
|
| 132 |
-
|
| 133 |
-
print("π Starting 2-Phase Training API...")
|
| 134 |
-
print(" Training approach: Pre-trained item tower + Joint fine-tuning")
|
| 135 |
-
|
| 136 |
-
try:
|
| 137 |
-
print("Loading 2-phase trained recommendation engine...")
|
| 138 |
-
recommendation_engine = RecommendationEngine()
|
| 139 |
-
print("β
2-phase recommendation engine loaded successfully!")
|
| 140 |
-
except Exception as e:
|
| 141 |
-
print(f"β Error loading recommendation engine: {e}")
|
| 142 |
-
recommendation_engine = None
|
| 143 |
-
|
| 144 |
-
try:
|
| 145 |
-
print("Loading enhanced recommendation engine...")
|
| 146 |
-
# Try enhanced 128D engine first, fallback to regular enhanced
|
| 147 |
-
try:
|
| 148 |
-
from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
|
| 149 |
-
enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
|
| 150 |
-
print("β
Using Enhanced 128D Recommendation Engine")
|
| 151 |
-
except:
|
| 152 |
-
from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
|
| 153 |
-
enhanced_recommendation_engine = EnhancedRecommendationEngine()
|
| 154 |
-
print("β οΈ Using fallback Enhanced Recommendation Engine")
|
| 155 |
-
print("Enhanced recommendation engine loaded successfully!")
|
| 156 |
-
except Exception as e:
|
| 157 |
-
print(f"Error loading enhanced recommendation engine: {e}")
|
| 158 |
-
enhanced_recommendation_engine = None
|
| 159 |
-
|
| 160 |
-
try:
|
| 161 |
-
print("Loading real user selector...")
|
| 162 |
-
real_user_selector = RealUserSelector()
|
| 163 |
-
print("Real user selector loaded successfully!")
|
| 164 |
-
except Exception as e:
|
| 165 |
-
print(f"Error loading real user selector: {e}")
|
| 166 |
-
real_user_selector = None
|
| 167 |
-
|
| 168 |
-
print("π― 2-Phase API ready to serve recommendations!")
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
@app.get("/")
|
| 172 |
-
async def root():
|
| 173 |
-
"""Root endpoint with API information."""
|
| 174 |
-
return {
|
| 175 |
-
"message": "Two-Tower Recommendation API (2-Phase Training)",
|
| 176 |
-
"version": "1.0.0-2phase",
|
| 177 |
-
"training_approach": "2-phase (pre-trained item tower + joint fine-tuning)",
|
| 178 |
-
"status": "active" if recommendation_engine is not None else "initialization_failed"
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
@app.get("/health")
|
| 183 |
-
async def health_check():
|
| 184 |
-
"""Health check endpoint."""
|
| 185 |
-
return {
|
| 186 |
-
"status": "healthy" if recommendation_engine is not None else "unhealthy",
|
| 187 |
-
"engine_loaded": recommendation_engine is not None,
|
| 188 |
-
"training_approach": "2-phase"
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
@app.get("/model-info")
|
| 193 |
-
async def model_info():
|
| 194 |
-
"""Get information about the loaded model."""
|
| 195 |
-
if recommendation_engine is None:
|
| 196 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 197 |
-
|
| 198 |
-
return {
|
| 199 |
-
"training_approach": "2-phase",
|
| 200 |
-
"description": "Pre-trained item tower followed by joint training with user tower",
|
| 201 |
-
"phases": [
|
| 202 |
-
"Phase 1: Item tower pre-training on item features only",
|
| 203 |
-
"Phase 2: Joint training of user tower + fine-tuning pre-trained item tower"
|
| 204 |
-
],
|
| 205 |
-
"embedding_dimension": 128,
|
| 206 |
-
"item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
|
| 207 |
-
"artifacts_loaded": {
|
| 208 |
-
"item_tower_pretrained": "src/artifacts/item_tower_weights",
|
| 209 |
-
"item_tower_finetuned": "src/artifacts/item_tower_weights_finetuned_best",
|
| 210 |
-
"user_tower": "src/artifacts/user_tower_weights_best",
|
| 211 |
-
"rating_model": "src/artifacts/rating_model_weights_best"
|
| 212 |
-
}
|
| 213 |
-
}
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
@app.get("/real-users", response_model=RealUsersResponse)
|
| 217 |
-
async def get_real_users(count: int = 100, min_interactions: int = 5):
|
| 218 |
-
"""Get real user profiles with genuine interaction histories."""
|
| 219 |
-
|
| 220 |
-
if real_user_selector is None:
|
| 221 |
-
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 222 |
-
|
| 223 |
-
try:
|
| 224 |
-
# Get real user profiles
|
| 225 |
-
real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
|
| 226 |
-
|
| 227 |
-
# Get dataset summary
|
| 228 |
-
dataset_summary = real_user_selector.get_dataset_summary()
|
| 229 |
-
|
| 230 |
-
# Format users for response
|
| 231 |
-
formatted_users = []
|
| 232 |
-
for user in real_users:
|
| 233 |
-
formatted_users.append(RealUserProfile(**user))
|
| 234 |
-
|
| 235 |
-
return RealUsersResponse(
|
| 236 |
-
users=formatted_users,
|
| 237 |
-
total_count=len(formatted_users),
|
| 238 |
-
dataset_summary=dataset_summary
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
except Exception as e:
|
| 242 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
@app.get("/real-users/{user_id}")
|
| 246 |
-
async def get_real_user_details(user_id: int):
|
| 247 |
-
"""Get detailed interaction breakdown for a specific real user."""
|
| 248 |
-
|
| 249 |
-
if real_user_selector is None:
|
| 250 |
-
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 251 |
-
|
| 252 |
-
try:
|
| 253 |
-
user_details = real_user_selector.get_user_interaction_details(user_id)
|
| 254 |
-
|
| 255 |
-
if "error" in user_details:
|
| 256 |
-
raise HTTPException(status_code=404, detail=user_details["error"])
|
| 257 |
-
|
| 258 |
-
return user_details
|
| 259 |
-
|
| 260 |
-
except Exception as e:
|
| 261 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
@app.get("/dataset-summary")
|
| 265 |
-
async def get_dataset_summary():
|
| 266 |
-
"""Get summary statistics of the real dataset."""
|
| 267 |
-
|
| 268 |
-
if real_user_selector is None:
|
| 269 |
-
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 270 |
-
|
| 271 |
-
try:
|
| 272 |
-
return real_user_selector.get_dataset_summary()
|
| 273 |
-
|
| 274 |
-
except Exception as e:
|
| 275 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
@app.post("/recommendations", response_model=RecommendationsResponse)
|
| 279 |
-
async def get_recommendations(request: RecommendationRequest):
|
| 280 |
-
"""Get item recommendations for a user."""
|
| 281 |
-
|
| 282 |
-
if recommendation_engine is None:
|
| 283 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 284 |
-
|
| 285 |
-
try:
|
| 286 |
-
user_profile = request.user_profile
|
| 287 |
-
|
| 288 |
-
# Generate recommendations based on type
|
| 289 |
-
if request.recommendation_type == "collaborative":
|
| 290 |
-
recommendations = recommendation_engine.recommend_items_collaborative(
|
| 291 |
-
age=user_profile.age,
|
| 292 |
-
gender=user_profile.gender,
|
| 293 |
-
income=user_profile.income,
|
| 294 |
-
interaction_history=user_profile.interaction_history,
|
| 295 |
-
k=request.num_recommendations
|
| 296 |
-
)
|
| 297 |
-
|
| 298 |
-
elif request.recommendation_type == "content":
|
| 299 |
-
if not user_profile.interaction_history:
|
| 300 |
-
raise HTTPException(
|
| 301 |
-
status_code=400,
|
| 302 |
-
detail="Content-based recommendations require interaction history"
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
# Use most recent interaction as seed
|
| 306 |
-
seed_item = user_profile.interaction_history[-1]
|
| 307 |
-
recommendations = recommendation_engine.recommend_items_content_based(
|
| 308 |
-
seed_item_id=seed_item,
|
| 309 |
-
k=request.num_recommendations
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
elif request.recommendation_type == "hybrid":
|
| 313 |
-
recommendations = recommendation_engine.recommend_items_hybrid(
|
| 314 |
-
age=user_profile.age,
|
| 315 |
-
gender=user_profile.gender,
|
| 316 |
-
income=user_profile.income,
|
| 317 |
-
interaction_history=user_profile.interaction_history,
|
| 318 |
-
k=request.num_recommendations,
|
| 319 |
-
collaborative_weight=request.collaborative_weight
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
elif request.recommendation_type == "enhanced":
|
| 323 |
-
if enhanced_recommendation_engine is None:
|
| 324 |
-
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 325 |
-
|
| 326 |
-
# Check if it's the 128D engine or fallback
|
| 327 |
-
if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 328 |
-
# 128D Enhanced engine
|
| 329 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 330 |
-
age=user_profile.age,
|
| 331 |
-
gender=user_profile.gender,
|
| 332 |
-
income=user_profile.income,
|
| 333 |
-
interaction_history=user_profile.interaction_history,
|
| 334 |
-
k=request.num_recommendations,
|
| 335 |
-
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 336 |
-
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 337 |
-
)
|
| 338 |
-
else:
|
| 339 |
-
# Fallback enhanced engine
|
| 340 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
|
| 341 |
-
age=user_profile.age,
|
| 342 |
-
gender=user_profile.gender,
|
| 343 |
-
income=user_profile.income,
|
| 344 |
-
interaction_history=user_profile.interaction_history,
|
| 345 |
-
k=request.num_recommendations,
|
| 346 |
-
collaborative_weight=request.collaborative_weight,
|
| 347 |
-
category_boost=request.category_boost,
|
| 348 |
-
enable_category_boost=request.enable_category_boost,
|
| 349 |
-
enable_diversity=request.enable_diversity
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
elif request.recommendation_type == "enhanced_128d":
|
| 353 |
-
if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 354 |
-
raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
|
| 355 |
-
|
| 356 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 357 |
-
age=user_profile.age,
|
| 358 |
-
gender=user_profile.gender,
|
| 359 |
-
income=user_profile.income,
|
| 360 |
-
interaction_history=user_profile.interaction_history,
|
| 361 |
-
k=request.num_recommendations,
|
| 362 |
-
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 363 |
-
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 364 |
-
)
|
| 365 |
-
|
| 366 |
-
elif request.recommendation_type == "category_focused":
|
| 367 |
-
if enhanced_recommendation_engine is None:
|
| 368 |
-
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 369 |
-
|
| 370 |
-
recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
|
| 371 |
-
age=user_profile.age,
|
| 372 |
-
gender=user_profile.gender,
|
| 373 |
-
income=user_profile.income,
|
| 374 |
-
interaction_history=user_profile.interaction_history,
|
| 375 |
-
k=request.num_recommendations,
|
| 376 |
-
focus_percentage=0.8
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
else:
|
| 380 |
-
raise HTTPException(
|
| 381 |
-
status_code=400,
|
| 382 |
-
detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
# Format response
|
| 386 |
-
formatted_recommendations = []
|
| 387 |
-
for item_id, score, item_info in recommendations:
|
| 388 |
-
formatted_recommendations.append(
|
| 389 |
-
RecommendationResponse(
|
| 390 |
-
item_id=item_id,
|
| 391 |
-
score=score,
|
| 392 |
-
item_info=ItemInfo(**item_info)
|
| 393 |
-
)
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
return RecommendationsResponse(
|
| 397 |
-
recommendations=formatted_recommendations,
|
| 398 |
-
user_profile=user_profile,
|
| 399 |
-
recommendation_type=request.recommendation_type,
|
| 400 |
-
total_count=len(formatted_recommendations),
|
| 401 |
-
training_approach="2-phase"
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
except Exception as e:
|
| 405 |
-
raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
@app.post("/item-similarity", response_model=List[RecommendationResponse])
|
| 409 |
-
async def get_similar_items(request: ItemSimilarityRequest):
|
| 410 |
-
"""Get items similar to a given item."""
|
| 411 |
-
|
| 412 |
-
if recommendation_engine is None:
|
| 413 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 414 |
-
|
| 415 |
-
try:
|
| 416 |
-
recommendations = recommendation_engine.recommend_items_content_based(
|
| 417 |
-
seed_item_id=request.item_id,
|
| 418 |
-
k=request.num_recommendations
|
| 419 |
-
)
|
| 420 |
-
|
| 421 |
-
formatted_recommendations = []
|
| 422 |
-
for item_id, score, item_info in recommendations:
|
| 423 |
-
formatted_recommendations.append(
|
| 424 |
-
RecommendationResponse(
|
| 425 |
-
item_id=item_id,
|
| 426 |
-
score=score,
|
| 427 |
-
item_info=ItemInfo(**item_info)
|
| 428 |
-
)
|
| 429 |
-
)
|
| 430 |
-
|
| 431 |
-
return formatted_recommendations
|
| 432 |
-
|
| 433 |
-
except Exception as e:
|
| 434 |
-
raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
@app.post("/predict-rating", response_model=RatingPredictionResponse)
|
| 438 |
-
async def predict_user_item_rating(request: RatingPredictionRequest):
|
| 439 |
-
"""Predict rating for a user-item pair."""
|
| 440 |
-
|
| 441 |
-
if recommendation_engine is None:
|
| 442 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 443 |
-
|
| 444 |
-
try:
|
| 445 |
-
user_profile = request.user_profile
|
| 446 |
-
|
| 447 |
-
predicted_rating = recommendation_engine.predict_rating(
|
| 448 |
-
age=user_profile.age,
|
| 449 |
-
gender=user_profile.gender,
|
| 450 |
-
income=user_profile.income,
|
| 451 |
-
interaction_history=user_profile.interaction_history,
|
| 452 |
-
item_id=request.item_id
|
| 453 |
-
)
|
| 454 |
-
|
| 455 |
-
item_info = recommendation_engine._get_item_info(request.item_id)
|
| 456 |
-
|
| 457 |
-
return RatingPredictionResponse(
|
| 458 |
-
user_profile=user_profile,
|
| 459 |
-
item_id=request.item_id,
|
| 460 |
-
predicted_rating=predicted_rating,
|
| 461 |
-
item_info=ItemInfo(**item_info)
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
except Exception as e:
|
| 465 |
-
raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
@app.get("/items/{item_id}", response_model=ItemInfo)
|
| 469 |
-
async def get_item_info(item_id: int):
|
| 470 |
-
"""Get information about a specific item."""
|
| 471 |
-
|
| 472 |
-
if recommendation_engine is None:
|
| 473 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 474 |
-
|
| 475 |
-
try:
|
| 476 |
-
item_info = recommendation_engine._get_item_info(item_id)
|
| 477 |
-
return ItemInfo(**item_info)
|
| 478 |
-
|
| 479 |
-
except Exception as e:
|
| 480 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
@app.get("/items")
|
| 484 |
-
async def get_sample_items(limit: int = 20):
|
| 485 |
-
"""Get a sample of items for testing."""
|
| 486 |
-
|
| 487 |
-
if recommendation_engine is None:
|
| 488 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 489 |
-
|
| 490 |
-
try:
|
| 491 |
-
# Get sample items from the dataframe
|
| 492 |
-
sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
|
| 493 |
-
|
| 494 |
-
items = []
|
| 495 |
-
for _, row in sample_items.iterrows():
|
| 496 |
-
items.append({
|
| 497 |
-
"product_id": int(row['product_id']),
|
| 498 |
-
"category_id": int(row['category_id']),
|
| 499 |
-
"category_code": str(row['category_code']),
|
| 500 |
-
"brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
|
| 501 |
-
"price": float(row['price'])
|
| 502 |
-
})
|
| 503 |
-
|
| 504 |
-
return {"items": items, "total": len(items), "training_approach": "2-phase"}
|
| 505 |
-
|
| 506 |
-
except Exception as e:
|
| 507 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
if __name__ == "__main__":
|
| 511 |
-
print("π Starting 2-Phase Training Recommendation API...")
|
| 512 |
-
print("π Training approach: Pre-trained item tower + Joint fine-tuning")
|
| 513 |
-
print("π Server will be available at: http://localhost:8000")
|
| 514 |
-
print("π API docs at: http://localhost:8000/docs")
|
| 515 |
-
|
| 516 |
-
uvicorn.run(
|
| 517 |
-
"api_2phase:app",
|
| 518 |
-
host="0.0.0.0",
|
| 519 |
-
port=8000,
|
| 520 |
-
reload=True
|
| 521 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api_joint.py
DELETED
|
@@ -1,522 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
API for Single Joint Trained Recommendation System
|
| 4 |
-
|
| 5 |
-
This API serves recommendations from a model trained using the single joint approach:
|
| 6 |
-
- Both user and item towers trained simultaneously from scratch
|
| 7 |
-
- End-to-end optimization without pre-training phases
|
| 8 |
-
|
| 9 |
-
Usage:
|
| 10 |
-
python api_joint.py
|
| 11 |
-
|
| 12 |
-
Then access: http://localhost:8000
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
from fastapi import FastAPI, HTTPException
|
| 16 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
-
from pydantic import BaseModel
|
| 18 |
-
from typing import List, Optional, Dict, Any
|
| 19 |
-
import uvicorn
|
| 20 |
-
import os
|
| 21 |
-
import sys
|
| 22 |
-
import pandas as pd
|
| 23 |
-
|
| 24 |
-
# Add src to path for imports and set working directory
|
| 25 |
-
parent_dir = os.path.dirname(__file__)
|
| 26 |
-
sys.path.append(parent_dir)
|
| 27 |
-
os.chdir(parent_dir) # Change to project root directory
|
| 28 |
-
|
| 29 |
-
from src.inference.recommendation_engine import RecommendationEngine
|
| 30 |
-
from src.utils.real_user_selector import RealUserSelector
|
| 31 |
-
|
| 32 |
-
# Initialize FastAPI app
|
| 33 |
-
app = FastAPI(
|
| 34 |
-
title="Two-Tower Recommendation API (Single Joint Training)",
|
| 35 |
-
description="API for serving recommendations using a two-tower architecture trained with single joint approach",
|
| 36 |
-
version="1.0.0-joint"
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
# Add CORS middleware
|
| 40 |
-
app.add_middleware(
|
| 41 |
-
CORSMiddleware,
|
| 42 |
-
allow_origins=["*"], # Configure appropriately for production
|
| 43 |
-
allow_credentials=True,
|
| 44 |
-
allow_methods=["*"],
|
| 45 |
-
allow_headers=["*"],
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
# Global instances
|
| 49 |
-
recommendation_engine = None
|
| 50 |
-
enhanced_recommendation_engine = None
|
| 51 |
-
real_user_selector = None
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# Pydantic models for request/response
|
| 55 |
-
class UserProfile(BaseModel):
|
| 56 |
-
age: int
|
| 57 |
-
gender: str # "male" or "female"
|
| 58 |
-
income: float
|
| 59 |
-
interaction_history: Optional[List[int]] = []
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class RecommendationRequest(BaseModel):
|
| 63 |
-
user_profile: UserProfile
|
| 64 |
-
num_recommendations: int = 10
|
| 65 |
-
recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
|
| 66 |
-
collaborative_weight: Optional[float] = 0.7
|
| 67 |
-
category_boost: Optional[float] = 1.5 # For enhanced recommendations
|
| 68 |
-
enable_category_boost: Optional[bool] = True
|
| 69 |
-
enable_diversity: Optional[bool] = True
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class ItemSimilarityRequest(BaseModel):
|
| 73 |
-
item_id: int
|
| 74 |
-
num_recommendations: int = 10
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class RatingPredictionRequest(BaseModel):
|
| 78 |
-
user_profile: UserProfile
|
| 79 |
-
item_id: int
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
class ItemInfo(BaseModel):
|
| 83 |
-
product_id: int
|
| 84 |
-
category_id: int
|
| 85 |
-
category_code: str
|
| 86 |
-
brand: str
|
| 87 |
-
price: float
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class RecommendationResponse(BaseModel):
|
| 91 |
-
item_id: int
|
| 92 |
-
score: float
|
| 93 |
-
item_info: ItemInfo
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class RecommendationsResponse(BaseModel):
|
| 97 |
-
recommendations: List[RecommendationResponse]
|
| 98 |
-
user_profile: UserProfile
|
| 99 |
-
recommendation_type: str
|
| 100 |
-
total_count: int
|
| 101 |
-
training_approach: str = "single-joint"
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class RatingPredictionResponse(BaseModel):
|
| 105 |
-
user_profile: UserProfile
|
| 106 |
-
item_id: int
|
| 107 |
-
predicted_rating: float
|
| 108 |
-
item_info: ItemInfo
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
class RealUserProfile(BaseModel):
|
| 112 |
-
user_id: int
|
| 113 |
-
age: int
|
| 114 |
-
gender: str
|
| 115 |
-
income: int
|
| 116 |
-
interaction_history: List[int]
|
| 117 |
-
interaction_stats: Dict[str, int]
|
| 118 |
-
interaction_pattern: str
|
| 119 |
-
summary: str
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
class RealUsersResponse(BaseModel):
|
| 123 |
-
users: List[RealUserProfile]
|
| 124 |
-
total_count: int
|
| 125 |
-
dataset_summary: Dict[str, Any]
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
@app.on_event("startup")
|
| 129 |
-
async def startup_event():
|
| 130 |
-
"""Initialize the recommendation engines and real user selector on startup."""
|
| 131 |
-
global recommendation_engine, enhanced_recommendation_engine, real_user_selector
|
| 132 |
-
|
| 133 |
-
print("π Starting Single Joint Training API...")
|
| 134 |
-
print(" Training approach: End-to-end joint optimization from scratch")
|
| 135 |
-
|
| 136 |
-
try:
|
| 137 |
-
print("Loading single joint trained recommendation engine...")
|
| 138 |
-
recommendation_engine = RecommendationEngine()
|
| 139 |
-
print("β
Single joint recommendation engine loaded successfully!")
|
| 140 |
-
except Exception as e:
|
| 141 |
-
print(f"β Error loading recommendation engine: {e}")
|
| 142 |
-
recommendation_engine = None
|
| 143 |
-
|
| 144 |
-
try:
|
| 145 |
-
print("Loading enhanced recommendation engine...")
|
| 146 |
-
# Try enhanced 128D engine first, fallback to regular enhanced
|
| 147 |
-
try:
|
| 148 |
-
from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
|
| 149 |
-
enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
|
| 150 |
-
print("β
Using Enhanced 128D Recommendation Engine")
|
| 151 |
-
except:
|
| 152 |
-
from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
|
| 153 |
-
enhanced_recommendation_engine = EnhancedRecommendationEngine()
|
| 154 |
-
print("β οΈ Using fallback Enhanced Recommendation Engine")
|
| 155 |
-
print("Enhanced recommendation engine loaded successfully!")
|
| 156 |
-
except Exception as e:
|
| 157 |
-
print(f"Error loading enhanced recommendation engine: {e}")
|
| 158 |
-
enhanced_recommendation_engine = None
|
| 159 |
-
|
| 160 |
-
try:
|
| 161 |
-
print("Loading real user selector...")
|
| 162 |
-
real_user_selector = RealUserSelector()
|
| 163 |
-
print("Real user selector loaded successfully!")
|
| 164 |
-
except Exception as e:
|
| 165 |
-
print(f"Error loading real user selector: {e}")
|
| 166 |
-
real_user_selector = None
|
| 167 |
-
|
| 168 |
-
print("π― Single Joint API ready to serve recommendations!")
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
@app.get("/")
|
| 172 |
-
async def root():
|
| 173 |
-
"""Root endpoint with API information."""
|
| 174 |
-
return {
|
| 175 |
-
"message": "Two-Tower Recommendation API (Single Joint Training)",
|
| 176 |
-
"version": "1.0.0-joint",
|
| 177 |
-
"training_approach": "single-joint (end-to-end optimization from scratch)",
|
| 178 |
-
"status": "active" if recommendation_engine is not None else "initialization_failed"
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
@app.get("/health")
|
| 183 |
-
async def health_check():
|
| 184 |
-
"""Health check endpoint."""
|
| 185 |
-
return {
|
| 186 |
-
"status": "healthy" if recommendation_engine is not None else "unhealthy",
|
| 187 |
-
"engine_loaded": recommendation_engine is not None,
|
| 188 |
-
"training_approach": "single-joint"
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
@app.get("/model-info")
|
| 193 |
-
async def model_info():
|
| 194 |
-
"""Get information about the loaded model."""
|
| 195 |
-
if recommendation_engine is None:
|
| 196 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 197 |
-
|
| 198 |
-
return {
|
| 199 |
-
"training_approach": "single-joint",
|
| 200 |
-
"description": "User and item towers trained simultaneously from scratch",
|
| 201 |
-
"advantages": [
|
| 202 |
-
"End-to-end optimization for better task alignment",
|
| 203 |
-
"No pre-training phase required",
|
| 204 |
-
"Faster overall training pipeline",
|
| 205 |
-
"Direct optimization for recommendation objectives"
|
| 206 |
-
],
|
| 207 |
-
"embedding_dimension": 128,
|
| 208 |
-
"item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
|
| 209 |
-
"artifacts_loaded": {
|
| 210 |
-
"user_tower": "src/artifacts/user_tower_weights_best",
|
| 211 |
-
"item_tower_joint": "src/artifacts/item_tower_weights_finetuned_best",
|
| 212 |
-
"rating_model": "src/artifacts/rating_model_weights_best"
|
| 213 |
-
}
|
| 214 |
-
}
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
@app.get("/real-users", response_model=RealUsersResponse)
|
| 218 |
-
async def get_real_users(count: int = 100, min_interactions: int = 5):
|
| 219 |
-
"""Get real user profiles with genuine interaction histories."""
|
| 220 |
-
|
| 221 |
-
if real_user_selector is None:
|
| 222 |
-
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 223 |
-
|
| 224 |
-
try:
|
| 225 |
-
# Get real user profiles
|
| 226 |
-
real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
|
| 227 |
-
|
| 228 |
-
# Get dataset summary
|
| 229 |
-
dataset_summary = real_user_selector.get_dataset_summary()
|
| 230 |
-
|
| 231 |
-
# Format users for response
|
| 232 |
-
formatted_users = []
|
| 233 |
-
for user in real_users:
|
| 234 |
-
formatted_users.append(RealUserProfile(**user))
|
| 235 |
-
|
| 236 |
-
return RealUsersResponse(
|
| 237 |
-
users=formatted_users,
|
| 238 |
-
total_count=len(formatted_users),
|
| 239 |
-
dataset_summary=dataset_summary
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
except Exception as e:
|
| 243 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
@app.get("/real-users/{user_id}")
|
| 247 |
-
async def get_real_user_details(user_id: int):
|
| 248 |
-
"""Get detailed interaction breakdown for a specific real user."""
|
| 249 |
-
|
| 250 |
-
if real_user_selector is None:
|
| 251 |
-
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 252 |
-
|
| 253 |
-
try:
|
| 254 |
-
user_details = real_user_selector.get_user_interaction_details(user_id)
|
| 255 |
-
|
| 256 |
-
if "error" in user_details:
|
| 257 |
-
raise HTTPException(status_code=404, detail=user_details["error"])
|
| 258 |
-
|
| 259 |
-
return user_details
|
| 260 |
-
|
| 261 |
-
except Exception as e:
|
| 262 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
@app.get("/dataset-summary")
|
| 266 |
-
async def get_dataset_summary():
|
| 267 |
-
"""Get summary statistics of the real dataset."""
|
| 268 |
-
|
| 269 |
-
if real_user_selector is None:
|
| 270 |
-
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 271 |
-
|
| 272 |
-
try:
|
| 273 |
-
return real_user_selector.get_dataset_summary()
|
| 274 |
-
|
| 275 |
-
except Exception as e:
|
| 276 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
@app.post("/recommendations", response_model=RecommendationsResponse)
|
| 280 |
-
async def get_recommendations(request: RecommendationRequest):
|
| 281 |
-
"""Get item recommendations for a user."""
|
| 282 |
-
|
| 283 |
-
if recommendation_engine is None:
|
| 284 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 285 |
-
|
| 286 |
-
try:
|
| 287 |
-
user_profile = request.user_profile
|
| 288 |
-
|
| 289 |
-
# Generate recommendations based on type
|
| 290 |
-
if request.recommendation_type == "collaborative":
|
| 291 |
-
recommendations = recommendation_engine.recommend_items_collaborative(
|
| 292 |
-
age=user_profile.age,
|
| 293 |
-
gender=user_profile.gender,
|
| 294 |
-
income=user_profile.income,
|
| 295 |
-
interaction_history=user_profile.interaction_history,
|
| 296 |
-
k=request.num_recommendations
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
elif request.recommendation_type == "content":
|
| 300 |
-
if not user_profile.interaction_history:
|
| 301 |
-
raise HTTPException(
|
| 302 |
-
status_code=400,
|
| 303 |
-
detail="Content-based recommendations require interaction history"
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
# Use most recent interaction as seed
|
| 307 |
-
seed_item = user_profile.interaction_history[-1]
|
| 308 |
-
recommendations = recommendation_engine.recommend_items_content_based(
|
| 309 |
-
seed_item_id=seed_item,
|
| 310 |
-
k=request.num_recommendations
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
elif request.recommendation_type == "hybrid":
|
| 314 |
-
recommendations = recommendation_engine.recommend_items_hybrid(
|
| 315 |
-
age=user_profile.age,
|
| 316 |
-
gender=user_profile.gender,
|
| 317 |
-
income=user_profile.income,
|
| 318 |
-
interaction_history=user_profile.interaction_history,
|
| 319 |
-
k=request.num_recommendations,
|
| 320 |
-
collaborative_weight=request.collaborative_weight
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
elif request.recommendation_type == "enhanced":
|
| 324 |
-
if enhanced_recommendation_engine is None:
|
| 325 |
-
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 326 |
-
|
| 327 |
-
# Check if it's the 128D engine or fallback
|
| 328 |
-
if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 329 |
-
# 128D Enhanced engine
|
| 330 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 331 |
-
age=user_profile.age,
|
| 332 |
-
gender=user_profile.gender,
|
| 333 |
-
income=user_profile.income,
|
| 334 |
-
interaction_history=user_profile.interaction_history,
|
| 335 |
-
k=request.num_recommendations,
|
| 336 |
-
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 337 |
-
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 338 |
-
)
|
| 339 |
-
else:
|
| 340 |
-
# Fallback enhanced engine
|
| 341 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
|
| 342 |
-
age=user_profile.age,
|
| 343 |
-
gender=user_profile.gender,
|
| 344 |
-
income=user_profile.income,
|
| 345 |
-
interaction_history=user_profile.interaction_history,
|
| 346 |
-
k=request.num_recommendations,
|
| 347 |
-
collaborative_weight=request.collaborative_weight,
|
| 348 |
-
category_boost=request.category_boost,
|
| 349 |
-
enable_category_boost=request.enable_category_boost,
|
| 350 |
-
enable_diversity=request.enable_diversity
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
elif request.recommendation_type == "enhanced_128d":
|
| 354 |
-
if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 355 |
-
raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
|
| 356 |
-
|
| 357 |
-
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 358 |
-
age=user_profile.age,
|
| 359 |
-
gender=user_profile.gender,
|
| 360 |
-
income=user_profile.income,
|
| 361 |
-
interaction_history=user_profile.interaction_history,
|
| 362 |
-
k=request.num_recommendations,
|
| 363 |
-
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 364 |
-
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
elif request.recommendation_type == "category_focused":
|
| 368 |
-
if enhanced_recommendation_engine is None:
|
| 369 |
-
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 370 |
-
|
| 371 |
-
recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
|
| 372 |
-
age=user_profile.age,
|
| 373 |
-
gender=user_profile.gender,
|
| 374 |
-
income=user_profile.income,
|
| 375 |
-
interaction_history=user_profile.interaction_history,
|
| 376 |
-
k=request.num_recommendations,
|
| 377 |
-
focus_percentage=0.8
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
else:
|
| 381 |
-
raise HTTPException(
|
| 382 |
-
status_code=400,
|
| 383 |
-
detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
# Format response
|
| 387 |
-
formatted_recommendations = []
|
| 388 |
-
for item_id, score, item_info in recommendations:
|
| 389 |
-
formatted_recommendations.append(
|
| 390 |
-
RecommendationResponse(
|
| 391 |
-
item_id=item_id,
|
| 392 |
-
score=score,
|
| 393 |
-
item_info=ItemInfo(**item_info)
|
| 394 |
-
)
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
return RecommendationsResponse(
|
| 398 |
-
recommendations=formatted_recommendations,
|
| 399 |
-
user_profile=user_profile,
|
| 400 |
-
recommendation_type=request.recommendation_type,
|
| 401 |
-
total_count=len(formatted_recommendations),
|
| 402 |
-
training_approach="single-joint"
|
| 403 |
-
)
|
| 404 |
-
|
| 405 |
-
except Exception as e:
|
| 406 |
-
raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
@app.post("/item-similarity", response_model=List[RecommendationResponse])
|
| 410 |
-
async def get_similar_items(request: ItemSimilarityRequest):
|
| 411 |
-
"""Get items similar to a given item."""
|
| 412 |
-
|
| 413 |
-
if recommendation_engine is None:
|
| 414 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 415 |
-
|
| 416 |
-
try:
|
| 417 |
-
recommendations = recommendation_engine.recommend_items_content_based(
|
| 418 |
-
seed_item_id=request.item_id,
|
| 419 |
-
k=request.num_recommendations
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
formatted_recommendations = []
|
| 423 |
-
for item_id, score, item_info in recommendations:
|
| 424 |
-
formatted_recommendations.append(
|
| 425 |
-
RecommendationResponse(
|
| 426 |
-
item_id=item_id,
|
| 427 |
-
score=score,
|
| 428 |
-
item_info=ItemInfo(**item_info)
|
| 429 |
-
)
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
return formatted_recommendations
|
| 433 |
-
|
| 434 |
-
except Exception as e:
|
| 435 |
-
raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
@app.post("/predict-rating", response_model=RatingPredictionResponse)
|
| 439 |
-
async def predict_user_item_rating(request: RatingPredictionRequest):
|
| 440 |
-
"""Predict rating for a user-item pair."""
|
| 441 |
-
|
| 442 |
-
if recommendation_engine is None:
|
| 443 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 444 |
-
|
| 445 |
-
try:
|
| 446 |
-
user_profile = request.user_profile
|
| 447 |
-
|
| 448 |
-
predicted_rating = recommendation_engine.predict_rating(
|
| 449 |
-
age=user_profile.age,
|
| 450 |
-
gender=user_profile.gender,
|
| 451 |
-
income=user_profile.income,
|
| 452 |
-
interaction_history=user_profile.interaction_history,
|
| 453 |
-
item_id=request.item_id
|
| 454 |
-
)
|
| 455 |
-
|
| 456 |
-
item_info = recommendation_engine._get_item_info(request.item_id)
|
| 457 |
-
|
| 458 |
-
return RatingPredictionResponse(
|
| 459 |
-
user_profile=user_profile,
|
| 460 |
-
item_id=request.item_id,
|
| 461 |
-
predicted_rating=predicted_rating,
|
| 462 |
-
item_info=ItemInfo(**item_info)
|
| 463 |
-
)
|
| 464 |
-
|
| 465 |
-
except Exception as e:
|
| 466 |
-
raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
@app.get("/items/{item_id}", response_model=ItemInfo)
|
| 470 |
-
async def get_item_info(item_id: int):
|
| 471 |
-
"""Get information about a specific item."""
|
| 472 |
-
|
| 473 |
-
if recommendation_engine is None:
|
| 474 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 475 |
-
|
| 476 |
-
try:
|
| 477 |
-
item_info = recommendation_engine._get_item_info(item_id)
|
| 478 |
-
return ItemInfo(**item_info)
|
| 479 |
-
|
| 480 |
-
except Exception as e:
|
| 481 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
@app.get("/items")
|
| 485 |
-
async def get_sample_items(limit: int = 20):
|
| 486 |
-
"""Get a sample of items for testing."""
|
| 487 |
-
|
| 488 |
-
if recommendation_engine is None:
|
| 489 |
-
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 490 |
-
|
| 491 |
-
try:
|
| 492 |
-
# Get sample items from the dataframe
|
| 493 |
-
sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
|
| 494 |
-
|
| 495 |
-
items = []
|
| 496 |
-
for _, row in sample_items.iterrows():
|
| 497 |
-
items.append({
|
| 498 |
-
"product_id": int(row['product_id']),
|
| 499 |
-
"category_id": int(row['category_id']),
|
| 500 |
-
"category_code": str(row['category_code']),
|
| 501 |
-
"brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
|
| 502 |
-
"price": float(row['price'])
|
| 503 |
-
})
|
| 504 |
-
|
| 505 |
-
return {"items": items, "total": len(items), "training_approach": "single-joint"}
|
| 506 |
-
|
| 507 |
-
except Exception as e:
|
| 508 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
if __name__ == "__main__":
|
| 512 |
-
print("π Starting Single Joint Training Recommendation API...")
|
| 513 |
-
print("β‘ Training approach: End-to-end joint optimization from scratch")
|
| 514 |
-
print("π Server will be available at: http://localhost:8000")
|
| 515 |
-
print("π API docs at: http://localhost:8000/docs")
|
| 516 |
-
|
| 517 |
-
uvicorn.run(
|
| 518 |
-
"api_joint:app",
|
| 519 |
-
host="0.0.0.0",
|
| 520 |
-
port=8000,
|
| 521 |
-
reload=True
|
| 522 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/interactions.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
datasets/items.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
datasets/users.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/src/App.css
CHANGED
|
@@ -3,6 +3,50 @@
|
|
| 3 |
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
|
| 4 |
}
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
/* Performance Monitoring Widget */
|
| 7 |
.performance-widget {
|
| 8 |
position: fixed;
|
|
@@ -1220,6 +1264,24 @@
|
|
| 1220 |
font-size: 12px;
|
| 1221 |
}
|
| 1222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1223 |
.pattern-summary {
|
| 1224 |
display: flex;
|
| 1225 |
gap: 20px;
|
|
|
|
| 3 |
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
|
| 4 |
}
|
| 5 |
|
| 6 |
+
/* Enhanced Demographics Styling */
|
| 7 |
+
.demographic-features {
|
| 8 |
+
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
|
| 9 |
+
padding: 20px;
|
| 10 |
+
border-radius: 10px;
|
| 11 |
+
border: 2px solid #dee2e6;
|
| 12 |
+
margin: 15px 0;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
.demographic-features .form-group {
|
| 16 |
+
position: relative;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
.demographic-features .form-group label {
|
| 20 |
+
font-weight: 600;
|
| 21 |
+
color: #495057;
|
| 22 |
+
font-size: 14px;
|
| 23 |
+
margin-bottom: 8px;
|
| 24 |
+
display: block;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.demographic-features .form-group select {
|
| 28 |
+
width: 100%;
|
| 29 |
+
padding: 10px 12px;
|
| 30 |
+
border: 2px solid #ced4da;
|
| 31 |
+
border-radius: 6px;
|
| 32 |
+
font-size: 14px;
|
| 33 |
+
background: white;
|
| 34 |
+
color: #495057;
|
| 35 |
+
transition: all 0.2s ease;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.demographic-features .form-group select:focus {
|
| 39 |
+
outline: none;
|
| 40 |
+
border-color: #007bff;
|
| 41 |
+
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
.demographic-features .form-group select:disabled {
|
| 45 |
+
background-color: #f8f9fa;
|
| 46 |
+
color: #6c757d;
|
| 47 |
+
cursor: not-allowed;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
/* Performance Monitoring Widget */
|
| 51 |
.performance-widget {
|
| 52 |
position: fixed;
|
|
|
|
| 1264 |
font-size: 12px;
|
| 1265 |
}
|
| 1266 |
|
| 1267 |
+
.pattern-btn.new-user-pattern {
|
| 1268 |
+
border-color: #6c757d;
|
| 1269 |
+
color: #6c757d;
|
| 1270 |
+
background: #f8f9fa;
|
| 1271 |
+
}
|
| 1272 |
+
|
| 1273 |
+
.pattern-btn.new-user-pattern:hover {
|
| 1274 |
+
background: #6c757d;
|
| 1275 |
+
color: white;
|
| 1276 |
+
box-shadow: 0 4px 8px rgba(108, 117, 125, 0.3);
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
.pattern-btn.new-user-pattern.active {
|
| 1280 |
+
background: #6c757d;
|
| 1281 |
+
color: white;
|
| 1282 |
+
box-shadow: 0 4px 8px rgba(108, 117, 125, 0.3);
|
| 1283 |
+
}
|
| 1284 |
+
|
| 1285 |
.pattern-summary {
|
| 1286 |
display: flex;
|
| 1287 |
gap: 20px;
|
frontend/src/App.js
CHANGED
|
@@ -6,6 +6,7 @@ const API_BASE_URL = process.env.REACT_APP_API_URL || 'http://localhost:8000';
|
|
| 6 |
|
| 7 |
// Interaction patterns with realistic ratios
|
| 8 |
const INTERACTION_PATTERNS = [
|
|
|
|
| 9 |
{ name: 'Light Browsing', views: 15, carts: 2, purchases: 0 },
|
| 10 |
{ name: 'Window Shopping', views: 25, carts: 5, purchases: 1 },
|
| 11 |
{ name: 'Serious Shopper', views: 35, carts: 8, purchases: 3 },
|
|
@@ -18,11 +19,15 @@ function App() {
|
|
| 18 |
age: 30,
|
| 19 |
gender: 'male',
|
| 20 |
income: 50000,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
interaction_history: []
|
| 22 |
});
|
| 23 |
|
| 24 |
const [recommendationType, setRecommendationType] = useState('hybrid');
|
| 25 |
-
const [numRecommendations, setNumRecommendations] = useState(
|
| 26 |
const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
|
| 27 |
|
| 28 |
const [recommendations, setRecommendations] = useState([]);
|
|
@@ -301,6 +306,10 @@ function App() {
|
|
| 301 |
age: user.age,
|
| 302 |
gender: user.gender,
|
| 303 |
income: user.income,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
interaction_history: user.interaction_history.slice(0, 50) // Limit to 50 items
|
| 305 |
});
|
| 306 |
// Clear any synthetic interactions and expanded states
|
|
@@ -443,7 +452,20 @@ function App() {
|
|
| 443 |
|
| 444 |
const handlePatternSelect = (pattern) => {
|
| 445 |
setSelectedPattern(pattern);
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
};
|
| 448 |
|
| 449 |
const toggleInteractionExpand = (interactionId) => {
|
|
@@ -472,6 +494,46 @@ function App() {
|
|
| 472 |
|
| 473 |
const counts = getInteractionCounts();
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
// Calculate category percentages from user interactions
|
| 476 |
const getCategoryPercentages = () => {
|
| 477 |
console.log('getCategoryPercentages called:', {
|
|
@@ -504,10 +566,7 @@ function App() {
|
|
| 504 |
console.log('Enriched behavioral pattern results:', { categoryCounts, totalInteractions });
|
| 505 |
|
| 506 |
if (totalInteractions > 0) {
|
| 507 |
-
const categoryPercentages =
|
| 508 |
-
Object.keys(categoryCounts).forEach(category => {
|
| 509 |
-
categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
|
| 510 |
-
});
|
| 511 |
console.log('Returning enriched behavioral pattern percentages:', categoryPercentages);
|
| 512 |
return categoryPercentages;
|
| 513 |
}
|
|
@@ -522,20 +581,15 @@ function App() {
|
|
| 522 |
|
| 523 |
interactions.forEach(interaction => {
|
| 524 |
console.log('Processing interaction:', interaction);
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
totalInteractions++;
|
| 529 |
-
}
|
| 530 |
});
|
| 531 |
|
| 532 |
console.log('Synthetic interaction results:', { categoryCounts, totalInteractions });
|
| 533 |
|
| 534 |
if (totalInteractions > 0) {
|
| 535 |
-
const categoryPercentages =
|
| 536 |
-
Object.keys(categoryCounts).forEach(category => {
|
| 537 |
-
categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
|
| 538 |
-
});
|
| 539 |
console.log('Returning synthetic percentages:', categoryPercentages);
|
| 540 |
return categoryPercentages;
|
| 541 |
}
|
|
@@ -554,11 +608,7 @@ function App() {
|
|
| 554 |
totalInteractions++;
|
| 555 |
});
|
| 556 |
|
| 557 |
-
const categoryPercentages =
|
| 558 |
-
Object.keys(categoryCounts).forEach(category => {
|
| 559 |
-
categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
|
| 560 |
-
});
|
| 561 |
-
|
| 562 |
return categoryPercentages;
|
| 563 |
}
|
| 564 |
|
|
@@ -708,7 +758,13 @@ function App() {
|
|
| 708 |
<div className="real-user-stats">
|
| 709 |
<div className="user-stat">
|
| 710 |
<span className="stat-label">Demographics:</span>
|
| 711 |
-
<span className="stat-value">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
</div>
|
| 713 |
<div className="user-stat">
|
| 714 |
<span className="stat-label">Behavior Pattern:</span>
|
|
@@ -847,6 +903,77 @@ function App() {
|
|
| 847 |
/>
|
| 848 |
</div>
|
| 849 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
</div>
|
| 851 |
|
| 852 |
{/* Random Behavioral Patterns for Custom Users */}
|
|
@@ -979,7 +1106,7 @@ function App() {
|
|
| 979 |
<div className="category-percentages">
|
| 980 |
{Object.entries(categoryPercentages)
|
| 981 |
.sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
|
| 982 |
-
.slice(0,
|
| 983 |
.map(([category, percentage]) => (
|
| 984 |
<div key={category} className="category-item">
|
| 985 |
<div className="category-bar-container">
|
|
@@ -1110,7 +1237,7 @@ function App() {
|
|
| 1110 |
)}
|
| 1111 |
|
| 1112 |
{/* Category Analysis for Custom Users */}
|
| 1113 |
-
{(selectedBehavioralPattern || interactions.length > 0 || userProfile.interaction_history.length > 0) && (
|
| 1114 |
<div
|
| 1115 |
key={`category-analysis-${interactions.length}-${selectedBehavioralPattern?.id || 'none'}-${sampleItems.length}`}
|
| 1116 |
className="category-analysis"
|
|
@@ -1128,7 +1255,7 @@ function App() {
|
|
| 1128 |
{Object.keys(categoryPercentages).length > 0 ? (
|
| 1129 |
Object.entries(categoryPercentages)
|
| 1130 |
.sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
|
| 1131 |
-
.slice(0,
|
| 1132 |
.map(([category, percentage]) => (
|
| 1133 |
<div key={category} className="category-item">
|
| 1134 |
<div className="category-bar-container">
|
|
@@ -1141,6 +1268,23 @@ function App() {
|
|
| 1141 |
<span className="category-percent">{percentage}%</span>
|
| 1142 |
</div>
|
| 1143 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1144 |
) : (
|
| 1145 |
<div className="category-loading">
|
| 1146 |
<p>Processing interaction categories...</p>
|
|
@@ -1325,18 +1469,24 @@ function App() {
|
|
| 1325 |
)}
|
| 1326 |
|
| 1327 |
<h3>Synthetic Interaction Patterns</h3>
|
| 1328 |
-
<p>Generate realistic user behavior patterns with proportional view, cart, and purchase events</p>
|
| 1329 |
|
| 1330 |
<div className="pattern-buttons">
|
| 1331 |
{INTERACTION_PATTERNS.map((pattern, index) => (
|
| 1332 |
<button
|
| 1333 |
key={index}
|
| 1334 |
-
className={`pattern-btn ${selectedPattern?.name === pattern.name ? 'active' : ''}`}
|
| 1335 |
onClick={() => handlePatternSelect(pattern)}
|
| 1336 |
>
|
| 1337 |
{pattern.name}
|
| 1338 |
<br />
|
| 1339 |
-
<small>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1340 |
</button>
|
| 1341 |
))}
|
| 1342 |
<button
|
|
@@ -1347,6 +1497,25 @@ function App() {
|
|
| 1347 |
Clear All
|
| 1348 |
</button>
|
| 1349 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1350 |
</>
|
| 1351 |
)}
|
| 1352 |
|
|
@@ -1498,11 +1667,10 @@ function App() {
|
|
| 1498 |
value={recommendationType}
|
| 1499 |
onChange={(e) => setRecommendationType(e.target.value)}
|
| 1500 |
>
|
| 1501 |
-
<option value="hybrid">Hybrid</option>
|
| 1502 |
-
<option value="enhanced">π― Enhanced Hybrid (Category-Aware)</option>
|
| 1503 |
-
<option value="category_focused">π― Category Focused (80% Match)</option>
|
| 1504 |
<option value="collaborative">Collaborative Filtering</option>
|
| 1505 |
<option value="content">Content-Based</option>
|
|
|
|
| 1506 |
</select>
|
| 1507 |
</div>
|
| 1508 |
|
|
@@ -1521,7 +1689,7 @@ function App() {
|
|
| 1521 |
</select>
|
| 1522 |
</div>
|
| 1523 |
|
| 1524 |
-
{
|
| 1525 |
<div className="form-group">
|
| 1526 |
<label htmlFor="collabWeight">Collaborative Weight:</label>
|
| 1527 |
<input
|
|
@@ -1548,7 +1716,13 @@ function App() {
|
|
| 1548 |
|
| 1549 |
{recommendationType === 'content' && userProfile.interaction_history.length === 0 && (
|
| 1550 |
<p style={{color: '#dc3545', marginTop: '10px', fontSize: '14px'}}>
|
| 1551 |
-
Content-based recommendations require interaction history. Please select
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1552 |
</p>
|
| 1553 |
)}
|
| 1554 |
</div>
|
|
@@ -1567,7 +1741,7 @@ function App() {
|
|
| 1567 |
|
| 1568 |
<div className="stats">
|
| 1569 |
<strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
|
| 1570 |
-
${userProfile.income.toLocaleString()} income
|
| 1571 |
{selectedCategory && (
|
| 1572 |
<span> | <strong>Category Filter:</strong> <span className="category-filter-display">{selectedCategory.replace(/\./g, ' > ')}</span></span>
|
| 1573 |
)}
|
|
|
|
| 6 |
|
| 7 |
// Interaction patterns with realistic ratios
|
| 8 |
const INTERACTION_PATTERNS = [
|
| 9 |
+
{ name: 'New User (No History)', views: 0, carts: 0, purchases: 0, isNewUser: true },
|
| 10 |
{ name: 'Light Browsing', views: 15, carts: 2, purchases: 0 },
|
| 11 |
{ name: 'Window Shopping', views: 25, carts: 5, purchases: 1 },
|
| 12 |
{ name: 'Serious Shopper', views: 35, carts: 8, purchases: 3 },
|
|
|
|
| 19 |
age: 30,
|
| 20 |
gender: 'male',
|
| 21 |
income: 50000,
|
| 22 |
+
profession: 'Technology',
|
| 23 |
+
location: 'Urban',
|
| 24 |
+
education_level: "Bachelor's",
|
| 25 |
+
marital_status: 'Single',
|
| 26 |
interaction_history: []
|
| 27 |
});
|
| 28 |
|
| 29 |
const [recommendationType, setRecommendationType] = useState('hybrid');
|
| 30 |
+
const [numRecommendations, setNumRecommendations] = useState(10);
|
| 31 |
const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
|
| 32 |
|
| 33 |
const [recommendations, setRecommendations] = useState([]);
|
|
|
|
| 306 |
age: user.age,
|
| 307 |
gender: user.gender,
|
| 308 |
income: user.income,
|
| 309 |
+
profession: user.profession || 'Other',
|
| 310 |
+
location: user.location || 'Urban',
|
| 311 |
+
education_level: user.education_level || 'High School',
|
| 312 |
+
marital_status: user.marital_status || 'Single',
|
| 313 |
interaction_history: user.interaction_history.slice(0, 50) // Limit to 50 items
|
| 314 |
});
|
| 315 |
// Clear any synthetic interactions and expanded states
|
|
|
|
| 452 |
|
| 453 |
const handlePatternSelect = (pattern) => {
|
| 454 |
setSelectedPattern(pattern);
|
| 455 |
+
|
| 456 |
+
// Handle New User (zero interactions) pattern specially
|
| 457 |
+
if (pattern.isNewUser) {
|
| 458 |
+
// Clear all interactions and interaction history for new user
|
| 459 |
+
setInteractions([]);
|
| 460 |
+
setUserProfile(prev => ({
|
| 461 |
+
...prev,
|
| 462 |
+
interaction_history: []
|
| 463 |
+
}));
|
| 464 |
+
console.log('Selected New User pattern - cleared all interactions');
|
| 465 |
+
} else {
|
| 466 |
+
// Generate realistic interactions for other patterns
|
| 467 |
+
generateRealisticInteractions(pattern);
|
| 468 |
+
}
|
| 469 |
};
|
| 470 |
|
| 471 |
const toggleInteractionExpand = (interactionId) => {
|
|
|
|
| 494 |
|
| 495 |
const counts = getInteractionCounts();
|
| 496 |
|
| 497 |
+
// Utility function to normalize percentages to sum to exactly 100%
|
| 498 |
+
const normalizePercentages = (categoryCounts, totalInteractions) => {
|
| 499 |
+
if (totalInteractions === 0) return {};
|
| 500 |
+
|
| 501 |
+
const categories = Object.keys(categoryCounts);
|
| 502 |
+
if (categories.length === 0) return {};
|
| 503 |
+
|
| 504 |
+
// Calculate raw percentages
|
| 505 |
+
const rawPercentages = {};
|
| 506 |
+
categories.forEach(category => {
|
| 507 |
+
rawPercentages[category] = (categoryCounts[category] / totalInteractions) * 100;
|
| 508 |
+
});
|
| 509 |
+
|
| 510 |
+
// Round all percentages to 1 decimal place
|
| 511 |
+
const roundedPercentages = {};
|
| 512 |
+
let totalRounded = 0;
|
| 513 |
+
categories.forEach(category => {
|
| 514 |
+
roundedPercentages[category] = Math.round(rawPercentages[category] * 10) / 10;
|
| 515 |
+
totalRounded += roundedPercentages[category];
|
| 516 |
+
});
|
| 517 |
+
|
| 518 |
+
// Adjust the largest category to make total exactly 100%
|
| 519 |
+
const difference = 100.0 - totalRounded;
|
| 520 |
+
if (Math.abs(difference) > 0.01) {
|
| 521 |
+
// Find category with largest raw percentage
|
| 522 |
+
const largestCategory = categories.reduce((max, category) =>
|
| 523 |
+
rawPercentages[category] > rawPercentages[max] ? category : max
|
| 524 |
+
);
|
| 525 |
+
roundedPercentages[largestCategory] = Math.round((roundedPercentages[largestCategory] + difference) * 10) / 10;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
// Convert to string with 1 decimal place
|
| 529 |
+
const normalizedPercentages = {};
|
| 530 |
+
categories.forEach(category => {
|
| 531 |
+
normalizedPercentages[category] = roundedPercentages[category].toFixed(1);
|
| 532 |
+
});
|
| 533 |
+
|
| 534 |
+
return normalizedPercentages;
|
| 535 |
+
};
|
| 536 |
+
|
| 537 |
// Calculate category percentages from user interactions
|
| 538 |
const getCategoryPercentages = () => {
|
| 539 |
console.log('getCategoryPercentages called:', {
|
|
|
|
| 566 |
console.log('Enriched behavioral pattern results:', { categoryCounts, totalInteractions });
|
| 567 |
|
| 568 |
if (totalInteractions > 0) {
|
| 569 |
+
const categoryPercentages = normalizePercentages(categoryCounts, totalInteractions);
|
|
|
|
|
|
|
|
|
|
| 570 |
console.log('Returning enriched behavioral pattern percentages:', categoryPercentages);
|
| 571 |
return categoryPercentages;
|
| 572 |
}
|
|
|
|
| 581 |
|
| 582 |
interactions.forEach(interaction => {
|
| 583 |
console.log('Processing interaction:', interaction);
|
| 584 |
+
const category = interaction.category_code || interaction.category || 'Unknown';
|
| 585 |
+
categoryCounts[category] = (categoryCounts[category] || 0) + 1;
|
| 586 |
+
totalInteractions++;
|
|
|
|
|
|
|
| 587 |
});
|
| 588 |
|
| 589 |
console.log('Synthetic interaction results:', { categoryCounts, totalInteractions });
|
| 590 |
|
| 591 |
if (totalInteractions > 0) {
|
| 592 |
+
const categoryPercentages = normalizePercentages(categoryCounts, totalInteractions);
|
|
|
|
|
|
|
|
|
|
| 593 |
console.log('Returning synthetic percentages:', categoryPercentages);
|
| 594 |
return categoryPercentages;
|
| 595 |
}
|
|
|
|
| 608 |
totalInteractions++;
|
| 609 |
});
|
| 610 |
|
| 611 |
+
const categoryPercentages = normalizePercentages(categoryCounts, totalInteractions);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
return categoryPercentages;
|
| 613 |
}
|
| 614 |
|
|
|
|
| 758 |
<div className="real-user-stats">
|
| 759 |
<div className="user-stat">
|
| 760 |
<span className="stat-label">Demographics:</span>
|
| 761 |
+
<span className="stat-value">
|
| 762 |
+
{selectedRealUser.age}yr {selectedRealUser.gender}, ${selectedRealUser.income.toLocaleString()}
|
| 763 |
+
{selectedRealUser.profession && ` | ${selectedRealUser.profession}`}
|
| 764 |
+
{selectedRealUser.location && ` | ${selectedRealUser.location}`}
|
| 765 |
+
{selectedRealUser.education_level && ` | ${selectedRealUser.education_level}`}
|
| 766 |
+
{selectedRealUser.marital_status && ` | ${selectedRealUser.marital_status}`}
|
| 767 |
+
</span>
|
| 768 |
</div>
|
| 769 |
<div className="user-stat">
|
| 770 |
<span className="stat-label">Behavior Pattern:</span>
|
|
|
|
| 903 |
/>
|
| 904 |
</div>
|
| 905 |
</div>
|
| 906 |
+
|
| 907 |
+
{/* New Demographic Features */}
|
| 908 |
+
<div className="form-row demographic-features">
|
| 909 |
+
<div className="form-group">
|
| 910 |
+
<label htmlFor="profession">Profession:</label>
|
| 911 |
+
<select
|
| 912 |
+
id="profession"
|
| 913 |
+
value={userProfile.profession}
|
| 914 |
+
onChange={(e) => handleProfileChange('profession', e.target.value)}
|
| 915 |
+
disabled={useRealUsers && selectedRealUser}
|
| 916 |
+
style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
|
| 917 |
+
>
|
| 918 |
+
<option value="Technology">Technology</option>
|
| 919 |
+
<option value="Healthcare">Healthcare</option>
|
| 920 |
+
<option value="Education">Education</option>
|
| 921 |
+
<option value="Finance">Finance</option>
|
| 922 |
+
<option value="Retail">Retail</option>
|
| 923 |
+
<option value="Manufacturing">Manufacturing</option>
|
| 924 |
+
<option value="Services">Services</option>
|
| 925 |
+
<option value="Other">Other</option>
|
| 926 |
+
</select>
|
| 927 |
+
</div>
|
| 928 |
+
|
| 929 |
+
<div className="form-group">
|
| 930 |
+
<label htmlFor="location">Location:</label>
|
| 931 |
+
<select
|
| 932 |
+
id="location"
|
| 933 |
+
value={userProfile.location}
|
| 934 |
+
onChange={(e) => handleProfileChange('location', e.target.value)}
|
| 935 |
+
disabled={useRealUsers && selectedRealUser}
|
| 936 |
+
style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
|
| 937 |
+
>
|
| 938 |
+
<option value="Urban">Urban</option>
|
| 939 |
+
<option value="Suburban">Suburban</option>
|
| 940 |
+
<option value="Rural">Rural</option>
|
| 941 |
+
</select>
|
| 942 |
+
</div>
|
| 943 |
+
|
| 944 |
+
<div className="form-group">
|
| 945 |
+
<label htmlFor="education_level">Education Level:</label>
|
| 946 |
+
<select
|
| 947 |
+
id="education_level"
|
| 948 |
+
value={userProfile.education_level}
|
| 949 |
+
onChange={(e) => handleProfileChange('education_level', e.target.value)}
|
| 950 |
+
disabled={useRealUsers && selectedRealUser}
|
| 951 |
+
style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
|
| 952 |
+
>
|
| 953 |
+
<option value="High School">High School</option>
|
| 954 |
+
<option value="Some College">Some College</option>
|
| 955 |
+
<option value="Bachelor's">Bachelor's</option>
|
| 956 |
+
<option value="Master's">Master's</option>
|
| 957 |
+
<option value="PhD+">PhD+</option>
|
| 958 |
+
</select>
|
| 959 |
+
</div>
|
| 960 |
+
|
| 961 |
+
<div className="form-group">
|
| 962 |
+
<label htmlFor="marital_status">Marital Status:</label>
|
| 963 |
+
<select
|
| 964 |
+
id="marital_status"
|
| 965 |
+
value={userProfile.marital_status}
|
| 966 |
+
onChange={(e) => handleProfileChange('marital_status', e.target.value)}
|
| 967 |
+
disabled={useRealUsers && selectedRealUser}
|
| 968 |
+
style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
|
| 969 |
+
>
|
| 970 |
+
<option value="Single">Single</option>
|
| 971 |
+
<option value="Married">Married</option>
|
| 972 |
+
<option value="Divorced">Divorced</option>
|
| 973 |
+
<option value="Widowed">Widowed</option>
|
| 974 |
+
</select>
|
| 975 |
+
</div>
|
| 976 |
+
</div>
|
| 977 |
</div>
|
| 978 |
|
| 979 |
{/* Random Behavioral Patterns for Custom Users */}
|
|
|
|
| 1106 |
<div className="category-percentages">
|
| 1107 |
{Object.entries(categoryPercentages)
|
| 1108 |
.sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
|
| 1109 |
+
.slice(0, 10)
|
| 1110 |
.map(([category, percentage]) => (
|
| 1111 |
<div key={category} className="category-item">
|
| 1112 |
<div className="category-bar-container">
|
|
|
|
| 1237 |
)}
|
| 1238 |
|
| 1239 |
{/* Category Analysis for Custom Users */}
|
| 1240 |
+
{(selectedBehavioralPattern || interactions.length > 0 || userProfile.interaction_history.length > 0 || (recommendations.length > 0 && selectedPattern?.isNewUser)) && (
|
| 1241 |
<div
|
| 1242 |
key={`category-analysis-${interactions.length}-${selectedBehavioralPattern?.id || 'none'}-${sampleItems.length}`}
|
| 1243 |
className="category-analysis"
|
|
|
|
| 1255 |
{Object.keys(categoryPercentages).length > 0 ? (
|
| 1256 |
Object.entries(categoryPercentages)
|
| 1257 |
.sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
|
| 1258 |
+
.slice(0, 10)
|
| 1259 |
.map(([category, percentage]) => (
|
| 1260 |
<div key={category} className="category-item">
|
| 1261 |
<div className="category-bar-container">
|
|
|
|
| 1268 |
<span className="category-percent">{percentage}%</span>
|
| 1269 |
</div>
|
| 1270 |
))
|
| 1271 |
+
) : selectedPattern?.isNewUser ? (
|
| 1272 |
+
<div className="new-user-category-message">
|
| 1273 |
+
<div style={{
|
| 1274 |
+
padding: '20px',
|
| 1275 |
+
backgroundColor: '#f8f9fa',
|
| 1276 |
+
border: '2px dashed #6c757d',
|
| 1277 |
+
borderRadius: '8px',
|
| 1278 |
+
textAlign: 'center',
|
| 1279 |
+
color: '#495057'
|
| 1280 |
+
}}>
|
| 1281 |
+
<h6 style={{margin: '0 0 8px 0', color: '#343a40'}}>π New User - No History</h6>
|
| 1282 |
+
<p style={{margin: '0', fontSize: '14px'}}>
|
| 1283 |
+
No category preferences yet.<br />
|
| 1284 |
+
Recommendations are based on demographics only.
|
| 1285 |
+
</p>
|
| 1286 |
+
</div>
|
| 1287 |
+
</div>
|
| 1288 |
) : (
|
| 1289 |
<div className="category-loading">
|
| 1290 |
<p>Processing interaction categories...</p>
|
|
|
|
| 1469 |
)}
|
| 1470 |
|
| 1471 |
<h3>Synthetic Interaction Patterns</h3>
|
| 1472 |
+
<p>Generate realistic user behavior patterns with proportional view, cart, and purchase events. Choose "New User" to test cold-start scenarios.</p>
|
| 1473 |
|
| 1474 |
<div className="pattern-buttons">
|
| 1475 |
{INTERACTION_PATTERNS.map((pattern, index) => (
|
| 1476 |
<button
|
| 1477 |
key={index}
|
| 1478 |
+
className={`pattern-btn ${selectedPattern?.name === pattern.name ? 'active' : ''} ${pattern.isNewUser ? 'new-user-pattern' : ''}`}
|
| 1479 |
onClick={() => handlePatternSelect(pattern)}
|
| 1480 |
>
|
| 1481 |
{pattern.name}
|
| 1482 |
<br />
|
| 1483 |
+
<small>
|
| 1484 |
+
{pattern.isNewUser ? (
|
| 1485 |
+
<span style={{fontStyle: 'italic', color: '#6c757d'}}>Cold Start User</span>
|
| 1486 |
+
) : (
|
| 1487 |
+
`${pattern.views}V β’ ${pattern.carts}C β’ ${pattern.purchases}P`
|
| 1488 |
+
)}
|
| 1489 |
+
</small>
|
| 1490 |
</button>
|
| 1491 |
))}
|
| 1492 |
<button
|
|
|
|
| 1497 |
Clear All
|
| 1498 |
</button>
|
| 1499 |
</div>
|
| 1500 |
+
|
| 1501 |
+
{/* Show informational message for New User pattern */}
|
| 1502 |
+
{selectedPattern?.isNewUser && (
|
| 1503 |
+
<div style={{
|
| 1504 |
+
backgroundColor: '#e3f2fd',
|
| 1505 |
+
border: '1px solid #90caf9',
|
| 1506 |
+
borderRadius: '8px',
|
| 1507 |
+
padding: '15px',
|
| 1508 |
+
margin: '15px 0',
|
| 1509 |
+
color: '#1565c0'
|
| 1510 |
+
}}>
|
| 1511 |
+
<h4 style={{margin: '0 0 10px 0', color: '#0d47a1'}}>π New User (Cold Start) Selected</h4>
|
| 1512 |
+
<p style={{margin: '0', fontSize: '14px', lineHeight: '1.4'}}>
|
| 1513 |
+
Testing cold-start scenario with no interaction history.
|
| 1514 |
+
<br /><strong>Compatible algorithms:</strong> Collaborative β
, Hybrid β
(demographics-based)
|
| 1515 |
+
<br /><strong>Incompatible:</strong> Content-based β, Category-boosted β (require history)
|
| 1516 |
+
</p>
|
| 1517 |
+
</div>
|
| 1518 |
+
)}
|
| 1519 |
</>
|
| 1520 |
)}
|
| 1521 |
|
|
|
|
| 1667 |
value={recommendationType}
|
| 1668 |
onChange={(e) => setRecommendationType(e.target.value)}
|
| 1669 |
>
|
| 1670 |
+
<option value="hybrid">Hybrid (Recommended)</option>
|
|
|
|
|
|
|
| 1671 |
<option value="collaborative">Collaborative Filtering</option>
|
| 1672 |
<option value="content">Content-Based</option>
|
| 1673 |
+
<option value="category_boosted">π Category Boosted (50% from user categories)</option>
|
| 1674 |
</select>
|
| 1675 |
</div>
|
| 1676 |
|
|
|
|
| 1689 |
</select>
|
| 1690 |
</div>
|
| 1691 |
|
| 1692 |
+
{recommendationType === 'hybrid' && (
|
| 1693 |
<div className="form-group">
|
| 1694 |
<label htmlFor="collabWeight">Collaborative Weight:</label>
|
| 1695 |
<input
|
|
|
|
| 1716 |
|
| 1717 |
{recommendationType === 'content' && userProfile.interaction_history.length === 0 && (
|
| 1718 |
<p style={{color: '#dc3545', marginTop: '10px', fontSize: '14px'}}>
|
| 1719 |
+
β οΈ Content-based recommendations require interaction history. Please select a pattern with interactions above, or choose 'Collaborative' or 'Hybrid' for new users.
|
| 1720 |
+
</p>
|
| 1721 |
+
)}
|
| 1722 |
+
|
| 1723 |
+
{recommendationType === 'category_boosted' && userProfile.interaction_history.length === 0 && (
|
| 1724 |
+
<p style={{color: '#dc3545', marginTop: '10px', fontSize: '14px'}}>
|
| 1725 |
+
β οΈ Category-boosted recommendations require interaction history to analyze preferences. Please select a pattern with interactions above, or choose 'Collaborative' for new users.
|
| 1726 |
</p>
|
| 1727 |
)}
|
| 1728 |
</div>
|
|
|
|
| 1741 |
|
| 1742 |
<div className="stats">
|
| 1743 |
<strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
|
| 1744 |
+
${userProfile.income.toLocaleString()} income, {userProfile.profession}, {userProfile.location}, {userProfile.education_level}, {userProfile.marital_status}
|
| 1745 |
{selectedCategory && (
|
| 1746 |
<span> | <strong>Category Filter:</strong> <span className="category-filter-display">{selectedCategory.replace(/\./g, ' > ')}</span></span>
|
| 1747 |
)}
|
src/data_generation/generate_demographics.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
class DemographicDataGenerator:
|
| 7 |
+
"""Generate realistic categorical demographic data correlating with existing age/income."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, seed: int = 42):
|
| 10 |
+
np.random.seed(seed)
|
| 11 |
+
|
| 12 |
+
# Define categorical mappings
|
| 13 |
+
self.profession_categories = [
|
| 14 |
+
"Technology", "Healthcare", "Education", "Finance",
|
| 15 |
+
"Retail", "Manufacturing", "Services", "Other"
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
self.location_categories = ["Urban", "Suburban", "Rural"]
|
| 19 |
+
|
| 20 |
+
self.education_categories = [
|
| 21 |
+
"High School", "Some College", "Bachelor's", "Master's", "PhD+"
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
self.marital_categories = ["Single", "Married", "Divorced", "Widowed"]
|
| 25 |
+
|
| 26 |
+
def generate_profession(self, age: int, income: float, gender: str) -> str:
|
| 27 |
+
"""Generate profession based on age, income, and gender correlations."""
|
| 28 |
+
|
| 29 |
+
# Age-based profession probabilities
|
| 30 |
+
if age < 25:
|
| 31 |
+
# Young adults - more likely in retail, services, some tech
|
| 32 |
+
probs = [0.15, 0.10, 0.08, 0.05, 0.25, 0.10, 0.20, 0.07]
|
| 33 |
+
elif age < 35:
|
| 34 |
+
# Early career - tech, healthcare, finance growth
|
| 35 |
+
probs = [0.25, 0.15, 0.10, 0.15, 0.12, 0.08, 0.10, 0.05]
|
| 36 |
+
elif age < 50:
|
| 37 |
+
# Mid career - established in all fields
|
| 38 |
+
probs = [0.20, 0.18, 0.15, 0.18, 0.08, 0.12, 0.07, 0.02]
|
| 39 |
+
else:
|
| 40 |
+
# Senior career - more in education, healthcare, services
|
| 41 |
+
probs = [0.15, 0.20, 0.20, 0.15, 0.05, 0.15, 0.08, 0.02]
|
| 42 |
+
|
| 43 |
+
# Income adjustments
|
| 44 |
+
if income > 90000: # High income
|
| 45 |
+
# Boost tech, finance, healthcare
|
| 46 |
+
probs[0] *= 1.5 # Technology
|
| 47 |
+
probs[3] *= 1.5 # Finance
|
| 48 |
+
probs[1] *= 1.3 # Healthcare
|
| 49 |
+
probs[4] *= 0.5 # Retail
|
| 50 |
+
probs[6] *= 0.7 # Services
|
| 51 |
+
elif income < 40000: # Lower income
|
| 52 |
+
# Boost retail, services, manufacturing
|
| 53 |
+
probs[4] *= 2.0 # Retail
|
| 54 |
+
probs[6] *= 1.8 # Services
|
| 55 |
+
probs[5] *= 1.5 # Manufacturing
|
| 56 |
+
probs[0] *= 0.3 # Technology
|
| 57 |
+
probs[3] *= 0.3 # Finance
|
| 58 |
+
|
| 59 |
+
# Normalize probabilities
|
| 60 |
+
probs = np.array(probs)
|
| 61 |
+
probs = probs / np.sum(probs)
|
| 62 |
+
|
| 63 |
+
return np.random.choice(self.profession_categories, p=probs)
|
| 64 |
+
|
| 65 |
+
def generate_location(self, income: float, profession: str) -> str:
|
| 66 |
+
"""Generate location based on income and profession."""
|
| 67 |
+
|
| 68 |
+
# Base probabilities (roughly US distribution)
|
| 69 |
+
probs = [0.62, 0.27, 0.11] # Urban, Suburban, Rural
|
| 70 |
+
|
| 71 |
+
# Income adjustments
|
| 72 |
+
if income > 80000:
|
| 73 |
+
# Higher income -> more suburban
|
| 74 |
+
probs = [0.45, 0.45, 0.10]
|
| 75 |
+
elif income < 35000:
|
| 76 |
+
# Lower income -> more urban/rural
|
| 77 |
+
probs = [0.70, 0.15, 0.15]
|
| 78 |
+
|
| 79 |
+
# Profession adjustments
|
| 80 |
+
if profession in ["Technology", "Finance"]:
|
| 81 |
+
# Tech/Finance -> more urban
|
| 82 |
+
probs[0] *= 1.4
|
| 83 |
+
probs[2] *= 0.5
|
| 84 |
+
elif profession in ["Manufacturing", "Other"]:
|
| 85 |
+
# Manufacturing -> more rural/suburban
|
| 86 |
+
probs[1] *= 1.3
|
| 87 |
+
probs[2] *= 1.5
|
| 88 |
+
probs[0] *= 0.7
|
| 89 |
+
|
| 90 |
+
# Normalize
|
| 91 |
+
probs = np.array(probs)
|
| 92 |
+
probs = probs / np.sum(probs)
|
| 93 |
+
|
| 94 |
+
return np.random.choice(self.location_categories, p=probs)
|
| 95 |
+
|
| 96 |
+
def generate_education_level(self, age: int, income: float, profession: str) -> str:
|
| 97 |
+
"""Generate education level based on age, income, and profession."""
|
| 98 |
+
|
| 99 |
+
# Base probabilities (roughly US distribution)
|
| 100 |
+
probs = [0.27, 0.20, 0.33, 0.13, 0.07] # HS, Some College, Bachelor's, Master's, PhD+
|
| 101 |
+
|
| 102 |
+
# Age adjustments (older generations had less college access)
|
| 103 |
+
if age > 55:
|
| 104 |
+
probs = [0.40, 0.25, 0.25, 0.08, 0.02]
|
| 105 |
+
elif age > 40:
|
| 106 |
+
probs = [0.32, 0.23, 0.30, 0.12, 0.03]
|
| 107 |
+
elif age < 30:
|
| 108 |
+
# Younger generation has more education
|
| 109 |
+
probs = [0.20, 0.15, 0.40, 0.18, 0.07]
|
| 110 |
+
|
| 111 |
+
# Income adjustments
|
| 112 |
+
if income > 100000:
|
| 113 |
+
# High income -> more advanced degrees
|
| 114 |
+
probs = [0.10, 0.10, 0.35, 0.30, 0.15]
|
| 115 |
+
elif income > 70000:
|
| 116 |
+
# Good income -> more bachelor's/master's
|
| 117 |
+
probs = [0.15, 0.15, 0.45, 0.20, 0.05]
|
| 118 |
+
elif income < 40000:
|
| 119 |
+
# Lower income -> less higher education
|
| 120 |
+
probs = [0.45, 0.30, 0.20, 0.04, 0.01]
|
| 121 |
+
|
| 122 |
+
# Profession adjustments
|
| 123 |
+
if profession in ["Technology", "Healthcare", "Finance"]:
|
| 124 |
+
# Professional fields -> more degrees
|
| 125 |
+
probs = [0.05, 0.10, 0.40, 0.30, 0.15]
|
| 126 |
+
elif profession == "Education":
|
| 127 |
+
# Education -> even more advanced degrees
|
| 128 |
+
probs = [0.02, 0.05, 0.25, 0.45, 0.23]
|
| 129 |
+
elif profession in ["Retail", "Services", "Manufacturing"]:
|
| 130 |
+
# Service industries -> less higher education
|
| 131 |
+
probs = [0.40, 0.25, 0.25, 0.08, 0.02]
|
| 132 |
+
|
| 133 |
+
# Normalize
|
| 134 |
+
probs = np.array(probs)
|
| 135 |
+
probs = probs / np.sum(probs)
|
| 136 |
+
|
| 137 |
+
return np.random.choice(self.education_categories, p=probs)
|
| 138 |
+
|
| 139 |
+
def generate_marital_status(self, age: int, gender: str) -> str:
|
| 140 |
+
"""Generate marital status based on age and gender."""
|
| 141 |
+
|
| 142 |
+
# Age-based probabilities
|
| 143 |
+
if age < 25:
|
| 144 |
+
probs = [0.85, 0.13, 0.02, 0.00] # Single, Married, Divorced, Widowed
|
| 145 |
+
elif age < 35:
|
| 146 |
+
probs = [0.45, 0.50, 0.05, 0.00]
|
| 147 |
+
elif age < 50:
|
| 148 |
+
probs = [0.15, 0.70, 0.14, 0.01]
|
| 149 |
+
elif age < 65:
|
| 150 |
+
probs = [0.10, 0.65, 0.20, 0.05]
|
| 151 |
+
else:
|
| 152 |
+
probs = [0.08, 0.55, 0.15, 0.22]
|
| 153 |
+
|
| 154 |
+
# Gender adjustments (women tend to be widowed more often in older ages)
|
| 155 |
+
if age > 65 and gender == 'female':
|
| 156 |
+
probs[3] *= 2.0 # More widowed women
|
| 157 |
+
probs[1] *= 0.8 # Fewer married
|
| 158 |
+
|
| 159 |
+
# Normalize
|
| 160 |
+
probs = np.array(probs)
|
| 161 |
+
probs = probs / np.sum(probs)
|
| 162 |
+
|
| 163 |
+
return np.random.choice(self.marital_categories, p=probs)
|
| 164 |
+
|
| 165 |
+
def generate_user_demographics(self, users_df: pd.DataFrame) -> pd.DataFrame:
|
| 166 |
+
"""Generate all demographic features for all users."""
|
| 167 |
+
|
| 168 |
+
print(f"Generating demographic data for {len(users_df)} users...")
|
| 169 |
+
|
| 170 |
+
# Create a copy to avoid modifying original
|
| 171 |
+
enhanced_users = users_df.copy()
|
| 172 |
+
|
| 173 |
+
# Generate each demographic feature
|
| 174 |
+
professions = []
|
| 175 |
+
locations = []
|
| 176 |
+
education_levels = []
|
| 177 |
+
marital_statuses = []
|
| 178 |
+
|
| 179 |
+
for idx, row in users_df.iterrows():
|
| 180 |
+
age = row['age']
|
| 181 |
+
income = row['income']
|
| 182 |
+
gender = row['gender']
|
| 183 |
+
|
| 184 |
+
# Generate profession first as it influences other features
|
| 185 |
+
profession = self.generate_profession(age, income, gender)
|
| 186 |
+
professions.append(profession)
|
| 187 |
+
|
| 188 |
+
# Generate location based on income and profession
|
| 189 |
+
location = self.generate_location(income, profession)
|
| 190 |
+
locations.append(location)
|
| 191 |
+
|
| 192 |
+
# Generate education based on age, income, and profession
|
| 193 |
+
education = self.generate_education_level(age, income, profession)
|
| 194 |
+
education_levels.append(education)
|
| 195 |
+
|
| 196 |
+
# Generate marital status based on age and gender
|
| 197 |
+
marital_status = self.generate_marital_status(age, gender)
|
| 198 |
+
marital_statuses.append(marital_status)
|
| 199 |
+
|
| 200 |
+
# Add new columns
|
| 201 |
+
enhanced_users['profession'] = professions
|
| 202 |
+
enhanced_users['location'] = locations
|
| 203 |
+
enhanced_users['education_level'] = education_levels
|
| 204 |
+
enhanced_users['marital_status'] = marital_statuses
|
| 205 |
+
|
| 206 |
+
return enhanced_users
|
| 207 |
+
|
| 208 |
+
def print_demographic_statistics(self, users_df: pd.DataFrame):
|
| 209 |
+
"""Print statistics about the generated demographics."""
|
| 210 |
+
|
| 211 |
+
print("\n=== Demographic Statistics ===")
|
| 212 |
+
|
| 213 |
+
# Profession distribution
|
| 214 |
+
print(f"\nProfession Distribution:")
|
| 215 |
+
prof_counts = users_df['profession'].value_counts()
|
| 216 |
+
for prof, count in prof_counts.items():
|
| 217 |
+
pct = (count / len(users_df)) * 100
|
| 218 |
+
print(f" {prof}: {count:,} ({pct:.1f}%)")
|
| 219 |
+
|
| 220 |
+
# Location distribution
|
| 221 |
+
print(f"\nLocation Distribution:")
|
| 222 |
+
loc_counts = users_df['location'].value_counts()
|
| 223 |
+
for loc, count in loc_counts.items():
|
| 224 |
+
pct = (count / len(users_df)) * 100
|
| 225 |
+
print(f" {loc}: {count:,} ({pct:.1f}%)")
|
| 226 |
+
|
| 227 |
+
# Education distribution
|
| 228 |
+
print(f"\nEducation Level Distribution:")
|
| 229 |
+
edu_counts = users_df['education_level'].value_counts()
|
| 230 |
+
for edu, count in edu_counts.items():
|
| 231 |
+
pct = (count / len(users_df)) * 100
|
| 232 |
+
print(f" {edu}: {count:,} ({pct:.1f}%)")
|
| 233 |
+
|
| 234 |
+
# Marital status distribution
|
| 235 |
+
print(f"\nMarital Status Distribution:")
|
| 236 |
+
marital_counts = users_df['marital_status'].value_counts()
|
| 237 |
+
for status, count in marital_counts.items():
|
| 238 |
+
pct = (count / len(users_df)) * 100
|
| 239 |
+
print(f" {status}: {count:,} ({pct:.1f}%)")
|
| 240 |
+
|
| 241 |
+
print(f"\nTotal users: {len(users_df):,}")
|
| 242 |
+
|
| 243 |
+
# Cross-tabulations to show correlations
|
| 244 |
+
print(f"\n=== Key Correlations ===")
|
| 245 |
+
|
| 246 |
+
# High income professions
|
| 247 |
+
high_income = users_df[users_df['income'] > 80000]
|
| 248 |
+
print(f"\nTop professions for high income (>${80000:,}+):")
|
| 249 |
+
high_income_prof = high_income['profession'].value_counts(normalize=True) * 100
|
| 250 |
+
for prof, pct in high_income_prof.head().items():
|
| 251 |
+
print(f" {prof}: {pct:.1f}%")
|
| 252 |
+
|
| 253 |
+
# Education by profession
|
| 254 |
+
print(f"\nEducation levels in Technology:")
|
| 255 |
+
tech_edu = users_df[users_df['profession'] == 'Technology']['education_level'].value_counts(normalize=True) * 100
|
| 256 |
+
for edu, pct in tech_edu.items():
|
| 257 |
+
print(f" {edu}: {pct:.1f}%")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
"""Main function to generate and save enhanced demographic data."""
|
| 262 |
+
|
| 263 |
+
# Load existing users data
|
| 264 |
+
users_path = "datasets/users.csv"
|
| 265 |
+
if not os.path.exists(users_path):
|
| 266 |
+
print(f"Error: {users_path} not found!")
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
print(f"Loading users data from {users_path}")
|
| 270 |
+
users_df = pd.read_csv(users_path)
|
| 271 |
+
|
| 272 |
+
print(f"Original data shape: {users_df.shape}")
|
| 273 |
+
print(f"Original columns: {list(users_df.columns)}")
|
| 274 |
+
|
| 275 |
+
# Generate demographic data
|
| 276 |
+
generator = DemographicDataGenerator(seed=42)
|
| 277 |
+
enhanced_users = generator.generate_user_demographics(users_df)
|
| 278 |
+
|
| 279 |
+
# Print statistics
|
| 280 |
+
generator.print_demographic_statistics(enhanced_users)
|
| 281 |
+
|
| 282 |
+
# Save enhanced data
|
| 283 |
+
output_path = "datasets/users_enhanced.csv"
|
| 284 |
+
enhanced_users.to_csv(output_path, index=False)
|
| 285 |
+
print(f"\nEnhanced users data saved to {output_path}")
|
| 286 |
+
|
| 287 |
+
print(f"Enhanced data shape: {enhanced_users.shape}")
|
| 288 |
+
print(f"New columns: {list(enhanced_users.columns)}")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
main()
|
src/inference/enhanced_recommendation_engine.py
DELETED
|
@@ -1,303 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Enhanced recommendation engine with category-aware filtering and improved user alignment.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import pandas as pd
|
| 8 |
-
from typing import Dict, List, Tuple, Optional
|
| 9 |
-
from collections import Counter
|
| 10 |
-
import random
|
| 11 |
-
|
| 12 |
-
import sys
|
| 13 |
-
import os
|
| 14 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 15 |
-
|
| 16 |
-
from src.inference.recommendation_engine import RecommendationEngine
|
| 17 |
-
from src.utils.real_user_selector import RealUserSelector
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class EnhancedRecommendationEngine(RecommendationEngine):
|
| 21 |
-
"""Enhanced recommendation engine with category-aware improvements."""
|
| 22 |
-
|
| 23 |
-
def __init__(self, artifacts_path: str = "src/artifacts/"):
|
| 24 |
-
super().__init__(artifacts_path)
|
| 25 |
-
self.real_user_selector = RealUserSelector()
|
| 26 |
-
|
| 27 |
-
def _analyze_user_category_preferences(self, interaction_history: List[int]) -> Dict[str, float]:
|
| 28 |
-
"""Analyze user's category preferences from interaction history."""
|
| 29 |
-
|
| 30 |
-
if not interaction_history:
|
| 31 |
-
return {}
|
| 32 |
-
|
| 33 |
-
category_counts = Counter()
|
| 34 |
-
|
| 35 |
-
for item_id in interaction_history:
|
| 36 |
-
# Get item category from items dataframe
|
| 37 |
-
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 38 |
-
if not item_row.empty:
|
| 39 |
-
category = item_row.iloc[0].get('category_code', 'Unknown')
|
| 40 |
-
category_counts[category] += 1
|
| 41 |
-
|
| 42 |
-
# Convert to percentages
|
| 43 |
-
total_interactions = sum(category_counts.values())
|
| 44 |
-
if total_interactions == 0:
|
| 45 |
-
return {}
|
| 46 |
-
|
| 47 |
-
category_preferences = {}
|
| 48 |
-
for category, count in category_counts.items():
|
| 49 |
-
category_preferences[category] = count / total_interactions
|
| 50 |
-
|
| 51 |
-
return category_preferences
|
| 52 |
-
|
| 53 |
-
def _boost_category_aligned_recommendations(self,
|
| 54 |
-
recommendations: List[Tuple[int, float, Dict]],
|
| 55 |
-
user_category_preferences: Dict[str, float],
|
| 56 |
-
boost_factor: float = 1.5) -> List[Tuple[int, float, Dict]]:
|
| 57 |
-
"""Boost recommendations that align with user's category preferences."""
|
| 58 |
-
|
| 59 |
-
if not user_category_preferences:
|
| 60 |
-
return recommendations
|
| 61 |
-
|
| 62 |
-
boosted_recs = []
|
| 63 |
-
|
| 64 |
-
for item_id, score, item_info in recommendations:
|
| 65 |
-
item_category = item_info.get('category_code', 'Unknown')
|
| 66 |
-
|
| 67 |
-
# Apply category boost if user has preference for this category
|
| 68 |
-
category_preference = user_category_preferences.get(item_category, 0)
|
| 69 |
-
|
| 70 |
-
if category_preference > 0:
|
| 71 |
-
# Boost score based on user's preference strength
|
| 72 |
-
boosted_score = score * (1 + boost_factor * category_preference)
|
| 73 |
-
boosted_recs.append((item_id, boosted_score, item_info))
|
| 74 |
-
else:
|
| 75 |
-
boosted_recs.append((item_id, score, item_info))
|
| 76 |
-
|
| 77 |
-
# Re-sort by boosted scores
|
| 78 |
-
boosted_recs.sort(key=lambda x: x[1], reverse=True)
|
| 79 |
-
return boosted_recs
|
| 80 |
-
|
| 81 |
-
def _diversify_recommendations(self,
|
| 82 |
-
recommendations: List[Tuple[int, float, Dict]],
|
| 83 |
-
max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
|
| 84 |
-
"""Ensure category diversity in recommendations."""
|
| 85 |
-
|
| 86 |
-
category_counts = Counter()
|
| 87 |
-
diversified_recs = []
|
| 88 |
-
|
| 89 |
-
for item_id, score, item_info in recommendations:
|
| 90 |
-
item_category = item_info.get('category_code', 'Unknown')
|
| 91 |
-
|
| 92 |
-
if category_counts[item_category] < max_per_category:
|
| 93 |
-
diversified_recs.append((item_id, score, item_info))
|
| 94 |
-
category_counts[item_category] += 1
|
| 95 |
-
|
| 96 |
-
return diversified_recs
|
| 97 |
-
|
| 98 |
-
def recommend_items_enhanced_hybrid(self,
|
| 99 |
-
age: int,
|
| 100 |
-
gender: str,
|
| 101 |
-
income: float,
|
| 102 |
-
interaction_history: List[int] = None,
|
| 103 |
-
k: int = 10,
|
| 104 |
-
collaborative_weight: float = 0.7,
|
| 105 |
-
category_boost: float = 1.5,
|
| 106 |
-
enable_category_boost: bool = True,
|
| 107 |
-
enable_diversity: bool = True,
|
| 108 |
-
max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
|
| 109 |
-
"""Generate enhanced hybrid recommendations with category awareness."""
|
| 110 |
-
|
| 111 |
-
# Start with base hybrid recommendations (get more than needed)
|
| 112 |
-
base_k = k * 3 # Get 3x more candidates for filtering
|
| 113 |
-
|
| 114 |
-
base_recommendations = self.recommend_items_hybrid(
|
| 115 |
-
age=age,
|
| 116 |
-
gender=gender,
|
| 117 |
-
income=income,
|
| 118 |
-
interaction_history=interaction_history,
|
| 119 |
-
k=base_k,
|
| 120 |
-
collaborative_weight=collaborative_weight
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
if not base_recommendations:
|
| 124 |
-
return []
|
| 125 |
-
|
| 126 |
-
# Analyze user's category preferences
|
| 127 |
-
if enable_category_boost and interaction_history:
|
| 128 |
-
user_category_preferences = self._analyze_user_category_preferences(interaction_history)
|
| 129 |
-
|
| 130 |
-
# Apply category-based boosting
|
| 131 |
-
base_recommendations = self._boost_category_aligned_recommendations(
|
| 132 |
-
base_recommendations,
|
| 133 |
-
user_category_preferences,
|
| 134 |
-
boost_factor=category_boost
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
# Apply diversity filtering if enabled
|
| 138 |
-
if enable_diversity:
|
| 139 |
-
base_recommendations = self._diversify_recommendations(
|
| 140 |
-
base_recommendations,
|
| 141 |
-
max_per_category=max_per_category
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
# Return top k
|
| 145 |
-
return base_recommendations[:k]
|
| 146 |
-
|
| 147 |
-
def recommend_items_category_focused(self,
|
| 148 |
-
age: int,
|
| 149 |
-
gender: str,
|
| 150 |
-
income: float,
|
| 151 |
-
interaction_history: List[int] = None,
|
| 152 |
-
k: int = 10,
|
| 153 |
-
focus_percentage: float = 0.7) -> List[Tuple[int, float, Dict]]:
|
| 154 |
-
"""Generate recommendations focused on user's preferred categories."""
|
| 155 |
-
|
| 156 |
-
if not interaction_history:
|
| 157 |
-
# Fall back to regular hybrid for users without history
|
| 158 |
-
return self.recommend_items_hybrid(age, gender, income, interaction_history, k)
|
| 159 |
-
|
| 160 |
-
# Analyze user preferences
|
| 161 |
-
user_category_preferences = self._analyze_user_category_preferences(interaction_history)
|
| 162 |
-
|
| 163 |
-
if not user_category_preferences:
|
| 164 |
-
return self.recommend_items_hybrid(age, gender, income, interaction_history, k)
|
| 165 |
-
|
| 166 |
-
# Get top categories (sorted by preference)
|
| 167 |
-
top_categories = sorted(user_category_preferences.items(),
|
| 168 |
-
key=lambda x: x[1], reverse=True)
|
| 169 |
-
|
| 170 |
-
# Determine how many recs to focus on preferred categories
|
| 171 |
-
focused_k = int(k * focus_percentage)
|
| 172 |
-
exploration_k = k - focused_k
|
| 173 |
-
|
| 174 |
-
# Get base recommendations
|
| 175 |
-
all_recommendations = self.recommend_items_hybrid(
|
| 176 |
-
age, gender, income, interaction_history, k * 2
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
# Split into focused and exploration recommendations
|
| 180 |
-
focused_recs = []
|
| 181 |
-
exploration_recs = []
|
| 182 |
-
|
| 183 |
-
# Get user's top 3 categories
|
| 184 |
-
preferred_categories = set([cat for cat, _ in top_categories[:3]])
|
| 185 |
-
|
| 186 |
-
for item_id, score, item_info in all_recommendations:
|
| 187 |
-
item_category = item_info.get('category_code', 'Unknown')
|
| 188 |
-
|
| 189 |
-
if (item_category in preferred_categories and
|
| 190 |
-
len(focused_recs) < focused_k):
|
| 191 |
-
focused_recs.append((item_id, score, item_info))
|
| 192 |
-
elif len(exploration_recs) < exploration_k:
|
| 193 |
-
exploration_recs.append((item_id, score, item_info))
|
| 194 |
-
|
| 195 |
-
# Combine focused and exploration recommendations
|
| 196 |
-
final_recommendations = focused_recs + exploration_recs
|
| 197 |
-
|
| 198 |
-
return final_recommendations[:k]
|
| 199 |
-
|
| 200 |
-
def get_recommendation_explanation(self,
|
| 201 |
-
recommendations: List[Tuple[int, float, Dict]],
|
| 202 |
-
interaction_history: List[int] = None) -> Dict:
|
| 203 |
-
"""Provide explanation for why these recommendations were generated."""
|
| 204 |
-
|
| 205 |
-
if not recommendations:
|
| 206 |
-
return {"message": "No recommendations generated"}
|
| 207 |
-
|
| 208 |
-
# Analyze recommendation categories
|
| 209 |
-
rec_categories = Counter()
|
| 210 |
-
for _, _, item_info in recommendations:
|
| 211 |
-
category = item_info.get('category_code', 'Unknown')
|
| 212 |
-
rec_categories[category] += 1
|
| 213 |
-
|
| 214 |
-
explanation = {
|
| 215 |
-
"total_recommendations": len(recommendations),
|
| 216 |
-
"categories_covered": len(rec_categories),
|
| 217 |
-
"category_breakdown": dict(rec_categories.most_common())
|
| 218 |
-
}
|
| 219 |
-
|
| 220 |
-
# Add user preference analysis if history available
|
| 221 |
-
if interaction_history:
|
| 222 |
-
user_preferences = self._analyze_user_category_preferences(interaction_history)
|
| 223 |
-
|
| 224 |
-
# Calculate alignment
|
| 225 |
-
user_cats = set(user_preferences.keys())
|
| 226 |
-
rec_cats = set(rec_categories.keys())
|
| 227 |
-
alignment = len(user_cats & rec_cats) / len(rec_cats) * 100 if rec_cats else 0
|
| 228 |
-
|
| 229 |
-
explanation.update({
|
| 230 |
-
"user_category_preferences": user_preferences,
|
| 231 |
-
"alignment_percentage": round(alignment, 1),
|
| 232 |
-
"matched_categories": list(user_cats & rec_cats),
|
| 233 |
-
"new_categories": list(rec_cats - user_cats)
|
| 234 |
-
})
|
| 235 |
-
|
| 236 |
-
return explanation
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def demo_enhanced_recommendations():
|
| 240 |
-
"""Demo the enhanced recommendation engine."""
|
| 241 |
-
|
| 242 |
-
print("π ENHANCED RECOMMENDATION ENGINE DEMO")
|
| 243 |
-
print("="*70)
|
| 244 |
-
|
| 245 |
-
# Initialize enhanced engine
|
| 246 |
-
engine = EnhancedRecommendationEngine()
|
| 247 |
-
|
| 248 |
-
# Get a real user for testing
|
| 249 |
-
real_user_selector = RealUserSelector()
|
| 250 |
-
test_users = real_user_selector.get_real_users(n=3, min_interactions=15)
|
| 251 |
-
|
| 252 |
-
for user in test_users:
|
| 253 |
-
print(f"\nπ Testing User {user['user_id']} ({user['age']}yr {user['gender']}):")
|
| 254 |
-
print(f" Interaction History: {len(user['interaction_history'])} items")
|
| 255 |
-
|
| 256 |
-
# Test different recommendation methods
|
| 257 |
-
methods = [
|
| 258 |
-
("Original Hybrid", lambda: engine.recommend_items_hybrid(
|
| 259 |
-
age=user['age'],
|
| 260 |
-
gender=user['gender'],
|
| 261 |
-
income=user['income'],
|
| 262 |
-
interaction_history=user['interaction_history'][:20],
|
| 263 |
-
k=10,
|
| 264 |
-
collaborative_weight=0.7
|
| 265 |
-
)),
|
| 266 |
-
("Enhanced Hybrid", lambda: engine.recommend_items_enhanced_hybrid(
|
| 267 |
-
age=user['age'],
|
| 268 |
-
gender=user['gender'],
|
| 269 |
-
income=user['income'],
|
| 270 |
-
interaction_history=user['interaction_history'][:20],
|
| 271 |
-
k=10,
|
| 272 |
-
collaborative_weight=0.7,
|
| 273 |
-
category_boost=1.5
|
| 274 |
-
)),
|
| 275 |
-
("Category Focused", lambda: engine.recommend_items_category_focused(
|
| 276 |
-
age=user['age'],
|
| 277 |
-
gender=user['gender'],
|
| 278 |
-
income=user['income'],
|
| 279 |
-
interaction_history=user['interaction_history'][:20],
|
| 280 |
-
k=10,
|
| 281 |
-
focus_percentage=0.8
|
| 282 |
-
))
|
| 283 |
-
]
|
| 284 |
-
|
| 285 |
-
for method_name, method_func in methods:
|
| 286 |
-
try:
|
| 287 |
-
recs = method_func()
|
| 288 |
-
explanation = engine.get_recommendation_explanation(
|
| 289 |
-
recs, user['interaction_history'][:20]
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
print(f"\n π― {method_name}:")
|
| 293 |
-
print(f" Categories: {explanation.get('category_breakdown', {})}")
|
| 294 |
-
print(f" Alignment: {explanation.get('alignment_percentage', 'N/A')}%")
|
| 295 |
-
|
| 296 |
-
except Exception as e:
|
| 297 |
-
print(f" β Error: {str(e)[:40]}...")
|
| 298 |
-
|
| 299 |
-
print(f"\nβ
Enhanced recommendation engine demo completed!")
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
if __name__ == "__main__":
|
| 303 |
-
demo_enhanced_recommendations()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/inference/enhanced_recommendation_engine_128d.py
DELETED
|
@@ -1,499 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Enhanced recommendation engine using 128D embeddings with diversity regularization.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import pandas as pd
|
| 8 |
-
import tensorflow as tf
|
| 9 |
-
import pickle
|
| 10 |
-
import os
|
| 11 |
-
from typing import Dict, List, Tuple, Optional
|
| 12 |
-
from collections import Counter, defaultdict
|
| 13 |
-
|
| 14 |
-
import sys
|
| 15 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 16 |
-
|
| 17 |
-
from src.models.enhanced_two_tower import EnhancedItemTower, EnhancedUserTower
|
| 18 |
-
from src.inference.faiss_index import FAISSItemIndex
|
| 19 |
-
from src.preprocessing.data_loader import DataProcessor
|
| 20 |
-
from src.preprocessing.user_data_preparation import prepare_user_features
|
| 21 |
-
from src.utils.real_user_selector import RealUserSelector
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class Enhanced128DRecommendationEngine:
|
| 25 |
-
"""Enhanced recommendation engine with 128D embeddings and all improvements."""
|
| 26 |
-
|
| 27 |
-
def __init__(self, artifacts_path: str = "src/artifacts/"):
|
| 28 |
-
self.artifacts_path = artifacts_path
|
| 29 |
-
self.embedding_dim = 128 # Fixed to 128D
|
| 30 |
-
|
| 31 |
-
# Model components
|
| 32 |
-
self.item_tower = None
|
| 33 |
-
self.user_tower = None
|
| 34 |
-
self.rating_model = None
|
| 35 |
-
self.faiss_index = None
|
| 36 |
-
self.data_processor = None
|
| 37 |
-
|
| 38 |
-
# Data
|
| 39 |
-
self.items_df = None
|
| 40 |
-
self.users_df = None
|
| 41 |
-
self.income_thresholds = None
|
| 42 |
-
|
| 43 |
-
# Load all components
|
| 44 |
-
self._load_all_components()
|
| 45 |
-
|
| 46 |
-
def _load_all_components(self):
|
| 47 |
-
"""Load all enhanced model components."""
|
| 48 |
-
|
| 49 |
-
print("Loading enhanced 128D recommendation engine...")
|
| 50 |
-
|
| 51 |
-
# Load data processor
|
| 52 |
-
self.data_processor = DataProcessor()
|
| 53 |
-
try:
|
| 54 |
-
self.data_processor.load_vocabularies(f"{self.artifacts_path}/vocabularies.pkl")
|
| 55 |
-
except FileNotFoundError:
|
| 56 |
-
print("β Vocabularies not found. Please train the model first.")
|
| 57 |
-
return
|
| 58 |
-
|
| 59 |
-
# Load datasets
|
| 60 |
-
self.items_df = pd.read_csv("datasets/items.csv")
|
| 61 |
-
self.users_df = pd.read_csv("datasets/users.csv")
|
| 62 |
-
|
| 63 |
-
# Load enhanced model components
|
| 64 |
-
self._load_enhanced_models()
|
| 65 |
-
|
| 66 |
-
# Load FAISS index with 128D
|
| 67 |
-
try:
|
| 68 |
-
self.faiss_index = FAISSItemIndex(embedding_dim=self.embedding_dim)
|
| 69 |
-
# Try to load enhanced embeddings first
|
| 70 |
-
if os.path.exists(f"{self.artifacts_path}/enhanced_item_embeddings.npy"):
|
| 71 |
-
enhanced_embeddings = np.load(
|
| 72 |
-
f"{self.artifacts_path}/enhanced_item_embeddings.npy",
|
| 73 |
-
allow_pickle=True
|
| 74 |
-
).item()
|
| 75 |
-
self.faiss_index.build_index(enhanced_embeddings)
|
| 76 |
-
print("β
Loaded enhanced 128D FAISS index")
|
| 77 |
-
else:
|
| 78 |
-
print("β οΈ Enhanced embeddings not found. Train enhanced model first.")
|
| 79 |
-
self.faiss_index = None
|
| 80 |
-
except Exception as e:
|
| 81 |
-
print(f"β οΈ Could not load FAISS index: {e}")
|
| 82 |
-
self.faiss_index = None
|
| 83 |
-
|
| 84 |
-
# Load income thresholds for categorical demographics
|
| 85 |
-
self._load_income_thresholds()
|
| 86 |
-
|
| 87 |
-
print("β
Enhanced 128D engine loaded successfully!")
|
| 88 |
-
|
| 89 |
-
def _load_enhanced_models(self):
|
| 90 |
-
"""Load enhanced model components."""
|
| 91 |
-
|
| 92 |
-
try:
|
| 93 |
-
# Create model architecture
|
| 94 |
-
self.item_tower = EnhancedItemTower(
|
| 95 |
-
item_vocab_size=len(self.data_processor.item_vocab),
|
| 96 |
-
category_vocab_size=len(self.data_processor.category_vocab),
|
| 97 |
-
brand_vocab_size=len(self.data_processor.brand_vocab),
|
| 98 |
-
embedding_dim=self.embedding_dim,
|
| 99 |
-
use_bias=True,
|
| 100 |
-
use_diversity_reg=False # Disable during inference
|
| 101 |
-
)
|
| 102 |
-
|
| 103 |
-
self.user_tower = EnhancedUserTower(
|
| 104 |
-
max_history_length=50,
|
| 105 |
-
embedding_dim=self.embedding_dim,
|
| 106 |
-
use_bias=True,
|
| 107 |
-
use_diversity_reg=False # Disable during inference
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
# Create rating model
|
| 111 |
-
self.rating_model = tf.keras.Sequential([
|
| 112 |
-
tf.keras.layers.Dense(512, activation="relu"),
|
| 113 |
-
tf.keras.layers.BatchNormalization(),
|
| 114 |
-
tf.keras.layers.Dropout(0.3),
|
| 115 |
-
tf.keras.layers.Dense(256, activation="relu"),
|
| 116 |
-
tf.keras.layers.BatchNormalization(),
|
| 117 |
-
tf.keras.layers.Dropout(0.2),
|
| 118 |
-
tf.keras.layers.Dense(64, activation="relu"),
|
| 119 |
-
tf.keras.layers.Dense(1, activation="sigmoid")
|
| 120 |
-
])
|
| 121 |
-
|
| 122 |
-
# Load weights - try enhanced first, fall back to regular
|
| 123 |
-
model_files = [
|
| 124 |
-
('enhanced_item_tower_weights_enhanced_best', 'enhanced_user_tower_weights_enhanced_best', 'enhanced_rating_model_weights_enhanced_best'),
|
| 125 |
-
('enhanced_item_tower_weights_enhanced_final', 'enhanced_user_tower_weights_enhanced_final', 'enhanced_rating_model_weights_enhanced_final'),
|
| 126 |
-
]
|
| 127 |
-
|
| 128 |
-
loaded = False
|
| 129 |
-
for item_file, user_file, rating_file in model_files:
|
| 130 |
-
try:
|
| 131 |
-
# Need to build models first with dummy data
|
| 132 |
-
self._build_models()
|
| 133 |
-
|
| 134 |
-
self.item_tower.load_weights(f"{self.artifacts_path}/{item_file}")
|
| 135 |
-
self.user_tower.load_weights(f"{self.artifacts_path}/{user_file}")
|
| 136 |
-
self.rating_model.load_weights(f"{self.artifacts_path}/{rating_file}")
|
| 137 |
-
|
| 138 |
-
print(f"β
Loaded enhanced model: {item_file}")
|
| 139 |
-
loaded = True
|
| 140 |
-
break
|
| 141 |
-
except Exception as e:
|
| 142 |
-
print(f"β οΈ Could not load {item_file}: {e}")
|
| 143 |
-
continue
|
| 144 |
-
|
| 145 |
-
if not loaded:
|
| 146 |
-
print("β No enhanced model weights found. Please train enhanced model first.")
|
| 147 |
-
self.item_tower = None
|
| 148 |
-
self.user_tower = None
|
| 149 |
-
self.rating_model = None
|
| 150 |
-
|
| 151 |
-
except Exception as e:
|
| 152 |
-
print(f"β Failed to load enhanced models: {e}")
|
| 153 |
-
self.item_tower = None
|
| 154 |
-
self.user_tower = None
|
| 155 |
-
self.rating_model = None
|
| 156 |
-
|
| 157 |
-
def _build_models(self):
|
| 158 |
-
"""Build models with dummy data to initialize weights."""
|
| 159 |
-
|
| 160 |
-
# Dummy item features
|
| 161 |
-
dummy_item_features = {
|
| 162 |
-
'product_id': tf.constant([0]),
|
| 163 |
-
'category_id': tf.constant([0]),
|
| 164 |
-
'brand_id': tf.constant([0]),
|
| 165 |
-
'price': tf.constant([100.0])
|
| 166 |
-
}
|
| 167 |
-
|
| 168 |
-
# Dummy user features
|
| 169 |
-
dummy_user_features = {
|
| 170 |
-
'age': tf.constant([2]), # Adult category
|
| 171 |
-
'gender': tf.constant([0]), # Female
|
| 172 |
-
'income': tf.constant([2]), # Middle income
|
| 173 |
-
'item_history_embeddings': tf.constant(np.zeros((1, 50, self.embedding_dim), dtype=np.float32))
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
# Forward pass to build models
|
| 177 |
-
_ = self.item_tower(dummy_item_features, training=False)
|
| 178 |
-
_ = self.user_tower(dummy_user_features, training=False)
|
| 179 |
-
|
| 180 |
-
# Build rating model
|
| 181 |
-
dummy_concat = tf.constant(np.zeros((1, self.embedding_dim * 2), dtype=np.float32))
|
| 182 |
-
_ = self.rating_model(dummy_concat, training=False)
|
| 183 |
-
|
| 184 |
-
def _load_income_thresholds(self):
|
| 185 |
-
"""Load income thresholds for categorical processing."""
|
| 186 |
-
|
| 187 |
-
# Calculate income thresholds from training data
|
| 188 |
-
user_incomes = self.users_df['income'].values
|
| 189 |
-
self.income_thresholds = np.percentile(user_incomes, [0, 20, 40, 60, 80, 100])
|
| 190 |
-
print(f"Income thresholds: {self.income_thresholds}")
|
| 191 |
-
|
| 192 |
-
def categorize_age(self, age: float) -> int:
|
| 193 |
-
"""Categorize age into 6 groups."""
|
| 194 |
-
if age < 18: return 0 # Teen
|
| 195 |
-
elif age < 26: return 1 # Young Adult
|
| 196 |
-
elif age < 36: return 2 # Adult
|
| 197 |
-
elif age < 51: return 3 # Middle Age
|
| 198 |
-
elif age < 66: return 4 # Mature
|
| 199 |
-
else: return 5 # Senior
|
| 200 |
-
|
| 201 |
-
def categorize_income(self, income: float) -> int:
|
| 202 |
-
"""Categorize income into 5 percentile groups."""
|
| 203 |
-
category = np.digitize([income], self.income_thresholds[1:-1])[0]
|
| 204 |
-
return min(max(category, 0), 4)
|
| 205 |
-
|
| 206 |
-
def categorize_gender(self, gender: str) -> int:
|
| 207 |
-
"""Categorize gender."""
|
| 208 |
-
return 1 if gender.lower() == 'male' else 0
|
| 209 |
-
|
| 210 |
-
def get_user_embedding(self,
|
| 211 |
-
age: int,
|
| 212 |
-
gender: str,
|
| 213 |
-
income: float,
|
| 214 |
-
interaction_history: List[int] = None) -> np.ndarray:
|
| 215 |
-
"""Generate user embedding with categorical demographics."""
|
| 216 |
-
|
| 217 |
-
if self.user_tower is None:
|
| 218 |
-
print("β User tower not loaded")
|
| 219 |
-
return None
|
| 220 |
-
|
| 221 |
-
# Categorize demographics
|
| 222 |
-
age_cat = self.categorize_age(age)
|
| 223 |
-
gender_cat = self.categorize_gender(gender)
|
| 224 |
-
income_cat = self.categorize_income(income)
|
| 225 |
-
|
| 226 |
-
# Prepare interaction history embeddings
|
| 227 |
-
if interaction_history is None:
|
| 228 |
-
interaction_history = []
|
| 229 |
-
|
| 230 |
-
# Get item embeddings for history
|
| 231 |
-
history_embeddings = np.zeros((50, self.embedding_dim), dtype=np.float32)
|
| 232 |
-
|
| 233 |
-
for i, item_id in enumerate(interaction_history[:50]):
|
| 234 |
-
if self.faiss_index and item_id in self.faiss_index.item_id_to_idx:
|
| 235 |
-
item_emb = self.faiss_index.get_item_embedding(item_id)
|
| 236 |
-
if item_emb is not None:
|
| 237 |
-
history_embeddings[i] = item_emb
|
| 238 |
-
|
| 239 |
-
# Create user features
|
| 240 |
-
user_features = {
|
| 241 |
-
'age': tf.constant([age_cat]),
|
| 242 |
-
'gender': tf.constant([gender_cat]),
|
| 243 |
-
'income': tf.constant([income_cat]),
|
| 244 |
-
'item_history_embeddings': tf.constant([history_embeddings])
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
# Get embedding
|
| 248 |
-
user_output = self.user_tower(user_features, training=False)
|
| 249 |
-
if isinstance(user_output, tuple):
|
| 250 |
-
user_embedding = user_output[0].numpy()[0]
|
| 251 |
-
else:
|
| 252 |
-
user_embedding = user_output.numpy()[0]
|
| 253 |
-
|
| 254 |
-
return user_embedding
|
| 255 |
-
|
| 256 |
-
def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
|
| 257 |
-
"""Get item embedding."""
|
| 258 |
-
|
| 259 |
-
if self.faiss_index:
|
| 260 |
-
return self.faiss_index.get_item_embedding(item_id)
|
| 261 |
-
|
| 262 |
-
# Fallback to model computation
|
| 263 |
-
if self.item_tower is None:
|
| 264 |
-
return None
|
| 265 |
-
|
| 266 |
-
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 267 |
-
if item_row.empty:
|
| 268 |
-
return None
|
| 269 |
-
|
| 270 |
-
item_data = item_row.iloc[0]
|
| 271 |
-
|
| 272 |
-
# Prepare features
|
| 273 |
-
item_features = {
|
| 274 |
-
'product_id': tf.constant([self.data_processor.item_vocab.get(item_id, 0)]),
|
| 275 |
-
'category_id': tf.constant([self.data_processor.category_vocab.get(item_data['category_id'], 0)]),
|
| 276 |
-
'brand_id': tf.constant([self.data_processor.brand_vocab.get(item_data.get('brand', 'unknown'), 0)]),
|
| 277 |
-
'price': tf.constant([float(item_data.get('price', 0.0))])
|
| 278 |
-
}
|
| 279 |
-
|
| 280 |
-
# Get embedding
|
| 281 |
-
item_output = self.item_tower(item_features, training=False)
|
| 282 |
-
if isinstance(item_output, tuple):
|
| 283 |
-
item_embedding = item_output[0].numpy()[0]
|
| 284 |
-
else:
|
| 285 |
-
item_embedding = item_output.numpy()[0]
|
| 286 |
-
|
| 287 |
-
return item_embedding
|
| 288 |
-
|
| 289 |
-
def recommend_items_enhanced(self,
|
| 290 |
-
age: int,
|
| 291 |
-
gender: str,
|
| 292 |
-
income: float,
|
| 293 |
-
interaction_history: List[int] = None,
|
| 294 |
-
k: int = 10,
|
| 295 |
-
diversity_weight: float = 0.3,
|
| 296 |
-
category_boost: float = 1.5) -> List[Tuple[int, float, Dict]]:
|
| 297 |
-
"""Generate enhanced recommendations with diversity and category boosting."""
|
| 298 |
-
|
| 299 |
-
if not self.faiss_index:
|
| 300 |
-
print("β FAISS index not available")
|
| 301 |
-
return []
|
| 302 |
-
|
| 303 |
-
# Get user embedding
|
| 304 |
-
user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
|
| 305 |
-
if user_embedding is None:
|
| 306 |
-
return []
|
| 307 |
-
|
| 308 |
-
# Get candidate recommendations (more than needed for filtering)
|
| 309 |
-
candidates = self.faiss_index.search_by_embedding(user_embedding, k * 3)
|
| 310 |
-
|
| 311 |
-
# Filter out items from interaction history
|
| 312 |
-
if interaction_history:
|
| 313 |
-
history_set = set(interaction_history)
|
| 314 |
-
candidates = [(item_id, score) for item_id, score in candidates
|
| 315 |
-
if item_id not in history_set]
|
| 316 |
-
|
| 317 |
-
# Add item metadata and apply enhancements
|
| 318 |
-
enhanced_candidates = []
|
| 319 |
-
|
| 320 |
-
for item_id, similarity_score in candidates[:k * 2]:
|
| 321 |
-
# Get item info
|
| 322 |
-
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 323 |
-
if item_row.empty:
|
| 324 |
-
continue
|
| 325 |
-
|
| 326 |
-
item_info = item_row.iloc[0].to_dict()
|
| 327 |
-
|
| 328 |
-
# Enhanced scoring with multiple factors
|
| 329 |
-
final_score = similarity_score
|
| 330 |
-
|
| 331 |
-
# Category boosting based on user history
|
| 332 |
-
if interaction_history and category_boost > 1.0:
|
| 333 |
-
user_categories = self._get_user_categories(interaction_history)
|
| 334 |
-
item_category = item_info.get('category_code', '')
|
| 335 |
-
|
| 336 |
-
if item_category in user_categories:
|
| 337 |
-
category_preference = user_categories[item_category]
|
| 338 |
-
final_score *= (1 + (category_boost - 1) * category_preference)
|
| 339 |
-
|
| 340 |
-
enhanced_candidates.append((item_id, final_score, item_info))
|
| 341 |
-
|
| 342 |
-
# Sort by enhanced scores
|
| 343 |
-
enhanced_candidates.sort(key=lambda x: x[1], reverse=True)
|
| 344 |
-
|
| 345 |
-
# Apply diversity filtering
|
| 346 |
-
if diversity_weight > 0:
|
| 347 |
-
diversified_candidates = self._apply_diversity_filter(
|
| 348 |
-
enhanced_candidates, diversity_weight
|
| 349 |
-
)
|
| 350 |
-
else:
|
| 351 |
-
diversified_candidates = enhanced_candidates
|
| 352 |
-
|
| 353 |
-
return diversified_candidates[:k]
|
| 354 |
-
|
| 355 |
-
def _get_user_categories(self, interaction_history: List[int]) -> Dict[str, float]:
|
| 356 |
-
"""Get user's category preferences from history."""
|
| 357 |
-
|
| 358 |
-
category_counts = Counter()
|
| 359 |
-
|
| 360 |
-
for item_id in interaction_history:
|
| 361 |
-
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 362 |
-
if not item_row.empty:
|
| 363 |
-
category = item_row.iloc[0].get('category_code', 'Unknown')
|
| 364 |
-
category_counts[category] += 1
|
| 365 |
-
|
| 366 |
-
# Convert to preferences (percentages)
|
| 367 |
-
total = sum(category_counts.values())
|
| 368 |
-
if total == 0:
|
| 369 |
-
return {}
|
| 370 |
-
|
| 371 |
-
return {cat: count / total for cat, count in category_counts.items()}
|
| 372 |
-
|
| 373 |
-
def _apply_diversity_filter(self,
|
| 374 |
-
candidates: List[Tuple[int, float, Dict]],
|
| 375 |
-
diversity_weight: float,
|
| 376 |
-
max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
|
| 377 |
-
"""Apply diversity filtering to recommendations."""
|
| 378 |
-
|
| 379 |
-
category_counts = defaultdict(int)
|
| 380 |
-
diversified = []
|
| 381 |
-
|
| 382 |
-
for item_id, score, item_info in candidates:
|
| 383 |
-
category = item_info.get('category_code', 'Unknown')
|
| 384 |
-
|
| 385 |
-
# Apply diversity penalty
|
| 386 |
-
if category_counts[category] >= max_per_category:
|
| 387 |
-
# Penalty for over-representation
|
| 388 |
-
diversity_penalty = diversity_weight * (category_counts[category] - max_per_category + 1)
|
| 389 |
-
adjusted_score = score * (1 - diversity_penalty)
|
| 390 |
-
else:
|
| 391 |
-
adjusted_score = score
|
| 392 |
-
|
| 393 |
-
diversified.append((item_id, adjusted_score, item_info))
|
| 394 |
-
category_counts[category] += 1
|
| 395 |
-
|
| 396 |
-
# Re-sort by adjusted scores
|
| 397 |
-
diversified.sort(key=lambda x: x[1], reverse=True)
|
| 398 |
-
return diversified
|
| 399 |
-
|
| 400 |
-
def predict_rating(self,
|
| 401 |
-
age: int,
|
| 402 |
-
gender: str,
|
| 403 |
-
income: float,
|
| 404 |
-
item_id: int,
|
| 405 |
-
interaction_history: List[int] = None) -> float:
|
| 406 |
-
"""Predict rating for user-item pair."""
|
| 407 |
-
|
| 408 |
-
if self.rating_model is None:
|
| 409 |
-
return 0.5 # Default rating
|
| 410 |
-
|
| 411 |
-
# Get embeddings
|
| 412 |
-
user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
|
| 413 |
-
item_embedding = self.get_item_embedding(item_id)
|
| 414 |
-
|
| 415 |
-
if user_embedding is None or item_embedding is None:
|
| 416 |
-
return 0.5
|
| 417 |
-
|
| 418 |
-
# Concatenate embeddings
|
| 419 |
-
combined = np.concatenate([user_embedding, item_embedding])
|
| 420 |
-
combined = tf.constant([combined])
|
| 421 |
-
|
| 422 |
-
# Predict rating
|
| 423 |
-
rating = self.rating_model(combined, training=False)
|
| 424 |
-
return float(rating.numpy()[0][0])
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
def demo_enhanced_engine():
|
| 428 |
-
"""Demo the enhanced 128D recommendation engine."""
|
| 429 |
-
|
| 430 |
-
print("π ENHANCED 128D RECOMMENDATION ENGINE DEMO")
|
| 431 |
-
print("="*70)
|
| 432 |
-
|
| 433 |
-
try:
|
| 434 |
-
# Initialize engine
|
| 435 |
-
engine = Enhanced128DRecommendationEngine()
|
| 436 |
-
|
| 437 |
-
if engine.item_tower is None:
|
| 438 |
-
print("β Enhanced model not available. Please train first using:")
|
| 439 |
-
print(" python train_enhanced_model.py")
|
| 440 |
-
return
|
| 441 |
-
|
| 442 |
-
# Get real user for testing
|
| 443 |
-
real_user_selector = RealUserSelector()
|
| 444 |
-
test_users = real_user_selector.get_real_users(n=2, min_interactions=10)
|
| 445 |
-
|
| 446 |
-
for user in test_users:
|
| 447 |
-
print(f"\nπ Testing User {user['user_id']} ({user['age']}yr {user['gender']}):")
|
| 448 |
-
print(f" Income: ${user['income']:,}")
|
| 449 |
-
print(f" History: {len(user['interaction_history'])} items")
|
| 450 |
-
|
| 451 |
-
# Test enhanced recommendations
|
| 452 |
-
try:
|
| 453 |
-
recs = engine.recommend_items_enhanced(
|
| 454 |
-
age=user['age'],
|
| 455 |
-
gender=user['gender'],
|
| 456 |
-
income=user['income'],
|
| 457 |
-
interaction_history=user['interaction_history'][:20],
|
| 458 |
-
k=10,
|
| 459 |
-
diversity_weight=0.3,
|
| 460 |
-
category_boost=1.5
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
print(f" π― Enhanced Recommendations:")
|
| 464 |
-
categories = []
|
| 465 |
-
for i, (item_id, score, item_info) in enumerate(recs[:5]):
|
| 466 |
-
category = item_info.get('category_code', 'Unknown')[:30]
|
| 467 |
-
price = item_info.get('price', 0)
|
| 468 |
-
categories.append(category)
|
| 469 |
-
print(f" #{i+1} Item {item_id}: {score:.4f} | ${price:.2f} | {category}")
|
| 470 |
-
|
| 471 |
-
# Analyze diversity
|
| 472 |
-
unique_categories = len(set(categories))
|
| 473 |
-
print(f" π Diversity: {unique_categories}/{len(categories)} unique categories")
|
| 474 |
-
|
| 475 |
-
# Test rating prediction
|
| 476 |
-
if recs:
|
| 477 |
-
test_item = recs[0][0]
|
| 478 |
-
predicted_rating = engine.predict_rating(
|
| 479 |
-
age=user['age'],
|
| 480 |
-
gender=user['gender'],
|
| 481 |
-
income=user['income'],
|
| 482 |
-
item_id=test_item,
|
| 483 |
-
interaction_history=user['interaction_history'][:20]
|
| 484 |
-
)
|
| 485 |
-
print(f" β Rating prediction for item {test_item}: {predicted_rating:.3f}")
|
| 486 |
-
|
| 487 |
-
except Exception as e:
|
| 488 |
-
print(f" β Error: {e}")
|
| 489 |
-
|
| 490 |
-
print(f"\nβ
Enhanced 128D engine demo completed!")
|
| 491 |
-
|
| 492 |
-
except Exception as e:
|
| 493 |
-
print(f"β Demo failed: {e}")
|
| 494 |
-
import traceback
|
| 495 |
-
traceback.print_exc()
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
if __name__ == "__main__":
|
| 499 |
-
demo_enhanced_engine()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/inference/recommendation_engine.py
CHANGED
|
@@ -61,6 +61,50 @@ class RecommendationEngine:
|
|
| 61 |
category = np.digitize([income], self.income_thresholds[1:-1])[0]
|
| 62 |
return min(max(category, 0), 4)
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def _load_all_components(self):
|
| 65 |
"""Load all required components for inference."""
|
| 66 |
|
|
@@ -139,6 +183,10 @@ class RecommendationEngine:
|
|
| 139 |
'age': tf.constant([2]), # Adult category (26-35)
|
| 140 |
'gender': tf.constant([1]), # Male
|
| 141 |
'income': tf.constant([2]), # Middle income category
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
'item_history_embeddings': tf.constant([[[0.0] * 128] * 50]) # Changed from 64 to 128
|
| 143 |
}
|
| 144 |
_ = self.user_tower(dummy_input)
|
|
@@ -195,6 +243,10 @@ class RecommendationEngine:
|
|
| 195 |
age: int,
|
| 196 |
gender: str,
|
| 197 |
income: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
interaction_history: List[int] = None) -> Dict[str, tf.Tensor]:
|
| 199 |
"""Prepare user features for inference."""
|
| 200 |
|
|
@@ -204,9 +256,13 @@ class RecommendationEngine:
|
|
| 204 |
# Convert gender
|
| 205 |
gender_numeric = 1 if gender.lower() == 'male' else 0
|
| 206 |
|
| 207 |
-
# Categorize
|
| 208 |
age_category = self.categorize_age(age)
|
| 209 |
income_category = self.categorize_income(income)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
# Get item embeddings for history
|
| 212 |
history_embeddings = []
|
|
@@ -235,6 +291,10 @@ class RecommendationEngine:
|
|
| 235 |
'age': tf.constant([age_category]), # Categorical age (0-5)
|
| 236 |
'gender': tf.constant([gender_numeric]), # Categorical gender (0-1)
|
| 237 |
'income': tf.constant([income_category]), # Categorical income (0-4)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
'item_history_embeddings': tf.constant([history_embeddings])
|
| 239 |
}
|
| 240 |
|
|
@@ -275,14 +335,77 @@ class RecommendationEngine:
|
|
| 275 |
age: int,
|
| 276 |
gender: str,
|
| 277 |
income: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
interaction_history: List[int] = None) -> np.ndarray:
|
| 279 |
"""Get user embedding from user tower."""
|
| 280 |
|
| 281 |
-
user_features = self.prepare_user_features(age, gender, income, interaction_history)
|
| 282 |
user_embedding = self.user_tower(user_features, training=False)
|
| 283 |
|
| 284 |
return user_embedding.numpy()[0]
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
|
| 287 |
"""Get item embedding from FAISS index or item tower."""
|
| 288 |
|
|
@@ -301,14 +424,18 @@ class RecommendationEngine:
|
|
| 301 |
age: int,
|
| 302 |
gender: str,
|
| 303 |
income: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
interaction_history: List[int] = None,
|
| 305 |
k: int = 10,
|
| 306 |
exclude_history: bool = True,
|
| 307 |
category_boost: float = 1.3) -> List[Tuple[int, float, Dict]]:
|
| 308 |
"""Generate recommendations using collaborative filtering with category awareness."""
|
| 309 |
|
| 310 |
-
# Get user embedding
|
| 311 |
-
user_embedding = self.
|
| 312 |
|
| 313 |
# Find similar items using FAISS (get more candidates for boosting)
|
| 314 |
similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 4)
|
|
@@ -354,6 +481,153 @@ class RecommendationEngine:
|
|
| 354 |
|
| 355 |
return recommendations
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
def recommend_items_content_based(self,
|
| 358 |
seed_item_id: int,
|
| 359 |
k: int = 10,
|
|
@@ -431,6 +705,10 @@ class RecommendationEngine:
|
|
| 431 |
age: int,
|
| 432 |
gender: str,
|
| 433 |
income: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
interaction_history: List[int] = None,
|
| 435 |
k: int = 10,
|
| 436 |
collaborative_weight: float = 0.7) -> List[Tuple[int, float, Dict]]:
|
|
@@ -438,15 +716,16 @@ class RecommendationEngine:
|
|
| 438 |
|
| 439 |
# Get collaborative recommendations
|
| 440 |
collab_recs = self.recommend_items_collaborative(
|
| 441 |
-
age, gender, income, interaction_history, k * 2
|
| 442 |
)
|
| 443 |
|
| 444 |
-
# Get content-based recommendations from
|
| 445 |
content_recs = []
|
| 446 |
if interaction_history:
|
| 447 |
-
# Use
|
| 448 |
-
|
| 449 |
-
|
|
|
|
| 450 |
|
| 451 |
# Combine recommendations with weighted scores
|
| 452 |
item_scores = {}
|
|
@@ -484,6 +763,254 @@ class RecommendationEngine:
|
|
| 484 |
|
| 485 |
return hybrid_recommendations[:k]
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
def _get_item_info(self, item_id: int) -> Dict:
|
| 488 |
"""Get item metadata."""
|
| 489 |
|
|
@@ -512,6 +1039,10 @@ class RecommendationEngine:
|
|
| 512 |
gender: str,
|
| 513 |
income: float,
|
| 514 |
item_id: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
interaction_history: List[int] = None) -> float:
|
| 516 |
"""Predict rating for a specific user-item pair."""
|
| 517 |
|
|
@@ -519,7 +1050,7 @@ class RecommendationEngine:
|
|
| 519 |
return 0.5 # Default prediction
|
| 520 |
|
| 521 |
# Prepare user features
|
| 522 |
-
user_features = self.prepare_user_features(age, gender, income, interaction_history)
|
| 523 |
|
| 524 |
# Prepare item features
|
| 525 |
if item_id not in self.data_processor.item_vocab:
|
|
@@ -552,6 +1083,10 @@ def main():
|
|
| 552 |
'age': 32,
|
| 553 |
'gender': 'male',
|
| 554 |
'income': 75000,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
'interaction_history': [1000978, 1001588, 1001618] # Sample item IDs
|
| 556 |
}
|
| 557 |
|
|
@@ -559,20 +1094,28 @@ def main():
|
|
| 559 |
print(f"Age: {demo_user['age']}")
|
| 560 |
print(f"Gender: {demo_user['gender']}")
|
| 561 |
print(f"Income: ${demo_user['income']:,}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
print(f"Interaction history: {demo_user['interaction_history']}")
|
| 563 |
|
| 564 |
# Generate collaborative recommendations
|
| 565 |
-
print("\n=== Collaborative Filtering Recommendations ===")
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
|
| 568 |
for i, (item_id, score, info) in enumerate(collab_recs, 1):
|
| 569 |
print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
|
| 570 |
|
| 571 |
-
# Generate content-based recommendations
|
| 572 |
-
print("\n=== Content-Based Recommendations (
|
| 573 |
if demo_user['interaction_history']:
|
| 574 |
-
content_recs = engine.
|
| 575 |
-
|
| 576 |
)
|
| 577 |
|
| 578 |
for i, (item_id, score, info) in enumerate(content_recs, 1):
|
|
@@ -580,7 +1123,9 @@ def main():
|
|
| 580 |
|
| 581 |
# Generate hybrid recommendations
|
| 582 |
print("\n=== Hybrid Recommendations ===")
|
| 583 |
-
hybrid_recs = engine.recommend_items_hybrid(
|
|
|
|
|
|
|
| 584 |
|
| 585 |
for i, (item_id, score, info) in enumerate(hybrid_recs, 1):
|
| 586 |
print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
|
|
|
|
| 61 |
category = np.digitize([income], self.income_thresholds[1:-1])[0]
|
| 62 |
return min(max(category, 0), 4)
|
| 63 |
|
| 64 |
+
def categorize_profession(self, profession: str) -> int:
|
| 65 |
+
"""Categorize profession into numeric categories."""
|
| 66 |
+
profession_map = {
|
| 67 |
+
"Technology": 0,
|
| 68 |
+
"Healthcare": 1,
|
| 69 |
+
"Education": 2,
|
| 70 |
+
"Finance": 3,
|
| 71 |
+
"Retail": 4,
|
| 72 |
+
"Manufacturing": 5,
|
| 73 |
+
"Services": 6,
|
| 74 |
+
"Other": 7
|
| 75 |
+
}
|
| 76 |
+
return profession_map.get(profession, 7) # Default to "Other"
|
| 77 |
+
|
| 78 |
+
def categorize_location(self, location: str) -> int:
|
| 79 |
+
"""Categorize location into numeric categories."""
|
| 80 |
+
location_map = {
|
| 81 |
+
"Urban": 0,
|
| 82 |
+
"Suburban": 1,
|
| 83 |
+
"Rural": 2
|
| 84 |
+
}
|
| 85 |
+
return location_map.get(location, 0) # Default to "Urban"
|
| 86 |
+
|
| 87 |
+
def categorize_education_level(self, education: str) -> int:
|
| 88 |
+
"""Categorize education level into numeric categories."""
|
| 89 |
+
education_map = {
|
| 90 |
+
"High School": 0,
|
| 91 |
+
"Some College": 1,
|
| 92 |
+
"Bachelor's": 2,
|
| 93 |
+
"Master's": 3,
|
| 94 |
+
"PhD+": 4
|
| 95 |
+
}
|
| 96 |
+
return education_map.get(education, 0) # Default to "High School"
|
| 97 |
+
|
| 98 |
+
def categorize_marital_status(self, marital_status: str) -> int:
|
| 99 |
+
"""Categorize marital status into numeric categories."""
|
| 100 |
+
marital_map = {
|
| 101 |
+
"Single": 0,
|
| 102 |
+
"Married": 1,
|
| 103 |
+
"Divorced": 2,
|
| 104 |
+
"Widowed": 3
|
| 105 |
+
}
|
| 106 |
+
return marital_map.get(marital_status, 0) # Default to "Single"
|
| 107 |
+
|
| 108 |
def _load_all_components(self):
|
| 109 |
"""Load all required components for inference."""
|
| 110 |
|
|
|
|
| 183 |
'age': tf.constant([2]), # Adult category (26-35)
|
| 184 |
'gender': tf.constant([1]), # Male
|
| 185 |
'income': tf.constant([2]), # Middle income category
|
| 186 |
+
'profession': tf.constant([0]), # Technology
|
| 187 |
+
'location': tf.constant([0]), # Urban
|
| 188 |
+
'education_level': tf.constant([2]), # Bachelor's
|
| 189 |
+
'marital_status': tf.constant([1]), # Married
|
| 190 |
'item_history_embeddings': tf.constant([[[0.0] * 128] * 50]) # Changed from 64 to 128
|
| 191 |
}
|
| 192 |
_ = self.user_tower(dummy_input)
|
|
|
|
| 243 |
age: int,
|
| 244 |
gender: str,
|
| 245 |
income: float,
|
| 246 |
+
profession: str = "Other",
|
| 247 |
+
location: str = "Urban",
|
| 248 |
+
education_level: str = "High School",
|
| 249 |
+
marital_status: str = "Single",
|
| 250 |
interaction_history: List[int] = None) -> Dict[str, tf.Tensor]:
|
| 251 |
"""Prepare user features for inference."""
|
| 252 |
|
|
|
|
| 256 |
# Convert gender
|
| 257 |
gender_numeric = 1 if gender.lower() == 'male' else 0
|
| 258 |
|
| 259 |
+
# Categorize all demographics
|
| 260 |
age_category = self.categorize_age(age)
|
| 261 |
income_category = self.categorize_income(income)
|
| 262 |
+
profession_category = self.categorize_profession(profession)
|
| 263 |
+
location_category = self.categorize_location(location)
|
| 264 |
+
education_category = self.categorize_education_level(education_level)
|
| 265 |
+
marital_category = self.categorize_marital_status(marital_status)
|
| 266 |
|
| 267 |
# Get item embeddings for history
|
| 268 |
history_embeddings = []
|
|
|
|
| 291 |
'age': tf.constant([age_category]), # Categorical age (0-5)
|
| 292 |
'gender': tf.constant([gender_numeric]), # Categorical gender (0-1)
|
| 293 |
'income': tf.constant([income_category]), # Categorical income (0-4)
|
| 294 |
+
'profession': tf.constant([profession_category]), # Categorical profession (0-7)
|
| 295 |
+
'location': tf.constant([location_category]), # Categorical location (0-2)
|
| 296 |
+
'education_level': tf.constant([education_category]), # Categorical education (0-4)
|
| 297 |
+
'marital_status': tf.constant([marital_category]), # Categorical marital status (0-3)
|
| 298 |
'item_history_embeddings': tf.constant([history_embeddings])
|
| 299 |
}
|
| 300 |
|
|
|
|
| 335 |
age: int,
|
| 336 |
gender: str,
|
| 337 |
income: float,
|
| 338 |
+
profession: str = "Other",
|
| 339 |
+
location: str = "Urban",
|
| 340 |
+
education_level: str = "High School",
|
| 341 |
+
marital_status: str = "Single",
|
| 342 |
interaction_history: List[int] = None) -> np.ndarray:
|
| 343 |
"""Get user embedding from user tower."""
|
| 344 |
|
| 345 |
+
user_features = self.prepare_user_features(age, gender, income, profession, location, education_level, marital_status, interaction_history)
|
| 346 |
user_embedding = self.user_tower(user_features, training=False)
|
| 347 |
|
| 348 |
return user_embedding.numpy()[0]
|
| 349 |
|
| 350 |
+
def get_user_embedding_enhanced(self,
|
| 351 |
+
age: int,
|
| 352 |
+
gender: str,
|
| 353 |
+
income: float,
|
| 354 |
+
profession: str = "Other",
|
| 355 |
+
location: str = "Urban",
|
| 356 |
+
education_level: str = "High School",
|
| 357 |
+
marital_status: str = "Single",
|
| 358 |
+
interaction_history: List[int] = None) -> np.ndarray:
|
| 359 |
+
"""Enhanced user embedding that handles zero interactions better."""
|
| 360 |
+
|
| 361 |
+
# Get base embedding
|
| 362 |
+
base_embedding = self.get_user_embedding(
|
| 363 |
+
age, gender, income, profession, location, education_level, marital_status, interaction_history
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Check if this is a zero-interaction user
|
| 367 |
+
has_interactions = interaction_history and len(interaction_history) > 0
|
| 368 |
+
|
| 369 |
+
if not has_interactions:
|
| 370 |
+
# For zero interactions, amplify the demographic component
|
| 371 |
+
# This is a heuristic fix until we retrain the model
|
| 372 |
+
|
| 373 |
+
# Create demographic-enhanced embedding
|
| 374 |
+
demographic_mask = np.ones_like(base_embedding)
|
| 375 |
+
|
| 376 |
+
# Amplify first 50% of dimensions (likely demographic-influenced)
|
| 377 |
+
mid_point = len(base_embedding) // 2
|
| 378 |
+
demographic_mask[:mid_point] *= 3.0 # Strong amplification
|
| 379 |
+
|
| 380 |
+
# Reduce influence of latter dimensions (likely history-influenced)
|
| 381 |
+
demographic_mask[mid_point:] *= 0.2 # Strong reduction
|
| 382 |
+
|
| 383 |
+
enhanced_embedding = base_embedding * demographic_mask
|
| 384 |
+
|
| 385 |
+
# Add demographic-specific variation to differentiate profiles
|
| 386 |
+
demographic_hash = (
|
| 387 |
+
age * 1000 +
|
| 388 |
+
(1 if gender.lower() == 'male' else 0) * 100 +
|
| 389 |
+
int(income / 10000) * 10 +
|
| 390 |
+
self.categorize_profession(profession) * 7 +
|
| 391 |
+
self.categorize_location(location) * 3 +
|
| 392 |
+
self.categorize_education_level(education_level) * 5 +
|
| 393 |
+
self.categorize_marital_status(marital_status) * 2
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
np.random.seed(demographic_hash % 2**32) # Reproducible noise
|
| 397 |
+
demographic_noise = np.random.normal(0, 0.02, base_embedding.shape) # Increased noise
|
| 398 |
+
enhanced_embedding += demographic_noise
|
| 399 |
+
|
| 400 |
+
# Renormalize
|
| 401 |
+
enhanced_embedding = enhanced_embedding / np.linalg.norm(enhanced_embedding)
|
| 402 |
+
|
| 403 |
+
print(f"Enhanced embedding for zero interactions: age={age}, gender={gender}, profession={profession}")
|
| 404 |
+
|
| 405 |
+
return enhanced_embedding.astype(np.float32)
|
| 406 |
+
|
| 407 |
+
return base_embedding
|
| 408 |
+
|
| 409 |
def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
|
| 410 |
"""Get item embedding from FAISS index or item tower."""
|
| 411 |
|
|
|
|
| 424 |
age: int,
|
| 425 |
gender: str,
|
| 426 |
income: float,
|
| 427 |
+
profession: str = "Other",
|
| 428 |
+
location: str = "Urban",
|
| 429 |
+
education_level: str = "High School",
|
| 430 |
+
marital_status: str = "Single",
|
| 431 |
interaction_history: List[int] = None,
|
| 432 |
k: int = 10,
|
| 433 |
exclude_history: bool = True,
|
| 434 |
category_boost: float = 1.3) -> List[Tuple[int, float, Dict]]:
|
| 435 |
"""Generate recommendations using collaborative filtering with category awareness."""
|
| 436 |
|
| 437 |
+
# Get enhanced user embedding (better for zero interactions)
|
| 438 |
+
user_embedding = self.get_user_embedding_enhanced(age, gender, income, profession, location, education_level, marital_status, interaction_history)
|
| 439 |
|
| 440 |
# Find similar items using FAISS (get more candidates for boosting)
|
| 441 |
similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 4)
|
|
|
|
| 481 |
|
| 482 |
return recommendations
|
| 483 |
|
| 484 |
+
def _aggregate_user_history_embedding(self,
|
| 485 |
+
interaction_history: List[int],
|
| 486 |
+
aggregation_method: str = "weighted_mean") -> Optional[np.ndarray]:
|
| 487 |
+
"""Aggregate user's interaction history into a single embedding vector."""
|
| 488 |
+
|
| 489 |
+
if not interaction_history:
|
| 490 |
+
return None
|
| 491 |
+
|
| 492 |
+
# Get embeddings for items in history
|
| 493 |
+
item_embeddings = []
|
| 494 |
+
valid_items = []
|
| 495 |
+
|
| 496 |
+
for item_id in interaction_history:
|
| 497 |
+
embedding = self.faiss_index.get_item_embedding(item_id)
|
| 498 |
+
if embedding is not None:
|
| 499 |
+
item_embeddings.append(embedding)
|
| 500 |
+
valid_items.append(item_id)
|
| 501 |
+
|
| 502 |
+
if not item_embeddings:
|
| 503 |
+
print(f"No valid embeddings found for interaction history: {interaction_history}")
|
| 504 |
+
return None
|
| 505 |
+
|
| 506 |
+
item_embeddings = np.array(item_embeddings)
|
| 507 |
+
print(f"Aggregating {len(item_embeddings)} item embeddings using {aggregation_method}")
|
| 508 |
+
|
| 509 |
+
# Apply aggregation method
|
| 510 |
+
if aggregation_method == "mean":
|
| 511 |
+
# Simple mean pooling
|
| 512 |
+
aggregated = np.mean(item_embeddings, axis=0)
|
| 513 |
+
|
| 514 |
+
elif aggregation_method == "weighted_mean":
|
| 515 |
+
# Weight recent interactions higher (exponential decay)
|
| 516 |
+
weights = np.exp(np.linspace(-1, 0, len(item_embeddings))) # More recent = higher weight
|
| 517 |
+
weights = weights / np.sum(weights) # Normalize weights
|
| 518 |
+
aggregated = np.average(item_embeddings, axis=0, weights=weights)
|
| 519 |
+
print(f"Applied weighted mean with weights: {weights[-3:]} (showing last 3)")
|
| 520 |
+
|
| 521 |
+
elif aggregation_method == "max":
|
| 522 |
+
# Element-wise maximum pooling
|
| 523 |
+
aggregated = np.max(item_embeddings, axis=0)
|
| 524 |
+
|
| 525 |
+
else:
|
| 526 |
+
raise ValueError(f"Unknown aggregation method: {aggregation_method}")
|
| 527 |
+
|
| 528 |
+
# L2 normalize the aggregated embedding
|
| 529 |
+
aggregated = aggregated / np.linalg.norm(aggregated)
|
| 530 |
+
|
| 531 |
+
return aggregated.astype('float32')
|
| 532 |
+
|
| 533 |
+
def recommend_items_content_based_from_history(self,
|
| 534 |
+
interaction_history: List[int],
|
| 535 |
+
k: int = 10,
|
| 536 |
+
aggregation_method: str = "weighted_mean",
|
| 537 |
+
same_category_ratio: float = None) -> List[Tuple[int, float, Dict]]:
|
| 538 |
+
"""Generate recommendations using content-based filtering from aggregated user history."""
|
| 539 |
+
|
| 540 |
+
# Aggregate user's interaction history
|
| 541 |
+
aggregated_embedding = self._aggregate_user_history_embedding(
|
| 542 |
+
interaction_history, aggregation_method
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if aggregated_embedding is None:
|
| 546 |
+
print("Could not create aggregated embedding from interaction history")
|
| 547 |
+
return []
|
| 548 |
+
|
| 549 |
+
if same_category_ratio is None:
|
| 550 |
+
# Direct ANN search with aggregated embedding
|
| 551 |
+
similar_items = self.faiss_index.search_by_embedding(aggregated_embedding, k)
|
| 552 |
+
recommendations = []
|
| 553 |
+
|
| 554 |
+
# Filter out items already in interaction history
|
| 555 |
+
interaction_set = set(interaction_history)
|
| 556 |
+
|
| 557 |
+
for item_id, score in similar_items:
|
| 558 |
+
if item_id not in interaction_set: # Exclude already interacted items
|
| 559 |
+
item_info = self._get_item_info(item_id)
|
| 560 |
+
recommendations.append((item_id, score, item_info))
|
| 561 |
+
|
| 562 |
+
if len(recommendations) >= k:
|
| 563 |
+
break
|
| 564 |
+
|
| 565 |
+
print(f"Found {len(recommendations)} content-based recommendations from aggregated history")
|
| 566 |
+
return recommendations
|
| 567 |
+
|
| 568 |
+
else:
|
| 569 |
+
# Category-aware approach with aggregated embedding
|
| 570 |
+
print(f"Finding similar items with {same_category_ratio*100}% category constraint from aggregated history")
|
| 571 |
+
|
| 572 |
+
# Analyze user's category preferences from interaction history
|
| 573 |
+
user_categories = {}
|
| 574 |
+
total_interactions = len(interaction_history)
|
| 575 |
+
|
| 576 |
+
for item_id in interaction_history:
|
| 577 |
+
item_info = self._get_item_info(item_id)
|
| 578 |
+
category = item_info.get('category_code', '')
|
| 579 |
+
if category:
|
| 580 |
+
user_categories[category] = user_categories.get(category, 0) + 1
|
| 581 |
+
|
| 582 |
+
# Convert to percentages
|
| 583 |
+
for category in user_categories:
|
| 584 |
+
user_categories[category] = user_categories[category] / total_interactions
|
| 585 |
+
|
| 586 |
+
print(f"User category preferences: {user_categories}")
|
| 587 |
+
|
| 588 |
+
# Get more candidates for category filtering
|
| 589 |
+
candidate_items = self.faiss_index.search_by_embedding(aggregated_embedding, k * 3)
|
| 590 |
+
interaction_set = set(interaction_history)
|
| 591 |
+
|
| 592 |
+
# Separate by category alignment with user preferences
|
| 593 |
+
preferred_category_items = []
|
| 594 |
+
other_category_items = []
|
| 595 |
+
|
| 596 |
+
for item_id, score in candidate_items:
|
| 597 |
+
if item_id in interaction_set:
|
| 598 |
+
continue # Skip already interacted items
|
| 599 |
+
|
| 600 |
+
item_info = self._get_item_info(item_id)
|
| 601 |
+
item_category = item_info.get('category_code', '')
|
| 602 |
+
|
| 603 |
+
# Check if item category matches user's preferred categories
|
| 604 |
+
if item_category in user_categories:
|
| 605 |
+
preferred_category_items.append((item_id, score, item_info))
|
| 606 |
+
else:
|
| 607 |
+
other_category_items.append((item_id, score, item_info))
|
| 608 |
+
|
| 609 |
+
# Calculate target distribution
|
| 610 |
+
preferred_count = int(k * same_category_ratio)
|
| 611 |
+
other_count = k - preferred_count
|
| 612 |
+
|
| 613 |
+
print(f"Target: {preferred_count} from preferred categories, {other_count} for exploration")
|
| 614 |
+
|
| 615 |
+
# Build balanced recommendations
|
| 616 |
+
recommendations = []
|
| 617 |
+
recommendations.extend(preferred_category_items[:preferred_count])
|
| 618 |
+
recommendations.extend(other_category_items[:other_count])
|
| 619 |
+
|
| 620 |
+
# Fill remaining slots with best available items
|
| 621 |
+
if len(recommendations) < k:
|
| 622 |
+
remaining_items = (preferred_category_items[preferred_count:] +
|
| 623 |
+
other_category_items[other_count:])
|
| 624 |
+
remaining_items.sort(key=lambda x: x[1], reverse=True) # Sort by score
|
| 625 |
+
needed = k - len(recommendations)
|
| 626 |
+
recommendations.extend(remaining_items[:needed])
|
| 627 |
+
|
| 628 |
+
print(f"Final recommendations: {len(recommendations)} items")
|
| 629 |
+
return recommendations[:k]
|
| 630 |
+
|
| 631 |
def recommend_items_content_based(self,
|
| 632 |
seed_item_id: int,
|
| 633 |
k: int = 10,
|
|
|
|
| 705 |
age: int,
|
| 706 |
gender: str,
|
| 707 |
income: float,
|
| 708 |
+
profession: str = "Other",
|
| 709 |
+
location: str = "Urban",
|
| 710 |
+
education_level: str = "High School",
|
| 711 |
+
marital_status: str = "Single",
|
| 712 |
interaction_history: List[int] = None,
|
| 713 |
k: int = 10,
|
| 714 |
collaborative_weight: float = 0.7) -> List[Tuple[int, float, Dict]]:
|
|
|
|
| 716 |
|
| 717 |
# Get collaborative recommendations
|
| 718 |
collab_recs = self.recommend_items_collaborative(
|
| 719 |
+
age, gender, income, profession, location, education_level, marital_status, interaction_history, k * 2
|
| 720 |
)
|
| 721 |
|
| 722 |
+
# Get content-based recommendations from aggregated user history
|
| 723 |
content_recs = []
|
| 724 |
if interaction_history:
|
| 725 |
+
# Use aggregated history embedding instead of single recent item
|
| 726 |
+
content_recs = self.recommend_items_content_based_from_history(
|
| 727 |
+
interaction_history, k, aggregation_method="weighted_mean"
|
| 728 |
+
)
|
| 729 |
|
| 730 |
# Combine recommendations with weighted scores
|
| 731 |
item_scores = {}
|
|
|
|
| 763 |
|
| 764 |
return hybrid_recommendations[:k]
|
| 765 |
|
| 766 |
+
def recommend_items_category_boosted(self,
|
| 767 |
+
age: int,
|
| 768 |
+
gender: str,
|
| 769 |
+
income: float,
|
| 770 |
+
profession: str = "Other",
|
| 771 |
+
location: str = "Urban",
|
| 772 |
+
education_level: str = "High School",
|
| 773 |
+
marital_status: str = "Single",
|
| 774 |
+
interaction_history: List[int] = None,
|
| 775 |
+
k: int = 10,
|
| 776 |
+
exclude_history: bool = True) -> List[Tuple[int, float, Dict]]:
|
| 777 |
+
"""Generate category-boosted recommendations ensuring 50% from user's interacted categories."""
|
| 778 |
+
|
| 779 |
+
if not interaction_history or len(interaction_history) == 0:
|
| 780 |
+
# Fallback to collaborative filtering if no interaction history
|
| 781 |
+
return self.recommend_items_collaborative(
|
| 782 |
+
age, gender, income, profession, location, education_level, marital_status,
|
| 783 |
+
interaction_history, k, exclude_history
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# Step 1: Calculate category percentages from interaction history
|
| 787 |
+
category_percentages = self._calculate_category_percentages(interaction_history)
|
| 788 |
+
|
| 789 |
+
if not category_percentages:
|
| 790 |
+
# Fallback if no categories found
|
| 791 |
+
return self.recommend_items_collaborative(
|
| 792 |
+
age, gender, income, profession, location, education_level, marital_status,
|
| 793 |
+
interaction_history, k, exclude_history
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
# Step 2: Get enhanced user embedding and do wide search (increased for better subcategory coverage)
|
| 797 |
+
user_embedding = self.get_user_embedding_enhanced(age, gender, income, profession, location, education_level, marital_status, interaction_history)
|
| 798 |
+
similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 10) # Increased from k*6 to k*10
|
| 799 |
+
|
| 800 |
+
# Step 3: Organize candidates by subcategory with parent fallback
|
| 801 |
+
category_candidates = {category: [] for category in category_percentages.keys()}
|
| 802 |
+
parent_category_mapping = {} # Track parent categories for fallback
|
| 803 |
+
other_candidates = []
|
| 804 |
+
history_set = set(interaction_history) if exclude_history else set()
|
| 805 |
+
|
| 806 |
+
# Build parent category mapping for fallback
|
| 807 |
+
for subcategory in category_percentages.keys():
|
| 808 |
+
if '.' in subcategory:
|
| 809 |
+
parent = subcategory.split('.')[0]
|
| 810 |
+
if parent not in parent_category_mapping:
|
| 811 |
+
parent_category_mapping[parent] = []
|
| 812 |
+
parent_category_mapping[parent].append(subcategory)
|
| 813 |
+
|
| 814 |
+
for item_id, score in similar_items:
|
| 815 |
+
if item_id in history_set:
|
| 816 |
+
continue
|
| 817 |
+
|
| 818 |
+
# Get item category
|
| 819 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 820 |
+
if len(item_row) > 0:
|
| 821 |
+
full_item_category = item_row.iloc[0]['category_code']
|
| 822 |
+
|
| 823 |
+
# Extract 2-level subcategory for matching
|
| 824 |
+
if '.' in full_item_category:
|
| 825 |
+
category_parts = full_item_category.split('.')
|
| 826 |
+
if len(category_parts) >= 2:
|
| 827 |
+
item_subcategory = f"{category_parts[0]}.{category_parts[1]}"
|
| 828 |
+
else:
|
| 829 |
+
item_subcategory = category_parts[0]
|
| 830 |
+
else:
|
| 831 |
+
item_subcategory = full_item_category
|
| 832 |
+
|
| 833 |
+
# Try exact subcategory match first
|
| 834 |
+
if item_subcategory in category_percentages:
|
| 835 |
+
category_candidates[item_subcategory].append((item_id, score))
|
| 836 |
+
else:
|
| 837 |
+
# Fallback: try parent category match
|
| 838 |
+
parent_category = item_subcategory.split('.')[0] if '.' in item_subcategory else item_subcategory
|
| 839 |
+
matched = False
|
| 840 |
+
|
| 841 |
+
if parent_category in parent_category_mapping:
|
| 842 |
+
# Add to the first subcategory of this parent (round-robin could be improved later)
|
| 843 |
+
target_subcategory = parent_category_mapping[parent_category][0]
|
| 844 |
+
category_candidates[target_subcategory].append((item_id, score))
|
| 845 |
+
matched = True
|
| 846 |
+
|
| 847 |
+
if not matched:
|
| 848 |
+
other_candidates.append((item_id, score))
|
| 849 |
+
|
| 850 |
+
# Step 4: Calculate target counts for each subcategory (50% distributed proportionally)
|
| 851 |
+
category_target_count = max(1, k // 2) # At least 50% from user categories
|
| 852 |
+
|
| 853 |
+
# Calculate proportional distribution with proper rounding
|
| 854 |
+
category_counts = self._calculate_proportional_distribution(
|
| 855 |
+
category_percentages, category_target_count
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
# Step 5: Select items with round-robin filling and rebalancing
|
| 859 |
+
selected_recommendations = []
|
| 860 |
+
|
| 861 |
+
# Fill from user's categories with rebalancing for insufficient candidates
|
| 862 |
+
actual_selections = {}
|
| 863 |
+
unused_allocations = {}
|
| 864 |
+
|
| 865 |
+
for category, target_count in category_counts.items():
|
| 866 |
+
candidates = sorted(category_candidates[category], key=lambda x: x[1], reverse=True)
|
| 867 |
+
available_count = len(candidates)
|
| 868 |
+
selected_count = min(target_count, available_count)
|
| 869 |
+
|
| 870 |
+
print(f"[DEBUG] Category {category}: target={target_count}, available={available_count}, selected={selected_count}")
|
| 871 |
+
|
| 872 |
+
actual_selections[category] = selected_count
|
| 873 |
+
if selected_count < target_count:
|
| 874 |
+
unused_allocations[category] = target_count - selected_count
|
| 875 |
+
|
| 876 |
+
# Select items from this category
|
| 877 |
+
for i in range(selected_count):
|
| 878 |
+
item_id, score = candidates[i]
|
| 879 |
+
item_info = self._get_item_info(item_id)
|
| 880 |
+
selected_recommendations.append((item_id, score, item_info))
|
| 881 |
+
|
| 882 |
+
# Step 6: Redistribute unused allocations proportionally
|
| 883 |
+
total_unused = sum(unused_allocations.values())
|
| 884 |
+
if total_unused > 0:
|
| 885 |
+
print(f"[DEBUG] Redistributing {total_unused} unused slots")
|
| 886 |
+
|
| 887 |
+
# Find categories with remaining candidates for redistribution
|
| 888 |
+
categories_with_extras = {}
|
| 889 |
+
for category, candidates in category_candidates.items():
|
| 890 |
+
used_count = actual_selections.get(category, 0)
|
| 891 |
+
available_extras = len(candidates) - used_count
|
| 892 |
+
if available_extras > 0:
|
| 893 |
+
categories_with_extras[category] = available_extras
|
| 894 |
+
|
| 895 |
+
# Redistribute based on original proportions and availability
|
| 896 |
+
redistributed = 0
|
| 897 |
+
for category in sorted(categories_with_extras.keys(), key=lambda c: category_percentages.get(c, 0), reverse=True):
|
| 898 |
+
if redistributed >= total_unused:
|
| 899 |
+
break
|
| 900 |
+
|
| 901 |
+
extra_slots = min(unused_allocations.get(category, 0) + 1, categories_with_extras[category])
|
| 902 |
+
candidates = sorted(category_candidates[category], key=lambda x: x[1], reverse=True)
|
| 903 |
+
used_count = actual_selections.get(category, 0)
|
| 904 |
+
|
| 905 |
+
for i in range(used_count, min(used_count + extra_slots, len(candidates))):
|
| 906 |
+
if redistributed >= total_unused:
|
| 907 |
+
break
|
| 908 |
+
item_id, score = candidates[i]
|
| 909 |
+
item_info = self._get_item_info(item_id)
|
| 910 |
+
selected_recommendations.append((item_id, score, item_info))
|
| 911 |
+
redistributed += 1
|
| 912 |
+
|
| 913 |
+
# Step 7: Fill remaining slots with diverse recommendations
|
| 914 |
+
remaining_slots = k - len(selected_recommendations)
|
| 915 |
+
if remaining_slots > 0:
|
| 916 |
+
# Collect all unused candidates (both from user categories and other categories)
|
| 917 |
+
all_remaining = []
|
| 918 |
+
|
| 919 |
+
# Add unused items from user categories
|
| 920 |
+
for category, candidates in category_candidates.items():
|
| 921 |
+
used_count = len([rec for rec in selected_recommendations if rec[2].get('category_code', '').startswith(category.split('.')[0])])
|
| 922 |
+
sorted_candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
|
| 923 |
+
for i in range(used_count, len(sorted_candidates)):
|
| 924 |
+
all_remaining.append(sorted_candidates[i])
|
| 925 |
+
|
| 926 |
+
# Add items from other categories
|
| 927 |
+
all_remaining.extend(other_candidates)
|
| 928 |
+
|
| 929 |
+
# Sort by score and take best remaining
|
| 930 |
+
all_remaining.sort(key=lambda x: x[1], reverse=True)
|
| 931 |
+
|
| 932 |
+
print(f"[DEBUG] Filling {remaining_slots} remaining slots from {len(all_remaining)} candidates")
|
| 933 |
+
|
| 934 |
+
for i in range(min(remaining_slots, len(all_remaining))):
|
| 935 |
+
item_id, score = all_remaining[i]
|
| 936 |
+
item_info = self._get_item_info(item_id)
|
| 937 |
+
selected_recommendations.append((item_id, score, item_info))
|
| 938 |
+
|
| 939 |
+
# Step 7: Sort final recommendations by score and return top k
|
| 940 |
+
selected_recommendations.sort(key=lambda x: x[1], reverse=True)
|
| 941 |
+
return selected_recommendations[:k]
|
| 942 |
+
|
| 943 |
+
def _calculate_category_percentages(self, interaction_history: List[int]) -> Dict[str, float]:
|
| 944 |
+
"""Calculate subcategory percentages from interaction history (2-level depth)."""
|
| 945 |
+
if not interaction_history:
|
| 946 |
+
return {}
|
| 947 |
+
|
| 948 |
+
category_counts = {}
|
| 949 |
+
total_interactions = 0
|
| 950 |
+
|
| 951 |
+
for item_id in interaction_history:
|
| 952 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 953 |
+
if len(item_row) > 0:
|
| 954 |
+
full_category = item_row.iloc[0]['category_code']
|
| 955 |
+
|
| 956 |
+
# Use 2-level subcategory (e.g., "computers.components" from "computers.components.memory")
|
| 957 |
+
if '.' in full_category:
|
| 958 |
+
category_parts = full_category.split('.')
|
| 959 |
+
if len(category_parts) >= 2:
|
| 960 |
+
subcategory = f"{category_parts[0]}.{category_parts[1]}"
|
| 961 |
+
else:
|
| 962 |
+
subcategory = category_parts[0] # Fallback to top-level if only one part
|
| 963 |
+
else:
|
| 964 |
+
subcategory = full_category
|
| 965 |
+
|
| 966 |
+
category_counts[subcategory] = category_counts.get(subcategory, 0) + 1
|
| 967 |
+
total_interactions += 1
|
| 968 |
+
|
| 969 |
+
# Convert to percentages
|
| 970 |
+
category_percentages = {}
|
| 971 |
+
for category, count in category_counts.items():
|
| 972 |
+
category_percentages[category] = (count / total_interactions) * 100
|
| 973 |
+
|
| 974 |
+
return category_percentages
|
| 975 |
+
|
| 976 |
+
def _calculate_proportional_distribution(self, category_percentages: Dict[str, float],
|
| 977 |
+
total_target: int) -> Dict[str, int]:
|
| 978 |
+
"""Calculate proportional distribution with proper rounding and no minimum distortion."""
|
| 979 |
+
if not category_percentages or total_target <= 0:
|
| 980 |
+
return {}
|
| 981 |
+
|
| 982 |
+
# Calculate raw allocations (without minimum guarantee)
|
| 983 |
+
total_percentage = sum(category_percentages.values())
|
| 984 |
+
raw_allocations = {}
|
| 985 |
+
remainders = {}
|
| 986 |
+
|
| 987 |
+
for category, percentage in category_percentages.items():
|
| 988 |
+
if total_percentage > 0:
|
| 989 |
+
raw_allocation = (percentage / total_percentage) * total_target
|
| 990 |
+
raw_allocations[category] = int(raw_allocation) # Floor
|
| 991 |
+
remainders[category] = raw_allocation - int(raw_allocation) # Remainder
|
| 992 |
+
else:
|
| 993 |
+
raw_allocations[category] = 0
|
| 994 |
+
remainders[category] = 0
|
| 995 |
+
|
| 996 |
+
# Distribute remaining slots based on largest remainders
|
| 997 |
+
allocated_so_far = sum(raw_allocations.values())
|
| 998 |
+
remaining_slots = total_target - allocated_so_far
|
| 999 |
+
|
| 1000 |
+
# Sort categories by remainder (largest first) to distribute remaining slots
|
| 1001 |
+
sorted_by_remainder = sorted(remainders.items(), key=lambda x: x[1], reverse=True)
|
| 1002 |
+
|
| 1003 |
+
for i in range(remaining_slots):
|
| 1004 |
+
if i < len(sorted_by_remainder):
|
| 1005 |
+
category_to_increment = sorted_by_remainder[i][0]
|
| 1006 |
+
raw_allocations[category_to_increment] += 1
|
| 1007 |
+
|
| 1008 |
+
# Filter out zero allocations (no artificial minimum guarantee)
|
| 1009 |
+
final_allocations = {cat: count for cat, count in raw_allocations.items() if count > 0}
|
| 1010 |
+
|
| 1011 |
+
print(f"[DEBUG] Proportional distribution: target={total_target}, allocations={final_allocations}")
|
| 1012 |
+
return final_allocations
|
| 1013 |
+
|
| 1014 |
def _get_item_info(self, item_id: int) -> Dict:
|
| 1015 |
"""Get item metadata."""
|
| 1016 |
|
|
|
|
| 1039 |
gender: str,
|
| 1040 |
income: float,
|
| 1041 |
item_id: int,
|
| 1042 |
+
profession: str = "Other",
|
| 1043 |
+
location: str = "Urban",
|
| 1044 |
+
education_level: str = "High School",
|
| 1045 |
+
marital_status: str = "Single",
|
| 1046 |
interaction_history: List[int] = None) -> float:
|
| 1047 |
"""Predict rating for a specific user-item pair."""
|
| 1048 |
|
|
|
|
| 1050 |
return 0.5 # Default prediction
|
| 1051 |
|
| 1052 |
# Prepare user features
|
| 1053 |
+
user_features = self.prepare_user_features(age, gender, income, profession, location, education_level, marital_status, interaction_history)
|
| 1054 |
|
| 1055 |
# Prepare item features
|
| 1056 |
if item_id not in self.data_processor.item_vocab:
|
|
|
|
| 1083 |
'age': 32,
|
| 1084 |
'gender': 'male',
|
| 1085 |
'income': 75000,
|
| 1086 |
+
'profession': 'Technology',
|
| 1087 |
+
'location': 'Urban',
|
| 1088 |
+
'education_level': "Bachelor's",
|
| 1089 |
+
'marital_status': 'Married',
|
| 1090 |
'interaction_history': [1000978, 1001588, 1001618] # Sample item IDs
|
| 1091 |
}
|
| 1092 |
|
|
|
|
| 1094 |
print(f"Age: {demo_user['age']}")
|
| 1095 |
print(f"Gender: {demo_user['gender']}")
|
| 1096 |
print(f"Income: ${demo_user['income']:,}")
|
| 1097 |
+
print(f"Profession: {demo_user['profession']}")
|
| 1098 |
+
print(f"Location: {demo_user['location']}")
|
| 1099 |
+
print(f"Education: {demo_user['education_level']}")
|
| 1100 |
+
print(f"Marital Status: {demo_user['marital_status']}")
|
| 1101 |
print(f"Interaction history: {demo_user['interaction_history']}")
|
| 1102 |
|
| 1103 |
# Generate collaborative recommendations
|
| 1104 |
+
print("\n=== Collaborative Filtering Recommendations ===")
|
| 1105 |
+
# Extract demographics and history separately to avoid conflicts
|
| 1106 |
+
demo_kwargs = {k: v for k, v in demo_user.items() if k != 'interaction_history'}
|
| 1107 |
+
collab_recs = engine.recommend_items_collaborative(
|
| 1108 |
+
**demo_kwargs, interaction_history=demo_user['interaction_history'], k=5
|
| 1109 |
+
)
|
| 1110 |
|
| 1111 |
for i, (item_id, score, info) in enumerate(collab_recs, 1):
|
| 1112 |
print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
|
| 1113 |
|
| 1114 |
+
# Generate content-based recommendations from aggregated history
|
| 1115 |
+
print("\n=== Content-Based Recommendations (from aggregated user history) ===")
|
| 1116 |
if demo_user['interaction_history']:
|
| 1117 |
+
content_recs = engine.recommend_items_content_based_from_history(
|
| 1118 |
+
interaction_history=demo_user['interaction_history'], k=5
|
| 1119 |
)
|
| 1120 |
|
| 1121 |
for i, (item_id, score, info) in enumerate(content_recs, 1):
|
|
|
|
| 1123 |
|
| 1124 |
# Generate hybrid recommendations
|
| 1125 |
print("\n=== Hybrid Recommendations ===")
|
| 1126 |
+
hybrid_recs = engine.recommend_items_hybrid(
|
| 1127 |
+
**demo_kwargs, interaction_history=demo_user['interaction_history'], k=5
|
| 1128 |
+
)
|
| 1129 |
|
| 1130 |
for i, (item_id, score, info) in enumerate(hybrid_recs, 1):
|
| 1131 |
print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
|
src/models/enhanced_two_tower.py
DELETED
|
@@ -1,574 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Enhanced two-tower model with embedding diversity regularization and improved discrimination.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import tensorflow as tf
|
| 7 |
-
import tensorflow_recommenders as tfrs
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class EmbeddingDiversityRegularizer(tf.keras.layers.Layer):
|
| 12 |
-
"""Regularizer to prevent embedding collapse by enforcing diversity."""
|
| 13 |
-
|
| 14 |
-
def __init__(self, diversity_weight=0.01, orthogonality_weight=0.05, **kwargs):
|
| 15 |
-
super().__init__(**kwargs)
|
| 16 |
-
self.diversity_weight = diversity_weight
|
| 17 |
-
self.orthogonality_weight = orthogonality_weight
|
| 18 |
-
|
| 19 |
-
def call(self, embeddings):
|
| 20 |
-
"""Apply diversity regularization to embeddings."""
|
| 21 |
-
batch_size = tf.shape(embeddings)[0]
|
| 22 |
-
|
| 23 |
-
# Compute pairwise cosine similarities
|
| 24 |
-
normalized_embeddings = tf.nn.l2_normalize(embeddings, axis=1)
|
| 25 |
-
similarity_matrix = tf.linalg.matmul(
|
| 26 |
-
normalized_embeddings, normalized_embeddings, transpose_b=True
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Remove diagonal (self-similarities)
|
| 30 |
-
mask = 1.0 - tf.eye(batch_size)
|
| 31 |
-
masked_similarities = similarity_matrix * mask
|
| 32 |
-
|
| 33 |
-
# Diversity loss: penalize high similarities between different embeddings
|
| 34 |
-
diversity_loss = tf.reduce_mean(tf.square(masked_similarities))
|
| 35 |
-
|
| 36 |
-
# Orthogonality loss: encourage embeddings to be orthogonal
|
| 37 |
-
identity_target = tf.eye(batch_size)
|
| 38 |
-
orthogonality_loss = tf.reduce_mean(
|
| 39 |
-
tf.square(similarity_matrix - identity_target)
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
# Add as regularization losses
|
| 43 |
-
self.add_loss(self.diversity_weight * diversity_loss)
|
| 44 |
-
self.add_loss(self.orthogonality_weight * orthogonality_loss)
|
| 45 |
-
|
| 46 |
-
return embeddings
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class AdaptiveTemperatureScaling(tf.keras.layers.Layer):
|
| 50 |
-
"""Advanced temperature scaling with learned parameters."""
|
| 51 |
-
|
| 52 |
-
def __init__(self, initial_temperature=1.0, min_temp=0.1, max_temp=5.0, **kwargs):
|
| 53 |
-
super().__init__(**kwargs)
|
| 54 |
-
self.initial_temperature = initial_temperature
|
| 55 |
-
self.min_temp = min_temp
|
| 56 |
-
self.max_temp = max_temp
|
| 57 |
-
|
| 58 |
-
def build(self, input_shape):
|
| 59 |
-
# Learnable temperature with constraints
|
| 60 |
-
self.raw_temperature = self.add_weight(
|
| 61 |
-
name='raw_temperature',
|
| 62 |
-
shape=(),
|
| 63 |
-
initializer=tf.keras.initializers.Constant(
|
| 64 |
-
np.log(self.initial_temperature - self.min_temp)
|
| 65 |
-
),
|
| 66 |
-
trainable=True
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
# Learnable bias term for better discrimination
|
| 70 |
-
self.similarity_bias = self.add_weight(
|
| 71 |
-
name='similarity_bias',
|
| 72 |
-
shape=(),
|
| 73 |
-
initializer=tf.keras.initializers.Zeros(),
|
| 74 |
-
trainable=True
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
super().build(input_shape)
|
| 78 |
-
|
| 79 |
-
def call(self, user_embeddings, item_embeddings):
|
| 80 |
-
"""Compute adaptive temperature-scaled similarity with bias."""
|
| 81 |
-
# Constrain temperature to valid range
|
| 82 |
-
temperature = self.min_temp + tf.nn.softplus(self.raw_temperature)
|
| 83 |
-
temperature = tf.minimum(temperature, self.max_temp)
|
| 84 |
-
|
| 85 |
-
# Compute similarities
|
| 86 |
-
similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 87 |
-
|
| 88 |
-
# Add learnable bias and apply temperature scaling
|
| 89 |
-
scaled_similarities = (similarities + self.similarity_bias) / temperature
|
| 90 |
-
|
| 91 |
-
return scaled_similarities, temperature
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class EnhancedItemTower(tf.keras.Model):
|
| 95 |
-
"""Enhanced item tower with diversity regularization."""
|
| 96 |
-
|
| 97 |
-
def __init__(self,
|
| 98 |
-
item_vocab_size: int,
|
| 99 |
-
category_vocab_size: int,
|
| 100 |
-
brand_vocab_size: int,
|
| 101 |
-
embedding_dim: int = 128,
|
| 102 |
-
hidden_dims: list = [256, 128],
|
| 103 |
-
dropout_rate: float = 0.3,
|
| 104 |
-
use_bias: bool = True,
|
| 105 |
-
use_diversity_reg: bool = True):
|
| 106 |
-
super().__init__()
|
| 107 |
-
|
| 108 |
-
self.embedding_dim = embedding_dim
|
| 109 |
-
self.use_bias = use_bias
|
| 110 |
-
self.use_diversity_reg = use_diversity_reg
|
| 111 |
-
|
| 112 |
-
# Embedding layers with better initialization
|
| 113 |
-
self.item_embedding = tf.keras.layers.Embedding(
|
| 114 |
-
item_vocab_size, embedding_dim,
|
| 115 |
-
embeddings_initializer='he_normal', # Better initialization
|
| 116 |
-
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 117 |
-
name="item_embedding"
|
| 118 |
-
)
|
| 119 |
-
self.category_embedding = tf.keras.layers.Embedding(
|
| 120 |
-
category_vocab_size, embedding_dim,
|
| 121 |
-
embeddings_initializer='he_normal',
|
| 122 |
-
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 123 |
-
name="category_embedding"
|
| 124 |
-
)
|
| 125 |
-
self.brand_embedding = tf.keras.layers.Embedding(
|
| 126 |
-
brand_vocab_size, embedding_dim,
|
| 127 |
-
embeddings_initializer='he_normal',
|
| 128 |
-
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 129 |
-
name="brand_embedding"
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
# Price processing
|
| 133 |
-
self.price_normalization = tf.keras.layers.Normalization(name="price_norm")
|
| 134 |
-
self.price_projection = tf.keras.layers.Dense(
|
| 135 |
-
embedding_dim // 4, activation='relu', name="price_proj"
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
# Enhanced attention mechanism
|
| 139 |
-
self.feature_attention = tf.keras.layers.MultiHeadAttention(
|
| 140 |
-
num_heads=4,
|
| 141 |
-
key_dim=embedding_dim,
|
| 142 |
-
dropout=0.1,
|
| 143 |
-
name="feature_attention"
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
# Dense layers with residual connections
|
| 147 |
-
self.dense_layers = []
|
| 148 |
-
for i, dim in enumerate(hidden_dims):
|
| 149 |
-
self.dense_layers.extend([
|
| 150 |
-
tf.keras.layers.Dense(dim, activation=None, name=f"dense_{i}"),
|
| 151 |
-
tf.keras.layers.BatchNormalization(name=f"bn_{i}"),
|
| 152 |
-
tf.keras.layers.Activation('relu', name=f"relu_{i}"),
|
| 153 |
-
tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}")
|
| 154 |
-
])
|
| 155 |
-
|
| 156 |
-
# Output layer with controlled normalization
|
| 157 |
-
self.output_layer = tf.keras.layers.Dense(
|
| 158 |
-
embedding_dim, activation=None, use_bias=use_bias, name="item_output"
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
# Diversity regularizer
|
| 162 |
-
if use_diversity_reg:
|
| 163 |
-
self.diversity_regularizer = EmbeddingDiversityRegularizer()
|
| 164 |
-
|
| 165 |
-
# Adaptive normalization instead of hard L2 normalization
|
| 166 |
-
self.adaptive_norm = tf.keras.layers.LayerNormalization(name="adaptive_norm")
|
| 167 |
-
|
| 168 |
-
# Item bias
|
| 169 |
-
if use_bias:
|
| 170 |
-
self.item_bias = tf.keras.layers.Embedding(
|
| 171 |
-
item_vocab_size, 1, name="item_bias"
|
| 172 |
-
)
|
| 173 |
-
|
| 174 |
-
def call(self, inputs, training=None):
|
| 175 |
-
"""Enhanced forward pass with diversity regularization."""
|
| 176 |
-
item_id = inputs["product_id"]
|
| 177 |
-
category_id = inputs["category_id"]
|
| 178 |
-
brand_id = inputs["brand_id"]
|
| 179 |
-
price = inputs["price"]
|
| 180 |
-
|
| 181 |
-
# Get embeddings
|
| 182 |
-
item_emb = self.item_embedding(item_id)
|
| 183 |
-
category_emb = self.category_embedding(category_id)
|
| 184 |
-
brand_emb = self.brand_embedding(brand_id)
|
| 185 |
-
|
| 186 |
-
# Process price
|
| 187 |
-
price_norm = self.price_normalization(tf.expand_dims(price, -1))
|
| 188 |
-
price_emb = self.price_projection(price_norm)
|
| 189 |
-
|
| 190 |
-
# Pad price embedding
|
| 191 |
-
price_emb_padded = tf.pad(
|
| 192 |
-
price_emb,
|
| 193 |
-
[[0, 0], [0, self.embedding_dim - tf.shape(price_emb)[-1]]]
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
# Stack features for attention
|
| 197 |
-
features = tf.stack([item_emb, category_emb, brand_emb, price_emb_padded], axis=1)
|
| 198 |
-
|
| 199 |
-
# Apply attention
|
| 200 |
-
attended_features = self.feature_attention(
|
| 201 |
-
query=features,
|
| 202 |
-
value=features,
|
| 203 |
-
key=features,
|
| 204 |
-
training=training
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
# Aggregate with residual connection
|
| 208 |
-
combined = tf.reduce_mean(attended_features + features, axis=1)
|
| 209 |
-
|
| 210 |
-
# Pass through dense layers with residual connections
|
| 211 |
-
x = combined
|
| 212 |
-
residual = x
|
| 213 |
-
for i, layer in enumerate(self.dense_layers):
|
| 214 |
-
x = layer(x, training=training)
|
| 215 |
-
# Add residual connection every 4 layers (complete block)
|
| 216 |
-
if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
|
| 217 |
-
x = x + residual
|
| 218 |
-
residual = x
|
| 219 |
-
|
| 220 |
-
# Final output
|
| 221 |
-
output = self.output_layer(x)
|
| 222 |
-
|
| 223 |
-
# Apply diversity regularization if enabled
|
| 224 |
-
if self.use_diversity_reg and training:
|
| 225 |
-
output = self.diversity_regularizer(output)
|
| 226 |
-
|
| 227 |
-
# Adaptive normalization instead of hard L2
|
| 228 |
-
normalized_output = self.adaptive_norm(output)
|
| 229 |
-
|
| 230 |
-
# Add bias if enabled
|
| 231 |
-
if self.use_bias:
|
| 232 |
-
bias = tf.squeeze(self.item_bias(item_id), axis=-1)
|
| 233 |
-
return normalized_output, bias
|
| 234 |
-
else:
|
| 235 |
-
return normalized_output
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
class EnhancedUserTower(tf.keras.Model):
|
| 239 |
-
"""Enhanced user tower with diversity regularization."""
|
| 240 |
-
|
| 241 |
-
def __init__(self,
|
| 242 |
-
max_history_length: int = 50,
|
| 243 |
-
embedding_dim: int = 128,
|
| 244 |
-
hidden_dims: list = [256, 128],
|
| 245 |
-
dropout_rate: float = 0.3,
|
| 246 |
-
use_bias: bool = True,
|
| 247 |
-
use_diversity_reg: bool = True):
|
| 248 |
-
super().__init__()
|
| 249 |
-
|
| 250 |
-
self.embedding_dim = embedding_dim
|
| 251 |
-
self.max_history_length = max_history_length
|
| 252 |
-
self.use_bias = use_bias
|
| 253 |
-
self.use_diversity_reg = use_diversity_reg
|
| 254 |
-
|
| 255 |
-
# Demographic embeddings with regularization
|
| 256 |
-
self.age_embedding = tf.keras.layers.Embedding(
|
| 257 |
-
6, embedding_dim // 16,
|
| 258 |
-
embeddings_initializer='he_normal',
|
| 259 |
-
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 260 |
-
name="age_embedding"
|
| 261 |
-
)
|
| 262 |
-
self.income_embedding = tf.keras.layers.Embedding(
|
| 263 |
-
5, embedding_dim // 16,
|
| 264 |
-
embeddings_initializer='he_normal',
|
| 265 |
-
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 266 |
-
name="income_embedding"
|
| 267 |
-
)
|
| 268 |
-
self.gender_embedding = tf.keras.layers.Embedding(
|
| 269 |
-
2, embedding_dim // 16,
|
| 270 |
-
embeddings_initializer='he_normal',
|
| 271 |
-
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 272 |
-
name="gender_embedding"
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
# Enhanced history processing
|
| 276 |
-
self.history_transformer = tf.keras.layers.MultiHeadAttention(
|
| 277 |
-
num_heads=8,
|
| 278 |
-
key_dim=embedding_dim,
|
| 279 |
-
dropout=0.1,
|
| 280 |
-
name="history_transformer"
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
-
# History aggregation with attention pooling
|
| 284 |
-
self.history_attention_pooling = tf.keras.layers.Dense(
|
| 285 |
-
1, activation=None, name="history_attention"
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
# Dense layers with residual connections
|
| 289 |
-
self.dense_layers = []
|
| 290 |
-
for i, dim in enumerate(hidden_dims):
|
| 291 |
-
self.dense_layers.extend([
|
| 292 |
-
tf.keras.layers.Dense(dim, activation=None, name=f"user_dense_{i}"),
|
| 293 |
-
tf.keras.layers.BatchNormalization(name=f"user_bn_{i}"),
|
| 294 |
-
tf.keras.layers.Activation('relu', name=f"user_relu_{i}"),
|
| 295 |
-
tf.keras.layers.Dropout(dropout_rate, name=f"user_dropout_{i}")
|
| 296 |
-
])
|
| 297 |
-
|
| 298 |
-
# Output layer
|
| 299 |
-
self.output_layer = tf.keras.layers.Dense(
|
| 300 |
-
embedding_dim, activation=None, use_bias=use_bias, name="user_output"
|
| 301 |
-
)
|
| 302 |
-
|
| 303 |
-
# Diversity regularizer
|
| 304 |
-
if use_diversity_reg:
|
| 305 |
-
self.diversity_regularizer = EmbeddingDiversityRegularizer()
|
| 306 |
-
|
| 307 |
-
# Adaptive normalization
|
| 308 |
-
self.adaptive_norm = tf.keras.layers.LayerNormalization(name="user_adaptive_norm")
|
| 309 |
-
|
| 310 |
-
# Global user bias
|
| 311 |
-
if use_bias:
|
| 312 |
-
self.global_user_bias = tf.Variable(
|
| 313 |
-
initial_value=0.0, trainable=True, name="global_user_bias"
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
def call(self, inputs, training=None):
|
| 317 |
-
"""Enhanced forward pass with diversity regularization."""
|
| 318 |
-
age = inputs["age"]
|
| 319 |
-
gender = inputs["gender"]
|
| 320 |
-
income = inputs["income"]
|
| 321 |
-
item_history = inputs["item_history_embeddings"]
|
| 322 |
-
|
| 323 |
-
# Process demographics
|
| 324 |
-
age_emb = self.age_embedding(age)
|
| 325 |
-
income_emb = self.income_embedding(income)
|
| 326 |
-
gender_emb = self.gender_embedding(gender)
|
| 327 |
-
|
| 328 |
-
# Combine demographics
|
| 329 |
-
demo_combined = tf.concat([age_emb, income_emb, gender_emb], axis=-1)
|
| 330 |
-
|
| 331 |
-
# Enhanced history processing
|
| 332 |
-
batch_size = tf.shape(item_history)[0]
|
| 333 |
-
seq_len = tf.shape(item_history)[1]
|
| 334 |
-
|
| 335 |
-
# Simplified positional encoding - ensure shape compatibility
|
| 336 |
-
positions = tf.range(seq_len, dtype=tf.float32)
|
| 337 |
-
# Create simpler positional encoding
|
| 338 |
-
pos_encoding_scale = tf.cast(tf.range(self.embedding_dim, dtype=tf.float32), tf.float32) / self.embedding_dim
|
| 339 |
-
position_encoding = tf.sin(positions[:, tf.newaxis] * pos_encoding_scale[tf.newaxis, :])
|
| 340 |
-
|
| 341 |
-
# Ensure correct shape: [seq_len, embedding_dim] -> [batch_size, seq_len, embedding_dim]
|
| 342 |
-
position_encoding = tf.expand_dims(position_encoding, 0)
|
| 343 |
-
position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
|
| 344 |
-
|
| 345 |
-
# Add positional encoding with shape check
|
| 346 |
-
history_with_pos = item_history + position_encoding
|
| 347 |
-
|
| 348 |
-
# Create attention mask - fix shape for MultiHeadAttention
|
| 349 |
-
# MultiHeadAttention expects mask shape: [batch_size, seq_len] or [batch_size, seq_len, seq_len]
|
| 350 |
-
history_mask = tf.reduce_sum(tf.abs(item_history), axis=-1) > 0 # [batch_size, seq_len]
|
| 351 |
-
|
| 352 |
-
# Apply transformer attention
|
| 353 |
-
attended_history = self.history_transformer(
|
| 354 |
-
query=history_with_pos,
|
| 355 |
-
value=history_with_pos,
|
| 356 |
-
key=history_with_pos,
|
| 357 |
-
attention_mask=history_mask,
|
| 358 |
-
training=training
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
# Attention-based pooling instead of simple mean
|
| 362 |
-
attention_weights = tf.nn.softmax(
|
| 363 |
-
self.history_attention_pooling(attended_history), axis=1
|
| 364 |
-
)
|
| 365 |
-
history_aggregated = tf.reduce_sum(
|
| 366 |
-
attended_history * attention_weights, axis=1
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
# Combine features
|
| 370 |
-
combined = tf.concat([demo_combined, history_aggregated], axis=-1)
|
| 371 |
-
|
| 372 |
-
# Pass through dense layers with residual connections
|
| 373 |
-
x = combined
|
| 374 |
-
residual = x
|
| 375 |
-
for i, layer in enumerate(self.dense_layers):
|
| 376 |
-
x = layer(x, training=training)
|
| 377 |
-
# Add residual connection every 4 layers
|
| 378 |
-
if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
|
| 379 |
-
x = x + residual
|
| 380 |
-
residual = x
|
| 381 |
-
|
| 382 |
-
# Final output
|
| 383 |
-
output = self.output_layer(x)
|
| 384 |
-
|
| 385 |
-
# Apply diversity regularization if enabled
|
| 386 |
-
if self.use_diversity_reg and training:
|
| 387 |
-
output = self.diversity_regularizer(output)
|
| 388 |
-
|
| 389 |
-
# Adaptive normalization
|
| 390 |
-
normalized_output = self.adaptive_norm(output)
|
| 391 |
-
|
| 392 |
-
# Add bias if enabled
|
| 393 |
-
if self.use_bias:
|
| 394 |
-
return normalized_output, self.global_user_bias
|
| 395 |
-
else:
|
| 396 |
-
return normalized_output
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
class EnhancedTwoTowerModel(tfrs.Model):
|
| 400 |
-
"""Enhanced two-tower model with all improvements."""
|
| 401 |
-
|
| 402 |
-
def __init__(self,
|
| 403 |
-
item_tower: EnhancedItemTower,
|
| 404 |
-
user_tower: EnhancedUserTower,
|
| 405 |
-
rating_weight: float = 1.0,
|
| 406 |
-
retrieval_weight: float = 1.0,
|
| 407 |
-
contrastive_weight: float = 0.3,
|
| 408 |
-
diversity_weight: float = 0.1):
|
| 409 |
-
super().__init__()
|
| 410 |
-
|
| 411 |
-
self.item_tower = item_tower
|
| 412 |
-
self.user_tower = user_tower
|
| 413 |
-
self.rating_weight = rating_weight
|
| 414 |
-
self.retrieval_weight = retrieval_weight
|
| 415 |
-
self.contrastive_weight = contrastive_weight
|
| 416 |
-
self.diversity_weight = diversity_weight
|
| 417 |
-
|
| 418 |
-
# Adaptive temperature scaling
|
| 419 |
-
self.temperature_similarity = AdaptiveTemperatureScaling()
|
| 420 |
-
|
| 421 |
-
# Enhanced rating model
|
| 422 |
-
self.rating_model = tf.keras.Sequential([
|
| 423 |
-
tf.keras.layers.Dense(512, activation="relu"),
|
| 424 |
-
tf.keras.layers.BatchNormalization(),
|
| 425 |
-
tf.keras.layers.Dropout(0.3),
|
| 426 |
-
tf.keras.layers.Dense(256, activation="relu"),
|
| 427 |
-
tf.keras.layers.BatchNormalization(),
|
| 428 |
-
tf.keras.layers.Dropout(0.2),
|
| 429 |
-
tf.keras.layers.Dense(64, activation="relu"),
|
| 430 |
-
tf.keras.layers.Dense(1, activation="sigmoid")
|
| 431 |
-
])
|
| 432 |
-
|
| 433 |
-
# Focal loss for imbalanced data
|
| 434 |
-
self.focal_loss = self._focal_loss
|
| 435 |
-
|
| 436 |
-
def _focal_loss(self, y_true, y_pred, alpha=0.25, gamma=2.0):
|
| 437 |
-
"""Focal loss implementation."""
|
| 438 |
-
epsilon = tf.keras.backend.epsilon()
|
| 439 |
-
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
|
| 440 |
-
|
| 441 |
-
alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
|
| 442 |
-
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
|
| 443 |
-
focal_weight = alpha_t * tf.pow((1 - p_t), gamma)
|
| 444 |
-
|
| 445 |
-
bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
|
| 446 |
-
focal_loss = focal_weight * bce
|
| 447 |
-
|
| 448 |
-
return tf.reduce_mean(focal_loss)
|
| 449 |
-
|
| 450 |
-
def call(self, features):
|
| 451 |
-
# Get embeddings
|
| 452 |
-
user_output = self.user_tower(features)
|
| 453 |
-
item_output = self.item_tower(features)
|
| 454 |
-
|
| 455 |
-
# Handle bias terms
|
| 456 |
-
if isinstance(user_output, tuple):
|
| 457 |
-
user_embeddings, user_bias = user_output
|
| 458 |
-
else:
|
| 459 |
-
user_embeddings = user_output
|
| 460 |
-
user_bias = 0.0
|
| 461 |
-
|
| 462 |
-
if isinstance(item_output, tuple):
|
| 463 |
-
item_embeddings, item_bias = item_output
|
| 464 |
-
else:
|
| 465 |
-
item_embeddings = item_output
|
| 466 |
-
item_bias = 0.0
|
| 467 |
-
|
| 468 |
-
return {
|
| 469 |
-
"user_embedding": user_embeddings,
|
| 470 |
-
"item_embedding": item_embeddings,
|
| 471 |
-
"user_bias": user_bias,
|
| 472 |
-
"item_bias": item_bias
|
| 473 |
-
}
|
| 474 |
-
|
| 475 |
-
def compute_loss(self, features, training=False):
|
| 476 |
-
# Get embeddings and biases
|
| 477 |
-
outputs = self(features)
|
| 478 |
-
user_embeddings = outputs["user_embedding"]
|
| 479 |
-
item_embeddings = outputs["item_embedding"]
|
| 480 |
-
user_bias = outputs["user_bias"]
|
| 481 |
-
item_bias = outputs["item_bias"]
|
| 482 |
-
|
| 483 |
-
# Rating prediction
|
| 484 |
-
concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
|
| 485 |
-
rating_predictions = self.rating_model(concatenated, training=training)
|
| 486 |
-
|
| 487 |
-
# Add bias terms
|
| 488 |
-
rating_predictions_with_bias = rating_predictions + user_bias + item_bias
|
| 489 |
-
rating_predictions_with_bias = tf.nn.sigmoid(rating_predictions_with_bias)
|
| 490 |
-
|
| 491 |
-
# Losses
|
| 492 |
-
rating_loss = self.focal_loss(features["rating"], rating_predictions_with_bias)
|
| 493 |
-
|
| 494 |
-
# Adaptive temperature-scaled retrieval loss
|
| 495 |
-
scaled_similarities, temperature = self.temperature_similarity(
|
| 496 |
-
user_embeddings, item_embeddings
|
| 497 |
-
)
|
| 498 |
-
retrieval_loss = tf.keras.losses.binary_crossentropy(
|
| 499 |
-
features["rating"],
|
| 500 |
-
tf.nn.sigmoid(scaled_similarities)
|
| 501 |
-
)
|
| 502 |
-
retrieval_loss = tf.reduce_mean(retrieval_loss)
|
| 503 |
-
|
| 504 |
-
# Enhanced contrastive loss with hard negatives
|
| 505 |
-
batch_size = tf.shape(user_embeddings)[0]
|
| 506 |
-
positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 507 |
-
|
| 508 |
-
# Random negative sampling
|
| 509 |
-
shuffled_indices = tf.random.shuffle(tf.range(batch_size))
|
| 510 |
-
negative_item_embeddings = tf.gather(item_embeddings, shuffled_indices)
|
| 511 |
-
negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
|
| 512 |
-
|
| 513 |
-
# Triplet loss with adaptive margin
|
| 514 |
-
margin = 0.5 / temperature # Adaptive margin based on temperature
|
| 515 |
-
contrastive_loss = tf.reduce_mean(
|
| 516 |
-
tf.maximum(0.0, margin + negative_similarities - positive_similarities)
|
| 517 |
-
)
|
| 518 |
-
|
| 519 |
-
# Combine losses
|
| 520 |
-
total_loss = (
|
| 521 |
-
self.rating_weight * rating_loss +
|
| 522 |
-
self.retrieval_weight * retrieval_loss +
|
| 523 |
-
self.contrastive_weight * contrastive_loss
|
| 524 |
-
)
|
| 525 |
-
|
| 526 |
-
# Add regularization losses from diversity regularizers
|
| 527 |
-
if training:
|
| 528 |
-
regularization_losses = tf.add_n(self.losses) if self.losses else 0.0
|
| 529 |
-
total_loss += self.diversity_weight * regularization_losses
|
| 530 |
-
|
| 531 |
-
return {
|
| 532 |
-
'total_loss': total_loss,
|
| 533 |
-
'rating_loss': rating_loss,
|
| 534 |
-
'retrieval_loss': retrieval_loss,
|
| 535 |
-
'contrastive_loss': contrastive_loss,
|
| 536 |
-
'temperature': temperature,
|
| 537 |
-
'diversity_loss': regularization_losses if training else 0.0
|
| 538 |
-
}
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
def create_enhanced_model(data_processor,
|
| 542 |
-
embedding_dim=128,
|
| 543 |
-
use_bias=True,
|
| 544 |
-
use_diversity_reg=True):
|
| 545 |
-
"""Factory function to create enhanced two-tower model."""
|
| 546 |
-
|
| 547 |
-
# Create enhanced towers
|
| 548 |
-
item_tower = EnhancedItemTower(
|
| 549 |
-
item_vocab_size=len(data_processor.item_vocab),
|
| 550 |
-
category_vocab_size=len(data_processor.category_vocab),
|
| 551 |
-
brand_vocab_size=len(data_processor.brand_vocab),
|
| 552 |
-
embedding_dim=embedding_dim,
|
| 553 |
-
use_bias=use_bias,
|
| 554 |
-
use_diversity_reg=use_diversity_reg
|
| 555 |
-
)
|
| 556 |
-
|
| 557 |
-
user_tower = EnhancedUserTower(
|
| 558 |
-
max_history_length=50,
|
| 559 |
-
embedding_dim=embedding_dim,
|
| 560 |
-
use_bias=use_bias,
|
| 561 |
-
use_diversity_reg=use_diversity_reg
|
| 562 |
-
)
|
| 563 |
-
|
| 564 |
-
# Create enhanced model
|
| 565 |
-
model = EnhancedTwoTowerModel(
|
| 566 |
-
item_tower=item_tower,
|
| 567 |
-
user_tower=user_tower,
|
| 568 |
-
rating_weight=1.0,
|
| 569 |
-
retrieval_weight=0.5,
|
| 570 |
-
contrastive_weight=0.3,
|
| 571 |
-
diversity_weight=0.1
|
| 572 |
-
)
|
| 573 |
-
|
| 574 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/improved_two_tower.py
DELETED
|
@@ -1,545 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Improved two-tower model with better embedding discrimination and training stability.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import tensorflow as tf
|
| 7 |
-
import tensorflow_recommenders as tfrs
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class ImprovedItemTower(tf.keras.Model):
|
| 12 |
-
"""Enhanced item tower with better discrimination and representation capacity."""
|
| 13 |
-
|
| 14 |
-
def __init__(self,
|
| 15 |
-
item_vocab_size: int,
|
| 16 |
-
category_vocab_size: int,
|
| 17 |
-
brand_vocab_size: int,
|
| 18 |
-
embedding_dim: int = 128, # Increased from 64
|
| 19 |
-
hidden_dims: list = [256, 128], # Deeper network
|
| 20 |
-
dropout_rate: float = 0.3,
|
| 21 |
-
use_bias: bool = True):
|
| 22 |
-
super().__init__()
|
| 23 |
-
|
| 24 |
-
self.embedding_dim = embedding_dim
|
| 25 |
-
self.use_bias = use_bias
|
| 26 |
-
|
| 27 |
-
# Larger embedding layers with proper initialization
|
| 28 |
-
self.item_embedding = tf.keras.layers.Embedding(
|
| 29 |
-
item_vocab_size, embedding_dim,
|
| 30 |
-
embeddings_initializer='glorot_uniform',
|
| 31 |
-
name="item_embedding"
|
| 32 |
-
)
|
| 33 |
-
self.category_embedding = tf.keras.layers.Embedding(
|
| 34 |
-
category_vocab_size, embedding_dim,
|
| 35 |
-
embeddings_initializer='glorot_uniform',
|
| 36 |
-
name="category_embedding"
|
| 37 |
-
)
|
| 38 |
-
self.brand_embedding = tf.keras.layers.Embedding(
|
| 39 |
-
brand_vocab_size, embedding_dim,
|
| 40 |
-
embeddings_initializer='glorot_uniform',
|
| 41 |
-
name="brand_embedding"
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
# Price normalization and projection
|
| 45 |
-
self.price_normalization = tf.keras.layers.Normalization(name="price_norm")
|
| 46 |
-
self.price_projection = tf.keras.layers.Dense(
|
| 47 |
-
embedding_dim // 4, activation='relu', name="price_proj"
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
# Attention mechanism for feature fusion
|
| 51 |
-
self.feature_attention = tf.keras.layers.MultiHeadAttention(
|
| 52 |
-
num_heads=4,
|
| 53 |
-
key_dim=embedding_dim,
|
| 54 |
-
name="feature_attention"
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
# Enhanced dense layers with batch normalization
|
| 58 |
-
self.dense_layers = []
|
| 59 |
-
for i, dim in enumerate(hidden_dims):
|
| 60 |
-
self.dense_layers.extend([
|
| 61 |
-
tf.keras.layers.Dense(dim, activation=None, name=f"dense_{i}"),
|
| 62 |
-
tf.keras.layers.BatchNormalization(name=f"bn_{i}"),
|
| 63 |
-
tf.keras.layers.Activation('relu', name=f"relu_{i}"),
|
| 64 |
-
tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}")
|
| 65 |
-
])
|
| 66 |
-
|
| 67 |
-
# Output projection with bias term
|
| 68 |
-
self.output_layer = tf.keras.layers.Dense(
|
| 69 |
-
embedding_dim, activation=None, use_bias=use_bias, name="item_output"
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
# Learnable bias term for each item
|
| 73 |
-
if use_bias:
|
| 74 |
-
self.item_bias = tf.keras.layers.Embedding(
|
| 75 |
-
item_vocab_size, 1, name="item_bias"
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
def call(self, inputs, training=None):
|
| 79 |
-
"""Enhanced forward pass with attention and better feature fusion."""
|
| 80 |
-
item_id = inputs["product_id"]
|
| 81 |
-
category_id = inputs["category_id"]
|
| 82 |
-
brand_id = inputs["brand_id"]
|
| 83 |
-
price = inputs["price"]
|
| 84 |
-
|
| 85 |
-
# Get embeddings
|
| 86 |
-
item_emb = self.item_embedding(item_id) # [batch, emb_dim]
|
| 87 |
-
category_emb = self.category_embedding(category_id)
|
| 88 |
-
brand_emb = self.brand_embedding(brand_id)
|
| 89 |
-
|
| 90 |
-
# Process price
|
| 91 |
-
price_norm = self.price_normalization(tf.expand_dims(price, -1))
|
| 92 |
-
price_emb = self.price_projection(price_norm)
|
| 93 |
-
|
| 94 |
-
# Pad price embedding to match others
|
| 95 |
-
price_emb_padded = tf.pad(
|
| 96 |
-
price_emb,
|
| 97 |
-
[[0, 0], [0, self.embedding_dim - tf.shape(price_emb)[-1]]]
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
# Stack features for attention [batch, 4, emb_dim]
|
| 101 |
-
features = tf.stack([item_emb, category_emb, brand_emb, price_emb_padded], axis=1)
|
| 102 |
-
|
| 103 |
-
# Apply self-attention for feature fusion
|
| 104 |
-
attended_features = self.feature_attention(
|
| 105 |
-
query=features,
|
| 106 |
-
value=features,
|
| 107 |
-
key=features,
|
| 108 |
-
training=training
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
# Aggregate features (mean pooling)
|
| 112 |
-
combined = tf.reduce_mean(attended_features, axis=1)
|
| 113 |
-
|
| 114 |
-
# Pass through enhanced dense layers
|
| 115 |
-
x = combined
|
| 116 |
-
for layer in self.dense_layers:
|
| 117 |
-
x = layer(x, training=training)
|
| 118 |
-
|
| 119 |
-
# Final output
|
| 120 |
-
output = self.output_layer(x)
|
| 121 |
-
|
| 122 |
-
# L2 normalize for similarity computations
|
| 123 |
-
normalized_output = tf.nn.l2_normalize(output, axis=-1)
|
| 124 |
-
|
| 125 |
-
# Add bias if enabled
|
| 126 |
-
if self.use_bias:
|
| 127 |
-
bias = tf.squeeze(self.item_bias(item_id), axis=-1)
|
| 128 |
-
return normalized_output, bias
|
| 129 |
-
else:
|
| 130 |
-
return normalized_output
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
class ImprovedUserTower(tf.keras.Model):
|
| 134 |
-
"""Enhanced user tower with better history modeling and representation."""
|
| 135 |
-
|
| 136 |
-
def __init__(self,
|
| 137 |
-
max_history_length: int = 50,
|
| 138 |
-
embedding_dim: int = 128, # Increased from 64
|
| 139 |
-
hidden_dims: list = [256, 128], # Deeper network
|
| 140 |
-
dropout_rate: float = 0.3,
|
| 141 |
-
use_bias: bool = True):
|
| 142 |
-
super().__init__()
|
| 143 |
-
|
| 144 |
-
self.embedding_dim = embedding_dim
|
| 145 |
-
self.max_history_length = max_history_length
|
| 146 |
-
self.use_bias = use_bias
|
| 147 |
-
|
| 148 |
-
# Demographic embeddings (categorical features)
|
| 149 |
-
# Age: 6 categories (Teen, Young Adult, Adult, Middle Age, Mature, Senior)
|
| 150 |
-
self.age_embedding = tf.keras.layers.Embedding(
|
| 151 |
-
6, embedding_dim // 16,
|
| 152 |
-
embeddings_initializer='glorot_uniform',
|
| 153 |
-
name="age_embedding"
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
# Income: 5 categories (percentile-based)
|
| 157 |
-
self.income_embedding = tf.keras.layers.Embedding(
|
| 158 |
-
5, embedding_dim // 16,
|
| 159 |
-
embeddings_initializer='glorot_uniform',
|
| 160 |
-
name="income_embedding"
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# Gender: 2 categories (0=female, 1=male)
|
| 164 |
-
self.gender_embedding = tf.keras.layers.Embedding(
|
| 165 |
-
2, embedding_dim // 16,
|
| 166 |
-
embeddings_initializer='glorot_uniform',
|
| 167 |
-
name="gender_embedding"
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
# Improved history processing with positional encoding
|
| 171 |
-
self.history_transformer = tf.keras.layers.MultiHeadAttention(
|
| 172 |
-
num_heads=8, # More attention heads
|
| 173 |
-
key_dim=embedding_dim,
|
| 174 |
-
name="history_transformer"
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
# History aggregation with learned weights
|
| 178 |
-
self.history_aggregation = tf.keras.layers.Dense(
|
| 179 |
-
embedding_dim, activation='tanh', name="history_agg"
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
# Enhanced dense layers with batch normalization
|
| 183 |
-
self.dense_layers = []
|
| 184 |
-
for i, dim in enumerate(hidden_dims):
|
| 185 |
-
self.dense_layers.extend([
|
| 186 |
-
tf.keras.layers.Dense(dim, activation=None, name=f"user_dense_{i}"),
|
| 187 |
-
tf.keras.layers.BatchNormalization(name=f"user_bn_{i}"),
|
| 188 |
-
tf.keras.layers.Activation('relu', name=f"user_relu_{i}"),
|
| 189 |
-
tf.keras.layers.Dropout(dropout_rate, name=f"user_dropout_{i}")
|
| 190 |
-
])
|
| 191 |
-
|
| 192 |
-
# Output layer
|
| 193 |
-
self.output_layer = tf.keras.layers.Dense(
|
| 194 |
-
embedding_dim, activation=None, use_bias=use_bias, name="user_output"
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
# Learnable user bias
|
| 198 |
-
if use_bias:
|
| 199 |
-
# We'll need to handle user bias differently since we don't have user vocab in inference
|
| 200 |
-
self.global_user_bias = tf.Variable(
|
| 201 |
-
initial_value=0.0, trainable=True, name="global_user_bias"
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
def call(self, inputs, training=None):
|
| 205 |
-
"""Enhanced forward pass with better history modeling."""
|
| 206 |
-
age = inputs["age"] # Now categorical (0-5)
|
| 207 |
-
gender = inputs["gender"] # Categorical (0-1)
|
| 208 |
-
income = inputs["income"] # Now categorical (0-4)
|
| 209 |
-
item_history = inputs["item_history_embeddings"] # [batch_size, seq_len, emb_dim]
|
| 210 |
-
|
| 211 |
-
# Process demographics through embeddings
|
| 212 |
-
age_emb = self.age_embedding(age) # [batch_size, embedding_dim//16]
|
| 213 |
-
income_emb = self.income_embedding(income) # [batch_size, embedding_dim//16]
|
| 214 |
-
gender_emb = self.gender_embedding(gender) # [batch_size, embedding_dim//16]
|
| 215 |
-
|
| 216 |
-
# Combine all demographic embeddings
|
| 217 |
-
demo_combined = tf.concat([age_emb, income_emb, gender_emb], axis=-1)
|
| 218 |
-
# Total demographics: 3 * (embedding_dim//16) = ~18.75% of embedding_dim
|
| 219 |
-
|
| 220 |
-
# Enhanced history processing with positional encoding
|
| 221 |
-
batch_size = tf.shape(item_history)[0]
|
| 222 |
-
seq_len = tf.shape(item_history)[1]
|
| 223 |
-
|
| 224 |
-
# Create positional encoding
|
| 225 |
-
positions = tf.range(seq_len, dtype=tf.float32)
|
| 226 |
-
position_encoding = tf.sin(
|
| 227 |
-
positions[:, tf.newaxis] /
|
| 228 |
-
tf.pow(10000.0, 2 * tf.range(self.embedding_dim, dtype=tf.float32) / self.embedding_dim)
|
| 229 |
-
)
|
| 230 |
-
position_encoding = tf.expand_dims(position_encoding, 0)
|
| 231 |
-
position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
|
| 232 |
-
|
| 233 |
-
# Add positional encoding to history
|
| 234 |
-
history_with_pos = item_history + position_encoding
|
| 235 |
-
|
| 236 |
-
# Create attention mask for padding
|
| 237 |
-
history_mask = tf.reduce_sum(tf.abs(item_history), axis=-1) > 0
|
| 238 |
-
|
| 239 |
-
# Apply transformer attention to history
|
| 240 |
-
attended_history = self.history_transformer(
|
| 241 |
-
query=history_with_pos,
|
| 242 |
-
value=history_with_pos,
|
| 243 |
-
key=history_with_pos,
|
| 244 |
-
attention_mask=history_mask,
|
| 245 |
-
training=training
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
# Aggregate history with learned weights
|
| 249 |
-
history_weights = tf.nn.softmax(
|
| 250 |
-
tf.keras.layers.Dense(1)(attended_history), axis=1
|
| 251 |
-
)
|
| 252 |
-
history_aggregated = tf.reduce_sum(
|
| 253 |
-
attended_history * history_weights, axis=1
|
| 254 |
-
)
|
| 255 |
-
|
| 256 |
-
# Apply additional processing
|
| 257 |
-
history_processed = self.history_aggregation(history_aggregated)
|
| 258 |
-
|
| 259 |
-
# Combine all features
|
| 260 |
-
combined = tf.concat([
|
| 261 |
-
demo_combined,
|
| 262 |
-
history_processed
|
| 263 |
-
], axis=-1)
|
| 264 |
-
|
| 265 |
-
# Pass through enhanced dense layers
|
| 266 |
-
x = combined
|
| 267 |
-
for layer in self.dense_layers:
|
| 268 |
-
x = layer(x, training=training)
|
| 269 |
-
|
| 270 |
-
# Final output
|
| 271 |
-
output = self.output_layer(x)
|
| 272 |
-
|
| 273 |
-
# L2 normalize for similarity computations
|
| 274 |
-
normalized_output = tf.nn.l2_normalize(output, axis=-1)
|
| 275 |
-
|
| 276 |
-
# Add global bias if enabled
|
| 277 |
-
if self.use_bias:
|
| 278 |
-
return normalized_output, self.global_user_bias
|
| 279 |
-
else:
|
| 280 |
-
return normalized_output
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
class TemperatureScaledSimilarity(tf.keras.layers.Layer):
|
| 284 |
-
"""Learnable temperature scaling for similarity computations."""
|
| 285 |
-
|
| 286 |
-
def __init__(self, initial_temperature=1.0, **kwargs):
|
| 287 |
-
super().__init__(**kwargs)
|
| 288 |
-
self.initial_temperature = initial_temperature
|
| 289 |
-
|
| 290 |
-
def build(self, input_shape):
|
| 291 |
-
self.temperature = self.add_weight(
|
| 292 |
-
name='temperature',
|
| 293 |
-
shape=(),
|
| 294 |
-
initializer=tf.keras.initializers.Constant(self.initial_temperature),
|
| 295 |
-
trainable=True
|
| 296 |
-
)
|
| 297 |
-
super().build(input_shape)
|
| 298 |
-
|
| 299 |
-
def call(self, user_embeddings, item_embeddings):
|
| 300 |
-
"""Compute temperature-scaled similarity."""
|
| 301 |
-
# Dot product similarity
|
| 302 |
-
similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 303 |
-
|
| 304 |
-
# Scale by learnable temperature
|
| 305 |
-
scaled_similarities = similarities / tf.maximum(self.temperature, 0.01) # Prevent division by 0
|
| 306 |
-
|
| 307 |
-
return scaled_similarities
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
class ImprovedTwoTowerModel(tfrs.Model):
|
| 311 |
-
"""Enhanced two-tower model with better discrimination and training stability."""
|
| 312 |
-
|
| 313 |
-
def __init__(self,
|
| 314 |
-
item_tower: ImprovedItemTower,
|
| 315 |
-
user_tower: ImprovedUserTower,
|
| 316 |
-
rating_weight: float = 1.0,
|
| 317 |
-
retrieval_weight: float = 1.0,
|
| 318 |
-
contrastive_weight: float = 0.5,
|
| 319 |
-
use_focal_loss: bool = True):
|
| 320 |
-
super().__init__()
|
| 321 |
-
|
| 322 |
-
self.item_tower = item_tower
|
| 323 |
-
self.user_tower = user_tower
|
| 324 |
-
self.rating_weight = rating_weight
|
| 325 |
-
self.retrieval_weight = retrieval_weight
|
| 326 |
-
self.contrastive_weight = contrastive_weight
|
| 327 |
-
self.use_focal_loss = use_focal_loss
|
| 328 |
-
|
| 329 |
-
# Temperature-scaled similarity
|
| 330 |
-
self.temperature_similarity = TemperatureScaledSimilarity()
|
| 331 |
-
|
| 332 |
-
# Enhanced rating prediction with more capacity
|
| 333 |
-
self.rating_model = tf.keras.Sequential([
|
| 334 |
-
tf.keras.layers.Dense(512, activation="relu"),
|
| 335 |
-
tf.keras.layers.BatchNormalization(),
|
| 336 |
-
tf.keras.layers.Dropout(0.3),
|
| 337 |
-
tf.keras.layers.Dense(256, activation="relu"),
|
| 338 |
-
tf.keras.layers.BatchNormalization(),
|
| 339 |
-
tf.keras.layers.Dropout(0.2),
|
| 340 |
-
tf.keras.layers.Dense(64, activation="relu"),
|
| 341 |
-
tf.keras.layers.Dense(1, activation="sigmoid")
|
| 342 |
-
])
|
| 343 |
-
|
| 344 |
-
# Rating task with better loss
|
| 345 |
-
if use_focal_loss:
|
| 346 |
-
self.rating_loss = self._focal_loss
|
| 347 |
-
else:
|
| 348 |
-
self.rating_loss = tf.keras.losses.BinaryCrossentropy()
|
| 349 |
-
|
| 350 |
-
# Contrastive loss for embedding separation
|
| 351 |
-
self.contrastive_loss = tf.keras.losses.CosineSimilarity()
|
| 352 |
-
|
| 353 |
-
def _focal_loss(self, y_true, y_pred, alpha=0.25, gamma=2.0):
|
| 354 |
-
"""Focal loss for handling imbalanced data."""
|
| 355 |
-
epsilon = tf.keras.backend.epsilon()
|
| 356 |
-
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
|
| 357 |
-
|
| 358 |
-
# Compute focal weight
|
| 359 |
-
alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
|
| 360 |
-
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
|
| 361 |
-
focal_weight = alpha_t * tf.pow((1 - p_t), gamma)
|
| 362 |
-
|
| 363 |
-
# Compute loss
|
| 364 |
-
bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
|
| 365 |
-
focal_loss = focal_weight * bce
|
| 366 |
-
|
| 367 |
-
return tf.reduce_mean(focal_loss)
|
| 368 |
-
|
| 369 |
-
def call(self, features):
|
| 370 |
-
# Get embeddings (handle bias if present)
|
| 371 |
-
user_output = self.user_tower(features)
|
| 372 |
-
item_output = self.item_tower(features)
|
| 373 |
-
|
| 374 |
-
# Handle bias terms
|
| 375 |
-
if isinstance(user_output, tuple):
|
| 376 |
-
user_embeddings, user_bias = user_output
|
| 377 |
-
else:
|
| 378 |
-
user_embeddings = user_output
|
| 379 |
-
user_bias = 0.0
|
| 380 |
-
|
| 381 |
-
if isinstance(item_output, tuple):
|
| 382 |
-
item_embeddings, item_bias = item_output
|
| 383 |
-
else:
|
| 384 |
-
item_embeddings = item_output
|
| 385 |
-
item_bias = 0.0
|
| 386 |
-
|
| 387 |
-
return {
|
| 388 |
-
"user_embedding": user_embeddings,
|
| 389 |
-
"item_embedding": item_embeddings,
|
| 390 |
-
"user_bias": user_bias,
|
| 391 |
-
"item_bias": item_bias
|
| 392 |
-
}
|
| 393 |
-
|
| 394 |
-
def _hard_negative_mining(self, user_embeddings, item_embeddings, ratings, num_negatives=5):
|
| 395 |
-
"""Mine hard negatives for better training."""
|
| 396 |
-
batch_size = tf.shape(user_embeddings)[0]
|
| 397 |
-
|
| 398 |
-
# Compute all pairwise similarities
|
| 399 |
-
user_norm = tf.nn.l2_normalize(user_embeddings, axis=1)
|
| 400 |
-
item_norm = tf.nn.l2_normalize(item_embeddings, axis=1)
|
| 401 |
-
|
| 402 |
-
# Expand dimensions for broadcasting: [batch, 1, dim] x [1, batch, dim]
|
| 403 |
-
user_expanded = tf.expand_dims(user_norm, 1)
|
| 404 |
-
item_expanded = tf.expand_dims(item_norm, 0)
|
| 405 |
-
|
| 406 |
-
# Compute similarity matrix [batch, batch]
|
| 407 |
-
similarity_matrix = tf.reduce_sum(user_expanded * item_expanded, axis=2)
|
| 408 |
-
|
| 409 |
-
# Create mask to exclude positive pairs
|
| 410 |
-
positive_mask = tf.eye(batch_size, dtype=tf.bool)
|
| 411 |
-
negative_mask = tf.logical_not(positive_mask)
|
| 412 |
-
|
| 413 |
-
# Get negative similarities and find hardest negatives
|
| 414 |
-
negative_similarities = tf.where(negative_mask, similarity_matrix, -tf.float32.max)
|
| 415 |
-
|
| 416 |
-
# Get top-k hardest negatives (highest similarities among negatives)
|
| 417 |
-
_, hard_negative_indices = tf.nn.top_k(negative_similarities, k=num_negatives)
|
| 418 |
-
|
| 419 |
-
return hard_negative_indices
|
| 420 |
-
|
| 421 |
-
def compute_loss(self, features, training=False):
|
| 422 |
-
# Get embeddings and biases
|
| 423 |
-
outputs = self(features)
|
| 424 |
-
user_embeddings = outputs["user_embedding"]
|
| 425 |
-
item_embeddings = outputs["item_embedding"]
|
| 426 |
-
user_bias = outputs["user_bias"]
|
| 427 |
-
item_bias = outputs["item_bias"]
|
| 428 |
-
|
| 429 |
-
# Rating prediction with bias terms
|
| 430 |
-
concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
|
| 431 |
-
rating_predictions = self.rating_model(concatenated, training=training)
|
| 432 |
-
|
| 433 |
-
# Add bias terms to rating predictions
|
| 434 |
-
rating_predictions_with_bias = rating_predictions + user_bias + item_bias
|
| 435 |
-
rating_predictions_with_bias = tf.nn.sigmoid(rating_predictions_with_bias)
|
| 436 |
-
|
| 437 |
-
# Rating loss
|
| 438 |
-
rating_loss = self.rating_loss(features["rating"], rating_predictions_with_bias)
|
| 439 |
-
|
| 440 |
-
# Temperature-scaled retrieval loss
|
| 441 |
-
scaled_similarities = self.temperature_similarity(user_embeddings, item_embeddings)
|
| 442 |
-
retrieval_loss = tf.keras.losses.binary_crossentropy(
|
| 443 |
-
features["rating"],
|
| 444 |
-
tf.nn.sigmoid(scaled_similarities)
|
| 445 |
-
)
|
| 446 |
-
retrieval_loss = tf.reduce_mean(retrieval_loss)
|
| 447 |
-
|
| 448 |
-
# Enhanced contrastive loss with hard negative mining
|
| 449 |
-
batch_size = tf.shape(user_embeddings)[0]
|
| 450 |
-
|
| 451 |
-
if training and batch_size > 5: # Only use hard negatives during training with sufficient batch size
|
| 452 |
-
# Hard negative mining
|
| 453 |
-
hard_negative_indices = self._hard_negative_mining(
|
| 454 |
-
user_embeddings, item_embeddings, features["rating"], num_negatives=3
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
# Positive similarities
|
| 458 |
-
positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 459 |
-
|
| 460 |
-
# Hard negative similarities
|
| 461 |
-
hard_negative_losses = []
|
| 462 |
-
for i in range(3): # Use top 3 hard negatives
|
| 463 |
-
neg_indices = hard_negative_indices[:, i]
|
| 464 |
-
negative_item_embeddings = tf.gather(item_embeddings, neg_indices)
|
| 465 |
-
negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
|
| 466 |
-
|
| 467 |
-
# Triplet-like loss with margin
|
| 468 |
-
margin_loss = tf.maximum(0.0, 0.2 + negative_similarities - positive_similarities)
|
| 469 |
-
hard_negative_losses.append(margin_loss)
|
| 470 |
-
|
| 471 |
-
# Average hard negative losses
|
| 472 |
-
contrastive_loss = tf.reduce_mean(tf.stack(hard_negative_losses))
|
| 473 |
-
|
| 474 |
-
else:
|
| 475 |
-
# Fallback to random negative sampling
|
| 476 |
-
shuffled_indices = tf.random.shuffle(tf.range(batch_size))
|
| 477 |
-
negative_item_embeddings = tf.gather(item_embeddings, shuffled_indices)
|
| 478 |
-
|
| 479 |
-
# Positive similarities
|
| 480 |
-
positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 481 |
-
|
| 482 |
-
# Negative similarities
|
| 483 |
-
negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
|
| 484 |
-
|
| 485 |
-
# Contrastive loss (maximize positive, minimize negative)
|
| 486 |
-
contrastive_loss = tf.reduce_mean(
|
| 487 |
-
tf.maximum(0.0, 0.5 + negative_similarities - positive_similarities)
|
| 488 |
-
)
|
| 489 |
-
|
| 490 |
-
# Combine losses
|
| 491 |
-
total_loss = (
|
| 492 |
-
self.rating_weight * rating_loss +
|
| 493 |
-
self.retrieval_weight * retrieval_loss +
|
| 494 |
-
self.contrastive_weight * contrastive_loss
|
| 495 |
-
)
|
| 496 |
-
|
| 497 |
-
# Add L2 regularization to prevent overfitting
|
| 498 |
-
l2_loss = tf.add_n([
|
| 499 |
-
tf.nn.l2_loss(var) for var in self.trainable_variables
|
| 500 |
-
if 'bias' not in var.name and 'normalization' not in var.name
|
| 501 |
-
]) * 1e-5
|
| 502 |
-
|
| 503 |
-
total_loss += l2_loss
|
| 504 |
-
|
| 505 |
-
return {
|
| 506 |
-
'total_loss': total_loss,
|
| 507 |
-
'rating_loss': rating_loss,
|
| 508 |
-
'retrieval_loss': retrieval_loss,
|
| 509 |
-
'contrastive_loss': contrastive_loss,
|
| 510 |
-
'l2_loss': l2_loss
|
| 511 |
-
}
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
def create_improved_model(data_processor,
|
| 515 |
-
embedding_dim=128,
|
| 516 |
-
use_bias=True,
|
| 517 |
-
use_focal_loss=True):
|
| 518 |
-
"""Factory function to create improved two-tower model."""
|
| 519 |
-
|
| 520 |
-
# Create enhanced towers
|
| 521 |
-
item_tower = ImprovedItemTower(
|
| 522 |
-
item_vocab_size=len(data_processor.item_vocab),
|
| 523 |
-
category_vocab_size=len(data_processor.category_vocab),
|
| 524 |
-
brand_vocab_size=len(data_processor.brand_vocab),
|
| 525 |
-
embedding_dim=embedding_dim,
|
| 526 |
-
use_bias=use_bias
|
| 527 |
-
)
|
| 528 |
-
|
| 529 |
-
user_tower = ImprovedUserTower(
|
| 530 |
-
max_history_length=50,
|
| 531 |
-
embedding_dim=embedding_dim,
|
| 532 |
-
use_bias=use_bias
|
| 533 |
-
)
|
| 534 |
-
|
| 535 |
-
# Create improved model
|
| 536 |
-
model = ImprovedTwoTowerModel(
|
| 537 |
-
item_tower=item_tower,
|
| 538 |
-
user_tower=user_tower,
|
| 539 |
-
rating_weight=1.0,
|
| 540 |
-
retrieval_weight=0.5,
|
| 541 |
-
contrastive_weight=0.3,
|
| 542 |
-
use_focal_loss=use_focal_loss
|
| 543 |
-
)
|
| 544 |
-
|
| 545 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/user_tower.py
CHANGED
|
@@ -32,6 +32,27 @@ class UserTower(tf.keras.Model):
|
|
| 32 |
2, embedding_dim // 16, name="gender_embedding"
|
| 33 |
)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# History aggregation layers
|
| 36 |
self.history_attention = tf.keras.layers.MultiHeadAttention(
|
| 37 |
num_heads=4,
|
|
@@ -57,12 +78,20 @@ class UserTower(tf.keras.Model):
|
|
| 57 |
age = inputs["age"] # Now categorical (0-5)
|
| 58 |
gender = inputs["gender"] # Categorical (0-1)
|
| 59 |
income = inputs["income"] # Now categorical (0-4)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
item_history = inputs["item_history_embeddings"] # [batch_size, seq_len, emb_dim]
|
| 61 |
|
| 62 |
# Process demographics through embeddings
|
| 63 |
age_emb = self.age_embedding(age) # [batch_size, embedding_dim//16]
|
| 64 |
income_emb = self.income_embedding(income) # [batch_size, embedding_dim//16]
|
| 65 |
gender_emb = self.gender_embedding(gender) # [batch_size, embedding_dim//16]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# Aggregate item history using attention
|
| 68 |
# Create attention mask for padding
|
|
@@ -84,6 +113,10 @@ class UserTower(tf.keras.Model):
|
|
| 84 |
age_emb,
|
| 85 |
income_emb,
|
| 86 |
gender_emb,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
history_aggregated
|
| 88 |
], axis=-1)
|
| 89 |
|
|
|
|
| 32 |
2, embedding_dim // 16, name="gender_embedding"
|
| 33 |
)
|
| 34 |
|
| 35 |
+
# New demographic embeddings
|
| 36 |
+
# Profession: 8 categories (Technology, Healthcare, Education, Finance, Retail, Manufacturing, Services, Other)
|
| 37 |
+
self.profession_embedding = tf.keras.layers.Embedding(
|
| 38 |
+
8, embedding_dim // 16, name="profession_embedding"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Location: 3 categories (Urban, Suburban, Rural)
|
| 42 |
+
self.location_embedding = tf.keras.layers.Embedding(
|
| 43 |
+
3, embedding_dim // 16, name="location_embedding"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Education Level: 5 categories (High School, Some College, Bachelor's, Master's, PhD+)
|
| 47 |
+
self.education_embedding = tf.keras.layers.Embedding(
|
| 48 |
+
5, embedding_dim // 16, name="education_embedding"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Marital Status: 4 categories (Single, Married, Divorced, Widowed)
|
| 52 |
+
self.marital_embedding = tf.keras.layers.Embedding(
|
| 53 |
+
4, embedding_dim // 16, name="marital_embedding"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
# History aggregation layers
|
| 57 |
self.history_attention = tf.keras.layers.MultiHeadAttention(
|
| 58 |
num_heads=4,
|
|
|
|
| 78 |
age = inputs["age"] # Now categorical (0-5)
|
| 79 |
gender = inputs["gender"] # Categorical (0-1)
|
| 80 |
income = inputs["income"] # Now categorical (0-4)
|
| 81 |
+
profession = inputs["profession"] # Categorical (0-7)
|
| 82 |
+
location = inputs["location"] # Categorical (0-2)
|
| 83 |
+
education = inputs["education_level"] # Categorical (0-4)
|
| 84 |
+
marital_status = inputs["marital_status"] # Categorical (0-3)
|
| 85 |
item_history = inputs["item_history_embeddings"] # [batch_size, seq_len, emb_dim]
|
| 86 |
|
| 87 |
# Process demographics through embeddings
|
| 88 |
age_emb = self.age_embedding(age) # [batch_size, embedding_dim//16]
|
| 89 |
income_emb = self.income_embedding(income) # [batch_size, embedding_dim//16]
|
| 90 |
gender_emb = self.gender_embedding(gender) # [batch_size, embedding_dim//16]
|
| 91 |
+
profession_emb = self.profession_embedding(profession) # [batch_size, embedding_dim//16]
|
| 92 |
+
location_emb = self.location_embedding(location) # [batch_size, embedding_dim//16]
|
| 93 |
+
education_emb = self.education_embedding(education) # [batch_size, embedding_dim//16]
|
| 94 |
+
marital_emb = self.marital_embedding(marital_status) # [batch_size, embedding_dim//16]
|
| 95 |
|
| 96 |
# Aggregate item history using attention
|
| 97 |
# Create attention mask for padding
|
|
|
|
| 113 |
age_emb,
|
| 114 |
income_emb,
|
| 115 |
gender_emb,
|
| 116 |
+
profession_emb,
|
| 117 |
+
location_emb,
|
| 118 |
+
education_emb,
|
| 119 |
+
marital_emb,
|
| 120 |
history_aggregated
|
| 121 |
], axis=-1)
|
| 122 |
|
src/preprocessing/optimized_dataset_creator.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Optimized dataset creation script with performance improvements.
|
| 3 |
-
"""
|
| 4 |
-
import time
|
| 5 |
-
import numpy as np
|
| 6 |
-
from src.preprocessing.user_data_preparation import UserDatasetCreator
|
| 7 |
-
from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def create_optimized_dataset(max_history_length: int = 50,
|
| 11 |
-
batch_size: int = 512,
|
| 12 |
-
negative_samples_per_positive: int = 2,
|
| 13 |
-
use_sample: bool = False,
|
| 14 |
-
sample_size: int = 10000):
|
| 15 |
-
"""
|
| 16 |
-
Create dataset with optimized performance settings.
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
max_history_length: Maximum user interaction history length
|
| 20 |
-
batch_size: Batch size for TensorFlow dataset
|
| 21 |
-
negative_samples_per_positive: Negative sampling ratio
|
| 22 |
-
use_sample: Whether to use a sample of the data for faster processing
|
| 23 |
-
sample_size: Size of sample if use_sample=True
|
| 24 |
-
"""
|
| 25 |
-
print("Starting optimized dataset creation...")
|
| 26 |
-
start_time = time.time()
|
| 27 |
-
|
| 28 |
-
# Initialize with optimized settings
|
| 29 |
-
dataset_creator = UserDatasetCreator(max_history_length=max_history_length)
|
| 30 |
-
data_processor = DataProcessor()
|
| 31 |
-
|
| 32 |
-
# Load data
|
| 33 |
-
print("Loading data...")
|
| 34 |
-
load_start = time.time()
|
| 35 |
-
items_df, users_df, interactions_df = data_processor.load_data()
|
| 36 |
-
print(f"Data loaded in {time.time() - load_start:.2f} seconds")
|
| 37 |
-
|
| 38 |
-
# Optional: Use sample for faster development/testing
|
| 39 |
-
if use_sample:
|
| 40 |
-
print(f"Using sample of {sample_size} interactions for faster processing...")
|
| 41 |
-
sample_interactions = interactions_df.sample(min(sample_size, len(interactions_df)))
|
| 42 |
-
user_ids = set(sample_interactions['user_id'])
|
| 43 |
-
item_ids = set(sample_interactions['product_id'])
|
| 44 |
-
|
| 45 |
-
users_df = users_df[users_df['user_id'].isin(user_ids)]
|
| 46 |
-
items_df = items_df[items_df['product_id'].isin(item_ids)]
|
| 47 |
-
interactions_df = sample_interactions
|
| 48 |
-
|
| 49 |
-
print(f"Sample: {len(items_df)} items, {len(users_df)} users, {len(interactions_df)} interactions")
|
| 50 |
-
|
| 51 |
-
# Load embeddings with caching
|
| 52 |
-
print("Loading item embeddings...")
|
| 53 |
-
embed_start = time.time()
|
| 54 |
-
item_embeddings = dataset_creator.load_item_embeddings()
|
| 55 |
-
print(f"Embeddings loaded in {time.time() - embed_start:.2f} seconds")
|
| 56 |
-
|
| 57 |
-
# Create temporal split
|
| 58 |
-
print("Creating temporal split...")
|
| 59 |
-
split_start = time.time()
|
| 60 |
-
train_interactions, val_interactions = dataset_creator.create_temporal_split(interactions_df)
|
| 61 |
-
print(f"Temporal split created in {time.time() - split_start:.2f} seconds")
|
| 62 |
-
|
| 63 |
-
# Create training dataset with optimizations
|
| 64 |
-
print("Creating optimized training dataset...")
|
| 65 |
-
train_start = time.time()
|
| 66 |
-
training_features = dataset_creator.create_training_dataset(
|
| 67 |
-
train_interactions, items_df, users_df, item_embeddings,
|
| 68 |
-
negative_samples_per_positive=negative_samples_per_positive
|
| 69 |
-
)
|
| 70 |
-
print(f"Training dataset created in {time.time() - train_start:.2f} seconds")
|
| 71 |
-
|
| 72 |
-
# Create TensorFlow dataset optimized for CPU
|
| 73 |
-
print("Creating TensorFlow dataset...")
|
| 74 |
-
tf_start = time.time()
|
| 75 |
-
tf_dataset = create_tf_dataset(training_features, batch_size=batch_size)
|
| 76 |
-
print(f"TensorFlow dataset created in {time.time() - tf_start:.2f} seconds")
|
| 77 |
-
|
| 78 |
-
# Save optimized dataset
|
| 79 |
-
print("Saving dataset...")
|
| 80 |
-
save_start = time.time()
|
| 81 |
-
dataset_creator.save_dataset(training_features, "src/artifacts/")
|
| 82 |
-
|
| 83 |
-
# Save vocabularies for later use
|
| 84 |
-
data_processor.save_vocabularies("src/artifacts/")
|
| 85 |
-
print(f"Dataset saved in {time.time() - save_start:.2f} seconds")
|
| 86 |
-
|
| 87 |
-
total_time = time.time() - start_time
|
| 88 |
-
print(f"\nOptimized dataset creation completed in {total_time:.2f} seconds!")
|
| 89 |
-
print(f"Training samples: {len(training_features['rating'])}")
|
| 90 |
-
print(f"Memory usage optimized for CPU training")
|
| 91 |
-
|
| 92 |
-
return tf_dataset, training_features
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if __name__ == "__main__":
|
| 96 |
-
# Run with optimized settings
|
| 97 |
-
tf_dataset, features = create_optimized_dataset(
|
| 98 |
-
max_history_length=30, # Reduced for speed
|
| 99 |
-
batch_size=512, # Larger batches for CPU efficiency
|
| 100 |
-
negative_samples_per_positive=2, # Reduced sampling ratio
|
| 101 |
-
use_sample=True, # Use sample for development
|
| 102 |
-
sample_size=50000 # Reasonable sample size
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
print("\nDataset creation optimization complete!")
|
| 106 |
-
print("Key optimizations applied:")
|
| 107 |
-
print("- Vectorized DataFrame operations")
|
| 108 |
-
print("- Parallel negative sampling")
|
| 109 |
-
print("- Memory-efficient embedding lookup")
|
| 110 |
-
print("- Optimized TensorFlow dataset pipeline")
|
| 111 |
-
print("- LRU caching for embeddings")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/preprocessing/user_data_preparation.py
CHANGED
|
@@ -45,6 +45,50 @@ class UserDatasetCreator:
|
|
| 45 |
categories = np.clip(categories, 0, 4)
|
| 46 |
|
| 47 |
return categories.astype(np.int32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
@lru_cache(maxsize=1)
|
| 50 |
def load_item_embeddings(self, embeddings_path: str = "src/artifacts/item_embeddings.npy") -> Dict[int, np.ndarray]:
|
|
@@ -155,6 +199,12 @@ class UserDatasetCreator:
|
|
| 155 |
# Categorize income (5 percentile-based categories)
|
| 156 |
user_demographics['income_category'] = self.categorize_income(user_demographics['income'])
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
# Create mapping from user_id to array index
|
| 159 |
user_id_to_index = {uid: idx for idx, uid in enumerate(user_demographics['user_id'])}
|
| 160 |
|
|
@@ -165,6 +215,10 @@ class UserDatasetCreator:
|
|
| 165 |
'age': user_demographics['age_category'].values.astype(np.int32), # Categorical age
|
| 166 |
'gender': user_demographics['gender_numeric'].values.astype(np.int32),
|
| 167 |
'income': user_demographics['income_category'].values.astype(np.int32), # Categorical income
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
'item_history_embeddings': np.array([
|
| 169 |
user_aggregated_embeddings[uid] for uid in user_demographics['user_id']
|
| 170 |
]).astype(np.float32)
|
|
@@ -173,6 +227,10 @@ class UserDatasetCreator:
|
|
| 173 |
print(f"Prepared user features for {len(valid_users)} users")
|
| 174 |
print(f"Age categories: {np.unique(user_features['age'], return_counts=True)}")
|
| 175 |
print(f"Income categories: {np.unique(user_features['income'], return_counts=True)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
print(f"History embeddings shape: {user_features['item_history_embeddings'].shape}")
|
| 177 |
|
| 178 |
return user_features
|
|
@@ -256,6 +314,10 @@ class UserDatasetCreator:
|
|
| 256 |
training_features['age'] = user_features['age'][user_indices]
|
| 257 |
training_features['gender'] = user_features['gender'][user_indices]
|
| 258 |
training_features['income'] = user_features['income'][user_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
training_features['item_history_embeddings'] = user_features['item_history_embeddings'][user_indices]
|
| 260 |
|
| 261 |
# Item features for each pair
|
|
@@ -412,10 +474,20 @@ def prepare_user_features(users_df: pd.DataFrame,
|
|
| 412 |
user_idx = users_df[users_df['user_id'] == user_id].index[0]
|
| 413 |
income_cat = income_categories[user_idx]
|
| 414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
user_feature_dict[user_id] = {
|
| 416 |
'age': age_cat,
|
| 417 |
'gender': gender_cat,
|
| 418 |
'income': income_cat,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
'item_history_embeddings': user_aggregated_embeddings[user_id]
|
| 420 |
}
|
| 421 |
|
|
|
|
| 45 |
categories = np.clip(categories, 0, 4)
|
| 46 |
|
| 47 |
return categories.astype(np.int32)
|
| 48 |
+
|
| 49 |
+
def categorize_profession(self, profession: str) -> int:
|
| 50 |
+
"""Categorize profession into numeric categories."""
|
| 51 |
+
profession_map = {
|
| 52 |
+
"Technology": 0,
|
| 53 |
+
"Healthcare": 1,
|
| 54 |
+
"Education": 2,
|
| 55 |
+
"Finance": 3,
|
| 56 |
+
"Retail": 4,
|
| 57 |
+
"Manufacturing": 5,
|
| 58 |
+
"Services": 6,
|
| 59 |
+
"Other": 7
|
| 60 |
+
}
|
| 61 |
+
return profession_map.get(profession, 7) # Default to "Other"
|
| 62 |
+
|
| 63 |
+
def categorize_location(self, location: str) -> int:
|
| 64 |
+
"""Categorize location into numeric categories."""
|
| 65 |
+
location_map = {
|
| 66 |
+
"Urban": 0,
|
| 67 |
+
"Suburban": 1,
|
| 68 |
+
"Rural": 2
|
| 69 |
+
}
|
| 70 |
+
return location_map.get(location, 0) # Default to "Urban"
|
| 71 |
+
|
| 72 |
+
def categorize_education_level(self, education: str) -> int:
|
| 73 |
+
"""Categorize education level into numeric categories."""
|
| 74 |
+
education_map = {
|
| 75 |
+
"High School": 0,
|
| 76 |
+
"Some College": 1,
|
| 77 |
+
"Bachelor's": 2,
|
| 78 |
+
"Master's": 3,
|
| 79 |
+
"PhD+": 4
|
| 80 |
+
}
|
| 81 |
+
return education_map.get(education, 0) # Default to "High School"
|
| 82 |
+
|
| 83 |
+
def categorize_marital_status(self, marital_status: str) -> int:
|
| 84 |
+
"""Categorize marital status into numeric categories."""
|
| 85 |
+
marital_map = {
|
| 86 |
+
"Single": 0,
|
| 87 |
+
"Married": 1,
|
| 88 |
+
"Divorced": 2,
|
| 89 |
+
"Widowed": 3
|
| 90 |
+
}
|
| 91 |
+
return marital_map.get(marital_status, 0) # Default to "Single"
|
| 92 |
|
| 93 |
@lru_cache(maxsize=1)
|
| 94 |
def load_item_embeddings(self, embeddings_path: str = "src/artifacts/item_embeddings.npy") -> Dict[int, np.ndarray]:
|
|
|
|
| 199 |
# Categorize income (5 percentile-based categories)
|
| 200 |
user_demographics['income_category'] = self.categorize_income(user_demographics['income'])
|
| 201 |
|
| 202 |
+
# Categorize new demographic features
|
| 203 |
+
user_demographics['profession_category'] = user_demographics['profession'].apply(self.categorize_profession)
|
| 204 |
+
user_demographics['location_category'] = user_demographics['location'].apply(self.categorize_location)
|
| 205 |
+
user_demographics['education_category'] = user_demographics['education_level'].apply(self.categorize_education_level)
|
| 206 |
+
user_demographics['marital_category'] = user_demographics['marital_status'].apply(self.categorize_marital_status)
|
| 207 |
+
|
| 208 |
# Create mapping from user_id to array index
|
| 209 |
user_id_to_index = {uid: idx for idx, uid in enumerate(user_demographics['user_id'])}
|
| 210 |
|
|
|
|
| 215 |
'age': user_demographics['age_category'].values.astype(np.int32), # Categorical age
|
| 216 |
'gender': user_demographics['gender_numeric'].values.astype(np.int32),
|
| 217 |
'income': user_demographics['income_category'].values.astype(np.int32), # Categorical income
|
| 218 |
+
'profession': user_demographics['profession_category'].values.astype(np.int32), # Categorical profession
|
| 219 |
+
'location': user_demographics['location_category'].values.astype(np.int32), # Categorical location
|
| 220 |
+
'education_level': user_demographics['education_category'].values.astype(np.int32), # Categorical education
|
| 221 |
+
'marital_status': user_demographics['marital_category'].values.astype(np.int32), # Categorical marital status
|
| 222 |
'item_history_embeddings': np.array([
|
| 223 |
user_aggregated_embeddings[uid] for uid in user_demographics['user_id']
|
| 224 |
]).astype(np.float32)
|
|
|
|
| 227 |
print(f"Prepared user features for {len(valid_users)} users")
|
| 228 |
print(f"Age categories: {np.unique(user_features['age'], return_counts=True)}")
|
| 229 |
print(f"Income categories: {np.unique(user_features['income'], return_counts=True)}")
|
| 230 |
+
print(f"Profession categories: {np.unique(user_features['profession'], return_counts=True)}")
|
| 231 |
+
print(f"Location categories: {np.unique(user_features['location'], return_counts=True)}")
|
| 232 |
+
print(f"Education categories: {np.unique(user_features['education_level'], return_counts=True)}")
|
| 233 |
+
print(f"Marital status categories: {np.unique(user_features['marital_status'], return_counts=True)}")
|
| 234 |
print(f"History embeddings shape: {user_features['item_history_embeddings'].shape}")
|
| 235 |
|
| 236 |
return user_features
|
|
|
|
| 314 |
training_features['age'] = user_features['age'][user_indices]
|
| 315 |
training_features['gender'] = user_features['gender'][user_indices]
|
| 316 |
training_features['income'] = user_features['income'][user_indices]
|
| 317 |
+
training_features['profession'] = user_features['profession'][user_indices]
|
| 318 |
+
training_features['location'] = user_features['location'][user_indices]
|
| 319 |
+
training_features['education_level'] = user_features['education_level'][user_indices]
|
| 320 |
+
training_features['marital_status'] = user_features['marital_status'][user_indices]
|
| 321 |
training_features['item_history_embeddings'] = user_features['item_history_embeddings'][user_indices]
|
| 322 |
|
| 323 |
# Item features for each pair
|
|
|
|
| 474 |
user_idx = users_df[users_df['user_id'] == user_id].index[0]
|
| 475 |
income_cat = income_categories[user_idx]
|
| 476 |
|
| 477 |
+
# Get new demographic features from the row
|
| 478 |
+
profession_cat = creator.categorize_profession(user_row.get('profession', 'Other'))
|
| 479 |
+
location_cat = creator.categorize_location(user_row.get('location', 'Urban'))
|
| 480 |
+
education_cat = creator.categorize_education_level(user_row.get('education_level', 'High School'))
|
| 481 |
+
marital_cat = creator.categorize_marital_status(user_row.get('marital_status', 'Single'))
|
| 482 |
+
|
| 483 |
user_feature_dict[user_id] = {
|
| 484 |
'age': age_cat,
|
| 485 |
'gender': gender_cat,
|
| 486 |
'income': income_cat,
|
| 487 |
+
'profession': profession_cat,
|
| 488 |
+
'location': location_cat,
|
| 489 |
+
'education_level': education_cat,
|
| 490 |
+
'marital_status': marital_cat,
|
| 491 |
'item_history_embeddings': user_aggregated_embeddings[user_id]
|
| 492 |
}
|
| 493 |
|
src/training/curriculum_trainer.py
DELETED
|
@@ -1,341 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Curriculum learning trainer for the improved two-tower model.
|
| 4 |
-
Implements progressive difficulty training for better convergence.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import tensorflow as tf
|
| 8 |
-
import numpy as np
|
| 9 |
-
import pickle
|
| 10 |
-
import os
|
| 11 |
-
import time
|
| 12 |
-
from typing import Dict, List, Tuple
|
| 13 |
-
|
| 14 |
-
from src.models.improved_two_tower import create_improved_model
|
| 15 |
-
from src.preprocessing.data_loader import DataProcessor
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class CurriculumTrainer:
|
| 19 |
-
"""Trainer with curriculum learning for improved two-tower model."""
|
| 20 |
-
|
| 21 |
-
def __init__(self,
|
| 22 |
-
embedding_dim: int = 128,
|
| 23 |
-
learning_rate: float = 0.001,
|
| 24 |
-
use_focal_loss: bool = True,
|
| 25 |
-
curriculum_stages: int = 3):
|
| 26 |
-
|
| 27 |
-
self.embedding_dim = embedding_dim
|
| 28 |
-
self.learning_rate = learning_rate
|
| 29 |
-
self.use_focal_loss = use_focal_loss
|
| 30 |
-
self.curriculum_stages = curriculum_stages
|
| 31 |
-
|
| 32 |
-
self.data_processor = None
|
| 33 |
-
self.model = None
|
| 34 |
-
|
| 35 |
-
def load_data_processor(self, artifacts_path: str = "src/artifacts/"):
|
| 36 |
-
"""Load data processor with vocabularies."""
|
| 37 |
-
self.data_processor = DataProcessor()
|
| 38 |
-
self.data_processor.load_vocabularies(f"{artifacts_path}/vocabularies.pkl")
|
| 39 |
-
print("Data processor loaded successfully")
|
| 40 |
-
|
| 41 |
-
def create_model(self):
|
| 42 |
-
"""Create improved two-tower model."""
|
| 43 |
-
if self.data_processor is None:
|
| 44 |
-
raise ValueError("Data processor must be loaded first")
|
| 45 |
-
|
| 46 |
-
self.model = create_improved_model(
|
| 47 |
-
data_processor=self.data_processor,
|
| 48 |
-
embedding_dim=self.embedding_dim,
|
| 49 |
-
use_bias=True,
|
| 50 |
-
use_focal_loss=self.use_focal_loss
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
# Compile model
|
| 54 |
-
self.model.compile(
|
| 55 |
-
optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
print("Improved two-tower model created successfully")
|
| 59 |
-
|
| 60 |
-
def _create_curriculum_stages(self, features: Dict[str, np.ndarray]) -> List[Dict[str, np.ndarray]]:
|
| 61 |
-
"""Create curriculum stages based on interaction complexity."""
|
| 62 |
-
|
| 63 |
-
# Calculate interaction history lengths for curriculum
|
| 64 |
-
history_lengths = []
|
| 65 |
-
for i in range(len(features['age'])):
|
| 66 |
-
hist = features['item_history_embeddings'][i]
|
| 67 |
-
# Count non-zero embeddings
|
| 68 |
-
length = np.sum(np.any(hist != 0, axis=1))
|
| 69 |
-
history_lengths.append(length)
|
| 70 |
-
|
| 71 |
-
history_lengths = np.array(history_lengths)
|
| 72 |
-
|
| 73 |
-
# Create stages based on history length percentiles
|
| 74 |
-
stages = []
|
| 75 |
-
|
| 76 |
-
if self.curriculum_stages == 3:
|
| 77 |
-
# Stage 1: Simple cases (short or no history)
|
| 78 |
-
stage1_mask = history_lengths <= np.percentile(history_lengths, 33)
|
| 79 |
-
|
| 80 |
-
# Stage 2: Medium complexity (medium history)
|
| 81 |
-
stage2_mask = (history_lengths > np.percentile(history_lengths, 33)) & \
|
| 82 |
-
(history_lengths <= np.percentile(history_lengths, 67))
|
| 83 |
-
|
| 84 |
-
# Stage 3: Complex cases (long history)
|
| 85 |
-
stage3_mask = history_lengths > np.percentile(history_lengths, 67)
|
| 86 |
-
|
| 87 |
-
masks = [stage1_mask, stage2_mask, stage3_mask]
|
| 88 |
-
stage_names = ["Simple (short history)", "Medium (moderate history)", "Complex (long history)"]
|
| 89 |
-
|
| 90 |
-
else:
|
| 91 |
-
# Flexible number of stages
|
| 92 |
-
percentiles = np.linspace(0, 100, self.curriculum_stages + 1)
|
| 93 |
-
masks = []
|
| 94 |
-
stage_names = []
|
| 95 |
-
|
| 96 |
-
for i in range(self.curriculum_stages):
|
| 97 |
-
if i == 0:
|
| 98 |
-
mask = history_lengths <= np.percentile(history_lengths, percentiles[i+1])
|
| 99 |
-
stage_names.append(f"Stage {i+1} (β€{percentiles[i+1]:.0f}%ile)")
|
| 100 |
-
elif i == self.curriculum_stages - 1:
|
| 101 |
-
mask = history_lengths > np.percentile(history_lengths, percentiles[i])
|
| 102 |
-
stage_names.append(f"Stage {i+1} (>{percentiles[i]:.0f}%ile)")
|
| 103 |
-
else:
|
| 104 |
-
mask = (history_lengths > np.percentile(history_lengths, percentiles[i])) & \
|
| 105 |
-
(history_lengths <= np.percentile(history_lengths, percentiles[i+1]))
|
| 106 |
-
stage_names.append(f"Stage {i+1} ({percentiles[i]:.0f}-{percentiles[i+1]:.0f}%ile)")
|
| 107 |
-
|
| 108 |
-
masks.append(mask)
|
| 109 |
-
|
| 110 |
-
# Create stage datasets
|
| 111 |
-
for i, (mask, name) in enumerate(zip(masks, stage_names)):
|
| 112 |
-
stage_features = {}
|
| 113 |
-
for key, values in features.items():
|
| 114 |
-
stage_features[key] = values[mask]
|
| 115 |
-
|
| 116 |
-
print(f" Stage {i+1} ({name}): {np.sum(mask)} samples")
|
| 117 |
-
stages.append(stage_features)
|
| 118 |
-
|
| 119 |
-
return stages
|
| 120 |
-
|
| 121 |
-
def _create_tf_dataset(self, features: Dict[str, np.ndarray],
|
| 122 |
-
batch_size: int = 256,
|
| 123 |
-
shuffle: bool = True) -> tf.data.Dataset:
|
| 124 |
-
"""Create TensorFlow dataset from features."""
|
| 125 |
-
|
| 126 |
-
dataset = tf.data.Dataset.from_tensor_slices(features)
|
| 127 |
-
|
| 128 |
-
if shuffle:
|
| 129 |
-
dataset = dataset.shuffle(buffer_size=10000)
|
| 130 |
-
|
| 131 |
-
dataset = dataset.batch(batch_size, drop_remainder=False)
|
| 132 |
-
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 133 |
-
|
| 134 |
-
return dataset
|
| 135 |
-
|
| 136 |
-
def train_with_curriculum(self,
|
| 137 |
-
training_features: Dict[str, np.ndarray],
|
| 138 |
-
validation_features: Dict[str, np.ndarray],
|
| 139 |
-
epochs_per_stage: int = 10,
|
| 140 |
-
batch_size: int = 256) -> Dict:
|
| 141 |
-
"""Train model using curriculum learning."""
|
| 142 |
-
|
| 143 |
-
print(f"π CURRICULUM LEARNING TRAINING")
|
| 144 |
-
print(f"Stages: {self.curriculum_stages} | Epochs per stage: {epochs_per_stage}")
|
| 145 |
-
print("="*70)
|
| 146 |
-
|
| 147 |
-
# Create curriculum stages
|
| 148 |
-
print("\nπ Creating curriculum stages...")
|
| 149 |
-
training_stages = self._create_curriculum_stages(training_features)
|
| 150 |
-
|
| 151 |
-
# Training history
|
| 152 |
-
history = {
|
| 153 |
-
'stage_losses': [],
|
| 154 |
-
'stage_val_losses': [],
|
| 155 |
-
'stage_times': [],
|
| 156 |
-
'total_loss': [],
|
| 157 |
-
'rating_loss': [],
|
| 158 |
-
'retrieval_loss': [],
|
| 159 |
-
'contrastive_loss': [],
|
| 160 |
-
'val_total_loss': [],
|
| 161 |
-
'val_rating_loss': [],
|
| 162 |
-
'val_retrieval_loss': []
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
# Validation dataset (constant across stages)
|
| 166 |
-
val_dataset = self._create_tf_dataset(validation_features, batch_size, shuffle=False)
|
| 167 |
-
|
| 168 |
-
total_start_time = time.time()
|
| 169 |
-
|
| 170 |
-
# Train through curriculum stages
|
| 171 |
-
for stage_idx, stage_features in enumerate(training_stages):
|
| 172 |
-
stage_start_time = time.time()
|
| 173 |
-
|
| 174 |
-
print(f"\nπ― STAGE {stage_idx + 1}/{self.curriculum_stages}")
|
| 175 |
-
print(f"Training samples: {len(stage_features['rating'])}")
|
| 176 |
-
|
| 177 |
-
# Create training dataset for this stage
|
| 178 |
-
train_dataset = self._create_tf_dataset(stage_features, batch_size, shuffle=True)
|
| 179 |
-
|
| 180 |
-
# Adaptive learning rate (decrease as stages progress)
|
| 181 |
-
stage_lr = self.learning_rate * (0.8 ** stage_idx)
|
| 182 |
-
self.model.optimizer.learning_rate.assign(stage_lr)
|
| 183 |
-
print(f"Learning rate: {stage_lr:.6f}")
|
| 184 |
-
|
| 185 |
-
# Train on this stage
|
| 186 |
-
stage_history = {'loss': [], 'val_loss': []}
|
| 187 |
-
|
| 188 |
-
for epoch in range(epochs_per_stage):
|
| 189 |
-
epoch_start = time.time()
|
| 190 |
-
|
| 191 |
-
# Training step
|
| 192 |
-
train_losses = []
|
| 193 |
-
for batch in train_dataset:
|
| 194 |
-
with tf.GradientTape() as tape:
|
| 195 |
-
loss_dict = self.model.compute_loss(batch, training=True)
|
| 196 |
-
total_loss = loss_dict['total_loss']
|
| 197 |
-
|
| 198 |
-
gradients = tape.gradient(total_loss, self.model.trainable_variables)
|
| 199 |
-
self.model.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
|
| 200 |
-
|
| 201 |
-
train_losses.append({k: v.numpy() for k, v in loss_dict.items()})
|
| 202 |
-
|
| 203 |
-
# Average training losses
|
| 204 |
-
avg_train_loss = {}
|
| 205 |
-
for key in train_losses[0].keys():
|
| 206 |
-
avg_train_loss[key] = np.mean([loss[key] for loss in train_losses])
|
| 207 |
-
|
| 208 |
-
# Validation step
|
| 209 |
-
val_losses = []
|
| 210 |
-
for batch in val_dataset:
|
| 211 |
-
loss_dict = self.model.compute_loss(batch, training=False)
|
| 212 |
-
val_losses.append({k: v.numpy() for k, v in loss_dict.items()})
|
| 213 |
-
|
| 214 |
-
# Average validation losses
|
| 215 |
-
avg_val_loss = {}
|
| 216 |
-
for key in val_losses[0].keys():
|
| 217 |
-
avg_val_loss[key] = np.mean([loss[key] for loss in val_losses])
|
| 218 |
-
|
| 219 |
-
# Record epoch results
|
| 220 |
-
stage_history['loss'].append(avg_train_loss['total_loss'])
|
| 221 |
-
stage_history['val_loss'].append(avg_val_loss['total_loss'])
|
| 222 |
-
|
| 223 |
-
# Add to overall history
|
| 224 |
-
for key in ['total_loss', 'rating_loss', 'retrieval_loss', 'contrastive_loss']:
|
| 225 |
-
history[key].append(avg_train_loss[key])
|
| 226 |
-
history[f'val_{key}'].append(avg_val_loss[key])
|
| 227 |
-
|
| 228 |
-
epoch_time = time.time() - epoch_start
|
| 229 |
-
print(f" Epoch {epoch+1:2d}/{epochs_per_stage} | "
|
| 230 |
-
f"Loss: {avg_train_loss['total_loss']:.4f} | "
|
| 231 |
-
f"Val: {avg_val_loss['total_loss']:.4f} | "
|
| 232 |
-
f"Time: {epoch_time:.1f}s")
|
| 233 |
-
|
| 234 |
-
stage_time = time.time() - stage_start_time
|
| 235 |
-
|
| 236 |
-
# Record stage results
|
| 237 |
-
history['stage_losses'].append(stage_history['loss'])
|
| 238 |
-
history['stage_val_losses'].append(stage_history['val_loss'])
|
| 239 |
-
history['stage_times'].append(stage_time)
|
| 240 |
-
|
| 241 |
-
print(f"β
Stage {stage_idx + 1} completed in {stage_time:.1f}s")
|
| 242 |
-
|
| 243 |
-
# Save intermediate model after each stage
|
| 244 |
-
self.save_model(f"src/artifacts/", suffix=f"_stage_{stage_idx + 1}")
|
| 245 |
-
|
| 246 |
-
total_time = time.time() - total_start_time
|
| 247 |
-
|
| 248 |
-
print(f"\nπ CURRICULUM TRAINING COMPLETED!")
|
| 249 |
-
print(f"Total time: {total_time:.1f}s")
|
| 250 |
-
print(f"Average time per stage: {np.mean(history['stage_times']):.1f}s")
|
| 251 |
-
|
| 252 |
-
return history
|
| 253 |
-
|
| 254 |
-
def save_model(self, save_path: str = "src/artifacts/", suffix: str = ""):
|
| 255 |
-
"""Save the trained model."""
|
| 256 |
-
os.makedirs(save_path, exist_ok=True)
|
| 257 |
-
|
| 258 |
-
# Save model weights
|
| 259 |
-
self.model.user_tower.save_weights(f"{save_path}/improved_user_tower_weights{suffix}")
|
| 260 |
-
self.model.item_tower.save_weights(f"{save_path}/improved_item_tower_weights{suffix}")
|
| 261 |
-
self.model.rating_model.save_weights(f"{save_path}/improved_rating_model_weights{suffix}")
|
| 262 |
-
|
| 263 |
-
# Save temperature parameter
|
| 264 |
-
temp_value = self.model.temperature_similarity.temperature.numpy()
|
| 265 |
-
with open(f"{save_path}/temperature_value{suffix}.txt", 'w') as f:
|
| 266 |
-
f.write(str(temp_value))
|
| 267 |
-
|
| 268 |
-
# Save configuration
|
| 269 |
-
config = {
|
| 270 |
-
'embedding_dim': self.embedding_dim,
|
| 271 |
-
'learning_rate': self.learning_rate,
|
| 272 |
-
'use_focal_loss': self.use_focal_loss,
|
| 273 |
-
'curriculum_stages': self.curriculum_stages
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
-
with open(f"{save_path}/curriculum_model_config{suffix}.txt", 'w') as f:
|
| 277 |
-
for key, value in config.items():
|
| 278 |
-
f.write(f"{key}: {value}\n")
|
| 279 |
-
|
| 280 |
-
if not suffix:
|
| 281 |
-
print(f"Model saved to {save_path}")
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
def main():
|
| 285 |
-
"""Main function for curriculum training."""
|
| 286 |
-
|
| 287 |
-
print("π INITIALIZING CURRICULUM TRAINER")
|
| 288 |
-
|
| 289 |
-
# Initialize trainer
|
| 290 |
-
trainer = CurriculumTrainer(
|
| 291 |
-
embedding_dim=128,
|
| 292 |
-
learning_rate=0.001,
|
| 293 |
-
use_focal_loss=True,
|
| 294 |
-
curriculum_stages=3
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# Load data processor
|
| 298 |
-
print("Loading data processor...")
|
| 299 |
-
trainer.load_data_processor()
|
| 300 |
-
|
| 301 |
-
# Create improved model
|
| 302 |
-
print("Creating improved two-tower model...")
|
| 303 |
-
trainer.create_model()
|
| 304 |
-
|
| 305 |
-
# Load training data
|
| 306 |
-
print("Loading training data...")
|
| 307 |
-
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 308 |
-
training_features = pickle.load(f)
|
| 309 |
-
|
| 310 |
-
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 311 |
-
validation_features = pickle.load(f)
|
| 312 |
-
|
| 313 |
-
print(f"Training samples: {len(training_features['rating'])}")
|
| 314 |
-
print(f"Validation samples: {len(validation_features['rating'])}")
|
| 315 |
-
|
| 316 |
-
# Train with curriculum learning
|
| 317 |
-
start_time = time.time()
|
| 318 |
-
|
| 319 |
-
history = trainer.train_with_curriculum(
|
| 320 |
-
training_features=training_features,
|
| 321 |
-
validation_features=validation_features,
|
| 322 |
-
epochs_per_stage=15,
|
| 323 |
-
batch_size=512
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
total_time = time.time() - start_time
|
| 327 |
-
|
| 328 |
-
# Save final model and history
|
| 329 |
-
print("Saving final model...")
|
| 330 |
-
trainer.save_model()
|
| 331 |
-
|
| 332 |
-
with open("src/artifacts/curriculum_training_history.pkl", 'wb') as f:
|
| 333 |
-
pickle.dump(history, f)
|
| 334 |
-
|
| 335 |
-
print(f"\nβ
CURRICULUM TRAINING COMPLETED!")
|
| 336 |
-
print(f"Total training time: {total_time:.1f}s")
|
| 337 |
-
print(f"All improvements implemented successfully!")
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
if __name__ == "__main__":
|
| 341 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/fast_joint_training.py
DELETED
|
@@ -1,268 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Fast joint training with key optimizations for CPU performance.
|
| 3 |
-
"""
|
| 4 |
-
import tensorflow as tf
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
import os
|
| 8 |
-
import time
|
| 9 |
-
from typing import Dict
|
| 10 |
-
|
| 11 |
-
from src.models.item_tower import ItemTower
|
| 12 |
-
from src.models.user_tower import UserTower, TwoTowerModel
|
| 13 |
-
from src.preprocessing.data_loader import DataProcessor
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class FastJointTrainer:
|
| 17 |
-
"""Simplified fast joint training optimized for CPU."""
|
| 18 |
-
|
| 19 |
-
def __init__(self):
|
| 20 |
-
self.item_tower = None
|
| 21 |
-
self.user_tower = None
|
| 22 |
-
self.model = None
|
| 23 |
-
|
| 24 |
-
# Optimized hyperparameters for fast training
|
| 25 |
-
self.user_lr = 0.003
|
| 26 |
-
self.item_lr = 0.0003
|
| 27 |
-
self.batch_size = 2048 # Large batch for efficiency
|
| 28 |
-
self.epochs = 20 # Reduced epochs
|
| 29 |
-
|
| 30 |
-
def load_components(self):
|
| 31 |
-
"""Load all required components."""
|
| 32 |
-
print("Loading components...")
|
| 33 |
-
|
| 34 |
-
# Load data processor
|
| 35 |
-
data_processor = DataProcessor()
|
| 36 |
-
data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
|
| 37 |
-
|
| 38 |
-
# Load item tower config
|
| 39 |
-
with open("src/artifacts/item_tower_config.txt", 'r') as f:
|
| 40 |
-
config = {}
|
| 41 |
-
for line in f:
|
| 42 |
-
key, value = line.strip().split(': ')
|
| 43 |
-
if key in ['embedding_dim', 'dropout_rate']:
|
| 44 |
-
config[key] = float(value) if '.' in value else int(value)
|
| 45 |
-
elif key == 'hidden_dims':
|
| 46 |
-
config[key] = eval(value)
|
| 47 |
-
|
| 48 |
-
# Build item tower
|
| 49 |
-
self.item_tower = ItemTower(
|
| 50 |
-
item_vocab_size=len(data_processor.item_vocab),
|
| 51 |
-
category_vocab_size=len(data_processor.category_vocab),
|
| 52 |
-
brand_vocab_size=len(data_processor.brand_vocab),
|
| 53 |
-
**config
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
# Initialize and load weights
|
| 57 |
-
dummy_input = {
|
| 58 |
-
'product_id': tf.constant([0]),
|
| 59 |
-
'category_id': tf.constant([0]),
|
| 60 |
-
'brand_id': tf.constant([0]),
|
| 61 |
-
'price': tf.constant([0.0])
|
| 62 |
-
}
|
| 63 |
-
_ = self.item_tower(dummy_input)
|
| 64 |
-
self.item_tower.load_weights("src/artifacts/item_tower_weights")
|
| 65 |
-
|
| 66 |
-
# Build user tower (simplified)
|
| 67 |
-
self.user_tower = UserTower(
|
| 68 |
-
max_history_length=50,
|
| 69 |
-
embedding_dim=128, # Updated to 128D
|
| 70 |
-
hidden_dims=[64], # Simplified architecture
|
| 71 |
-
dropout_rate=0.1
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
# Build complete model
|
| 75 |
-
self.model = TwoTowerModel(
|
| 76 |
-
item_tower=self.item_tower,
|
| 77 |
-
user_tower=self.user_tower,
|
| 78 |
-
rating_weight=1.0,
|
| 79 |
-
retrieval_weight=0.2 # Reduced for faster training
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
print("Components loaded successfully")
|
| 83 |
-
|
| 84 |
-
def create_fast_dataset(self, features: Dict, is_training: bool = True):
|
| 85 |
-
"""Create optimized dataset pipeline."""
|
| 86 |
-
dataset = tf.data.Dataset.from_tensor_slices(features)
|
| 87 |
-
|
| 88 |
-
if is_training:
|
| 89 |
-
dataset = dataset.shuffle(buffer_size=5000)
|
| 90 |
-
dataset = dataset.repeat()
|
| 91 |
-
|
| 92 |
-
dataset = dataset.batch(self.batch_size, drop_remainder=True)
|
| 93 |
-
dataset = dataset.prefetch(2) # Conservative prefetch for CPU
|
| 94 |
-
|
| 95 |
-
return dataset
|
| 96 |
-
|
| 97 |
-
def train_fast(self, training_features: Dict, validation_features: Dict):
|
| 98 |
-
"""Fast training loop with key optimizations."""
|
| 99 |
-
|
| 100 |
-
print(f"Starting fast training: {self.epochs} epochs, batch size {self.batch_size}")
|
| 101 |
-
|
| 102 |
-
# Setup datasets
|
| 103 |
-
steps_per_epoch = len(training_features['rating']) // self.batch_size
|
| 104 |
-
val_steps = len(validation_features['rating']) // self.batch_size
|
| 105 |
-
|
| 106 |
-
train_ds = self.create_fast_dataset(training_features, is_training=True)
|
| 107 |
-
val_ds = self.create_fast_dataset(validation_features, is_training=False)
|
| 108 |
-
|
| 109 |
-
# Note: Age and income are now categorical - no normalization needed
|
| 110 |
-
|
| 111 |
-
# Setup optimizers
|
| 112 |
-
user_optimizer = tf.keras.optimizers.Adam(learning_rate=self.user_lr)
|
| 113 |
-
item_optimizer = tf.keras.optimizers.Adam(learning_rate=self.item_lr)
|
| 114 |
-
|
| 115 |
-
# Training loop
|
| 116 |
-
train_iter = iter(train_ds)
|
| 117 |
-
val_iter = iter(val_ds)
|
| 118 |
-
|
| 119 |
-
best_val_loss = float('inf')
|
| 120 |
-
|
| 121 |
-
for epoch in range(self.epochs):
|
| 122 |
-
epoch_start = time.time()
|
| 123 |
-
|
| 124 |
-
# Progressive unfreezing - simple strategy
|
| 125 |
-
train_item = epoch >= (self.epochs // 4) # Unfreeze after 25%
|
| 126 |
-
|
| 127 |
-
print(f"Epoch {epoch+1}/{self.epochs} - Item training: {'ON' if train_item else 'OFF'}")
|
| 128 |
-
|
| 129 |
-
# Training
|
| 130 |
-
train_losses = []
|
| 131 |
-
for step in range(steps_per_epoch):
|
| 132 |
-
try:
|
| 133 |
-
batch = next(train_iter)
|
| 134 |
-
except StopIteration:
|
| 135 |
-
train_iter = iter(train_ds)
|
| 136 |
-
batch = next(train_iter)
|
| 137 |
-
|
| 138 |
-
with tf.GradientTape() as tape:
|
| 139 |
-
# Forward pass
|
| 140 |
-
user_emb = self.user_tower(batch, training=True)
|
| 141 |
-
item_emb = self.item_tower(batch, training=True)
|
| 142 |
-
|
| 143 |
-
# Rating prediction
|
| 144 |
-
concat_emb = tf.concat([user_emb, item_emb], axis=-1)
|
| 145 |
-
rating_pred = self.model.rating_model(concat_emb, training=True)
|
| 146 |
-
|
| 147 |
-
# Simple loss calculation
|
| 148 |
-
rating_loss = tf.keras.losses.binary_crossentropy(
|
| 149 |
-
batch["rating"], tf.squeeze(rating_pred)
|
| 150 |
-
)
|
| 151 |
-
rating_loss = tf.reduce_mean(rating_loss)
|
| 152 |
-
|
| 153 |
-
# Simplified retrieval loss
|
| 154 |
-
similarity = tf.reduce_sum(user_emb * item_emb, axis=1)
|
| 155 |
-
retrieval_loss = tf.keras.losses.binary_crossentropy(
|
| 156 |
-
batch["rating"], tf.nn.sigmoid(similarity)
|
| 157 |
-
)
|
| 158 |
-
retrieval_loss = tf.reduce_mean(retrieval_loss)
|
| 159 |
-
|
| 160 |
-
total_loss = rating_loss + 0.2 * retrieval_loss
|
| 161 |
-
|
| 162 |
-
# Gradient computation and application
|
| 163 |
-
if train_item:
|
| 164 |
-
# Train both towers
|
| 165 |
-
user_vars = self.user_tower.trainable_variables + self.model.rating_model.trainable_variables
|
| 166 |
-
item_vars = self.item_tower.trainable_variables
|
| 167 |
-
all_vars = user_vars + item_vars
|
| 168 |
-
|
| 169 |
-
grads = tape.gradient(total_loss, all_vars)
|
| 170 |
-
user_grads = grads[:len(user_vars)]
|
| 171 |
-
item_grads = grads[len(user_vars):]
|
| 172 |
-
|
| 173 |
-
user_optimizer.apply_gradients(zip(user_grads, user_vars))
|
| 174 |
-
item_optimizer.apply_gradients(zip(item_grads, item_vars))
|
| 175 |
-
else:
|
| 176 |
-
# Train only user tower
|
| 177 |
-
user_vars = self.user_tower.trainable_variables + self.model.rating_model.trainable_variables
|
| 178 |
-
grads = tape.gradient(total_loss, user_vars)
|
| 179 |
-
user_optimizer.apply_gradients(zip(grads, user_vars))
|
| 180 |
-
|
| 181 |
-
train_losses.append(total_loss.numpy())
|
| 182 |
-
|
| 183 |
-
# Validation
|
| 184 |
-
val_losses = []
|
| 185 |
-
for step in range(val_steps):
|
| 186 |
-
try:
|
| 187 |
-
batch = next(val_iter)
|
| 188 |
-
except StopIteration:
|
| 189 |
-
val_iter = iter(val_ds)
|
| 190 |
-
batch = next(val_iter)
|
| 191 |
-
|
| 192 |
-
user_emb = self.user_tower(batch, training=False)
|
| 193 |
-
item_emb = self.item_tower(batch, training=False)
|
| 194 |
-
|
| 195 |
-
concat_emb = tf.concat([user_emb, item_emb], axis=-1)
|
| 196 |
-
rating_pred = self.model.rating_model(concat_emb, training=False)
|
| 197 |
-
|
| 198 |
-
rating_loss = tf.reduce_mean(
|
| 199 |
-
tf.keras.losses.binary_crossentropy(batch["rating"], tf.squeeze(rating_pred))
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
similarity = tf.reduce_sum(user_emb * item_emb, axis=1)
|
| 203 |
-
retrieval_loss = tf.reduce_mean(
|
| 204 |
-
tf.keras.losses.binary_crossentropy(batch["rating"], tf.nn.sigmoid(similarity))
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
total_loss = rating_loss + 0.2 * retrieval_loss
|
| 208 |
-
val_losses.append(total_loss.numpy())
|
| 209 |
-
|
| 210 |
-
# Calculate averages
|
| 211 |
-
avg_train_loss = np.mean(train_losses)
|
| 212 |
-
avg_val_loss = np.mean(val_losses)
|
| 213 |
-
epoch_time = time.time() - epoch_start
|
| 214 |
-
|
| 215 |
-
print(f"Time: {epoch_time:.1f}s | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
|
| 216 |
-
|
| 217 |
-
# Save best model
|
| 218 |
-
if avg_val_loss < best_val_loss:
|
| 219 |
-
best_val_loss = avg_val_loss
|
| 220 |
-
self.save_model("_best")
|
| 221 |
-
|
| 222 |
-
print("Fast training completed!")
|
| 223 |
-
|
| 224 |
-
def save_model(self, suffix=""):
|
| 225 |
-
"""Save trained model."""
|
| 226 |
-
save_path = "src/artifacts/"
|
| 227 |
-
|
| 228 |
-
self.user_tower.save_weights(f"{save_path}/user_tower_weights{suffix}")
|
| 229 |
-
self.item_tower.save_weights(f"{save_path}/item_tower_weights_finetuned{suffix}")
|
| 230 |
-
self.model.rating_model.save_weights(f"{save_path}/rating_model_weights{suffix}")
|
| 231 |
-
|
| 232 |
-
if not suffix:
|
| 233 |
-
print("Model saved successfully")
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def main():
|
| 237 |
-
"""Main function for fast joint training."""
|
| 238 |
-
|
| 239 |
-
print("=== Fast Joint Training ===")
|
| 240 |
-
|
| 241 |
-
# Initialize trainer
|
| 242 |
-
trainer = FastJointTrainer()
|
| 243 |
-
trainer.load_components()
|
| 244 |
-
|
| 245 |
-
# Load training data
|
| 246 |
-
print("Loading training data...")
|
| 247 |
-
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 248 |
-
training_features = pickle.load(f)
|
| 249 |
-
|
| 250 |
-
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 251 |
-
validation_features = pickle.load(f)
|
| 252 |
-
|
| 253 |
-
print(f"Training samples: {len(training_features['rating']):,}")
|
| 254 |
-
print(f"Validation samples: {len(validation_features['rating']):,}")
|
| 255 |
-
|
| 256 |
-
# Start training
|
| 257 |
-
start_time = time.time()
|
| 258 |
-
trainer.train_fast(training_features, validation_features)
|
| 259 |
-
|
| 260 |
-
total_time = time.time() - start_time
|
| 261 |
-
trainer.save_model()
|
| 262 |
-
|
| 263 |
-
print(f"\\nTraining completed in {total_time:.1f} seconds!")
|
| 264 |
-
print(f"Average time per epoch: {total_time/trainer.epochs:.1f}s")
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
if __name__ == "__main__":
|
| 268 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/improved_joint_training.py
DELETED
|
@@ -1,462 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Improved joint training with hard negative mining, curriculum learning, and better optimization.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import tensorflow as tf
|
| 7 |
-
import numpy as np
|
| 8 |
-
import pickle
|
| 9 |
-
import os
|
| 10 |
-
from typing import Dict, List, Tuple, Optional
|
| 11 |
-
import time
|
| 12 |
-
from collections import defaultdict
|
| 13 |
-
|
| 14 |
-
from src.models.improved_two_tower import create_improved_model
|
| 15 |
-
from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class HardNegativeSampler:
|
| 19 |
-
"""Hard negative sampling strategy for better training."""
|
| 20 |
-
|
| 21 |
-
def __init__(self, model, item_embeddings, sampling_strategy='mixed'):
|
| 22 |
-
self.model = model
|
| 23 |
-
self.item_embeddings = item_embeddings # Pre-computed item embeddings
|
| 24 |
-
self.sampling_strategy = sampling_strategy
|
| 25 |
-
|
| 26 |
-
def sample_hard_negatives(self, user_embeddings, positive_items, k_hard=2, k_random=2):
|
| 27 |
-
"""Sample hard negatives based on user-item similarity."""
|
| 28 |
-
batch_size = tf.shape(user_embeddings)[0]
|
| 29 |
-
|
| 30 |
-
# Compute similarities between users and all items
|
| 31 |
-
similarities = tf.linalg.matmul(user_embeddings, self.item_embeddings, transpose_b=True)
|
| 32 |
-
|
| 33 |
-
# Mask out positive items
|
| 34 |
-
positive_mask = tf.one_hot(positive_items, depth=tf.shape(self.item_embeddings)[0])
|
| 35 |
-
similarities = similarities - positive_mask * 1e9 # Large negative value
|
| 36 |
-
|
| 37 |
-
# Get top-k similar items (hard negatives)
|
| 38 |
-
_, hard_negative_indices = tf.nn.top_k(similarities, k=k_hard)
|
| 39 |
-
|
| 40 |
-
# Sample random negatives
|
| 41 |
-
total_items = tf.shape(self.item_embeddings)[0]
|
| 42 |
-
random_negatives = tf.random.uniform(
|
| 43 |
-
shape=[batch_size, k_random],
|
| 44 |
-
minval=0,
|
| 45 |
-
maxval=total_items,
|
| 46 |
-
dtype=tf.int32
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
# Combine hard and random negatives
|
| 50 |
-
if self.sampling_strategy == 'hard':
|
| 51 |
-
return hard_negative_indices
|
| 52 |
-
elif self.sampling_strategy == 'random':
|
| 53 |
-
return random_negatives
|
| 54 |
-
else: # mixed
|
| 55 |
-
return tf.concat([hard_negative_indices, random_negatives], axis=1)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class CurriculumLearningScheduler:
|
| 59 |
-
"""Curriculum learning scheduler for progressive difficulty."""
|
| 60 |
-
|
| 61 |
-
def __init__(self, total_epochs, warmup_epochs=10):
|
| 62 |
-
self.total_epochs = total_epochs
|
| 63 |
-
self.warmup_epochs = warmup_epochs
|
| 64 |
-
|
| 65 |
-
def get_difficulty_schedule(self, epoch):
|
| 66 |
-
"""Get curriculum parameters for current epoch."""
|
| 67 |
-
if epoch < self.warmup_epochs:
|
| 68 |
-
# Easy phase: more random negatives, lower temperature
|
| 69 |
-
hard_negative_ratio = 0.2
|
| 70 |
-
temperature = 2.0
|
| 71 |
-
negative_samples = 2
|
| 72 |
-
elif epoch < self.total_epochs * 0.6:
|
| 73 |
-
# Medium phase: balanced negatives
|
| 74 |
-
hard_negative_ratio = 0.5
|
| 75 |
-
temperature = 1.0
|
| 76 |
-
negative_samples = 4
|
| 77 |
-
else:
|
| 78 |
-
# Hard phase: more hard negatives, higher temperature
|
| 79 |
-
hard_negative_ratio = 0.8
|
| 80 |
-
temperature = 0.5
|
| 81 |
-
negative_samples = 6
|
| 82 |
-
|
| 83 |
-
return {
|
| 84 |
-
'hard_negative_ratio': hard_negative_ratio,
|
| 85 |
-
'temperature': temperature,
|
| 86 |
-
'negative_samples': negative_samples
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class ImprovedJointTrainer:
|
| 91 |
-
"""Enhanced joint trainer with advanced techniques."""
|
| 92 |
-
|
| 93 |
-
def __init__(self,
|
| 94 |
-
embedding_dim: int = 128,
|
| 95 |
-
learning_rate: float = 0.001,
|
| 96 |
-
use_mixed_precision: bool = True,
|
| 97 |
-
use_curriculum_learning: bool = True,
|
| 98 |
-
use_hard_negatives: bool = True):
|
| 99 |
-
|
| 100 |
-
self.embedding_dim = embedding_dim
|
| 101 |
-
self.learning_rate = learning_rate
|
| 102 |
-
self.use_mixed_precision = use_mixed_precision
|
| 103 |
-
self.use_curriculum_learning = use_curriculum_learning
|
| 104 |
-
self.use_hard_negatives = use_hard_negatives
|
| 105 |
-
|
| 106 |
-
# Enable mixed precision if requested
|
| 107 |
-
if use_mixed_precision:
|
| 108 |
-
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
| 109 |
-
tf.keras.mixed_precision.set_global_policy(policy)
|
| 110 |
-
|
| 111 |
-
self.model = None
|
| 112 |
-
self.data_processor = None
|
| 113 |
-
self.curriculum_scheduler = None
|
| 114 |
-
self.hard_negative_sampler = None
|
| 115 |
-
|
| 116 |
-
def setup_model(self, data_processor: DataProcessor):
|
| 117 |
-
"""Setup the improved model."""
|
| 118 |
-
self.data_processor = data_processor
|
| 119 |
-
|
| 120 |
-
# Create improved model
|
| 121 |
-
self.model = create_improved_model(
|
| 122 |
-
data_processor=data_processor,
|
| 123 |
-
embedding_dim=self.embedding_dim,
|
| 124 |
-
use_bias=True,
|
| 125 |
-
use_focal_loss=True
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
print(f"Created improved two-tower model with {self.embedding_dim}D embeddings")
|
| 129 |
-
|
| 130 |
-
def setup_curriculum_learning(self, total_epochs: int):
|
| 131 |
-
"""Setup curriculum learning scheduler."""
|
| 132 |
-
if self.use_curriculum_learning:
|
| 133 |
-
self.curriculum_scheduler = CurriculumLearningScheduler(
|
| 134 |
-
total_epochs=total_epochs,
|
| 135 |
-
warmup_epochs=max(5, total_epochs // 10)
|
| 136 |
-
)
|
| 137 |
-
print("Curriculum learning enabled")
|
| 138 |
-
|
| 139 |
-
def setup_hard_negative_sampling(self, item_features: Dict[str, np.ndarray]):
|
| 140 |
-
"""Setup hard negative sampling."""
|
| 141 |
-
if self.use_hard_negatives:
|
| 142 |
-
# Pre-compute item embeddings for efficient hard negative sampling
|
| 143 |
-
print("Pre-computing item embeddings for hard negative sampling...")
|
| 144 |
-
|
| 145 |
-
# Create a dummy batch to get item embeddings
|
| 146 |
-
batch_size = 1000
|
| 147 |
-
total_items = len(item_features['product_id'])
|
| 148 |
-
|
| 149 |
-
item_embeddings_list = []
|
| 150 |
-
for i in range(0, total_items, batch_size):
|
| 151 |
-
end_idx = min(i + batch_size, total_items)
|
| 152 |
-
batch_features = {
|
| 153 |
-
key: tf.constant(value[i:end_idx])
|
| 154 |
-
for key, value in item_features.items()
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
item_emb_output = self.model.item_tower(batch_features, training=False)
|
| 158 |
-
if isinstance(item_emb_output, tuple):
|
| 159 |
-
item_emb = item_emb_output[0] # Get embeddings, ignore bias
|
| 160 |
-
else:
|
| 161 |
-
item_emb = item_emb_output
|
| 162 |
-
|
| 163 |
-
item_embeddings_list.append(item_emb.numpy())
|
| 164 |
-
|
| 165 |
-
item_embeddings = np.vstack(item_embeddings_list)
|
| 166 |
-
|
| 167 |
-
self.hard_negative_sampler = HardNegativeSampler(
|
| 168 |
-
model=self.model,
|
| 169 |
-
item_embeddings=tf.constant(item_embeddings, dtype=tf.float32),
|
| 170 |
-
sampling_strategy='mixed'
|
| 171 |
-
)
|
| 172 |
-
print(f"Hard negative sampling enabled with {len(item_embeddings)} items")
|
| 173 |
-
|
| 174 |
-
def create_advanced_training_dataset(self,
|
| 175 |
-
features: Dict[str, np.ndarray],
|
| 176 |
-
batch_size: int = 256,
|
| 177 |
-
epoch: int = 0) -> tf.data.Dataset:
|
| 178 |
-
"""Create training dataset with curriculum learning and hard negatives."""
|
| 179 |
-
|
| 180 |
-
# Get curriculum parameters
|
| 181 |
-
if self.curriculum_scheduler:
|
| 182 |
-
curriculum_params = self.curriculum_scheduler.get_difficulty_schedule(epoch)
|
| 183 |
-
print(f"Epoch {epoch}: {curriculum_params}")
|
| 184 |
-
else:
|
| 185 |
-
curriculum_params = {
|
| 186 |
-
'hard_negative_ratio': 0.5,
|
| 187 |
-
'temperature': 1.0,
|
| 188 |
-
'negative_samples': 4
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
# Filter data based on curriculum (start with easier examples)
|
| 192 |
-
if epoch < 5: # Warmup epochs - use only high-confidence positive examples
|
| 193 |
-
positive_mask = features['rating'] == 1.0
|
| 194 |
-
if np.sum(positive_mask) > 0:
|
| 195 |
-
# Sample subset of positives and all negatives
|
| 196 |
-
positive_indices = np.where(positive_mask)[0]
|
| 197 |
-
negative_indices = np.where(features['rating'] == 0.0)[0]
|
| 198 |
-
|
| 199 |
-
# Sample subset for easier learning
|
| 200 |
-
n_positive_samples = min(len(positive_indices), len(negative_indices))
|
| 201 |
-
selected_positive = np.random.choice(
|
| 202 |
-
positive_indices, size=n_positive_samples, replace=False
|
| 203 |
-
)
|
| 204 |
-
selected_negative = np.random.choice(
|
| 205 |
-
negative_indices, size=n_positive_samples, replace=False
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
selected_indices = np.concatenate([selected_positive, selected_negative])
|
| 209 |
-
np.random.shuffle(selected_indices)
|
| 210 |
-
|
| 211 |
-
# Filter features
|
| 212 |
-
filtered_features = {
|
| 213 |
-
key: value[selected_indices] for key, value in features.items()
|
| 214 |
-
}
|
| 215 |
-
else:
|
| 216 |
-
filtered_features = features
|
| 217 |
-
else:
|
| 218 |
-
filtered_features = features
|
| 219 |
-
|
| 220 |
-
# Create dataset
|
| 221 |
-
dataset = create_tf_dataset(filtered_features, batch_size, shuffle=True)
|
| 222 |
-
|
| 223 |
-
return dataset
|
| 224 |
-
|
| 225 |
-
def compile_model(self):
|
| 226 |
-
"""Compile model with advanced optimizer."""
|
| 227 |
-
# Use AdamW with learning rate scheduling
|
| 228 |
-
initial_learning_rate = self.learning_rate
|
| 229 |
-
lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
|
| 230 |
-
initial_learning_rate=initial_learning_rate,
|
| 231 |
-
first_decay_steps=1000,
|
| 232 |
-
t_mul=2.0,
|
| 233 |
-
m_mul=0.9,
|
| 234 |
-
alpha=0.01
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
optimizer = tf.keras.optimizers.AdamW(
|
| 238 |
-
learning_rate=lr_schedule,
|
| 239 |
-
weight_decay=1e-5,
|
| 240 |
-
beta_1=0.9,
|
| 241 |
-
beta_2=0.999,
|
| 242 |
-
epsilon=1e-7
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
# Enable mixed precision optimizer if needed
|
| 246 |
-
if self.use_mixed_precision:
|
| 247 |
-
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
|
| 248 |
-
|
| 249 |
-
self.optimizer = optimizer
|
| 250 |
-
print(f"Model compiled with AdamW optimizer (lr={self.learning_rate})")
|
| 251 |
-
|
| 252 |
-
@tf.function
|
| 253 |
-
def train_step(self, features):
|
| 254 |
-
"""Optimized training step with gradient scaling."""
|
| 255 |
-
with tf.GradientTape() as tape:
|
| 256 |
-
# Forward pass
|
| 257 |
-
loss_dict = self.model.compute_loss(features, training=True)
|
| 258 |
-
total_loss = loss_dict['total_loss']
|
| 259 |
-
|
| 260 |
-
# Scale loss for mixed precision
|
| 261 |
-
if self.use_mixed_precision:
|
| 262 |
-
scaled_loss = self.optimizer.get_scaled_loss(total_loss)
|
| 263 |
-
else:
|
| 264 |
-
scaled_loss = total_loss
|
| 265 |
-
|
| 266 |
-
# Compute gradients
|
| 267 |
-
if self.use_mixed_precision:
|
| 268 |
-
scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
| 269 |
-
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
|
| 270 |
-
else:
|
| 271 |
-
gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
| 272 |
-
|
| 273 |
-
# Clip gradients to prevent exploding gradients
|
| 274 |
-
gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
|
| 275 |
-
|
| 276 |
-
# Apply gradients
|
| 277 |
-
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
|
| 278 |
-
|
| 279 |
-
return loss_dict
|
| 280 |
-
|
| 281 |
-
def evaluate_model(self, validation_dataset):
|
| 282 |
-
"""Evaluate model on validation set."""
|
| 283 |
-
total_losses = defaultdict(list)
|
| 284 |
-
|
| 285 |
-
for batch in validation_dataset:
|
| 286 |
-
loss_dict = self.model.compute_loss(batch, training=False)
|
| 287 |
-
for key, value in loss_dict.items():
|
| 288 |
-
total_losses[key].append(float(value))
|
| 289 |
-
|
| 290 |
-
# Average losses
|
| 291 |
-
avg_losses = {key: np.mean(values) for key, values in total_losses.items()}
|
| 292 |
-
return avg_losses
|
| 293 |
-
|
| 294 |
-
def train(self,
|
| 295 |
-
training_features: Dict[str, np.ndarray],
|
| 296 |
-
validation_features: Dict[str, np.ndarray],
|
| 297 |
-
epochs: int = 50,
|
| 298 |
-
batch_size: int = 256,
|
| 299 |
-
save_path: str = "src/artifacts/") -> Dict:
|
| 300 |
-
"""Enhanced training loop with all improvements."""
|
| 301 |
-
|
| 302 |
-
print(f"Starting improved training for {epochs} epochs...")
|
| 303 |
-
|
| 304 |
-
# Setup components
|
| 305 |
-
self.setup_curriculum_learning(epochs)
|
| 306 |
-
self.compile_model()
|
| 307 |
-
|
| 308 |
-
# Create validation dataset
|
| 309 |
-
validation_dataset = create_tf_dataset(validation_features, batch_size, shuffle=False)
|
| 310 |
-
|
| 311 |
-
# Training history
|
| 312 |
-
history = defaultdict(list)
|
| 313 |
-
best_val_loss = float('inf')
|
| 314 |
-
patience_counter = 0
|
| 315 |
-
early_stopping_patience = 10
|
| 316 |
-
|
| 317 |
-
# Training loop
|
| 318 |
-
for epoch in range(epochs):
|
| 319 |
-
epoch_start_time = time.time()
|
| 320 |
-
|
| 321 |
-
# Create training dataset for this epoch (curriculum learning)
|
| 322 |
-
training_dataset = self.create_advanced_training_dataset(
|
| 323 |
-
training_features, batch_size, epoch
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
# Training
|
| 327 |
-
epoch_losses = defaultdict(list)
|
| 328 |
-
num_batches = 0
|
| 329 |
-
|
| 330 |
-
for batch in training_dataset:
|
| 331 |
-
loss_dict = self.train_step(batch)
|
| 332 |
-
|
| 333 |
-
for key, value in loss_dict.items():
|
| 334 |
-
epoch_losses[key].append(float(value))
|
| 335 |
-
num_batches += 1
|
| 336 |
-
|
| 337 |
-
# Average training losses
|
| 338 |
-
avg_train_losses = {
|
| 339 |
-
key: np.mean(values) for key, values in epoch_losses.items()
|
| 340 |
-
}
|
| 341 |
-
|
| 342 |
-
# Validation
|
| 343 |
-
avg_val_losses = self.evaluate_model(validation_dataset)
|
| 344 |
-
|
| 345 |
-
# Log progress
|
| 346 |
-
epoch_time = time.time() - epoch_start_time
|
| 347 |
-
print(f"Epoch {epoch+1}/{epochs} ({epoch_time:.1f}s):")
|
| 348 |
-
print(f" Train Loss: {avg_train_losses['total_loss']:.4f}")
|
| 349 |
-
print(f" Val Loss: {avg_val_losses['total_loss']:.4f}")
|
| 350 |
-
print(f" Val Rating Loss: {avg_val_losses['rating_loss']:.4f}")
|
| 351 |
-
print(f" Val Retrieval Loss: {avg_val_losses['retrieval_loss']:.4f}")
|
| 352 |
-
|
| 353 |
-
# Save history
|
| 354 |
-
for key, value in avg_train_losses.items():
|
| 355 |
-
history[f'train_{key}'].append(value)
|
| 356 |
-
for key, value in avg_val_losses.items():
|
| 357 |
-
history[f'val_{key}'].append(value)
|
| 358 |
-
|
| 359 |
-
# Early stopping and model saving
|
| 360 |
-
current_val_loss = avg_val_losses['total_loss']
|
| 361 |
-
if current_val_loss < best_val_loss:
|
| 362 |
-
best_val_loss = current_val_loss
|
| 363 |
-
patience_counter = 0
|
| 364 |
-
|
| 365 |
-
# Save best model
|
| 366 |
-
self.save_model(save_path, suffix='_improved_best')
|
| 367 |
-
print(f" πΎ Saved best model (val_loss: {best_val_loss:.4f})")
|
| 368 |
-
else:
|
| 369 |
-
patience_counter += 1
|
| 370 |
-
|
| 371 |
-
if patience_counter >= early_stopping_patience:
|
| 372 |
-
print(f"Early stopping at epoch {epoch+1}")
|
| 373 |
-
break
|
| 374 |
-
|
| 375 |
-
# Save final model and history
|
| 376 |
-
self.save_model(save_path, suffix='_improved_final')
|
| 377 |
-
self.save_training_history(dict(history), save_path)
|
| 378 |
-
|
| 379 |
-
print("β
Improved training completed!")
|
| 380 |
-
return dict(history)
|
| 381 |
-
|
| 382 |
-
def save_model(self, save_path: str, suffix: str = ''):
|
| 383 |
-
"""Save the trained model components."""
|
| 384 |
-
os.makedirs(save_path, exist_ok=True)
|
| 385 |
-
|
| 386 |
-
# Save model weights
|
| 387 |
-
self.model.item_tower.save_weights(f"{save_path}/improved_item_tower_weights{suffix}")
|
| 388 |
-
self.model.user_tower.save_weights(f"{save_path}/improved_user_tower_weights{suffix}")
|
| 389 |
-
|
| 390 |
-
if hasattr(self.model, 'rating_model'):
|
| 391 |
-
self.model.rating_model.save_weights(f"{save_path}/improved_rating_model_weights{suffix}")
|
| 392 |
-
|
| 393 |
-
# Save configuration
|
| 394 |
-
config = {
|
| 395 |
-
'embedding_dim': self.embedding_dim,
|
| 396 |
-
'learning_rate': self.learning_rate,
|
| 397 |
-
'use_mixed_precision': self.use_mixed_precision,
|
| 398 |
-
'use_curriculum_learning': self.use_curriculum_learning,
|
| 399 |
-
'use_hard_negatives': self.use_hard_negatives
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
with open(f"{save_path}/improved_model_config{suffix}.txt", 'w') as f:
|
| 403 |
-
for key, value in config.items():
|
| 404 |
-
f.write(f"{key}: {value}\n")
|
| 405 |
-
|
| 406 |
-
print(f"Model saved to {save_path} with suffix '{suffix}'")
|
| 407 |
-
|
| 408 |
-
def save_training_history(self, history: Dict, save_path: str):
|
| 409 |
-
"""Save training history."""
|
| 410 |
-
with open(f"{save_path}/improved_training_history.pkl", 'wb') as f:
|
| 411 |
-
pickle.dump(history, f)
|
| 412 |
-
print(f"Training history saved to {save_path}")
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
def main():
|
| 416 |
-
"""Demo of improved training."""
|
| 417 |
-
print("π IMPROVED TWO-TOWER TRAINING DEMO")
|
| 418 |
-
print("="*60)
|
| 419 |
-
|
| 420 |
-
# Load data
|
| 421 |
-
print("Loading training data...")
|
| 422 |
-
try:
|
| 423 |
-
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 424 |
-
training_features = pickle.load(f)
|
| 425 |
-
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 426 |
-
validation_features = pickle.load(f)
|
| 427 |
-
|
| 428 |
-
print(f"Loaded {len(training_features['rating'])} training samples")
|
| 429 |
-
print(f"Loaded {len(validation_features['rating'])} validation samples")
|
| 430 |
-
except FileNotFoundError:
|
| 431 |
-
print("β Training data not found. Please run data preparation first.")
|
| 432 |
-
return
|
| 433 |
-
|
| 434 |
-
# Load data processor
|
| 435 |
-
data_processor = DataProcessor()
|
| 436 |
-
data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
|
| 437 |
-
|
| 438 |
-
# Create trainer
|
| 439 |
-
trainer = ImprovedJointTrainer(
|
| 440 |
-
embedding_dim=128,
|
| 441 |
-
learning_rate=0.001,
|
| 442 |
-
use_mixed_precision=True,
|
| 443 |
-
use_curriculum_learning=True,
|
| 444 |
-
use_hard_negatives=True
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
# Setup and train
|
| 448 |
-
trainer.setup_model(data_processor)
|
| 449 |
-
|
| 450 |
-
# Train model
|
| 451 |
-
history = trainer.train(
|
| 452 |
-
training_features=training_features,
|
| 453 |
-
validation_features=validation_features,
|
| 454 |
-
epochs=30,
|
| 455 |
-
batch_size=256
|
| 456 |
-
)
|
| 457 |
-
|
| 458 |
-
print("β
Improved training completed successfully!")
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
if __name__ == "__main__":
|
| 462 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/optimized_joint_training.py
DELETED
|
@@ -1,439 +0,0 @@
|
|
| 1 |
-
import tensorflow as tf
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pickle
|
| 4 |
-
import os
|
| 5 |
-
import time
|
| 6 |
-
from typing import Dict, List, Tuple
|
| 7 |
-
|
| 8 |
-
from src.models.item_tower import ItemTower
|
| 9 |
-
from src.models.user_tower import UserTower, TwoTowerModel
|
| 10 |
-
from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class OptimizedJointTrainer:
|
| 14 |
-
"""Optimized joint training with performance enhancements."""
|
| 15 |
-
|
| 16 |
-
def __init__(self,
|
| 17 |
-
embedding_dim: int = 128, # Updated to 128D output
|
| 18 |
-
user_learning_rate: float = 0.001,
|
| 19 |
-
item_learning_rate: float = 0.0001,
|
| 20 |
-
rating_weight: float = 1.0,
|
| 21 |
-
retrieval_weight: float = 1.0,
|
| 22 |
-
gradient_accumulation_steps: int = 1,
|
| 23 |
-
use_mixed_precision: bool = False): # Disabled for CPU training
|
| 24 |
-
|
| 25 |
-
self.embedding_dim = embedding_dim
|
| 26 |
-
self.user_learning_rate = user_learning_rate
|
| 27 |
-
self.item_learning_rate = item_learning_rate
|
| 28 |
-
self.rating_weight = rating_weight
|
| 29 |
-
self.retrieval_weight = retrieval_weight
|
| 30 |
-
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 31 |
-
self.use_mixed_precision = use_mixed_precision
|
| 32 |
-
|
| 33 |
-
# Enable mixed precision for faster training
|
| 34 |
-
if self.use_mixed_precision:
|
| 35 |
-
tf.keras.mixed_precision.set_global_policy('mixed_float16')
|
| 36 |
-
print("Mixed precision training enabled")
|
| 37 |
-
|
| 38 |
-
self.item_tower = None
|
| 39 |
-
self.user_tower = None
|
| 40 |
-
self.model = None
|
| 41 |
-
|
| 42 |
-
# Precompile TensorFlow functions for speed
|
| 43 |
-
self._compiled_train_step = None
|
| 44 |
-
self._compiled_val_step = None
|
| 45 |
-
|
| 46 |
-
def load_pre_trained_item_tower(self, artifacts_path: str = "src/artifacts/") -> ItemTower:
|
| 47 |
-
"""Load pre-trained item tower with optimizations."""
|
| 48 |
-
data_processor = DataProcessor()
|
| 49 |
-
data_processor.load_vocabularies(f"{artifacts_path}/vocabularies.pkl")
|
| 50 |
-
|
| 51 |
-
with open(f"{artifacts_path}/item_tower_config.txt", 'r') as f:
|
| 52 |
-
config = {}
|
| 53 |
-
for line in f:
|
| 54 |
-
key, value = line.strip().split(': ')
|
| 55 |
-
if key in ['embedding_dim', 'dropout_rate']:
|
| 56 |
-
config[key] = float(value) if '.' in value else int(value)
|
| 57 |
-
elif key == 'hidden_dims':
|
| 58 |
-
config[key] = eval(value)
|
| 59 |
-
|
| 60 |
-
self.item_tower = ItemTower(
|
| 61 |
-
item_vocab_size=len(data_processor.item_vocab),
|
| 62 |
-
category_vocab_size=len(data_processor.category_vocab),
|
| 63 |
-
brand_vocab_size=len(data_processor.brand_vocab),
|
| 64 |
-
**config
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
dummy_input = {
|
| 68 |
-
'product_id': tf.constant([0]),
|
| 69 |
-
'category_id': tf.constant([0]),
|
| 70 |
-
'brand_id': tf.constant([0]),
|
| 71 |
-
'price': tf.constant([0.0])
|
| 72 |
-
}
|
| 73 |
-
_ = self.item_tower(dummy_input)
|
| 74 |
-
self.item_tower.load_weights(f"{artifacts_path}/item_tower_weights")
|
| 75 |
-
|
| 76 |
-
print("Pre-trained item tower loaded successfully")
|
| 77 |
-
return self.item_tower
|
| 78 |
-
|
| 79 |
-
def build_user_tower(self, max_history_length: int = 50) -> UserTower:
|
| 80 |
-
"""Build user tower with optimizations."""
|
| 81 |
-
self.user_tower = UserTower(
|
| 82 |
-
max_history_length=max_history_length,
|
| 83 |
-
embedding_dim=self.embedding_dim,
|
| 84 |
-
hidden_dims=[128, 64],
|
| 85 |
-
dropout_rate=0.1 # Reduced dropout for faster training
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
print("User tower initialized")
|
| 89 |
-
return self.user_tower
|
| 90 |
-
|
| 91 |
-
def build_two_tower_model(self) -> TwoTowerModel:
|
| 92 |
-
"""Build complete two-tower model."""
|
| 93 |
-
if self.item_tower is None or self.user_tower is None:
|
| 94 |
-
raise ValueError("Both towers must be initialized first")
|
| 95 |
-
|
| 96 |
-
self.model = TwoTowerModel(
|
| 97 |
-
item_tower=self.item_tower,
|
| 98 |
-
user_tower=self.user_tower,
|
| 99 |
-
rating_weight=self.rating_weight,
|
| 100 |
-
retrieval_weight=self.retrieval_weight
|
| 101 |
-
)
|
| 102 |
-
|
| 103 |
-
print("Two-tower model built successfully")
|
| 104 |
-
return self.model
|
| 105 |
-
|
| 106 |
-
def create_optimized_dataset(self, features: Dict[str, np.ndarray],
|
| 107 |
-
batch_size: int,
|
| 108 |
-
is_training: bool = True) -> tf.data.Dataset:
|
| 109 |
-
"""Create optimized dataset pipeline for faster training."""
|
| 110 |
-
|
| 111 |
-
dataset = tf.data.Dataset.from_tensor_slices(features)
|
| 112 |
-
|
| 113 |
-
if is_training:
|
| 114 |
-
# Optimized shuffling and prefetching
|
| 115 |
-
dataset = dataset.shuffle(buffer_size=min(10000, len(features['rating'])))
|
| 116 |
-
dataset = dataset.repeat() # Repeat for multiple epochs
|
| 117 |
-
|
| 118 |
-
dataset = dataset.batch(batch_size, drop_remainder=True)
|
| 119 |
-
|
| 120 |
-
# Optimize for CPU training
|
| 121 |
-
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 122 |
-
|
| 123 |
-
return dataset
|
| 124 |
-
|
| 125 |
-
@tf.function(experimental_relax_shapes=True)
|
| 126 |
-
def optimized_train_step(self, batch: Dict[str, tf.Tensor],
|
| 127 |
-
user_optimizer: tf.keras.optimizers.Optimizer,
|
| 128 |
-
item_optimizer: tf.keras.optimizers.Optimizer,
|
| 129 |
-
train_item: bool) -> Dict[str, tf.Tensor]:
|
| 130 |
-
"""Optimized training step with tf.function compilation."""
|
| 131 |
-
|
| 132 |
-
with tf.GradientTape() as tape:
|
| 133 |
-
# Forward pass
|
| 134 |
-
user_embeddings = self.user_tower(batch, training=True)
|
| 135 |
-
item_embeddings = self.item_tower(batch, training=True)
|
| 136 |
-
|
| 137 |
-
# Concatenate and predict rating
|
| 138 |
-
concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
|
| 139 |
-
rating_predictions = self.model.rating_model(concatenated, training=True)
|
| 140 |
-
|
| 141 |
-
# Compute losses - fix shape mismatch
|
| 142 |
-
rating_loss = tf.keras.losses.binary_crossentropy(
|
| 143 |
-
tf.expand_dims(batch["rating"], -1), rating_predictions
|
| 144 |
-
)
|
| 145 |
-
rating_loss = tf.reduce_mean(rating_loss)
|
| 146 |
-
|
| 147 |
-
# Retrieval loss - cosine similarity
|
| 148 |
-
user_norm = tf.nn.l2_normalize(user_embeddings, axis=1)
|
| 149 |
-
item_norm = tf.nn.l2_normalize(item_embeddings, axis=1)
|
| 150 |
-
similarities = tf.reduce_sum(user_norm * item_norm, axis=1)
|
| 151 |
-
|
| 152 |
-
retrieval_loss = tf.keras.losses.binary_crossentropy(
|
| 153 |
-
batch["rating"], tf.nn.sigmoid(similarities)
|
| 154 |
-
)
|
| 155 |
-
retrieval_loss = tf.reduce_mean(retrieval_loss)
|
| 156 |
-
|
| 157 |
-
total_loss = (
|
| 158 |
-
self.rating_weight * rating_loss +
|
| 159 |
-
self.retrieval_weight * retrieval_loss
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
# Handle mixed precision
|
| 163 |
-
if self.use_mixed_precision:
|
| 164 |
-
total_loss = user_optimizer.get_scaled_loss(total_loss)
|
| 165 |
-
|
| 166 |
-
# Compute gradients
|
| 167 |
-
user_vars = self.user_tower.trainable_variables + self.model.rating_model.trainable_variables
|
| 168 |
-
|
| 169 |
-
if train_item:
|
| 170 |
-
all_vars = user_vars + self.item_tower.trainable_variables
|
| 171 |
-
gradients = tape.gradient(total_loss, all_vars)
|
| 172 |
-
|
| 173 |
-
if self.use_mixed_precision:
|
| 174 |
-
gradients = user_optimizer.get_unscaled_gradients(gradients)
|
| 175 |
-
|
| 176 |
-
# Split gradients
|
| 177 |
-
user_grads = gradients[:len(user_vars)]
|
| 178 |
-
item_grads = gradients[len(user_vars):]
|
| 179 |
-
|
| 180 |
-
# Apply gradients
|
| 181 |
-
user_optimizer.apply_gradients(zip(user_grads, user_vars))
|
| 182 |
-
item_optimizer.apply_gradients(zip(item_grads, self.item_tower.trainable_variables))
|
| 183 |
-
else:
|
| 184 |
-
gradients = tape.gradient(total_loss, user_vars)
|
| 185 |
-
|
| 186 |
-
if self.use_mixed_precision:
|
| 187 |
-
gradients = user_optimizer.get_unscaled_gradients(gradients)
|
| 188 |
-
|
| 189 |
-
user_optimizer.apply_gradients(zip(gradients, user_vars))
|
| 190 |
-
|
| 191 |
-
# Convert back from scaled loss for logging
|
| 192 |
-
if self.use_mixed_precision:
|
| 193 |
-
total_loss = total_loss / user_optimizer.loss_scale
|
| 194 |
-
rating_loss = rating_loss / user_optimizer.loss_scale
|
| 195 |
-
retrieval_loss = retrieval_loss / user_optimizer.loss_scale
|
| 196 |
-
|
| 197 |
-
return {
|
| 198 |
-
'total_loss': total_loss,
|
| 199 |
-
'rating_loss': rating_loss,
|
| 200 |
-
'retrieval_loss': retrieval_loss
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
@tf.function(experimental_relax_shapes=True)
|
| 204 |
-
def optimized_val_step(self, batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
|
| 205 |
-
"""Optimized validation step."""
|
| 206 |
-
|
| 207 |
-
user_embeddings = self.user_tower(batch, training=False)
|
| 208 |
-
item_embeddings = self.item_tower(batch, training=False)
|
| 209 |
-
|
| 210 |
-
concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
|
| 211 |
-
rating_predictions = self.model.rating_model(concatenated, training=False)
|
| 212 |
-
|
| 213 |
-
rating_loss = tf.reduce_mean(
|
| 214 |
-
tf.keras.losses.binary_crossentropy(tf.expand_dims(batch["rating"], -1), rating_predictions)
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
# Retrieval loss
|
| 218 |
-
user_norm = tf.nn.l2_normalize(user_embeddings, axis=1)
|
| 219 |
-
item_norm = tf.nn.l2_normalize(item_embeddings, axis=1)
|
| 220 |
-
similarities = tf.reduce_sum(user_norm * item_norm, axis=1)
|
| 221 |
-
|
| 222 |
-
retrieval_loss = tf.reduce_mean(
|
| 223 |
-
tf.keras.losses.binary_crossentropy(batch["rating"], tf.nn.sigmoid(similarities))
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
total_loss = self.rating_weight * rating_loss + self.retrieval_weight * retrieval_loss
|
| 227 |
-
|
| 228 |
-
return {
|
| 229 |
-
'total_loss': total_loss,
|
| 230 |
-
'rating_loss': rating_loss,
|
| 231 |
-
'retrieval_loss': retrieval_loss
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
def train(self,
|
| 235 |
-
training_features: Dict[str, np.ndarray],
|
| 236 |
-
validation_features: Dict[str, np.ndarray],
|
| 237 |
-
epochs: int = 50, # Reduced default epochs
|
| 238 |
-
batch_size: int = 512) -> Dict: # Larger batch size for efficiency
|
| 239 |
-
"""Optimized training loop."""
|
| 240 |
-
|
| 241 |
-
print(f"Starting optimized joint training for {epochs} epochs...")
|
| 242 |
-
print(f"Batch size: {batch_size}")
|
| 243 |
-
print(f"Mixed precision: {self.use_mixed_precision}")
|
| 244 |
-
|
| 245 |
-
# Create optimized datasets
|
| 246 |
-
steps_per_epoch = len(training_features['rating']) // batch_size
|
| 247 |
-
val_steps = len(validation_features['rating']) // batch_size
|
| 248 |
-
|
| 249 |
-
train_dataset = self.create_optimized_dataset(training_features, batch_size, is_training=True)
|
| 250 |
-
val_dataset = self.create_optimized_dataset(validation_features, batch_size, is_training=False)
|
| 251 |
-
|
| 252 |
-
# Note: Age and income are now categorical - no normalization needed
|
| 253 |
-
|
| 254 |
-
# Setup optimizers with mixed precision
|
| 255 |
-
if self.use_mixed_precision:
|
| 256 |
-
user_optimizer = tf.keras.optimizers.Adam(learning_rate=self.user_learning_rate)
|
| 257 |
-
user_optimizer = tf.keras.mixed_precision.LossScaleOptimizer(user_optimizer)
|
| 258 |
-
item_optimizer = tf.keras.optimizers.Adam(learning_rate=self.item_learning_rate)
|
| 259 |
-
item_optimizer = tf.keras.mixed_precision.LossScaleOptimizer(item_optimizer)
|
| 260 |
-
else:
|
| 261 |
-
user_optimizer = tf.keras.optimizers.Adam(learning_rate=self.user_learning_rate)
|
| 262 |
-
item_optimizer = tf.keras.optimizers.Adam(learning_rate=self.item_learning_rate)
|
| 263 |
-
|
| 264 |
-
# Training history
|
| 265 |
-
history = {
|
| 266 |
-
'total_loss': [], 'rating_loss': [], 'retrieval_loss': [],
|
| 267 |
-
'val_total_loss': [], 'val_rating_loss': [], 'val_retrieval_loss': [],
|
| 268 |
-
'epoch_times': []
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
-
best_val_loss = float('inf')
|
| 272 |
-
patience_counter = 0
|
| 273 |
-
patience = 10 # Reduced patience for faster training
|
| 274 |
-
|
| 275 |
-
train_iter = iter(train_dataset)
|
| 276 |
-
val_iter = iter(val_dataset)
|
| 277 |
-
|
| 278 |
-
for epoch in range(epochs):
|
| 279 |
-
epoch_start_time = time.time()
|
| 280 |
-
print(f"\\nEpoch {epoch + 1}/{epochs}")
|
| 281 |
-
|
| 282 |
-
# Determine training strategy
|
| 283 |
-
freeze_threshold = int(0.2 * epochs) # Reduced freeze period
|
| 284 |
-
train_item = epoch >= freeze_threshold
|
| 285 |
-
|
| 286 |
-
print(f"Training: User=β, Item={'β' if train_item else 'β'}")
|
| 287 |
-
|
| 288 |
-
# Training loop
|
| 289 |
-
epoch_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
|
| 290 |
-
|
| 291 |
-
for step in range(steps_per_epoch):
|
| 292 |
-
try:
|
| 293 |
-
batch = next(train_iter)
|
| 294 |
-
except StopIteration:
|
| 295 |
-
train_iter = iter(train_dataset)
|
| 296 |
-
batch = next(train_iter)
|
| 297 |
-
|
| 298 |
-
losses = self.optimized_train_step(batch, user_optimizer, item_optimizer, train_item)
|
| 299 |
-
|
| 300 |
-
for key in epoch_losses:
|
| 301 |
-
epoch_losses[key].append(losses[key])
|
| 302 |
-
|
| 303 |
-
# Calculate training averages
|
| 304 |
-
avg_train_losses = {k: tf.reduce_mean(v).numpy() for k, v in epoch_losses.items()}
|
| 305 |
-
|
| 306 |
-
# Validation loop
|
| 307 |
-
val_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
|
| 308 |
-
|
| 309 |
-
for step in range(val_steps):
|
| 310 |
-
try:
|
| 311 |
-
batch = next(val_iter)
|
| 312 |
-
except StopIteration:
|
| 313 |
-
val_iter = iter(val_dataset)
|
| 314 |
-
batch = next(val_iter)
|
| 315 |
-
|
| 316 |
-
losses = self.optimized_val_step(batch)
|
| 317 |
-
|
| 318 |
-
for key in val_losses:
|
| 319 |
-
val_losses[key].append(losses[key])
|
| 320 |
-
|
| 321 |
-
avg_val_losses = {k: tf.reduce_mean(v).numpy() for k, v in val_losses.items()}
|
| 322 |
-
|
| 323 |
-
# Record history
|
| 324 |
-
epoch_time = time.time() - epoch_start_time
|
| 325 |
-
history['epoch_times'].append(epoch_time)
|
| 326 |
-
|
| 327 |
-
for key in ['total_loss', 'rating_loss', 'retrieval_loss']:
|
| 328 |
-
history[key].append(avg_train_losses[key])
|
| 329 |
-
history[f'val_{key}'].append(avg_val_losses[key])
|
| 330 |
-
|
| 331 |
-
# Print progress
|
| 332 |
-
print(f"Time: {epoch_time:.1f}s | "
|
| 333 |
-
f"Train Loss: {avg_train_losses['total_loss']:.4f} | "
|
| 334 |
-
f"Val Loss: {avg_val_losses['total_loss']:.4f}")
|
| 335 |
-
|
| 336 |
-
# Early stopping with model saving
|
| 337 |
-
if avg_val_losses['total_loss'] < best_val_loss:
|
| 338 |
-
best_val_loss = avg_val_losses['total_loss']
|
| 339 |
-
patience_counter = 0
|
| 340 |
-
self.save_model("src/artifacts/", suffix="_best")
|
| 341 |
-
else:
|
| 342 |
-
patience_counter += 1
|
| 343 |
-
if patience_counter >= patience:
|
| 344 |
-
print(f"Early stopping at epoch {epoch + 1}")
|
| 345 |
-
break
|
| 346 |
-
|
| 347 |
-
avg_epoch_time = np.mean(history['epoch_times'])
|
| 348 |
-
print(f"\\nTraining completed!")
|
| 349 |
-
print(f"Average epoch time: {avg_epoch_time:.1f}s")
|
| 350 |
-
print(f"Total training time: {sum(history['epoch_times']):.1f}s")
|
| 351 |
-
|
| 352 |
-
return history
|
| 353 |
-
|
| 354 |
-
def save_model(self, save_path: str = "src/artifacts/", suffix: str = ""):
|
| 355 |
-
"""Save the trained model."""
|
| 356 |
-
os.makedirs(save_path, exist_ok=True)
|
| 357 |
-
|
| 358 |
-
self.user_tower.save_weights(f"{save_path}/user_tower_weights{suffix}")
|
| 359 |
-
self.item_tower.save_weights(f"{save_path}/item_tower_weights_finetuned{suffix}")
|
| 360 |
-
self.model.rating_model.save_weights(f"{save_path}/rating_model_weights{suffix}")
|
| 361 |
-
|
| 362 |
-
config = {
|
| 363 |
-
'embedding_dim': self.embedding_dim,
|
| 364 |
-
'user_learning_rate': self.user_learning_rate,
|
| 365 |
-
'item_learning_rate': self.item_learning_rate,
|
| 366 |
-
'rating_weight': self.rating_weight,
|
| 367 |
-
'retrieval_weight': self.retrieval_weight,
|
| 368 |
-
'use_mixed_precision': self.use_mixed_precision
|
| 369 |
-
}
|
| 370 |
-
|
| 371 |
-
with open(f"{save_path}/optimized_joint_model_config{suffix}.txt", 'w') as f:
|
| 372 |
-
for key, value in config.items():
|
| 373 |
-
f.write(f"{key}: {value}\\n")
|
| 374 |
-
|
| 375 |
-
if not suffix:
|
| 376 |
-
print(f"Optimized model saved to {save_path}")
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
def main():
|
| 380 |
-
"""Main function for optimized joint training."""
|
| 381 |
-
|
| 382 |
-
print("Initializing optimized joint trainer...")
|
| 383 |
-
trainer = OptimizedJointTrainer(
|
| 384 |
-
embedding_dim=128, # Updated to 128D
|
| 385 |
-
user_learning_rate=0.002, # Slightly higher for faster convergence
|
| 386 |
-
item_learning_rate=0.0002,
|
| 387 |
-
rating_weight=1.0,
|
| 388 |
-
retrieval_weight=0.3, # Reduced for faster training
|
| 389 |
-
use_mixed_precision=False # Disabled for CPU
|
| 390 |
-
)
|
| 391 |
-
|
| 392 |
-
# Load components
|
| 393 |
-
print("Loading pre-trained item tower...")
|
| 394 |
-
trainer.load_pre_trained_item_tower()
|
| 395 |
-
|
| 396 |
-
print("Building user tower...")
|
| 397 |
-
trainer.build_user_tower(max_history_length=50)
|
| 398 |
-
|
| 399 |
-
print("Building two-tower model...")
|
| 400 |
-
trainer.build_two_tower_model()
|
| 401 |
-
|
| 402 |
-
# Load training data
|
| 403 |
-
print("Loading training data...")
|
| 404 |
-
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 405 |
-
training_features = pickle.load(f)
|
| 406 |
-
|
| 407 |
-
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 408 |
-
validation_features = pickle.load(f)
|
| 409 |
-
|
| 410 |
-
print(f"Training samples: {len(training_features['rating'])}")
|
| 411 |
-
print(f"Validation samples: {len(validation_features['rating'])}")
|
| 412 |
-
|
| 413 |
-
# Train with optimizations
|
| 414 |
-
print("Starting optimized training...")
|
| 415 |
-
start_time = time.time()
|
| 416 |
-
|
| 417 |
-
history = trainer.train(
|
| 418 |
-
training_features=training_features,
|
| 419 |
-
validation_features=validation_features,
|
| 420 |
-
epochs=30, # Reduced epochs for faster training
|
| 421 |
-
batch_size=1024 # Larger batch size for better GPU utilization
|
| 422 |
-
)
|
| 423 |
-
|
| 424 |
-
total_time = time.time() - start_time
|
| 425 |
-
|
| 426 |
-
# Save final model and history
|
| 427 |
-
print("Saving final model...")
|
| 428 |
-
trainer.save_model()
|
| 429 |
-
|
| 430 |
-
with open("src/artifacts/optimized_training_history.pkl", 'wb') as f:
|
| 431 |
-
pickle.dump(history, f)
|
| 432 |
-
|
| 433 |
-
print(f"\\nOptimized joint training completed!")
|
| 434 |
-
print(f"Total training time: {total_time:.1f}s")
|
| 435 |
-
print(f"Average time per epoch: {total_time/len(history['epoch_times']):.1f}s")
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
if __name__ == "__main__":
|
| 439 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/real_user_selector.py
CHANGED
|
@@ -93,22 +93,58 @@ class RealUserSelector:
|
|
| 93 |
"""
|
| 94 |
print(f"Selecting {n} real users with at least {min_interactions} interactions...")
|
| 95 |
|
| 96 |
-
# Filter users with sufficient interactions
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
for _, user in self.users_df.iterrows():
|
| 99 |
user_id = user['user_id']
|
| 100 |
-
if
|
| 101 |
-
self.user_stats[user_id]['total_interactions']
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
| 110 |
else:
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Build user profiles with real data
|
| 114 |
real_user_profiles = []
|
|
@@ -124,6 +160,10 @@ class RealUserSelector:
|
|
| 124 |
'age': int(user['age']),
|
| 125 |
'gender': user['gender'],
|
| 126 |
'income': int(user['income']),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
'interaction_history': stats['unique_items'][:50], # Limit to 50 most recent
|
| 128 |
'interaction_stats': {
|
| 129 |
'total_interactions': stats['total_interactions'],
|
|
|
|
| 93 |
"""
|
| 94 |
print(f"Selecting {n} real users with at least {min_interactions} interactions...")
|
| 95 |
|
| 96 |
+
# Filter users with sufficient interactions, separating by interaction count
|
| 97 |
+
high_interaction_users = [] # >14 interactions
|
| 98 |
+
low_interaction_users = [] # min_interactions to 14 interactions
|
| 99 |
+
|
| 100 |
for _, user in self.users_df.iterrows():
|
| 101 |
user_id = user['user_id']
|
| 102 |
+
if user_id in self.user_stats:
|
| 103 |
+
interaction_count = self.user_stats[user_id]['total_interactions']
|
| 104 |
+
if interaction_count >= min_interactions:
|
| 105 |
+
if interaction_count > 14:
|
| 106 |
+
high_interaction_users.append(user)
|
| 107 |
+
else:
|
| 108 |
+
low_interaction_users.append(user)
|
| 109 |
+
|
| 110 |
+
print(f"Found {len(high_interaction_users)} high-interaction users (>14) and {len(low_interaction_users)} low-interaction users ({min_interactions}-14)")
|
| 111 |
|
| 112 |
+
# Ensure more than half have >14 interactions
|
| 113 |
+
min_high_interaction = (n // 2) + 1 # More than 50%
|
| 114 |
|
| 115 |
+
selected_users = []
|
| 116 |
+
|
| 117 |
+
# First, select from high-interaction users (prioritize these)
|
| 118 |
+
if len(high_interaction_users) >= min_high_interaction:
|
| 119 |
+
selected_high = random.sample(high_interaction_users, min_high_interaction)
|
| 120 |
else:
|
| 121 |
+
print(f"Warning: Only {len(high_interaction_users)} high-interaction users available, using all")
|
| 122 |
+
selected_high = high_interaction_users
|
| 123 |
+
|
| 124 |
+
selected_users.extend(selected_high)
|
| 125 |
+
remaining_slots = n - len(selected_high)
|
| 126 |
+
|
| 127 |
+
# Fill remaining slots with low-interaction users if available
|
| 128 |
+
if remaining_slots > 0 and len(low_interaction_users) > 0:
|
| 129 |
+
if len(low_interaction_users) >= remaining_slots:
|
| 130 |
+
selected_low = random.sample(low_interaction_users, remaining_slots)
|
| 131 |
+
else:
|
| 132 |
+
selected_low = low_interaction_users
|
| 133 |
+
selected_users.extend(selected_low)
|
| 134 |
+
|
| 135 |
+
# If we still need more users, add remaining high-interaction users
|
| 136 |
+
remaining_slots = n - len(selected_users)
|
| 137 |
+
if remaining_slots > 0:
|
| 138 |
+
remaining_high = [user for user in high_interaction_users if user not in selected_users]
|
| 139 |
+
if len(remaining_high) >= remaining_slots:
|
| 140 |
+
selected_users.extend(random.sample(remaining_high, remaining_slots))
|
| 141 |
+
else:
|
| 142 |
+
selected_users.extend(remaining_high)
|
| 143 |
+
|
| 144 |
+
print(f"Selected {len(selected_users)} total users: {len([u for u in selected_users if self.user_stats[u['user_id']]['total_interactions'] > 14])} high-interaction (>14), {len([u for u in selected_users if self.user_stats[u['user_id']]['total_interactions'] <= 14])} low-interaction (β€14)")
|
| 145 |
+
|
| 146 |
+
if len(selected_users) < n:
|
| 147 |
+
print(f"Warning: Only {len(selected_users)} users available, returning all")
|
| 148 |
|
| 149 |
# Build user profiles with real data
|
| 150 |
real_user_profiles = []
|
|
|
|
| 160 |
'age': int(user['age']),
|
| 161 |
'gender': user['gender'],
|
| 162 |
'income': int(user['income']),
|
| 163 |
+
'profession': user.get('profession', 'Other'),
|
| 164 |
+
'location': user.get('location', 'Urban'),
|
| 165 |
+
'education_level': user.get('education_level', 'High School'),
|
| 166 |
+
'marital_status': user.get('marital_status', 'Single'),
|
| 167 |
'interaction_history': stats['unique_items'][:50], # Limit to 50 most recent
|
| 168 |
'interaction_stats': {
|
| 169 |
'total_interactions': stats['total_interactions'],
|
train_improved_model.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Train the improved two-tower model with all enhancements to address the identified issues.
|
| 4 |
-
|
| 5 |
-
This script implements:
|
| 6 |
-
β
128D embeddings (vs 64D) - Better representation capacity
|
| 7 |
-
β
Temperature scaling - Improved score discrimination
|
| 8 |
-
β
Category-aware boosting - Enhanced personalization
|
| 9 |
-
β
Contrastive loss - Prevents embedding collapse
|
| 10 |
-
β
Hard negative mining - Better training signal
|
| 11 |
-
β
User/item bias terms - Improved modeling capacity
|
| 12 |
-
β
Curriculum learning - Progressive training strategy
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
import argparse
|
| 16 |
-
import sys
|
| 17 |
-
import os
|
| 18 |
-
|
| 19 |
-
# Add project root to path
|
| 20 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 21 |
-
|
| 22 |
-
from src.training.curriculum_trainer import CurriculumTrainer
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def main():
|
| 26 |
-
parser = argparse.ArgumentParser(description='Train improved two-tower model')
|
| 27 |
-
parser.add_argument('--embedding-dim', type=int, default=128,
|
| 28 |
-
help='Embedding dimension (default: 128)')
|
| 29 |
-
parser.add_argument('--learning-rate', type=float, default=0.001,
|
| 30 |
-
help='Learning rate (default: 0.001)')
|
| 31 |
-
parser.add_argument('--epochs-per-stage', type=int, default=15,
|
| 32 |
-
help='Epochs per curriculum stage (default: 15)')
|
| 33 |
-
parser.add_argument('--batch-size', type=int, default=512,
|
| 34 |
-
help='Batch size (default: 512)')
|
| 35 |
-
parser.add_argument('--curriculum-stages', type=int, default=3,
|
| 36 |
-
help='Number of curriculum stages (default: 3)')
|
| 37 |
-
parser.add_argument('--use-focal-loss', action='store_true', default=True,
|
| 38 |
-
help='Use focal loss for imbalanced data')
|
| 39 |
-
|
| 40 |
-
args = parser.parse_args()
|
| 41 |
-
|
| 42 |
-
print("π TRAINING IMPROVED TWO-TOWER MODEL")
|
| 43 |
-
print("="*70)
|
| 44 |
-
print("IMPROVEMENTS IMPLEMENTED:")
|
| 45 |
-
print("β
128D embeddings (increased from 64D)")
|
| 46 |
-
print("β
Temperature scaling for better score discrimination")
|
| 47 |
-
print("β
Category-aware boosting for personalization")
|
| 48 |
-
print("β
Contrastive loss to prevent embedding collapse")
|
| 49 |
-
print("β
Hard negative mining for better training")
|
| 50 |
-
print("β
User/item bias terms for improved modeling")
|
| 51 |
-
print("β
Curriculum learning for progressive training")
|
| 52 |
-
print("="*70)
|
| 53 |
-
|
| 54 |
-
# Initialize trainer with improved settings
|
| 55 |
-
trainer = CurriculumTrainer(
|
| 56 |
-
embedding_dim=args.embedding_dim,
|
| 57 |
-
learning_rate=args.learning_rate,
|
| 58 |
-
use_focal_loss=args.use_focal_loss,
|
| 59 |
-
curriculum_stages=args.curriculum_stages
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
try:
|
| 63 |
-
# Load data and train
|
| 64 |
-
trainer.load_data_processor()
|
| 65 |
-
trainer.create_model()
|
| 66 |
-
|
| 67 |
-
# Load training data
|
| 68 |
-
import pickle
|
| 69 |
-
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 70 |
-
training_features = pickle.load(f)
|
| 71 |
-
|
| 72 |
-
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 73 |
-
validation_features = pickle.load(f)
|
| 74 |
-
|
| 75 |
-
# Train with curriculum learning
|
| 76 |
-
history = trainer.train_with_curriculum(
|
| 77 |
-
training_features=training_features,
|
| 78 |
-
validation_features=validation_features,
|
| 79 |
-
epochs_per_stage=args.epochs_per_stage,
|
| 80 |
-
batch_size=args.batch_size
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
# Save results
|
| 84 |
-
trainer.save_model()
|
| 85 |
-
|
| 86 |
-
with open("src/artifacts/improved_training_history.pkl", 'wb') as f:
|
| 87 |
-
pickle.dump(history, f)
|
| 88 |
-
|
| 89 |
-
print("\nπ― EXPECTED IMPROVEMENTS:")
|
| 90 |
-
print("β’ Score variance: 0.0007 β 0.01+ (15x better discrimination)")
|
| 91 |
-
print("β’ Category alignment: 12% β 60%+ (5x better personalization)")
|
| 92 |
-
print("β’ Reduced embedding collapse (more diverse user representations)")
|
| 93 |
-
print("β’ Better negative sampling and contrastive learning")
|
| 94 |
-
print("β’ Improved bias modeling for users and items")
|
| 95 |
-
|
| 96 |
-
print("\nβ
TRAINING COMPLETED SUCCESSFULLY!")
|
| 97 |
-
print("The improved model should address all critical issues identified in your analysis.")
|
| 98 |
-
|
| 99 |
-
except FileNotFoundError as e:
|
| 100 |
-
print(f"β ERROR: {e}")
|
| 101 |
-
print("Please ensure training data exists in src/artifacts/")
|
| 102 |
-
print("Run data preprocessing first if needed.")
|
| 103 |
-
|
| 104 |
-
except Exception as e:
|
| 105 |
-
print(f"β TRAINING ERROR: {e}")
|
| 106 |
-
import traceback
|
| 107 |
-
traceback.print_exc()
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
if __name__ == "__main__":
|
| 111 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|