minhajHP commited on
Commit
7b5d392
Β·
1 Parent(s): 644ceea

Clean codebase and add demographic enhancements

Browse files

- Enhanced demographic handling for zero interactions
- Improved content-based filtering with aggregated history
- Added comprehensive UI support for new user scenarios
- Cleaned up analysis files and redundant components
- Updated documentation and project structure

DEEP_ARCHITECTURE.md ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep Architecture Documentation - RecSys-HP
2
+
3
+ ## πŸ—οΈ Complete System Architecture Overview
4
+
5
+ ```mermaid
6
+ graph TB
7
+ subgraph "Data Layer"
8
+ D1[items.csv<br/>15K+ products]
9
+ D2[users.csv<br/>Enhanced demographics]
10
+ D3[interactions.csv<br/>User-item interactions]
11
+ D4[Artifacts<br/>Trained models & indices]
12
+ end
13
+
14
+ subgraph "ML Pipeline"
15
+ P1[Data Preprocessing]
16
+ P2[Item Tower Pre-training]
17
+ P3[FAISS Index Creation]
18
+ P4[Joint Training]
19
+ P5[User Tower Training]
20
+ end
21
+
22
+ subgraph "Inference Layer"
23
+ I1[Recommendation Engine<br/>Category-Boosted Algorithm]
24
+ I2[FAISS Similarity Search]
25
+ I3[Real User Selection]
26
+ I4[Hybrid Scoring]
27
+ end
28
+
29
+ subgraph "API Layer"
30
+ A1[FastAPI Server<br/>Port 8000]
31
+ A2[Recommendation Endpoints]
32
+ A3[User Management]
33
+ A4[Item Retrieval]
34
+ end
35
+
36
+ subgraph "Frontend Layer"
37
+ F1[React.js Application]
38
+ F2[Interactive UI Components]
39
+ F3[Real-time Analytics]
40
+ F4[User Profile Management]
41
+ end
42
+
43
+ D1 --> P1
44
+ D2 --> P1
45
+ D3 --> P1
46
+ P1 --> P2
47
+ P2 --> P3
48
+ P3 --> P4
49
+ P4 --> P5
50
+ P5 --> D4
51
+ D4 --> I1
52
+ I1 --> I2
53
+ I2 --> I3
54
+ I3 --> I4
55
+ I4 --> A1
56
+ A1 --> A2
57
+ A2 --> A3
58
+ A3 --> A4
59
+ A4 --> F1
60
+ F1 --> F2
61
+ F2 --> F3
62
+ F3 --> F4
63
+ ```
64
+
65
+ ---
66
+
67
+ ## πŸ“ Project Structure
68
+
69
+ ```
70
+ RecSys-HP/
71
+ β”œβ”€β”€ πŸ—‚οΈ Data Layer
72
+ β”‚ β”œβ”€β”€ datasets/
73
+ β”‚ β”‚ β”œβ”€β”€ items.csv # 15K+ product catalog
74
+ β”‚ β”‚ β”œβ”€β”€ users.csv # Enhanced user demographics (7 features)
75
+ β”‚ β”‚ β”œβ”€β”€ interactions.csv # User-item interaction history
76
+ β”‚ β”‚ └── users_enhanced.csv # Backup with enhanced features
77
+ β”‚ └── src/artifacts/ # Trained models and indices
78
+ β”‚ β”œβ”€β”€ item_embeddings.npy # Pre-trained item vectors (128D)
79
+ β”‚ β”œβ”€β”€ faiss_item_index.bin # FAISS similarity index
80
+ β”‚ β”œβ”€β”€ faiss_metadata.pkl # Item metadata mapping
81
+ β”‚ β”œβ”€β”€ vocabularies.pkl # Categorical encoders
82
+ β”‚ └── *.weights.* files # TensorFlow model weights
83
+ β”‚
84
+ β”œβ”€β”€ 🧠 ML Pipeline
85
+ β”‚ β”œβ”€β”€ src/preprocessing/
86
+ β”‚ β”‚ β”œβ”€β”€ data_loader.py # Data loading and preprocessing
87
+ β”‚ β”‚ └── user_data_preparation.py # User feature engineering
88
+ β”‚ β”œβ”€β”€ src/training/
89
+ β”‚ β”‚ β”œβ”€β”€ item_pretraining.py # Item tower pre-training
90
+ β”‚ β”‚ └── joint_training.py # Two-tower joint training
91
+ β”‚ β”œβ”€β”€ src/models/
92
+ β”‚ β”‚ β”œβ”€β”€ item_tower.py # Item embedding model (TensorFlow)
93
+ β”‚ β”‚ └── user_tower.py # User embedding model (TensorFlow)
94
+ β”‚ └── Training Scripts
95
+ β”‚ β”œβ”€β”€ run_training_pipeline.py # Complete pipeline executor
96
+ β”‚ β”œβ”€β”€ run_2phase_training.py # 2-phase training approach
97
+ β”‚ └── run_joint_training.py # Joint training approach
98
+ β”‚
99
+ β”œβ”€β”€ πŸ” Inference Layer
100
+ β”‚ β”œβ”€β”€ src/inference/
101
+ β”‚ β”‚ β”œβ”€β”€ recommendation_engine.py # Core recommendation algorithms
102
+ β”‚ β”‚ └── faiss_index.py # FAISS index management
103
+ β”‚ β”œβ”€β”€ src/utils/
104
+ β”‚ β”‚ └── real_user_selector.py # Real user data selection
105
+ β”‚ └── src/data_generation/
106
+ β”‚ └── generate_demographics.py # Synthetic user generation
107
+ β”‚
108
+ β”œβ”€β”€ 🌐 API Layer
109
+ β”‚ └── api/
110
+ β”‚ └── main.py # FastAPI server with all endpoints
111
+ β”‚
112
+ β”œβ”€β”€ 🎨 Frontend Layer
113
+ β”‚ └── frontend/
114
+ β”‚ β”œβ”€β”€ src/
115
+ β”‚ β”‚ β”œβ”€β”€ App.js # Main React application
116
+ β”‚ β”‚ └── index.js # Entry point
117
+ β”‚ β”œβ”€β”€ public/ # Static assets
118
+ β”‚ β”œβ”€β”€ build/ # Production build
119
+ β”‚ └── package.json # Dependencies
120
+ β”‚
121
+ └── πŸ§ͺ Testing & Analysis
122
+ β”œβ”€β”€ test_category_boosted.py # Basic algorithm testing
123
+ β”œβ”€β”€ test_enhanced_category_boosted.py # Advanced subcategory testing
124
+ └── deep_analyze_category_boosted.py # Comprehensive analysis tool
125
+ ```
126
+
127
+ ---
128
+
129
+ ## πŸ”„ Data Flow Architecture
130
+
131
+ ### 1. Training Pipeline Flow
132
+ ```mermaid
133
+ sequenceDiagram
134
+ participant D as Data Files
135
+ participant P as Preprocessing
136
+ participant IT as Item Tower
137
+ participant F as FAISS Index
138
+ participant JT as Joint Training
139
+ participant UT as User Tower
140
+ participant A as Artifacts
141
+
142
+ D->>P: Load datasets (items, users, interactions)
143
+ P->>IT: Preprocessed item features
144
+ IT->>IT: Pre-train item embeddings (128D)
145
+ IT->>F: Generate item vectors
146
+ F->>F: Build FAISS similarity index
147
+ IT->>JT: Pre-trained item tower
148
+ P->>JT: User features (7 demographics)
149
+ JT->>UT: Train user tower
150
+ JT->>A: Save trained models
151
+ F->>A: Save FAISS index
152
+ ```
153
+
154
+ ### 2. Inference Pipeline Flow
155
+ ```mermaid
156
+ sequenceDiagram
157
+ participant U as User Request
158
+ participant API as FastAPI
159
+ participant RE as Recommendation Engine
160
+ participant F as FAISS Search
161
+ participant CB as Category Boosted
162
+ participant R as Response
163
+
164
+ U->>API: POST /recommendations
165
+ API->>RE: User profile + preferences
166
+ RE->>F: Query item embeddings
167
+ F->>RE: Similar items (k*10 wide search)
168
+ RE->>CB: Apply category-boosted algorithm
169
+ CB->>CB: 50% from user categories + proportional distribution
170
+ CB->>RE: Balanced recommendations
171
+ RE->>API: Scored & ranked items
172
+ API->>R: JSON response with recommendations
173
+ ```
174
+
175
+ ---
176
+
177
+ ## 🧠 Machine Learning Architecture
178
+
179
+ ### Two-Tower Architecture
180
+ ```
181
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
182
+ β”‚ ITEM TOWER β”‚ β”‚ USER TOWER β”‚
183
+ β”‚ β”‚ β”‚ β”‚
184
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
185
+ β”‚ β”‚ Item Features β”‚ β”‚ β”‚ β”‚ Demographic Featuresβ”‚ β”‚
186
+ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚
187
+ β”‚ β”‚ β€’ product_id β”‚ β”‚ β”‚ β”‚ β€’ age (normalized) β”‚ β”‚
188
+ β”‚ β”‚ β€’ category_code β”‚ β”‚ β”‚ β”‚ β€’ gender (encoded) β”‚ β”‚
189
+ β”‚ β”‚ β€’ brand β”‚ β”‚ β”‚ β”‚ β€’ income (binned) β”‚ β”‚
190
+ β”‚ β”‚ β€’ price (log) β”‚ β”‚ β”‚ β”‚ β€’ profession β”‚ β”‚
191
+ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β€’ location β”‚ β”‚
192
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β€’ education_level β”‚ β”‚
193
+ β”‚ β”‚ β”‚ β”‚ β”‚ β€’ marital_status β”‚ β”‚
194
+ β”‚ β–Ό β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
195
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚
196
+ β”‚ β”‚ Dense Layers β”‚ β”‚ β”‚ β–Ό β”‚
197
+ β”‚ β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
198
+ β”‚ β”‚ β€’ Dense(256, ReLU) β”‚ β”‚ β”‚ β”‚ Dense Layers β”‚ β”‚
199
+ β”‚ β”‚ β€’ Dropout(0.3) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚
200
+ β”‚ β”‚ β€’ Dense(128, ReLU) β”‚ β”‚ β”‚ β”‚ β€’ Dense(128, ReLU) β”‚ β”‚
201
+ β”‚ β”‚ β€’ L2 Regularization β”‚ β”‚ β”‚ β”‚ β€’ Dropout(0.2) β”‚ β”‚
202
+ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Dense(64, ReLU) β”‚ β”‚
203
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β€’ L2 Regularization β”‚ β”‚
204
+ β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
205
+ β”‚ β–Ό β”‚ β”‚ β”‚ β”‚
206
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β–Ό β”‚
207
+ β”‚ β”‚ Item Embedding β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
208
+ β”‚ β”‚ (128D) β”‚ β”‚ β”‚ β”‚ User Embedding β”‚ β”‚
209
+ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ (64D) β”‚ β”‚
210
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
211
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
212
+ β”‚ β”‚
213
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
214
+ β–Ό
215
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
216
+ β”‚ Dot Product β”‚
217
+ β”‚ Similarity Score β”‚
218
+ β”‚ β”‚
219
+ β”‚ similarity = user_emb β”‚
220
+ β”‚ Β· item_emb β”‚
221
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
222
+ ```
223
+
224
+ ### Category-Boosted Algorithm
225
+ ```
226
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
227
+ β”‚ CATEGORY-BOOSTED RECOMMENDATION FLOW β”‚
228
+ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
229
+ β”‚ β”‚
230
+ β”‚ 1. USER INTERACTION ANALYSIS β”‚
231
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
232
+ β”‚ β”‚ interaction_history: [1001, 2003, β”‚ β”‚
233
+ β”‚ β”‚ 3045, 1099] β”‚ β”‚
234
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
235
+ β”‚ β”‚ β”‚
236
+ β”‚ β–Ό β”‚
237
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
238
+ β”‚ β”‚ Extract 2-level subcategories: β”‚ β”‚
239
+ β”‚ β”‚ β€’ computers.components: 40% β”‚ β”‚
240
+ β”‚ β”‚ β€’ electronics.audio: 35% β”‚ β”‚
241
+ β”‚ β”‚ β€’ computers.peripherals: 25% β”‚ β”‚
242
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
243
+ β”‚ β”‚
244
+ β”‚ 2. WIDE SIMILARITY SEARCH β”‚
245
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
246
+ β”‚ β”‚ FAISS.search(user_embedding, β”‚ β”‚
247
+ β”‚ β”‚ k = requested * 10) β”‚ β”‚
248
+ β”‚ β”‚ β”‚ β”‚
249
+ β”‚ β”‚ Returns: ~1000 similar items β”‚ β”‚
250
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
251
+ β”‚ β”‚ β”‚
252
+ β”‚ β–Ό β”‚
253
+ β”‚ 3. CATEGORY ORGANIZATION β”‚
254
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
255
+ β”‚ β”‚ Group by subcategories: β”‚ β”‚
256
+ β”‚ β”‚ β”‚ β”‚
257
+ β”‚ β”‚ computers.components: [1001, 1099, β”‚ β”‚
258
+ β”‚ β”‚ 1203, ...] β”‚ β”‚
259
+ β”‚ β”‚ electronics.audio: [2003, 2156, β”‚ β”‚
260
+ β”‚ β”‚ 2089, ...] β”‚ β”‚
261
+ β”‚ β”‚ computers.peripherals: [3045, 3201, β”‚ β”‚
262
+ β”‚ β”‚ 3078, ...] β”‚ β”‚
263
+ β”‚ β”‚ other_categories: [4001, 5002, ...] β”‚ β”‚
264
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
265
+ β”‚ β”‚ β”‚
266
+ β”‚ β–Ό β”‚
267
+ β”‚ 4. PROPORTIONAL ALLOCATION β”‚
268
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
269
+ β”‚ β”‚ Target: 50% from user categories β”‚ β”‚
270
+ β”‚ β”‚ (50 items for 100 recommendations) β”‚ β”‚
271
+ β”‚ β”‚ β”‚ β”‚
272
+ β”‚ β”‚ computers.components: 40% β†’ 20 itemsβ”‚ β”‚
273
+ β”‚ β”‚ electronics.audio: 35% β†’ 18 items β”‚ β”‚
274
+ β”‚ β”‚ computers.peripherals: 25% β†’ 12 itemsβ”‚ β”‚
275
+ β”‚ β”‚ β”‚ β”‚
276
+ β”‚ β”‚ Remaining 50 items: diverse mix β”‚ β”‚
277
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
278
+ β”‚ β”‚ β”‚
279
+ β”‚ β–Ό β”‚
280
+ β”‚ 5. FINAL RECOMMENDATION SET β”‚
281
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
282
+ β”‚ β”‚ β€’ 50 items from user's categories β”‚ β”‚
283
+ β”‚ β”‚ (proportionally distributed) β”‚ β”‚
284
+ β”‚ β”‚ β€’ 50 items for exploration β”‚ β”‚
285
+ β”‚ β”‚ β€’ All ranked by similarity score β”‚ β”‚
286
+ β”‚ β”‚ β€’ Ensures category diversity β”‚ β”‚
287
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
288
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
289
+ ```
290
+
291
+ ---
292
+
293
+ ## 🌐 API Architecture
294
+
295
+ ### FastAPI Server Endpoints
296
+ ```python
297
+ # Core Recommendation Endpoints
298
+ POST /recommendations # Main recommendation engine
299
+ GET /real-users # Fetch real user profiles
300
+ GET /items/{item_id} # Get item details
301
+ GET /dataset-summary # Dataset statistics
302
+
303
+ # Algorithm-Specific Endpoints
304
+ POST /recommendations/hybrid # Hybrid collaborative + content
305
+ POST /recommendations/collaborative # Pure collaborative filtering
306
+ POST /recommendations/content # Aggregated history content-based recommendations
307
+ POST /recommendations/category_boosted # Category-boosted algorithm
308
+
309
+ # Utility Endpoints
310
+ GET / # Health check
311
+ GET /sample-items # Random item samples
312
+ POST /generate-interactions # Synthetic interaction generation
313
+ ```
314
+
315
+ ### Request/Response Flow
316
+ ```mermaid
317
+ graph LR
318
+ subgraph "Request Processing"
319
+ A[User Request] --> B[Validation]
320
+ B --> C[Feature Engineering]
321
+ C --> D[Model Inference]
322
+ end
323
+
324
+ subgraph "Recommendation Engine"
325
+ D --> E[FAISS Search]
326
+ E --> F[Category Analysis]
327
+ F --> G[Score Calculation]
328
+ G --> H[Ranking & Filtering]
329
+ end
330
+
331
+ subgraph "Response Generation"
332
+ H --> I[Item Enrichment]
333
+ I --> J[Metadata Addition]
334
+ J --> K[JSON Response]
335
+ end
336
+ ```
337
+
338
+ ---
339
+
340
+ ## 🎨 Frontend Architecture
341
+
342
+ ### React.js Component Structure
343
+ ```
344
+ App.js (Main Container)
345
+ β”œβ”€β”€ User Profile Management
346
+ β”‚ β”œβ”€β”€ Demographics Form
347
+ β”‚ β”œβ”€β”€ Real User Selection
348
+ β”‚ └── Interaction History Display
349
+ β”œβ”€β”€ Recommendation Controls
350
+ β”‚ β”œβ”€β”€ Algorithm Selection
351
+ β”‚ β”œβ”€β”€ Count Configuration
352
+ β”‚ └── Weight Adjustment
353
+ β”œβ”€β”€ Results Display
354
+ β”‚ β”œβ”€β”€ Recommendation Cards
355
+ β”‚ β”œβ”€β”€ Category Analytics
356
+ β”‚ β”œβ”€β”€ Pagination Controls
357
+ β”‚ └── Similar Items View
358
+ └── Analysis Components
359
+ β”œβ”€β”€ Category Interest Graphs
360
+ β”œβ”€β”€ Interaction Patterns
361
+ └── Performance Metrics
362
+ ```
363
+
364
+ ### State Management
365
+ ```javascript
366
+ const [userProfile, setUserProfile] = useState({
367
+ age: 30,
368
+ gender: 'male',
369
+ income: 50000,
370
+ profession: 'Technology',
371
+ location: 'Urban',
372
+ education_level: "Bachelor's",
373
+ marital_status: 'Single',
374
+ interaction_history: []
375
+ });
376
+
377
+ const [recommendationType, setRecommendationType] = useState('category_boosted');
378
+ const [recommendations, setRecommendations] = useState([]);
379
+ const [realUsers, setRealUsers] = useState([]);
380
+ const [datasetSummary, setDatasetSummary] = useState(null);
381
+ ```
382
+
383
+ ---
384
+
385
+ ## πŸ” Algorithm Deep Dive
386
+
387
+ ### 1. Hybrid Recommendation
388
+ - **Collaborative Filtering**: User-item interaction patterns
389
+ - **Aggregated Content-Based**: User's complete interaction history aggregated into single embedding
390
+ - **Weight Balance**: Configurable collaborative weight (default: 0.7)
391
+
392
+ ### 1.5. Aggregated History Content-Based Filtering
393
+ - **Revolutionary Approach**: Aggregates user's entire interaction history instead of single-item similarity
394
+ - **Aggregation Methods**:
395
+ - **Weighted Mean**: `weights = exp(linspace(-1, 0, len(history)))` (recent interactions weighted higher)
396
+ - **Simple Mean**: Equal weighting of all interaction embeddings
397
+ - **Max Pooling**: Element-wise maximum across all embeddings
398
+ - **Process Flow**:
399
+ 1. **Embedding Extraction**: Get 128D vectors for each item in user's history
400
+ 2. **Aggregation**: Apply selected aggregation method (weighted_mean by default)
401
+ 3. **Normalization**: L2-normalize the aggregated embedding
402
+ 4. **ANN Search**: Direct FAISS similarity search using aggregated user profile
403
+ 5. **Filtering**: Remove already-interacted items from results
404
+ - **Benefits**: Captures complete user preference profile, more robust than single-item seed
405
+
406
+ ### 2. Category-Boosted Algorithm
407
+ - **Step 1**: Analyze user's subcategory preferences (2-level depth)
408
+ - **Step 2**: Wide FAISS search (k Γ— 10 multiplier)
409
+ - **Step 3**: Category organization and candidate grouping
410
+ - **Step 4**: Proportional allocation (50% from user categories)
411
+ - **Step 5**: Exploration items filling (remaining 50%)
412
+
413
+ ### 3. FAISS Integration
414
+ - **Index Type**: Flat L2 similarity search
415
+ - **Vector Dimension**: 128D item embeddings
416
+ - **Search Strategy**: Wide retrieval + post-processing
417
+ - **Metadata**: Item-to-index mapping via pickle files
418
+
419
+ ---
420
+
421
+ ## πŸ“Š Performance Characteristics
422
+
423
+ ### Scalability Metrics
424
+ - **Items**: 15K+ products supported
425
+ - **Users**: Unlimited (stateless design)
426
+ - **Recommendations**: 1-1000 per request
427
+ - **Response Time**: <2s for 100 recommendations
428
+ - **Memory Usage**: ~500MB for full model + index
429
+
430
+ ### Algorithm Performance
431
+ - **Category Matching**: β‰₯50% from user's categories
432
+ - **Diversity Score**: Balanced exploration vs exploitation
433
+ - **Cold Start**: Handles new users via demographic features
434
+ - **Subcategory Precision**: 2-level category matching
435
+
436
+ ---
437
+
438
+ ## πŸš€ Deployment Architecture
439
+
440
+ ### Development Environment
441
+ ```bash
442
+ # Backend (FastAPI)
443
+ cd /api && python main.py
444
+
445
+ # Frontend (React)
446
+ cd frontend && npm start
447
+
448
+ # Training Pipeline
449
+ python run_training_pipeline.py
450
+ ```
451
+
452
+ ### Production Considerations
453
+ - **Containerization**: Docker support for API + Frontend
454
+ - **Database**: PostgreSQL for production user/item storage
455
+ - **Caching**: Redis for recommendation caching
456
+ - **Load Balancing**: Nginx for multiple API instances
457
+ - **Monitoring**: Prometheus + Grafana for metrics
458
+
459
+ ---
460
+
461
+ ## πŸ”§ Configuration & Customization
462
+
463
+ ### Model Configuration
464
+ ```python
465
+ # Item Tower
466
+ ITEM_EMBEDDING_DIM = 128
467
+ ITEM_HIDDEN_LAYERS = [256, 128]
468
+ ITEM_DROPOUT_RATE = 0.3
469
+
470
+ # User Tower
471
+ USER_EMBEDDING_DIM = 64
472
+ USER_HIDDEN_LAYERS = [128, 64]
473
+ USER_DROPOUT_RATE = 0.2
474
+
475
+ # Training
476
+ BATCH_SIZE = 512
477
+ LEARNING_RATE = 0.001
478
+ EPOCHS = 100
479
+ VALIDATION_SPLIT = 0.2
480
+ ```
481
+
482
+ ### Algorithm Parameters
483
+ ```python
484
+ # Category-Boosted
485
+ WIDE_SEARCH_MULTIPLIER = 10
486
+ USER_CATEGORY_PERCENTAGE = 0.5
487
+ SUBCATEGORY_LEVELS = 2
488
+ MIN_INTERACTION_THRESHOLD = 5
489
+
490
+ # FAISS
491
+ INDEX_TYPE = "Flat"
492
+ SIMILARITY_METRIC = "L2"
493
+ SEARCH_PARAMS = {"nprobe": 10}
494
+
495
+ # Aggregated Content-Based
496
+ AGGREGATION_METHOD = "weighted_mean" # "mean", "weighted_mean", "max"
497
+ TEMPORAL_DECAY_ALPHA = 1.0 # Controls recency weighting strength
498
+ HISTORY_LIMIT = 50 # Max items to consider for aggregation
499
+ ```
500
+
501
+ ---
502
+
503
+ ## πŸ§ͺ Testing Framework
504
+
505
+ ### Test Coverage
506
+ - **Unit Tests**: Individual algorithm components
507
+ - **Integration Tests**: End-to-end recommendation flow
508
+ - **Performance Tests**: Latency and throughput benchmarks
509
+ - **Accuracy Tests**: Category matching validation
510
+
511
+ ### Analysis Tools
512
+ - `test_category_boosted.py`: Basic algorithm validation
513
+ - `test_enhanced_category_boosted.py`: Advanced subcategory testing
514
+ - `deep_analyze_category_boosted.py`: Comprehensive performance analysis
515
+ - **`analyze_recommendation_alignment.py`**: **NEW** - Multi-algorithm alignment analysis
516
+ - Tests all 4 algorithms (collaborative, content, hybrid, category_boosted)
517
+ - Category alignment scoring and coverage analysis
518
+ - Diversity vs relevance trade-off analysis
519
+ - User-specific algorithm performance comparison
520
+ - Generates comprehensive visualizations and reports
521
+
522
+ ### Algorithm Comparison Metrics
523
+ - **Top-Level Alignment**: % of recommendations matching user's preferred categories
524
+ - **Subcategory Precision**: 2-level category matching accuracy
525
+ - **Coverage Score**: % of user's categories represented in recommendations
526
+ - **Diversity Score**: Shannon entropy of recommendation categories
527
+ - **Performance by Scale**: Algorithm behavior across 10-100+ recommendations
528
+
529
+ ---
530
+
531
+ ## πŸ“ˆ Future Enhancements
532
+
533
+ ### Planned Features
534
+ 1. **Real-time Learning**: Online model updates
535
+ 2. **A/B Testing**: Algorithm comparison framework
536
+ 3. **Explainability**: Recommendation reasoning
537
+ 4. **Multi-objective**: Balancing relevance, diversity, novelty
538
+ 5. **Graph Neural Networks**: Advanced relationship modeling
539
+
540
+ ### Technical Debt
541
+ - [ ] Add comprehensive error handling
542
+ - [ ] Implement request caching
543
+ - [ ] Add model versioning
544
+ - [ ] Create automated testing pipeline
545
+ - [ ] Add performance monitoring
546
+
547
+ ---
548
+
549
+ This deep architecture documentation provides a comprehensive view of the RecSys-HP recommendation system, covering all layers from data storage to user interface, with detailed technical specifications and implementation details.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  # Advanced Two-Tower Recommendation System
2
 
3
- A production-ready recommendation system implementation using TensorFlow Recommenders with an enhanced two-tower architecture. This system provides personalized item recommendations through collaborative filtering, content-based filtering, and hybrid approaches, featuring categorical demographics, curriculum learning, and advanced training strategies.
4
 
5
  ## 🎯 Project Overview
6
 
@@ -11,7 +11,7 @@ This recommendation system addresses the challenge of providing personalized ite
11
  - **🧠 Enhanced Two-Tower Architecture**: 128D embeddings with temperature scaling and contrastive learning
12
  - **πŸ“š Curriculum Learning**: Progressive training strategy for improved convergence
13
  - **⚑ Real-time Inference**: Sub-100ms recommendation serving with FAISS indexing
14
- - **πŸ”„ Multi-strategy Recommendations**: Collaborative, content-based, and hybrid approaches
15
  - **πŸŽͺ Category-Aware Boosting**: Enhanced personalization through user preference alignment
16
  - **πŸ” Interactive Similar Items**: Click-to-explore with 60/40 category-balanced discovery
17
  - **πŸ“Š Comprehensive Analysis**: Quality metrics and performance evaluation tools
@@ -70,11 +70,22 @@ The system implements a sophisticated two-tower neural network architecture opti
70
  - **Stage 3**: Complex cases (long history) - 67th+ percentile
71
  - **Adaptive Learning Rates**: Decrease as stages progress for stability
72
 
73
- ### 5. Category-Aware Recommendation Engine πŸŽͺ
 
 
 
 
 
 
 
 
 
 
74
  - **Enhanced Hybrid Recommendations**: Category boosting based on user preferences
75
  - **Category Alignment Analysis**: Measures personalization effectiveness
76
  - **Diversity Controls**: Balanced category representation in recommendations
77
- - **Explanation Generation**: Detailed reasoning for recommendation choices
 
78
 
79
  ## πŸ“ Project Structure
80
 
@@ -134,10 +145,8 @@ RecSys-HP/
134
  β”œβ”€β”€ πŸš€ Training Scripts # Multiple training approaches
135
  β”‚ β”œβ”€β”€ run_training_pipeline.py # Main training orchestration
136
  β”‚ β”œβ”€β”€ run_2phase_training.py # 2-phase training approach
137
- β”‚ β”œβ”€β”€ run_joint_training.py # Joint training approach
138
- β”‚ └── train_improved_model.py # Enhanced model training
139
  β”‚
140
- β”œβ”€β”€ πŸ“Š analyze_recommendations.py # Recommendation quality analysis
141
  └── πŸ“‹ requirements.txt # Python dependencies
142
  ```
143
 
@@ -478,24 +487,11 @@ python api_2phase.py
478
  python api_joint.py
479
  ```
480
 
481
- ### πŸ“Š Analysis & Testing Tools
482
-
483
  ```bash
484
- # Comprehensive recommendation analysis
485
- python analyze_recommendations.py
486
- # β†’ Generates recommendation_analysis_report.md + plots
487
-
488
- # Test individual engines
489
- python -m src.inference.enhanced_recommendation_engine_128d # 128D enhanced
490
- python -m src.inference.enhanced_recommendation_engine # Standard enhanced
491
- python -m src.inference.recommendation_engine # Basic engine
492
-
493
- # Real user data utilities
494
  python -m src.utils.real_user_selector # Demo real user extraction
495
-
496
- # Data processing utilities
497
- python -m src.preprocessing.data_loader
498
- python -m src.preprocessing.optimized_dataset_creator
499
  ```
500
 
501
  ### πŸ§ͺ Frontend Development
@@ -667,7 +663,8 @@ RecSys-HP/
667
  βœ… **Enhanced Architecture**: 128D embeddings, temperature scaling, contrastive learning
668
  βœ… **Curriculum Learning**: Progressive training for better convergence
669
  βœ… **Category-Aware Recommendations**: Intelligent personalization with diversity
670
- βœ… **Comprehensive Analysis**: Quality metrics and performance evaluation
 
671
  βœ… **Production Ready**: Scalable API with enhanced frontend features
672
 
673
  **πŸŽ‰ Ready to deliver next-generation personalized recommendations!**
@@ -687,11 +684,12 @@ This project provides multiple training strategies:
687
  - **2-Phase API** (`api_2phase.py`) - Specialized for 2-phase training
688
  - **Joint API** (`api_joint.py`) - Optimized for joint training approach
689
 
690
- ## πŸ“Š Analysis Tools
691
 
692
- - **Recommendation Analysis** (`analyze_recommendations.py`) - Quality metrics and evaluation
 
693
 
694
- ## πŸ”§ Development & Testing
695
 
696
  ### Frontend Development
697
  ```bash
 
1
  # Advanced Two-Tower Recommendation System
2
 
3
+ A production-ready recommendation system implementation using TensorFlow Recommenders with an enhanced two-tower architecture. This system provides personalized item recommendations through collaborative filtering, **aggregated history content-based filtering**, and hybrid approaches, featuring categorical demographics, curriculum learning, and advanced training strategies.
4
 
5
  ## 🎯 Project Overview
6
 
 
11
  - **🧠 Enhanced Two-Tower Architecture**: 128D embeddings with temperature scaling and contrastive learning
12
  - **πŸ“š Curriculum Learning**: Progressive training strategy for improved convergence
13
  - **⚑ Real-time Inference**: Sub-100ms recommendation serving with FAISS indexing
14
+ - **πŸ”„ Multi-strategy Recommendations**: Collaborative, **aggregated history content-based**, and hybrid approaches
15
  - **πŸŽͺ Category-Aware Boosting**: Enhanced personalization through user preference alignment
16
  - **πŸ” Interactive Similar Items**: Click-to-explore with 60/40 category-balanced discovery
17
  - **πŸ“Š Comprehensive Analysis**: Quality metrics and performance evaluation tools
 
70
  - **Stage 3**: Complex cases (long history) - 67th+ percentile
71
  - **Adaptive Learning Rates**: Decrease as stages progress for stability
72
 
73
+ ### 5. Aggregated History Content-Based Filtering πŸ”„
74
+ - **Revolutionary Approach**: Uses aggregated user interaction history instead of single-item similarity
75
+ - **Multiple Aggregation Methods**:
76
+ - **Weighted Mean**: Recent interactions weighted higher (exponential decay)
77
+ - **Simple Mean**: Equal weighting of all interactions
78
+ - **Max Pooling**: Element-wise maximum of embeddings
79
+ - **ANN Search**: Direct similarity search using FAISS with aggregated user profile
80
+ - **Enhanced Personalization**: Captures complete user preference profile, not just recent item
81
+ - **Category-Aware**: Analyzes user's full category distribution for balanced recommendations
82
+
83
+ ### 6. Category-Aware Recommendation Engine πŸŽͺ
84
  - **Enhanced Hybrid Recommendations**: Category boosting based on user preferences
85
  - **Category Alignment Analysis**: Measures personalization effectiveness
86
  - **Diversity Controls**: Balanced category representation in recommendations
87
+ - **Subcategory Precision**: 2-level category matching (e.g., "computers.components")
88
+ - **Comprehensive Analysis Tools**: Multi-algorithm comparison and alignment scoring
89
 
90
  ## πŸ“ Project Structure
91
 
 
145
  β”œβ”€β”€ πŸš€ Training Scripts # Multiple training approaches
146
  β”‚ β”œβ”€β”€ run_training_pipeline.py # Main training orchestration
147
  β”‚ β”œβ”€β”€ run_2phase_training.py # 2-phase training approach
148
+ β”‚ └── run_joint_training.py # Joint training approach
 
149
  β”‚
 
150
  └── πŸ“‹ requirements.txt # Python dependencies
151
  ```
152
 
 
487
  python api_joint.py
488
  ```
489
 
490
+ ### πŸ”§ System Testing
 
491
  ```bash
492
+ # Test core system components
 
 
 
 
 
 
 
 
 
493
  python -m src.utils.real_user_selector # Demo real user extraction
494
+ python -m src.preprocessing.data_loader # Verify data loading
 
 
 
495
  ```
496
 
497
  ### πŸ§ͺ Frontend Development
 
663
  βœ… **Enhanced Architecture**: 128D embeddings, temperature scaling, contrastive learning
664
  βœ… **Curriculum Learning**: Progressive training for better convergence
665
  βœ… **Category-Aware Recommendations**: Intelligent personalization with diversity
666
+ βœ… **Aggregated Content-Based Filtering**: Revolutionary user history aggregation approach
667
+ βœ… **Enhanced Demographic Support**: Improved cold-start user handling
668
  βœ… **Production Ready**: Scalable API with enhanced frontend features
669
 
670
  **πŸŽ‰ Ready to deliver next-generation personalized recommendations!**
 
684
  - **2-Phase API** (`api_2phase.py`) - Specialized for 2-phase training
685
  - **Joint API** (`api_joint.py`) - Optimized for joint training approach
686
 
687
+ ## πŸ”§ Development Tools
688
 
689
+ - **Real User Selection** (`src.utils.real_user_selector`) - Extract real user profiles for testing
690
+ - **Data Loading Utilities** (`src.preprocessing.data_loader`) - Dataset loading and validation
691
 
692
+ ## πŸ§ͺ Development & Testing
693
 
694
  ### Frontend Development
695
  ```bash
analyze_recommendations.py DELETED
@@ -1,543 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Recommendation Analysis Script
4
-
5
- This script compares recommendations from both training approaches:
6
- 1. 2-phase training (pre-trained item tower + joint fine-tuning)
7
- 2. Single joint training (end-to-end optimization)
8
-
9
- It analyzes:
10
- - Category alignment between user interactions and recommendations
11
- - Diversity of recommended categories
12
- - Overlap between the two approaches
13
- - Performance on real users
14
-
15
- Usage:
16
- python analyze_recommendations.py
17
- """
18
-
19
- import os
20
- import sys
21
- import numpy as np
22
- import pandas as pd
23
- from collections import defaultdict, Counter
24
- from typing import Dict, List, Tuple
25
- import matplotlib.pyplot as plt
26
- import seaborn as sns
27
-
28
- # Add src to path
29
- sys.path.append('src')
30
-
31
- from src.inference.recommendation_engine import RecommendationEngine
32
- from src.utils.real_user_selector import RealUserSelector
33
-
34
-
35
- class RecommendationAnalyzer:
36
- """Analyzer for comparing different recommendation approaches."""
37
-
38
- def __init__(self):
39
- self.recommendation_engine = None
40
- self.real_user_selector = None
41
- self.items_df = None
42
- self.setup_engines()
43
-
44
- def setup_engines(self):
45
- """Setup recommendation engines and data."""
46
- print("Loading recommendation engines...")
47
-
48
- try:
49
- # Load recommendation engine (assumes trained model artifacts exist)
50
- self.recommendation_engine = RecommendationEngine()
51
- print("βœ… Recommendation engine loaded")
52
- except Exception as e:
53
- print(f"❌ Error loading recommendation engine: {e}")
54
- return
55
-
56
- try:
57
- # Load real user selector
58
- self.real_user_selector = RealUserSelector()
59
- print("βœ… Real user selector loaded")
60
- except Exception as e:
61
- print(f"❌ Error loading real user selector: {e}")
62
-
63
- # Load items data for category analysis
64
- self.items_df = pd.read_csv("datasets/items.csv")
65
- print(f"βœ… Loaded {len(self.items_df)} items")
66
-
67
- def get_item_categories(self, item_ids: List[int]) -> List[str]:
68
- """Get category codes for given item IDs."""
69
- categories = []
70
- for item_id in item_ids:
71
- item_row = self.items_df[self.items_df['product_id'] == item_id]
72
- if len(item_row) > 0:
73
- categories.append(item_row.iloc[0]['category_code'])
74
- else:
75
- categories.append('unknown')
76
- return categories
77
-
78
- def analyze_user_recommendations(self,
79
- user_profile: Dict,
80
- recommendation_types: List[str] = None) -> Dict:
81
- """Analyze recommendations for a single user across different approaches."""
82
-
83
- if recommendation_types is None:
84
- recommendation_types = ['collaborative', 'hybrid', 'content']
85
-
86
- results = {
87
- 'user_profile': user_profile,
88
- 'interaction_categories': [],
89
- 'recommendations': {},
90
- 'category_analysis': {}
91
- }
92
-
93
- # Get categories from user's interaction history
94
- if user_profile['interaction_history']:
95
- results['interaction_categories'] = self.get_item_categories(
96
- user_profile['interaction_history']
97
- )
98
-
99
- # Get recommendations for each type
100
- for rec_type in recommendation_types:
101
- try:
102
- if rec_type == 'collaborative':
103
- recs = self.recommendation_engine.recommend_items_collaborative(
104
- age=user_profile['age'],
105
- gender=user_profile['gender'],
106
- income=user_profile['income'],
107
- interaction_history=user_profile['interaction_history'],
108
- k=10
109
- )
110
- elif rec_type == 'hybrid':
111
- recs = self.recommendation_engine.recommend_items_hybrid(
112
- age=user_profile['age'],
113
- gender=user_profile['gender'],
114
- income=user_profile['income'],
115
- interaction_history=user_profile['interaction_history'],
116
- k=10
117
- )
118
- elif rec_type == 'content' and user_profile['interaction_history']:
119
- recs = self.recommendation_engine.recommend_items_content_based(
120
- seed_item_id=user_profile['interaction_history'][-1],
121
- k=10
122
- )
123
- else:
124
- continue
125
-
126
- # Extract item IDs and categories
127
- item_ids = [item_id for item_id, score, info in recs]
128
- rec_categories = self.get_item_categories(item_ids)
129
-
130
- results['recommendations'][rec_type] = {
131
- 'items': recs,
132
- 'item_ids': item_ids,
133
- 'categories': rec_categories,
134
- 'scores': [score for item_id, score, info in recs]
135
- }
136
-
137
- # Analyze category alignment
138
- results['category_analysis'][rec_type] = self.analyze_category_alignment(
139
- results['interaction_categories'],
140
- rec_categories
141
- )
142
-
143
- except Exception as e:
144
- print(f"Error generating {rec_type} recommendations: {e}")
145
-
146
- return results
147
-
148
- def analyze_category_alignment(self,
149
- interaction_categories: List[str],
150
- recommendation_categories: List[str]) -> Dict:
151
- """Analyze alignment between interaction and recommendation categories."""
152
-
153
- if not interaction_categories:
154
- return {
155
- 'overlap_ratio': 0.0,
156
- 'unique_interaction_categories': 0,
157
- 'unique_recommendation_categories': len(set(recommendation_categories)),
158
- 'common_categories': [],
159
- 'category_distribution': Counter(recommendation_categories)
160
- }
161
-
162
- interaction_set = set(interaction_categories)
163
- recommendation_set = set(recommendation_categories)
164
-
165
- common_categories = interaction_set.intersection(recommendation_set)
166
- overlap_ratio = len(common_categories) / len(interaction_set) if interaction_set else 0.0
167
-
168
- return {
169
- 'overlap_ratio': overlap_ratio,
170
- 'unique_interaction_categories': len(interaction_set),
171
- 'unique_recommendation_categories': len(recommendation_set),
172
- 'common_categories': list(common_categories),
173
- 'category_distribution': Counter(recommendation_categories),
174
- 'interaction_category_distribution': Counter(interaction_categories)
175
- }
176
-
177
- def compare_recommendation_approaches(self,
178
- users_sample: List[Dict],
179
- approaches: List[str] = None) -> Dict:
180
- """Compare different recommendation approaches across multiple users."""
181
-
182
- if approaches is None:
183
- approaches = ['collaborative', 'hybrid', 'content']
184
-
185
- comparison_results = {
186
- 'approach_stats': defaultdict(list),
187
- 'cross_approach_analysis': {},
188
- 'user_results': []
189
- }
190
-
191
- print(f"Analyzing {len(users_sample)} users across {len(approaches)} approaches...")
192
-
193
- for i, user in enumerate(users_sample):
194
- print(f"Analyzing user {i+1}/{len(users_sample)}...")
195
-
196
- user_results = self.analyze_user_recommendations(user, approaches)
197
- comparison_results['user_results'].append(user_results)
198
-
199
- # Aggregate stats by approach
200
- for approach in approaches:
201
- if approach in user_results['category_analysis']:
202
- analysis = user_results['category_analysis'][approach]
203
- comparison_results['approach_stats'][approach].append({
204
- 'overlap_ratio': analysis['overlap_ratio'],
205
- 'unique_rec_categories': analysis['unique_recommendation_categories'],
206
- 'common_categories_count': len(analysis['common_categories'])
207
- })
208
-
209
- # Calculate aggregate statistics
210
- for approach in approaches:
211
- stats = comparison_results['approach_stats'][approach]
212
- if stats:
213
- comparison_results['approach_stats'][approach] = {
214
- 'avg_overlap_ratio': np.mean([s['overlap_ratio'] for s in stats]),
215
- 'std_overlap_ratio': np.std([s['overlap_ratio'] for s in stats]),
216
- 'avg_unique_categories': np.mean([s['unique_rec_categories'] for s in stats]),
217
- 'avg_common_categories': np.mean([s['common_categories_count'] for s in stats]),
218
- 'total_users': len(stats)
219
- }
220
-
221
- # Cross-approach analysis
222
- comparison_results['cross_approach_analysis'] = self.cross_approach_analysis(
223
- comparison_results['user_results'], approaches
224
- )
225
-
226
- return comparison_results
227
-
228
- def cross_approach_analysis(self, user_results: List[Dict], approaches: List[str]) -> Dict:
229
- """Analyze similarities and differences between approaches."""
230
-
231
- cross_analysis = {
232
- 'item_overlap': defaultdict(dict),
233
- 'category_overlap': defaultdict(dict),
234
- 'score_correlation': defaultdict(dict)
235
- }
236
-
237
- for user_result in user_results:
238
- recommendations = user_result['recommendations']
239
-
240
- # Compare each pair of approaches
241
- for i, approach1 in enumerate(approaches):
242
- for approach2 in approaches[i+1:]:
243
- if approach1 in recommendations and approach2 in recommendations:
244
-
245
- # Item overlap
246
- items1 = set(recommendations[approach1]['item_ids'])
247
- items2 = set(recommendations[approach2]['item_ids'])
248
- item_overlap_ratio = len(items1.intersection(items2)) / len(items1.union(items2))
249
-
250
- # Category overlap
251
- cats1 = set(recommendations[approach1]['categories'])
252
- cats2 = set(recommendations[approach2]['categories'])
253
- cat_overlap_ratio = len(cats1.intersection(cats2)) / len(cats1.union(cats2)) if cats1.union(cats2) else 0
254
-
255
- # Store results
256
- pair_key = f"{approach1}_vs_{approach2}"
257
- if pair_key not in cross_analysis['item_overlap']:
258
- cross_analysis['item_overlap'][pair_key] = []
259
- cross_analysis['category_overlap'][pair_key] = []
260
-
261
- cross_analysis['item_overlap'][pair_key].append(item_overlap_ratio)
262
- cross_analysis['category_overlap'][pair_key].append(cat_overlap_ratio)
263
-
264
- # Calculate averages
265
- for pair_key in cross_analysis['item_overlap']:
266
- cross_analysis['item_overlap'][pair_key] = {
267
- 'avg': np.mean(cross_analysis['item_overlap'][pair_key]),
268
- 'std': np.std(cross_analysis['item_overlap'][pair_key])
269
- }
270
- cross_analysis['category_overlap'][pair_key] = {
271
- 'avg': np.mean(cross_analysis['category_overlap'][pair_key]),
272
- 'std': np.std(cross_analysis['category_overlap'][pair_key])
273
- }
274
-
275
- return cross_analysis
276
-
277
- def generate_report(self, comparison_results: Dict, output_file: str = "recommendation_analysis_report.md"):
278
- """Generate a comprehensive analysis report."""
279
-
280
- report = []
281
- report.append("# Recommendation System Analysis Report")
282
- report.append(f"Generated: {pd.Timestamp.now()}")
283
- report.append("")
284
-
285
- # Overall Statistics
286
- report.append("## Overall Statistics")
287
- report.append("")
288
-
289
- for approach, stats in comparison_results['approach_stats'].items():
290
- if isinstance(stats, dict):
291
- report.append(f"### {approach.title()} Recommendations")
292
- report.append(f"- **Average Category Overlap**: {stats['avg_overlap_ratio']:.3f} Β± {stats['std_overlap_ratio']:.3f}")
293
- report.append(f"- **Average Unique Categories per User**: {stats['avg_unique_categories']:.1f}")
294
- report.append(f"- **Average Common Categories**: {stats['avg_common_categories']:.1f}")
295
- report.append(f"- **Users Analyzed**: {stats['total_users']}")
296
- report.append("")
297
-
298
- # Cross-Approach Analysis
299
- report.append("## Cross-Approach Comparison")
300
- report.append("")
301
-
302
- cross_analysis = comparison_results['cross_approach_analysis']
303
-
304
- report.append("### Item Overlap Between Approaches")
305
- for pair, overlap_stats in cross_analysis['item_overlap'].items():
306
- report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} Β± {overlap_stats['std']:.3f}")
307
- report.append("")
308
-
309
- report.append("### Category Overlap Between Approaches")
310
- for pair, overlap_stats in cross_analysis['category_overlap'].items():
311
- report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} Β± {overlap_stats['std']:.3f}")
312
- report.append("")
313
-
314
- # Category Alignment Analysis
315
- report.append("## Category Alignment Analysis")
316
- report.append("")
317
- report.append("Category alignment measures how well recommendations match the categories")
318
- report.append("of items users have previously interacted with.")
319
- report.append("")
320
-
321
- # Find best performing approach
322
- best_approach = max(
323
- comparison_results['approach_stats'].keys(),
324
- key=lambda k: comparison_results['approach_stats'][k]['avg_overlap_ratio']
325
- if isinstance(comparison_results['approach_stats'][k], dict) else 0
326
- )
327
-
328
- report.append(f"**Best Category Alignment**: {best_approach.title()} approach")
329
- report.append("")
330
-
331
- # Recommendations
332
- report.append("## Key Findings & Recommendations")
333
- report.append("")
334
-
335
- # Analyze overlap ratios to provide insights
336
- overlap_ratios = {
337
- k: v['avg_overlap_ratio'] for k, v in comparison_results['approach_stats'].items()
338
- if isinstance(v, dict)
339
- }
340
-
341
- if overlap_ratios:
342
- avg_overlap = np.mean(list(overlap_ratios.values()))
343
- if avg_overlap > 0.5:
344
- report.append("βœ… **Strong Category Alignment**: Recommendations show good alignment with user interaction patterns.")
345
- elif avg_overlap > 0.3:
346
- report.append("⚠️ **Moderate Category Alignment**: Some alignment present but room for improvement.")
347
- else:
348
- report.append("❌ **Weak Category Alignment**: Recommendations may be too diverse or not well-aligned with user preferences.")
349
-
350
- report.append("")
351
-
352
- # Compare approaches
353
- if len(overlap_ratios) > 1:
354
- sorted_approaches = sorted(overlap_ratios.items(), key=lambda x: x[1], reverse=True)
355
- report.append("### Approach Rankings (by category alignment):")
356
- for i, (approach, ratio) in enumerate(sorted_approaches, 1):
357
- report.append(f"{i}. **{approach.title()}**: {ratio:.3f}")
358
- report.append("")
359
-
360
- # Write report
361
- with open(output_file, 'w') as f:
362
- f.write('\n'.join(report))
363
-
364
- print(f"βœ… Analysis report saved to: {output_file}")
365
- return '\n'.join(report)
366
-
367
- def visualize_results(self, comparison_results: Dict, save_plots: bool = True):
368
- """Create visualizations for the analysis results."""
369
-
370
- # Set up plotting style
371
- plt.style.use('default')
372
- sns.set_palette("husl")
373
-
374
- # Create figure with subplots
375
- fig, axes = plt.subplots(2, 2, figsize=(15, 12))
376
- fig.suptitle('Recommendation System Analysis', fontsize=16, fontweight='bold')
377
-
378
- # 1. Category Overlap by Approach
379
- ax1 = axes[0, 0]
380
- approaches = []
381
- overlap_means = []
382
- overlap_stds = []
383
-
384
- for approach, stats in comparison_results['approach_stats'].items():
385
- if isinstance(stats, dict):
386
- approaches.append(approach.title())
387
- overlap_means.append(stats['avg_overlap_ratio'])
388
- overlap_stds.append(stats['std_overlap_ratio'])
389
-
390
- bars1 = ax1.bar(approaches, overlap_means, yerr=overlap_stds, capsize=5, alpha=0.7)
391
- ax1.set_title('Average Category Overlap by Approach')
392
- ax1.set_ylabel('Category Overlap Ratio')
393
- ax1.set_ylim(0, 1)
394
-
395
- # Add value labels on bars
396
- for bar, mean in zip(bars1, overlap_means):
397
- height = bar.get_height()
398
- ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
399
- f'{mean:.3f}', ha='center', va='bottom')
400
-
401
- # 2. Cross-Approach Item Overlap
402
- ax2 = axes[0, 1]
403
- cross_analysis = comparison_results['cross_approach_analysis']
404
-
405
- pair_names = []
406
- item_overlaps = []
407
-
408
- for pair, overlap_stats in cross_analysis['item_overlap'].items():
409
- pair_names.append(pair.replace('_vs_', ' vs ').title())
410
- item_overlaps.append(overlap_stats['avg'])
411
-
412
- if pair_names:
413
- bars2 = ax2.bar(pair_names, item_overlaps, alpha=0.7, color='coral')
414
- ax2.set_title('Item Overlap Between Approaches')
415
- ax2.set_ylabel('Item Overlap Ratio')
416
- ax2.set_ylim(0, 1)
417
- plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
418
-
419
- # Add value labels
420
- for bar, overlap in zip(bars2, item_overlaps):
421
- height = bar.get_height()
422
- ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
423
- f'{overlap:.3f}', ha='center', va='bottom')
424
-
425
- # 3. Category Diversity
426
- ax3 = axes[1, 0]
427
- unique_categories = []
428
- for approach, stats in comparison_results['approach_stats'].items():
429
- if isinstance(stats, dict):
430
- unique_categories.append(stats['avg_unique_categories'])
431
-
432
- bars3 = ax3.bar(approaches, unique_categories, alpha=0.7, color='lightgreen')
433
- ax3.set_title('Average Unique Categories per Recommendation')
434
- ax3.set_ylabel('Number of Unique Categories')
435
-
436
- for bar, cats in zip(bars3, unique_categories):
437
- height = bar.get_height()
438
- ax3.text(bar.get_x() + bar.get_width()/2., height + 0.1,
439
- f'{cats:.1f}', ha='center', va='bottom')
440
-
441
- # 4. Category vs Item Overlap Comparison
442
- ax4 = axes[1, 1]
443
-
444
- if cross_analysis['item_overlap'] and cross_analysis['category_overlap']:
445
- pairs = list(cross_analysis['item_overlap'].keys())
446
- item_overlaps = [cross_analysis['item_overlap'][p]['avg'] for p in pairs]
447
- cat_overlaps = [cross_analysis['category_overlap'][p]['avg'] for p in pairs]
448
-
449
- x = np.arange(len(pairs))
450
- width = 0.35
451
-
452
- bars4a = ax4.bar(x - width/2, item_overlaps, width, label='Item Overlap', alpha=0.7)
453
- bars4b = ax4.bar(x + width/2, cat_overlaps, width, label='Category Overlap', alpha=0.7)
454
-
455
- ax4.set_title('Item vs Category Overlap Between Approaches')
456
- ax4.set_ylabel('Overlap Ratio')
457
- ax4.set_xticks(x)
458
- ax4.set_xticklabels([p.replace('_vs_', ' vs ') for p in pairs], rotation=45, ha='right')
459
- ax4.legend()
460
- ax4.set_ylim(0, 1)
461
-
462
- plt.tight_layout()
463
-
464
- if save_plots:
465
- plt.savefig('recommendation_analysis_plots.png', dpi=300, bbox_inches='tight')
466
- print("βœ… Plots saved to: recommendation_analysis_plots.png")
467
-
468
- plt.show()
469
-
470
-
471
- def main():
472
- """Main function to run the recommendation analysis."""
473
-
474
- print("πŸ” Starting Recommendation Analysis...")
475
- print("=" * 50)
476
-
477
- # Initialize analyzer
478
- analyzer = RecommendationAnalyzer()
479
-
480
- if analyzer.recommendation_engine is None:
481
- print("❌ Cannot proceed without recommendation engine. Please ensure model is trained.")
482
- return
483
-
484
- # Get sample of real users for analysis
485
- print("Getting real user sample...")
486
- try:
487
- real_users = analyzer.real_user_selector.get_real_users(n=20, min_interactions=3)
488
- print(f"βœ… Loaded {len(real_users)} real users for analysis")
489
- except Exception as e:
490
- print(f"❌ Error loading real users: {e}")
491
- # Fallback to synthetic users
492
- real_users = [
493
- {
494
- 'age': 32, 'gender': 'male', 'income': 75000,
495
- 'interaction_history': [1000978, 1001588, 1001618, 1002039]
496
- },
497
- {
498
- 'age': 28, 'gender': 'female', 'income': 45000,
499
- 'interaction_history': [1003456, 1004567, 1005678]
500
- },
501
- {
502
- 'age': 45, 'gender': 'male', 'income': 85000,
503
- 'interaction_history': [1006789, 1007890, 1008901, 1009012, 1010123]
504
- }
505
- ]
506
- print(f"Using {len(real_users)} synthetic users for analysis")
507
-
508
- # Run comprehensive analysis
509
- print("Running recommendation analysis...")
510
- approaches = ['collaborative', 'hybrid', 'content']
511
-
512
- comparison_results = analyzer.compare_recommendation_approaches(
513
- users_sample=real_users,
514
- approaches=approaches
515
- )
516
-
517
- # Generate report
518
- print("Generating analysis report...")
519
- report = analyzer.generate_report(comparison_results)
520
-
521
- # Create visualizations
522
- print("Creating visualizations...")
523
- try:
524
- analyzer.visualize_results(comparison_results, save_plots=True)
525
- except Exception as e:
526
- print(f"Warning: Could not create visualizations: {e}")
527
-
528
- # Print summary
529
- print("\n" + "=" * 50)
530
- print("πŸ“Š ANALYSIS SUMMARY")
531
- print("=" * 50)
532
-
533
- for approach, stats in comparison_results['approach_stats'].items():
534
- if isinstance(stats, dict):
535
- print(f"{approach.title()}: {stats['avg_overlap_ratio']:.3f} avg category overlap")
536
-
537
- print(f"\nβœ… Analysis complete! Check:")
538
- print(" πŸ“„ recommendation_analysis_report.md")
539
- print(" πŸ“Š recommendation_analysis_plots.png")
540
-
541
-
542
- if __name__ == "__main__":
543
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/main.py CHANGED
@@ -33,7 +33,6 @@ app.add_middleware(
33
 
34
  # Global instances
35
  recommendation_engine = None
36
- enhanced_recommendation_engine = None
37
  real_user_selector = None
38
 
39
 
@@ -75,13 +74,17 @@ class UserProfile(BaseModel):
75
  age: int
76
  gender: str # "male" or "female"
77
  income: float
 
 
 
 
78
  interaction_history: Optional[List[int]] = []
79
 
80
 
81
  class RecommendationRequest(BaseModel):
82
  user_profile: UserProfile
83
  num_recommendations: int = 10
84
- recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
85
  collaborative_weight: Optional[float] = 0.7
86
  category_boost: Optional[float] = 1.5 # For enhanced recommendations
87
  enable_category_boost: Optional[bool] = True
@@ -132,6 +135,10 @@ class RealUserProfile(BaseModel):
132
  age: int
133
  gender: str
134
  income: int
 
 
 
 
135
  interaction_history: List[int]
136
  interaction_stats: Dict[str, int]
137
  interaction_pattern: str
@@ -169,39 +176,24 @@ class EnrichedBehavioralPatternsResponse(BaseModel):
169
 
170
  @app.on_event("startup")
171
  async def startup_event():
172
- """Initialize the recommendation engines and real user selector on startup."""
173
- global recommendation_engine, enhanced_recommendation_engine, real_user_selector
174
 
175
  try:
176
- print("Loading recommendation engine...")
177
  recommendation_engine = RecommendationEngine()
178
- print("Recommendation engine loaded successfully!")
 
179
  except Exception as e:
180
- print(f"Error loading recommendation engine: {e}")
181
  recommendation_engine = None
182
 
183
- try:
184
- print("Loading enhanced recommendation engine...")
185
- # Try enhanced 128D engine first, fallback to regular enhanced
186
- try:
187
- from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
188
- enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
189
- print("βœ… Using Enhanced 128D Recommendation Engine")
190
- except:
191
- from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
192
- enhanced_recommendation_engine = EnhancedRecommendationEngine()
193
- print("⚠️ Using fallback Enhanced Recommendation Engine")
194
- print("Enhanced recommendation engine loaded successfully!")
195
- except Exception as e:
196
- print(f"Error loading enhanced recommendation engine: {e}")
197
- enhanced_recommendation_engine = None
198
-
199
  try:
200
  print("Loading real user selector...")
201
  real_user_selector = RealUserSelector()
202
- print("Real user selector loaded successfully!")
203
  except Exception as e:
204
- print(f"Error loading real user selector: {e}")
205
  real_user_selector = None
206
 
207
 
@@ -211,7 +203,12 @@ async def root():
211
  return {
212
  "message": "Two-Tower Recommendation API",
213
  "version": "1.0.0",
214
- "status": "active" if recommendation_engine is not None else "initialization_failed"
 
 
 
 
 
215
  }
216
 
217
 
@@ -220,7 +217,13 @@ async def health_check():
220
  """Health check endpoint."""
221
  return {
222
  "status": "healthy" if recommendation_engine is not None else "unhealthy",
223
- "engine_loaded": recommendation_engine is not None
 
 
 
 
 
 
224
  }
225
 
226
 
@@ -378,6 +381,10 @@ async def get_recommendations(request: RecommendationRequest):
378
  age=user_profile.age,
379
  gender=user_profile.gender,
380
  income=user_profile.income,
 
 
 
 
381
  interaction_history=filtered_interaction_history,
382
  k=request.num_recommendations * 2 # Get more to allow for filtering
383
  )
@@ -390,11 +397,11 @@ async def get_recommendations(request: RecommendationRequest):
390
  (f" in category '{request.selected_category}'" if request.selected_category else "")
391
  )
392
 
393
- # Use most recent interaction as seed
394
- seed_item = filtered_interaction_history[-1]
395
- recommendations = recommendation_engine.recommend_items_content_based(
396
- seed_item_id=seed_item,
397
- k=request.num_recommendations * 2 # Get more to allow for filtering
398
  )
399
 
400
  elif request.recommendation_type == "hybrid":
@@ -402,79 +409,40 @@ async def get_recommendations(request: RecommendationRequest):
402
  age=user_profile.age,
403
  gender=user_profile.gender,
404
  income=user_profile.income,
 
 
 
 
405
  interaction_history=filtered_interaction_history,
406
  k=request.num_recommendations * 2, # Get more to allow for filtering
407
  collaborative_weight=request.collaborative_weight
408
  )
409
 
410
- elif request.recommendation_type == "enhanced":
411
- if enhanced_recommendation_engine is None:
412
- raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
413
-
414
- # Check if it's the 128D engine or fallback
415
- if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
416
- # 128D Enhanced engine
417
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
418
- age=user_profile.age,
419
- gender=user_profile.gender,
420
- income=user_profile.income,
421
- interaction_history=filtered_interaction_history,
422
- k=request.num_recommendations * 2, # Get more to allow for filtering
423
- diversity_weight=0.3 if request.enable_diversity else 0.0,
424
- category_boost=request.category_boost if request.enable_category_boost else 1.0
425
- )
426
- else:
427
- # Fallback enhanced engine
428
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
429
- age=user_profile.age,
430
- gender=user_profile.gender,
431
- income=user_profile.income,
432
- interaction_history=filtered_interaction_history,
433
- k=request.num_recommendations * 2, # Get more to allow for filtering
434
- collaborative_weight=request.collaborative_weight,
435
- category_boost=request.category_boost,
436
- enable_category_boost=request.enable_category_boost,
437
- enable_diversity=request.enable_diversity
438
- )
439
-
440
- elif request.recommendation_type == "enhanced_128d":
441
- if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
442
- raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
443
-
444
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
445
  age=user_profile.age,
446
  gender=user_profile.gender,
447
  income=user_profile.income,
 
 
 
 
448
  interaction_history=filtered_interaction_history,
449
- k=request.num_recommendations * 2, # Get more to allow for filtering
450
- diversity_weight=0.3 if request.enable_diversity else 0.0,
451
- category_boost=request.category_boost if request.enable_category_boost else 1.0
452
- )
453
-
454
- elif request.recommendation_type == "category_focused":
455
- if enhanced_recommendation_engine is None:
456
- raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
457
-
458
- recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
459
- age=user_profile.age,
460
- gender=user_profile.gender,
461
- income=user_profile.income,
462
- interaction_history=filtered_interaction_history,
463
- k=request.num_recommendations * 2, # Get more to allow for filtering
464
- focus_percentage=0.8
465
  )
466
 
467
  else:
468
  raise HTTPException(
469
  status_code=400,
470
- detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
471
  )
472
 
473
  # Apply category filtering to final recommendations if needed
474
  if request.selected_category:
475
  recommendations = filter_recommendations_by_category(recommendations, request.selected_category)
476
- # Limit to requested number after filtering
477
- recommendations = recommendations[:request.num_recommendations]
 
478
 
479
  # Format response
480
  formatted_recommendations = []
@@ -543,8 +511,12 @@ async def predict_user_item_rating(request: RatingPredictionRequest):
543
  age=user_profile.age,
544
  gender=user_profile.gender,
545
  income=user_profile.income,
546
- interaction_history=user_profile.interaction_history,
547
- item_id=request.item_id
 
 
 
 
548
  )
549
 
550
  item_info = recommendation_engine._get_item_info(request.item_id)
 
33
 
34
  # Global instances
35
  recommendation_engine = None
 
36
  real_user_selector = None
37
 
38
 
 
74
  age: int
75
  gender: str # "male" or "female"
76
  income: float
77
+ profession: Optional[str] = "Other"
78
+ location: Optional[str] = "Urban"
79
+ education_level: Optional[str] = "High School"
80
+ marital_status: Optional[str] = "Single"
81
  interaction_history: Optional[List[int]] = []
82
 
83
 
84
  class RecommendationRequest(BaseModel):
85
  user_profile: UserProfile
86
  num_recommendations: int = 10
87
+ recommendation_type: str = "hybrid" # "collaborative", "content" (aggregated history), "hybrid"
88
  collaborative_weight: Optional[float] = 0.7
89
  category_boost: Optional[float] = 1.5 # For enhanced recommendations
90
  enable_category_boost: Optional[bool] = True
 
135
  age: int
136
  gender: str
137
  income: int
138
+ profession: Optional[str] = None
139
+ location: Optional[str] = None
140
+ education_level: Optional[str] = None
141
+ marital_status: Optional[str] = None
142
  interaction_history: List[int]
143
  interaction_stats: Dict[str, int]
144
  interaction_pattern: str
 
176
 
177
  @app.on_event("startup")
178
  async def startup_event():
179
+ """Initialize the recommendation engine and real user selector on startup."""
180
+ global recommendation_engine, real_user_selector
181
 
182
  try:
183
+ print("Loading recommendation engine with enhanced demographics...")
184
  recommendation_engine = RecommendationEngine()
185
+ print("βœ… Recommendation engine loaded successfully!")
186
+ print(" Supports 7 demographic features: age, gender, income, profession, location, education, marital_status")
187
  except Exception as e:
188
+ print(f"❌ Error loading recommendation engine: {e}")
189
  recommendation_engine = None
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  try:
192
  print("Loading real user selector...")
193
  real_user_selector = RealUserSelector()
194
+ print("βœ… Real user selector loaded successfully!")
195
  except Exception as e:
196
+ print(f"❌ Error loading real user selector: {e}")
197
  real_user_selector = None
198
 
199
 
 
203
  return {
204
  "message": "Two-Tower Recommendation API",
205
  "version": "1.0.0",
206
+ "status": "active" if recommendation_engine is not None else "initialization_failed",
207
+ "enhanced_demographics": True,
208
+ "supported_demographics": [
209
+ "age", "gender", "income", "profession",
210
+ "location", "education_level", "marital_status"
211
+ ]
212
  }
213
 
214
 
 
217
  """Health check endpoint."""
218
  return {
219
  "status": "healthy" if recommendation_engine is not None else "unhealthy",
220
+ "engine_loaded": recommendation_engine is not None,
221
+ "enhanced_demographics": True,
222
+ "demographic_features": 7,
223
+ "supported_demographics": [
224
+ "age", "gender", "income", "profession",
225
+ "location", "education_level", "marital_status"
226
+ ]
227
  }
228
 
229
 
 
381
  age=user_profile.age,
382
  gender=user_profile.gender,
383
  income=user_profile.income,
384
+ profession=user_profile.profession or "Other",
385
+ location=user_profile.location or "Urban",
386
+ education_level=user_profile.education_level or "High School",
387
+ marital_status=user_profile.marital_status or "Single",
388
  interaction_history=filtered_interaction_history,
389
  k=request.num_recommendations * 2 # Get more to allow for filtering
390
  )
 
397
  (f" in category '{request.selected_category}'" if request.selected_category else "")
398
  )
399
 
400
+ # Use aggregated interaction history for content-based recommendations
401
+ recommendations = recommendation_engine.recommend_items_content_based_from_history(
402
+ interaction_history=filtered_interaction_history,
403
+ k=request.num_recommendations * 2, # Get more to allow for filtering
404
+ aggregation_method="weighted_mean"
405
  )
406
 
407
  elif request.recommendation_type == "hybrid":
 
409
  age=user_profile.age,
410
  gender=user_profile.gender,
411
  income=user_profile.income,
412
+ profession=user_profile.profession or "Other",
413
+ location=user_profile.location or "Urban",
414
+ education_level=user_profile.education_level or "High School",
415
+ marital_status=user_profile.marital_status or "Single",
416
  interaction_history=filtered_interaction_history,
417
  k=request.num_recommendations * 2, # Get more to allow for filtering
418
  collaborative_weight=request.collaborative_weight
419
  )
420
 
421
+ elif request.recommendation_type == "category_boosted":
422
+ recommendations = recommendation_engine.recommend_items_category_boosted(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  age=user_profile.age,
424
  gender=user_profile.gender,
425
  income=user_profile.income,
426
+ profession=user_profile.profession or "Other",
427
+ location=user_profile.location or "Urban",
428
+ education_level=user_profile.education_level or "High School",
429
+ marital_status=user_profile.marital_status or "Single",
430
  interaction_history=filtered_interaction_history,
431
+ k=request.num_recommendations * 2 # Get more to allow for filtering
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  )
433
 
434
  else:
435
  raise HTTPException(
436
  status_code=400,
437
+ detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', or 'category_boosted'"
438
  )
439
 
440
  # Apply category filtering to final recommendations if needed
441
  if request.selected_category:
442
  recommendations = filter_recommendations_by_category(recommendations, request.selected_category)
443
+
444
+ # Always limit to requested number of recommendations
445
+ recommendations = recommendations[:request.num_recommendations]
446
 
447
  # Format response
448
  formatted_recommendations = []
 
511
  age=user_profile.age,
512
  gender=user_profile.gender,
513
  income=user_profile.income,
514
+ item_id=request.item_id,
515
+ profession=user_profile.profession or "Other",
516
+ location=user_profile.location or "Urban",
517
+ education_level=user_profile.education_level or "High School",
518
+ marital_status=user_profile.marital_status or "Single",
519
+ interaction_history=user_profile.interaction_history
520
  )
521
 
522
  item_info = recommendation_engine._get_item_info(request.item_id)
api_2phase.py DELETED
@@ -1,521 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- API for 2-Phase Trained Recommendation System
4
-
5
- This API serves recommendations from a model trained using the 2-phase approach:
6
- 1. Pre-trained item tower
7
- 2. Joint training with fine-tuned item tower
8
-
9
- Usage:
10
- python api_2phase.py
11
-
12
- Then access: http://localhost:8000
13
- """
14
-
15
- from fastapi import FastAPI, HTTPException
16
- from fastapi.middleware.cors import CORSMiddleware
17
- from pydantic import BaseModel
18
- from typing import List, Optional, Dict, Any
19
- import uvicorn
20
- import os
21
- import sys
22
- import pandas as pd
23
-
24
- # Add src to path for imports and set working directory
25
- parent_dir = os.path.dirname(__file__)
26
- sys.path.append(parent_dir)
27
- os.chdir(parent_dir) # Change to project root directory
28
-
29
- from src.inference.recommendation_engine import RecommendationEngine
30
- from src.utils.real_user_selector import RealUserSelector
31
-
32
- # Initialize FastAPI app
33
- app = FastAPI(
34
- title="Two-Tower Recommendation API (2-Phase Training)",
35
- description="API for serving recommendations using a two-tower architecture trained with 2-phase approach",
36
- version="1.0.0-2phase"
37
- )
38
-
39
- # Add CORS middleware
40
- app.add_middleware(
41
- CORSMiddleware,
42
- allow_origins=["*"], # Configure appropriately for production
43
- allow_credentials=True,
44
- allow_methods=["*"],
45
- allow_headers=["*"],
46
- )
47
-
48
- # Global instances
49
- recommendation_engine = None
50
- enhanced_recommendation_engine = None
51
- real_user_selector = None
52
-
53
-
54
- # Pydantic models for request/response
55
- class UserProfile(BaseModel):
56
- age: int
57
- gender: str # "male" or "female"
58
- income: float
59
- interaction_history: Optional[List[int]] = []
60
-
61
-
62
- class RecommendationRequest(BaseModel):
63
- user_profile: UserProfile
64
- num_recommendations: int = 10
65
- recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
66
- collaborative_weight: Optional[float] = 0.7
67
- category_boost: Optional[float] = 1.5 # For enhanced recommendations
68
- enable_category_boost: Optional[bool] = True
69
- enable_diversity: Optional[bool] = True
70
-
71
-
72
- class ItemSimilarityRequest(BaseModel):
73
- item_id: int
74
- num_recommendations: int = 10
75
-
76
-
77
- class RatingPredictionRequest(BaseModel):
78
- user_profile: UserProfile
79
- item_id: int
80
-
81
-
82
- class ItemInfo(BaseModel):
83
- product_id: int
84
- category_id: int
85
- category_code: str
86
- brand: str
87
- price: float
88
-
89
-
90
- class RecommendationResponse(BaseModel):
91
- item_id: int
92
- score: float
93
- item_info: ItemInfo
94
-
95
-
96
- class RecommendationsResponse(BaseModel):
97
- recommendations: List[RecommendationResponse]
98
- user_profile: UserProfile
99
- recommendation_type: str
100
- total_count: int
101
- training_approach: str = "2-phase"
102
-
103
-
104
- class RatingPredictionResponse(BaseModel):
105
- user_profile: UserProfile
106
- item_id: int
107
- predicted_rating: float
108
- item_info: ItemInfo
109
-
110
-
111
- class RealUserProfile(BaseModel):
112
- user_id: int
113
- age: int
114
- gender: str
115
- income: int
116
- interaction_history: List[int]
117
- interaction_stats: Dict[str, int]
118
- interaction_pattern: str
119
- summary: str
120
-
121
-
122
- class RealUsersResponse(BaseModel):
123
- users: List[RealUserProfile]
124
- total_count: int
125
- dataset_summary: Dict[str, Any]
126
-
127
-
128
- @app.on_event("startup")
129
- async def startup_event():
130
- """Initialize the recommendation engines and real user selector on startup."""
131
- global recommendation_engine, enhanced_recommendation_engine, real_user_selector
132
-
133
- print("πŸš€ Starting 2-Phase Training API...")
134
- print(" Training approach: Pre-trained item tower + Joint fine-tuning")
135
-
136
- try:
137
- print("Loading 2-phase trained recommendation engine...")
138
- recommendation_engine = RecommendationEngine()
139
- print("βœ… 2-phase recommendation engine loaded successfully!")
140
- except Exception as e:
141
- print(f"❌ Error loading recommendation engine: {e}")
142
- recommendation_engine = None
143
-
144
- try:
145
- print("Loading enhanced recommendation engine...")
146
- # Try enhanced 128D engine first, fallback to regular enhanced
147
- try:
148
- from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
149
- enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
150
- print("βœ… Using Enhanced 128D Recommendation Engine")
151
- except:
152
- from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
153
- enhanced_recommendation_engine = EnhancedRecommendationEngine()
154
- print("⚠️ Using fallback Enhanced Recommendation Engine")
155
- print("Enhanced recommendation engine loaded successfully!")
156
- except Exception as e:
157
- print(f"Error loading enhanced recommendation engine: {e}")
158
- enhanced_recommendation_engine = None
159
-
160
- try:
161
- print("Loading real user selector...")
162
- real_user_selector = RealUserSelector()
163
- print("Real user selector loaded successfully!")
164
- except Exception as e:
165
- print(f"Error loading real user selector: {e}")
166
- real_user_selector = None
167
-
168
- print("🎯 2-Phase API ready to serve recommendations!")
169
-
170
-
171
- @app.get("/")
172
- async def root():
173
- """Root endpoint with API information."""
174
- return {
175
- "message": "Two-Tower Recommendation API (2-Phase Training)",
176
- "version": "1.0.0-2phase",
177
- "training_approach": "2-phase (pre-trained item tower + joint fine-tuning)",
178
- "status": "active" if recommendation_engine is not None else "initialization_failed"
179
- }
180
-
181
-
182
- @app.get("/health")
183
- async def health_check():
184
- """Health check endpoint."""
185
- return {
186
- "status": "healthy" if recommendation_engine is not None else "unhealthy",
187
- "engine_loaded": recommendation_engine is not None,
188
- "training_approach": "2-phase"
189
- }
190
-
191
-
192
- @app.get("/model-info")
193
- async def model_info():
194
- """Get information about the loaded model."""
195
- if recommendation_engine is None:
196
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
197
-
198
- return {
199
- "training_approach": "2-phase",
200
- "description": "Pre-trained item tower followed by joint training with user tower",
201
- "phases": [
202
- "Phase 1: Item tower pre-training on item features only",
203
- "Phase 2: Joint training of user tower + fine-tuning pre-trained item tower"
204
- ],
205
- "embedding_dimension": 128,
206
- "item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
207
- "artifacts_loaded": {
208
- "item_tower_pretrained": "src/artifacts/item_tower_weights",
209
- "item_tower_finetuned": "src/artifacts/item_tower_weights_finetuned_best",
210
- "user_tower": "src/artifacts/user_tower_weights_best",
211
- "rating_model": "src/artifacts/rating_model_weights_best"
212
- }
213
- }
214
-
215
-
216
- @app.get("/real-users", response_model=RealUsersResponse)
217
- async def get_real_users(count: int = 100, min_interactions: int = 5):
218
- """Get real user profiles with genuine interaction histories."""
219
-
220
- if real_user_selector is None:
221
- raise HTTPException(status_code=503, detail="Real user selector not available")
222
-
223
- try:
224
- # Get real user profiles
225
- real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
226
-
227
- # Get dataset summary
228
- dataset_summary = real_user_selector.get_dataset_summary()
229
-
230
- # Format users for response
231
- formatted_users = []
232
- for user in real_users:
233
- formatted_users.append(RealUserProfile(**user))
234
-
235
- return RealUsersResponse(
236
- users=formatted_users,
237
- total_count=len(formatted_users),
238
- dataset_summary=dataset_summary
239
- )
240
-
241
- except Exception as e:
242
- raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
243
-
244
-
245
- @app.get("/real-users/{user_id}")
246
- async def get_real_user_details(user_id: int):
247
- """Get detailed interaction breakdown for a specific real user."""
248
-
249
- if real_user_selector is None:
250
- raise HTTPException(status_code=503, detail="Real user selector not available")
251
-
252
- try:
253
- user_details = real_user_selector.get_user_interaction_details(user_id)
254
-
255
- if "error" in user_details:
256
- raise HTTPException(status_code=404, detail=user_details["error"])
257
-
258
- return user_details
259
-
260
- except Exception as e:
261
- raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
262
-
263
-
264
- @app.get("/dataset-summary")
265
- async def get_dataset_summary():
266
- """Get summary statistics of the real dataset."""
267
-
268
- if real_user_selector is None:
269
- raise HTTPException(status_code=503, detail="Real user selector not available")
270
-
271
- try:
272
- return real_user_selector.get_dataset_summary()
273
-
274
- except Exception as e:
275
- raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
276
-
277
-
278
- @app.post("/recommendations", response_model=RecommendationsResponse)
279
- async def get_recommendations(request: RecommendationRequest):
280
- """Get item recommendations for a user."""
281
-
282
- if recommendation_engine is None:
283
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
284
-
285
- try:
286
- user_profile = request.user_profile
287
-
288
- # Generate recommendations based on type
289
- if request.recommendation_type == "collaborative":
290
- recommendations = recommendation_engine.recommend_items_collaborative(
291
- age=user_profile.age,
292
- gender=user_profile.gender,
293
- income=user_profile.income,
294
- interaction_history=user_profile.interaction_history,
295
- k=request.num_recommendations
296
- )
297
-
298
- elif request.recommendation_type == "content":
299
- if not user_profile.interaction_history:
300
- raise HTTPException(
301
- status_code=400,
302
- detail="Content-based recommendations require interaction history"
303
- )
304
-
305
- # Use most recent interaction as seed
306
- seed_item = user_profile.interaction_history[-1]
307
- recommendations = recommendation_engine.recommend_items_content_based(
308
- seed_item_id=seed_item,
309
- k=request.num_recommendations
310
- )
311
-
312
- elif request.recommendation_type == "hybrid":
313
- recommendations = recommendation_engine.recommend_items_hybrid(
314
- age=user_profile.age,
315
- gender=user_profile.gender,
316
- income=user_profile.income,
317
- interaction_history=user_profile.interaction_history,
318
- k=request.num_recommendations,
319
- collaborative_weight=request.collaborative_weight
320
- )
321
-
322
- elif request.recommendation_type == "enhanced":
323
- if enhanced_recommendation_engine is None:
324
- raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
325
-
326
- # Check if it's the 128D engine or fallback
327
- if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
328
- # 128D Enhanced engine
329
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
330
- age=user_profile.age,
331
- gender=user_profile.gender,
332
- income=user_profile.income,
333
- interaction_history=user_profile.interaction_history,
334
- k=request.num_recommendations,
335
- diversity_weight=0.3 if request.enable_diversity else 0.0,
336
- category_boost=request.category_boost if request.enable_category_boost else 1.0
337
- )
338
- else:
339
- # Fallback enhanced engine
340
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
341
- age=user_profile.age,
342
- gender=user_profile.gender,
343
- income=user_profile.income,
344
- interaction_history=user_profile.interaction_history,
345
- k=request.num_recommendations,
346
- collaborative_weight=request.collaborative_weight,
347
- category_boost=request.category_boost,
348
- enable_category_boost=request.enable_category_boost,
349
- enable_diversity=request.enable_diversity
350
- )
351
-
352
- elif request.recommendation_type == "enhanced_128d":
353
- if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
354
- raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
355
-
356
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
357
- age=user_profile.age,
358
- gender=user_profile.gender,
359
- income=user_profile.income,
360
- interaction_history=user_profile.interaction_history,
361
- k=request.num_recommendations,
362
- diversity_weight=0.3 if request.enable_diversity else 0.0,
363
- category_boost=request.category_boost if request.enable_category_boost else 1.0
364
- )
365
-
366
- elif request.recommendation_type == "category_focused":
367
- if enhanced_recommendation_engine is None:
368
- raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
369
-
370
- recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
371
- age=user_profile.age,
372
- gender=user_profile.gender,
373
- income=user_profile.income,
374
- interaction_history=user_profile.interaction_history,
375
- k=request.num_recommendations,
376
- focus_percentage=0.8
377
- )
378
-
379
- else:
380
- raise HTTPException(
381
- status_code=400,
382
- detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
383
- )
384
-
385
- # Format response
386
- formatted_recommendations = []
387
- for item_id, score, item_info in recommendations:
388
- formatted_recommendations.append(
389
- RecommendationResponse(
390
- item_id=item_id,
391
- score=score,
392
- item_info=ItemInfo(**item_info)
393
- )
394
- )
395
-
396
- return RecommendationsResponse(
397
- recommendations=formatted_recommendations,
398
- user_profile=user_profile,
399
- recommendation_type=request.recommendation_type,
400
- total_count=len(formatted_recommendations),
401
- training_approach="2-phase"
402
- )
403
-
404
- except Exception as e:
405
- raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
406
-
407
-
408
- @app.post("/item-similarity", response_model=List[RecommendationResponse])
409
- async def get_similar_items(request: ItemSimilarityRequest):
410
- """Get items similar to a given item."""
411
-
412
- if recommendation_engine is None:
413
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
414
-
415
- try:
416
- recommendations = recommendation_engine.recommend_items_content_based(
417
- seed_item_id=request.item_id,
418
- k=request.num_recommendations
419
- )
420
-
421
- formatted_recommendations = []
422
- for item_id, score, item_info in recommendations:
423
- formatted_recommendations.append(
424
- RecommendationResponse(
425
- item_id=item_id,
426
- score=score,
427
- item_info=ItemInfo(**item_info)
428
- )
429
- )
430
-
431
- return formatted_recommendations
432
-
433
- except Exception as e:
434
- raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
435
-
436
-
437
- @app.post("/predict-rating", response_model=RatingPredictionResponse)
438
- async def predict_user_item_rating(request: RatingPredictionRequest):
439
- """Predict rating for a user-item pair."""
440
-
441
- if recommendation_engine is None:
442
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
443
-
444
- try:
445
- user_profile = request.user_profile
446
-
447
- predicted_rating = recommendation_engine.predict_rating(
448
- age=user_profile.age,
449
- gender=user_profile.gender,
450
- income=user_profile.income,
451
- interaction_history=user_profile.interaction_history,
452
- item_id=request.item_id
453
- )
454
-
455
- item_info = recommendation_engine._get_item_info(request.item_id)
456
-
457
- return RatingPredictionResponse(
458
- user_profile=user_profile,
459
- item_id=request.item_id,
460
- predicted_rating=predicted_rating,
461
- item_info=ItemInfo(**item_info)
462
- )
463
-
464
- except Exception as e:
465
- raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
466
-
467
-
468
- @app.get("/items/{item_id}", response_model=ItemInfo)
469
- async def get_item_info(item_id: int):
470
- """Get information about a specific item."""
471
-
472
- if recommendation_engine is None:
473
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
474
-
475
- try:
476
- item_info = recommendation_engine._get_item_info(item_id)
477
- return ItemInfo(**item_info)
478
-
479
- except Exception as e:
480
- raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
481
-
482
-
483
- @app.get("/items")
484
- async def get_sample_items(limit: int = 20):
485
- """Get a sample of items for testing."""
486
-
487
- if recommendation_engine is None:
488
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
489
-
490
- try:
491
- # Get sample items from the dataframe
492
- sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
493
-
494
- items = []
495
- for _, row in sample_items.iterrows():
496
- items.append({
497
- "product_id": int(row['product_id']),
498
- "category_id": int(row['category_id']),
499
- "category_code": str(row['category_code']),
500
- "brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
501
- "price": float(row['price'])
502
- })
503
-
504
- return {"items": items, "total": len(items), "training_approach": "2-phase"}
505
-
506
- except Exception as e:
507
- raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
508
-
509
-
510
- if __name__ == "__main__":
511
- print("πŸš€ Starting 2-Phase Training Recommendation API...")
512
- print("πŸ“Š Training approach: Pre-trained item tower + Joint fine-tuning")
513
- print("🌐 Server will be available at: http://localhost:8000")
514
- print("πŸ“š API docs at: http://localhost:8000/docs")
515
-
516
- uvicorn.run(
517
- "api_2phase:app",
518
- host="0.0.0.0",
519
- port=8000,
520
- reload=True
521
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api_joint.py DELETED
@@ -1,522 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- API for Single Joint Trained Recommendation System
4
-
5
- This API serves recommendations from a model trained using the single joint approach:
6
- - Both user and item towers trained simultaneously from scratch
7
- - End-to-end optimization without pre-training phases
8
-
9
- Usage:
10
- python api_joint.py
11
-
12
- Then access: http://localhost:8000
13
- """
14
-
15
- from fastapi import FastAPI, HTTPException
16
- from fastapi.middleware.cors import CORSMiddleware
17
- from pydantic import BaseModel
18
- from typing import List, Optional, Dict, Any
19
- import uvicorn
20
- import os
21
- import sys
22
- import pandas as pd
23
-
24
- # Add src to path for imports and set working directory
25
- parent_dir = os.path.dirname(__file__)
26
- sys.path.append(parent_dir)
27
- os.chdir(parent_dir) # Change to project root directory
28
-
29
- from src.inference.recommendation_engine import RecommendationEngine
30
- from src.utils.real_user_selector import RealUserSelector
31
-
32
- # Initialize FastAPI app
33
- app = FastAPI(
34
- title="Two-Tower Recommendation API (Single Joint Training)",
35
- description="API for serving recommendations using a two-tower architecture trained with single joint approach",
36
- version="1.0.0-joint"
37
- )
38
-
39
- # Add CORS middleware
40
- app.add_middleware(
41
- CORSMiddleware,
42
- allow_origins=["*"], # Configure appropriately for production
43
- allow_credentials=True,
44
- allow_methods=["*"],
45
- allow_headers=["*"],
46
- )
47
-
48
- # Global instances
49
- recommendation_engine = None
50
- enhanced_recommendation_engine = None
51
- real_user_selector = None
52
-
53
-
54
- # Pydantic models for request/response
55
- class UserProfile(BaseModel):
56
- age: int
57
- gender: str # "male" or "female"
58
- income: float
59
- interaction_history: Optional[List[int]] = []
60
-
61
-
62
- class RecommendationRequest(BaseModel):
63
- user_profile: UserProfile
64
- num_recommendations: int = 10
65
- recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
66
- collaborative_weight: Optional[float] = 0.7
67
- category_boost: Optional[float] = 1.5 # For enhanced recommendations
68
- enable_category_boost: Optional[bool] = True
69
- enable_diversity: Optional[bool] = True
70
-
71
-
72
- class ItemSimilarityRequest(BaseModel):
73
- item_id: int
74
- num_recommendations: int = 10
75
-
76
-
77
- class RatingPredictionRequest(BaseModel):
78
- user_profile: UserProfile
79
- item_id: int
80
-
81
-
82
- class ItemInfo(BaseModel):
83
- product_id: int
84
- category_id: int
85
- category_code: str
86
- brand: str
87
- price: float
88
-
89
-
90
- class RecommendationResponse(BaseModel):
91
- item_id: int
92
- score: float
93
- item_info: ItemInfo
94
-
95
-
96
- class RecommendationsResponse(BaseModel):
97
- recommendations: List[RecommendationResponse]
98
- user_profile: UserProfile
99
- recommendation_type: str
100
- total_count: int
101
- training_approach: str = "single-joint"
102
-
103
-
104
- class RatingPredictionResponse(BaseModel):
105
- user_profile: UserProfile
106
- item_id: int
107
- predicted_rating: float
108
- item_info: ItemInfo
109
-
110
-
111
- class RealUserProfile(BaseModel):
112
- user_id: int
113
- age: int
114
- gender: str
115
- income: int
116
- interaction_history: List[int]
117
- interaction_stats: Dict[str, int]
118
- interaction_pattern: str
119
- summary: str
120
-
121
-
122
- class RealUsersResponse(BaseModel):
123
- users: List[RealUserProfile]
124
- total_count: int
125
- dataset_summary: Dict[str, Any]
126
-
127
-
128
- @app.on_event("startup")
129
- async def startup_event():
130
- """Initialize the recommendation engines and real user selector on startup."""
131
- global recommendation_engine, enhanced_recommendation_engine, real_user_selector
132
-
133
- print("πŸš€ Starting Single Joint Training API...")
134
- print(" Training approach: End-to-end joint optimization from scratch")
135
-
136
- try:
137
- print("Loading single joint trained recommendation engine...")
138
- recommendation_engine = RecommendationEngine()
139
- print("βœ… Single joint recommendation engine loaded successfully!")
140
- except Exception as e:
141
- print(f"❌ Error loading recommendation engine: {e}")
142
- recommendation_engine = None
143
-
144
- try:
145
- print("Loading enhanced recommendation engine...")
146
- # Try enhanced 128D engine first, fallback to regular enhanced
147
- try:
148
- from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
149
- enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
150
- print("βœ… Using Enhanced 128D Recommendation Engine")
151
- except:
152
- from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
153
- enhanced_recommendation_engine = EnhancedRecommendationEngine()
154
- print("⚠️ Using fallback Enhanced Recommendation Engine")
155
- print("Enhanced recommendation engine loaded successfully!")
156
- except Exception as e:
157
- print(f"Error loading enhanced recommendation engine: {e}")
158
- enhanced_recommendation_engine = None
159
-
160
- try:
161
- print("Loading real user selector...")
162
- real_user_selector = RealUserSelector()
163
- print("Real user selector loaded successfully!")
164
- except Exception as e:
165
- print(f"Error loading real user selector: {e}")
166
- real_user_selector = None
167
-
168
- print("🎯 Single Joint API ready to serve recommendations!")
169
-
170
-
171
- @app.get("/")
172
- async def root():
173
- """Root endpoint with API information."""
174
- return {
175
- "message": "Two-Tower Recommendation API (Single Joint Training)",
176
- "version": "1.0.0-joint",
177
- "training_approach": "single-joint (end-to-end optimization from scratch)",
178
- "status": "active" if recommendation_engine is not None else "initialization_failed"
179
- }
180
-
181
-
182
- @app.get("/health")
183
- async def health_check():
184
- """Health check endpoint."""
185
- return {
186
- "status": "healthy" if recommendation_engine is not None else "unhealthy",
187
- "engine_loaded": recommendation_engine is not None,
188
- "training_approach": "single-joint"
189
- }
190
-
191
-
192
- @app.get("/model-info")
193
- async def model_info():
194
- """Get information about the loaded model."""
195
- if recommendation_engine is None:
196
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
197
-
198
- return {
199
- "training_approach": "single-joint",
200
- "description": "User and item towers trained simultaneously from scratch",
201
- "advantages": [
202
- "End-to-end optimization for better task alignment",
203
- "No pre-training phase required",
204
- "Faster overall training pipeline",
205
- "Direct optimization for recommendation objectives"
206
- ],
207
- "embedding_dimension": 128,
208
- "item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
209
- "artifacts_loaded": {
210
- "user_tower": "src/artifacts/user_tower_weights_best",
211
- "item_tower_joint": "src/artifacts/item_tower_weights_finetuned_best",
212
- "rating_model": "src/artifacts/rating_model_weights_best"
213
- }
214
- }
215
-
216
-
217
- @app.get("/real-users", response_model=RealUsersResponse)
218
- async def get_real_users(count: int = 100, min_interactions: int = 5):
219
- """Get real user profiles with genuine interaction histories."""
220
-
221
- if real_user_selector is None:
222
- raise HTTPException(status_code=503, detail="Real user selector not available")
223
-
224
- try:
225
- # Get real user profiles
226
- real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
227
-
228
- # Get dataset summary
229
- dataset_summary = real_user_selector.get_dataset_summary()
230
-
231
- # Format users for response
232
- formatted_users = []
233
- for user in real_users:
234
- formatted_users.append(RealUserProfile(**user))
235
-
236
- return RealUsersResponse(
237
- users=formatted_users,
238
- total_count=len(formatted_users),
239
- dataset_summary=dataset_summary
240
- )
241
-
242
- except Exception as e:
243
- raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
244
-
245
-
246
- @app.get("/real-users/{user_id}")
247
- async def get_real_user_details(user_id: int):
248
- """Get detailed interaction breakdown for a specific real user."""
249
-
250
- if real_user_selector is None:
251
- raise HTTPException(status_code=503, detail="Real user selector not available")
252
-
253
- try:
254
- user_details = real_user_selector.get_user_interaction_details(user_id)
255
-
256
- if "error" in user_details:
257
- raise HTTPException(status_code=404, detail=user_details["error"])
258
-
259
- return user_details
260
-
261
- except Exception as e:
262
- raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
263
-
264
-
265
- @app.get("/dataset-summary")
266
- async def get_dataset_summary():
267
- """Get summary statistics of the real dataset."""
268
-
269
- if real_user_selector is None:
270
- raise HTTPException(status_code=503, detail="Real user selector not available")
271
-
272
- try:
273
- return real_user_selector.get_dataset_summary()
274
-
275
- except Exception as e:
276
- raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
277
-
278
-
279
- @app.post("/recommendations", response_model=RecommendationsResponse)
280
- async def get_recommendations(request: RecommendationRequest):
281
- """Get item recommendations for a user."""
282
-
283
- if recommendation_engine is None:
284
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
285
-
286
- try:
287
- user_profile = request.user_profile
288
-
289
- # Generate recommendations based on type
290
- if request.recommendation_type == "collaborative":
291
- recommendations = recommendation_engine.recommend_items_collaborative(
292
- age=user_profile.age,
293
- gender=user_profile.gender,
294
- income=user_profile.income,
295
- interaction_history=user_profile.interaction_history,
296
- k=request.num_recommendations
297
- )
298
-
299
- elif request.recommendation_type == "content":
300
- if not user_profile.interaction_history:
301
- raise HTTPException(
302
- status_code=400,
303
- detail="Content-based recommendations require interaction history"
304
- )
305
-
306
- # Use most recent interaction as seed
307
- seed_item = user_profile.interaction_history[-1]
308
- recommendations = recommendation_engine.recommend_items_content_based(
309
- seed_item_id=seed_item,
310
- k=request.num_recommendations
311
- )
312
-
313
- elif request.recommendation_type == "hybrid":
314
- recommendations = recommendation_engine.recommend_items_hybrid(
315
- age=user_profile.age,
316
- gender=user_profile.gender,
317
- income=user_profile.income,
318
- interaction_history=user_profile.interaction_history,
319
- k=request.num_recommendations,
320
- collaborative_weight=request.collaborative_weight
321
- )
322
-
323
- elif request.recommendation_type == "enhanced":
324
- if enhanced_recommendation_engine is None:
325
- raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
326
-
327
- # Check if it's the 128D engine or fallback
328
- if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
329
- # 128D Enhanced engine
330
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
331
- age=user_profile.age,
332
- gender=user_profile.gender,
333
- income=user_profile.income,
334
- interaction_history=user_profile.interaction_history,
335
- k=request.num_recommendations,
336
- diversity_weight=0.3 if request.enable_diversity else 0.0,
337
- category_boost=request.category_boost if request.enable_category_boost else 1.0
338
- )
339
- else:
340
- # Fallback enhanced engine
341
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
342
- age=user_profile.age,
343
- gender=user_profile.gender,
344
- income=user_profile.income,
345
- interaction_history=user_profile.interaction_history,
346
- k=request.num_recommendations,
347
- collaborative_weight=request.collaborative_weight,
348
- category_boost=request.category_boost,
349
- enable_category_boost=request.enable_category_boost,
350
- enable_diversity=request.enable_diversity
351
- )
352
-
353
- elif request.recommendation_type == "enhanced_128d":
354
- if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
355
- raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
356
-
357
- recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
358
- age=user_profile.age,
359
- gender=user_profile.gender,
360
- income=user_profile.income,
361
- interaction_history=user_profile.interaction_history,
362
- k=request.num_recommendations,
363
- diversity_weight=0.3 if request.enable_diversity else 0.0,
364
- category_boost=request.category_boost if request.enable_category_boost else 1.0
365
- )
366
-
367
- elif request.recommendation_type == "category_focused":
368
- if enhanced_recommendation_engine is None:
369
- raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
370
-
371
- recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
372
- age=user_profile.age,
373
- gender=user_profile.gender,
374
- income=user_profile.income,
375
- interaction_history=user_profile.interaction_history,
376
- k=request.num_recommendations,
377
- focus_percentage=0.8
378
- )
379
-
380
- else:
381
- raise HTTPException(
382
- status_code=400,
383
- detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
384
- )
385
-
386
- # Format response
387
- formatted_recommendations = []
388
- for item_id, score, item_info in recommendations:
389
- formatted_recommendations.append(
390
- RecommendationResponse(
391
- item_id=item_id,
392
- score=score,
393
- item_info=ItemInfo(**item_info)
394
- )
395
- )
396
-
397
- return RecommendationsResponse(
398
- recommendations=formatted_recommendations,
399
- user_profile=user_profile,
400
- recommendation_type=request.recommendation_type,
401
- total_count=len(formatted_recommendations),
402
- training_approach="single-joint"
403
- )
404
-
405
- except Exception as e:
406
- raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
407
-
408
-
409
- @app.post("/item-similarity", response_model=List[RecommendationResponse])
410
- async def get_similar_items(request: ItemSimilarityRequest):
411
- """Get items similar to a given item."""
412
-
413
- if recommendation_engine is None:
414
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
415
-
416
- try:
417
- recommendations = recommendation_engine.recommend_items_content_based(
418
- seed_item_id=request.item_id,
419
- k=request.num_recommendations
420
- )
421
-
422
- formatted_recommendations = []
423
- for item_id, score, item_info in recommendations:
424
- formatted_recommendations.append(
425
- RecommendationResponse(
426
- item_id=item_id,
427
- score=score,
428
- item_info=ItemInfo(**item_info)
429
- )
430
- )
431
-
432
- return formatted_recommendations
433
-
434
- except Exception as e:
435
- raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
436
-
437
-
438
- @app.post("/predict-rating", response_model=RatingPredictionResponse)
439
- async def predict_user_item_rating(request: RatingPredictionRequest):
440
- """Predict rating for a user-item pair."""
441
-
442
- if recommendation_engine is None:
443
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
444
-
445
- try:
446
- user_profile = request.user_profile
447
-
448
- predicted_rating = recommendation_engine.predict_rating(
449
- age=user_profile.age,
450
- gender=user_profile.gender,
451
- income=user_profile.income,
452
- interaction_history=user_profile.interaction_history,
453
- item_id=request.item_id
454
- )
455
-
456
- item_info = recommendation_engine._get_item_info(request.item_id)
457
-
458
- return RatingPredictionResponse(
459
- user_profile=user_profile,
460
- item_id=request.item_id,
461
- predicted_rating=predicted_rating,
462
- item_info=ItemInfo(**item_info)
463
- )
464
-
465
- except Exception as e:
466
- raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
467
-
468
-
469
- @app.get("/items/{item_id}", response_model=ItemInfo)
470
- async def get_item_info(item_id: int):
471
- """Get information about a specific item."""
472
-
473
- if recommendation_engine is None:
474
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
475
-
476
- try:
477
- item_info = recommendation_engine._get_item_info(item_id)
478
- return ItemInfo(**item_info)
479
-
480
- except Exception as e:
481
- raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
482
-
483
-
484
- @app.get("/items")
485
- async def get_sample_items(limit: int = 20):
486
- """Get a sample of items for testing."""
487
-
488
- if recommendation_engine is None:
489
- raise HTTPException(status_code=503, detail="Recommendation engine not available")
490
-
491
- try:
492
- # Get sample items from the dataframe
493
- sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
494
-
495
- items = []
496
- for _, row in sample_items.iterrows():
497
- items.append({
498
- "product_id": int(row['product_id']),
499
- "category_id": int(row['category_id']),
500
- "category_code": str(row['category_code']),
501
- "brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
502
- "price": float(row['price'])
503
- })
504
-
505
- return {"items": items, "total": len(items), "training_approach": "single-joint"}
506
-
507
- except Exception as e:
508
- raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
509
-
510
-
511
- if __name__ == "__main__":
512
- print("πŸš€ Starting Single Joint Training Recommendation API...")
513
- print("⚑ Training approach: End-to-end joint optimization from scratch")
514
- print("🌐 Server will be available at: http://localhost:8000")
515
- print("πŸ“š API docs at: http://localhost:8000/docs")
516
-
517
- uvicorn.run(
518
- "api_joint:app",
519
- host="0.0.0.0",
520
- port=8000,
521
- reload=True
522
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets/interactions.csv ADDED
The diff for this file is too large to render. See raw diff
 
datasets/items.csv ADDED
The diff for this file is too large to render. See raw diff
 
datasets/users.csv ADDED
The diff for this file is too large to render. See raw diff
 
frontend/src/App.css CHANGED
@@ -3,6 +3,50 @@
3
  font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
4
  }
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  /* Performance Monitoring Widget */
7
  .performance-widget {
8
  position: fixed;
@@ -1220,6 +1264,24 @@
1220
  font-size: 12px;
1221
  }
1222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1223
  .pattern-summary {
1224
  display: flex;
1225
  gap: 20px;
 
3
  font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
4
  }
5
 
6
+ /* Enhanced Demographics Styling */
7
+ .demographic-features {
8
+ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
9
+ padding: 20px;
10
+ border-radius: 10px;
11
+ border: 2px solid #dee2e6;
12
+ margin: 15px 0;
13
+ }
14
+
15
+ .demographic-features .form-group {
16
+ position: relative;
17
+ }
18
+
19
+ .demographic-features .form-group label {
20
+ font-weight: 600;
21
+ color: #495057;
22
+ font-size: 14px;
23
+ margin-bottom: 8px;
24
+ display: block;
25
+ }
26
+
27
+ .demographic-features .form-group select {
28
+ width: 100%;
29
+ padding: 10px 12px;
30
+ border: 2px solid #ced4da;
31
+ border-radius: 6px;
32
+ font-size: 14px;
33
+ background: white;
34
+ color: #495057;
35
+ transition: all 0.2s ease;
36
+ }
37
+
38
+ .demographic-features .form-group select:focus {
39
+ outline: none;
40
+ border-color: #007bff;
41
+ box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
42
+ }
43
+
44
+ .demographic-features .form-group select:disabled {
45
+ background-color: #f8f9fa;
46
+ color: #6c757d;
47
+ cursor: not-allowed;
48
+ }
49
+
50
  /* Performance Monitoring Widget */
51
  .performance-widget {
52
  position: fixed;
 
1264
  font-size: 12px;
1265
  }
1266
 
1267
+ .pattern-btn.new-user-pattern {
1268
+ border-color: #6c757d;
1269
+ color: #6c757d;
1270
+ background: #f8f9fa;
1271
+ }
1272
+
1273
+ .pattern-btn.new-user-pattern:hover {
1274
+ background: #6c757d;
1275
+ color: white;
1276
+ box-shadow: 0 4px 8px rgba(108, 117, 125, 0.3);
1277
+ }
1278
+
1279
+ .pattern-btn.new-user-pattern.active {
1280
+ background: #6c757d;
1281
+ color: white;
1282
+ box-shadow: 0 4px 8px rgba(108, 117, 125, 0.3);
1283
+ }
1284
+
1285
  .pattern-summary {
1286
  display: flex;
1287
  gap: 20px;
frontend/src/App.js CHANGED
@@ -6,6 +6,7 @@ const API_BASE_URL = process.env.REACT_APP_API_URL || 'http://localhost:8000';
6
 
7
  // Interaction patterns with realistic ratios
8
  const INTERACTION_PATTERNS = [
 
9
  { name: 'Light Browsing', views: 15, carts: 2, purchases: 0 },
10
  { name: 'Window Shopping', views: 25, carts: 5, purchases: 1 },
11
  { name: 'Serious Shopper', views: 35, carts: 8, purchases: 3 },
@@ -18,11 +19,15 @@ function App() {
18
  age: 30,
19
  gender: 'male',
20
  income: 50000,
 
 
 
 
21
  interaction_history: []
22
  });
23
 
24
  const [recommendationType, setRecommendationType] = useState('hybrid');
25
- const [numRecommendations, setNumRecommendations] = useState(100);
26
  const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
27
 
28
  const [recommendations, setRecommendations] = useState([]);
@@ -301,6 +306,10 @@ function App() {
301
  age: user.age,
302
  gender: user.gender,
303
  income: user.income,
 
 
 
 
304
  interaction_history: user.interaction_history.slice(0, 50) // Limit to 50 items
305
  });
306
  // Clear any synthetic interactions and expanded states
@@ -443,7 +452,20 @@ function App() {
443
 
444
  const handlePatternSelect = (pattern) => {
445
  setSelectedPattern(pattern);
446
- generateRealisticInteractions(pattern);
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  };
448
 
449
  const toggleInteractionExpand = (interactionId) => {
@@ -472,6 +494,46 @@ function App() {
472
 
473
  const counts = getInteractionCounts();
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  // Calculate category percentages from user interactions
476
  const getCategoryPercentages = () => {
477
  console.log('getCategoryPercentages called:', {
@@ -504,10 +566,7 @@ function App() {
504
  console.log('Enriched behavioral pattern results:', { categoryCounts, totalInteractions });
505
 
506
  if (totalInteractions > 0) {
507
- const categoryPercentages = {};
508
- Object.keys(categoryCounts).forEach(category => {
509
- categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
510
- });
511
  console.log('Returning enriched behavioral pattern percentages:', categoryPercentages);
512
  return categoryPercentages;
513
  }
@@ -522,20 +581,15 @@ function App() {
522
 
523
  interactions.forEach(interaction => {
524
  console.log('Processing interaction:', interaction);
525
- if (interaction.category && interaction.category !== 'Unknown') {
526
- const category = interaction.category;
527
- categoryCounts[category] = (categoryCounts[category] || 0) + 1;
528
- totalInteractions++;
529
- }
530
  });
531
 
532
  console.log('Synthetic interaction results:', { categoryCounts, totalInteractions });
533
 
534
  if (totalInteractions > 0) {
535
- const categoryPercentages = {};
536
- Object.keys(categoryCounts).forEach(category => {
537
- categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
538
- });
539
  console.log('Returning synthetic percentages:', categoryPercentages);
540
  return categoryPercentages;
541
  }
@@ -554,11 +608,7 @@ function App() {
554
  totalInteractions++;
555
  });
556
 
557
- const categoryPercentages = {};
558
- Object.keys(categoryCounts).forEach(category => {
559
- categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
560
- });
561
-
562
  return categoryPercentages;
563
  }
564
 
@@ -708,7 +758,13 @@ function App() {
708
  <div className="real-user-stats">
709
  <div className="user-stat">
710
  <span className="stat-label">Demographics:</span>
711
- <span className="stat-value">{selectedRealUser.age}yr {selectedRealUser.gender}, ${selectedRealUser.income.toLocaleString()}</span>
 
 
 
 
 
 
712
  </div>
713
  <div className="user-stat">
714
  <span className="stat-label">Behavior Pattern:</span>
@@ -847,6 +903,77 @@ function App() {
847
  />
848
  </div>
849
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  </div>
851
 
852
  {/* Random Behavioral Patterns for Custom Users */}
@@ -979,7 +1106,7 @@ function App() {
979
  <div className="category-percentages">
980
  {Object.entries(categoryPercentages)
981
  .sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
982
- .slice(0, 5)
983
  .map(([category, percentage]) => (
984
  <div key={category} className="category-item">
985
  <div className="category-bar-container">
@@ -1110,7 +1237,7 @@ function App() {
1110
  )}
1111
 
1112
  {/* Category Analysis for Custom Users */}
1113
- {(selectedBehavioralPattern || interactions.length > 0 || userProfile.interaction_history.length > 0) && (
1114
  <div
1115
  key={`category-analysis-${interactions.length}-${selectedBehavioralPattern?.id || 'none'}-${sampleItems.length}`}
1116
  className="category-analysis"
@@ -1128,7 +1255,7 @@ function App() {
1128
  {Object.keys(categoryPercentages).length > 0 ? (
1129
  Object.entries(categoryPercentages)
1130
  .sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
1131
- .slice(0, 5)
1132
  .map(([category, percentage]) => (
1133
  <div key={category} className="category-item">
1134
  <div className="category-bar-container">
@@ -1141,6 +1268,23 @@ function App() {
1141
  <span className="category-percent">{percentage}%</span>
1142
  </div>
1143
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1144
  ) : (
1145
  <div className="category-loading">
1146
  <p>Processing interaction categories...</p>
@@ -1325,18 +1469,24 @@ function App() {
1325
  )}
1326
 
1327
  <h3>Synthetic Interaction Patterns</h3>
1328
- <p>Generate realistic user behavior patterns with proportional view, cart, and purchase events</p>
1329
 
1330
  <div className="pattern-buttons">
1331
  {INTERACTION_PATTERNS.map((pattern, index) => (
1332
  <button
1333
  key={index}
1334
- className={`pattern-btn ${selectedPattern?.name === pattern.name ? 'active' : ''}`}
1335
  onClick={() => handlePatternSelect(pattern)}
1336
  >
1337
  {pattern.name}
1338
  <br />
1339
- <small>{pattern.views}V β€’ {pattern.carts}C β€’ {pattern.purchases}P</small>
 
 
 
 
 
 
1340
  </button>
1341
  ))}
1342
  <button
@@ -1347,6 +1497,25 @@ function App() {
1347
  Clear All
1348
  </button>
1349
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
  </>
1351
  )}
1352
 
@@ -1498,11 +1667,10 @@ function App() {
1498
  value={recommendationType}
1499
  onChange={(e) => setRecommendationType(e.target.value)}
1500
  >
1501
- <option value="hybrid">Hybrid</option>
1502
- <option value="enhanced">🎯 Enhanced Hybrid (Category-Aware)</option>
1503
- <option value="category_focused">🎯 Category Focused (80% Match)</option>
1504
  <option value="collaborative">Collaborative Filtering</option>
1505
  <option value="content">Content-Based</option>
 
1506
  </select>
1507
  </div>
1508
 
@@ -1521,7 +1689,7 @@ function App() {
1521
  </select>
1522
  </div>
1523
 
1524
- {(recommendationType === 'hybrid' || recommendationType === 'enhanced') && (
1525
  <div className="form-group">
1526
  <label htmlFor="collabWeight">Collaborative Weight:</label>
1527
  <input
@@ -1548,7 +1716,13 @@ function App() {
1548
 
1549
  {recommendationType === 'content' && userProfile.interaction_history.length === 0 && (
1550
  <p style={{color: '#dc3545', marginTop: '10px', fontSize: '14px'}}>
1551
- Content-based recommendations require interaction history. Please select an interaction pattern above.
 
 
 
 
 
 
1552
  </p>
1553
  )}
1554
  </div>
@@ -1567,7 +1741,7 @@ function App() {
1567
 
1568
  <div className="stats">
1569
  <strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
1570
- ${userProfile.income.toLocaleString()} income
1571
  {selectedCategory && (
1572
  <span> | <strong>Category Filter:</strong> <span className="category-filter-display">{selectedCategory.replace(/\./g, ' > ')}</span></span>
1573
  )}
 
6
 
7
  // Interaction patterns with realistic ratios
8
  const INTERACTION_PATTERNS = [
9
+ { name: 'New User (No History)', views: 0, carts: 0, purchases: 0, isNewUser: true },
10
  { name: 'Light Browsing', views: 15, carts: 2, purchases: 0 },
11
  { name: 'Window Shopping', views: 25, carts: 5, purchases: 1 },
12
  { name: 'Serious Shopper', views: 35, carts: 8, purchases: 3 },
 
19
  age: 30,
20
  gender: 'male',
21
  income: 50000,
22
+ profession: 'Technology',
23
+ location: 'Urban',
24
+ education_level: "Bachelor's",
25
+ marital_status: 'Single',
26
  interaction_history: []
27
  });
28
 
29
  const [recommendationType, setRecommendationType] = useState('hybrid');
30
+ const [numRecommendations, setNumRecommendations] = useState(10);
31
  const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
32
 
33
  const [recommendations, setRecommendations] = useState([]);
 
306
  age: user.age,
307
  gender: user.gender,
308
  income: user.income,
309
+ profession: user.profession || 'Other',
310
+ location: user.location || 'Urban',
311
+ education_level: user.education_level || 'High School',
312
+ marital_status: user.marital_status || 'Single',
313
  interaction_history: user.interaction_history.slice(0, 50) // Limit to 50 items
314
  });
315
  // Clear any synthetic interactions and expanded states
 
452
 
453
  const handlePatternSelect = (pattern) => {
454
  setSelectedPattern(pattern);
455
+
456
+ // Handle New User (zero interactions) pattern specially
457
+ if (pattern.isNewUser) {
458
+ // Clear all interactions and interaction history for new user
459
+ setInteractions([]);
460
+ setUserProfile(prev => ({
461
+ ...prev,
462
+ interaction_history: []
463
+ }));
464
+ console.log('Selected New User pattern - cleared all interactions');
465
+ } else {
466
+ // Generate realistic interactions for other patterns
467
+ generateRealisticInteractions(pattern);
468
+ }
469
  };
470
 
471
  const toggleInteractionExpand = (interactionId) => {
 
494
 
495
  const counts = getInteractionCounts();
496
 
497
+ // Utility function to normalize percentages to sum to exactly 100%
498
+ const normalizePercentages = (categoryCounts, totalInteractions) => {
499
+ if (totalInteractions === 0) return {};
500
+
501
+ const categories = Object.keys(categoryCounts);
502
+ if (categories.length === 0) return {};
503
+
504
+ // Calculate raw percentages
505
+ const rawPercentages = {};
506
+ categories.forEach(category => {
507
+ rawPercentages[category] = (categoryCounts[category] / totalInteractions) * 100;
508
+ });
509
+
510
+ // Round all percentages to 1 decimal place
511
+ const roundedPercentages = {};
512
+ let totalRounded = 0;
513
+ categories.forEach(category => {
514
+ roundedPercentages[category] = Math.round(rawPercentages[category] * 10) / 10;
515
+ totalRounded += roundedPercentages[category];
516
+ });
517
+
518
+ // Adjust the largest category to make total exactly 100%
519
+ const difference = 100.0 - totalRounded;
520
+ if (Math.abs(difference) > 0.01) {
521
+ // Find category with largest raw percentage
522
+ const largestCategory = categories.reduce((max, category) =>
523
+ rawPercentages[category] > rawPercentages[max] ? category : max
524
+ );
525
+ roundedPercentages[largestCategory] = Math.round((roundedPercentages[largestCategory] + difference) * 10) / 10;
526
+ }
527
+
528
+ // Convert to string with 1 decimal place
529
+ const normalizedPercentages = {};
530
+ categories.forEach(category => {
531
+ normalizedPercentages[category] = roundedPercentages[category].toFixed(1);
532
+ });
533
+
534
+ return normalizedPercentages;
535
+ };
536
+
537
  // Calculate category percentages from user interactions
538
  const getCategoryPercentages = () => {
539
  console.log('getCategoryPercentages called:', {
 
566
  console.log('Enriched behavioral pattern results:', { categoryCounts, totalInteractions });
567
 
568
  if (totalInteractions > 0) {
569
+ const categoryPercentages = normalizePercentages(categoryCounts, totalInteractions);
 
 
 
570
  console.log('Returning enriched behavioral pattern percentages:', categoryPercentages);
571
  return categoryPercentages;
572
  }
 
581
 
582
  interactions.forEach(interaction => {
583
  console.log('Processing interaction:', interaction);
584
+ const category = interaction.category_code || interaction.category || 'Unknown';
585
+ categoryCounts[category] = (categoryCounts[category] || 0) + 1;
586
+ totalInteractions++;
 
 
587
  });
588
 
589
  console.log('Synthetic interaction results:', { categoryCounts, totalInteractions });
590
 
591
  if (totalInteractions > 0) {
592
+ const categoryPercentages = normalizePercentages(categoryCounts, totalInteractions);
 
 
 
593
  console.log('Returning synthetic percentages:', categoryPercentages);
594
  return categoryPercentages;
595
  }
 
608
  totalInteractions++;
609
  });
610
 
611
+ const categoryPercentages = normalizePercentages(categoryCounts, totalInteractions);
 
 
 
 
612
  return categoryPercentages;
613
  }
614
 
 
758
  <div className="real-user-stats">
759
  <div className="user-stat">
760
  <span className="stat-label">Demographics:</span>
761
+ <span className="stat-value">
762
+ {selectedRealUser.age}yr {selectedRealUser.gender}, ${selectedRealUser.income.toLocaleString()}
763
+ {selectedRealUser.profession && ` | ${selectedRealUser.profession}`}
764
+ {selectedRealUser.location && ` | ${selectedRealUser.location}`}
765
+ {selectedRealUser.education_level && ` | ${selectedRealUser.education_level}`}
766
+ {selectedRealUser.marital_status && ` | ${selectedRealUser.marital_status}`}
767
+ </span>
768
  </div>
769
  <div className="user-stat">
770
  <span className="stat-label">Behavior Pattern:</span>
 
903
  />
904
  </div>
905
  </div>
906
+
907
+ {/* New Demographic Features */}
908
+ <div className="form-row demographic-features">
909
+ <div className="form-group">
910
+ <label htmlFor="profession">Profession:</label>
911
+ <select
912
+ id="profession"
913
+ value={userProfile.profession}
914
+ onChange={(e) => handleProfileChange('profession', e.target.value)}
915
+ disabled={useRealUsers && selectedRealUser}
916
+ style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
917
+ >
918
+ <option value="Technology">Technology</option>
919
+ <option value="Healthcare">Healthcare</option>
920
+ <option value="Education">Education</option>
921
+ <option value="Finance">Finance</option>
922
+ <option value="Retail">Retail</option>
923
+ <option value="Manufacturing">Manufacturing</option>
924
+ <option value="Services">Services</option>
925
+ <option value="Other">Other</option>
926
+ </select>
927
+ </div>
928
+
929
+ <div className="form-group">
930
+ <label htmlFor="location">Location:</label>
931
+ <select
932
+ id="location"
933
+ value={userProfile.location}
934
+ onChange={(e) => handleProfileChange('location', e.target.value)}
935
+ disabled={useRealUsers && selectedRealUser}
936
+ style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
937
+ >
938
+ <option value="Urban">Urban</option>
939
+ <option value="Suburban">Suburban</option>
940
+ <option value="Rural">Rural</option>
941
+ </select>
942
+ </div>
943
+
944
+ <div className="form-group">
945
+ <label htmlFor="education_level">Education Level:</label>
946
+ <select
947
+ id="education_level"
948
+ value={userProfile.education_level}
949
+ onChange={(e) => handleProfileChange('education_level', e.target.value)}
950
+ disabled={useRealUsers && selectedRealUser}
951
+ style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
952
+ >
953
+ <option value="High School">High School</option>
954
+ <option value="Some College">Some College</option>
955
+ <option value="Bachelor's">Bachelor's</option>
956
+ <option value="Master's">Master's</option>
957
+ <option value="PhD+">PhD+</option>
958
+ </select>
959
+ </div>
960
+
961
+ <div className="form-group">
962
+ <label htmlFor="marital_status">Marital Status:</label>
963
+ <select
964
+ id="marital_status"
965
+ value={userProfile.marital_status}
966
+ onChange={(e) => handleProfileChange('marital_status', e.target.value)}
967
+ disabled={useRealUsers && selectedRealUser}
968
+ style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
969
+ >
970
+ <option value="Single">Single</option>
971
+ <option value="Married">Married</option>
972
+ <option value="Divorced">Divorced</option>
973
+ <option value="Widowed">Widowed</option>
974
+ </select>
975
+ </div>
976
+ </div>
977
  </div>
978
 
979
  {/* Random Behavioral Patterns for Custom Users */}
 
1106
  <div className="category-percentages">
1107
  {Object.entries(categoryPercentages)
1108
  .sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
1109
+ .slice(0, 10)
1110
  .map(([category, percentage]) => (
1111
  <div key={category} className="category-item">
1112
  <div className="category-bar-container">
 
1237
  )}
1238
 
1239
  {/* Category Analysis for Custom Users */}
1240
+ {(selectedBehavioralPattern || interactions.length > 0 || userProfile.interaction_history.length > 0 || (recommendations.length > 0 && selectedPattern?.isNewUser)) && (
1241
  <div
1242
  key={`category-analysis-${interactions.length}-${selectedBehavioralPattern?.id || 'none'}-${sampleItems.length}`}
1243
  className="category-analysis"
 
1255
  {Object.keys(categoryPercentages).length > 0 ? (
1256
  Object.entries(categoryPercentages)
1257
  .sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
1258
+ .slice(0, 10)
1259
  .map(([category, percentage]) => (
1260
  <div key={category} className="category-item">
1261
  <div className="category-bar-container">
 
1268
  <span className="category-percent">{percentage}%</span>
1269
  </div>
1270
  ))
1271
+ ) : selectedPattern?.isNewUser ? (
1272
+ <div className="new-user-category-message">
1273
+ <div style={{
1274
+ padding: '20px',
1275
+ backgroundColor: '#f8f9fa',
1276
+ border: '2px dashed #6c757d',
1277
+ borderRadius: '8px',
1278
+ textAlign: 'center',
1279
+ color: '#495057'
1280
+ }}>
1281
+ <h6 style={{margin: '0 0 8px 0', color: '#343a40'}}>πŸ†• New User - No History</h6>
1282
+ <p style={{margin: '0', fontSize: '14px'}}>
1283
+ No category preferences yet.<br />
1284
+ Recommendations are based on demographics only.
1285
+ </p>
1286
+ </div>
1287
+ </div>
1288
  ) : (
1289
  <div className="category-loading">
1290
  <p>Processing interaction categories...</p>
 
1469
  )}
1470
 
1471
  <h3>Synthetic Interaction Patterns</h3>
1472
+ <p>Generate realistic user behavior patterns with proportional view, cart, and purchase events. Choose "New User" to test cold-start scenarios.</p>
1473
 
1474
  <div className="pattern-buttons">
1475
  {INTERACTION_PATTERNS.map((pattern, index) => (
1476
  <button
1477
  key={index}
1478
+ className={`pattern-btn ${selectedPattern?.name === pattern.name ? 'active' : ''} ${pattern.isNewUser ? 'new-user-pattern' : ''}`}
1479
  onClick={() => handlePatternSelect(pattern)}
1480
  >
1481
  {pattern.name}
1482
  <br />
1483
+ <small>
1484
+ {pattern.isNewUser ? (
1485
+ <span style={{fontStyle: 'italic', color: '#6c757d'}}>Cold Start User</span>
1486
+ ) : (
1487
+ `${pattern.views}V β€’ ${pattern.carts}C β€’ ${pattern.purchases}P`
1488
+ )}
1489
+ </small>
1490
  </button>
1491
  ))}
1492
  <button
 
1497
  Clear All
1498
  </button>
1499
  </div>
1500
+
1501
+ {/* Show informational message for New User pattern */}
1502
+ {selectedPattern?.isNewUser && (
1503
+ <div style={{
1504
+ backgroundColor: '#e3f2fd',
1505
+ border: '1px solid #90caf9',
1506
+ borderRadius: '8px',
1507
+ padding: '15px',
1508
+ margin: '15px 0',
1509
+ color: '#1565c0'
1510
+ }}>
1511
+ <h4 style={{margin: '0 0 10px 0', color: '#0d47a1'}}>πŸ†• New User (Cold Start) Selected</h4>
1512
+ <p style={{margin: '0', fontSize: '14px', lineHeight: '1.4'}}>
1513
+ Testing cold-start scenario with no interaction history.
1514
+ <br /><strong>Compatible algorithms:</strong> Collaborative βœ…, Hybrid βœ… (demographics-based)
1515
+ <br /><strong>Incompatible:</strong> Content-based ❌, Category-boosted ❌ (require history)
1516
+ </p>
1517
+ </div>
1518
+ )}
1519
  </>
1520
  )}
1521
 
 
1667
  value={recommendationType}
1668
  onChange={(e) => setRecommendationType(e.target.value)}
1669
  >
1670
+ <option value="hybrid">Hybrid (Recommended)</option>
 
 
1671
  <option value="collaborative">Collaborative Filtering</option>
1672
  <option value="content">Content-Based</option>
1673
+ <option value="category_boosted">πŸ“Š Category Boosted (50% from user categories)</option>
1674
  </select>
1675
  </div>
1676
 
 
1689
  </select>
1690
  </div>
1691
 
1692
+ {recommendationType === 'hybrid' && (
1693
  <div className="form-group">
1694
  <label htmlFor="collabWeight">Collaborative Weight:</label>
1695
  <input
 
1716
 
1717
  {recommendationType === 'content' && userProfile.interaction_history.length === 0 && (
1718
  <p style={{color: '#dc3545', marginTop: '10px', fontSize: '14px'}}>
1719
+ ⚠️ Content-based recommendations require interaction history. Please select a pattern with interactions above, or choose 'Collaborative' or 'Hybrid' for new users.
1720
+ </p>
1721
+ )}
1722
+
1723
+ {recommendationType === 'category_boosted' && userProfile.interaction_history.length === 0 && (
1724
+ <p style={{color: '#dc3545', marginTop: '10px', fontSize: '14px'}}>
1725
+ ⚠️ Category-boosted recommendations require interaction history to analyze preferences. Please select a pattern with interactions above, or choose 'Collaborative' for new users.
1726
  </p>
1727
  )}
1728
  </div>
 
1741
 
1742
  <div className="stats">
1743
  <strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
1744
+ ${userProfile.income.toLocaleString()} income, {userProfile.profession}, {userProfile.location}, {userProfile.education_level}, {userProfile.marital_status}
1745
  {selectedCategory && (
1746
  <span> | <strong>Category Filter:</strong> <span className="category-filter-display">{selectedCategory.replace(/\./g, ' > ')}</span></span>
1747
  )}
src/data_generation/generate_demographics.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from typing import Dict, List
4
+ import os
5
+
6
+ class DemographicDataGenerator:
7
+ """Generate realistic categorical demographic data correlating with existing age/income."""
8
+
9
+ def __init__(self, seed: int = 42):
10
+ np.random.seed(seed)
11
+
12
+ # Define categorical mappings
13
+ self.profession_categories = [
14
+ "Technology", "Healthcare", "Education", "Finance",
15
+ "Retail", "Manufacturing", "Services", "Other"
16
+ ]
17
+
18
+ self.location_categories = ["Urban", "Suburban", "Rural"]
19
+
20
+ self.education_categories = [
21
+ "High School", "Some College", "Bachelor's", "Master's", "PhD+"
22
+ ]
23
+
24
+ self.marital_categories = ["Single", "Married", "Divorced", "Widowed"]
25
+
26
+ def generate_profession(self, age: int, income: float, gender: str) -> str:
27
+ """Generate profession based on age, income, and gender correlations."""
28
+
29
+ # Age-based profession probabilities
30
+ if age < 25:
31
+ # Young adults - more likely in retail, services, some tech
32
+ probs = [0.15, 0.10, 0.08, 0.05, 0.25, 0.10, 0.20, 0.07]
33
+ elif age < 35:
34
+ # Early career - tech, healthcare, finance growth
35
+ probs = [0.25, 0.15, 0.10, 0.15, 0.12, 0.08, 0.10, 0.05]
36
+ elif age < 50:
37
+ # Mid career - established in all fields
38
+ probs = [0.20, 0.18, 0.15, 0.18, 0.08, 0.12, 0.07, 0.02]
39
+ else:
40
+ # Senior career - more in education, healthcare, services
41
+ probs = [0.15, 0.20, 0.20, 0.15, 0.05, 0.15, 0.08, 0.02]
42
+
43
+ # Income adjustments
44
+ if income > 90000: # High income
45
+ # Boost tech, finance, healthcare
46
+ probs[0] *= 1.5 # Technology
47
+ probs[3] *= 1.5 # Finance
48
+ probs[1] *= 1.3 # Healthcare
49
+ probs[4] *= 0.5 # Retail
50
+ probs[6] *= 0.7 # Services
51
+ elif income < 40000: # Lower income
52
+ # Boost retail, services, manufacturing
53
+ probs[4] *= 2.0 # Retail
54
+ probs[6] *= 1.8 # Services
55
+ probs[5] *= 1.5 # Manufacturing
56
+ probs[0] *= 0.3 # Technology
57
+ probs[3] *= 0.3 # Finance
58
+
59
+ # Normalize probabilities
60
+ probs = np.array(probs)
61
+ probs = probs / np.sum(probs)
62
+
63
+ return np.random.choice(self.profession_categories, p=probs)
64
+
65
+ def generate_location(self, income: float, profession: str) -> str:
66
+ """Generate location based on income and profession."""
67
+
68
+ # Base probabilities (roughly US distribution)
69
+ probs = [0.62, 0.27, 0.11] # Urban, Suburban, Rural
70
+
71
+ # Income adjustments
72
+ if income > 80000:
73
+ # Higher income -> more suburban
74
+ probs = [0.45, 0.45, 0.10]
75
+ elif income < 35000:
76
+ # Lower income -> more urban/rural
77
+ probs = [0.70, 0.15, 0.15]
78
+
79
+ # Profession adjustments
80
+ if profession in ["Technology", "Finance"]:
81
+ # Tech/Finance -> more urban
82
+ probs[0] *= 1.4
83
+ probs[2] *= 0.5
84
+ elif profession in ["Manufacturing", "Other"]:
85
+ # Manufacturing -> more rural/suburban
86
+ probs[1] *= 1.3
87
+ probs[2] *= 1.5
88
+ probs[0] *= 0.7
89
+
90
+ # Normalize
91
+ probs = np.array(probs)
92
+ probs = probs / np.sum(probs)
93
+
94
+ return np.random.choice(self.location_categories, p=probs)
95
+
96
+ def generate_education_level(self, age: int, income: float, profession: str) -> str:
97
+ """Generate education level based on age, income, and profession."""
98
+
99
+ # Base probabilities (roughly US distribution)
100
+ probs = [0.27, 0.20, 0.33, 0.13, 0.07] # HS, Some College, Bachelor's, Master's, PhD+
101
+
102
+ # Age adjustments (older generations had less college access)
103
+ if age > 55:
104
+ probs = [0.40, 0.25, 0.25, 0.08, 0.02]
105
+ elif age > 40:
106
+ probs = [0.32, 0.23, 0.30, 0.12, 0.03]
107
+ elif age < 30:
108
+ # Younger generation has more education
109
+ probs = [0.20, 0.15, 0.40, 0.18, 0.07]
110
+
111
+ # Income adjustments
112
+ if income > 100000:
113
+ # High income -> more advanced degrees
114
+ probs = [0.10, 0.10, 0.35, 0.30, 0.15]
115
+ elif income > 70000:
116
+ # Good income -> more bachelor's/master's
117
+ probs = [0.15, 0.15, 0.45, 0.20, 0.05]
118
+ elif income < 40000:
119
+ # Lower income -> less higher education
120
+ probs = [0.45, 0.30, 0.20, 0.04, 0.01]
121
+
122
+ # Profession adjustments
123
+ if profession in ["Technology", "Healthcare", "Finance"]:
124
+ # Professional fields -> more degrees
125
+ probs = [0.05, 0.10, 0.40, 0.30, 0.15]
126
+ elif profession == "Education":
127
+ # Education -> even more advanced degrees
128
+ probs = [0.02, 0.05, 0.25, 0.45, 0.23]
129
+ elif profession in ["Retail", "Services", "Manufacturing"]:
130
+ # Service industries -> less higher education
131
+ probs = [0.40, 0.25, 0.25, 0.08, 0.02]
132
+
133
+ # Normalize
134
+ probs = np.array(probs)
135
+ probs = probs / np.sum(probs)
136
+
137
+ return np.random.choice(self.education_categories, p=probs)
138
+
139
+ def generate_marital_status(self, age: int, gender: str) -> str:
140
+ """Generate marital status based on age and gender."""
141
+
142
+ # Age-based probabilities
143
+ if age < 25:
144
+ probs = [0.85, 0.13, 0.02, 0.00] # Single, Married, Divorced, Widowed
145
+ elif age < 35:
146
+ probs = [0.45, 0.50, 0.05, 0.00]
147
+ elif age < 50:
148
+ probs = [0.15, 0.70, 0.14, 0.01]
149
+ elif age < 65:
150
+ probs = [0.10, 0.65, 0.20, 0.05]
151
+ else:
152
+ probs = [0.08, 0.55, 0.15, 0.22]
153
+
154
+ # Gender adjustments (women tend to be widowed more often in older ages)
155
+ if age > 65 and gender == 'female':
156
+ probs[3] *= 2.0 # More widowed women
157
+ probs[1] *= 0.8 # Fewer married
158
+
159
+ # Normalize
160
+ probs = np.array(probs)
161
+ probs = probs / np.sum(probs)
162
+
163
+ return np.random.choice(self.marital_categories, p=probs)
164
+
165
+ def generate_user_demographics(self, users_df: pd.DataFrame) -> pd.DataFrame:
166
+ """Generate all demographic features for all users."""
167
+
168
+ print(f"Generating demographic data for {len(users_df)} users...")
169
+
170
+ # Create a copy to avoid modifying original
171
+ enhanced_users = users_df.copy()
172
+
173
+ # Generate each demographic feature
174
+ professions = []
175
+ locations = []
176
+ education_levels = []
177
+ marital_statuses = []
178
+
179
+ for idx, row in users_df.iterrows():
180
+ age = row['age']
181
+ income = row['income']
182
+ gender = row['gender']
183
+
184
+ # Generate profession first as it influences other features
185
+ profession = self.generate_profession(age, income, gender)
186
+ professions.append(profession)
187
+
188
+ # Generate location based on income and profession
189
+ location = self.generate_location(income, profession)
190
+ locations.append(location)
191
+
192
+ # Generate education based on age, income, and profession
193
+ education = self.generate_education_level(age, income, profession)
194
+ education_levels.append(education)
195
+
196
+ # Generate marital status based on age and gender
197
+ marital_status = self.generate_marital_status(age, gender)
198
+ marital_statuses.append(marital_status)
199
+
200
+ # Add new columns
201
+ enhanced_users['profession'] = professions
202
+ enhanced_users['location'] = locations
203
+ enhanced_users['education_level'] = education_levels
204
+ enhanced_users['marital_status'] = marital_statuses
205
+
206
+ return enhanced_users
207
+
208
+ def print_demographic_statistics(self, users_df: pd.DataFrame):
209
+ """Print statistics about the generated demographics."""
210
+
211
+ print("\n=== Demographic Statistics ===")
212
+
213
+ # Profession distribution
214
+ print(f"\nProfession Distribution:")
215
+ prof_counts = users_df['profession'].value_counts()
216
+ for prof, count in prof_counts.items():
217
+ pct = (count / len(users_df)) * 100
218
+ print(f" {prof}: {count:,} ({pct:.1f}%)")
219
+
220
+ # Location distribution
221
+ print(f"\nLocation Distribution:")
222
+ loc_counts = users_df['location'].value_counts()
223
+ for loc, count in loc_counts.items():
224
+ pct = (count / len(users_df)) * 100
225
+ print(f" {loc}: {count:,} ({pct:.1f}%)")
226
+
227
+ # Education distribution
228
+ print(f"\nEducation Level Distribution:")
229
+ edu_counts = users_df['education_level'].value_counts()
230
+ for edu, count in edu_counts.items():
231
+ pct = (count / len(users_df)) * 100
232
+ print(f" {edu}: {count:,} ({pct:.1f}%)")
233
+
234
+ # Marital status distribution
235
+ print(f"\nMarital Status Distribution:")
236
+ marital_counts = users_df['marital_status'].value_counts()
237
+ for status, count in marital_counts.items():
238
+ pct = (count / len(users_df)) * 100
239
+ print(f" {status}: {count:,} ({pct:.1f}%)")
240
+
241
+ print(f"\nTotal users: {len(users_df):,}")
242
+
243
+ # Cross-tabulations to show correlations
244
+ print(f"\n=== Key Correlations ===")
245
+
246
+ # High income professions
247
+ high_income = users_df[users_df['income'] > 80000]
248
+ print(f"\nTop professions for high income (>${80000:,}+):")
249
+ high_income_prof = high_income['profession'].value_counts(normalize=True) * 100
250
+ for prof, pct in high_income_prof.head().items():
251
+ print(f" {prof}: {pct:.1f}%")
252
+
253
+ # Education by profession
254
+ print(f"\nEducation levels in Technology:")
255
+ tech_edu = users_df[users_df['profession'] == 'Technology']['education_level'].value_counts(normalize=True) * 100
256
+ for edu, pct in tech_edu.items():
257
+ print(f" {edu}: {pct:.1f}%")
258
+
259
+
260
+ def main():
261
+ """Main function to generate and save enhanced demographic data."""
262
+
263
+ # Load existing users data
264
+ users_path = "datasets/users.csv"
265
+ if not os.path.exists(users_path):
266
+ print(f"Error: {users_path} not found!")
267
+ return
268
+
269
+ print(f"Loading users data from {users_path}")
270
+ users_df = pd.read_csv(users_path)
271
+
272
+ print(f"Original data shape: {users_df.shape}")
273
+ print(f"Original columns: {list(users_df.columns)}")
274
+
275
+ # Generate demographic data
276
+ generator = DemographicDataGenerator(seed=42)
277
+ enhanced_users = generator.generate_user_demographics(users_df)
278
+
279
+ # Print statistics
280
+ generator.print_demographic_statistics(enhanced_users)
281
+
282
+ # Save enhanced data
283
+ output_path = "datasets/users_enhanced.csv"
284
+ enhanced_users.to_csv(output_path, index=False)
285
+ print(f"\nEnhanced users data saved to {output_path}")
286
+
287
+ print(f"Enhanced data shape: {enhanced_users.shape}")
288
+ print(f"New columns: {list(enhanced_users.columns)}")
289
+
290
+
291
+ if __name__ == "__main__":
292
+ main()
src/inference/enhanced_recommendation_engine.py DELETED
@@ -1,303 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Enhanced recommendation engine with category-aware filtering and improved user alignment.
4
- """
5
-
6
- import numpy as np
7
- import pandas as pd
8
- from typing import Dict, List, Tuple, Optional
9
- from collections import Counter
10
- import random
11
-
12
- import sys
13
- import os
14
- sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
15
-
16
- from src.inference.recommendation_engine import RecommendationEngine
17
- from src.utils.real_user_selector import RealUserSelector
18
-
19
-
20
- class EnhancedRecommendationEngine(RecommendationEngine):
21
- """Enhanced recommendation engine with category-aware improvements."""
22
-
23
- def __init__(self, artifacts_path: str = "src/artifacts/"):
24
- super().__init__(artifacts_path)
25
- self.real_user_selector = RealUserSelector()
26
-
27
- def _analyze_user_category_preferences(self, interaction_history: List[int]) -> Dict[str, float]:
28
- """Analyze user's category preferences from interaction history."""
29
-
30
- if not interaction_history:
31
- return {}
32
-
33
- category_counts = Counter()
34
-
35
- for item_id in interaction_history:
36
- # Get item category from items dataframe
37
- item_row = self.items_df[self.items_df['product_id'] == item_id]
38
- if not item_row.empty:
39
- category = item_row.iloc[0].get('category_code', 'Unknown')
40
- category_counts[category] += 1
41
-
42
- # Convert to percentages
43
- total_interactions = sum(category_counts.values())
44
- if total_interactions == 0:
45
- return {}
46
-
47
- category_preferences = {}
48
- for category, count in category_counts.items():
49
- category_preferences[category] = count / total_interactions
50
-
51
- return category_preferences
52
-
53
- def _boost_category_aligned_recommendations(self,
54
- recommendations: List[Tuple[int, float, Dict]],
55
- user_category_preferences: Dict[str, float],
56
- boost_factor: float = 1.5) -> List[Tuple[int, float, Dict]]:
57
- """Boost recommendations that align with user's category preferences."""
58
-
59
- if not user_category_preferences:
60
- return recommendations
61
-
62
- boosted_recs = []
63
-
64
- for item_id, score, item_info in recommendations:
65
- item_category = item_info.get('category_code', 'Unknown')
66
-
67
- # Apply category boost if user has preference for this category
68
- category_preference = user_category_preferences.get(item_category, 0)
69
-
70
- if category_preference > 0:
71
- # Boost score based on user's preference strength
72
- boosted_score = score * (1 + boost_factor * category_preference)
73
- boosted_recs.append((item_id, boosted_score, item_info))
74
- else:
75
- boosted_recs.append((item_id, score, item_info))
76
-
77
- # Re-sort by boosted scores
78
- boosted_recs.sort(key=lambda x: x[1], reverse=True)
79
- return boosted_recs
80
-
81
- def _diversify_recommendations(self,
82
- recommendations: List[Tuple[int, float, Dict]],
83
- max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
84
- """Ensure category diversity in recommendations."""
85
-
86
- category_counts = Counter()
87
- diversified_recs = []
88
-
89
- for item_id, score, item_info in recommendations:
90
- item_category = item_info.get('category_code', 'Unknown')
91
-
92
- if category_counts[item_category] < max_per_category:
93
- diversified_recs.append((item_id, score, item_info))
94
- category_counts[item_category] += 1
95
-
96
- return diversified_recs
97
-
98
- def recommend_items_enhanced_hybrid(self,
99
- age: int,
100
- gender: str,
101
- income: float,
102
- interaction_history: List[int] = None,
103
- k: int = 10,
104
- collaborative_weight: float = 0.7,
105
- category_boost: float = 1.5,
106
- enable_category_boost: bool = True,
107
- enable_diversity: bool = True,
108
- max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
109
- """Generate enhanced hybrid recommendations with category awareness."""
110
-
111
- # Start with base hybrid recommendations (get more than needed)
112
- base_k = k * 3 # Get 3x more candidates for filtering
113
-
114
- base_recommendations = self.recommend_items_hybrid(
115
- age=age,
116
- gender=gender,
117
- income=income,
118
- interaction_history=interaction_history,
119
- k=base_k,
120
- collaborative_weight=collaborative_weight
121
- )
122
-
123
- if not base_recommendations:
124
- return []
125
-
126
- # Analyze user's category preferences
127
- if enable_category_boost and interaction_history:
128
- user_category_preferences = self._analyze_user_category_preferences(interaction_history)
129
-
130
- # Apply category-based boosting
131
- base_recommendations = self._boost_category_aligned_recommendations(
132
- base_recommendations,
133
- user_category_preferences,
134
- boost_factor=category_boost
135
- )
136
-
137
- # Apply diversity filtering if enabled
138
- if enable_diversity:
139
- base_recommendations = self._diversify_recommendations(
140
- base_recommendations,
141
- max_per_category=max_per_category
142
- )
143
-
144
- # Return top k
145
- return base_recommendations[:k]
146
-
147
- def recommend_items_category_focused(self,
148
- age: int,
149
- gender: str,
150
- income: float,
151
- interaction_history: List[int] = None,
152
- k: int = 10,
153
- focus_percentage: float = 0.7) -> List[Tuple[int, float, Dict]]:
154
- """Generate recommendations focused on user's preferred categories."""
155
-
156
- if not interaction_history:
157
- # Fall back to regular hybrid for users without history
158
- return self.recommend_items_hybrid(age, gender, income, interaction_history, k)
159
-
160
- # Analyze user preferences
161
- user_category_preferences = self._analyze_user_category_preferences(interaction_history)
162
-
163
- if not user_category_preferences:
164
- return self.recommend_items_hybrid(age, gender, income, interaction_history, k)
165
-
166
- # Get top categories (sorted by preference)
167
- top_categories = sorted(user_category_preferences.items(),
168
- key=lambda x: x[1], reverse=True)
169
-
170
- # Determine how many recs to focus on preferred categories
171
- focused_k = int(k * focus_percentage)
172
- exploration_k = k - focused_k
173
-
174
- # Get base recommendations
175
- all_recommendations = self.recommend_items_hybrid(
176
- age, gender, income, interaction_history, k * 2
177
- )
178
-
179
- # Split into focused and exploration recommendations
180
- focused_recs = []
181
- exploration_recs = []
182
-
183
- # Get user's top 3 categories
184
- preferred_categories = set([cat for cat, _ in top_categories[:3]])
185
-
186
- for item_id, score, item_info in all_recommendations:
187
- item_category = item_info.get('category_code', 'Unknown')
188
-
189
- if (item_category in preferred_categories and
190
- len(focused_recs) < focused_k):
191
- focused_recs.append((item_id, score, item_info))
192
- elif len(exploration_recs) < exploration_k:
193
- exploration_recs.append((item_id, score, item_info))
194
-
195
- # Combine focused and exploration recommendations
196
- final_recommendations = focused_recs + exploration_recs
197
-
198
- return final_recommendations[:k]
199
-
200
- def get_recommendation_explanation(self,
201
- recommendations: List[Tuple[int, float, Dict]],
202
- interaction_history: List[int] = None) -> Dict:
203
- """Provide explanation for why these recommendations were generated."""
204
-
205
- if not recommendations:
206
- return {"message": "No recommendations generated"}
207
-
208
- # Analyze recommendation categories
209
- rec_categories = Counter()
210
- for _, _, item_info in recommendations:
211
- category = item_info.get('category_code', 'Unknown')
212
- rec_categories[category] += 1
213
-
214
- explanation = {
215
- "total_recommendations": len(recommendations),
216
- "categories_covered": len(rec_categories),
217
- "category_breakdown": dict(rec_categories.most_common())
218
- }
219
-
220
- # Add user preference analysis if history available
221
- if interaction_history:
222
- user_preferences = self._analyze_user_category_preferences(interaction_history)
223
-
224
- # Calculate alignment
225
- user_cats = set(user_preferences.keys())
226
- rec_cats = set(rec_categories.keys())
227
- alignment = len(user_cats & rec_cats) / len(rec_cats) * 100 if rec_cats else 0
228
-
229
- explanation.update({
230
- "user_category_preferences": user_preferences,
231
- "alignment_percentage": round(alignment, 1),
232
- "matched_categories": list(user_cats & rec_cats),
233
- "new_categories": list(rec_cats - user_cats)
234
- })
235
-
236
- return explanation
237
-
238
-
239
- def demo_enhanced_recommendations():
240
- """Demo the enhanced recommendation engine."""
241
-
242
- print("πŸš€ ENHANCED RECOMMENDATION ENGINE DEMO")
243
- print("="*70)
244
-
245
- # Initialize enhanced engine
246
- engine = EnhancedRecommendationEngine()
247
-
248
- # Get a real user for testing
249
- real_user_selector = RealUserSelector()
250
- test_users = real_user_selector.get_real_users(n=3, min_interactions=15)
251
-
252
- for user in test_users:
253
- print(f"\nπŸ“Š Testing User {user['user_id']} ({user['age']}yr {user['gender']}):")
254
- print(f" Interaction History: {len(user['interaction_history'])} items")
255
-
256
- # Test different recommendation methods
257
- methods = [
258
- ("Original Hybrid", lambda: engine.recommend_items_hybrid(
259
- age=user['age'],
260
- gender=user['gender'],
261
- income=user['income'],
262
- interaction_history=user['interaction_history'][:20],
263
- k=10,
264
- collaborative_weight=0.7
265
- )),
266
- ("Enhanced Hybrid", lambda: engine.recommend_items_enhanced_hybrid(
267
- age=user['age'],
268
- gender=user['gender'],
269
- income=user['income'],
270
- interaction_history=user['interaction_history'][:20],
271
- k=10,
272
- collaborative_weight=0.7,
273
- category_boost=1.5
274
- )),
275
- ("Category Focused", lambda: engine.recommend_items_category_focused(
276
- age=user['age'],
277
- gender=user['gender'],
278
- income=user['income'],
279
- interaction_history=user['interaction_history'][:20],
280
- k=10,
281
- focus_percentage=0.8
282
- ))
283
- ]
284
-
285
- for method_name, method_func in methods:
286
- try:
287
- recs = method_func()
288
- explanation = engine.get_recommendation_explanation(
289
- recs, user['interaction_history'][:20]
290
- )
291
-
292
- print(f"\n 🎯 {method_name}:")
293
- print(f" Categories: {explanation.get('category_breakdown', {})}")
294
- print(f" Alignment: {explanation.get('alignment_percentage', 'N/A')}%")
295
-
296
- except Exception as e:
297
- print(f" ❌ Error: {str(e)[:40]}...")
298
-
299
- print(f"\nβœ… Enhanced recommendation engine demo completed!")
300
-
301
-
302
- if __name__ == "__main__":
303
- demo_enhanced_recommendations()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/inference/enhanced_recommendation_engine_128d.py DELETED
@@ -1,499 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Enhanced recommendation engine using 128D embeddings with diversity regularization.
4
- """
5
-
6
- import numpy as np
7
- import pandas as pd
8
- import tensorflow as tf
9
- import pickle
10
- import os
11
- from typing import Dict, List, Tuple, Optional
12
- from collections import Counter, defaultdict
13
-
14
- import sys
15
- sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
16
-
17
- from src.models.enhanced_two_tower import EnhancedItemTower, EnhancedUserTower
18
- from src.inference.faiss_index import FAISSItemIndex
19
- from src.preprocessing.data_loader import DataProcessor
20
- from src.preprocessing.user_data_preparation import prepare_user_features
21
- from src.utils.real_user_selector import RealUserSelector
22
-
23
-
24
- class Enhanced128DRecommendationEngine:
25
- """Enhanced recommendation engine with 128D embeddings and all improvements."""
26
-
27
- def __init__(self, artifacts_path: str = "src/artifacts/"):
28
- self.artifacts_path = artifacts_path
29
- self.embedding_dim = 128 # Fixed to 128D
30
-
31
- # Model components
32
- self.item_tower = None
33
- self.user_tower = None
34
- self.rating_model = None
35
- self.faiss_index = None
36
- self.data_processor = None
37
-
38
- # Data
39
- self.items_df = None
40
- self.users_df = None
41
- self.income_thresholds = None
42
-
43
- # Load all components
44
- self._load_all_components()
45
-
46
- def _load_all_components(self):
47
- """Load all enhanced model components."""
48
-
49
- print("Loading enhanced 128D recommendation engine...")
50
-
51
- # Load data processor
52
- self.data_processor = DataProcessor()
53
- try:
54
- self.data_processor.load_vocabularies(f"{self.artifacts_path}/vocabularies.pkl")
55
- except FileNotFoundError:
56
- print("❌ Vocabularies not found. Please train the model first.")
57
- return
58
-
59
- # Load datasets
60
- self.items_df = pd.read_csv("datasets/items.csv")
61
- self.users_df = pd.read_csv("datasets/users.csv")
62
-
63
- # Load enhanced model components
64
- self._load_enhanced_models()
65
-
66
- # Load FAISS index with 128D
67
- try:
68
- self.faiss_index = FAISSItemIndex(embedding_dim=self.embedding_dim)
69
- # Try to load enhanced embeddings first
70
- if os.path.exists(f"{self.artifacts_path}/enhanced_item_embeddings.npy"):
71
- enhanced_embeddings = np.load(
72
- f"{self.artifacts_path}/enhanced_item_embeddings.npy",
73
- allow_pickle=True
74
- ).item()
75
- self.faiss_index.build_index(enhanced_embeddings)
76
- print("βœ… Loaded enhanced 128D FAISS index")
77
- else:
78
- print("⚠️ Enhanced embeddings not found. Train enhanced model first.")
79
- self.faiss_index = None
80
- except Exception as e:
81
- print(f"⚠️ Could not load FAISS index: {e}")
82
- self.faiss_index = None
83
-
84
- # Load income thresholds for categorical demographics
85
- self._load_income_thresholds()
86
-
87
- print("βœ… Enhanced 128D engine loaded successfully!")
88
-
89
- def _load_enhanced_models(self):
90
- """Load enhanced model components."""
91
-
92
- try:
93
- # Create model architecture
94
- self.item_tower = EnhancedItemTower(
95
- item_vocab_size=len(self.data_processor.item_vocab),
96
- category_vocab_size=len(self.data_processor.category_vocab),
97
- brand_vocab_size=len(self.data_processor.brand_vocab),
98
- embedding_dim=self.embedding_dim,
99
- use_bias=True,
100
- use_diversity_reg=False # Disable during inference
101
- )
102
-
103
- self.user_tower = EnhancedUserTower(
104
- max_history_length=50,
105
- embedding_dim=self.embedding_dim,
106
- use_bias=True,
107
- use_diversity_reg=False # Disable during inference
108
- )
109
-
110
- # Create rating model
111
- self.rating_model = tf.keras.Sequential([
112
- tf.keras.layers.Dense(512, activation="relu"),
113
- tf.keras.layers.BatchNormalization(),
114
- tf.keras.layers.Dropout(0.3),
115
- tf.keras.layers.Dense(256, activation="relu"),
116
- tf.keras.layers.BatchNormalization(),
117
- tf.keras.layers.Dropout(0.2),
118
- tf.keras.layers.Dense(64, activation="relu"),
119
- tf.keras.layers.Dense(1, activation="sigmoid")
120
- ])
121
-
122
- # Load weights - try enhanced first, fall back to regular
123
- model_files = [
124
- ('enhanced_item_tower_weights_enhanced_best', 'enhanced_user_tower_weights_enhanced_best', 'enhanced_rating_model_weights_enhanced_best'),
125
- ('enhanced_item_tower_weights_enhanced_final', 'enhanced_user_tower_weights_enhanced_final', 'enhanced_rating_model_weights_enhanced_final'),
126
- ]
127
-
128
- loaded = False
129
- for item_file, user_file, rating_file in model_files:
130
- try:
131
- # Need to build models first with dummy data
132
- self._build_models()
133
-
134
- self.item_tower.load_weights(f"{self.artifacts_path}/{item_file}")
135
- self.user_tower.load_weights(f"{self.artifacts_path}/{user_file}")
136
- self.rating_model.load_weights(f"{self.artifacts_path}/{rating_file}")
137
-
138
- print(f"βœ… Loaded enhanced model: {item_file}")
139
- loaded = True
140
- break
141
- except Exception as e:
142
- print(f"⚠️ Could not load {item_file}: {e}")
143
- continue
144
-
145
- if not loaded:
146
- print("❌ No enhanced model weights found. Please train enhanced model first.")
147
- self.item_tower = None
148
- self.user_tower = None
149
- self.rating_model = None
150
-
151
- except Exception as e:
152
- print(f"❌ Failed to load enhanced models: {e}")
153
- self.item_tower = None
154
- self.user_tower = None
155
- self.rating_model = None
156
-
157
- def _build_models(self):
158
- """Build models with dummy data to initialize weights."""
159
-
160
- # Dummy item features
161
- dummy_item_features = {
162
- 'product_id': tf.constant([0]),
163
- 'category_id': tf.constant([0]),
164
- 'brand_id': tf.constant([0]),
165
- 'price': tf.constant([100.0])
166
- }
167
-
168
- # Dummy user features
169
- dummy_user_features = {
170
- 'age': tf.constant([2]), # Adult category
171
- 'gender': tf.constant([0]), # Female
172
- 'income': tf.constant([2]), # Middle income
173
- 'item_history_embeddings': tf.constant(np.zeros((1, 50, self.embedding_dim), dtype=np.float32))
174
- }
175
-
176
- # Forward pass to build models
177
- _ = self.item_tower(dummy_item_features, training=False)
178
- _ = self.user_tower(dummy_user_features, training=False)
179
-
180
- # Build rating model
181
- dummy_concat = tf.constant(np.zeros((1, self.embedding_dim * 2), dtype=np.float32))
182
- _ = self.rating_model(dummy_concat, training=False)
183
-
184
- def _load_income_thresholds(self):
185
- """Load income thresholds for categorical processing."""
186
-
187
- # Calculate income thresholds from training data
188
- user_incomes = self.users_df['income'].values
189
- self.income_thresholds = np.percentile(user_incomes, [0, 20, 40, 60, 80, 100])
190
- print(f"Income thresholds: {self.income_thresholds}")
191
-
192
- def categorize_age(self, age: float) -> int:
193
- """Categorize age into 6 groups."""
194
- if age < 18: return 0 # Teen
195
- elif age < 26: return 1 # Young Adult
196
- elif age < 36: return 2 # Adult
197
- elif age < 51: return 3 # Middle Age
198
- elif age < 66: return 4 # Mature
199
- else: return 5 # Senior
200
-
201
- def categorize_income(self, income: float) -> int:
202
- """Categorize income into 5 percentile groups."""
203
- category = np.digitize([income], self.income_thresholds[1:-1])[0]
204
- return min(max(category, 0), 4)
205
-
206
- def categorize_gender(self, gender: str) -> int:
207
- """Categorize gender."""
208
- return 1 if gender.lower() == 'male' else 0
209
-
210
- def get_user_embedding(self,
211
- age: int,
212
- gender: str,
213
- income: float,
214
- interaction_history: List[int] = None) -> np.ndarray:
215
- """Generate user embedding with categorical demographics."""
216
-
217
- if self.user_tower is None:
218
- print("❌ User tower not loaded")
219
- return None
220
-
221
- # Categorize demographics
222
- age_cat = self.categorize_age(age)
223
- gender_cat = self.categorize_gender(gender)
224
- income_cat = self.categorize_income(income)
225
-
226
- # Prepare interaction history embeddings
227
- if interaction_history is None:
228
- interaction_history = []
229
-
230
- # Get item embeddings for history
231
- history_embeddings = np.zeros((50, self.embedding_dim), dtype=np.float32)
232
-
233
- for i, item_id in enumerate(interaction_history[:50]):
234
- if self.faiss_index and item_id in self.faiss_index.item_id_to_idx:
235
- item_emb = self.faiss_index.get_item_embedding(item_id)
236
- if item_emb is not None:
237
- history_embeddings[i] = item_emb
238
-
239
- # Create user features
240
- user_features = {
241
- 'age': tf.constant([age_cat]),
242
- 'gender': tf.constant([gender_cat]),
243
- 'income': tf.constant([income_cat]),
244
- 'item_history_embeddings': tf.constant([history_embeddings])
245
- }
246
-
247
- # Get embedding
248
- user_output = self.user_tower(user_features, training=False)
249
- if isinstance(user_output, tuple):
250
- user_embedding = user_output[0].numpy()[0]
251
- else:
252
- user_embedding = user_output.numpy()[0]
253
-
254
- return user_embedding
255
-
256
- def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
257
- """Get item embedding."""
258
-
259
- if self.faiss_index:
260
- return self.faiss_index.get_item_embedding(item_id)
261
-
262
- # Fallback to model computation
263
- if self.item_tower is None:
264
- return None
265
-
266
- item_row = self.items_df[self.items_df['product_id'] == item_id]
267
- if item_row.empty:
268
- return None
269
-
270
- item_data = item_row.iloc[0]
271
-
272
- # Prepare features
273
- item_features = {
274
- 'product_id': tf.constant([self.data_processor.item_vocab.get(item_id, 0)]),
275
- 'category_id': tf.constant([self.data_processor.category_vocab.get(item_data['category_id'], 0)]),
276
- 'brand_id': tf.constant([self.data_processor.brand_vocab.get(item_data.get('brand', 'unknown'), 0)]),
277
- 'price': tf.constant([float(item_data.get('price', 0.0))])
278
- }
279
-
280
- # Get embedding
281
- item_output = self.item_tower(item_features, training=False)
282
- if isinstance(item_output, tuple):
283
- item_embedding = item_output[0].numpy()[0]
284
- else:
285
- item_embedding = item_output.numpy()[0]
286
-
287
- return item_embedding
288
-
289
- def recommend_items_enhanced(self,
290
- age: int,
291
- gender: str,
292
- income: float,
293
- interaction_history: List[int] = None,
294
- k: int = 10,
295
- diversity_weight: float = 0.3,
296
- category_boost: float = 1.5) -> List[Tuple[int, float, Dict]]:
297
- """Generate enhanced recommendations with diversity and category boosting."""
298
-
299
- if not self.faiss_index:
300
- print("❌ FAISS index not available")
301
- return []
302
-
303
- # Get user embedding
304
- user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
305
- if user_embedding is None:
306
- return []
307
-
308
- # Get candidate recommendations (more than needed for filtering)
309
- candidates = self.faiss_index.search_by_embedding(user_embedding, k * 3)
310
-
311
- # Filter out items from interaction history
312
- if interaction_history:
313
- history_set = set(interaction_history)
314
- candidates = [(item_id, score) for item_id, score in candidates
315
- if item_id not in history_set]
316
-
317
- # Add item metadata and apply enhancements
318
- enhanced_candidates = []
319
-
320
- for item_id, similarity_score in candidates[:k * 2]:
321
- # Get item info
322
- item_row = self.items_df[self.items_df['product_id'] == item_id]
323
- if item_row.empty:
324
- continue
325
-
326
- item_info = item_row.iloc[0].to_dict()
327
-
328
- # Enhanced scoring with multiple factors
329
- final_score = similarity_score
330
-
331
- # Category boosting based on user history
332
- if interaction_history and category_boost > 1.0:
333
- user_categories = self._get_user_categories(interaction_history)
334
- item_category = item_info.get('category_code', '')
335
-
336
- if item_category in user_categories:
337
- category_preference = user_categories[item_category]
338
- final_score *= (1 + (category_boost - 1) * category_preference)
339
-
340
- enhanced_candidates.append((item_id, final_score, item_info))
341
-
342
- # Sort by enhanced scores
343
- enhanced_candidates.sort(key=lambda x: x[1], reverse=True)
344
-
345
- # Apply diversity filtering
346
- if diversity_weight > 0:
347
- diversified_candidates = self._apply_diversity_filter(
348
- enhanced_candidates, diversity_weight
349
- )
350
- else:
351
- diversified_candidates = enhanced_candidates
352
-
353
- return diversified_candidates[:k]
354
-
355
- def _get_user_categories(self, interaction_history: List[int]) -> Dict[str, float]:
356
- """Get user's category preferences from history."""
357
-
358
- category_counts = Counter()
359
-
360
- for item_id in interaction_history:
361
- item_row = self.items_df[self.items_df['product_id'] == item_id]
362
- if not item_row.empty:
363
- category = item_row.iloc[0].get('category_code', 'Unknown')
364
- category_counts[category] += 1
365
-
366
- # Convert to preferences (percentages)
367
- total = sum(category_counts.values())
368
- if total == 0:
369
- return {}
370
-
371
- return {cat: count / total for cat, count in category_counts.items()}
372
-
373
- def _apply_diversity_filter(self,
374
- candidates: List[Tuple[int, float, Dict]],
375
- diversity_weight: float,
376
- max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
377
- """Apply diversity filtering to recommendations."""
378
-
379
- category_counts = defaultdict(int)
380
- diversified = []
381
-
382
- for item_id, score, item_info in candidates:
383
- category = item_info.get('category_code', 'Unknown')
384
-
385
- # Apply diversity penalty
386
- if category_counts[category] >= max_per_category:
387
- # Penalty for over-representation
388
- diversity_penalty = diversity_weight * (category_counts[category] - max_per_category + 1)
389
- adjusted_score = score * (1 - diversity_penalty)
390
- else:
391
- adjusted_score = score
392
-
393
- diversified.append((item_id, adjusted_score, item_info))
394
- category_counts[category] += 1
395
-
396
- # Re-sort by adjusted scores
397
- diversified.sort(key=lambda x: x[1], reverse=True)
398
- return diversified
399
-
400
- def predict_rating(self,
401
- age: int,
402
- gender: str,
403
- income: float,
404
- item_id: int,
405
- interaction_history: List[int] = None) -> float:
406
- """Predict rating for user-item pair."""
407
-
408
- if self.rating_model is None:
409
- return 0.5 # Default rating
410
-
411
- # Get embeddings
412
- user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
413
- item_embedding = self.get_item_embedding(item_id)
414
-
415
- if user_embedding is None or item_embedding is None:
416
- return 0.5
417
-
418
- # Concatenate embeddings
419
- combined = np.concatenate([user_embedding, item_embedding])
420
- combined = tf.constant([combined])
421
-
422
- # Predict rating
423
- rating = self.rating_model(combined, training=False)
424
- return float(rating.numpy()[0][0])
425
-
426
-
427
- def demo_enhanced_engine():
428
- """Demo the enhanced 128D recommendation engine."""
429
-
430
- print("πŸš€ ENHANCED 128D RECOMMENDATION ENGINE DEMO")
431
- print("="*70)
432
-
433
- try:
434
- # Initialize engine
435
- engine = Enhanced128DRecommendationEngine()
436
-
437
- if engine.item_tower is None:
438
- print("❌ Enhanced model not available. Please train first using:")
439
- print(" python train_enhanced_model.py")
440
- return
441
-
442
- # Get real user for testing
443
- real_user_selector = RealUserSelector()
444
- test_users = real_user_selector.get_real_users(n=2, min_interactions=10)
445
-
446
- for user in test_users:
447
- print(f"\nπŸ“Š Testing User {user['user_id']} ({user['age']}yr {user['gender']}):")
448
- print(f" Income: ${user['income']:,}")
449
- print(f" History: {len(user['interaction_history'])} items")
450
-
451
- # Test enhanced recommendations
452
- try:
453
- recs = engine.recommend_items_enhanced(
454
- age=user['age'],
455
- gender=user['gender'],
456
- income=user['income'],
457
- interaction_history=user['interaction_history'][:20],
458
- k=10,
459
- diversity_weight=0.3,
460
- category_boost=1.5
461
- )
462
-
463
- print(f" 🎯 Enhanced Recommendations:")
464
- categories = []
465
- for i, (item_id, score, item_info) in enumerate(recs[:5]):
466
- category = item_info.get('category_code', 'Unknown')[:30]
467
- price = item_info.get('price', 0)
468
- categories.append(category)
469
- print(f" #{i+1} Item {item_id}: {score:.4f} | ${price:.2f} | {category}")
470
-
471
- # Analyze diversity
472
- unique_categories = len(set(categories))
473
- print(f" πŸ“ˆ Diversity: {unique_categories}/{len(categories)} unique categories")
474
-
475
- # Test rating prediction
476
- if recs:
477
- test_item = recs[0][0]
478
- predicted_rating = engine.predict_rating(
479
- age=user['age'],
480
- gender=user['gender'],
481
- income=user['income'],
482
- item_id=test_item,
483
- interaction_history=user['interaction_history'][:20]
484
- )
485
- print(f" ⭐ Rating prediction for item {test_item}: {predicted_rating:.3f}")
486
-
487
- except Exception as e:
488
- print(f" ❌ Error: {e}")
489
-
490
- print(f"\nβœ… Enhanced 128D engine demo completed!")
491
-
492
- except Exception as e:
493
- print(f"❌ Demo failed: {e}")
494
- import traceback
495
- traceback.print_exc()
496
-
497
-
498
- if __name__ == "__main__":
499
- demo_enhanced_engine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/inference/recommendation_engine.py CHANGED
@@ -61,6 +61,50 @@ class RecommendationEngine:
61
  category = np.digitize([income], self.income_thresholds[1:-1])[0]
62
  return min(max(category, 0), 4)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def _load_all_components(self):
65
  """Load all required components for inference."""
66
 
@@ -139,6 +183,10 @@ class RecommendationEngine:
139
  'age': tf.constant([2]), # Adult category (26-35)
140
  'gender': tf.constant([1]), # Male
141
  'income': tf.constant([2]), # Middle income category
 
 
 
 
142
  'item_history_embeddings': tf.constant([[[0.0] * 128] * 50]) # Changed from 64 to 128
143
  }
144
  _ = self.user_tower(dummy_input)
@@ -195,6 +243,10 @@ class RecommendationEngine:
195
  age: int,
196
  gender: str,
197
  income: float,
 
 
 
 
198
  interaction_history: List[int] = None) -> Dict[str, tf.Tensor]:
199
  """Prepare user features for inference."""
200
 
@@ -204,9 +256,13 @@ class RecommendationEngine:
204
  # Convert gender
205
  gender_numeric = 1 if gender.lower() == 'male' else 0
206
 
207
- # Categorize age and income
208
  age_category = self.categorize_age(age)
209
  income_category = self.categorize_income(income)
 
 
 
 
210
 
211
  # Get item embeddings for history
212
  history_embeddings = []
@@ -235,6 +291,10 @@ class RecommendationEngine:
235
  'age': tf.constant([age_category]), # Categorical age (0-5)
236
  'gender': tf.constant([gender_numeric]), # Categorical gender (0-1)
237
  'income': tf.constant([income_category]), # Categorical income (0-4)
 
 
 
 
238
  'item_history_embeddings': tf.constant([history_embeddings])
239
  }
240
 
@@ -275,14 +335,77 @@ class RecommendationEngine:
275
  age: int,
276
  gender: str,
277
  income: float,
 
 
 
 
278
  interaction_history: List[int] = None) -> np.ndarray:
279
  """Get user embedding from user tower."""
280
 
281
- user_features = self.prepare_user_features(age, gender, income, interaction_history)
282
  user_embedding = self.user_tower(user_features, training=False)
283
 
284
  return user_embedding.numpy()[0]
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
287
  """Get item embedding from FAISS index or item tower."""
288
 
@@ -301,14 +424,18 @@ class RecommendationEngine:
301
  age: int,
302
  gender: str,
303
  income: float,
 
 
 
 
304
  interaction_history: List[int] = None,
305
  k: int = 10,
306
  exclude_history: bool = True,
307
  category_boost: float = 1.3) -> List[Tuple[int, float, Dict]]:
308
  """Generate recommendations using collaborative filtering with category awareness."""
309
 
310
- # Get user embedding
311
- user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
312
 
313
  # Find similar items using FAISS (get more candidates for boosting)
314
  similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 4)
@@ -354,6 +481,153 @@ class RecommendationEngine:
354
 
355
  return recommendations
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  def recommend_items_content_based(self,
358
  seed_item_id: int,
359
  k: int = 10,
@@ -431,6 +705,10 @@ class RecommendationEngine:
431
  age: int,
432
  gender: str,
433
  income: float,
 
 
 
 
434
  interaction_history: List[int] = None,
435
  k: int = 10,
436
  collaborative_weight: float = 0.7) -> List[Tuple[int, float, Dict]]:
@@ -438,15 +716,16 @@ class RecommendationEngine:
438
 
439
  # Get collaborative recommendations
440
  collab_recs = self.recommend_items_collaborative(
441
- age, gender, income, interaction_history, k * 2
442
  )
443
 
444
- # Get content-based recommendations from recent interactions
445
  content_recs = []
446
  if interaction_history:
447
- # Use most recent item as seed
448
- recent_item = interaction_history[-1]
449
- content_recs = self.recommend_items_content_based(recent_item, k)
 
450
 
451
  # Combine recommendations with weighted scores
452
  item_scores = {}
@@ -484,6 +763,254 @@ class RecommendationEngine:
484
 
485
  return hybrid_recommendations[:k]
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  def _get_item_info(self, item_id: int) -> Dict:
488
  """Get item metadata."""
489
 
@@ -512,6 +1039,10 @@ class RecommendationEngine:
512
  gender: str,
513
  income: float,
514
  item_id: int,
 
 
 
 
515
  interaction_history: List[int] = None) -> float:
516
  """Predict rating for a specific user-item pair."""
517
 
@@ -519,7 +1050,7 @@ class RecommendationEngine:
519
  return 0.5 # Default prediction
520
 
521
  # Prepare user features
522
- user_features = self.prepare_user_features(age, gender, income, interaction_history)
523
 
524
  # Prepare item features
525
  if item_id not in self.data_processor.item_vocab:
@@ -552,6 +1083,10 @@ def main():
552
  'age': 32,
553
  'gender': 'male',
554
  'income': 75000,
 
 
 
 
555
  'interaction_history': [1000978, 1001588, 1001618] # Sample item IDs
556
  }
557
 
@@ -559,20 +1094,28 @@ def main():
559
  print(f"Age: {demo_user['age']}")
560
  print(f"Gender: {demo_user['gender']}")
561
  print(f"Income: ${demo_user['income']:,}")
 
 
 
 
562
  print(f"Interaction history: {demo_user['interaction_history']}")
563
 
564
  # Generate collaborative recommendations
565
- print("\n=== Collaborative Filtering Recommendations ===")
566
- collab_recs = engine.recommend_items_collaborative(**demo_user, k=5)
 
 
 
 
567
 
568
  for i, (item_id, score, info) in enumerate(collab_recs, 1):
569
  print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
570
 
571
- # Generate content-based recommendations
572
- print("\n=== Content-Based Recommendations (similar to recent item) ===")
573
  if demo_user['interaction_history']:
574
- content_recs = engine.recommend_items_content_based(
575
- seed_item_id=demo_user['interaction_history'][-1], k=5
576
  )
577
 
578
  for i, (item_id, score, info) in enumerate(content_recs, 1):
@@ -580,7 +1123,9 @@ def main():
580
 
581
  # Generate hybrid recommendations
582
  print("\n=== Hybrid Recommendations ===")
583
- hybrid_recs = engine.recommend_items_hybrid(**demo_user, k=5)
 
 
584
 
585
  for i, (item_id, score, info) in enumerate(hybrid_recs, 1):
586
  print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
 
61
  category = np.digitize([income], self.income_thresholds[1:-1])[0]
62
  return min(max(category, 0), 4)
63
 
64
+ def categorize_profession(self, profession: str) -> int:
65
+ """Categorize profession into numeric categories."""
66
+ profession_map = {
67
+ "Technology": 0,
68
+ "Healthcare": 1,
69
+ "Education": 2,
70
+ "Finance": 3,
71
+ "Retail": 4,
72
+ "Manufacturing": 5,
73
+ "Services": 6,
74
+ "Other": 7
75
+ }
76
+ return profession_map.get(profession, 7) # Default to "Other"
77
+
78
+ def categorize_location(self, location: str) -> int:
79
+ """Categorize location into numeric categories."""
80
+ location_map = {
81
+ "Urban": 0,
82
+ "Suburban": 1,
83
+ "Rural": 2
84
+ }
85
+ return location_map.get(location, 0) # Default to "Urban"
86
+
87
+ def categorize_education_level(self, education: str) -> int:
88
+ """Categorize education level into numeric categories."""
89
+ education_map = {
90
+ "High School": 0,
91
+ "Some College": 1,
92
+ "Bachelor's": 2,
93
+ "Master's": 3,
94
+ "PhD+": 4
95
+ }
96
+ return education_map.get(education, 0) # Default to "High School"
97
+
98
+ def categorize_marital_status(self, marital_status: str) -> int:
99
+ """Categorize marital status into numeric categories."""
100
+ marital_map = {
101
+ "Single": 0,
102
+ "Married": 1,
103
+ "Divorced": 2,
104
+ "Widowed": 3
105
+ }
106
+ return marital_map.get(marital_status, 0) # Default to "Single"
107
+
108
  def _load_all_components(self):
109
  """Load all required components for inference."""
110
 
 
183
  'age': tf.constant([2]), # Adult category (26-35)
184
  'gender': tf.constant([1]), # Male
185
  'income': tf.constant([2]), # Middle income category
186
+ 'profession': tf.constant([0]), # Technology
187
+ 'location': tf.constant([0]), # Urban
188
+ 'education_level': tf.constant([2]), # Bachelor's
189
+ 'marital_status': tf.constant([1]), # Married
190
  'item_history_embeddings': tf.constant([[[0.0] * 128] * 50]) # Changed from 64 to 128
191
  }
192
  _ = self.user_tower(dummy_input)
 
243
  age: int,
244
  gender: str,
245
  income: float,
246
+ profession: str = "Other",
247
+ location: str = "Urban",
248
+ education_level: str = "High School",
249
+ marital_status: str = "Single",
250
  interaction_history: List[int] = None) -> Dict[str, tf.Tensor]:
251
  """Prepare user features for inference."""
252
 
 
256
  # Convert gender
257
  gender_numeric = 1 if gender.lower() == 'male' else 0
258
 
259
+ # Categorize all demographics
260
  age_category = self.categorize_age(age)
261
  income_category = self.categorize_income(income)
262
+ profession_category = self.categorize_profession(profession)
263
+ location_category = self.categorize_location(location)
264
+ education_category = self.categorize_education_level(education_level)
265
+ marital_category = self.categorize_marital_status(marital_status)
266
 
267
  # Get item embeddings for history
268
  history_embeddings = []
 
291
  'age': tf.constant([age_category]), # Categorical age (0-5)
292
  'gender': tf.constant([gender_numeric]), # Categorical gender (0-1)
293
  'income': tf.constant([income_category]), # Categorical income (0-4)
294
+ 'profession': tf.constant([profession_category]), # Categorical profession (0-7)
295
+ 'location': tf.constant([location_category]), # Categorical location (0-2)
296
+ 'education_level': tf.constant([education_category]), # Categorical education (0-4)
297
+ 'marital_status': tf.constant([marital_category]), # Categorical marital status (0-3)
298
  'item_history_embeddings': tf.constant([history_embeddings])
299
  }
300
 
 
335
  age: int,
336
  gender: str,
337
  income: float,
338
+ profession: str = "Other",
339
+ location: str = "Urban",
340
+ education_level: str = "High School",
341
+ marital_status: str = "Single",
342
  interaction_history: List[int] = None) -> np.ndarray:
343
  """Get user embedding from user tower."""
344
 
345
+ user_features = self.prepare_user_features(age, gender, income, profession, location, education_level, marital_status, interaction_history)
346
  user_embedding = self.user_tower(user_features, training=False)
347
 
348
  return user_embedding.numpy()[0]
349
 
350
+ def get_user_embedding_enhanced(self,
351
+ age: int,
352
+ gender: str,
353
+ income: float,
354
+ profession: str = "Other",
355
+ location: str = "Urban",
356
+ education_level: str = "High School",
357
+ marital_status: str = "Single",
358
+ interaction_history: List[int] = None) -> np.ndarray:
359
+ """Enhanced user embedding that handles zero interactions better."""
360
+
361
+ # Get base embedding
362
+ base_embedding = self.get_user_embedding(
363
+ age, gender, income, profession, location, education_level, marital_status, interaction_history
364
+ )
365
+
366
+ # Check if this is a zero-interaction user
367
+ has_interactions = interaction_history and len(interaction_history) > 0
368
+
369
+ if not has_interactions:
370
+ # For zero interactions, amplify the demographic component
371
+ # This is a heuristic fix until we retrain the model
372
+
373
+ # Create demographic-enhanced embedding
374
+ demographic_mask = np.ones_like(base_embedding)
375
+
376
+ # Amplify first 50% of dimensions (likely demographic-influenced)
377
+ mid_point = len(base_embedding) // 2
378
+ demographic_mask[:mid_point] *= 3.0 # Strong amplification
379
+
380
+ # Reduce influence of latter dimensions (likely history-influenced)
381
+ demographic_mask[mid_point:] *= 0.2 # Strong reduction
382
+
383
+ enhanced_embedding = base_embedding * demographic_mask
384
+
385
+ # Add demographic-specific variation to differentiate profiles
386
+ demographic_hash = (
387
+ age * 1000 +
388
+ (1 if gender.lower() == 'male' else 0) * 100 +
389
+ int(income / 10000) * 10 +
390
+ self.categorize_profession(profession) * 7 +
391
+ self.categorize_location(location) * 3 +
392
+ self.categorize_education_level(education_level) * 5 +
393
+ self.categorize_marital_status(marital_status) * 2
394
+ )
395
+
396
+ np.random.seed(demographic_hash % 2**32) # Reproducible noise
397
+ demographic_noise = np.random.normal(0, 0.02, base_embedding.shape) # Increased noise
398
+ enhanced_embedding += demographic_noise
399
+
400
+ # Renormalize
401
+ enhanced_embedding = enhanced_embedding / np.linalg.norm(enhanced_embedding)
402
+
403
+ print(f"Enhanced embedding for zero interactions: age={age}, gender={gender}, profession={profession}")
404
+
405
+ return enhanced_embedding.astype(np.float32)
406
+
407
+ return base_embedding
408
+
409
  def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
410
  """Get item embedding from FAISS index or item tower."""
411
 
 
424
  age: int,
425
  gender: str,
426
  income: float,
427
+ profession: str = "Other",
428
+ location: str = "Urban",
429
+ education_level: str = "High School",
430
+ marital_status: str = "Single",
431
  interaction_history: List[int] = None,
432
  k: int = 10,
433
  exclude_history: bool = True,
434
  category_boost: float = 1.3) -> List[Tuple[int, float, Dict]]:
435
  """Generate recommendations using collaborative filtering with category awareness."""
436
 
437
+ # Get enhanced user embedding (better for zero interactions)
438
+ user_embedding = self.get_user_embedding_enhanced(age, gender, income, profession, location, education_level, marital_status, interaction_history)
439
 
440
  # Find similar items using FAISS (get more candidates for boosting)
441
  similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 4)
 
481
 
482
  return recommendations
483
 
484
+ def _aggregate_user_history_embedding(self,
485
+ interaction_history: List[int],
486
+ aggregation_method: str = "weighted_mean") -> Optional[np.ndarray]:
487
+ """Aggregate user's interaction history into a single embedding vector."""
488
+
489
+ if not interaction_history:
490
+ return None
491
+
492
+ # Get embeddings for items in history
493
+ item_embeddings = []
494
+ valid_items = []
495
+
496
+ for item_id in interaction_history:
497
+ embedding = self.faiss_index.get_item_embedding(item_id)
498
+ if embedding is not None:
499
+ item_embeddings.append(embedding)
500
+ valid_items.append(item_id)
501
+
502
+ if not item_embeddings:
503
+ print(f"No valid embeddings found for interaction history: {interaction_history}")
504
+ return None
505
+
506
+ item_embeddings = np.array(item_embeddings)
507
+ print(f"Aggregating {len(item_embeddings)} item embeddings using {aggregation_method}")
508
+
509
+ # Apply aggregation method
510
+ if aggregation_method == "mean":
511
+ # Simple mean pooling
512
+ aggregated = np.mean(item_embeddings, axis=0)
513
+
514
+ elif aggregation_method == "weighted_mean":
515
+ # Weight recent interactions higher (exponential decay)
516
+ weights = np.exp(np.linspace(-1, 0, len(item_embeddings))) # More recent = higher weight
517
+ weights = weights / np.sum(weights) # Normalize weights
518
+ aggregated = np.average(item_embeddings, axis=0, weights=weights)
519
+ print(f"Applied weighted mean with weights: {weights[-3:]} (showing last 3)")
520
+
521
+ elif aggregation_method == "max":
522
+ # Element-wise maximum pooling
523
+ aggregated = np.max(item_embeddings, axis=0)
524
+
525
+ else:
526
+ raise ValueError(f"Unknown aggregation method: {aggregation_method}")
527
+
528
+ # L2 normalize the aggregated embedding
529
+ aggregated = aggregated / np.linalg.norm(aggregated)
530
+
531
+ return aggregated.astype('float32')
532
+
533
+ def recommend_items_content_based_from_history(self,
534
+ interaction_history: List[int],
535
+ k: int = 10,
536
+ aggregation_method: str = "weighted_mean",
537
+ same_category_ratio: float = None) -> List[Tuple[int, float, Dict]]:
538
+ """Generate recommendations using content-based filtering from aggregated user history."""
539
+
540
+ # Aggregate user's interaction history
541
+ aggregated_embedding = self._aggregate_user_history_embedding(
542
+ interaction_history, aggregation_method
543
+ )
544
+
545
+ if aggregated_embedding is None:
546
+ print("Could not create aggregated embedding from interaction history")
547
+ return []
548
+
549
+ if same_category_ratio is None:
550
+ # Direct ANN search with aggregated embedding
551
+ similar_items = self.faiss_index.search_by_embedding(aggregated_embedding, k)
552
+ recommendations = []
553
+
554
+ # Filter out items already in interaction history
555
+ interaction_set = set(interaction_history)
556
+
557
+ for item_id, score in similar_items:
558
+ if item_id not in interaction_set: # Exclude already interacted items
559
+ item_info = self._get_item_info(item_id)
560
+ recommendations.append((item_id, score, item_info))
561
+
562
+ if len(recommendations) >= k:
563
+ break
564
+
565
+ print(f"Found {len(recommendations)} content-based recommendations from aggregated history")
566
+ return recommendations
567
+
568
+ else:
569
+ # Category-aware approach with aggregated embedding
570
+ print(f"Finding similar items with {same_category_ratio*100}% category constraint from aggregated history")
571
+
572
+ # Analyze user's category preferences from interaction history
573
+ user_categories = {}
574
+ total_interactions = len(interaction_history)
575
+
576
+ for item_id in interaction_history:
577
+ item_info = self._get_item_info(item_id)
578
+ category = item_info.get('category_code', '')
579
+ if category:
580
+ user_categories[category] = user_categories.get(category, 0) + 1
581
+
582
+ # Convert to percentages
583
+ for category in user_categories:
584
+ user_categories[category] = user_categories[category] / total_interactions
585
+
586
+ print(f"User category preferences: {user_categories}")
587
+
588
+ # Get more candidates for category filtering
589
+ candidate_items = self.faiss_index.search_by_embedding(aggregated_embedding, k * 3)
590
+ interaction_set = set(interaction_history)
591
+
592
+ # Separate by category alignment with user preferences
593
+ preferred_category_items = []
594
+ other_category_items = []
595
+
596
+ for item_id, score in candidate_items:
597
+ if item_id in interaction_set:
598
+ continue # Skip already interacted items
599
+
600
+ item_info = self._get_item_info(item_id)
601
+ item_category = item_info.get('category_code', '')
602
+
603
+ # Check if item category matches user's preferred categories
604
+ if item_category in user_categories:
605
+ preferred_category_items.append((item_id, score, item_info))
606
+ else:
607
+ other_category_items.append((item_id, score, item_info))
608
+
609
+ # Calculate target distribution
610
+ preferred_count = int(k * same_category_ratio)
611
+ other_count = k - preferred_count
612
+
613
+ print(f"Target: {preferred_count} from preferred categories, {other_count} for exploration")
614
+
615
+ # Build balanced recommendations
616
+ recommendations = []
617
+ recommendations.extend(preferred_category_items[:preferred_count])
618
+ recommendations.extend(other_category_items[:other_count])
619
+
620
+ # Fill remaining slots with best available items
621
+ if len(recommendations) < k:
622
+ remaining_items = (preferred_category_items[preferred_count:] +
623
+ other_category_items[other_count:])
624
+ remaining_items.sort(key=lambda x: x[1], reverse=True) # Sort by score
625
+ needed = k - len(recommendations)
626
+ recommendations.extend(remaining_items[:needed])
627
+
628
+ print(f"Final recommendations: {len(recommendations)} items")
629
+ return recommendations[:k]
630
+
631
  def recommend_items_content_based(self,
632
  seed_item_id: int,
633
  k: int = 10,
 
705
  age: int,
706
  gender: str,
707
  income: float,
708
+ profession: str = "Other",
709
+ location: str = "Urban",
710
+ education_level: str = "High School",
711
+ marital_status: str = "Single",
712
  interaction_history: List[int] = None,
713
  k: int = 10,
714
  collaborative_weight: float = 0.7) -> List[Tuple[int, float, Dict]]:
 
716
 
717
  # Get collaborative recommendations
718
  collab_recs = self.recommend_items_collaborative(
719
+ age, gender, income, profession, location, education_level, marital_status, interaction_history, k * 2
720
  )
721
 
722
+ # Get content-based recommendations from aggregated user history
723
  content_recs = []
724
  if interaction_history:
725
+ # Use aggregated history embedding instead of single recent item
726
+ content_recs = self.recommend_items_content_based_from_history(
727
+ interaction_history, k, aggregation_method="weighted_mean"
728
+ )
729
 
730
  # Combine recommendations with weighted scores
731
  item_scores = {}
 
763
 
764
  return hybrid_recommendations[:k]
765
 
766
+ def recommend_items_category_boosted(self,
767
+ age: int,
768
+ gender: str,
769
+ income: float,
770
+ profession: str = "Other",
771
+ location: str = "Urban",
772
+ education_level: str = "High School",
773
+ marital_status: str = "Single",
774
+ interaction_history: List[int] = None,
775
+ k: int = 10,
776
+ exclude_history: bool = True) -> List[Tuple[int, float, Dict]]:
777
+ """Generate category-boosted recommendations ensuring 50% from user's interacted categories."""
778
+
779
+ if not interaction_history or len(interaction_history) == 0:
780
+ # Fallback to collaborative filtering if no interaction history
781
+ return self.recommend_items_collaborative(
782
+ age, gender, income, profession, location, education_level, marital_status,
783
+ interaction_history, k, exclude_history
784
+ )
785
+
786
+ # Step 1: Calculate category percentages from interaction history
787
+ category_percentages = self._calculate_category_percentages(interaction_history)
788
+
789
+ if not category_percentages:
790
+ # Fallback if no categories found
791
+ return self.recommend_items_collaborative(
792
+ age, gender, income, profession, location, education_level, marital_status,
793
+ interaction_history, k, exclude_history
794
+ )
795
+
796
+ # Step 2: Get enhanced user embedding and do wide search (increased for better subcategory coverage)
797
+ user_embedding = self.get_user_embedding_enhanced(age, gender, income, profession, location, education_level, marital_status, interaction_history)
798
+ similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 10) # Increased from k*6 to k*10
799
+
800
+ # Step 3: Organize candidates by subcategory with parent fallback
801
+ category_candidates = {category: [] for category in category_percentages.keys()}
802
+ parent_category_mapping = {} # Track parent categories for fallback
803
+ other_candidates = []
804
+ history_set = set(interaction_history) if exclude_history else set()
805
+
806
+ # Build parent category mapping for fallback
807
+ for subcategory in category_percentages.keys():
808
+ if '.' in subcategory:
809
+ parent = subcategory.split('.')[0]
810
+ if parent not in parent_category_mapping:
811
+ parent_category_mapping[parent] = []
812
+ parent_category_mapping[parent].append(subcategory)
813
+
814
+ for item_id, score in similar_items:
815
+ if item_id in history_set:
816
+ continue
817
+
818
+ # Get item category
819
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
820
+ if len(item_row) > 0:
821
+ full_item_category = item_row.iloc[0]['category_code']
822
+
823
+ # Extract 2-level subcategory for matching
824
+ if '.' in full_item_category:
825
+ category_parts = full_item_category.split('.')
826
+ if len(category_parts) >= 2:
827
+ item_subcategory = f"{category_parts[0]}.{category_parts[1]}"
828
+ else:
829
+ item_subcategory = category_parts[0]
830
+ else:
831
+ item_subcategory = full_item_category
832
+
833
+ # Try exact subcategory match first
834
+ if item_subcategory in category_percentages:
835
+ category_candidates[item_subcategory].append((item_id, score))
836
+ else:
837
+ # Fallback: try parent category match
838
+ parent_category = item_subcategory.split('.')[0] if '.' in item_subcategory else item_subcategory
839
+ matched = False
840
+
841
+ if parent_category in parent_category_mapping:
842
+ # Add to the first subcategory of this parent (round-robin could be improved later)
843
+ target_subcategory = parent_category_mapping[parent_category][0]
844
+ category_candidates[target_subcategory].append((item_id, score))
845
+ matched = True
846
+
847
+ if not matched:
848
+ other_candidates.append((item_id, score))
849
+
850
+ # Step 4: Calculate target counts for each subcategory (50% distributed proportionally)
851
+ category_target_count = max(1, k // 2) # At least 50% from user categories
852
+
853
+ # Calculate proportional distribution with proper rounding
854
+ category_counts = self._calculate_proportional_distribution(
855
+ category_percentages, category_target_count
856
+ )
857
+
858
+ # Step 5: Select items with round-robin filling and rebalancing
859
+ selected_recommendations = []
860
+
861
+ # Fill from user's categories with rebalancing for insufficient candidates
862
+ actual_selections = {}
863
+ unused_allocations = {}
864
+
865
+ for category, target_count in category_counts.items():
866
+ candidates = sorted(category_candidates[category], key=lambda x: x[1], reverse=True)
867
+ available_count = len(candidates)
868
+ selected_count = min(target_count, available_count)
869
+
870
+ print(f"[DEBUG] Category {category}: target={target_count}, available={available_count}, selected={selected_count}")
871
+
872
+ actual_selections[category] = selected_count
873
+ if selected_count < target_count:
874
+ unused_allocations[category] = target_count - selected_count
875
+
876
+ # Select items from this category
877
+ for i in range(selected_count):
878
+ item_id, score = candidates[i]
879
+ item_info = self._get_item_info(item_id)
880
+ selected_recommendations.append((item_id, score, item_info))
881
+
882
+ # Step 6: Redistribute unused allocations proportionally
883
+ total_unused = sum(unused_allocations.values())
884
+ if total_unused > 0:
885
+ print(f"[DEBUG] Redistributing {total_unused} unused slots")
886
+
887
+ # Find categories with remaining candidates for redistribution
888
+ categories_with_extras = {}
889
+ for category, candidates in category_candidates.items():
890
+ used_count = actual_selections.get(category, 0)
891
+ available_extras = len(candidates) - used_count
892
+ if available_extras > 0:
893
+ categories_with_extras[category] = available_extras
894
+
895
+ # Redistribute based on original proportions and availability
896
+ redistributed = 0
897
+ for category in sorted(categories_with_extras.keys(), key=lambda c: category_percentages.get(c, 0), reverse=True):
898
+ if redistributed >= total_unused:
899
+ break
900
+
901
+ extra_slots = min(unused_allocations.get(category, 0) + 1, categories_with_extras[category])
902
+ candidates = sorted(category_candidates[category], key=lambda x: x[1], reverse=True)
903
+ used_count = actual_selections.get(category, 0)
904
+
905
+ for i in range(used_count, min(used_count + extra_slots, len(candidates))):
906
+ if redistributed >= total_unused:
907
+ break
908
+ item_id, score = candidates[i]
909
+ item_info = self._get_item_info(item_id)
910
+ selected_recommendations.append((item_id, score, item_info))
911
+ redistributed += 1
912
+
913
+ # Step 7: Fill remaining slots with diverse recommendations
914
+ remaining_slots = k - len(selected_recommendations)
915
+ if remaining_slots > 0:
916
+ # Collect all unused candidates (both from user categories and other categories)
917
+ all_remaining = []
918
+
919
+ # Add unused items from user categories
920
+ for category, candidates in category_candidates.items():
921
+ used_count = len([rec for rec in selected_recommendations if rec[2].get('category_code', '').startswith(category.split('.')[0])])
922
+ sorted_candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
923
+ for i in range(used_count, len(sorted_candidates)):
924
+ all_remaining.append(sorted_candidates[i])
925
+
926
+ # Add items from other categories
927
+ all_remaining.extend(other_candidates)
928
+
929
+ # Sort by score and take best remaining
930
+ all_remaining.sort(key=lambda x: x[1], reverse=True)
931
+
932
+ print(f"[DEBUG] Filling {remaining_slots} remaining slots from {len(all_remaining)} candidates")
933
+
934
+ for i in range(min(remaining_slots, len(all_remaining))):
935
+ item_id, score = all_remaining[i]
936
+ item_info = self._get_item_info(item_id)
937
+ selected_recommendations.append((item_id, score, item_info))
938
+
939
+ # Step 7: Sort final recommendations by score and return top k
940
+ selected_recommendations.sort(key=lambda x: x[1], reverse=True)
941
+ return selected_recommendations[:k]
942
+
943
+ def _calculate_category_percentages(self, interaction_history: List[int]) -> Dict[str, float]:
944
+ """Calculate subcategory percentages from interaction history (2-level depth)."""
945
+ if not interaction_history:
946
+ return {}
947
+
948
+ category_counts = {}
949
+ total_interactions = 0
950
+
951
+ for item_id in interaction_history:
952
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
953
+ if len(item_row) > 0:
954
+ full_category = item_row.iloc[0]['category_code']
955
+
956
+ # Use 2-level subcategory (e.g., "computers.components" from "computers.components.memory")
957
+ if '.' in full_category:
958
+ category_parts = full_category.split('.')
959
+ if len(category_parts) >= 2:
960
+ subcategory = f"{category_parts[0]}.{category_parts[1]}"
961
+ else:
962
+ subcategory = category_parts[0] # Fallback to top-level if only one part
963
+ else:
964
+ subcategory = full_category
965
+
966
+ category_counts[subcategory] = category_counts.get(subcategory, 0) + 1
967
+ total_interactions += 1
968
+
969
+ # Convert to percentages
970
+ category_percentages = {}
971
+ for category, count in category_counts.items():
972
+ category_percentages[category] = (count / total_interactions) * 100
973
+
974
+ return category_percentages
975
+
976
+ def _calculate_proportional_distribution(self, category_percentages: Dict[str, float],
977
+ total_target: int) -> Dict[str, int]:
978
+ """Calculate proportional distribution with proper rounding and no minimum distortion."""
979
+ if not category_percentages or total_target <= 0:
980
+ return {}
981
+
982
+ # Calculate raw allocations (without minimum guarantee)
983
+ total_percentage = sum(category_percentages.values())
984
+ raw_allocations = {}
985
+ remainders = {}
986
+
987
+ for category, percentage in category_percentages.items():
988
+ if total_percentage > 0:
989
+ raw_allocation = (percentage / total_percentage) * total_target
990
+ raw_allocations[category] = int(raw_allocation) # Floor
991
+ remainders[category] = raw_allocation - int(raw_allocation) # Remainder
992
+ else:
993
+ raw_allocations[category] = 0
994
+ remainders[category] = 0
995
+
996
+ # Distribute remaining slots based on largest remainders
997
+ allocated_so_far = sum(raw_allocations.values())
998
+ remaining_slots = total_target - allocated_so_far
999
+
1000
+ # Sort categories by remainder (largest first) to distribute remaining slots
1001
+ sorted_by_remainder = sorted(remainders.items(), key=lambda x: x[1], reverse=True)
1002
+
1003
+ for i in range(remaining_slots):
1004
+ if i < len(sorted_by_remainder):
1005
+ category_to_increment = sorted_by_remainder[i][0]
1006
+ raw_allocations[category_to_increment] += 1
1007
+
1008
+ # Filter out zero allocations (no artificial minimum guarantee)
1009
+ final_allocations = {cat: count for cat, count in raw_allocations.items() if count > 0}
1010
+
1011
+ print(f"[DEBUG] Proportional distribution: target={total_target}, allocations={final_allocations}")
1012
+ return final_allocations
1013
+
1014
  def _get_item_info(self, item_id: int) -> Dict:
1015
  """Get item metadata."""
1016
 
 
1039
  gender: str,
1040
  income: float,
1041
  item_id: int,
1042
+ profession: str = "Other",
1043
+ location: str = "Urban",
1044
+ education_level: str = "High School",
1045
+ marital_status: str = "Single",
1046
  interaction_history: List[int] = None) -> float:
1047
  """Predict rating for a specific user-item pair."""
1048
 
 
1050
  return 0.5 # Default prediction
1051
 
1052
  # Prepare user features
1053
+ user_features = self.prepare_user_features(age, gender, income, profession, location, education_level, marital_status, interaction_history)
1054
 
1055
  # Prepare item features
1056
  if item_id not in self.data_processor.item_vocab:
 
1083
  'age': 32,
1084
  'gender': 'male',
1085
  'income': 75000,
1086
+ 'profession': 'Technology',
1087
+ 'location': 'Urban',
1088
+ 'education_level': "Bachelor's",
1089
+ 'marital_status': 'Married',
1090
  'interaction_history': [1000978, 1001588, 1001618] # Sample item IDs
1091
  }
1092
 
 
1094
  print(f"Age: {demo_user['age']}")
1095
  print(f"Gender: {demo_user['gender']}")
1096
  print(f"Income: ${demo_user['income']:,}")
1097
+ print(f"Profession: {demo_user['profession']}")
1098
+ print(f"Location: {demo_user['location']}")
1099
+ print(f"Education: {demo_user['education_level']}")
1100
+ print(f"Marital Status: {demo_user['marital_status']}")
1101
  print(f"Interaction history: {demo_user['interaction_history']}")
1102
 
1103
  # Generate collaborative recommendations
1104
+ print("\n=== Collaborative Filtering Recommendations ===")
1105
+ # Extract demographics and history separately to avoid conflicts
1106
+ demo_kwargs = {k: v for k, v in demo_user.items() if k != 'interaction_history'}
1107
+ collab_recs = engine.recommend_items_collaborative(
1108
+ **demo_kwargs, interaction_history=demo_user['interaction_history'], k=5
1109
+ )
1110
 
1111
  for i, (item_id, score, info) in enumerate(collab_recs, 1):
1112
  print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
1113
 
1114
+ # Generate content-based recommendations from aggregated history
1115
+ print("\n=== Content-Based Recommendations (from aggregated user history) ===")
1116
  if demo_user['interaction_history']:
1117
+ content_recs = engine.recommend_items_content_based_from_history(
1118
+ interaction_history=demo_user['interaction_history'], k=5
1119
  )
1120
 
1121
  for i, (item_id, score, info) in enumerate(content_recs, 1):
 
1123
 
1124
  # Generate hybrid recommendations
1125
  print("\n=== Hybrid Recommendations ===")
1126
+ hybrid_recs = engine.recommend_items_hybrid(
1127
+ **demo_kwargs, interaction_history=demo_user['interaction_history'], k=5
1128
+ )
1129
 
1130
  for i, (item_id, score, info) in enumerate(hybrid_recs, 1):
1131
  print(f"{i}. Item {item_id}: {info['brand']} - ${info['price']:.2f} (Score: {score:.4f})")
src/models/enhanced_two_tower.py DELETED
@@ -1,574 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Enhanced two-tower model with embedding diversity regularization and improved discrimination.
4
- """
5
-
6
- import tensorflow as tf
7
- import tensorflow_recommenders as tfrs
8
- import numpy as np
9
-
10
-
11
- class EmbeddingDiversityRegularizer(tf.keras.layers.Layer):
12
- """Regularizer to prevent embedding collapse by enforcing diversity."""
13
-
14
- def __init__(self, diversity_weight=0.01, orthogonality_weight=0.05, **kwargs):
15
- super().__init__(**kwargs)
16
- self.diversity_weight = diversity_weight
17
- self.orthogonality_weight = orthogonality_weight
18
-
19
- def call(self, embeddings):
20
- """Apply diversity regularization to embeddings."""
21
- batch_size = tf.shape(embeddings)[0]
22
-
23
- # Compute pairwise cosine similarities
24
- normalized_embeddings = tf.nn.l2_normalize(embeddings, axis=1)
25
- similarity_matrix = tf.linalg.matmul(
26
- normalized_embeddings, normalized_embeddings, transpose_b=True
27
- )
28
-
29
- # Remove diagonal (self-similarities)
30
- mask = 1.0 - tf.eye(batch_size)
31
- masked_similarities = similarity_matrix * mask
32
-
33
- # Diversity loss: penalize high similarities between different embeddings
34
- diversity_loss = tf.reduce_mean(tf.square(masked_similarities))
35
-
36
- # Orthogonality loss: encourage embeddings to be orthogonal
37
- identity_target = tf.eye(batch_size)
38
- orthogonality_loss = tf.reduce_mean(
39
- tf.square(similarity_matrix - identity_target)
40
- )
41
-
42
- # Add as regularization losses
43
- self.add_loss(self.diversity_weight * diversity_loss)
44
- self.add_loss(self.orthogonality_weight * orthogonality_loss)
45
-
46
- return embeddings
47
-
48
-
49
- class AdaptiveTemperatureScaling(tf.keras.layers.Layer):
50
- """Advanced temperature scaling with learned parameters."""
51
-
52
- def __init__(self, initial_temperature=1.0, min_temp=0.1, max_temp=5.0, **kwargs):
53
- super().__init__(**kwargs)
54
- self.initial_temperature = initial_temperature
55
- self.min_temp = min_temp
56
- self.max_temp = max_temp
57
-
58
- def build(self, input_shape):
59
- # Learnable temperature with constraints
60
- self.raw_temperature = self.add_weight(
61
- name='raw_temperature',
62
- shape=(),
63
- initializer=tf.keras.initializers.Constant(
64
- np.log(self.initial_temperature - self.min_temp)
65
- ),
66
- trainable=True
67
- )
68
-
69
- # Learnable bias term for better discrimination
70
- self.similarity_bias = self.add_weight(
71
- name='similarity_bias',
72
- shape=(),
73
- initializer=tf.keras.initializers.Zeros(),
74
- trainable=True
75
- )
76
-
77
- super().build(input_shape)
78
-
79
- def call(self, user_embeddings, item_embeddings):
80
- """Compute adaptive temperature-scaled similarity with bias."""
81
- # Constrain temperature to valid range
82
- temperature = self.min_temp + tf.nn.softplus(self.raw_temperature)
83
- temperature = tf.minimum(temperature, self.max_temp)
84
-
85
- # Compute similarities
86
- similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
87
-
88
- # Add learnable bias and apply temperature scaling
89
- scaled_similarities = (similarities + self.similarity_bias) / temperature
90
-
91
- return scaled_similarities, temperature
92
-
93
-
94
- class EnhancedItemTower(tf.keras.Model):
95
- """Enhanced item tower with diversity regularization."""
96
-
97
- def __init__(self,
98
- item_vocab_size: int,
99
- category_vocab_size: int,
100
- brand_vocab_size: int,
101
- embedding_dim: int = 128,
102
- hidden_dims: list = [256, 128],
103
- dropout_rate: float = 0.3,
104
- use_bias: bool = True,
105
- use_diversity_reg: bool = True):
106
- super().__init__()
107
-
108
- self.embedding_dim = embedding_dim
109
- self.use_bias = use_bias
110
- self.use_diversity_reg = use_diversity_reg
111
-
112
- # Embedding layers with better initialization
113
- self.item_embedding = tf.keras.layers.Embedding(
114
- item_vocab_size, embedding_dim,
115
- embeddings_initializer='he_normal', # Better initialization
116
- embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
117
- name="item_embedding"
118
- )
119
- self.category_embedding = tf.keras.layers.Embedding(
120
- category_vocab_size, embedding_dim,
121
- embeddings_initializer='he_normal',
122
- embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
123
- name="category_embedding"
124
- )
125
- self.brand_embedding = tf.keras.layers.Embedding(
126
- brand_vocab_size, embedding_dim,
127
- embeddings_initializer='he_normal',
128
- embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
129
- name="brand_embedding"
130
- )
131
-
132
- # Price processing
133
- self.price_normalization = tf.keras.layers.Normalization(name="price_norm")
134
- self.price_projection = tf.keras.layers.Dense(
135
- embedding_dim // 4, activation='relu', name="price_proj"
136
- )
137
-
138
- # Enhanced attention mechanism
139
- self.feature_attention = tf.keras.layers.MultiHeadAttention(
140
- num_heads=4,
141
- key_dim=embedding_dim,
142
- dropout=0.1,
143
- name="feature_attention"
144
- )
145
-
146
- # Dense layers with residual connections
147
- self.dense_layers = []
148
- for i, dim in enumerate(hidden_dims):
149
- self.dense_layers.extend([
150
- tf.keras.layers.Dense(dim, activation=None, name=f"dense_{i}"),
151
- tf.keras.layers.BatchNormalization(name=f"bn_{i}"),
152
- tf.keras.layers.Activation('relu', name=f"relu_{i}"),
153
- tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}")
154
- ])
155
-
156
- # Output layer with controlled normalization
157
- self.output_layer = tf.keras.layers.Dense(
158
- embedding_dim, activation=None, use_bias=use_bias, name="item_output"
159
- )
160
-
161
- # Diversity regularizer
162
- if use_diversity_reg:
163
- self.diversity_regularizer = EmbeddingDiversityRegularizer()
164
-
165
- # Adaptive normalization instead of hard L2 normalization
166
- self.adaptive_norm = tf.keras.layers.LayerNormalization(name="adaptive_norm")
167
-
168
- # Item bias
169
- if use_bias:
170
- self.item_bias = tf.keras.layers.Embedding(
171
- item_vocab_size, 1, name="item_bias"
172
- )
173
-
174
- def call(self, inputs, training=None):
175
- """Enhanced forward pass with diversity regularization."""
176
- item_id = inputs["product_id"]
177
- category_id = inputs["category_id"]
178
- brand_id = inputs["brand_id"]
179
- price = inputs["price"]
180
-
181
- # Get embeddings
182
- item_emb = self.item_embedding(item_id)
183
- category_emb = self.category_embedding(category_id)
184
- brand_emb = self.brand_embedding(brand_id)
185
-
186
- # Process price
187
- price_norm = self.price_normalization(tf.expand_dims(price, -1))
188
- price_emb = self.price_projection(price_norm)
189
-
190
- # Pad price embedding
191
- price_emb_padded = tf.pad(
192
- price_emb,
193
- [[0, 0], [0, self.embedding_dim - tf.shape(price_emb)[-1]]]
194
- )
195
-
196
- # Stack features for attention
197
- features = tf.stack([item_emb, category_emb, brand_emb, price_emb_padded], axis=1)
198
-
199
- # Apply attention
200
- attended_features = self.feature_attention(
201
- query=features,
202
- value=features,
203
- key=features,
204
- training=training
205
- )
206
-
207
- # Aggregate with residual connection
208
- combined = tf.reduce_mean(attended_features + features, axis=1)
209
-
210
- # Pass through dense layers with residual connections
211
- x = combined
212
- residual = x
213
- for i, layer in enumerate(self.dense_layers):
214
- x = layer(x, training=training)
215
- # Add residual connection every 4 layers (complete block)
216
- if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
217
- x = x + residual
218
- residual = x
219
-
220
- # Final output
221
- output = self.output_layer(x)
222
-
223
- # Apply diversity regularization if enabled
224
- if self.use_diversity_reg and training:
225
- output = self.diversity_regularizer(output)
226
-
227
- # Adaptive normalization instead of hard L2
228
- normalized_output = self.adaptive_norm(output)
229
-
230
- # Add bias if enabled
231
- if self.use_bias:
232
- bias = tf.squeeze(self.item_bias(item_id), axis=-1)
233
- return normalized_output, bias
234
- else:
235
- return normalized_output
236
-
237
-
238
- class EnhancedUserTower(tf.keras.Model):
239
- """Enhanced user tower with diversity regularization."""
240
-
241
- def __init__(self,
242
- max_history_length: int = 50,
243
- embedding_dim: int = 128,
244
- hidden_dims: list = [256, 128],
245
- dropout_rate: float = 0.3,
246
- use_bias: bool = True,
247
- use_diversity_reg: bool = True):
248
- super().__init__()
249
-
250
- self.embedding_dim = embedding_dim
251
- self.max_history_length = max_history_length
252
- self.use_bias = use_bias
253
- self.use_diversity_reg = use_diversity_reg
254
-
255
- # Demographic embeddings with regularization
256
- self.age_embedding = tf.keras.layers.Embedding(
257
- 6, embedding_dim // 16,
258
- embeddings_initializer='he_normal',
259
- embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
260
- name="age_embedding"
261
- )
262
- self.income_embedding = tf.keras.layers.Embedding(
263
- 5, embedding_dim // 16,
264
- embeddings_initializer='he_normal',
265
- embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
266
- name="income_embedding"
267
- )
268
- self.gender_embedding = tf.keras.layers.Embedding(
269
- 2, embedding_dim // 16,
270
- embeddings_initializer='he_normal',
271
- embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
272
- name="gender_embedding"
273
- )
274
-
275
- # Enhanced history processing
276
- self.history_transformer = tf.keras.layers.MultiHeadAttention(
277
- num_heads=8,
278
- key_dim=embedding_dim,
279
- dropout=0.1,
280
- name="history_transformer"
281
- )
282
-
283
- # History aggregation with attention pooling
284
- self.history_attention_pooling = tf.keras.layers.Dense(
285
- 1, activation=None, name="history_attention"
286
- )
287
-
288
- # Dense layers with residual connections
289
- self.dense_layers = []
290
- for i, dim in enumerate(hidden_dims):
291
- self.dense_layers.extend([
292
- tf.keras.layers.Dense(dim, activation=None, name=f"user_dense_{i}"),
293
- tf.keras.layers.BatchNormalization(name=f"user_bn_{i}"),
294
- tf.keras.layers.Activation('relu', name=f"user_relu_{i}"),
295
- tf.keras.layers.Dropout(dropout_rate, name=f"user_dropout_{i}")
296
- ])
297
-
298
- # Output layer
299
- self.output_layer = tf.keras.layers.Dense(
300
- embedding_dim, activation=None, use_bias=use_bias, name="user_output"
301
- )
302
-
303
- # Diversity regularizer
304
- if use_diversity_reg:
305
- self.diversity_regularizer = EmbeddingDiversityRegularizer()
306
-
307
- # Adaptive normalization
308
- self.adaptive_norm = tf.keras.layers.LayerNormalization(name="user_adaptive_norm")
309
-
310
- # Global user bias
311
- if use_bias:
312
- self.global_user_bias = tf.Variable(
313
- initial_value=0.0, trainable=True, name="global_user_bias"
314
- )
315
-
316
- def call(self, inputs, training=None):
317
- """Enhanced forward pass with diversity regularization."""
318
- age = inputs["age"]
319
- gender = inputs["gender"]
320
- income = inputs["income"]
321
- item_history = inputs["item_history_embeddings"]
322
-
323
- # Process demographics
324
- age_emb = self.age_embedding(age)
325
- income_emb = self.income_embedding(income)
326
- gender_emb = self.gender_embedding(gender)
327
-
328
- # Combine demographics
329
- demo_combined = tf.concat([age_emb, income_emb, gender_emb], axis=-1)
330
-
331
- # Enhanced history processing
332
- batch_size = tf.shape(item_history)[0]
333
- seq_len = tf.shape(item_history)[1]
334
-
335
- # Simplified positional encoding - ensure shape compatibility
336
- positions = tf.range(seq_len, dtype=tf.float32)
337
- # Create simpler positional encoding
338
- pos_encoding_scale = tf.cast(tf.range(self.embedding_dim, dtype=tf.float32), tf.float32) / self.embedding_dim
339
- position_encoding = tf.sin(positions[:, tf.newaxis] * pos_encoding_scale[tf.newaxis, :])
340
-
341
- # Ensure correct shape: [seq_len, embedding_dim] -> [batch_size, seq_len, embedding_dim]
342
- position_encoding = tf.expand_dims(position_encoding, 0)
343
- position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
344
-
345
- # Add positional encoding with shape check
346
- history_with_pos = item_history + position_encoding
347
-
348
- # Create attention mask - fix shape for MultiHeadAttention
349
- # MultiHeadAttention expects mask shape: [batch_size, seq_len] or [batch_size, seq_len, seq_len]
350
- history_mask = tf.reduce_sum(tf.abs(item_history), axis=-1) > 0 # [batch_size, seq_len]
351
-
352
- # Apply transformer attention
353
- attended_history = self.history_transformer(
354
- query=history_with_pos,
355
- value=history_with_pos,
356
- key=history_with_pos,
357
- attention_mask=history_mask,
358
- training=training
359
- )
360
-
361
- # Attention-based pooling instead of simple mean
362
- attention_weights = tf.nn.softmax(
363
- self.history_attention_pooling(attended_history), axis=1
364
- )
365
- history_aggregated = tf.reduce_sum(
366
- attended_history * attention_weights, axis=1
367
- )
368
-
369
- # Combine features
370
- combined = tf.concat([demo_combined, history_aggregated], axis=-1)
371
-
372
- # Pass through dense layers with residual connections
373
- x = combined
374
- residual = x
375
- for i, layer in enumerate(self.dense_layers):
376
- x = layer(x, training=training)
377
- # Add residual connection every 4 layers
378
- if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
379
- x = x + residual
380
- residual = x
381
-
382
- # Final output
383
- output = self.output_layer(x)
384
-
385
- # Apply diversity regularization if enabled
386
- if self.use_diversity_reg and training:
387
- output = self.diversity_regularizer(output)
388
-
389
- # Adaptive normalization
390
- normalized_output = self.adaptive_norm(output)
391
-
392
- # Add bias if enabled
393
- if self.use_bias:
394
- return normalized_output, self.global_user_bias
395
- else:
396
- return normalized_output
397
-
398
-
399
- class EnhancedTwoTowerModel(tfrs.Model):
400
- """Enhanced two-tower model with all improvements."""
401
-
402
- def __init__(self,
403
- item_tower: EnhancedItemTower,
404
- user_tower: EnhancedUserTower,
405
- rating_weight: float = 1.0,
406
- retrieval_weight: float = 1.0,
407
- contrastive_weight: float = 0.3,
408
- diversity_weight: float = 0.1):
409
- super().__init__()
410
-
411
- self.item_tower = item_tower
412
- self.user_tower = user_tower
413
- self.rating_weight = rating_weight
414
- self.retrieval_weight = retrieval_weight
415
- self.contrastive_weight = contrastive_weight
416
- self.diversity_weight = diversity_weight
417
-
418
- # Adaptive temperature scaling
419
- self.temperature_similarity = AdaptiveTemperatureScaling()
420
-
421
- # Enhanced rating model
422
- self.rating_model = tf.keras.Sequential([
423
- tf.keras.layers.Dense(512, activation="relu"),
424
- tf.keras.layers.BatchNormalization(),
425
- tf.keras.layers.Dropout(0.3),
426
- tf.keras.layers.Dense(256, activation="relu"),
427
- tf.keras.layers.BatchNormalization(),
428
- tf.keras.layers.Dropout(0.2),
429
- tf.keras.layers.Dense(64, activation="relu"),
430
- tf.keras.layers.Dense(1, activation="sigmoid")
431
- ])
432
-
433
- # Focal loss for imbalanced data
434
- self.focal_loss = self._focal_loss
435
-
436
- def _focal_loss(self, y_true, y_pred, alpha=0.25, gamma=2.0):
437
- """Focal loss implementation."""
438
- epsilon = tf.keras.backend.epsilon()
439
- y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
440
-
441
- alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
442
- p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
443
- focal_weight = alpha_t * tf.pow((1 - p_t), gamma)
444
-
445
- bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
446
- focal_loss = focal_weight * bce
447
-
448
- return tf.reduce_mean(focal_loss)
449
-
450
- def call(self, features):
451
- # Get embeddings
452
- user_output = self.user_tower(features)
453
- item_output = self.item_tower(features)
454
-
455
- # Handle bias terms
456
- if isinstance(user_output, tuple):
457
- user_embeddings, user_bias = user_output
458
- else:
459
- user_embeddings = user_output
460
- user_bias = 0.0
461
-
462
- if isinstance(item_output, tuple):
463
- item_embeddings, item_bias = item_output
464
- else:
465
- item_embeddings = item_output
466
- item_bias = 0.0
467
-
468
- return {
469
- "user_embedding": user_embeddings,
470
- "item_embedding": item_embeddings,
471
- "user_bias": user_bias,
472
- "item_bias": item_bias
473
- }
474
-
475
- def compute_loss(self, features, training=False):
476
- # Get embeddings and biases
477
- outputs = self(features)
478
- user_embeddings = outputs["user_embedding"]
479
- item_embeddings = outputs["item_embedding"]
480
- user_bias = outputs["user_bias"]
481
- item_bias = outputs["item_bias"]
482
-
483
- # Rating prediction
484
- concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
485
- rating_predictions = self.rating_model(concatenated, training=training)
486
-
487
- # Add bias terms
488
- rating_predictions_with_bias = rating_predictions + user_bias + item_bias
489
- rating_predictions_with_bias = tf.nn.sigmoid(rating_predictions_with_bias)
490
-
491
- # Losses
492
- rating_loss = self.focal_loss(features["rating"], rating_predictions_with_bias)
493
-
494
- # Adaptive temperature-scaled retrieval loss
495
- scaled_similarities, temperature = self.temperature_similarity(
496
- user_embeddings, item_embeddings
497
- )
498
- retrieval_loss = tf.keras.losses.binary_crossentropy(
499
- features["rating"],
500
- tf.nn.sigmoid(scaled_similarities)
501
- )
502
- retrieval_loss = tf.reduce_mean(retrieval_loss)
503
-
504
- # Enhanced contrastive loss with hard negatives
505
- batch_size = tf.shape(user_embeddings)[0]
506
- positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
507
-
508
- # Random negative sampling
509
- shuffled_indices = tf.random.shuffle(tf.range(batch_size))
510
- negative_item_embeddings = tf.gather(item_embeddings, shuffled_indices)
511
- negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
512
-
513
- # Triplet loss with adaptive margin
514
- margin = 0.5 / temperature # Adaptive margin based on temperature
515
- contrastive_loss = tf.reduce_mean(
516
- tf.maximum(0.0, margin + negative_similarities - positive_similarities)
517
- )
518
-
519
- # Combine losses
520
- total_loss = (
521
- self.rating_weight * rating_loss +
522
- self.retrieval_weight * retrieval_loss +
523
- self.contrastive_weight * contrastive_loss
524
- )
525
-
526
- # Add regularization losses from diversity regularizers
527
- if training:
528
- regularization_losses = tf.add_n(self.losses) if self.losses else 0.0
529
- total_loss += self.diversity_weight * regularization_losses
530
-
531
- return {
532
- 'total_loss': total_loss,
533
- 'rating_loss': rating_loss,
534
- 'retrieval_loss': retrieval_loss,
535
- 'contrastive_loss': contrastive_loss,
536
- 'temperature': temperature,
537
- 'diversity_loss': regularization_losses if training else 0.0
538
- }
539
-
540
-
541
- def create_enhanced_model(data_processor,
542
- embedding_dim=128,
543
- use_bias=True,
544
- use_diversity_reg=True):
545
- """Factory function to create enhanced two-tower model."""
546
-
547
- # Create enhanced towers
548
- item_tower = EnhancedItemTower(
549
- item_vocab_size=len(data_processor.item_vocab),
550
- category_vocab_size=len(data_processor.category_vocab),
551
- brand_vocab_size=len(data_processor.brand_vocab),
552
- embedding_dim=embedding_dim,
553
- use_bias=use_bias,
554
- use_diversity_reg=use_diversity_reg
555
- )
556
-
557
- user_tower = EnhancedUserTower(
558
- max_history_length=50,
559
- embedding_dim=embedding_dim,
560
- use_bias=use_bias,
561
- use_diversity_reg=use_diversity_reg
562
- )
563
-
564
- # Create enhanced model
565
- model = EnhancedTwoTowerModel(
566
- item_tower=item_tower,
567
- user_tower=user_tower,
568
- rating_weight=1.0,
569
- retrieval_weight=0.5,
570
- contrastive_weight=0.3,
571
- diversity_weight=0.1
572
- )
573
-
574
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/improved_two_tower.py DELETED
@@ -1,545 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Improved two-tower model with better embedding discrimination and training stability.
4
- """
5
-
6
- import tensorflow as tf
7
- import tensorflow_recommenders as tfrs
8
- import numpy as np
9
-
10
-
11
- class ImprovedItemTower(tf.keras.Model):
12
- """Enhanced item tower with better discrimination and representation capacity."""
13
-
14
- def __init__(self,
15
- item_vocab_size: int,
16
- category_vocab_size: int,
17
- brand_vocab_size: int,
18
- embedding_dim: int = 128, # Increased from 64
19
- hidden_dims: list = [256, 128], # Deeper network
20
- dropout_rate: float = 0.3,
21
- use_bias: bool = True):
22
- super().__init__()
23
-
24
- self.embedding_dim = embedding_dim
25
- self.use_bias = use_bias
26
-
27
- # Larger embedding layers with proper initialization
28
- self.item_embedding = tf.keras.layers.Embedding(
29
- item_vocab_size, embedding_dim,
30
- embeddings_initializer='glorot_uniform',
31
- name="item_embedding"
32
- )
33
- self.category_embedding = tf.keras.layers.Embedding(
34
- category_vocab_size, embedding_dim,
35
- embeddings_initializer='glorot_uniform',
36
- name="category_embedding"
37
- )
38
- self.brand_embedding = tf.keras.layers.Embedding(
39
- brand_vocab_size, embedding_dim,
40
- embeddings_initializer='glorot_uniform',
41
- name="brand_embedding"
42
- )
43
-
44
- # Price normalization and projection
45
- self.price_normalization = tf.keras.layers.Normalization(name="price_norm")
46
- self.price_projection = tf.keras.layers.Dense(
47
- embedding_dim // 4, activation='relu', name="price_proj"
48
- )
49
-
50
- # Attention mechanism for feature fusion
51
- self.feature_attention = tf.keras.layers.MultiHeadAttention(
52
- num_heads=4,
53
- key_dim=embedding_dim,
54
- name="feature_attention"
55
- )
56
-
57
- # Enhanced dense layers with batch normalization
58
- self.dense_layers = []
59
- for i, dim in enumerate(hidden_dims):
60
- self.dense_layers.extend([
61
- tf.keras.layers.Dense(dim, activation=None, name=f"dense_{i}"),
62
- tf.keras.layers.BatchNormalization(name=f"bn_{i}"),
63
- tf.keras.layers.Activation('relu', name=f"relu_{i}"),
64
- tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}")
65
- ])
66
-
67
- # Output projection with bias term
68
- self.output_layer = tf.keras.layers.Dense(
69
- embedding_dim, activation=None, use_bias=use_bias, name="item_output"
70
- )
71
-
72
- # Learnable bias term for each item
73
- if use_bias:
74
- self.item_bias = tf.keras.layers.Embedding(
75
- item_vocab_size, 1, name="item_bias"
76
- )
77
-
78
- def call(self, inputs, training=None):
79
- """Enhanced forward pass with attention and better feature fusion."""
80
- item_id = inputs["product_id"]
81
- category_id = inputs["category_id"]
82
- brand_id = inputs["brand_id"]
83
- price = inputs["price"]
84
-
85
- # Get embeddings
86
- item_emb = self.item_embedding(item_id) # [batch, emb_dim]
87
- category_emb = self.category_embedding(category_id)
88
- brand_emb = self.brand_embedding(brand_id)
89
-
90
- # Process price
91
- price_norm = self.price_normalization(tf.expand_dims(price, -1))
92
- price_emb = self.price_projection(price_norm)
93
-
94
- # Pad price embedding to match others
95
- price_emb_padded = tf.pad(
96
- price_emb,
97
- [[0, 0], [0, self.embedding_dim - tf.shape(price_emb)[-1]]]
98
- )
99
-
100
- # Stack features for attention [batch, 4, emb_dim]
101
- features = tf.stack([item_emb, category_emb, brand_emb, price_emb_padded], axis=1)
102
-
103
- # Apply self-attention for feature fusion
104
- attended_features = self.feature_attention(
105
- query=features,
106
- value=features,
107
- key=features,
108
- training=training
109
- )
110
-
111
- # Aggregate features (mean pooling)
112
- combined = tf.reduce_mean(attended_features, axis=1)
113
-
114
- # Pass through enhanced dense layers
115
- x = combined
116
- for layer in self.dense_layers:
117
- x = layer(x, training=training)
118
-
119
- # Final output
120
- output = self.output_layer(x)
121
-
122
- # L2 normalize for similarity computations
123
- normalized_output = tf.nn.l2_normalize(output, axis=-1)
124
-
125
- # Add bias if enabled
126
- if self.use_bias:
127
- bias = tf.squeeze(self.item_bias(item_id), axis=-1)
128
- return normalized_output, bias
129
- else:
130
- return normalized_output
131
-
132
-
133
- class ImprovedUserTower(tf.keras.Model):
134
- """Enhanced user tower with better history modeling and representation."""
135
-
136
- def __init__(self,
137
- max_history_length: int = 50,
138
- embedding_dim: int = 128, # Increased from 64
139
- hidden_dims: list = [256, 128], # Deeper network
140
- dropout_rate: float = 0.3,
141
- use_bias: bool = True):
142
- super().__init__()
143
-
144
- self.embedding_dim = embedding_dim
145
- self.max_history_length = max_history_length
146
- self.use_bias = use_bias
147
-
148
- # Demographic embeddings (categorical features)
149
- # Age: 6 categories (Teen, Young Adult, Adult, Middle Age, Mature, Senior)
150
- self.age_embedding = tf.keras.layers.Embedding(
151
- 6, embedding_dim // 16,
152
- embeddings_initializer='glorot_uniform',
153
- name="age_embedding"
154
- )
155
-
156
- # Income: 5 categories (percentile-based)
157
- self.income_embedding = tf.keras.layers.Embedding(
158
- 5, embedding_dim // 16,
159
- embeddings_initializer='glorot_uniform',
160
- name="income_embedding"
161
- )
162
-
163
- # Gender: 2 categories (0=female, 1=male)
164
- self.gender_embedding = tf.keras.layers.Embedding(
165
- 2, embedding_dim // 16,
166
- embeddings_initializer='glorot_uniform',
167
- name="gender_embedding"
168
- )
169
-
170
- # Improved history processing with positional encoding
171
- self.history_transformer = tf.keras.layers.MultiHeadAttention(
172
- num_heads=8, # More attention heads
173
- key_dim=embedding_dim,
174
- name="history_transformer"
175
- )
176
-
177
- # History aggregation with learned weights
178
- self.history_aggregation = tf.keras.layers.Dense(
179
- embedding_dim, activation='tanh', name="history_agg"
180
- )
181
-
182
- # Enhanced dense layers with batch normalization
183
- self.dense_layers = []
184
- for i, dim in enumerate(hidden_dims):
185
- self.dense_layers.extend([
186
- tf.keras.layers.Dense(dim, activation=None, name=f"user_dense_{i}"),
187
- tf.keras.layers.BatchNormalization(name=f"user_bn_{i}"),
188
- tf.keras.layers.Activation('relu', name=f"user_relu_{i}"),
189
- tf.keras.layers.Dropout(dropout_rate, name=f"user_dropout_{i}")
190
- ])
191
-
192
- # Output layer
193
- self.output_layer = tf.keras.layers.Dense(
194
- embedding_dim, activation=None, use_bias=use_bias, name="user_output"
195
- )
196
-
197
- # Learnable user bias
198
- if use_bias:
199
- # We'll need to handle user bias differently since we don't have user vocab in inference
200
- self.global_user_bias = tf.Variable(
201
- initial_value=0.0, trainable=True, name="global_user_bias"
202
- )
203
-
204
- def call(self, inputs, training=None):
205
- """Enhanced forward pass with better history modeling."""
206
- age = inputs["age"] # Now categorical (0-5)
207
- gender = inputs["gender"] # Categorical (0-1)
208
- income = inputs["income"] # Now categorical (0-4)
209
- item_history = inputs["item_history_embeddings"] # [batch_size, seq_len, emb_dim]
210
-
211
- # Process demographics through embeddings
212
- age_emb = self.age_embedding(age) # [batch_size, embedding_dim//16]
213
- income_emb = self.income_embedding(income) # [batch_size, embedding_dim//16]
214
- gender_emb = self.gender_embedding(gender) # [batch_size, embedding_dim//16]
215
-
216
- # Combine all demographic embeddings
217
- demo_combined = tf.concat([age_emb, income_emb, gender_emb], axis=-1)
218
- # Total demographics: 3 * (embedding_dim//16) = ~18.75% of embedding_dim
219
-
220
- # Enhanced history processing with positional encoding
221
- batch_size = tf.shape(item_history)[0]
222
- seq_len = tf.shape(item_history)[1]
223
-
224
- # Create positional encoding
225
- positions = tf.range(seq_len, dtype=tf.float32)
226
- position_encoding = tf.sin(
227
- positions[:, tf.newaxis] /
228
- tf.pow(10000.0, 2 * tf.range(self.embedding_dim, dtype=tf.float32) / self.embedding_dim)
229
- )
230
- position_encoding = tf.expand_dims(position_encoding, 0)
231
- position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
232
-
233
- # Add positional encoding to history
234
- history_with_pos = item_history + position_encoding
235
-
236
- # Create attention mask for padding
237
- history_mask = tf.reduce_sum(tf.abs(item_history), axis=-1) > 0
238
-
239
- # Apply transformer attention to history
240
- attended_history = self.history_transformer(
241
- query=history_with_pos,
242
- value=history_with_pos,
243
- key=history_with_pos,
244
- attention_mask=history_mask,
245
- training=training
246
- )
247
-
248
- # Aggregate history with learned weights
249
- history_weights = tf.nn.softmax(
250
- tf.keras.layers.Dense(1)(attended_history), axis=1
251
- )
252
- history_aggregated = tf.reduce_sum(
253
- attended_history * history_weights, axis=1
254
- )
255
-
256
- # Apply additional processing
257
- history_processed = self.history_aggregation(history_aggregated)
258
-
259
- # Combine all features
260
- combined = tf.concat([
261
- demo_combined,
262
- history_processed
263
- ], axis=-1)
264
-
265
- # Pass through enhanced dense layers
266
- x = combined
267
- for layer in self.dense_layers:
268
- x = layer(x, training=training)
269
-
270
- # Final output
271
- output = self.output_layer(x)
272
-
273
- # L2 normalize for similarity computations
274
- normalized_output = tf.nn.l2_normalize(output, axis=-1)
275
-
276
- # Add global bias if enabled
277
- if self.use_bias:
278
- return normalized_output, self.global_user_bias
279
- else:
280
- return normalized_output
281
-
282
-
283
- class TemperatureScaledSimilarity(tf.keras.layers.Layer):
284
- """Learnable temperature scaling for similarity computations."""
285
-
286
- def __init__(self, initial_temperature=1.0, **kwargs):
287
- super().__init__(**kwargs)
288
- self.initial_temperature = initial_temperature
289
-
290
- def build(self, input_shape):
291
- self.temperature = self.add_weight(
292
- name='temperature',
293
- shape=(),
294
- initializer=tf.keras.initializers.Constant(self.initial_temperature),
295
- trainable=True
296
- )
297
- super().build(input_shape)
298
-
299
- def call(self, user_embeddings, item_embeddings):
300
- """Compute temperature-scaled similarity."""
301
- # Dot product similarity
302
- similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
303
-
304
- # Scale by learnable temperature
305
- scaled_similarities = similarities / tf.maximum(self.temperature, 0.01) # Prevent division by 0
306
-
307
- return scaled_similarities
308
-
309
-
310
- class ImprovedTwoTowerModel(tfrs.Model):
311
- """Enhanced two-tower model with better discrimination and training stability."""
312
-
313
- def __init__(self,
314
- item_tower: ImprovedItemTower,
315
- user_tower: ImprovedUserTower,
316
- rating_weight: float = 1.0,
317
- retrieval_weight: float = 1.0,
318
- contrastive_weight: float = 0.5,
319
- use_focal_loss: bool = True):
320
- super().__init__()
321
-
322
- self.item_tower = item_tower
323
- self.user_tower = user_tower
324
- self.rating_weight = rating_weight
325
- self.retrieval_weight = retrieval_weight
326
- self.contrastive_weight = contrastive_weight
327
- self.use_focal_loss = use_focal_loss
328
-
329
- # Temperature-scaled similarity
330
- self.temperature_similarity = TemperatureScaledSimilarity()
331
-
332
- # Enhanced rating prediction with more capacity
333
- self.rating_model = tf.keras.Sequential([
334
- tf.keras.layers.Dense(512, activation="relu"),
335
- tf.keras.layers.BatchNormalization(),
336
- tf.keras.layers.Dropout(0.3),
337
- tf.keras.layers.Dense(256, activation="relu"),
338
- tf.keras.layers.BatchNormalization(),
339
- tf.keras.layers.Dropout(0.2),
340
- tf.keras.layers.Dense(64, activation="relu"),
341
- tf.keras.layers.Dense(1, activation="sigmoid")
342
- ])
343
-
344
- # Rating task with better loss
345
- if use_focal_loss:
346
- self.rating_loss = self._focal_loss
347
- else:
348
- self.rating_loss = tf.keras.losses.BinaryCrossentropy()
349
-
350
- # Contrastive loss for embedding separation
351
- self.contrastive_loss = tf.keras.losses.CosineSimilarity()
352
-
353
- def _focal_loss(self, y_true, y_pred, alpha=0.25, gamma=2.0):
354
- """Focal loss for handling imbalanced data."""
355
- epsilon = tf.keras.backend.epsilon()
356
- y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
357
-
358
- # Compute focal weight
359
- alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
360
- p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
361
- focal_weight = alpha_t * tf.pow((1 - p_t), gamma)
362
-
363
- # Compute loss
364
- bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
365
- focal_loss = focal_weight * bce
366
-
367
- return tf.reduce_mean(focal_loss)
368
-
369
- def call(self, features):
370
- # Get embeddings (handle bias if present)
371
- user_output = self.user_tower(features)
372
- item_output = self.item_tower(features)
373
-
374
- # Handle bias terms
375
- if isinstance(user_output, tuple):
376
- user_embeddings, user_bias = user_output
377
- else:
378
- user_embeddings = user_output
379
- user_bias = 0.0
380
-
381
- if isinstance(item_output, tuple):
382
- item_embeddings, item_bias = item_output
383
- else:
384
- item_embeddings = item_output
385
- item_bias = 0.0
386
-
387
- return {
388
- "user_embedding": user_embeddings,
389
- "item_embedding": item_embeddings,
390
- "user_bias": user_bias,
391
- "item_bias": item_bias
392
- }
393
-
394
- def _hard_negative_mining(self, user_embeddings, item_embeddings, ratings, num_negatives=5):
395
- """Mine hard negatives for better training."""
396
- batch_size = tf.shape(user_embeddings)[0]
397
-
398
- # Compute all pairwise similarities
399
- user_norm = tf.nn.l2_normalize(user_embeddings, axis=1)
400
- item_norm = tf.nn.l2_normalize(item_embeddings, axis=1)
401
-
402
- # Expand dimensions for broadcasting: [batch, 1, dim] x [1, batch, dim]
403
- user_expanded = tf.expand_dims(user_norm, 1)
404
- item_expanded = tf.expand_dims(item_norm, 0)
405
-
406
- # Compute similarity matrix [batch, batch]
407
- similarity_matrix = tf.reduce_sum(user_expanded * item_expanded, axis=2)
408
-
409
- # Create mask to exclude positive pairs
410
- positive_mask = tf.eye(batch_size, dtype=tf.bool)
411
- negative_mask = tf.logical_not(positive_mask)
412
-
413
- # Get negative similarities and find hardest negatives
414
- negative_similarities = tf.where(negative_mask, similarity_matrix, -tf.float32.max)
415
-
416
- # Get top-k hardest negatives (highest similarities among negatives)
417
- _, hard_negative_indices = tf.nn.top_k(negative_similarities, k=num_negatives)
418
-
419
- return hard_negative_indices
420
-
421
- def compute_loss(self, features, training=False):
422
- # Get embeddings and biases
423
- outputs = self(features)
424
- user_embeddings = outputs["user_embedding"]
425
- item_embeddings = outputs["item_embedding"]
426
- user_bias = outputs["user_bias"]
427
- item_bias = outputs["item_bias"]
428
-
429
- # Rating prediction with bias terms
430
- concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
431
- rating_predictions = self.rating_model(concatenated, training=training)
432
-
433
- # Add bias terms to rating predictions
434
- rating_predictions_with_bias = rating_predictions + user_bias + item_bias
435
- rating_predictions_with_bias = tf.nn.sigmoid(rating_predictions_with_bias)
436
-
437
- # Rating loss
438
- rating_loss = self.rating_loss(features["rating"], rating_predictions_with_bias)
439
-
440
- # Temperature-scaled retrieval loss
441
- scaled_similarities = self.temperature_similarity(user_embeddings, item_embeddings)
442
- retrieval_loss = tf.keras.losses.binary_crossentropy(
443
- features["rating"],
444
- tf.nn.sigmoid(scaled_similarities)
445
- )
446
- retrieval_loss = tf.reduce_mean(retrieval_loss)
447
-
448
- # Enhanced contrastive loss with hard negative mining
449
- batch_size = tf.shape(user_embeddings)[0]
450
-
451
- if training and batch_size > 5: # Only use hard negatives during training with sufficient batch size
452
- # Hard negative mining
453
- hard_negative_indices = self._hard_negative_mining(
454
- user_embeddings, item_embeddings, features["rating"], num_negatives=3
455
- )
456
-
457
- # Positive similarities
458
- positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
459
-
460
- # Hard negative similarities
461
- hard_negative_losses = []
462
- for i in range(3): # Use top 3 hard negatives
463
- neg_indices = hard_negative_indices[:, i]
464
- negative_item_embeddings = tf.gather(item_embeddings, neg_indices)
465
- negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
466
-
467
- # Triplet-like loss with margin
468
- margin_loss = tf.maximum(0.0, 0.2 + negative_similarities - positive_similarities)
469
- hard_negative_losses.append(margin_loss)
470
-
471
- # Average hard negative losses
472
- contrastive_loss = tf.reduce_mean(tf.stack(hard_negative_losses))
473
-
474
- else:
475
- # Fallback to random negative sampling
476
- shuffled_indices = tf.random.shuffle(tf.range(batch_size))
477
- negative_item_embeddings = tf.gather(item_embeddings, shuffled_indices)
478
-
479
- # Positive similarities
480
- positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
481
-
482
- # Negative similarities
483
- negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
484
-
485
- # Contrastive loss (maximize positive, minimize negative)
486
- contrastive_loss = tf.reduce_mean(
487
- tf.maximum(0.0, 0.5 + negative_similarities - positive_similarities)
488
- )
489
-
490
- # Combine losses
491
- total_loss = (
492
- self.rating_weight * rating_loss +
493
- self.retrieval_weight * retrieval_loss +
494
- self.contrastive_weight * contrastive_loss
495
- )
496
-
497
- # Add L2 regularization to prevent overfitting
498
- l2_loss = tf.add_n([
499
- tf.nn.l2_loss(var) for var in self.trainable_variables
500
- if 'bias' not in var.name and 'normalization' not in var.name
501
- ]) * 1e-5
502
-
503
- total_loss += l2_loss
504
-
505
- return {
506
- 'total_loss': total_loss,
507
- 'rating_loss': rating_loss,
508
- 'retrieval_loss': retrieval_loss,
509
- 'contrastive_loss': contrastive_loss,
510
- 'l2_loss': l2_loss
511
- }
512
-
513
-
514
- def create_improved_model(data_processor,
515
- embedding_dim=128,
516
- use_bias=True,
517
- use_focal_loss=True):
518
- """Factory function to create improved two-tower model."""
519
-
520
- # Create enhanced towers
521
- item_tower = ImprovedItemTower(
522
- item_vocab_size=len(data_processor.item_vocab),
523
- category_vocab_size=len(data_processor.category_vocab),
524
- brand_vocab_size=len(data_processor.brand_vocab),
525
- embedding_dim=embedding_dim,
526
- use_bias=use_bias
527
- )
528
-
529
- user_tower = ImprovedUserTower(
530
- max_history_length=50,
531
- embedding_dim=embedding_dim,
532
- use_bias=use_bias
533
- )
534
-
535
- # Create improved model
536
- model = ImprovedTwoTowerModel(
537
- item_tower=item_tower,
538
- user_tower=user_tower,
539
- rating_weight=1.0,
540
- retrieval_weight=0.5,
541
- contrastive_weight=0.3,
542
- use_focal_loss=use_focal_loss
543
- )
544
-
545
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/user_tower.py CHANGED
@@ -32,6 +32,27 @@ class UserTower(tf.keras.Model):
32
  2, embedding_dim // 16, name="gender_embedding"
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # History aggregation layers
36
  self.history_attention = tf.keras.layers.MultiHeadAttention(
37
  num_heads=4,
@@ -57,12 +78,20 @@ class UserTower(tf.keras.Model):
57
  age = inputs["age"] # Now categorical (0-5)
58
  gender = inputs["gender"] # Categorical (0-1)
59
  income = inputs["income"] # Now categorical (0-4)
 
 
 
 
60
  item_history = inputs["item_history_embeddings"] # [batch_size, seq_len, emb_dim]
61
 
62
  # Process demographics through embeddings
63
  age_emb = self.age_embedding(age) # [batch_size, embedding_dim//16]
64
  income_emb = self.income_embedding(income) # [batch_size, embedding_dim//16]
65
  gender_emb = self.gender_embedding(gender) # [batch_size, embedding_dim//16]
 
 
 
 
66
 
67
  # Aggregate item history using attention
68
  # Create attention mask for padding
@@ -84,6 +113,10 @@ class UserTower(tf.keras.Model):
84
  age_emb,
85
  income_emb,
86
  gender_emb,
 
 
 
 
87
  history_aggregated
88
  ], axis=-1)
89
 
 
32
  2, embedding_dim // 16, name="gender_embedding"
33
  )
34
 
35
+ # New demographic embeddings
36
+ # Profession: 8 categories (Technology, Healthcare, Education, Finance, Retail, Manufacturing, Services, Other)
37
+ self.profession_embedding = tf.keras.layers.Embedding(
38
+ 8, embedding_dim // 16, name="profession_embedding"
39
+ )
40
+
41
+ # Location: 3 categories (Urban, Suburban, Rural)
42
+ self.location_embedding = tf.keras.layers.Embedding(
43
+ 3, embedding_dim // 16, name="location_embedding"
44
+ )
45
+
46
+ # Education Level: 5 categories (High School, Some College, Bachelor's, Master's, PhD+)
47
+ self.education_embedding = tf.keras.layers.Embedding(
48
+ 5, embedding_dim // 16, name="education_embedding"
49
+ )
50
+
51
+ # Marital Status: 4 categories (Single, Married, Divorced, Widowed)
52
+ self.marital_embedding = tf.keras.layers.Embedding(
53
+ 4, embedding_dim // 16, name="marital_embedding"
54
+ )
55
+
56
  # History aggregation layers
57
  self.history_attention = tf.keras.layers.MultiHeadAttention(
58
  num_heads=4,
 
78
  age = inputs["age"] # Now categorical (0-5)
79
  gender = inputs["gender"] # Categorical (0-1)
80
  income = inputs["income"] # Now categorical (0-4)
81
+ profession = inputs["profession"] # Categorical (0-7)
82
+ location = inputs["location"] # Categorical (0-2)
83
+ education = inputs["education_level"] # Categorical (0-4)
84
+ marital_status = inputs["marital_status"] # Categorical (0-3)
85
  item_history = inputs["item_history_embeddings"] # [batch_size, seq_len, emb_dim]
86
 
87
  # Process demographics through embeddings
88
  age_emb = self.age_embedding(age) # [batch_size, embedding_dim//16]
89
  income_emb = self.income_embedding(income) # [batch_size, embedding_dim//16]
90
  gender_emb = self.gender_embedding(gender) # [batch_size, embedding_dim//16]
91
+ profession_emb = self.profession_embedding(profession) # [batch_size, embedding_dim//16]
92
+ location_emb = self.location_embedding(location) # [batch_size, embedding_dim//16]
93
+ education_emb = self.education_embedding(education) # [batch_size, embedding_dim//16]
94
+ marital_emb = self.marital_embedding(marital_status) # [batch_size, embedding_dim//16]
95
 
96
  # Aggregate item history using attention
97
  # Create attention mask for padding
 
113
  age_emb,
114
  income_emb,
115
  gender_emb,
116
+ profession_emb,
117
+ location_emb,
118
+ education_emb,
119
+ marital_emb,
120
  history_aggregated
121
  ], axis=-1)
122
 
src/preprocessing/optimized_dataset_creator.py DELETED
@@ -1,111 +0,0 @@
1
- """
2
- Optimized dataset creation script with performance improvements.
3
- """
4
- import time
5
- import numpy as np
6
- from src.preprocessing.user_data_preparation import UserDatasetCreator
7
- from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
8
-
9
-
10
- def create_optimized_dataset(max_history_length: int = 50,
11
- batch_size: int = 512,
12
- negative_samples_per_positive: int = 2,
13
- use_sample: bool = False,
14
- sample_size: int = 10000):
15
- """
16
- Create dataset with optimized performance settings.
17
-
18
- Args:
19
- max_history_length: Maximum user interaction history length
20
- batch_size: Batch size for TensorFlow dataset
21
- negative_samples_per_positive: Negative sampling ratio
22
- use_sample: Whether to use a sample of the data for faster processing
23
- sample_size: Size of sample if use_sample=True
24
- """
25
- print("Starting optimized dataset creation...")
26
- start_time = time.time()
27
-
28
- # Initialize with optimized settings
29
- dataset_creator = UserDatasetCreator(max_history_length=max_history_length)
30
- data_processor = DataProcessor()
31
-
32
- # Load data
33
- print("Loading data...")
34
- load_start = time.time()
35
- items_df, users_df, interactions_df = data_processor.load_data()
36
- print(f"Data loaded in {time.time() - load_start:.2f} seconds")
37
-
38
- # Optional: Use sample for faster development/testing
39
- if use_sample:
40
- print(f"Using sample of {sample_size} interactions for faster processing...")
41
- sample_interactions = interactions_df.sample(min(sample_size, len(interactions_df)))
42
- user_ids = set(sample_interactions['user_id'])
43
- item_ids = set(sample_interactions['product_id'])
44
-
45
- users_df = users_df[users_df['user_id'].isin(user_ids)]
46
- items_df = items_df[items_df['product_id'].isin(item_ids)]
47
- interactions_df = sample_interactions
48
-
49
- print(f"Sample: {len(items_df)} items, {len(users_df)} users, {len(interactions_df)} interactions")
50
-
51
- # Load embeddings with caching
52
- print("Loading item embeddings...")
53
- embed_start = time.time()
54
- item_embeddings = dataset_creator.load_item_embeddings()
55
- print(f"Embeddings loaded in {time.time() - embed_start:.2f} seconds")
56
-
57
- # Create temporal split
58
- print("Creating temporal split...")
59
- split_start = time.time()
60
- train_interactions, val_interactions = dataset_creator.create_temporal_split(interactions_df)
61
- print(f"Temporal split created in {time.time() - split_start:.2f} seconds")
62
-
63
- # Create training dataset with optimizations
64
- print("Creating optimized training dataset...")
65
- train_start = time.time()
66
- training_features = dataset_creator.create_training_dataset(
67
- train_interactions, items_df, users_df, item_embeddings,
68
- negative_samples_per_positive=negative_samples_per_positive
69
- )
70
- print(f"Training dataset created in {time.time() - train_start:.2f} seconds")
71
-
72
- # Create TensorFlow dataset optimized for CPU
73
- print("Creating TensorFlow dataset...")
74
- tf_start = time.time()
75
- tf_dataset = create_tf_dataset(training_features, batch_size=batch_size)
76
- print(f"TensorFlow dataset created in {time.time() - tf_start:.2f} seconds")
77
-
78
- # Save optimized dataset
79
- print("Saving dataset...")
80
- save_start = time.time()
81
- dataset_creator.save_dataset(training_features, "src/artifacts/")
82
-
83
- # Save vocabularies for later use
84
- data_processor.save_vocabularies("src/artifacts/")
85
- print(f"Dataset saved in {time.time() - save_start:.2f} seconds")
86
-
87
- total_time = time.time() - start_time
88
- print(f"\nOptimized dataset creation completed in {total_time:.2f} seconds!")
89
- print(f"Training samples: {len(training_features['rating'])}")
90
- print(f"Memory usage optimized for CPU training")
91
-
92
- return tf_dataset, training_features
93
-
94
-
95
- if __name__ == "__main__":
96
- # Run with optimized settings
97
- tf_dataset, features = create_optimized_dataset(
98
- max_history_length=30, # Reduced for speed
99
- batch_size=512, # Larger batches for CPU efficiency
100
- negative_samples_per_positive=2, # Reduced sampling ratio
101
- use_sample=True, # Use sample for development
102
- sample_size=50000 # Reasonable sample size
103
- )
104
-
105
- print("\nDataset creation optimization complete!")
106
- print("Key optimizations applied:")
107
- print("- Vectorized DataFrame operations")
108
- print("- Parallel negative sampling")
109
- print("- Memory-efficient embedding lookup")
110
- print("- Optimized TensorFlow dataset pipeline")
111
- print("- LRU caching for embeddings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/preprocessing/user_data_preparation.py CHANGED
@@ -45,6 +45,50 @@ class UserDatasetCreator:
45
  categories = np.clip(categories, 0, 4)
46
 
47
  return categories.astype(np.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @lru_cache(maxsize=1)
50
  def load_item_embeddings(self, embeddings_path: str = "src/artifacts/item_embeddings.npy") -> Dict[int, np.ndarray]:
@@ -155,6 +199,12 @@ class UserDatasetCreator:
155
  # Categorize income (5 percentile-based categories)
156
  user_demographics['income_category'] = self.categorize_income(user_demographics['income'])
157
 
 
 
 
 
 
 
158
  # Create mapping from user_id to array index
159
  user_id_to_index = {uid: idx for idx, uid in enumerate(user_demographics['user_id'])}
160
 
@@ -165,6 +215,10 @@ class UserDatasetCreator:
165
  'age': user_demographics['age_category'].values.astype(np.int32), # Categorical age
166
  'gender': user_demographics['gender_numeric'].values.astype(np.int32),
167
  'income': user_demographics['income_category'].values.astype(np.int32), # Categorical income
 
 
 
 
168
  'item_history_embeddings': np.array([
169
  user_aggregated_embeddings[uid] for uid in user_demographics['user_id']
170
  ]).astype(np.float32)
@@ -173,6 +227,10 @@ class UserDatasetCreator:
173
  print(f"Prepared user features for {len(valid_users)} users")
174
  print(f"Age categories: {np.unique(user_features['age'], return_counts=True)}")
175
  print(f"Income categories: {np.unique(user_features['income'], return_counts=True)}")
 
 
 
 
176
  print(f"History embeddings shape: {user_features['item_history_embeddings'].shape}")
177
 
178
  return user_features
@@ -256,6 +314,10 @@ class UserDatasetCreator:
256
  training_features['age'] = user_features['age'][user_indices]
257
  training_features['gender'] = user_features['gender'][user_indices]
258
  training_features['income'] = user_features['income'][user_indices]
 
 
 
 
259
  training_features['item_history_embeddings'] = user_features['item_history_embeddings'][user_indices]
260
 
261
  # Item features for each pair
@@ -412,10 +474,20 @@ def prepare_user_features(users_df: pd.DataFrame,
412
  user_idx = users_df[users_df['user_id'] == user_id].index[0]
413
  income_cat = income_categories[user_idx]
414
 
 
 
 
 
 
 
415
  user_feature_dict[user_id] = {
416
  'age': age_cat,
417
  'gender': gender_cat,
418
  'income': income_cat,
 
 
 
 
419
  'item_history_embeddings': user_aggregated_embeddings[user_id]
420
  }
421
 
 
45
  categories = np.clip(categories, 0, 4)
46
 
47
  return categories.astype(np.int32)
48
+
49
+ def categorize_profession(self, profession: str) -> int:
50
+ """Categorize profession into numeric categories."""
51
+ profession_map = {
52
+ "Technology": 0,
53
+ "Healthcare": 1,
54
+ "Education": 2,
55
+ "Finance": 3,
56
+ "Retail": 4,
57
+ "Manufacturing": 5,
58
+ "Services": 6,
59
+ "Other": 7
60
+ }
61
+ return profession_map.get(profession, 7) # Default to "Other"
62
+
63
+ def categorize_location(self, location: str) -> int:
64
+ """Categorize location into numeric categories."""
65
+ location_map = {
66
+ "Urban": 0,
67
+ "Suburban": 1,
68
+ "Rural": 2
69
+ }
70
+ return location_map.get(location, 0) # Default to "Urban"
71
+
72
+ def categorize_education_level(self, education: str) -> int:
73
+ """Categorize education level into numeric categories."""
74
+ education_map = {
75
+ "High School": 0,
76
+ "Some College": 1,
77
+ "Bachelor's": 2,
78
+ "Master's": 3,
79
+ "PhD+": 4
80
+ }
81
+ return education_map.get(education, 0) # Default to "High School"
82
+
83
+ def categorize_marital_status(self, marital_status: str) -> int:
84
+ """Categorize marital status into numeric categories."""
85
+ marital_map = {
86
+ "Single": 0,
87
+ "Married": 1,
88
+ "Divorced": 2,
89
+ "Widowed": 3
90
+ }
91
+ return marital_map.get(marital_status, 0) # Default to "Single"
92
 
93
  @lru_cache(maxsize=1)
94
  def load_item_embeddings(self, embeddings_path: str = "src/artifacts/item_embeddings.npy") -> Dict[int, np.ndarray]:
 
199
  # Categorize income (5 percentile-based categories)
200
  user_demographics['income_category'] = self.categorize_income(user_demographics['income'])
201
 
202
+ # Categorize new demographic features
203
+ user_demographics['profession_category'] = user_demographics['profession'].apply(self.categorize_profession)
204
+ user_demographics['location_category'] = user_demographics['location'].apply(self.categorize_location)
205
+ user_demographics['education_category'] = user_demographics['education_level'].apply(self.categorize_education_level)
206
+ user_demographics['marital_category'] = user_demographics['marital_status'].apply(self.categorize_marital_status)
207
+
208
  # Create mapping from user_id to array index
209
  user_id_to_index = {uid: idx for idx, uid in enumerate(user_demographics['user_id'])}
210
 
 
215
  'age': user_demographics['age_category'].values.astype(np.int32), # Categorical age
216
  'gender': user_demographics['gender_numeric'].values.astype(np.int32),
217
  'income': user_demographics['income_category'].values.astype(np.int32), # Categorical income
218
+ 'profession': user_demographics['profession_category'].values.astype(np.int32), # Categorical profession
219
+ 'location': user_demographics['location_category'].values.astype(np.int32), # Categorical location
220
+ 'education_level': user_demographics['education_category'].values.astype(np.int32), # Categorical education
221
+ 'marital_status': user_demographics['marital_category'].values.astype(np.int32), # Categorical marital status
222
  'item_history_embeddings': np.array([
223
  user_aggregated_embeddings[uid] for uid in user_demographics['user_id']
224
  ]).astype(np.float32)
 
227
  print(f"Prepared user features for {len(valid_users)} users")
228
  print(f"Age categories: {np.unique(user_features['age'], return_counts=True)}")
229
  print(f"Income categories: {np.unique(user_features['income'], return_counts=True)}")
230
+ print(f"Profession categories: {np.unique(user_features['profession'], return_counts=True)}")
231
+ print(f"Location categories: {np.unique(user_features['location'], return_counts=True)}")
232
+ print(f"Education categories: {np.unique(user_features['education_level'], return_counts=True)}")
233
+ print(f"Marital status categories: {np.unique(user_features['marital_status'], return_counts=True)}")
234
  print(f"History embeddings shape: {user_features['item_history_embeddings'].shape}")
235
 
236
  return user_features
 
314
  training_features['age'] = user_features['age'][user_indices]
315
  training_features['gender'] = user_features['gender'][user_indices]
316
  training_features['income'] = user_features['income'][user_indices]
317
+ training_features['profession'] = user_features['profession'][user_indices]
318
+ training_features['location'] = user_features['location'][user_indices]
319
+ training_features['education_level'] = user_features['education_level'][user_indices]
320
+ training_features['marital_status'] = user_features['marital_status'][user_indices]
321
  training_features['item_history_embeddings'] = user_features['item_history_embeddings'][user_indices]
322
 
323
  # Item features for each pair
 
474
  user_idx = users_df[users_df['user_id'] == user_id].index[0]
475
  income_cat = income_categories[user_idx]
476
 
477
+ # Get new demographic features from the row
478
+ profession_cat = creator.categorize_profession(user_row.get('profession', 'Other'))
479
+ location_cat = creator.categorize_location(user_row.get('location', 'Urban'))
480
+ education_cat = creator.categorize_education_level(user_row.get('education_level', 'High School'))
481
+ marital_cat = creator.categorize_marital_status(user_row.get('marital_status', 'Single'))
482
+
483
  user_feature_dict[user_id] = {
484
  'age': age_cat,
485
  'gender': gender_cat,
486
  'income': income_cat,
487
+ 'profession': profession_cat,
488
+ 'location': location_cat,
489
+ 'education_level': education_cat,
490
+ 'marital_status': marital_cat,
491
  'item_history_embeddings': user_aggregated_embeddings[user_id]
492
  }
493
 
src/training/curriculum_trainer.py DELETED
@@ -1,341 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Curriculum learning trainer for the improved two-tower model.
4
- Implements progressive difficulty training for better convergence.
5
- """
6
-
7
- import tensorflow as tf
8
- import numpy as np
9
- import pickle
10
- import os
11
- import time
12
- from typing import Dict, List, Tuple
13
-
14
- from src.models.improved_two_tower import create_improved_model
15
- from src.preprocessing.data_loader import DataProcessor
16
-
17
-
18
- class CurriculumTrainer:
19
- """Trainer with curriculum learning for improved two-tower model."""
20
-
21
- def __init__(self,
22
- embedding_dim: int = 128,
23
- learning_rate: float = 0.001,
24
- use_focal_loss: bool = True,
25
- curriculum_stages: int = 3):
26
-
27
- self.embedding_dim = embedding_dim
28
- self.learning_rate = learning_rate
29
- self.use_focal_loss = use_focal_loss
30
- self.curriculum_stages = curriculum_stages
31
-
32
- self.data_processor = None
33
- self.model = None
34
-
35
- def load_data_processor(self, artifacts_path: str = "src/artifacts/"):
36
- """Load data processor with vocabularies."""
37
- self.data_processor = DataProcessor()
38
- self.data_processor.load_vocabularies(f"{artifacts_path}/vocabularies.pkl")
39
- print("Data processor loaded successfully")
40
-
41
- def create_model(self):
42
- """Create improved two-tower model."""
43
- if self.data_processor is None:
44
- raise ValueError("Data processor must be loaded first")
45
-
46
- self.model = create_improved_model(
47
- data_processor=self.data_processor,
48
- embedding_dim=self.embedding_dim,
49
- use_bias=True,
50
- use_focal_loss=self.use_focal_loss
51
- )
52
-
53
- # Compile model
54
- self.model.compile(
55
- optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
56
- )
57
-
58
- print("Improved two-tower model created successfully")
59
-
60
- def _create_curriculum_stages(self, features: Dict[str, np.ndarray]) -> List[Dict[str, np.ndarray]]:
61
- """Create curriculum stages based on interaction complexity."""
62
-
63
- # Calculate interaction history lengths for curriculum
64
- history_lengths = []
65
- for i in range(len(features['age'])):
66
- hist = features['item_history_embeddings'][i]
67
- # Count non-zero embeddings
68
- length = np.sum(np.any(hist != 0, axis=1))
69
- history_lengths.append(length)
70
-
71
- history_lengths = np.array(history_lengths)
72
-
73
- # Create stages based on history length percentiles
74
- stages = []
75
-
76
- if self.curriculum_stages == 3:
77
- # Stage 1: Simple cases (short or no history)
78
- stage1_mask = history_lengths <= np.percentile(history_lengths, 33)
79
-
80
- # Stage 2: Medium complexity (medium history)
81
- stage2_mask = (history_lengths > np.percentile(history_lengths, 33)) & \
82
- (history_lengths <= np.percentile(history_lengths, 67))
83
-
84
- # Stage 3: Complex cases (long history)
85
- stage3_mask = history_lengths > np.percentile(history_lengths, 67)
86
-
87
- masks = [stage1_mask, stage2_mask, stage3_mask]
88
- stage_names = ["Simple (short history)", "Medium (moderate history)", "Complex (long history)"]
89
-
90
- else:
91
- # Flexible number of stages
92
- percentiles = np.linspace(0, 100, self.curriculum_stages + 1)
93
- masks = []
94
- stage_names = []
95
-
96
- for i in range(self.curriculum_stages):
97
- if i == 0:
98
- mask = history_lengths <= np.percentile(history_lengths, percentiles[i+1])
99
- stage_names.append(f"Stage {i+1} (≀{percentiles[i+1]:.0f}%ile)")
100
- elif i == self.curriculum_stages - 1:
101
- mask = history_lengths > np.percentile(history_lengths, percentiles[i])
102
- stage_names.append(f"Stage {i+1} (>{percentiles[i]:.0f}%ile)")
103
- else:
104
- mask = (history_lengths > np.percentile(history_lengths, percentiles[i])) & \
105
- (history_lengths <= np.percentile(history_lengths, percentiles[i+1]))
106
- stage_names.append(f"Stage {i+1} ({percentiles[i]:.0f}-{percentiles[i+1]:.0f}%ile)")
107
-
108
- masks.append(mask)
109
-
110
- # Create stage datasets
111
- for i, (mask, name) in enumerate(zip(masks, stage_names)):
112
- stage_features = {}
113
- for key, values in features.items():
114
- stage_features[key] = values[mask]
115
-
116
- print(f" Stage {i+1} ({name}): {np.sum(mask)} samples")
117
- stages.append(stage_features)
118
-
119
- return stages
120
-
121
- def _create_tf_dataset(self, features: Dict[str, np.ndarray],
122
- batch_size: int = 256,
123
- shuffle: bool = True) -> tf.data.Dataset:
124
- """Create TensorFlow dataset from features."""
125
-
126
- dataset = tf.data.Dataset.from_tensor_slices(features)
127
-
128
- if shuffle:
129
- dataset = dataset.shuffle(buffer_size=10000)
130
-
131
- dataset = dataset.batch(batch_size, drop_remainder=False)
132
- dataset = dataset.prefetch(tf.data.AUTOTUNE)
133
-
134
- return dataset
135
-
136
- def train_with_curriculum(self,
137
- training_features: Dict[str, np.ndarray],
138
- validation_features: Dict[str, np.ndarray],
139
- epochs_per_stage: int = 10,
140
- batch_size: int = 256) -> Dict:
141
- """Train model using curriculum learning."""
142
-
143
- print(f"πŸŽ“ CURRICULUM LEARNING TRAINING")
144
- print(f"Stages: {self.curriculum_stages} | Epochs per stage: {epochs_per_stage}")
145
- print("="*70)
146
-
147
- # Create curriculum stages
148
- print("\nπŸ“š Creating curriculum stages...")
149
- training_stages = self._create_curriculum_stages(training_features)
150
-
151
- # Training history
152
- history = {
153
- 'stage_losses': [],
154
- 'stage_val_losses': [],
155
- 'stage_times': [],
156
- 'total_loss': [],
157
- 'rating_loss': [],
158
- 'retrieval_loss': [],
159
- 'contrastive_loss': [],
160
- 'val_total_loss': [],
161
- 'val_rating_loss': [],
162
- 'val_retrieval_loss': []
163
- }
164
-
165
- # Validation dataset (constant across stages)
166
- val_dataset = self._create_tf_dataset(validation_features, batch_size, shuffle=False)
167
-
168
- total_start_time = time.time()
169
-
170
- # Train through curriculum stages
171
- for stage_idx, stage_features in enumerate(training_stages):
172
- stage_start_time = time.time()
173
-
174
- print(f"\n🎯 STAGE {stage_idx + 1}/{self.curriculum_stages}")
175
- print(f"Training samples: {len(stage_features['rating'])}")
176
-
177
- # Create training dataset for this stage
178
- train_dataset = self._create_tf_dataset(stage_features, batch_size, shuffle=True)
179
-
180
- # Adaptive learning rate (decrease as stages progress)
181
- stage_lr = self.learning_rate * (0.8 ** stage_idx)
182
- self.model.optimizer.learning_rate.assign(stage_lr)
183
- print(f"Learning rate: {stage_lr:.6f}")
184
-
185
- # Train on this stage
186
- stage_history = {'loss': [], 'val_loss': []}
187
-
188
- for epoch in range(epochs_per_stage):
189
- epoch_start = time.time()
190
-
191
- # Training step
192
- train_losses = []
193
- for batch in train_dataset:
194
- with tf.GradientTape() as tape:
195
- loss_dict = self.model.compute_loss(batch, training=True)
196
- total_loss = loss_dict['total_loss']
197
-
198
- gradients = tape.gradient(total_loss, self.model.trainable_variables)
199
- self.model.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
200
-
201
- train_losses.append({k: v.numpy() for k, v in loss_dict.items()})
202
-
203
- # Average training losses
204
- avg_train_loss = {}
205
- for key in train_losses[0].keys():
206
- avg_train_loss[key] = np.mean([loss[key] for loss in train_losses])
207
-
208
- # Validation step
209
- val_losses = []
210
- for batch in val_dataset:
211
- loss_dict = self.model.compute_loss(batch, training=False)
212
- val_losses.append({k: v.numpy() for k, v in loss_dict.items()})
213
-
214
- # Average validation losses
215
- avg_val_loss = {}
216
- for key in val_losses[0].keys():
217
- avg_val_loss[key] = np.mean([loss[key] for loss in val_losses])
218
-
219
- # Record epoch results
220
- stage_history['loss'].append(avg_train_loss['total_loss'])
221
- stage_history['val_loss'].append(avg_val_loss['total_loss'])
222
-
223
- # Add to overall history
224
- for key in ['total_loss', 'rating_loss', 'retrieval_loss', 'contrastive_loss']:
225
- history[key].append(avg_train_loss[key])
226
- history[f'val_{key}'].append(avg_val_loss[key])
227
-
228
- epoch_time = time.time() - epoch_start
229
- print(f" Epoch {epoch+1:2d}/{epochs_per_stage} | "
230
- f"Loss: {avg_train_loss['total_loss']:.4f} | "
231
- f"Val: {avg_val_loss['total_loss']:.4f} | "
232
- f"Time: {epoch_time:.1f}s")
233
-
234
- stage_time = time.time() - stage_start_time
235
-
236
- # Record stage results
237
- history['stage_losses'].append(stage_history['loss'])
238
- history['stage_val_losses'].append(stage_history['val_loss'])
239
- history['stage_times'].append(stage_time)
240
-
241
- print(f"βœ… Stage {stage_idx + 1} completed in {stage_time:.1f}s")
242
-
243
- # Save intermediate model after each stage
244
- self.save_model(f"src/artifacts/", suffix=f"_stage_{stage_idx + 1}")
245
-
246
- total_time = time.time() - total_start_time
247
-
248
- print(f"\nπŸŽ“ CURRICULUM TRAINING COMPLETED!")
249
- print(f"Total time: {total_time:.1f}s")
250
- print(f"Average time per stage: {np.mean(history['stage_times']):.1f}s")
251
-
252
- return history
253
-
254
- def save_model(self, save_path: str = "src/artifacts/", suffix: str = ""):
255
- """Save the trained model."""
256
- os.makedirs(save_path, exist_ok=True)
257
-
258
- # Save model weights
259
- self.model.user_tower.save_weights(f"{save_path}/improved_user_tower_weights{suffix}")
260
- self.model.item_tower.save_weights(f"{save_path}/improved_item_tower_weights{suffix}")
261
- self.model.rating_model.save_weights(f"{save_path}/improved_rating_model_weights{suffix}")
262
-
263
- # Save temperature parameter
264
- temp_value = self.model.temperature_similarity.temperature.numpy()
265
- with open(f"{save_path}/temperature_value{suffix}.txt", 'w') as f:
266
- f.write(str(temp_value))
267
-
268
- # Save configuration
269
- config = {
270
- 'embedding_dim': self.embedding_dim,
271
- 'learning_rate': self.learning_rate,
272
- 'use_focal_loss': self.use_focal_loss,
273
- 'curriculum_stages': self.curriculum_stages
274
- }
275
-
276
- with open(f"{save_path}/curriculum_model_config{suffix}.txt", 'w') as f:
277
- for key, value in config.items():
278
- f.write(f"{key}: {value}\n")
279
-
280
- if not suffix:
281
- print(f"Model saved to {save_path}")
282
-
283
-
284
- def main():
285
- """Main function for curriculum training."""
286
-
287
- print("πŸš€ INITIALIZING CURRICULUM TRAINER")
288
-
289
- # Initialize trainer
290
- trainer = CurriculumTrainer(
291
- embedding_dim=128,
292
- learning_rate=0.001,
293
- use_focal_loss=True,
294
- curriculum_stages=3
295
- )
296
-
297
- # Load data processor
298
- print("Loading data processor...")
299
- trainer.load_data_processor()
300
-
301
- # Create improved model
302
- print("Creating improved two-tower model...")
303
- trainer.create_model()
304
-
305
- # Load training data
306
- print("Loading training data...")
307
- with open("src/artifacts/training_features.pkl", 'rb') as f:
308
- training_features = pickle.load(f)
309
-
310
- with open("src/artifacts/validation_features.pkl", 'rb') as f:
311
- validation_features = pickle.load(f)
312
-
313
- print(f"Training samples: {len(training_features['rating'])}")
314
- print(f"Validation samples: {len(validation_features['rating'])}")
315
-
316
- # Train with curriculum learning
317
- start_time = time.time()
318
-
319
- history = trainer.train_with_curriculum(
320
- training_features=training_features,
321
- validation_features=validation_features,
322
- epochs_per_stage=15,
323
- batch_size=512
324
- )
325
-
326
- total_time = time.time() - start_time
327
-
328
- # Save final model and history
329
- print("Saving final model...")
330
- trainer.save_model()
331
-
332
- with open("src/artifacts/curriculum_training_history.pkl", 'wb') as f:
333
- pickle.dump(history, f)
334
-
335
- print(f"\nβœ… CURRICULUM TRAINING COMPLETED!")
336
- print(f"Total training time: {total_time:.1f}s")
337
- print(f"All improvements implemented successfully!")
338
-
339
-
340
- if __name__ == "__main__":
341
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/fast_joint_training.py DELETED
@@ -1,268 +0,0 @@
1
- """
2
- Fast joint training with key optimizations for CPU performance.
3
- """
4
- import tensorflow as tf
5
- import numpy as np
6
- import pickle
7
- import os
8
- import time
9
- from typing import Dict
10
-
11
- from src.models.item_tower import ItemTower
12
- from src.models.user_tower import UserTower, TwoTowerModel
13
- from src.preprocessing.data_loader import DataProcessor
14
-
15
-
16
- class FastJointTrainer:
17
- """Simplified fast joint training optimized for CPU."""
18
-
19
- def __init__(self):
20
- self.item_tower = None
21
- self.user_tower = None
22
- self.model = None
23
-
24
- # Optimized hyperparameters for fast training
25
- self.user_lr = 0.003
26
- self.item_lr = 0.0003
27
- self.batch_size = 2048 # Large batch for efficiency
28
- self.epochs = 20 # Reduced epochs
29
-
30
- def load_components(self):
31
- """Load all required components."""
32
- print("Loading components...")
33
-
34
- # Load data processor
35
- data_processor = DataProcessor()
36
- data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
37
-
38
- # Load item tower config
39
- with open("src/artifacts/item_tower_config.txt", 'r') as f:
40
- config = {}
41
- for line in f:
42
- key, value = line.strip().split(': ')
43
- if key in ['embedding_dim', 'dropout_rate']:
44
- config[key] = float(value) if '.' in value else int(value)
45
- elif key == 'hidden_dims':
46
- config[key] = eval(value)
47
-
48
- # Build item tower
49
- self.item_tower = ItemTower(
50
- item_vocab_size=len(data_processor.item_vocab),
51
- category_vocab_size=len(data_processor.category_vocab),
52
- brand_vocab_size=len(data_processor.brand_vocab),
53
- **config
54
- )
55
-
56
- # Initialize and load weights
57
- dummy_input = {
58
- 'product_id': tf.constant([0]),
59
- 'category_id': tf.constant([0]),
60
- 'brand_id': tf.constant([0]),
61
- 'price': tf.constant([0.0])
62
- }
63
- _ = self.item_tower(dummy_input)
64
- self.item_tower.load_weights("src/artifacts/item_tower_weights")
65
-
66
- # Build user tower (simplified)
67
- self.user_tower = UserTower(
68
- max_history_length=50,
69
- embedding_dim=128, # Updated to 128D
70
- hidden_dims=[64], # Simplified architecture
71
- dropout_rate=0.1
72
- )
73
-
74
- # Build complete model
75
- self.model = TwoTowerModel(
76
- item_tower=self.item_tower,
77
- user_tower=self.user_tower,
78
- rating_weight=1.0,
79
- retrieval_weight=0.2 # Reduced for faster training
80
- )
81
-
82
- print("Components loaded successfully")
83
-
84
- def create_fast_dataset(self, features: Dict, is_training: bool = True):
85
- """Create optimized dataset pipeline."""
86
- dataset = tf.data.Dataset.from_tensor_slices(features)
87
-
88
- if is_training:
89
- dataset = dataset.shuffle(buffer_size=5000)
90
- dataset = dataset.repeat()
91
-
92
- dataset = dataset.batch(self.batch_size, drop_remainder=True)
93
- dataset = dataset.prefetch(2) # Conservative prefetch for CPU
94
-
95
- return dataset
96
-
97
- def train_fast(self, training_features: Dict, validation_features: Dict):
98
- """Fast training loop with key optimizations."""
99
-
100
- print(f"Starting fast training: {self.epochs} epochs, batch size {self.batch_size}")
101
-
102
- # Setup datasets
103
- steps_per_epoch = len(training_features['rating']) // self.batch_size
104
- val_steps = len(validation_features['rating']) // self.batch_size
105
-
106
- train_ds = self.create_fast_dataset(training_features, is_training=True)
107
- val_ds = self.create_fast_dataset(validation_features, is_training=False)
108
-
109
- # Note: Age and income are now categorical - no normalization needed
110
-
111
- # Setup optimizers
112
- user_optimizer = tf.keras.optimizers.Adam(learning_rate=self.user_lr)
113
- item_optimizer = tf.keras.optimizers.Adam(learning_rate=self.item_lr)
114
-
115
- # Training loop
116
- train_iter = iter(train_ds)
117
- val_iter = iter(val_ds)
118
-
119
- best_val_loss = float('inf')
120
-
121
- for epoch in range(self.epochs):
122
- epoch_start = time.time()
123
-
124
- # Progressive unfreezing - simple strategy
125
- train_item = epoch >= (self.epochs // 4) # Unfreeze after 25%
126
-
127
- print(f"Epoch {epoch+1}/{self.epochs} - Item training: {'ON' if train_item else 'OFF'}")
128
-
129
- # Training
130
- train_losses = []
131
- for step in range(steps_per_epoch):
132
- try:
133
- batch = next(train_iter)
134
- except StopIteration:
135
- train_iter = iter(train_ds)
136
- batch = next(train_iter)
137
-
138
- with tf.GradientTape() as tape:
139
- # Forward pass
140
- user_emb = self.user_tower(batch, training=True)
141
- item_emb = self.item_tower(batch, training=True)
142
-
143
- # Rating prediction
144
- concat_emb = tf.concat([user_emb, item_emb], axis=-1)
145
- rating_pred = self.model.rating_model(concat_emb, training=True)
146
-
147
- # Simple loss calculation
148
- rating_loss = tf.keras.losses.binary_crossentropy(
149
- batch["rating"], tf.squeeze(rating_pred)
150
- )
151
- rating_loss = tf.reduce_mean(rating_loss)
152
-
153
- # Simplified retrieval loss
154
- similarity = tf.reduce_sum(user_emb * item_emb, axis=1)
155
- retrieval_loss = tf.keras.losses.binary_crossentropy(
156
- batch["rating"], tf.nn.sigmoid(similarity)
157
- )
158
- retrieval_loss = tf.reduce_mean(retrieval_loss)
159
-
160
- total_loss = rating_loss + 0.2 * retrieval_loss
161
-
162
- # Gradient computation and application
163
- if train_item:
164
- # Train both towers
165
- user_vars = self.user_tower.trainable_variables + self.model.rating_model.trainable_variables
166
- item_vars = self.item_tower.trainable_variables
167
- all_vars = user_vars + item_vars
168
-
169
- grads = tape.gradient(total_loss, all_vars)
170
- user_grads = grads[:len(user_vars)]
171
- item_grads = grads[len(user_vars):]
172
-
173
- user_optimizer.apply_gradients(zip(user_grads, user_vars))
174
- item_optimizer.apply_gradients(zip(item_grads, item_vars))
175
- else:
176
- # Train only user tower
177
- user_vars = self.user_tower.trainable_variables + self.model.rating_model.trainable_variables
178
- grads = tape.gradient(total_loss, user_vars)
179
- user_optimizer.apply_gradients(zip(grads, user_vars))
180
-
181
- train_losses.append(total_loss.numpy())
182
-
183
- # Validation
184
- val_losses = []
185
- for step in range(val_steps):
186
- try:
187
- batch = next(val_iter)
188
- except StopIteration:
189
- val_iter = iter(val_ds)
190
- batch = next(val_iter)
191
-
192
- user_emb = self.user_tower(batch, training=False)
193
- item_emb = self.item_tower(batch, training=False)
194
-
195
- concat_emb = tf.concat([user_emb, item_emb], axis=-1)
196
- rating_pred = self.model.rating_model(concat_emb, training=False)
197
-
198
- rating_loss = tf.reduce_mean(
199
- tf.keras.losses.binary_crossentropy(batch["rating"], tf.squeeze(rating_pred))
200
- )
201
-
202
- similarity = tf.reduce_sum(user_emb * item_emb, axis=1)
203
- retrieval_loss = tf.reduce_mean(
204
- tf.keras.losses.binary_crossentropy(batch["rating"], tf.nn.sigmoid(similarity))
205
- )
206
-
207
- total_loss = rating_loss + 0.2 * retrieval_loss
208
- val_losses.append(total_loss.numpy())
209
-
210
- # Calculate averages
211
- avg_train_loss = np.mean(train_losses)
212
- avg_val_loss = np.mean(val_losses)
213
- epoch_time = time.time() - epoch_start
214
-
215
- print(f"Time: {epoch_time:.1f}s | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
216
-
217
- # Save best model
218
- if avg_val_loss < best_val_loss:
219
- best_val_loss = avg_val_loss
220
- self.save_model("_best")
221
-
222
- print("Fast training completed!")
223
-
224
- def save_model(self, suffix=""):
225
- """Save trained model."""
226
- save_path = "src/artifacts/"
227
-
228
- self.user_tower.save_weights(f"{save_path}/user_tower_weights{suffix}")
229
- self.item_tower.save_weights(f"{save_path}/item_tower_weights_finetuned{suffix}")
230
- self.model.rating_model.save_weights(f"{save_path}/rating_model_weights{suffix}")
231
-
232
- if not suffix:
233
- print("Model saved successfully")
234
-
235
-
236
- def main():
237
- """Main function for fast joint training."""
238
-
239
- print("=== Fast Joint Training ===")
240
-
241
- # Initialize trainer
242
- trainer = FastJointTrainer()
243
- trainer.load_components()
244
-
245
- # Load training data
246
- print("Loading training data...")
247
- with open("src/artifacts/training_features.pkl", 'rb') as f:
248
- training_features = pickle.load(f)
249
-
250
- with open("src/artifacts/validation_features.pkl", 'rb') as f:
251
- validation_features = pickle.load(f)
252
-
253
- print(f"Training samples: {len(training_features['rating']):,}")
254
- print(f"Validation samples: {len(validation_features['rating']):,}")
255
-
256
- # Start training
257
- start_time = time.time()
258
- trainer.train_fast(training_features, validation_features)
259
-
260
- total_time = time.time() - start_time
261
- trainer.save_model()
262
-
263
- print(f"\\nTraining completed in {total_time:.1f} seconds!")
264
- print(f"Average time per epoch: {total_time/trainer.epochs:.1f}s")
265
-
266
-
267
- if __name__ == "__main__":
268
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/improved_joint_training.py DELETED
@@ -1,462 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Improved joint training with hard negative mining, curriculum learning, and better optimization.
4
- """
5
-
6
- import tensorflow as tf
7
- import numpy as np
8
- import pickle
9
- import os
10
- from typing import Dict, List, Tuple, Optional
11
- import time
12
- from collections import defaultdict
13
-
14
- from src.models.improved_two_tower import create_improved_model
15
- from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
16
-
17
-
18
- class HardNegativeSampler:
19
- """Hard negative sampling strategy for better training."""
20
-
21
- def __init__(self, model, item_embeddings, sampling_strategy='mixed'):
22
- self.model = model
23
- self.item_embeddings = item_embeddings # Pre-computed item embeddings
24
- self.sampling_strategy = sampling_strategy
25
-
26
- def sample_hard_negatives(self, user_embeddings, positive_items, k_hard=2, k_random=2):
27
- """Sample hard negatives based on user-item similarity."""
28
- batch_size = tf.shape(user_embeddings)[0]
29
-
30
- # Compute similarities between users and all items
31
- similarities = tf.linalg.matmul(user_embeddings, self.item_embeddings, transpose_b=True)
32
-
33
- # Mask out positive items
34
- positive_mask = tf.one_hot(positive_items, depth=tf.shape(self.item_embeddings)[0])
35
- similarities = similarities - positive_mask * 1e9 # Large negative value
36
-
37
- # Get top-k similar items (hard negatives)
38
- _, hard_negative_indices = tf.nn.top_k(similarities, k=k_hard)
39
-
40
- # Sample random negatives
41
- total_items = tf.shape(self.item_embeddings)[0]
42
- random_negatives = tf.random.uniform(
43
- shape=[batch_size, k_random],
44
- minval=0,
45
- maxval=total_items,
46
- dtype=tf.int32
47
- )
48
-
49
- # Combine hard and random negatives
50
- if self.sampling_strategy == 'hard':
51
- return hard_negative_indices
52
- elif self.sampling_strategy == 'random':
53
- return random_negatives
54
- else: # mixed
55
- return tf.concat([hard_negative_indices, random_negatives], axis=1)
56
-
57
-
58
- class CurriculumLearningScheduler:
59
- """Curriculum learning scheduler for progressive difficulty."""
60
-
61
- def __init__(self, total_epochs, warmup_epochs=10):
62
- self.total_epochs = total_epochs
63
- self.warmup_epochs = warmup_epochs
64
-
65
- def get_difficulty_schedule(self, epoch):
66
- """Get curriculum parameters for current epoch."""
67
- if epoch < self.warmup_epochs:
68
- # Easy phase: more random negatives, lower temperature
69
- hard_negative_ratio = 0.2
70
- temperature = 2.0
71
- negative_samples = 2
72
- elif epoch < self.total_epochs * 0.6:
73
- # Medium phase: balanced negatives
74
- hard_negative_ratio = 0.5
75
- temperature = 1.0
76
- negative_samples = 4
77
- else:
78
- # Hard phase: more hard negatives, higher temperature
79
- hard_negative_ratio = 0.8
80
- temperature = 0.5
81
- negative_samples = 6
82
-
83
- return {
84
- 'hard_negative_ratio': hard_negative_ratio,
85
- 'temperature': temperature,
86
- 'negative_samples': negative_samples
87
- }
88
-
89
-
90
- class ImprovedJointTrainer:
91
- """Enhanced joint trainer with advanced techniques."""
92
-
93
- def __init__(self,
94
- embedding_dim: int = 128,
95
- learning_rate: float = 0.001,
96
- use_mixed_precision: bool = True,
97
- use_curriculum_learning: bool = True,
98
- use_hard_negatives: bool = True):
99
-
100
- self.embedding_dim = embedding_dim
101
- self.learning_rate = learning_rate
102
- self.use_mixed_precision = use_mixed_precision
103
- self.use_curriculum_learning = use_curriculum_learning
104
- self.use_hard_negatives = use_hard_negatives
105
-
106
- # Enable mixed precision if requested
107
- if use_mixed_precision:
108
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
109
- tf.keras.mixed_precision.set_global_policy(policy)
110
-
111
- self.model = None
112
- self.data_processor = None
113
- self.curriculum_scheduler = None
114
- self.hard_negative_sampler = None
115
-
116
- def setup_model(self, data_processor: DataProcessor):
117
- """Setup the improved model."""
118
- self.data_processor = data_processor
119
-
120
- # Create improved model
121
- self.model = create_improved_model(
122
- data_processor=data_processor,
123
- embedding_dim=self.embedding_dim,
124
- use_bias=True,
125
- use_focal_loss=True
126
- )
127
-
128
- print(f"Created improved two-tower model with {self.embedding_dim}D embeddings")
129
-
130
- def setup_curriculum_learning(self, total_epochs: int):
131
- """Setup curriculum learning scheduler."""
132
- if self.use_curriculum_learning:
133
- self.curriculum_scheduler = CurriculumLearningScheduler(
134
- total_epochs=total_epochs,
135
- warmup_epochs=max(5, total_epochs // 10)
136
- )
137
- print("Curriculum learning enabled")
138
-
139
- def setup_hard_negative_sampling(self, item_features: Dict[str, np.ndarray]):
140
- """Setup hard negative sampling."""
141
- if self.use_hard_negatives:
142
- # Pre-compute item embeddings for efficient hard negative sampling
143
- print("Pre-computing item embeddings for hard negative sampling...")
144
-
145
- # Create a dummy batch to get item embeddings
146
- batch_size = 1000
147
- total_items = len(item_features['product_id'])
148
-
149
- item_embeddings_list = []
150
- for i in range(0, total_items, batch_size):
151
- end_idx = min(i + batch_size, total_items)
152
- batch_features = {
153
- key: tf.constant(value[i:end_idx])
154
- for key, value in item_features.items()
155
- }
156
-
157
- item_emb_output = self.model.item_tower(batch_features, training=False)
158
- if isinstance(item_emb_output, tuple):
159
- item_emb = item_emb_output[0] # Get embeddings, ignore bias
160
- else:
161
- item_emb = item_emb_output
162
-
163
- item_embeddings_list.append(item_emb.numpy())
164
-
165
- item_embeddings = np.vstack(item_embeddings_list)
166
-
167
- self.hard_negative_sampler = HardNegativeSampler(
168
- model=self.model,
169
- item_embeddings=tf.constant(item_embeddings, dtype=tf.float32),
170
- sampling_strategy='mixed'
171
- )
172
- print(f"Hard negative sampling enabled with {len(item_embeddings)} items")
173
-
174
- def create_advanced_training_dataset(self,
175
- features: Dict[str, np.ndarray],
176
- batch_size: int = 256,
177
- epoch: int = 0) -> tf.data.Dataset:
178
- """Create training dataset with curriculum learning and hard negatives."""
179
-
180
- # Get curriculum parameters
181
- if self.curriculum_scheduler:
182
- curriculum_params = self.curriculum_scheduler.get_difficulty_schedule(epoch)
183
- print(f"Epoch {epoch}: {curriculum_params}")
184
- else:
185
- curriculum_params = {
186
- 'hard_negative_ratio': 0.5,
187
- 'temperature': 1.0,
188
- 'negative_samples': 4
189
- }
190
-
191
- # Filter data based on curriculum (start with easier examples)
192
- if epoch < 5: # Warmup epochs - use only high-confidence positive examples
193
- positive_mask = features['rating'] == 1.0
194
- if np.sum(positive_mask) > 0:
195
- # Sample subset of positives and all negatives
196
- positive_indices = np.where(positive_mask)[0]
197
- negative_indices = np.where(features['rating'] == 0.0)[0]
198
-
199
- # Sample subset for easier learning
200
- n_positive_samples = min(len(positive_indices), len(negative_indices))
201
- selected_positive = np.random.choice(
202
- positive_indices, size=n_positive_samples, replace=False
203
- )
204
- selected_negative = np.random.choice(
205
- negative_indices, size=n_positive_samples, replace=False
206
- )
207
-
208
- selected_indices = np.concatenate([selected_positive, selected_negative])
209
- np.random.shuffle(selected_indices)
210
-
211
- # Filter features
212
- filtered_features = {
213
- key: value[selected_indices] for key, value in features.items()
214
- }
215
- else:
216
- filtered_features = features
217
- else:
218
- filtered_features = features
219
-
220
- # Create dataset
221
- dataset = create_tf_dataset(filtered_features, batch_size, shuffle=True)
222
-
223
- return dataset
224
-
225
- def compile_model(self):
226
- """Compile model with advanced optimizer."""
227
- # Use AdamW with learning rate scheduling
228
- initial_learning_rate = self.learning_rate
229
- lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
230
- initial_learning_rate=initial_learning_rate,
231
- first_decay_steps=1000,
232
- t_mul=2.0,
233
- m_mul=0.9,
234
- alpha=0.01
235
- )
236
-
237
- optimizer = tf.keras.optimizers.AdamW(
238
- learning_rate=lr_schedule,
239
- weight_decay=1e-5,
240
- beta_1=0.9,
241
- beta_2=0.999,
242
- epsilon=1e-7
243
- )
244
-
245
- # Enable mixed precision optimizer if needed
246
- if self.use_mixed_precision:
247
- optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
248
-
249
- self.optimizer = optimizer
250
- print(f"Model compiled with AdamW optimizer (lr={self.learning_rate})")
251
-
252
- @tf.function
253
- def train_step(self, features):
254
- """Optimized training step with gradient scaling."""
255
- with tf.GradientTape() as tape:
256
- # Forward pass
257
- loss_dict = self.model.compute_loss(features, training=True)
258
- total_loss = loss_dict['total_loss']
259
-
260
- # Scale loss for mixed precision
261
- if self.use_mixed_precision:
262
- scaled_loss = self.optimizer.get_scaled_loss(total_loss)
263
- else:
264
- scaled_loss = total_loss
265
-
266
- # Compute gradients
267
- if self.use_mixed_precision:
268
- scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
269
- gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
270
- else:
271
- gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
272
-
273
- # Clip gradients to prevent exploding gradients
274
- gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
275
-
276
- # Apply gradients
277
- self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
278
-
279
- return loss_dict
280
-
281
- def evaluate_model(self, validation_dataset):
282
- """Evaluate model on validation set."""
283
- total_losses = defaultdict(list)
284
-
285
- for batch in validation_dataset:
286
- loss_dict = self.model.compute_loss(batch, training=False)
287
- for key, value in loss_dict.items():
288
- total_losses[key].append(float(value))
289
-
290
- # Average losses
291
- avg_losses = {key: np.mean(values) for key, values in total_losses.items()}
292
- return avg_losses
293
-
294
- def train(self,
295
- training_features: Dict[str, np.ndarray],
296
- validation_features: Dict[str, np.ndarray],
297
- epochs: int = 50,
298
- batch_size: int = 256,
299
- save_path: str = "src/artifacts/") -> Dict:
300
- """Enhanced training loop with all improvements."""
301
-
302
- print(f"Starting improved training for {epochs} epochs...")
303
-
304
- # Setup components
305
- self.setup_curriculum_learning(epochs)
306
- self.compile_model()
307
-
308
- # Create validation dataset
309
- validation_dataset = create_tf_dataset(validation_features, batch_size, shuffle=False)
310
-
311
- # Training history
312
- history = defaultdict(list)
313
- best_val_loss = float('inf')
314
- patience_counter = 0
315
- early_stopping_patience = 10
316
-
317
- # Training loop
318
- for epoch in range(epochs):
319
- epoch_start_time = time.time()
320
-
321
- # Create training dataset for this epoch (curriculum learning)
322
- training_dataset = self.create_advanced_training_dataset(
323
- training_features, batch_size, epoch
324
- )
325
-
326
- # Training
327
- epoch_losses = defaultdict(list)
328
- num_batches = 0
329
-
330
- for batch in training_dataset:
331
- loss_dict = self.train_step(batch)
332
-
333
- for key, value in loss_dict.items():
334
- epoch_losses[key].append(float(value))
335
- num_batches += 1
336
-
337
- # Average training losses
338
- avg_train_losses = {
339
- key: np.mean(values) for key, values in epoch_losses.items()
340
- }
341
-
342
- # Validation
343
- avg_val_losses = self.evaluate_model(validation_dataset)
344
-
345
- # Log progress
346
- epoch_time = time.time() - epoch_start_time
347
- print(f"Epoch {epoch+1}/{epochs} ({epoch_time:.1f}s):")
348
- print(f" Train Loss: {avg_train_losses['total_loss']:.4f}")
349
- print(f" Val Loss: {avg_val_losses['total_loss']:.4f}")
350
- print(f" Val Rating Loss: {avg_val_losses['rating_loss']:.4f}")
351
- print(f" Val Retrieval Loss: {avg_val_losses['retrieval_loss']:.4f}")
352
-
353
- # Save history
354
- for key, value in avg_train_losses.items():
355
- history[f'train_{key}'].append(value)
356
- for key, value in avg_val_losses.items():
357
- history[f'val_{key}'].append(value)
358
-
359
- # Early stopping and model saving
360
- current_val_loss = avg_val_losses['total_loss']
361
- if current_val_loss < best_val_loss:
362
- best_val_loss = current_val_loss
363
- patience_counter = 0
364
-
365
- # Save best model
366
- self.save_model(save_path, suffix='_improved_best')
367
- print(f" πŸ’Ύ Saved best model (val_loss: {best_val_loss:.4f})")
368
- else:
369
- patience_counter += 1
370
-
371
- if patience_counter >= early_stopping_patience:
372
- print(f"Early stopping at epoch {epoch+1}")
373
- break
374
-
375
- # Save final model and history
376
- self.save_model(save_path, suffix='_improved_final')
377
- self.save_training_history(dict(history), save_path)
378
-
379
- print("βœ… Improved training completed!")
380
- return dict(history)
381
-
382
- def save_model(self, save_path: str, suffix: str = ''):
383
- """Save the trained model components."""
384
- os.makedirs(save_path, exist_ok=True)
385
-
386
- # Save model weights
387
- self.model.item_tower.save_weights(f"{save_path}/improved_item_tower_weights{suffix}")
388
- self.model.user_tower.save_weights(f"{save_path}/improved_user_tower_weights{suffix}")
389
-
390
- if hasattr(self.model, 'rating_model'):
391
- self.model.rating_model.save_weights(f"{save_path}/improved_rating_model_weights{suffix}")
392
-
393
- # Save configuration
394
- config = {
395
- 'embedding_dim': self.embedding_dim,
396
- 'learning_rate': self.learning_rate,
397
- 'use_mixed_precision': self.use_mixed_precision,
398
- 'use_curriculum_learning': self.use_curriculum_learning,
399
- 'use_hard_negatives': self.use_hard_negatives
400
- }
401
-
402
- with open(f"{save_path}/improved_model_config{suffix}.txt", 'w') as f:
403
- for key, value in config.items():
404
- f.write(f"{key}: {value}\n")
405
-
406
- print(f"Model saved to {save_path} with suffix '{suffix}'")
407
-
408
- def save_training_history(self, history: Dict, save_path: str):
409
- """Save training history."""
410
- with open(f"{save_path}/improved_training_history.pkl", 'wb') as f:
411
- pickle.dump(history, f)
412
- print(f"Training history saved to {save_path}")
413
-
414
-
415
- def main():
416
- """Demo of improved training."""
417
- print("πŸš€ IMPROVED TWO-TOWER TRAINING DEMO")
418
- print("="*60)
419
-
420
- # Load data
421
- print("Loading training data...")
422
- try:
423
- with open("src/artifacts/training_features.pkl", 'rb') as f:
424
- training_features = pickle.load(f)
425
- with open("src/artifacts/validation_features.pkl", 'rb') as f:
426
- validation_features = pickle.load(f)
427
-
428
- print(f"Loaded {len(training_features['rating'])} training samples")
429
- print(f"Loaded {len(validation_features['rating'])} validation samples")
430
- except FileNotFoundError:
431
- print("❌ Training data not found. Please run data preparation first.")
432
- return
433
-
434
- # Load data processor
435
- data_processor = DataProcessor()
436
- data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
437
-
438
- # Create trainer
439
- trainer = ImprovedJointTrainer(
440
- embedding_dim=128,
441
- learning_rate=0.001,
442
- use_mixed_precision=True,
443
- use_curriculum_learning=True,
444
- use_hard_negatives=True
445
- )
446
-
447
- # Setup and train
448
- trainer.setup_model(data_processor)
449
-
450
- # Train model
451
- history = trainer.train(
452
- training_features=training_features,
453
- validation_features=validation_features,
454
- epochs=30,
455
- batch_size=256
456
- )
457
-
458
- print("βœ… Improved training completed successfully!")
459
-
460
-
461
- if __name__ == "__main__":
462
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/optimized_joint_training.py DELETED
@@ -1,439 +0,0 @@
1
- import tensorflow as tf
2
- import numpy as np
3
- import pickle
4
- import os
5
- import time
6
- from typing import Dict, List, Tuple
7
-
8
- from src.models.item_tower import ItemTower
9
- from src.models.user_tower import UserTower, TwoTowerModel
10
- from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
11
-
12
-
13
- class OptimizedJointTrainer:
14
- """Optimized joint training with performance enhancements."""
15
-
16
- def __init__(self,
17
- embedding_dim: int = 128, # Updated to 128D output
18
- user_learning_rate: float = 0.001,
19
- item_learning_rate: float = 0.0001,
20
- rating_weight: float = 1.0,
21
- retrieval_weight: float = 1.0,
22
- gradient_accumulation_steps: int = 1,
23
- use_mixed_precision: bool = False): # Disabled for CPU training
24
-
25
- self.embedding_dim = embedding_dim
26
- self.user_learning_rate = user_learning_rate
27
- self.item_learning_rate = item_learning_rate
28
- self.rating_weight = rating_weight
29
- self.retrieval_weight = retrieval_weight
30
- self.gradient_accumulation_steps = gradient_accumulation_steps
31
- self.use_mixed_precision = use_mixed_precision
32
-
33
- # Enable mixed precision for faster training
34
- if self.use_mixed_precision:
35
- tf.keras.mixed_precision.set_global_policy('mixed_float16')
36
- print("Mixed precision training enabled")
37
-
38
- self.item_tower = None
39
- self.user_tower = None
40
- self.model = None
41
-
42
- # Precompile TensorFlow functions for speed
43
- self._compiled_train_step = None
44
- self._compiled_val_step = None
45
-
46
- def load_pre_trained_item_tower(self, artifacts_path: str = "src/artifacts/") -> ItemTower:
47
- """Load pre-trained item tower with optimizations."""
48
- data_processor = DataProcessor()
49
- data_processor.load_vocabularies(f"{artifacts_path}/vocabularies.pkl")
50
-
51
- with open(f"{artifacts_path}/item_tower_config.txt", 'r') as f:
52
- config = {}
53
- for line in f:
54
- key, value = line.strip().split(': ')
55
- if key in ['embedding_dim', 'dropout_rate']:
56
- config[key] = float(value) if '.' in value else int(value)
57
- elif key == 'hidden_dims':
58
- config[key] = eval(value)
59
-
60
- self.item_tower = ItemTower(
61
- item_vocab_size=len(data_processor.item_vocab),
62
- category_vocab_size=len(data_processor.category_vocab),
63
- brand_vocab_size=len(data_processor.brand_vocab),
64
- **config
65
- )
66
-
67
- dummy_input = {
68
- 'product_id': tf.constant([0]),
69
- 'category_id': tf.constant([0]),
70
- 'brand_id': tf.constant([0]),
71
- 'price': tf.constant([0.0])
72
- }
73
- _ = self.item_tower(dummy_input)
74
- self.item_tower.load_weights(f"{artifacts_path}/item_tower_weights")
75
-
76
- print("Pre-trained item tower loaded successfully")
77
- return self.item_tower
78
-
79
- def build_user_tower(self, max_history_length: int = 50) -> UserTower:
80
- """Build user tower with optimizations."""
81
- self.user_tower = UserTower(
82
- max_history_length=max_history_length,
83
- embedding_dim=self.embedding_dim,
84
- hidden_dims=[128, 64],
85
- dropout_rate=0.1 # Reduced dropout for faster training
86
- )
87
-
88
- print("User tower initialized")
89
- return self.user_tower
90
-
91
- def build_two_tower_model(self) -> TwoTowerModel:
92
- """Build complete two-tower model."""
93
- if self.item_tower is None or self.user_tower is None:
94
- raise ValueError("Both towers must be initialized first")
95
-
96
- self.model = TwoTowerModel(
97
- item_tower=self.item_tower,
98
- user_tower=self.user_tower,
99
- rating_weight=self.rating_weight,
100
- retrieval_weight=self.retrieval_weight
101
- )
102
-
103
- print("Two-tower model built successfully")
104
- return self.model
105
-
106
- def create_optimized_dataset(self, features: Dict[str, np.ndarray],
107
- batch_size: int,
108
- is_training: bool = True) -> tf.data.Dataset:
109
- """Create optimized dataset pipeline for faster training."""
110
-
111
- dataset = tf.data.Dataset.from_tensor_slices(features)
112
-
113
- if is_training:
114
- # Optimized shuffling and prefetching
115
- dataset = dataset.shuffle(buffer_size=min(10000, len(features['rating'])))
116
- dataset = dataset.repeat() # Repeat for multiple epochs
117
-
118
- dataset = dataset.batch(batch_size, drop_remainder=True)
119
-
120
- # Optimize for CPU training
121
- dataset = dataset.prefetch(tf.data.AUTOTUNE)
122
-
123
- return dataset
124
-
125
- @tf.function(experimental_relax_shapes=True)
126
- def optimized_train_step(self, batch: Dict[str, tf.Tensor],
127
- user_optimizer: tf.keras.optimizers.Optimizer,
128
- item_optimizer: tf.keras.optimizers.Optimizer,
129
- train_item: bool) -> Dict[str, tf.Tensor]:
130
- """Optimized training step with tf.function compilation."""
131
-
132
- with tf.GradientTape() as tape:
133
- # Forward pass
134
- user_embeddings = self.user_tower(batch, training=True)
135
- item_embeddings = self.item_tower(batch, training=True)
136
-
137
- # Concatenate and predict rating
138
- concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
139
- rating_predictions = self.model.rating_model(concatenated, training=True)
140
-
141
- # Compute losses - fix shape mismatch
142
- rating_loss = tf.keras.losses.binary_crossentropy(
143
- tf.expand_dims(batch["rating"], -1), rating_predictions
144
- )
145
- rating_loss = tf.reduce_mean(rating_loss)
146
-
147
- # Retrieval loss - cosine similarity
148
- user_norm = tf.nn.l2_normalize(user_embeddings, axis=1)
149
- item_norm = tf.nn.l2_normalize(item_embeddings, axis=1)
150
- similarities = tf.reduce_sum(user_norm * item_norm, axis=1)
151
-
152
- retrieval_loss = tf.keras.losses.binary_crossentropy(
153
- batch["rating"], tf.nn.sigmoid(similarities)
154
- )
155
- retrieval_loss = tf.reduce_mean(retrieval_loss)
156
-
157
- total_loss = (
158
- self.rating_weight * rating_loss +
159
- self.retrieval_weight * retrieval_loss
160
- )
161
-
162
- # Handle mixed precision
163
- if self.use_mixed_precision:
164
- total_loss = user_optimizer.get_scaled_loss(total_loss)
165
-
166
- # Compute gradients
167
- user_vars = self.user_tower.trainable_variables + self.model.rating_model.trainable_variables
168
-
169
- if train_item:
170
- all_vars = user_vars + self.item_tower.trainable_variables
171
- gradients = tape.gradient(total_loss, all_vars)
172
-
173
- if self.use_mixed_precision:
174
- gradients = user_optimizer.get_unscaled_gradients(gradients)
175
-
176
- # Split gradients
177
- user_grads = gradients[:len(user_vars)]
178
- item_grads = gradients[len(user_vars):]
179
-
180
- # Apply gradients
181
- user_optimizer.apply_gradients(zip(user_grads, user_vars))
182
- item_optimizer.apply_gradients(zip(item_grads, self.item_tower.trainable_variables))
183
- else:
184
- gradients = tape.gradient(total_loss, user_vars)
185
-
186
- if self.use_mixed_precision:
187
- gradients = user_optimizer.get_unscaled_gradients(gradients)
188
-
189
- user_optimizer.apply_gradients(zip(gradients, user_vars))
190
-
191
- # Convert back from scaled loss for logging
192
- if self.use_mixed_precision:
193
- total_loss = total_loss / user_optimizer.loss_scale
194
- rating_loss = rating_loss / user_optimizer.loss_scale
195
- retrieval_loss = retrieval_loss / user_optimizer.loss_scale
196
-
197
- return {
198
- 'total_loss': total_loss,
199
- 'rating_loss': rating_loss,
200
- 'retrieval_loss': retrieval_loss
201
- }
202
-
203
- @tf.function(experimental_relax_shapes=True)
204
- def optimized_val_step(self, batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
205
- """Optimized validation step."""
206
-
207
- user_embeddings = self.user_tower(batch, training=False)
208
- item_embeddings = self.item_tower(batch, training=False)
209
-
210
- concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
211
- rating_predictions = self.model.rating_model(concatenated, training=False)
212
-
213
- rating_loss = tf.reduce_mean(
214
- tf.keras.losses.binary_crossentropy(tf.expand_dims(batch["rating"], -1), rating_predictions)
215
- )
216
-
217
- # Retrieval loss
218
- user_norm = tf.nn.l2_normalize(user_embeddings, axis=1)
219
- item_norm = tf.nn.l2_normalize(item_embeddings, axis=1)
220
- similarities = tf.reduce_sum(user_norm * item_norm, axis=1)
221
-
222
- retrieval_loss = tf.reduce_mean(
223
- tf.keras.losses.binary_crossentropy(batch["rating"], tf.nn.sigmoid(similarities))
224
- )
225
-
226
- total_loss = self.rating_weight * rating_loss + self.retrieval_weight * retrieval_loss
227
-
228
- return {
229
- 'total_loss': total_loss,
230
- 'rating_loss': rating_loss,
231
- 'retrieval_loss': retrieval_loss
232
- }
233
-
234
- def train(self,
235
- training_features: Dict[str, np.ndarray],
236
- validation_features: Dict[str, np.ndarray],
237
- epochs: int = 50, # Reduced default epochs
238
- batch_size: int = 512) -> Dict: # Larger batch size for efficiency
239
- """Optimized training loop."""
240
-
241
- print(f"Starting optimized joint training for {epochs} epochs...")
242
- print(f"Batch size: {batch_size}")
243
- print(f"Mixed precision: {self.use_mixed_precision}")
244
-
245
- # Create optimized datasets
246
- steps_per_epoch = len(training_features['rating']) // batch_size
247
- val_steps = len(validation_features['rating']) // batch_size
248
-
249
- train_dataset = self.create_optimized_dataset(training_features, batch_size, is_training=True)
250
- val_dataset = self.create_optimized_dataset(validation_features, batch_size, is_training=False)
251
-
252
- # Note: Age and income are now categorical - no normalization needed
253
-
254
- # Setup optimizers with mixed precision
255
- if self.use_mixed_precision:
256
- user_optimizer = tf.keras.optimizers.Adam(learning_rate=self.user_learning_rate)
257
- user_optimizer = tf.keras.mixed_precision.LossScaleOptimizer(user_optimizer)
258
- item_optimizer = tf.keras.optimizers.Adam(learning_rate=self.item_learning_rate)
259
- item_optimizer = tf.keras.mixed_precision.LossScaleOptimizer(item_optimizer)
260
- else:
261
- user_optimizer = tf.keras.optimizers.Adam(learning_rate=self.user_learning_rate)
262
- item_optimizer = tf.keras.optimizers.Adam(learning_rate=self.item_learning_rate)
263
-
264
- # Training history
265
- history = {
266
- 'total_loss': [], 'rating_loss': [], 'retrieval_loss': [],
267
- 'val_total_loss': [], 'val_rating_loss': [], 'val_retrieval_loss': [],
268
- 'epoch_times': []
269
- }
270
-
271
- best_val_loss = float('inf')
272
- patience_counter = 0
273
- patience = 10 # Reduced patience for faster training
274
-
275
- train_iter = iter(train_dataset)
276
- val_iter = iter(val_dataset)
277
-
278
- for epoch in range(epochs):
279
- epoch_start_time = time.time()
280
- print(f"\\nEpoch {epoch + 1}/{epochs}")
281
-
282
- # Determine training strategy
283
- freeze_threshold = int(0.2 * epochs) # Reduced freeze period
284
- train_item = epoch >= freeze_threshold
285
-
286
- print(f"Training: User=βœ“, Item={'βœ“' if train_item else 'βœ—'}")
287
-
288
- # Training loop
289
- epoch_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
290
-
291
- for step in range(steps_per_epoch):
292
- try:
293
- batch = next(train_iter)
294
- except StopIteration:
295
- train_iter = iter(train_dataset)
296
- batch = next(train_iter)
297
-
298
- losses = self.optimized_train_step(batch, user_optimizer, item_optimizer, train_item)
299
-
300
- for key in epoch_losses:
301
- epoch_losses[key].append(losses[key])
302
-
303
- # Calculate training averages
304
- avg_train_losses = {k: tf.reduce_mean(v).numpy() for k, v in epoch_losses.items()}
305
-
306
- # Validation loop
307
- val_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
308
-
309
- for step in range(val_steps):
310
- try:
311
- batch = next(val_iter)
312
- except StopIteration:
313
- val_iter = iter(val_dataset)
314
- batch = next(val_iter)
315
-
316
- losses = self.optimized_val_step(batch)
317
-
318
- for key in val_losses:
319
- val_losses[key].append(losses[key])
320
-
321
- avg_val_losses = {k: tf.reduce_mean(v).numpy() for k, v in val_losses.items()}
322
-
323
- # Record history
324
- epoch_time = time.time() - epoch_start_time
325
- history['epoch_times'].append(epoch_time)
326
-
327
- for key in ['total_loss', 'rating_loss', 'retrieval_loss']:
328
- history[key].append(avg_train_losses[key])
329
- history[f'val_{key}'].append(avg_val_losses[key])
330
-
331
- # Print progress
332
- print(f"Time: {epoch_time:.1f}s | "
333
- f"Train Loss: {avg_train_losses['total_loss']:.4f} | "
334
- f"Val Loss: {avg_val_losses['total_loss']:.4f}")
335
-
336
- # Early stopping with model saving
337
- if avg_val_losses['total_loss'] < best_val_loss:
338
- best_val_loss = avg_val_losses['total_loss']
339
- patience_counter = 0
340
- self.save_model("src/artifacts/", suffix="_best")
341
- else:
342
- patience_counter += 1
343
- if patience_counter >= patience:
344
- print(f"Early stopping at epoch {epoch + 1}")
345
- break
346
-
347
- avg_epoch_time = np.mean(history['epoch_times'])
348
- print(f"\\nTraining completed!")
349
- print(f"Average epoch time: {avg_epoch_time:.1f}s")
350
- print(f"Total training time: {sum(history['epoch_times']):.1f}s")
351
-
352
- return history
353
-
354
- def save_model(self, save_path: str = "src/artifacts/", suffix: str = ""):
355
- """Save the trained model."""
356
- os.makedirs(save_path, exist_ok=True)
357
-
358
- self.user_tower.save_weights(f"{save_path}/user_tower_weights{suffix}")
359
- self.item_tower.save_weights(f"{save_path}/item_tower_weights_finetuned{suffix}")
360
- self.model.rating_model.save_weights(f"{save_path}/rating_model_weights{suffix}")
361
-
362
- config = {
363
- 'embedding_dim': self.embedding_dim,
364
- 'user_learning_rate': self.user_learning_rate,
365
- 'item_learning_rate': self.item_learning_rate,
366
- 'rating_weight': self.rating_weight,
367
- 'retrieval_weight': self.retrieval_weight,
368
- 'use_mixed_precision': self.use_mixed_precision
369
- }
370
-
371
- with open(f"{save_path}/optimized_joint_model_config{suffix}.txt", 'w') as f:
372
- for key, value in config.items():
373
- f.write(f"{key}: {value}\\n")
374
-
375
- if not suffix:
376
- print(f"Optimized model saved to {save_path}")
377
-
378
-
379
- def main():
380
- """Main function for optimized joint training."""
381
-
382
- print("Initializing optimized joint trainer...")
383
- trainer = OptimizedJointTrainer(
384
- embedding_dim=128, # Updated to 128D
385
- user_learning_rate=0.002, # Slightly higher for faster convergence
386
- item_learning_rate=0.0002,
387
- rating_weight=1.0,
388
- retrieval_weight=0.3, # Reduced for faster training
389
- use_mixed_precision=False # Disabled for CPU
390
- )
391
-
392
- # Load components
393
- print("Loading pre-trained item tower...")
394
- trainer.load_pre_trained_item_tower()
395
-
396
- print("Building user tower...")
397
- trainer.build_user_tower(max_history_length=50)
398
-
399
- print("Building two-tower model...")
400
- trainer.build_two_tower_model()
401
-
402
- # Load training data
403
- print("Loading training data...")
404
- with open("src/artifacts/training_features.pkl", 'rb') as f:
405
- training_features = pickle.load(f)
406
-
407
- with open("src/artifacts/validation_features.pkl", 'rb') as f:
408
- validation_features = pickle.load(f)
409
-
410
- print(f"Training samples: {len(training_features['rating'])}")
411
- print(f"Validation samples: {len(validation_features['rating'])}")
412
-
413
- # Train with optimizations
414
- print("Starting optimized training...")
415
- start_time = time.time()
416
-
417
- history = trainer.train(
418
- training_features=training_features,
419
- validation_features=validation_features,
420
- epochs=30, # Reduced epochs for faster training
421
- batch_size=1024 # Larger batch size for better GPU utilization
422
- )
423
-
424
- total_time = time.time() - start_time
425
-
426
- # Save final model and history
427
- print("Saving final model...")
428
- trainer.save_model()
429
-
430
- with open("src/artifacts/optimized_training_history.pkl", 'wb') as f:
431
- pickle.dump(history, f)
432
-
433
- print(f"\\nOptimized joint training completed!")
434
- print(f"Total training time: {total_time:.1f}s")
435
- print(f"Average time per epoch: {total_time/len(history['epoch_times']):.1f}s")
436
-
437
-
438
- if __name__ == "__main__":
439
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/real_user_selector.py CHANGED
@@ -93,22 +93,58 @@ class RealUserSelector:
93
  """
94
  print(f"Selecting {n} real users with at least {min_interactions} interactions...")
95
 
96
- # Filter users with sufficient interactions
97
- active_users = []
 
 
98
  for _, user in self.users_df.iterrows():
99
  user_id = user['user_id']
100
- if (user_id in self.user_stats and
101
- self.user_stats[user_id]['total_interactions'] >= min_interactions):
102
- active_users.append(user)
 
 
 
 
 
 
103
 
104
- print(f"Found {len(active_users)} active users with >={min_interactions} interactions")
 
105
 
106
- # Randomly sample n users
107
- if len(active_users) < n:
108
- print(f"Warning: Only {len(active_users)} users available, returning all")
109
- selected_users = active_users
 
110
  else:
111
- selected_users = random.sample(active_users, n)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Build user profiles with real data
114
  real_user_profiles = []
@@ -124,6 +160,10 @@ class RealUserSelector:
124
  'age': int(user['age']),
125
  'gender': user['gender'],
126
  'income': int(user['income']),
 
 
 
 
127
  'interaction_history': stats['unique_items'][:50], # Limit to 50 most recent
128
  'interaction_stats': {
129
  'total_interactions': stats['total_interactions'],
 
93
  """
94
  print(f"Selecting {n} real users with at least {min_interactions} interactions...")
95
 
96
+ # Filter users with sufficient interactions, separating by interaction count
97
+ high_interaction_users = [] # >14 interactions
98
+ low_interaction_users = [] # min_interactions to 14 interactions
99
+
100
  for _, user in self.users_df.iterrows():
101
  user_id = user['user_id']
102
+ if user_id in self.user_stats:
103
+ interaction_count = self.user_stats[user_id]['total_interactions']
104
+ if interaction_count >= min_interactions:
105
+ if interaction_count > 14:
106
+ high_interaction_users.append(user)
107
+ else:
108
+ low_interaction_users.append(user)
109
+
110
+ print(f"Found {len(high_interaction_users)} high-interaction users (>14) and {len(low_interaction_users)} low-interaction users ({min_interactions}-14)")
111
 
112
+ # Ensure more than half have >14 interactions
113
+ min_high_interaction = (n // 2) + 1 # More than 50%
114
 
115
+ selected_users = []
116
+
117
+ # First, select from high-interaction users (prioritize these)
118
+ if len(high_interaction_users) >= min_high_interaction:
119
+ selected_high = random.sample(high_interaction_users, min_high_interaction)
120
  else:
121
+ print(f"Warning: Only {len(high_interaction_users)} high-interaction users available, using all")
122
+ selected_high = high_interaction_users
123
+
124
+ selected_users.extend(selected_high)
125
+ remaining_slots = n - len(selected_high)
126
+
127
+ # Fill remaining slots with low-interaction users if available
128
+ if remaining_slots > 0 and len(low_interaction_users) > 0:
129
+ if len(low_interaction_users) >= remaining_slots:
130
+ selected_low = random.sample(low_interaction_users, remaining_slots)
131
+ else:
132
+ selected_low = low_interaction_users
133
+ selected_users.extend(selected_low)
134
+
135
+ # If we still need more users, add remaining high-interaction users
136
+ remaining_slots = n - len(selected_users)
137
+ if remaining_slots > 0:
138
+ remaining_high = [user for user in high_interaction_users if user not in selected_users]
139
+ if len(remaining_high) >= remaining_slots:
140
+ selected_users.extend(random.sample(remaining_high, remaining_slots))
141
+ else:
142
+ selected_users.extend(remaining_high)
143
+
144
+ print(f"Selected {len(selected_users)} total users: {len([u for u in selected_users if self.user_stats[u['user_id']]['total_interactions'] > 14])} high-interaction (>14), {len([u for u in selected_users if self.user_stats[u['user_id']]['total_interactions'] <= 14])} low-interaction (≀14)")
145
+
146
+ if len(selected_users) < n:
147
+ print(f"Warning: Only {len(selected_users)} users available, returning all")
148
 
149
  # Build user profiles with real data
150
  real_user_profiles = []
 
160
  'age': int(user['age']),
161
  'gender': user['gender'],
162
  'income': int(user['income']),
163
+ 'profession': user.get('profession', 'Other'),
164
+ 'location': user.get('location', 'Urban'),
165
+ 'education_level': user.get('education_level', 'High School'),
166
+ 'marital_status': user.get('marital_status', 'Single'),
167
  'interaction_history': stats['unique_items'][:50], # Limit to 50 most recent
168
  'interaction_stats': {
169
  'total_interactions': stats['total_interactions'],
train_improved_model.py DELETED
@@ -1,111 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Train the improved two-tower model with all enhancements to address the identified issues.
4
-
5
- This script implements:
6
- βœ… 128D embeddings (vs 64D) - Better representation capacity
7
- βœ… Temperature scaling - Improved score discrimination
8
- βœ… Category-aware boosting - Enhanced personalization
9
- βœ… Contrastive loss - Prevents embedding collapse
10
- βœ… Hard negative mining - Better training signal
11
- βœ… User/item bias terms - Improved modeling capacity
12
- βœ… Curriculum learning - Progressive training strategy
13
- """
14
-
15
- import argparse
16
- import sys
17
- import os
18
-
19
- # Add project root to path
20
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
21
-
22
- from src.training.curriculum_trainer import CurriculumTrainer
23
-
24
-
25
- def main():
26
- parser = argparse.ArgumentParser(description='Train improved two-tower model')
27
- parser.add_argument('--embedding-dim', type=int, default=128,
28
- help='Embedding dimension (default: 128)')
29
- parser.add_argument('--learning-rate', type=float, default=0.001,
30
- help='Learning rate (default: 0.001)')
31
- parser.add_argument('--epochs-per-stage', type=int, default=15,
32
- help='Epochs per curriculum stage (default: 15)')
33
- parser.add_argument('--batch-size', type=int, default=512,
34
- help='Batch size (default: 512)')
35
- parser.add_argument('--curriculum-stages', type=int, default=3,
36
- help='Number of curriculum stages (default: 3)')
37
- parser.add_argument('--use-focal-loss', action='store_true', default=True,
38
- help='Use focal loss for imbalanced data')
39
-
40
- args = parser.parse_args()
41
-
42
- print("πŸš€ TRAINING IMPROVED TWO-TOWER MODEL")
43
- print("="*70)
44
- print("IMPROVEMENTS IMPLEMENTED:")
45
- print("βœ… 128D embeddings (increased from 64D)")
46
- print("βœ… Temperature scaling for better score discrimination")
47
- print("βœ… Category-aware boosting for personalization")
48
- print("βœ… Contrastive loss to prevent embedding collapse")
49
- print("βœ… Hard negative mining for better training")
50
- print("βœ… User/item bias terms for improved modeling")
51
- print("βœ… Curriculum learning for progressive training")
52
- print("="*70)
53
-
54
- # Initialize trainer with improved settings
55
- trainer = CurriculumTrainer(
56
- embedding_dim=args.embedding_dim,
57
- learning_rate=args.learning_rate,
58
- use_focal_loss=args.use_focal_loss,
59
- curriculum_stages=args.curriculum_stages
60
- )
61
-
62
- try:
63
- # Load data and train
64
- trainer.load_data_processor()
65
- trainer.create_model()
66
-
67
- # Load training data
68
- import pickle
69
- with open("src/artifacts/training_features.pkl", 'rb') as f:
70
- training_features = pickle.load(f)
71
-
72
- with open("src/artifacts/validation_features.pkl", 'rb') as f:
73
- validation_features = pickle.load(f)
74
-
75
- # Train with curriculum learning
76
- history = trainer.train_with_curriculum(
77
- training_features=training_features,
78
- validation_features=validation_features,
79
- epochs_per_stage=args.epochs_per_stage,
80
- batch_size=args.batch_size
81
- )
82
-
83
- # Save results
84
- trainer.save_model()
85
-
86
- with open("src/artifacts/improved_training_history.pkl", 'wb') as f:
87
- pickle.dump(history, f)
88
-
89
- print("\n🎯 EXPECTED IMPROVEMENTS:")
90
- print("β€’ Score variance: 0.0007 β†’ 0.01+ (15x better discrimination)")
91
- print("β€’ Category alignment: 12% β†’ 60%+ (5x better personalization)")
92
- print("β€’ Reduced embedding collapse (more diverse user representations)")
93
- print("β€’ Better negative sampling and contrastive learning")
94
- print("β€’ Improved bias modeling for users and items")
95
-
96
- print("\nβœ… TRAINING COMPLETED SUCCESSFULLY!")
97
- print("The improved model should address all critical issues identified in your analysis.")
98
-
99
- except FileNotFoundError as e:
100
- print(f"❌ ERROR: {e}")
101
- print("Please ensure training data exists in src/artifacts/")
102
- print("Run data preprocessing first if needed.")
103
-
104
- except Exception as e:
105
- print(f"❌ TRAINING ERROR: {e}")
106
- import traceback
107
- traceback.print_exc()
108
-
109
-
110
- if __name__ == "__main__":
111
- main()