minhajHP commited on
Commit
d32ca60
·
1 Parent(s): e69bfae

Major codebase cleanup and feature additions

Browse files

New Features:
- Add dual API implementations (api_2phase.py, api_joint.py)
- Add 2-phase and joint training scripts
- Add enhanced recommendation engine with 128D embeddings
- Add improved joint training with curriculum learning
- Add enhanced two-tower model architecture
- Add optimized dataset creator

Enhancements:
- Enhance main API with new endpoints and features
- Improve frontend UI with advanced styling and components
- Upgrade inference engines with better recommendation logic
- Enhance data preprocessing with categorical demographics
- Improve training modules with better optimization

Cleanup:
- Remove debug, analysis, and documentation files
- Update gitignore with ML-specific patterns
- Remove backup directories and temporary files
- Streamline codebase structure

Files: 25 changed, 5843 insertions, 829 deletions

.gitignore CHANGED
@@ -193,4 +193,32 @@ data/
193
  .Spotlight-V100
194
  .Trashes
195
  ehthumbs.db
196
- Thumbs.db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  .Spotlight-V100
194
  .Trashes
195
  ehthumbs.db
196
+ Thumbs.db
197
+
198
+ # Analysis and debug files
199
+ debug_*.py
200
+ test_*.py
201
+ analyze_enhanced_*.py
202
+ analyze_recommendation_quality.py
203
+ *_analysis_*
204
+ *_report.*
205
+ simple_*.py
206
+
207
+ # Keep analyze_recommendations.py - it's wanted
208
+ !analyze_recommendations.py
209
+
210
+ # Backup directories
211
+ *_backup/
212
+ *_bak/
213
+ artifacts_backup/
214
+
215
+ # Temporary files
216
+ *.tmp
217
+ *.temp
218
+ *.cache
219
+
220
+ # Configuration files with sensitive data
221
+ config.json
222
+ secrets.json
223
+ .secret
224
+ .key
CATEGORICAL_DEMOGRAPHICS_SUMMARY.md DELETED
@@ -1,113 +0,0 @@
1
- # Categorical Demographics Implementation Summary
2
-
3
- ## ✅ **IMPLEMENTATION COMPLETE**
4
-
5
- Successfully converted age and income from continuous normalized features to categorical embeddings, achieving the goal of reducing demographics to 25% of total input dimensions.
6
-
7
- ---
8
-
9
- ## 🎯 **Key Changes Made**
10
-
11
- ### **1. Age Categorization (6 Categories)**
12
- - **Teen (0)**: Under 18
13
- - **Young Adult (1)**: 18-25
14
- - **Adult (2)**: 26-35
15
- - **Middle Age (3)**: 36-50
16
- - **Mature (4)**: 51-65
17
- - **Senior (5)**: 65+
18
-
19
- ### **2. Income Categorization (5 Categories)**
20
- - **Low Income (0)**: Bottom 20% (≤$56,276)
21
- - **Lower Middle (1)**: 20-40% ($56,276-$69,236)
22
- - **Middle (2)**: 40-60% ($69,236-$80,661)
23
- - **Upper Middle (3)**: 60-80% ($80,661-$94,284)
24
- - **High Income (4)**: Top 20% (≥$94,284)
25
-
26
- ### **3. Embedding Dimensions**
27
- **Original Tower (64D):**
28
- - Age: 4D, Income: 4D, Gender: 4D
29
- - **Total Demographics**: 12D (18.8% of input)
30
-
31
- **Improved Tower (128D):**
32
- - Age: 8D, Income: 8D, Gender: 8D
33
- - **Total Demographics**: 24D (18.8% of input)
34
-
35
- ---
36
-
37
- ## 📁 **Files Modified**
38
-
39
- ### **Data Preparation**
40
- - `src/preprocessing/user_data_preparation.py`
41
- - Added `categorize_age()` and `categorize_income()` functions
42
- - Updated `prepare_user_features()` to output categorical features
43
-
44
- ### **Model Architecture**
45
- - `src/models/user_tower.py`
46
- - Replaced normalization layers with embedding layers
47
- - Updated forward pass for categorical inputs
48
-
49
- - `src/models/improved_two_tower.py`
50
- - Same embedding updates as original tower
51
- - Maintained sophisticated history aggregation
52
-
53
- ### **Training Scripts**
54
- - `src/training/optimized_joint_training.py`
55
- - `src/training/joint_training.py`
56
- - `src/training/fast_joint_training.py`
57
- - Removed normalization adaptation calls
58
-
59
- ### **Inference Engine**
60
- - `src/inference/recommendation_engine.py`
61
- - Added categorization functions for real-time inference
62
- - Updated `prepare_user_features()` to categorize raw inputs
63
- - Added income threshold loading from training data
64
-
65
- ---
66
-
67
- ## 🔍 **Verification Results**
68
-
69
- ✅ **All Tests Pass:**
70
- - Age categorization: 6 categories (0-5) ✅
71
- - Income categorization: 5 categories (0-4) ✅
72
- - Training features: Correct int32 dtypes ✅
73
- - User towers: Proper embedding dimensions ✅
74
- - Inference engine: Successful categorical conversion ✅
75
- - Recommendation engines: Working with categorical inputs ✅
76
-
77
- ---
78
-
79
- ## 📊 **Benefits Achieved**
80
-
81
- ### **1. Balanced Feature Representation**
82
- - **Before**: Demographics 75% (96D), History 25% (32D)
83
- - **After**: Demographics 19% (24D), History 81% (104D)
84
-
85
- ### **2. Better Learning Patterns**
86
- - **Interpretable segments**: Clear demographic groups vs continuous values
87
- - **Non-linear relationships**: Each category learns distinct behaviors
88
- - **Reduced bias**: Less dependence on exact age/income values
89
- - **Better generalization**: Discrete categories vs continuous normalization
90
-
91
- ### **3. Improved Model Architecture**
92
- - **Smaller demographics footprint**: More capacity for behavioral signals
93
- - **Category-specific patterns**: Age/income groups with unique preferences
94
- - **Embedding benefits**: Learned representations vs fixed normalization
95
-
96
- ---
97
-
98
- ## 🚀 **Ready for Training**
99
-
100
- The categorical demographics implementation is complete and verified. The system now:
101
-
102
- 1. **Prioritizes behavioral signals** (81%) over demographics (19%)
103
- 2. **Uses interpretable demographic segments** instead of continuous values
104
- 3. **Maintains all existing functionality** with enhanced representation
105
- 4. **Is ready for improved model training** with better feature balance
106
-
107
- To train the improved model with categorical demographics:
108
-
109
- ```bash
110
- python train_improved_model.py --embedding-dim 128 --epochs-per-stage 15
111
- ```
112
-
113
- The enhanced recommendation system should now achieve better personalization through balanced feature representation and categorical demographic learning.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
analyze_recommendation_quality.py DELETED
@@ -1,556 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Comprehensive analysis of recommendation quality from the two-tower model.
4
- """
5
-
6
- import sys
7
- import os
8
- import numpy as np
9
- import pandas as pd
10
- from collections import Counter, defaultdict
11
- import time
12
-
13
- sys.path.append('/home/user/Desktop/RecSys-HP')
14
- from src.inference.recommendation_engine import RecommendationEngine
15
- from src.utils.real_user_selector import RealUserSelector
16
-
17
- def analyze_score_distribution():
18
- """Analyze the distribution of recommendation scores."""
19
-
20
- print("📊 SCORE DISTRIBUTION ANALYSIS")
21
- print("="*50)
22
-
23
- try:
24
- engine = RecommendationEngine()
25
- real_user_selector = RealUserSelector()
26
-
27
- # Get multiple users for comprehensive analysis
28
- test_users = real_user_selector.get_real_users(n=10, min_interactions=10)
29
-
30
- all_scores = {
31
- 'collaborative': [],
32
- 'hybrid': [],
33
- 'content': []
34
- }
35
-
36
- print(f"Testing with {len(test_users)} users...")
37
-
38
- for i, user in enumerate(test_users):
39
- print(f"\nUser {i+1}/10 - {user['user_id']} ({user['age']}yr {user['gender']}):")
40
-
41
- # Test collaborative filtering
42
- try:
43
- collab_recs = engine.recommend_items_collaborative(
44
- age=user['age'],
45
- gender=user['gender'],
46
- income=user['income'],
47
- interaction_history=user['interaction_history'][:20],
48
- k=20
49
- )
50
- collab_scores = [score for _, score, _ in collab_recs]
51
- all_scores['collaborative'].extend(collab_scores)
52
-
53
- print(f" Collaborative: {min(collab_scores):.4f} - {max(collab_scores):.4f} (std: {np.std(collab_scores):.4f})")
54
-
55
- except Exception as e:
56
- print(f" Collaborative failed: {e}")
57
-
58
- # Test hybrid
59
- try:
60
- hybrid_recs = engine.recommend_items_hybrid(
61
- age=user['age'],
62
- gender=user['gender'],
63
- income=user['income'],
64
- interaction_history=user['interaction_history'][:20],
65
- k=20,
66
- collaborative_weight=0.7
67
- )
68
- hybrid_scores = [score for _, score, _ in hybrid_recs]
69
- all_scores['hybrid'].extend(hybrid_scores)
70
-
71
- print(f" Hybrid: {min(hybrid_scores):.4f} - {max(hybrid_scores):.4f} (std: {np.std(hybrid_scores):.4f})")
72
-
73
- except Exception as e:
74
- print(f" Hybrid failed: {e}")
75
-
76
- # Test content-based (if user has history)
77
- if user['interaction_history']:
78
- try:
79
- content_recs = engine.recommend_items_content_based(
80
- seed_item_id=user['interaction_history'][0],
81
- k=20
82
- )
83
- content_scores = [score for _, score, _ in content_recs]
84
- all_scores['content'].extend(content_scores)
85
-
86
- print(f" Content: {min(content_scores):.4f} - {max(content_scores):.4f} (std: {np.std(content_scores):.4f})")
87
-
88
- except Exception as e:
89
- print(f" Content failed: {e}")
90
-
91
- # Overall score analysis
92
- print(f"\n📈 OVERALL SCORE STATISTICS:")
93
- for method, scores in all_scores.items():
94
- if scores:
95
- print(f"\n{method.upper()}:")
96
- print(f" Total scores: {len(scores)}")
97
- print(f" Range: {min(scores):.4f} - {max(scores):.4f}")
98
- print(f" Mean: {np.mean(scores):.4f}")
99
- print(f" Std: {np.std(scores):.4f}")
100
- print(f" Variance: {np.var(scores):.6f}")
101
-
102
- # Score distribution percentiles
103
- percentiles = [10, 25, 50, 75, 90]
104
- perc_values = np.percentile(scores, percentiles)
105
- print(f" Percentiles: {dict(zip(percentiles, perc_values))}")
106
-
107
- # Quality assessment
108
- score_range = max(scores) - min(scores)
109
- if score_range < 0.1:
110
- print(f" ⚠️ WARNING: Low score range ({score_range:.4f}) - poor discrimination")
111
- elif score_range < 0.3:
112
- print(f" ⚠️ CAUTION: Moderate score range ({score_range:.4f})")
113
- else:
114
- print(f" ✅ GOOD: Wide score range ({score_range:.4f})")
115
-
116
- if np.var(scores) < 0.001:
117
- print(f" ⚠️ WARNING: Very low variance - poor ranking ability")
118
- elif np.var(scores) < 0.01:
119
- print(f" ⚠️ CAUTION: Low variance")
120
- else:
121
- print(f" ✅ GOOD: Adequate variance for ranking")
122
-
123
- return all_scores
124
-
125
- except Exception as e:
126
- print(f"❌ Score analysis failed: {e}")
127
- return None
128
-
129
- def analyze_category_alignment():
130
- """Analyze how well recommendations align with user category preferences."""
131
-
132
- print(f"\n🎯 CATEGORY ALIGNMENT ANALYSIS")
133
- print("="*40)
134
-
135
- try:
136
- engine = RecommendationEngine()
137
- real_user_selector = RealUserSelector()
138
-
139
- test_users = real_user_selector.get_real_users(n=5, min_interactions=15)
140
-
141
- alignment_results = []
142
-
143
- for user in test_users:
144
- print(f"\nUser {user['user_id']} ({user['age']}yr {user['gender']}):")
145
-
146
- # Get user's detailed interactions
147
- user_details = real_user_selector.get_user_interaction_details(user['user_id'])
148
-
149
- # Analyze user's category preferences
150
- user_categories = []
151
- for interaction in user_details['timeline']:
152
- category = interaction.get('category_code', 'Unknown')
153
- user_categories.append(category)
154
-
155
- user_category_counts = Counter(user_categories)
156
- total_user_interactions = len(user_categories)
157
-
158
- print(f" User's top categories:")
159
- for category, count in user_category_counts.most_common(3):
160
- percentage = (count / total_user_interactions) * 100
161
- print(f" {category}: {count} ({percentage:.1f}%)")
162
-
163
- # Get recommendations
164
- recs = engine.recommend_items_hybrid(
165
- age=user['age'],
166
- gender=user['gender'],
167
- income=user['income'],
168
- interaction_history=user['interaction_history'][:20],
169
- k=20,
170
- collaborative_weight=0.7
171
- )
172
-
173
- # Analyze recommendation categories
174
- rec_categories = []
175
- for _, _, item_info in recs:
176
- category = item_info.get('category_code', 'Unknown')
177
- rec_categories.append(category)
178
-
179
- rec_category_counts = Counter(rec_categories)
180
-
181
- print(f" Recommendation categories:")
182
- for category, count in rec_category_counts.most_common(3):
183
- percentage = (count / len(rec_categories)) * 100
184
- match = "✅" if category in user_category_counts else "🆕"
185
- print(f" {category}: {count} ({percentage:.1f}%) {match}")
186
-
187
- # Calculate alignment metrics
188
- user_cats = set(user_category_counts.keys())
189
- rec_cats = set(rec_category_counts.keys())
190
-
191
- intersection = user_cats & rec_cats
192
- alignment_percentage = len(intersection) / len(rec_cats) * 100 if rec_cats else 0
193
-
194
- # Calculate weighted alignment (by user preference strength)
195
- weighted_alignment = 0
196
- for category in intersection:
197
- user_weight = user_category_counts[category] / total_user_interactions
198
- rec_weight = rec_category_counts[category] / len(rec_categories)
199
- weighted_alignment += min(user_weight, rec_weight)
200
-
201
- alignment_results.append({
202
- 'user_id': user['user_id'],
203
- 'alignment_percentage': alignment_percentage,
204
- 'weighted_alignment': weighted_alignment * 100,
205
- 'user_categories': len(user_cats),
206
- 'rec_categories': len(rec_cats),
207
- 'matched_categories': len(intersection)
208
- })
209
-
210
- print(f" Alignment: {alignment_percentage:.1f}% ({len(intersection)}/{len(rec_cats)} categories)")
211
- print(f" Weighted alignment: {weighted_alignment * 100:.1f}%")
212
-
213
- # Overall alignment analysis
214
- print(f"\n📊 OVERALL ALIGNMENT STATISTICS:")
215
- avg_alignment = np.mean([r['alignment_percentage'] for r in alignment_results])
216
- avg_weighted = np.mean([r['weighted_alignment'] for r in alignment_results])
217
- avg_user_cats = np.mean([r['user_categories'] for r in alignment_results])
218
- avg_rec_cats = np.mean([r['rec_categories'] for r in alignment_results])
219
-
220
- print(f" Average alignment: {avg_alignment:.1f}%")
221
- print(f" Average weighted alignment: {avg_weighted:.1f}%")
222
- print(f" Average user categories: {avg_user_cats:.1f}")
223
- print(f" Average rec categories: {avg_rec_cats:.1f}")
224
-
225
- # Quality assessment
226
- if avg_alignment < 20:
227
- print(f" ❌ POOR: Very low category alignment")
228
- elif avg_alignment < 40:
229
- print(f" ⚠️ FAIR: Low category alignment")
230
- elif avg_alignment < 60:
231
- print(f" ✅ GOOD: Moderate category alignment")
232
- else:
233
- print(f" 🎉 EXCELLENT: High category alignment")
234
-
235
- return alignment_results
236
-
237
- except Exception as e:
238
- print(f"❌ Category alignment analysis failed: {e}")
239
- return None
240
-
241
- def analyze_diversity_metrics():
242
- """Analyze diversity metrics in recommendations."""
243
-
244
- print(f"\n🌈 DIVERSITY ANALYSIS")
245
- print("="*30)
246
-
247
- try:
248
- engine = RecommendationEngine()
249
- real_user_selector = RealUserSelector()
250
-
251
- test_users = real_user_selector.get_real_users(n=5, min_interactions=10)
252
-
253
- diversity_results = []
254
-
255
- for user in test_users:
256
- print(f"\nUser {user['user_id']}:")
257
-
258
- # Get recommendations
259
- recs = engine.recommend_items_hybrid(
260
- age=user['age'],
261
- gender=user['gender'],
262
- income=user['income'],
263
- interaction_history=user['interaction_history'][:20],
264
- k=20,
265
- collaborative_weight=0.7
266
- )
267
-
268
- # Extract features for diversity analysis
269
- categories = [item_info.get('category_code', 'Unknown') for _, _, item_info in recs]
270
- brands = [item_info.get('brand', 'Unknown') for _, _, item_info in recs]
271
- prices = [item_info.get('price', 0) for _, _, item_info in recs]
272
-
273
- # Calculate diversity metrics
274
- category_diversity = len(set(categories)) / len(categories) if categories else 0
275
- brand_diversity = len(set(brands)) / len(brands) if brands else 0
276
-
277
- # Price diversity (coefficient of variation)
278
- price_diversity = np.std(prices) / np.mean(prices) if np.mean(prices) > 0 else 0
279
-
280
- # Intra-list diversity (average pairwise dissimilarity)
281
- category_counts = Counter(categories)
282
- gini_categories = 1 - sum((count / len(categories)) ** 2 for count in category_counts.values())
283
-
284
- diversity_results.append({
285
- 'user_id': user['user_id'],
286
- 'category_diversity': category_diversity,
287
- 'brand_diversity': brand_diversity,
288
- 'price_diversity': price_diversity,
289
- 'gini_categories': gini_categories,
290
- 'unique_categories': len(set(categories)),
291
- 'unique_brands': len(set(brands))
292
- })
293
-
294
- print(f" Categories: {len(set(categories))} unique ({category_diversity:.2f} ratio)")
295
- print(f" Brands: {len(set(brands))} unique ({brand_diversity:.2f} ratio)")
296
- print(f" Price range: ${min(prices):.2f} - ${max(prices):.2f}")
297
- print(f" Gini (categories): {gini_categories:.2f}")
298
-
299
- # Overall diversity statistics
300
- print(f"\n📊 OVERALL DIVERSITY STATISTICS:")
301
- avg_cat_diversity = np.mean([r['category_diversity'] for r in diversity_results])
302
- avg_brand_diversity = np.mean([r['brand_diversity'] for r in diversity_results])
303
- avg_gini = np.mean([r['gini_categories'] for r in diversity_results])
304
- avg_unique_cats = np.mean([r['unique_categories'] for r in diversity_results])
305
-
306
- print(f" Average category diversity: {avg_cat_diversity:.2f}")
307
- print(f" Average brand diversity: {avg_brand_diversity:.2f}")
308
- print(f" Average Gini coefficient: {avg_gini:.2f}")
309
- print(f" Average unique categories: {avg_unique_cats:.1f}")
310
-
311
- # Quality assessment
312
- if avg_cat_diversity < 0.3:
313
- print(f" ❌ POOR: Low category diversity - recommendations too similar")
314
- elif avg_cat_diversity < 0.5:
315
- print(f" ⚠️ FAIR: Moderate category diversity")
316
- else:
317
- print(f" ✅ GOOD: High category diversity")
318
-
319
- return diversity_results
320
-
321
- except Exception as e:
322
- print(f"❌ Diversity analysis failed: {e}")
323
- return None
324
-
325
- def analyze_embedding_quality():
326
- """Analyze the quality of user and item embeddings."""
327
-
328
- print(f"\n🧠 EMBEDDING QUALITY ANALYSIS")
329
- print("="*35)
330
-
331
- try:
332
- engine = RecommendationEngine()
333
- real_user_selector = RealUserSelector()
334
-
335
- test_users = real_user_selector.get_real_users(n=3, min_interactions=10)
336
-
337
- user_embeddings = []
338
- user_item_similarities = []
339
-
340
- for user in test_users:
341
- print(f"\nUser {user['user_id']}:")
342
-
343
- # Get user embedding
344
- user_emb = engine.get_user_embedding(
345
- age=user['age'],
346
- gender=user['gender'],
347
- income=user['income'],
348
- interaction_history=user['interaction_history'][:10]
349
- )
350
-
351
- user_embeddings.append(user_emb)
352
-
353
- print(f" User embedding shape: {user_emb.shape}")
354
- print(f" User embedding norm: {np.linalg.norm(user_emb):.4f}")
355
- print(f" User embedding mean: {user_emb.mean():.4f}")
356
- print(f" User embedding std: {user_emb.std():.4f}")
357
-
358
- # Get embeddings for user's interaction history
359
- item_similarities = []
360
- for item_id in user['interaction_history'][:5]:
361
- item_emb = engine.get_item_embedding(item_id)
362
- if item_emb is not None:
363
- similarity = np.dot(user_emb, item_emb)
364
- item_similarities.append(similarity)
365
-
366
- if item_similarities:
367
- user_item_similarities.extend(item_similarities)
368
- print(f" Avg similarity with interacted items: {np.mean(item_similarities):.4f}")
369
- print(f" Similarity range: {min(item_similarities):.4f} - {max(item_similarities):.4f}")
370
-
371
- # Analyze user embedding diversity
372
- if len(user_embeddings) > 1:
373
- user_embeddings = np.array(user_embeddings)
374
-
375
- # User-user similarities
376
- user_similarities = []
377
- for i in range(len(user_embeddings)):
378
- for j in range(i+1, len(user_embeddings)):
379
- sim = np.dot(user_embeddings[i], user_embeddings[j])
380
- user_similarities.append(sim)
381
-
382
- print(f"\n📊 USER EMBEDDING ANALYSIS:")
383
- print(f" User-user similarities: {np.mean(user_similarities):.4f} ± {np.std(user_similarities):.4f}")
384
- print(f" User-item similarities: {np.mean(user_item_similarities):.4f} ± {np.std(user_item_similarities):.4f}")
385
-
386
- # Quality assessment
387
- if np.mean(user_similarities) > 0.9:
388
- print(f" ⚠️ WARNING: Users too similar - possible embedding collapse")
389
- elif np.mean(user_similarities) > 0.7:
390
- print(f" ⚠️ CAUTION: High user similarity - limited personalization")
391
- else:
392
- print(f" ✅ GOOD: Adequate user embedding diversity")
393
-
394
- return {
395
- 'user_embeddings': user_embeddings,
396
- 'user_similarities': user_similarities if len(user_embeddings) > 1 else [],
397
- 'user_item_similarities': user_item_similarities
398
- }
399
-
400
- except Exception as e:
401
- print(f"❌ Embedding analysis failed: {e}")
402
- return None
403
-
404
- def analyze_performance_metrics():
405
- """Analyze performance and efficiency metrics."""
406
-
407
- print(f"\n⚡ PERFORMANCE ANALYSIS")
408
- print("="*25)
409
-
410
- try:
411
- engine = RecommendationEngine()
412
- real_user_selector = RealUserSelector()
413
-
414
- test_user = real_user_selector.get_real_users(n=1, min_interactions=10)[0]
415
-
416
- # Test recommendation generation speed
417
- print("Testing recommendation generation speed...")
418
-
419
- methods = [
420
- ('Collaborative', lambda: engine.recommend_items_collaborative(
421
- age=test_user['age'], gender=test_user['gender'],
422
- income=test_user['income'], interaction_history=test_user['interaction_history'][:20], k=10
423
- )),
424
- ('Hybrid', lambda: engine.recommend_items_hybrid(
425
- age=test_user['age'], gender=test_user['gender'],
426
- income=test_user['income'], interaction_history=test_user['interaction_history'][:20], k=10
427
- )),
428
- ]
429
-
430
- for method_name, method_func in methods:
431
- times = []
432
- for _ in range(5): # Run 5 times for average
433
- start_time = time.time()
434
- recs = method_func()
435
- end_time = time.time()
436
- times.append(end_time - start_time)
437
-
438
- avg_time = np.mean(times)
439
- print(f" {method_name}: {avg_time:.3f}s ± {np.std(times):.3f}s")
440
-
441
- if avg_time > 1.0:
442
- print(f" ⚠️ SLOW: Consider optimization")
443
- elif avg_time > 0.5:
444
- print(f" ⚠️ MODERATE: Acceptable for real-time")
445
- else:
446
- print(f" ✅ FAST: Good for real-time recommendations")
447
-
448
- # Test scalability with different recommendation counts
449
- print(f"\nTesting scalability...")
450
- for k in [10, 50, 100]:
451
- start_time = time.time()
452
- recs = engine.recommend_items_hybrid(
453
- age=test_user['age'], gender=test_user['gender'],
454
- income=test_user['income'], interaction_history=test_user['interaction_history'][:20], k=k
455
- )
456
- end_time = time.time()
457
-
458
- print(f" {k} recommendations: {end_time - start_time:.3f}s")
459
-
460
- return True
461
-
462
- except Exception as e:
463
- print(f"❌ Performance analysis failed: {e}")
464
- return False
465
-
466
- def generate_quality_report():
467
- """Generate a comprehensive quality report."""
468
-
469
- print(f"\n📋 COMPREHENSIVE QUALITY REPORT")
470
- print("="*40)
471
-
472
- # Run all analyses
473
- score_results = analyze_score_distribution()
474
- alignment_results = analyze_category_alignment()
475
- diversity_results = analyze_diversity_metrics()
476
- embedding_results = analyze_embedding_quality()
477
- performance_results = analyze_performance_metrics()
478
-
479
- # Generate summary
480
- print(f"\n🎯 QUALITY SUMMARY:")
481
-
482
- issues = []
483
- strengths = []
484
-
485
- # Check score quality
486
- if score_results:
487
- for method, scores in score_results.items():
488
- if scores:
489
- score_variance = np.var(scores)
490
- score_range = max(scores) - min(scores)
491
-
492
- if score_variance < 0.001:
493
- issues.append(f"Low {method} score variance ({score_variance:.6f})")
494
- if score_range < 0.1:
495
- issues.append(f"Narrow {method} score range ({score_range:.4f})")
496
-
497
- if score_variance > 0.01 and score_range > 0.3:
498
- strengths.append(f"Good {method} score discrimination")
499
-
500
- # Check alignment quality
501
- if alignment_results:
502
- avg_alignment = np.mean([r['alignment_percentage'] for r in alignment_results])
503
- if avg_alignment < 30:
504
- issues.append(f"Low category alignment ({avg_alignment:.1f}%)")
505
- elif avg_alignment > 50:
506
- strengths.append(f"Good category alignment ({avg_alignment:.1f}%)")
507
-
508
- # Check diversity
509
- if diversity_results:
510
- avg_diversity = np.mean([r['category_diversity'] for r in diversity_results])
511
- if avg_diversity < 0.3:
512
- issues.append(f"Low category diversity ({avg_diversity:.2f})")
513
- elif avg_diversity > 0.5:
514
- strengths.append(f"Good category diversity ({avg_diversity:.2f})")
515
-
516
- # Print results
517
- if issues:
518
- print(f"\n❌ ISSUES IDENTIFIED:")
519
- for issue in issues:
520
- print(f" • {issue}")
521
-
522
- if strengths:
523
- print(f"\n✅ STRENGTHS:")
524
- for strength in strengths:
525
- print(f" • {strength}")
526
-
527
- # Recommendations
528
- print(f"\n💡 RECOMMENDATIONS:")
529
- if any("score variance" in issue for issue in issues):
530
- print(" • Increase embedding dimensions or add temperature scaling")
531
- if any("alignment" in issue for issue in issues):
532
- print(" • Implement category-aware recommendation boosting")
533
- if any("diversity" in issue for issue in issues):
534
- print(" • Add diversity regularization to recommendation algorithm")
535
-
536
- if not issues:
537
- print(" • No major issues detected - model performing well!")
538
-
539
- def main():
540
- """Main analysis function."""
541
-
542
- print("🔍 TWO-TOWER RECOMMENDATION QUALITY ANALYSIS")
543
- print("="*60)
544
-
545
- try:
546
- generate_quality_report()
547
-
548
- print(f"\n✅ Analysis completed successfully!")
549
-
550
- except Exception as e:
551
- print(f"❌ Analysis failed: {e}")
552
- import traceback
553
- traceback.print_exc()
554
-
555
- if __name__ == "__main__":
556
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
analyze_recommendations.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Recommendation Analysis Script
4
+
5
+ This script compares recommendations from both training approaches:
6
+ 1. 2-phase training (pre-trained item tower + joint fine-tuning)
7
+ 2. Single joint training (end-to-end optimization)
8
+
9
+ It analyzes:
10
+ - Category alignment between user interactions and recommendations
11
+ - Diversity of recommended categories
12
+ - Overlap between the two approaches
13
+ - Performance on real users
14
+
15
+ Usage:
16
+ python analyze_recommendations.py
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import numpy as np
22
+ import pandas as pd
23
+ from collections import defaultdict, Counter
24
+ from typing import Dict, List, Tuple
25
+ import matplotlib.pyplot as plt
26
+ import seaborn as sns
27
+
28
+ # Add src to path
29
+ sys.path.append('src')
30
+
31
+ from src.inference.recommendation_engine import RecommendationEngine
32
+ from src.utils.real_user_selector import RealUserSelector
33
+
34
+
35
+ class RecommendationAnalyzer:
36
+ """Analyzer for comparing different recommendation approaches."""
37
+
38
+ def __init__(self):
39
+ self.recommendation_engine = None
40
+ self.real_user_selector = None
41
+ self.items_df = None
42
+ self.setup_engines()
43
+
44
+ def setup_engines(self):
45
+ """Setup recommendation engines and data."""
46
+ print("Loading recommendation engines...")
47
+
48
+ try:
49
+ # Load recommendation engine (assumes trained model artifacts exist)
50
+ self.recommendation_engine = RecommendationEngine()
51
+ print("✅ Recommendation engine loaded")
52
+ except Exception as e:
53
+ print(f"❌ Error loading recommendation engine: {e}")
54
+ return
55
+
56
+ try:
57
+ # Load real user selector
58
+ self.real_user_selector = RealUserSelector()
59
+ print("✅ Real user selector loaded")
60
+ except Exception as e:
61
+ print(f"❌ Error loading real user selector: {e}")
62
+
63
+ # Load items data for category analysis
64
+ self.items_df = pd.read_csv("datasets/items.csv")
65
+ print(f"✅ Loaded {len(self.items_df)} items")
66
+
67
+ def get_item_categories(self, item_ids: List[int]) -> List[str]:
68
+ """Get category codes for given item IDs."""
69
+ categories = []
70
+ for item_id in item_ids:
71
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
72
+ if len(item_row) > 0:
73
+ categories.append(item_row.iloc[0]['category_code'])
74
+ else:
75
+ categories.append('unknown')
76
+ return categories
77
+
78
+ def analyze_user_recommendations(self,
79
+ user_profile: Dict,
80
+ recommendation_types: List[str] = None) -> Dict:
81
+ """Analyze recommendations for a single user across different approaches."""
82
+
83
+ if recommendation_types is None:
84
+ recommendation_types = ['collaborative', 'hybrid', 'content']
85
+
86
+ results = {
87
+ 'user_profile': user_profile,
88
+ 'interaction_categories': [],
89
+ 'recommendations': {},
90
+ 'category_analysis': {}
91
+ }
92
+
93
+ # Get categories from user's interaction history
94
+ if user_profile['interaction_history']:
95
+ results['interaction_categories'] = self.get_item_categories(
96
+ user_profile['interaction_history']
97
+ )
98
+
99
+ # Get recommendations for each type
100
+ for rec_type in recommendation_types:
101
+ try:
102
+ if rec_type == 'collaborative':
103
+ recs = self.recommendation_engine.recommend_items_collaborative(
104
+ age=user_profile['age'],
105
+ gender=user_profile['gender'],
106
+ income=user_profile['income'],
107
+ interaction_history=user_profile['interaction_history'],
108
+ k=10
109
+ )
110
+ elif rec_type == 'hybrid':
111
+ recs = self.recommendation_engine.recommend_items_hybrid(
112
+ age=user_profile['age'],
113
+ gender=user_profile['gender'],
114
+ income=user_profile['income'],
115
+ interaction_history=user_profile['interaction_history'],
116
+ k=10
117
+ )
118
+ elif rec_type == 'content' and user_profile['interaction_history']:
119
+ recs = self.recommendation_engine.recommend_items_content_based(
120
+ seed_item_id=user_profile['interaction_history'][-1],
121
+ k=10
122
+ )
123
+ else:
124
+ continue
125
+
126
+ # Extract item IDs and categories
127
+ item_ids = [item_id for item_id, score, info in recs]
128
+ rec_categories = self.get_item_categories(item_ids)
129
+
130
+ results['recommendations'][rec_type] = {
131
+ 'items': recs,
132
+ 'item_ids': item_ids,
133
+ 'categories': rec_categories,
134
+ 'scores': [score for item_id, score, info in recs]
135
+ }
136
+
137
+ # Analyze category alignment
138
+ results['category_analysis'][rec_type] = self.analyze_category_alignment(
139
+ results['interaction_categories'],
140
+ rec_categories
141
+ )
142
+
143
+ except Exception as e:
144
+ print(f"Error generating {rec_type} recommendations: {e}")
145
+
146
+ return results
147
+
148
+ def analyze_category_alignment(self,
149
+ interaction_categories: List[str],
150
+ recommendation_categories: List[str]) -> Dict:
151
+ """Analyze alignment between interaction and recommendation categories."""
152
+
153
+ if not interaction_categories:
154
+ return {
155
+ 'overlap_ratio': 0.0,
156
+ 'unique_interaction_categories': 0,
157
+ 'unique_recommendation_categories': len(set(recommendation_categories)),
158
+ 'common_categories': [],
159
+ 'category_distribution': Counter(recommendation_categories)
160
+ }
161
+
162
+ interaction_set = set(interaction_categories)
163
+ recommendation_set = set(recommendation_categories)
164
+
165
+ common_categories = interaction_set.intersection(recommendation_set)
166
+ overlap_ratio = len(common_categories) / len(interaction_set) if interaction_set else 0.0
167
+
168
+ return {
169
+ 'overlap_ratio': overlap_ratio,
170
+ 'unique_interaction_categories': len(interaction_set),
171
+ 'unique_recommendation_categories': len(recommendation_set),
172
+ 'common_categories': list(common_categories),
173
+ 'category_distribution': Counter(recommendation_categories),
174
+ 'interaction_category_distribution': Counter(interaction_categories)
175
+ }
176
+
177
+ def compare_recommendation_approaches(self,
178
+ users_sample: List[Dict],
179
+ approaches: List[str] = None) -> Dict:
180
+ """Compare different recommendation approaches across multiple users."""
181
+
182
+ if approaches is None:
183
+ approaches = ['collaborative', 'hybrid', 'content']
184
+
185
+ comparison_results = {
186
+ 'approach_stats': defaultdict(list),
187
+ 'cross_approach_analysis': {},
188
+ 'user_results': []
189
+ }
190
+
191
+ print(f"Analyzing {len(users_sample)} users across {len(approaches)} approaches...")
192
+
193
+ for i, user in enumerate(users_sample):
194
+ print(f"Analyzing user {i+1}/{len(users_sample)}...")
195
+
196
+ user_results = self.analyze_user_recommendations(user, approaches)
197
+ comparison_results['user_results'].append(user_results)
198
+
199
+ # Aggregate stats by approach
200
+ for approach in approaches:
201
+ if approach in user_results['category_analysis']:
202
+ analysis = user_results['category_analysis'][approach]
203
+ comparison_results['approach_stats'][approach].append({
204
+ 'overlap_ratio': analysis['overlap_ratio'],
205
+ 'unique_rec_categories': analysis['unique_recommendation_categories'],
206
+ 'common_categories_count': len(analysis['common_categories'])
207
+ })
208
+
209
+ # Calculate aggregate statistics
210
+ for approach in approaches:
211
+ stats = comparison_results['approach_stats'][approach]
212
+ if stats:
213
+ comparison_results['approach_stats'][approach] = {
214
+ 'avg_overlap_ratio': np.mean([s['overlap_ratio'] for s in stats]),
215
+ 'std_overlap_ratio': np.std([s['overlap_ratio'] for s in stats]),
216
+ 'avg_unique_categories': np.mean([s['unique_rec_categories'] for s in stats]),
217
+ 'avg_common_categories': np.mean([s['common_categories_count'] for s in stats]),
218
+ 'total_users': len(stats)
219
+ }
220
+
221
+ # Cross-approach analysis
222
+ comparison_results['cross_approach_analysis'] = self.cross_approach_analysis(
223
+ comparison_results['user_results'], approaches
224
+ )
225
+
226
+ return comparison_results
227
+
228
+ def cross_approach_analysis(self, user_results: List[Dict], approaches: List[str]) -> Dict:
229
+ """Analyze similarities and differences between approaches."""
230
+
231
+ cross_analysis = {
232
+ 'item_overlap': defaultdict(dict),
233
+ 'category_overlap': defaultdict(dict),
234
+ 'score_correlation': defaultdict(dict)
235
+ }
236
+
237
+ for user_result in user_results:
238
+ recommendations = user_result['recommendations']
239
+
240
+ # Compare each pair of approaches
241
+ for i, approach1 in enumerate(approaches):
242
+ for approach2 in approaches[i+1:]:
243
+ if approach1 in recommendations and approach2 in recommendations:
244
+
245
+ # Item overlap
246
+ items1 = set(recommendations[approach1]['item_ids'])
247
+ items2 = set(recommendations[approach2]['item_ids'])
248
+ item_overlap_ratio = len(items1.intersection(items2)) / len(items1.union(items2))
249
+
250
+ # Category overlap
251
+ cats1 = set(recommendations[approach1]['categories'])
252
+ cats2 = set(recommendations[approach2]['categories'])
253
+ cat_overlap_ratio = len(cats1.intersection(cats2)) / len(cats1.union(cats2)) if cats1.union(cats2) else 0
254
+
255
+ # Store results
256
+ pair_key = f"{approach1}_vs_{approach2}"
257
+ if pair_key not in cross_analysis['item_overlap']:
258
+ cross_analysis['item_overlap'][pair_key] = []
259
+ cross_analysis['category_overlap'][pair_key] = []
260
+
261
+ cross_analysis['item_overlap'][pair_key].append(item_overlap_ratio)
262
+ cross_analysis['category_overlap'][pair_key].append(cat_overlap_ratio)
263
+
264
+ # Calculate averages
265
+ for pair_key in cross_analysis['item_overlap']:
266
+ cross_analysis['item_overlap'][pair_key] = {
267
+ 'avg': np.mean(cross_analysis['item_overlap'][pair_key]),
268
+ 'std': np.std(cross_analysis['item_overlap'][pair_key])
269
+ }
270
+ cross_analysis['category_overlap'][pair_key] = {
271
+ 'avg': np.mean(cross_analysis['category_overlap'][pair_key]),
272
+ 'std': np.std(cross_analysis['category_overlap'][pair_key])
273
+ }
274
+
275
+ return cross_analysis
276
+
277
+ def generate_report(self, comparison_results: Dict, output_file: str = "recommendation_analysis_report.md"):
278
+ """Generate a comprehensive analysis report."""
279
+
280
+ report = []
281
+ report.append("# Recommendation System Analysis Report")
282
+ report.append(f"Generated: {pd.Timestamp.now()}")
283
+ report.append("")
284
+
285
+ # Overall Statistics
286
+ report.append("## Overall Statistics")
287
+ report.append("")
288
+
289
+ for approach, stats in comparison_results['approach_stats'].items():
290
+ if isinstance(stats, dict):
291
+ report.append(f"### {approach.title()} Recommendations")
292
+ report.append(f"- **Average Category Overlap**: {stats['avg_overlap_ratio']:.3f} ± {stats['std_overlap_ratio']:.3f}")
293
+ report.append(f"- **Average Unique Categories per User**: {stats['avg_unique_categories']:.1f}")
294
+ report.append(f"- **Average Common Categories**: {stats['avg_common_categories']:.1f}")
295
+ report.append(f"- **Users Analyzed**: {stats['total_users']}")
296
+ report.append("")
297
+
298
+ # Cross-Approach Analysis
299
+ report.append("## Cross-Approach Comparison")
300
+ report.append("")
301
+
302
+ cross_analysis = comparison_results['cross_approach_analysis']
303
+
304
+ report.append("### Item Overlap Between Approaches")
305
+ for pair, overlap_stats in cross_analysis['item_overlap'].items():
306
+ report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} ± {overlap_stats['std']:.3f}")
307
+ report.append("")
308
+
309
+ report.append("### Category Overlap Between Approaches")
310
+ for pair, overlap_stats in cross_analysis['category_overlap'].items():
311
+ report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} ± {overlap_stats['std']:.3f}")
312
+ report.append("")
313
+
314
+ # Category Alignment Analysis
315
+ report.append("## Category Alignment Analysis")
316
+ report.append("")
317
+ report.append("Category alignment measures how well recommendations match the categories")
318
+ report.append("of items users have previously interacted with.")
319
+ report.append("")
320
+
321
+ # Find best performing approach
322
+ best_approach = max(
323
+ comparison_results['approach_stats'].keys(),
324
+ key=lambda k: comparison_results['approach_stats'][k]['avg_overlap_ratio']
325
+ if isinstance(comparison_results['approach_stats'][k], dict) else 0
326
+ )
327
+
328
+ report.append(f"**Best Category Alignment**: {best_approach.title()} approach")
329
+ report.append("")
330
+
331
+ # Recommendations
332
+ report.append("## Key Findings & Recommendations")
333
+ report.append("")
334
+
335
+ # Analyze overlap ratios to provide insights
336
+ overlap_ratios = {
337
+ k: v['avg_overlap_ratio'] for k, v in comparison_results['approach_stats'].items()
338
+ if isinstance(v, dict)
339
+ }
340
+
341
+ if overlap_ratios:
342
+ avg_overlap = np.mean(list(overlap_ratios.values()))
343
+ if avg_overlap > 0.5:
344
+ report.append("✅ **Strong Category Alignment**: Recommendations show good alignment with user interaction patterns.")
345
+ elif avg_overlap > 0.3:
346
+ report.append("⚠️ **Moderate Category Alignment**: Some alignment present but room for improvement.")
347
+ else:
348
+ report.append("❌ **Weak Category Alignment**: Recommendations may be too diverse or not well-aligned with user preferences.")
349
+
350
+ report.append("")
351
+
352
+ # Compare approaches
353
+ if len(overlap_ratios) > 1:
354
+ sorted_approaches = sorted(overlap_ratios.items(), key=lambda x: x[1], reverse=True)
355
+ report.append("### Approach Rankings (by category alignment):")
356
+ for i, (approach, ratio) in enumerate(sorted_approaches, 1):
357
+ report.append(f"{i}. **{approach.title()}**: {ratio:.3f}")
358
+ report.append("")
359
+
360
+ # Write report
361
+ with open(output_file, 'w') as f:
362
+ f.write('\n'.join(report))
363
+
364
+ print(f"✅ Analysis report saved to: {output_file}")
365
+ return '\n'.join(report)
366
+
367
+ def visualize_results(self, comparison_results: Dict, save_plots: bool = True):
368
+ """Create visualizations for the analysis results."""
369
+
370
+ # Set up plotting style
371
+ plt.style.use('default')
372
+ sns.set_palette("husl")
373
+
374
+ # Create figure with subplots
375
+ fig, axes = plt.subplots(2, 2, figsize=(15, 12))
376
+ fig.suptitle('Recommendation System Analysis', fontsize=16, fontweight='bold')
377
+
378
+ # 1. Category Overlap by Approach
379
+ ax1 = axes[0, 0]
380
+ approaches = []
381
+ overlap_means = []
382
+ overlap_stds = []
383
+
384
+ for approach, stats in comparison_results['approach_stats'].items():
385
+ if isinstance(stats, dict):
386
+ approaches.append(approach.title())
387
+ overlap_means.append(stats['avg_overlap_ratio'])
388
+ overlap_stds.append(stats['std_overlap_ratio'])
389
+
390
+ bars1 = ax1.bar(approaches, overlap_means, yerr=overlap_stds, capsize=5, alpha=0.7)
391
+ ax1.set_title('Average Category Overlap by Approach')
392
+ ax1.set_ylabel('Category Overlap Ratio')
393
+ ax1.set_ylim(0, 1)
394
+
395
+ # Add value labels on bars
396
+ for bar, mean in zip(bars1, overlap_means):
397
+ height = bar.get_height()
398
+ ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
399
+ f'{mean:.3f}', ha='center', va='bottom')
400
+
401
+ # 2. Cross-Approach Item Overlap
402
+ ax2 = axes[0, 1]
403
+ cross_analysis = comparison_results['cross_approach_analysis']
404
+
405
+ pair_names = []
406
+ item_overlaps = []
407
+
408
+ for pair, overlap_stats in cross_analysis['item_overlap'].items():
409
+ pair_names.append(pair.replace('_vs_', ' vs ').title())
410
+ item_overlaps.append(overlap_stats['avg'])
411
+
412
+ if pair_names:
413
+ bars2 = ax2.bar(pair_names, item_overlaps, alpha=0.7, color='coral')
414
+ ax2.set_title('Item Overlap Between Approaches')
415
+ ax2.set_ylabel('Item Overlap Ratio')
416
+ ax2.set_ylim(0, 1)
417
+ plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
418
+
419
+ # Add value labels
420
+ for bar, overlap in zip(bars2, item_overlaps):
421
+ height = bar.get_height()
422
+ ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
423
+ f'{overlap:.3f}', ha='center', va='bottom')
424
+
425
+ # 3. Category Diversity
426
+ ax3 = axes[1, 0]
427
+ unique_categories = []
428
+ for approach, stats in comparison_results['approach_stats'].items():
429
+ if isinstance(stats, dict):
430
+ unique_categories.append(stats['avg_unique_categories'])
431
+
432
+ bars3 = ax3.bar(approaches, unique_categories, alpha=0.7, color='lightgreen')
433
+ ax3.set_title('Average Unique Categories per Recommendation')
434
+ ax3.set_ylabel('Number of Unique Categories')
435
+
436
+ for bar, cats in zip(bars3, unique_categories):
437
+ height = bar.get_height()
438
+ ax3.text(bar.get_x() + bar.get_width()/2., height + 0.1,
439
+ f'{cats:.1f}', ha='center', va='bottom')
440
+
441
+ # 4. Category vs Item Overlap Comparison
442
+ ax4 = axes[1, 1]
443
+
444
+ if cross_analysis['item_overlap'] and cross_analysis['category_overlap']:
445
+ pairs = list(cross_analysis['item_overlap'].keys())
446
+ item_overlaps = [cross_analysis['item_overlap'][p]['avg'] for p in pairs]
447
+ cat_overlaps = [cross_analysis['category_overlap'][p]['avg'] for p in pairs]
448
+
449
+ x = np.arange(len(pairs))
450
+ width = 0.35
451
+
452
+ bars4a = ax4.bar(x - width/2, item_overlaps, width, label='Item Overlap', alpha=0.7)
453
+ bars4b = ax4.bar(x + width/2, cat_overlaps, width, label='Category Overlap', alpha=0.7)
454
+
455
+ ax4.set_title('Item vs Category Overlap Between Approaches')
456
+ ax4.set_ylabel('Overlap Ratio')
457
+ ax4.set_xticks(x)
458
+ ax4.set_xticklabels([p.replace('_vs_', ' vs ') for p in pairs], rotation=45, ha='right')
459
+ ax4.legend()
460
+ ax4.set_ylim(0, 1)
461
+
462
+ plt.tight_layout()
463
+
464
+ if save_plots:
465
+ plt.savefig('recommendation_analysis_plots.png', dpi=300, bbox_inches='tight')
466
+ print("✅ Plots saved to: recommendation_analysis_plots.png")
467
+
468
+ plt.show()
469
+
470
+
471
+ def main():
472
+ """Main function to run the recommendation analysis."""
473
+
474
+ print("🔍 Starting Recommendation Analysis...")
475
+ print("=" * 50)
476
+
477
+ # Initialize analyzer
478
+ analyzer = RecommendationAnalyzer()
479
+
480
+ if analyzer.recommendation_engine is None:
481
+ print("❌ Cannot proceed without recommendation engine. Please ensure model is trained.")
482
+ return
483
+
484
+ # Get sample of real users for analysis
485
+ print("Getting real user sample...")
486
+ try:
487
+ real_users = analyzer.real_user_selector.get_real_users(n=20, min_interactions=3)
488
+ print(f"✅ Loaded {len(real_users)} real users for analysis")
489
+ except Exception as e:
490
+ print(f"❌ Error loading real users: {e}")
491
+ # Fallback to synthetic users
492
+ real_users = [
493
+ {
494
+ 'age': 32, 'gender': 'male', 'income': 75000,
495
+ 'interaction_history': [1000978, 1001588, 1001618, 1002039]
496
+ },
497
+ {
498
+ 'age': 28, 'gender': 'female', 'income': 45000,
499
+ 'interaction_history': [1003456, 1004567, 1005678]
500
+ },
501
+ {
502
+ 'age': 45, 'gender': 'male', 'income': 85000,
503
+ 'interaction_history': [1006789, 1007890, 1008901, 1009012, 1010123]
504
+ }
505
+ ]
506
+ print(f"Using {len(real_users)} synthetic users for analysis")
507
+
508
+ # Run comprehensive analysis
509
+ print("Running recommendation analysis...")
510
+ approaches = ['collaborative', 'hybrid', 'content']
511
+
512
+ comparison_results = analyzer.compare_recommendation_approaches(
513
+ users_sample=real_users,
514
+ approaches=approaches
515
+ )
516
+
517
+ # Generate report
518
+ print("Generating analysis report...")
519
+ report = analyzer.generate_report(comparison_results)
520
+
521
+ # Create visualizations
522
+ print("Creating visualizations...")
523
+ try:
524
+ analyzer.visualize_results(comparison_results, save_plots=True)
525
+ except Exception as e:
526
+ print(f"Warning: Could not create visualizations: {e}")
527
+
528
+ # Print summary
529
+ print("\n" + "=" * 50)
530
+ print("📊 ANALYSIS SUMMARY")
531
+ print("=" * 50)
532
+
533
+ for approach, stats in comparison_results['approach_stats'].items():
534
+ if isinstance(stats, dict):
535
+ print(f"{approach.title()}: {stats['avg_overlap_ratio']:.3f} avg category overlap")
536
+
537
+ print(f"\n✅ Analysis complete! Check:")
538
+ print(" 📄 recommendation_analysis_report.md")
539
+ print(" 📊 recommendation_analysis_plots.png")
540
+
541
+
542
+ if __name__ == "__main__":
543
+ main()
api/main.py CHANGED
@@ -13,6 +13,7 @@ sys.path.append(parent_dir)
13
  os.chdir(parent_dir) # Change to project root directory
14
 
15
  from src.inference.recommendation_engine import RecommendationEngine
 
16
 
17
  # Initialize FastAPI app
18
  app = FastAPI(
@@ -30,8 +31,10 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
33
- # Global recommendation engine instance
34
  recommendation_engine = None
 
 
35
 
36
 
37
  # Pydantic models for request/response
@@ -45,8 +48,11 @@ class UserProfile(BaseModel):
45
  class RecommendationRequest(BaseModel):
46
  user_profile: UserProfile
47
  num_recommendations: int = 10
48
- recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid"
49
  collaborative_weight: Optional[float] = 0.7
 
 
 
50
 
51
 
52
  class ItemSimilarityRequest(BaseModel):
@@ -87,10 +93,27 @@ class RatingPredictionResponse(BaseModel):
87
  item_info: ItemInfo
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  @app.on_event("startup")
91
  async def startup_event():
92
- """Initialize the recommendation engine on startup."""
93
- global recommendation_engine
94
 
95
  try:
96
  print("Loading recommendation engine...")
@@ -99,6 +122,30 @@ async def startup_event():
99
  except Exception as e:
100
  print(f"Error loading recommendation engine: {e}")
101
  recommendation_engine = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  @app.get("/")
@@ -120,6 +167,68 @@ async def health_check():
120
  }
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  @app.post("/recommendations", response_model=RecommendationsResponse)
124
  async def get_recommendations(request: RecommendationRequest):
125
  """Get item recommendations for a user."""
@@ -164,10 +273,67 @@ async def get_recommendations(request: RecommendationRequest):
164
  collaborative_weight=request.collaborative_weight
165
  )
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  else:
168
  raise HTTPException(
169
  status_code=400,
170
- detail="Invalid recommendation_type. Must be 'collaborative', 'content', or 'hybrid'"
171
  )
172
 
173
  # Format response
 
13
  os.chdir(parent_dir) # Change to project root directory
14
 
15
  from src.inference.recommendation_engine import RecommendationEngine
16
+ from src.utils.real_user_selector import RealUserSelector
17
 
18
  # Initialize FastAPI app
19
  app = FastAPI(
 
31
  allow_headers=["*"],
32
  )
33
 
34
+ # Global instances
35
  recommendation_engine = None
36
+ enhanced_recommendation_engine = None
37
+ real_user_selector = None
38
 
39
 
40
  # Pydantic models for request/response
 
48
  class RecommendationRequest(BaseModel):
49
  user_profile: UserProfile
50
  num_recommendations: int = 10
51
+ recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
52
  collaborative_weight: Optional[float] = 0.7
53
+ category_boost: Optional[float] = 1.5 # For enhanced recommendations
54
+ enable_category_boost: Optional[bool] = True
55
+ enable_diversity: Optional[bool] = True
56
 
57
 
58
  class ItemSimilarityRequest(BaseModel):
 
93
  item_info: ItemInfo
94
 
95
 
96
+ class RealUserProfile(BaseModel):
97
+ user_id: int
98
+ age: int
99
+ gender: str
100
+ income: int
101
+ interaction_history: List[int]
102
+ interaction_stats: Dict[str, int]
103
+ interaction_pattern: str
104
+ summary: str
105
+
106
+
107
+ class RealUsersResponse(BaseModel):
108
+ users: List[RealUserProfile]
109
+ total_count: int
110
+ dataset_summary: Dict[str, Any]
111
+
112
+
113
  @app.on_event("startup")
114
  async def startup_event():
115
+ """Initialize the recommendation engines and real user selector on startup."""
116
+ global recommendation_engine, enhanced_recommendation_engine, real_user_selector
117
 
118
  try:
119
  print("Loading recommendation engine...")
 
122
  except Exception as e:
123
  print(f"Error loading recommendation engine: {e}")
124
  recommendation_engine = None
125
+
126
+ try:
127
+ print("Loading enhanced recommendation engine...")
128
+ # Try enhanced 128D engine first, fallback to regular enhanced
129
+ try:
130
+ from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
131
+ enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
132
+ print("✅ Using Enhanced 128D Recommendation Engine")
133
+ except:
134
+ from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
135
+ enhanced_recommendation_engine = EnhancedRecommendationEngine()
136
+ print("⚠️ Using fallback Enhanced Recommendation Engine")
137
+ print("Enhanced recommendation engine loaded successfully!")
138
+ except Exception as e:
139
+ print(f"Error loading enhanced recommendation engine: {e}")
140
+ enhanced_recommendation_engine = None
141
+
142
+ try:
143
+ print("Loading real user selector...")
144
+ real_user_selector = RealUserSelector()
145
+ print("Real user selector loaded successfully!")
146
+ except Exception as e:
147
+ print(f"Error loading real user selector: {e}")
148
+ real_user_selector = None
149
 
150
 
151
  @app.get("/")
 
167
  }
168
 
169
 
170
+ @app.get("/real-users", response_model=RealUsersResponse)
171
+ async def get_real_users(count: int = 100, min_interactions: int = 5):
172
+ """Get real user profiles with genuine interaction histories."""
173
+
174
+ if real_user_selector is None:
175
+ raise HTTPException(status_code=503, detail="Real user selector not available")
176
+
177
+ try:
178
+ # Get real user profiles
179
+ real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
180
+
181
+ # Get dataset summary
182
+ dataset_summary = real_user_selector.get_dataset_summary()
183
+
184
+ # Format users for response
185
+ formatted_users = []
186
+ for user in real_users:
187
+ formatted_users.append(RealUserProfile(**user))
188
+
189
+ return RealUsersResponse(
190
+ users=formatted_users,
191
+ total_count=len(formatted_users),
192
+ dataset_summary=dataset_summary
193
+ )
194
+
195
+ except Exception as e:
196
+ raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
197
+
198
+
199
+ @app.get("/real-users/{user_id}")
200
+ async def get_real_user_details(user_id: int):
201
+ """Get detailed interaction breakdown for a specific real user."""
202
+
203
+ if real_user_selector is None:
204
+ raise HTTPException(status_code=503, detail="Real user selector not available")
205
+
206
+ try:
207
+ user_details = real_user_selector.get_user_interaction_details(user_id)
208
+
209
+ if "error" in user_details:
210
+ raise HTTPException(status_code=404, detail=user_details["error"])
211
+
212
+ return user_details
213
+
214
+ except Exception as e:
215
+ raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
216
+
217
+
218
+ @app.get("/dataset-summary")
219
+ async def get_dataset_summary():
220
+ """Get summary statistics of the real dataset."""
221
+
222
+ if real_user_selector is None:
223
+ raise HTTPException(status_code=503, detail="Real user selector not available")
224
+
225
+ try:
226
+ return real_user_selector.get_dataset_summary()
227
+
228
+ except Exception as e:
229
+ raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
230
+
231
+
232
  @app.post("/recommendations", response_model=RecommendationsResponse)
233
  async def get_recommendations(request: RecommendationRequest):
234
  """Get item recommendations for a user."""
 
273
  collaborative_weight=request.collaborative_weight
274
  )
275
 
276
+ elif request.recommendation_type == "enhanced":
277
+ if enhanced_recommendation_engine is None:
278
+ raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
279
+
280
+ # Check if it's the 128D engine or fallback
281
+ if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
282
+ # 128D Enhanced engine
283
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
284
+ age=user_profile.age,
285
+ gender=user_profile.gender,
286
+ income=user_profile.income,
287
+ interaction_history=user_profile.interaction_history,
288
+ k=request.num_recommendations,
289
+ diversity_weight=0.3 if request.enable_diversity else 0.0,
290
+ category_boost=request.category_boost if request.enable_category_boost else 1.0
291
+ )
292
+ else:
293
+ # Fallback enhanced engine
294
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
295
+ age=user_profile.age,
296
+ gender=user_profile.gender,
297
+ income=user_profile.income,
298
+ interaction_history=user_profile.interaction_history,
299
+ k=request.num_recommendations,
300
+ collaborative_weight=request.collaborative_weight,
301
+ category_boost=request.category_boost,
302
+ enable_category_boost=request.enable_category_boost,
303
+ enable_diversity=request.enable_diversity
304
+ )
305
+
306
+ elif request.recommendation_type == "enhanced_128d":
307
+ if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
308
+ raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
309
+
310
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
311
+ age=user_profile.age,
312
+ gender=user_profile.gender,
313
+ income=user_profile.income,
314
+ interaction_history=user_profile.interaction_history,
315
+ k=request.num_recommendations,
316
+ diversity_weight=0.3 if request.enable_diversity else 0.0,
317
+ category_boost=request.category_boost if request.enable_category_boost else 1.0
318
+ )
319
+
320
+ elif request.recommendation_type == "category_focused":
321
+ if enhanced_recommendation_engine is None:
322
+ raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
323
+
324
+ recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
325
+ age=user_profile.age,
326
+ gender=user_profile.gender,
327
+ income=user_profile.income,
328
+ interaction_history=user_profile.interaction_history,
329
+ k=request.num_recommendations,
330
+ focus_percentage=0.8
331
+ )
332
+
333
  else:
334
  raise HTTPException(
335
  status_code=400,
336
+ detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
337
  )
338
 
339
  # Format response
api_2phase.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ API for 2-Phase Trained Recommendation System
4
+
5
+ This API serves recommendations from a model trained using the 2-phase approach:
6
+ 1. Pre-trained item tower
7
+ 2. Joint training with fine-tuned item tower
8
+
9
+ Usage:
10
+ python api_2phase.py
11
+
12
+ Then access: http://localhost:8000
13
+ """
14
+
15
+ from fastapi import FastAPI, HTTPException
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel
18
+ from typing import List, Optional, Dict, Any
19
+ import uvicorn
20
+ import os
21
+ import sys
22
+ import pandas as pd
23
+
24
+ # Add src to path for imports and set working directory
25
+ parent_dir = os.path.dirname(__file__)
26
+ sys.path.append(parent_dir)
27
+ os.chdir(parent_dir) # Change to project root directory
28
+
29
+ from src.inference.recommendation_engine import RecommendationEngine
30
+ from src.utils.real_user_selector import RealUserSelector
31
+
32
+ # Initialize FastAPI app
33
+ app = FastAPI(
34
+ title="Two-Tower Recommendation API (2-Phase Training)",
35
+ description="API for serving recommendations using a two-tower architecture trained with 2-phase approach",
36
+ version="1.0.0-2phase"
37
+ )
38
+
39
+ # Add CORS middleware
40
+ app.add_middleware(
41
+ CORSMiddleware,
42
+ allow_origins=["*"], # Configure appropriately for production
43
+ allow_credentials=True,
44
+ allow_methods=["*"],
45
+ allow_headers=["*"],
46
+ )
47
+
48
+ # Global instances
49
+ recommendation_engine = None
50
+ enhanced_recommendation_engine = None
51
+ real_user_selector = None
52
+
53
+
54
+ # Pydantic models for request/response
55
+ class UserProfile(BaseModel):
56
+ age: int
57
+ gender: str # "male" or "female"
58
+ income: float
59
+ interaction_history: Optional[List[int]] = []
60
+
61
+
62
+ class RecommendationRequest(BaseModel):
63
+ user_profile: UserProfile
64
+ num_recommendations: int = 10
65
+ recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
66
+ collaborative_weight: Optional[float] = 0.7
67
+ category_boost: Optional[float] = 1.5 # For enhanced recommendations
68
+ enable_category_boost: Optional[bool] = True
69
+ enable_diversity: Optional[bool] = True
70
+
71
+
72
+ class ItemSimilarityRequest(BaseModel):
73
+ item_id: int
74
+ num_recommendations: int = 10
75
+
76
+
77
+ class RatingPredictionRequest(BaseModel):
78
+ user_profile: UserProfile
79
+ item_id: int
80
+
81
+
82
+ class ItemInfo(BaseModel):
83
+ product_id: int
84
+ category_id: int
85
+ category_code: str
86
+ brand: str
87
+ price: float
88
+
89
+
90
+ class RecommendationResponse(BaseModel):
91
+ item_id: int
92
+ score: float
93
+ item_info: ItemInfo
94
+
95
+
96
+ class RecommendationsResponse(BaseModel):
97
+ recommendations: List[RecommendationResponse]
98
+ user_profile: UserProfile
99
+ recommendation_type: str
100
+ total_count: int
101
+ training_approach: str = "2-phase"
102
+
103
+
104
+ class RatingPredictionResponse(BaseModel):
105
+ user_profile: UserProfile
106
+ item_id: int
107
+ predicted_rating: float
108
+ item_info: ItemInfo
109
+
110
+
111
+ class RealUserProfile(BaseModel):
112
+ user_id: int
113
+ age: int
114
+ gender: str
115
+ income: int
116
+ interaction_history: List[int]
117
+ interaction_stats: Dict[str, int]
118
+ interaction_pattern: str
119
+ summary: str
120
+
121
+
122
+ class RealUsersResponse(BaseModel):
123
+ users: List[RealUserProfile]
124
+ total_count: int
125
+ dataset_summary: Dict[str, Any]
126
+
127
+
128
+ @app.on_event("startup")
129
+ async def startup_event():
130
+ """Initialize the recommendation engines and real user selector on startup."""
131
+ global recommendation_engine, enhanced_recommendation_engine, real_user_selector
132
+
133
+ print("🚀 Starting 2-Phase Training API...")
134
+ print(" Training approach: Pre-trained item tower + Joint fine-tuning")
135
+
136
+ try:
137
+ print("Loading 2-phase trained recommendation engine...")
138
+ recommendation_engine = RecommendationEngine()
139
+ print("✅ 2-phase recommendation engine loaded successfully!")
140
+ except Exception as e:
141
+ print(f"❌ Error loading recommendation engine: {e}")
142
+ recommendation_engine = None
143
+
144
+ try:
145
+ print("Loading enhanced recommendation engine...")
146
+ # Try enhanced 128D engine first, fallback to regular enhanced
147
+ try:
148
+ from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
149
+ enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
150
+ print("✅ Using Enhanced 128D Recommendation Engine")
151
+ except:
152
+ from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
153
+ enhanced_recommendation_engine = EnhancedRecommendationEngine()
154
+ print("⚠️ Using fallback Enhanced Recommendation Engine")
155
+ print("Enhanced recommendation engine loaded successfully!")
156
+ except Exception as e:
157
+ print(f"Error loading enhanced recommendation engine: {e}")
158
+ enhanced_recommendation_engine = None
159
+
160
+ try:
161
+ print("Loading real user selector...")
162
+ real_user_selector = RealUserSelector()
163
+ print("Real user selector loaded successfully!")
164
+ except Exception as e:
165
+ print(f"Error loading real user selector: {e}")
166
+ real_user_selector = None
167
+
168
+ print("🎯 2-Phase API ready to serve recommendations!")
169
+
170
+
171
+ @app.get("/")
172
+ async def root():
173
+ """Root endpoint with API information."""
174
+ return {
175
+ "message": "Two-Tower Recommendation API (2-Phase Training)",
176
+ "version": "1.0.0-2phase",
177
+ "training_approach": "2-phase (pre-trained item tower + joint fine-tuning)",
178
+ "status": "active" if recommendation_engine is not None else "initialization_failed"
179
+ }
180
+
181
+
182
+ @app.get("/health")
183
+ async def health_check():
184
+ """Health check endpoint."""
185
+ return {
186
+ "status": "healthy" if recommendation_engine is not None else "unhealthy",
187
+ "engine_loaded": recommendation_engine is not None,
188
+ "training_approach": "2-phase"
189
+ }
190
+
191
+
192
+ @app.get("/model-info")
193
+ async def model_info():
194
+ """Get information about the loaded model."""
195
+ if recommendation_engine is None:
196
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
197
+
198
+ return {
199
+ "training_approach": "2-phase",
200
+ "description": "Pre-trained item tower followed by joint training with user tower",
201
+ "phases": [
202
+ "Phase 1: Item tower pre-training on item features only",
203
+ "Phase 2: Joint training of user tower + fine-tuning pre-trained item tower"
204
+ ],
205
+ "embedding_dimension": 128,
206
+ "item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
207
+ "artifacts_loaded": {
208
+ "item_tower_pretrained": "src/artifacts/item_tower_weights",
209
+ "item_tower_finetuned": "src/artifacts/item_tower_weights_finetuned_best",
210
+ "user_tower": "src/artifacts/user_tower_weights_best",
211
+ "rating_model": "src/artifacts/rating_model_weights_best"
212
+ }
213
+ }
214
+
215
+
216
+ @app.get("/real-users", response_model=RealUsersResponse)
217
+ async def get_real_users(count: int = 100, min_interactions: int = 5):
218
+ """Get real user profiles with genuine interaction histories."""
219
+
220
+ if real_user_selector is None:
221
+ raise HTTPException(status_code=503, detail="Real user selector not available")
222
+
223
+ try:
224
+ # Get real user profiles
225
+ real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
226
+
227
+ # Get dataset summary
228
+ dataset_summary = real_user_selector.get_dataset_summary()
229
+
230
+ # Format users for response
231
+ formatted_users = []
232
+ for user in real_users:
233
+ formatted_users.append(RealUserProfile(**user))
234
+
235
+ return RealUsersResponse(
236
+ users=formatted_users,
237
+ total_count=len(formatted_users),
238
+ dataset_summary=dataset_summary
239
+ )
240
+
241
+ except Exception as e:
242
+ raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
243
+
244
+
245
+ @app.get("/real-users/{user_id}")
246
+ async def get_real_user_details(user_id: int):
247
+ """Get detailed interaction breakdown for a specific real user."""
248
+
249
+ if real_user_selector is None:
250
+ raise HTTPException(status_code=503, detail="Real user selector not available")
251
+
252
+ try:
253
+ user_details = real_user_selector.get_user_interaction_details(user_id)
254
+
255
+ if "error" in user_details:
256
+ raise HTTPException(status_code=404, detail=user_details["error"])
257
+
258
+ return user_details
259
+
260
+ except Exception as e:
261
+ raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
262
+
263
+
264
+ @app.get("/dataset-summary")
265
+ async def get_dataset_summary():
266
+ """Get summary statistics of the real dataset."""
267
+
268
+ if real_user_selector is None:
269
+ raise HTTPException(status_code=503, detail="Real user selector not available")
270
+
271
+ try:
272
+ return real_user_selector.get_dataset_summary()
273
+
274
+ except Exception as e:
275
+ raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
276
+
277
+
278
+ @app.post("/recommendations", response_model=RecommendationsResponse)
279
+ async def get_recommendations(request: RecommendationRequest):
280
+ """Get item recommendations for a user."""
281
+
282
+ if recommendation_engine is None:
283
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
284
+
285
+ try:
286
+ user_profile = request.user_profile
287
+
288
+ # Generate recommendations based on type
289
+ if request.recommendation_type == "collaborative":
290
+ recommendations = recommendation_engine.recommend_items_collaborative(
291
+ age=user_profile.age,
292
+ gender=user_profile.gender,
293
+ income=user_profile.income,
294
+ interaction_history=user_profile.interaction_history,
295
+ k=request.num_recommendations
296
+ )
297
+
298
+ elif request.recommendation_type == "content":
299
+ if not user_profile.interaction_history:
300
+ raise HTTPException(
301
+ status_code=400,
302
+ detail="Content-based recommendations require interaction history"
303
+ )
304
+
305
+ # Use most recent interaction as seed
306
+ seed_item = user_profile.interaction_history[-1]
307
+ recommendations = recommendation_engine.recommend_items_content_based(
308
+ seed_item_id=seed_item,
309
+ k=request.num_recommendations
310
+ )
311
+
312
+ elif request.recommendation_type == "hybrid":
313
+ recommendations = recommendation_engine.recommend_items_hybrid(
314
+ age=user_profile.age,
315
+ gender=user_profile.gender,
316
+ income=user_profile.income,
317
+ interaction_history=user_profile.interaction_history,
318
+ k=request.num_recommendations,
319
+ collaborative_weight=request.collaborative_weight
320
+ )
321
+
322
+ elif request.recommendation_type == "enhanced":
323
+ if enhanced_recommendation_engine is None:
324
+ raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
325
+
326
+ # Check if it's the 128D engine or fallback
327
+ if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
328
+ # 128D Enhanced engine
329
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
330
+ age=user_profile.age,
331
+ gender=user_profile.gender,
332
+ income=user_profile.income,
333
+ interaction_history=user_profile.interaction_history,
334
+ k=request.num_recommendations,
335
+ diversity_weight=0.3 if request.enable_diversity else 0.0,
336
+ category_boost=request.category_boost if request.enable_category_boost else 1.0
337
+ )
338
+ else:
339
+ # Fallback enhanced engine
340
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
341
+ age=user_profile.age,
342
+ gender=user_profile.gender,
343
+ income=user_profile.income,
344
+ interaction_history=user_profile.interaction_history,
345
+ k=request.num_recommendations,
346
+ collaborative_weight=request.collaborative_weight,
347
+ category_boost=request.category_boost,
348
+ enable_category_boost=request.enable_category_boost,
349
+ enable_diversity=request.enable_diversity
350
+ )
351
+
352
+ elif request.recommendation_type == "enhanced_128d":
353
+ if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
354
+ raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
355
+
356
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
357
+ age=user_profile.age,
358
+ gender=user_profile.gender,
359
+ income=user_profile.income,
360
+ interaction_history=user_profile.interaction_history,
361
+ k=request.num_recommendations,
362
+ diversity_weight=0.3 if request.enable_diversity else 0.0,
363
+ category_boost=request.category_boost if request.enable_category_boost else 1.0
364
+ )
365
+
366
+ elif request.recommendation_type == "category_focused":
367
+ if enhanced_recommendation_engine is None:
368
+ raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
369
+
370
+ recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
371
+ age=user_profile.age,
372
+ gender=user_profile.gender,
373
+ income=user_profile.income,
374
+ interaction_history=user_profile.interaction_history,
375
+ k=request.num_recommendations,
376
+ focus_percentage=0.8
377
+ )
378
+
379
+ else:
380
+ raise HTTPException(
381
+ status_code=400,
382
+ detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
383
+ )
384
+
385
+ # Format response
386
+ formatted_recommendations = []
387
+ for item_id, score, item_info in recommendations:
388
+ formatted_recommendations.append(
389
+ RecommendationResponse(
390
+ item_id=item_id,
391
+ score=score,
392
+ item_info=ItemInfo(**item_info)
393
+ )
394
+ )
395
+
396
+ return RecommendationsResponse(
397
+ recommendations=formatted_recommendations,
398
+ user_profile=user_profile,
399
+ recommendation_type=request.recommendation_type,
400
+ total_count=len(formatted_recommendations),
401
+ training_approach="2-phase"
402
+ )
403
+
404
+ except Exception as e:
405
+ raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
406
+
407
+
408
+ @app.post("/item-similarity", response_model=List[RecommendationResponse])
409
+ async def get_similar_items(request: ItemSimilarityRequest):
410
+ """Get items similar to a given item."""
411
+
412
+ if recommendation_engine is None:
413
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
414
+
415
+ try:
416
+ recommendations = recommendation_engine.recommend_items_content_based(
417
+ seed_item_id=request.item_id,
418
+ k=request.num_recommendations
419
+ )
420
+
421
+ formatted_recommendations = []
422
+ for item_id, score, item_info in recommendations:
423
+ formatted_recommendations.append(
424
+ RecommendationResponse(
425
+ item_id=item_id,
426
+ score=score,
427
+ item_info=ItemInfo(**item_info)
428
+ )
429
+ )
430
+
431
+ return formatted_recommendations
432
+
433
+ except Exception as e:
434
+ raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
435
+
436
+
437
+ @app.post("/predict-rating", response_model=RatingPredictionResponse)
438
+ async def predict_user_item_rating(request: RatingPredictionRequest):
439
+ """Predict rating for a user-item pair."""
440
+
441
+ if recommendation_engine is None:
442
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
443
+
444
+ try:
445
+ user_profile = request.user_profile
446
+
447
+ predicted_rating = recommendation_engine.predict_rating(
448
+ age=user_profile.age,
449
+ gender=user_profile.gender,
450
+ income=user_profile.income,
451
+ interaction_history=user_profile.interaction_history,
452
+ item_id=request.item_id
453
+ )
454
+
455
+ item_info = recommendation_engine._get_item_info(request.item_id)
456
+
457
+ return RatingPredictionResponse(
458
+ user_profile=user_profile,
459
+ item_id=request.item_id,
460
+ predicted_rating=predicted_rating,
461
+ item_info=ItemInfo(**item_info)
462
+ )
463
+
464
+ except Exception as e:
465
+ raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
466
+
467
+
468
+ @app.get("/items/{item_id}", response_model=ItemInfo)
469
+ async def get_item_info(item_id: int):
470
+ """Get information about a specific item."""
471
+
472
+ if recommendation_engine is None:
473
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
474
+
475
+ try:
476
+ item_info = recommendation_engine._get_item_info(item_id)
477
+ return ItemInfo(**item_info)
478
+
479
+ except Exception as e:
480
+ raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
481
+
482
+
483
+ @app.get("/items")
484
+ async def get_sample_items(limit: int = 20):
485
+ """Get a sample of items for testing."""
486
+
487
+ if recommendation_engine is None:
488
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
489
+
490
+ try:
491
+ # Get sample items from the dataframe
492
+ sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
493
+
494
+ items = []
495
+ for _, row in sample_items.iterrows():
496
+ items.append({
497
+ "product_id": int(row['product_id']),
498
+ "category_id": int(row['category_id']),
499
+ "category_code": str(row['category_code']),
500
+ "brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
501
+ "price": float(row['price'])
502
+ })
503
+
504
+ return {"items": items, "total": len(items), "training_approach": "2-phase"}
505
+
506
+ except Exception as e:
507
+ raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
508
+
509
+
510
+ if __name__ == "__main__":
511
+ print("🚀 Starting 2-Phase Training Recommendation API...")
512
+ print("📊 Training approach: Pre-trained item tower + Joint fine-tuning")
513
+ print("🌐 Server will be available at: http://localhost:8000")
514
+ print("📚 API docs at: http://localhost:8000/docs")
515
+
516
+ uvicorn.run(
517
+ "api_2phase:app",
518
+ host="0.0.0.0",
519
+ port=8000,
520
+ reload=True
521
+ )
api_joint.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ API for Single Joint Trained Recommendation System
4
+
5
+ This API serves recommendations from a model trained using the single joint approach:
6
+ - Both user and item towers trained simultaneously from scratch
7
+ - End-to-end optimization without pre-training phases
8
+
9
+ Usage:
10
+ python api_joint.py
11
+
12
+ Then access: http://localhost:8000
13
+ """
14
+
15
+ from fastapi import FastAPI, HTTPException
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel
18
+ from typing import List, Optional, Dict, Any
19
+ import uvicorn
20
+ import os
21
+ import sys
22
+ import pandas as pd
23
+
24
+ # Add src to path for imports and set working directory
25
+ parent_dir = os.path.dirname(__file__)
26
+ sys.path.append(parent_dir)
27
+ os.chdir(parent_dir) # Change to project root directory
28
+
29
+ from src.inference.recommendation_engine import RecommendationEngine
30
+ from src.utils.real_user_selector import RealUserSelector
31
+
32
+ # Initialize FastAPI app
33
+ app = FastAPI(
34
+ title="Two-Tower Recommendation API (Single Joint Training)",
35
+ description="API for serving recommendations using a two-tower architecture trained with single joint approach",
36
+ version="1.0.0-joint"
37
+ )
38
+
39
+ # Add CORS middleware
40
+ app.add_middleware(
41
+ CORSMiddleware,
42
+ allow_origins=["*"], # Configure appropriately for production
43
+ allow_credentials=True,
44
+ allow_methods=["*"],
45
+ allow_headers=["*"],
46
+ )
47
+
48
+ # Global instances
49
+ recommendation_engine = None
50
+ enhanced_recommendation_engine = None
51
+ real_user_selector = None
52
+
53
+
54
+ # Pydantic models for request/response
55
+ class UserProfile(BaseModel):
56
+ age: int
57
+ gender: str # "male" or "female"
58
+ income: float
59
+ interaction_history: Optional[List[int]] = []
60
+
61
+
62
+ class RecommendationRequest(BaseModel):
63
+ user_profile: UserProfile
64
+ num_recommendations: int = 10
65
+ recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
66
+ collaborative_weight: Optional[float] = 0.7
67
+ category_boost: Optional[float] = 1.5 # For enhanced recommendations
68
+ enable_category_boost: Optional[bool] = True
69
+ enable_diversity: Optional[bool] = True
70
+
71
+
72
+ class ItemSimilarityRequest(BaseModel):
73
+ item_id: int
74
+ num_recommendations: int = 10
75
+
76
+
77
+ class RatingPredictionRequest(BaseModel):
78
+ user_profile: UserProfile
79
+ item_id: int
80
+
81
+
82
+ class ItemInfo(BaseModel):
83
+ product_id: int
84
+ category_id: int
85
+ category_code: str
86
+ brand: str
87
+ price: float
88
+
89
+
90
+ class RecommendationResponse(BaseModel):
91
+ item_id: int
92
+ score: float
93
+ item_info: ItemInfo
94
+
95
+
96
+ class RecommendationsResponse(BaseModel):
97
+ recommendations: List[RecommendationResponse]
98
+ user_profile: UserProfile
99
+ recommendation_type: str
100
+ total_count: int
101
+ training_approach: str = "single-joint"
102
+
103
+
104
+ class RatingPredictionResponse(BaseModel):
105
+ user_profile: UserProfile
106
+ item_id: int
107
+ predicted_rating: float
108
+ item_info: ItemInfo
109
+
110
+
111
+ class RealUserProfile(BaseModel):
112
+ user_id: int
113
+ age: int
114
+ gender: str
115
+ income: int
116
+ interaction_history: List[int]
117
+ interaction_stats: Dict[str, int]
118
+ interaction_pattern: str
119
+ summary: str
120
+
121
+
122
+ class RealUsersResponse(BaseModel):
123
+ users: List[RealUserProfile]
124
+ total_count: int
125
+ dataset_summary: Dict[str, Any]
126
+
127
+
128
+ @app.on_event("startup")
129
+ async def startup_event():
130
+ """Initialize the recommendation engines and real user selector on startup."""
131
+ global recommendation_engine, enhanced_recommendation_engine, real_user_selector
132
+
133
+ print("🚀 Starting Single Joint Training API...")
134
+ print(" Training approach: End-to-end joint optimization from scratch")
135
+
136
+ try:
137
+ print("Loading single joint trained recommendation engine...")
138
+ recommendation_engine = RecommendationEngine()
139
+ print("✅ Single joint recommendation engine loaded successfully!")
140
+ except Exception as e:
141
+ print(f"❌ Error loading recommendation engine: {e}")
142
+ recommendation_engine = None
143
+
144
+ try:
145
+ print("Loading enhanced recommendation engine...")
146
+ # Try enhanced 128D engine first, fallback to regular enhanced
147
+ try:
148
+ from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
149
+ enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
150
+ print("✅ Using Enhanced 128D Recommendation Engine")
151
+ except:
152
+ from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
153
+ enhanced_recommendation_engine = EnhancedRecommendationEngine()
154
+ print("⚠️ Using fallback Enhanced Recommendation Engine")
155
+ print("Enhanced recommendation engine loaded successfully!")
156
+ except Exception as e:
157
+ print(f"Error loading enhanced recommendation engine: {e}")
158
+ enhanced_recommendation_engine = None
159
+
160
+ try:
161
+ print("Loading real user selector...")
162
+ real_user_selector = RealUserSelector()
163
+ print("Real user selector loaded successfully!")
164
+ except Exception as e:
165
+ print(f"Error loading real user selector: {e}")
166
+ real_user_selector = None
167
+
168
+ print("🎯 Single Joint API ready to serve recommendations!")
169
+
170
+
171
+ @app.get("/")
172
+ async def root():
173
+ """Root endpoint with API information."""
174
+ return {
175
+ "message": "Two-Tower Recommendation API (Single Joint Training)",
176
+ "version": "1.0.0-joint",
177
+ "training_approach": "single-joint (end-to-end optimization from scratch)",
178
+ "status": "active" if recommendation_engine is not None else "initialization_failed"
179
+ }
180
+
181
+
182
+ @app.get("/health")
183
+ async def health_check():
184
+ """Health check endpoint."""
185
+ return {
186
+ "status": "healthy" if recommendation_engine is not None else "unhealthy",
187
+ "engine_loaded": recommendation_engine is not None,
188
+ "training_approach": "single-joint"
189
+ }
190
+
191
+
192
+ @app.get("/model-info")
193
+ async def model_info():
194
+ """Get information about the loaded model."""
195
+ if recommendation_engine is None:
196
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
197
+
198
+ return {
199
+ "training_approach": "single-joint",
200
+ "description": "User and item towers trained simultaneously from scratch",
201
+ "advantages": [
202
+ "End-to-end optimization for better task alignment",
203
+ "No pre-training phase required",
204
+ "Faster overall training pipeline",
205
+ "Direct optimization for recommendation objectives"
206
+ ],
207
+ "embedding_dimension": 128,
208
+ "item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
209
+ "artifacts_loaded": {
210
+ "user_tower": "src/artifacts/user_tower_weights_best",
211
+ "item_tower_joint": "src/artifacts/item_tower_weights_finetuned_best",
212
+ "rating_model": "src/artifacts/rating_model_weights_best"
213
+ }
214
+ }
215
+
216
+
217
+ @app.get("/real-users", response_model=RealUsersResponse)
218
+ async def get_real_users(count: int = 100, min_interactions: int = 5):
219
+ """Get real user profiles with genuine interaction histories."""
220
+
221
+ if real_user_selector is None:
222
+ raise HTTPException(status_code=503, detail="Real user selector not available")
223
+
224
+ try:
225
+ # Get real user profiles
226
+ real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
227
+
228
+ # Get dataset summary
229
+ dataset_summary = real_user_selector.get_dataset_summary()
230
+
231
+ # Format users for response
232
+ formatted_users = []
233
+ for user in real_users:
234
+ formatted_users.append(RealUserProfile(**user))
235
+
236
+ return RealUsersResponse(
237
+ users=formatted_users,
238
+ total_count=len(formatted_users),
239
+ dataset_summary=dataset_summary
240
+ )
241
+
242
+ except Exception as e:
243
+ raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
244
+
245
+
246
+ @app.get("/real-users/{user_id}")
247
+ async def get_real_user_details(user_id: int):
248
+ """Get detailed interaction breakdown for a specific real user."""
249
+
250
+ if real_user_selector is None:
251
+ raise HTTPException(status_code=503, detail="Real user selector not available")
252
+
253
+ try:
254
+ user_details = real_user_selector.get_user_interaction_details(user_id)
255
+
256
+ if "error" in user_details:
257
+ raise HTTPException(status_code=404, detail=user_details["error"])
258
+
259
+ return user_details
260
+
261
+ except Exception as e:
262
+ raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
263
+
264
+
265
+ @app.get("/dataset-summary")
266
+ async def get_dataset_summary():
267
+ """Get summary statistics of the real dataset."""
268
+
269
+ if real_user_selector is None:
270
+ raise HTTPException(status_code=503, detail="Real user selector not available")
271
+
272
+ try:
273
+ return real_user_selector.get_dataset_summary()
274
+
275
+ except Exception as e:
276
+ raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
277
+
278
+
279
+ @app.post("/recommendations", response_model=RecommendationsResponse)
280
+ async def get_recommendations(request: RecommendationRequest):
281
+ """Get item recommendations for a user."""
282
+
283
+ if recommendation_engine is None:
284
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
285
+
286
+ try:
287
+ user_profile = request.user_profile
288
+
289
+ # Generate recommendations based on type
290
+ if request.recommendation_type == "collaborative":
291
+ recommendations = recommendation_engine.recommend_items_collaborative(
292
+ age=user_profile.age,
293
+ gender=user_profile.gender,
294
+ income=user_profile.income,
295
+ interaction_history=user_profile.interaction_history,
296
+ k=request.num_recommendations
297
+ )
298
+
299
+ elif request.recommendation_type == "content":
300
+ if not user_profile.interaction_history:
301
+ raise HTTPException(
302
+ status_code=400,
303
+ detail="Content-based recommendations require interaction history"
304
+ )
305
+
306
+ # Use most recent interaction as seed
307
+ seed_item = user_profile.interaction_history[-1]
308
+ recommendations = recommendation_engine.recommend_items_content_based(
309
+ seed_item_id=seed_item,
310
+ k=request.num_recommendations
311
+ )
312
+
313
+ elif request.recommendation_type == "hybrid":
314
+ recommendations = recommendation_engine.recommend_items_hybrid(
315
+ age=user_profile.age,
316
+ gender=user_profile.gender,
317
+ income=user_profile.income,
318
+ interaction_history=user_profile.interaction_history,
319
+ k=request.num_recommendations,
320
+ collaborative_weight=request.collaborative_weight
321
+ )
322
+
323
+ elif request.recommendation_type == "enhanced":
324
+ if enhanced_recommendation_engine is None:
325
+ raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
326
+
327
+ # Check if it's the 128D engine or fallback
328
+ if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
329
+ # 128D Enhanced engine
330
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
331
+ age=user_profile.age,
332
+ gender=user_profile.gender,
333
+ income=user_profile.income,
334
+ interaction_history=user_profile.interaction_history,
335
+ k=request.num_recommendations,
336
+ diversity_weight=0.3 if request.enable_diversity else 0.0,
337
+ category_boost=request.category_boost if request.enable_category_boost else 1.0
338
+ )
339
+ else:
340
+ # Fallback enhanced engine
341
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
342
+ age=user_profile.age,
343
+ gender=user_profile.gender,
344
+ income=user_profile.income,
345
+ interaction_history=user_profile.interaction_history,
346
+ k=request.num_recommendations,
347
+ collaborative_weight=request.collaborative_weight,
348
+ category_boost=request.category_boost,
349
+ enable_category_boost=request.enable_category_boost,
350
+ enable_diversity=request.enable_diversity
351
+ )
352
+
353
+ elif request.recommendation_type == "enhanced_128d":
354
+ if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
355
+ raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
356
+
357
+ recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
358
+ age=user_profile.age,
359
+ gender=user_profile.gender,
360
+ income=user_profile.income,
361
+ interaction_history=user_profile.interaction_history,
362
+ k=request.num_recommendations,
363
+ diversity_weight=0.3 if request.enable_diversity else 0.0,
364
+ category_boost=request.category_boost if request.enable_category_boost else 1.0
365
+ )
366
+
367
+ elif request.recommendation_type == "category_focused":
368
+ if enhanced_recommendation_engine is None:
369
+ raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
370
+
371
+ recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
372
+ age=user_profile.age,
373
+ gender=user_profile.gender,
374
+ income=user_profile.income,
375
+ interaction_history=user_profile.interaction_history,
376
+ k=request.num_recommendations,
377
+ focus_percentage=0.8
378
+ )
379
+
380
+ else:
381
+ raise HTTPException(
382
+ status_code=400,
383
+ detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
384
+ )
385
+
386
+ # Format response
387
+ formatted_recommendations = []
388
+ for item_id, score, item_info in recommendations:
389
+ formatted_recommendations.append(
390
+ RecommendationResponse(
391
+ item_id=item_id,
392
+ score=score,
393
+ item_info=ItemInfo(**item_info)
394
+ )
395
+ )
396
+
397
+ return RecommendationsResponse(
398
+ recommendations=formatted_recommendations,
399
+ user_profile=user_profile,
400
+ recommendation_type=request.recommendation_type,
401
+ total_count=len(formatted_recommendations),
402
+ training_approach="single-joint"
403
+ )
404
+
405
+ except Exception as e:
406
+ raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
407
+
408
+
409
+ @app.post("/item-similarity", response_model=List[RecommendationResponse])
410
+ async def get_similar_items(request: ItemSimilarityRequest):
411
+ """Get items similar to a given item."""
412
+
413
+ if recommendation_engine is None:
414
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
415
+
416
+ try:
417
+ recommendations = recommendation_engine.recommend_items_content_based(
418
+ seed_item_id=request.item_id,
419
+ k=request.num_recommendations
420
+ )
421
+
422
+ formatted_recommendations = []
423
+ for item_id, score, item_info in recommendations:
424
+ formatted_recommendations.append(
425
+ RecommendationResponse(
426
+ item_id=item_id,
427
+ score=score,
428
+ item_info=ItemInfo(**item_info)
429
+ )
430
+ )
431
+
432
+ return formatted_recommendations
433
+
434
+ except Exception as e:
435
+ raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
436
+
437
+
438
+ @app.post("/predict-rating", response_model=RatingPredictionResponse)
439
+ async def predict_user_item_rating(request: RatingPredictionRequest):
440
+ """Predict rating for a user-item pair."""
441
+
442
+ if recommendation_engine is None:
443
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
444
+
445
+ try:
446
+ user_profile = request.user_profile
447
+
448
+ predicted_rating = recommendation_engine.predict_rating(
449
+ age=user_profile.age,
450
+ gender=user_profile.gender,
451
+ income=user_profile.income,
452
+ interaction_history=user_profile.interaction_history,
453
+ item_id=request.item_id
454
+ )
455
+
456
+ item_info = recommendation_engine._get_item_info(request.item_id)
457
+
458
+ return RatingPredictionResponse(
459
+ user_profile=user_profile,
460
+ item_id=request.item_id,
461
+ predicted_rating=predicted_rating,
462
+ item_info=ItemInfo(**item_info)
463
+ )
464
+
465
+ except Exception as e:
466
+ raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
467
+
468
+
469
+ @app.get("/items/{item_id}", response_model=ItemInfo)
470
+ async def get_item_info(item_id: int):
471
+ """Get information about a specific item."""
472
+
473
+ if recommendation_engine is None:
474
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
475
+
476
+ try:
477
+ item_info = recommendation_engine._get_item_info(item_id)
478
+ return ItemInfo(**item_info)
479
+
480
+ except Exception as e:
481
+ raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
482
+
483
+
484
+ @app.get("/items")
485
+ async def get_sample_items(limit: int = 20):
486
+ """Get a sample of items for testing."""
487
+
488
+ if recommendation_engine is None:
489
+ raise HTTPException(status_code=503, detail="Recommendation engine not available")
490
+
491
+ try:
492
+ # Get sample items from the dataframe
493
+ sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
494
+
495
+ items = []
496
+ for _, row in sample_items.iterrows():
497
+ items.append({
498
+ "product_id": int(row['product_id']),
499
+ "category_id": int(row['category_id']),
500
+ "category_code": str(row['category_code']),
501
+ "brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
502
+ "price": float(row['price'])
503
+ })
504
+
505
+ return {"items": items, "total": len(items), "training_approach": "single-joint"}
506
+
507
+ except Exception as e:
508
+ raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
509
+
510
+
511
+ if __name__ == "__main__":
512
+ print("🚀 Starting Single Joint Training Recommendation API...")
513
+ print("⚡ Training approach: End-to-end joint optimization from scratch")
514
+ print("🌐 Server will be available at: http://localhost:8000")
515
+ print("📚 API docs at: http://localhost:8000/docs")
516
+
517
+ uvicorn.run(
518
+ "api_joint:app",
519
+ host="0.0.0.0",
520
+ port=8000,
521
+ reload=True
522
+ )
frontend/src/App.css CHANGED
@@ -1,33 +1,1048 @@
1
  .App {
2
  text-align: center;
 
3
  }
4
 
5
- .App-logo {
6
- height: 40vmin;
7
- pointer-events: none;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
9
 
10
- @media (prefers-reduced-motion: no-preference) {
11
- .App-logo {
12
- animation: App-logo-spin infinite 20s linear;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  }
15
 
16
- .App-header {
17
- background-color: #282c34;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  padding: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  color: white;
20
  }
21
 
22
- .App-link {
23
- color: #61dafb;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
- @keyframes App-logo-spin {
27
- from {
28
- transform: rotate(0deg);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  }
30
- to {
31
- transform: rotate(360deg);
 
32
  }
33
  }
 
1
  .App {
2
  text-align: center;
3
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
4
  }
5
 
6
+ .container {
7
+ max-width: 1200px;
8
+ margin: 0 auto;
9
+ padding: 20px;
10
+ text-align: left;
11
+ }
12
+
13
+ .header {
14
+ text-align: center;
15
+ margin-bottom: 30px;
16
+ padding: 20px;
17
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
18
+ color: white;
19
+ border-radius: 10px;
20
+ }
21
+
22
+ .header h1 {
23
+ margin: 0 0 10px 0;
24
+ font-size: 2.5rem;
25
+ }
26
+
27
+ .header p {
28
+ margin: 0;
29
+ opacity: 0.9;
30
+ font-size: 1.1rem;
31
+ }
32
+
33
+ .dataset-info {
34
+ margin-top: 15px;
35
+ padding: 10px;
36
+ background: rgba(255, 255, 255, 0.2);
37
+ border-radius: 5px;
38
+ font-size: 0.95rem;
39
+ }
40
+
41
+ /* Real User Selector */
42
+ .real-user-selector {
43
+ background: #e8f5e8;
44
+ padding: 25px;
45
+ border-radius: 10px;
46
+ margin-bottom: 30px;
47
+ border: 1px solid #c3e6c3;
48
+ }
49
+
50
+ .real-user-selector h2 {
51
+ margin-top: 0;
52
+ color: #2d5a2d;
53
+ border-bottom: 2px solid #28a745;
54
+ padding-bottom: 10px;
55
+ }
56
+
57
+ .user-selector-controls {
58
+ display: flex;
59
+ align-items: center;
60
+ gap: 15px;
61
+ margin-bottom: 20px;
62
+ }
63
+
64
+ .user-selector-controls label {
65
+ font-weight: bold;
66
+ min-width: 200px;
67
+ }
68
+
69
+ .user-selector-controls select {
70
+ flex: 1;
71
+ padding: 8px 12px;
72
+ border: 1px solid #28a745;
73
+ border-radius: 5px;
74
+ font-size: 14px;
75
+ }
76
+
77
+ .selected-real-user {
78
+ background: rgba(255, 255, 255, 0.8);
79
+ padding: 20px;
80
+ border-radius: 8px;
81
+ border: 1px solid #28a745;
82
+ }
83
+
84
+ .real-user-stats {
85
+ display: grid;
86
+ gap: 12px;
87
+ }
88
+
89
+ .user-stat {
90
+ display: flex;
91
+ justify-content: space-between;
92
+ padding: 8px 0;
93
+ border-bottom: 1px solid #ddd;
94
+ }
95
+
96
+ .stat-label {
97
+ font-weight: bold;
98
+ color: #495057;
99
+ }
100
+
101
+ .stat-value {
102
+ color: #28a745;
103
+ font-weight: 600;
104
+ }
105
+
106
+ .real-interaction-summary {
107
+ display: flex;
108
+ gap: 20px;
109
+ justify-content: center;
110
+ margin: 20px 0;
111
+ }
112
+
113
+ .real-history-info {
114
+ background: rgba(255, 255, 255, 0.8);
115
+ padding: 15px;
116
+ border-radius: 8px;
117
+ margin-top: 15px;
118
+ }
119
+
120
+ .real-history-info p {
121
+ margin: 8px 0;
122
+ }
123
+
124
+ /* Expand Interactions Button */
125
+ .expand-interactions-btn {
126
+ background: #007bff;
127
+ color: white;
128
+ border: 1px solid #007bff;
129
+ border-radius: 5px;
130
+ padding: 10px 20px;
131
+ margin-top: 15px;
132
+ cursor: pointer;
133
+ font-size: 14px;
134
+ transition: background-color 0.3s;
135
+ }
136
+
137
+ .expand-interactions-btn:hover:not(:disabled) {
138
+ background: #0056b3;
139
+ }
140
+
141
+ .expand-interactions-btn:disabled {
142
+ opacity: 0.6;
143
+ cursor: not-allowed;
144
+ }
145
+
146
+ /* User Interactions Timeline */
147
+ .user-interactions-timeline {
148
+ margin-top: 20px;
149
+ padding: 20px;
150
+ background: rgba(255, 255, 255, 0.95);
151
+ border-radius: 8px;
152
+ border: 1px solid #007bff;
153
+ }
154
+
155
+ .user-interactions-timeline h4 {
156
+ margin-top: 0;
157
+ color: #007bff;
158
+ border-bottom: 2px solid #007bff;
159
+ padding-bottom: 8px;
160
+ }
161
+
162
+ .user-interactions-timeline h5 {
163
+ color: #495057;
164
+ margin: 15px 0 10px 0;
165
+ }
166
+
167
+ .timeline-stats {
168
+ display: flex;
169
+ flex-wrap: wrap;
170
+ gap: 20px;
171
+ margin: 15px 0;
172
+ padding: 15px;
173
+ background: #f8f9fa;
174
+ border-radius: 5px;
175
+ font-size: 14px;
176
+ }
177
+
178
+ .timeline-stats span {
179
+ color: #495057;
180
+ }
181
+
182
+ /* Interactions List */
183
+ .interactions-list {
184
+ max-height: 400px;
185
+ overflow-y: auto;
186
+ border: 1px solid #e9ecef;
187
+ border-radius: 5px;
188
+ background: white;
189
+ }
190
+
191
+ .interaction-timeline-item {
192
+ display: flex;
193
+ align-items: center;
194
+ padding: 12px 15px;
195
+ border-bottom: 1px solid #e9ecef;
196
+ transition: background-color 0.2s;
197
+ }
198
+
199
+ .interaction-timeline-item:hover {
200
+ background: #f8f9fa;
201
+ }
202
+
203
+ .interaction-timeline-item:last-child {
204
+ border-bottom: none;
205
+ }
206
+
207
+ .interaction-timeline-time {
208
+ flex: 0 0 180px;
209
+ font-size: 12px;
210
+ color: #6c757d;
211
+ font-family: monospace;
212
+ }
213
+
214
+ .interaction-timeline-content {
215
+ flex: 1;
216
+ display: flex;
217
+ flex-direction: column;
218
+ gap: 5px;
219
+ }
220
+
221
+ .interaction-main-info {
222
+ display: flex;
223
+ align-items: center;
224
+ gap: 10px;
225
+ }
226
+
227
+ .interaction-item-details {
228
+ display: flex;
229
+ align-items: center;
230
+ gap: 15px;
231
+ padding-left: 20px;
232
+ font-size: 13px;
233
+ }
234
+
235
+ .interaction-item-id {
236
+ color: #495057;
237
+ font-size: 13px;
238
+ font-weight: 500;
239
+ }
240
+
241
+ .interaction-icon {
242
+ font-size: 16px;
243
+ min-width: 20px;
244
+ }
245
+
246
+ .item-brand {
247
+ color: #007bff;
248
+ font-size: 14px;
249
+ min-width: 120px;
250
+ }
251
+
252
+ .item-category {
253
+ color: #6c757d;
254
+ font-size: 12px;
255
+ background: #f8f9fa;
256
+ padding: 2px 6px;
257
+ border-radius: 3px;
258
+ border: 1px solid #e9ecef;
259
+ min-width: 150px;
260
+ text-align: center;
261
+ }
262
+
263
+ .item-price {
264
+ color: #28a745;
265
+ font-weight: 600;
266
+ font-size: 13px;
267
+ min-width: 70px;
268
+ text-align: right;
269
+ }
270
+
271
+ /* Interaction type badges in timeline */
272
+ .interaction-timeline-content .interaction-type {
273
+ font-size: 11px;
274
+ font-weight: bold;
275
+ padding: 3px 8px;
276
+ border-radius: 12px;
277
+ text-transform: uppercase;
278
+ letter-spacing: 0.5px;
279
+ }
280
+
281
+ .interaction-timeline-content .interaction-type.view {
282
+ background: #d1ecf1;
283
+ color: #0c5460;
284
+ }
285
+
286
+ .interaction-timeline-content .interaction-type.cart {
287
+ background: #fff3cd;
288
+ color: #856404;
289
+ }
290
+
291
+ .interaction-timeline-content .interaction-type.purchase {
292
+ background: #d4edda;
293
+ color: #155724;
294
+ }
295
+
296
+ /* Category Analysis Columns */
297
+ .category-analysis {
298
+ margin-top: 20px;
299
+ padding: 20px;
300
+ background: rgba(255, 255, 255, 0.95);
301
+ border-radius: 8px;
302
+ border: 1px solid #6c757d;
303
+ }
304
+
305
+ .category-analysis h4 {
306
+ margin-top: 0;
307
+ color: #495057;
308
+ border-bottom: 2px solid #6c757d;
309
+ padding-bottom: 8px;
310
+ }
311
+
312
+ .category-columns {
313
+ display: grid;
314
+ grid-template-columns: 1fr 1fr;
315
+ gap: 30px;
316
+ margin: 20px 0;
317
+ }
318
+
319
+ .category-column {
320
+ background: #f8f9fa;
321
+ padding: 15px;
322
+ border-radius: 8px;
323
+ border: 1px solid #dee2e6;
324
+ }
325
+
326
+ .category-column h5 {
327
+ margin-top: 0;
328
+ margin-bottom: 15px;
329
+ color: #495057;
330
+ font-size: 16px;
331
+ padding-bottom: 5px;
332
+ border-bottom: 1px solid #dee2e6;
333
+ }
334
+
335
+ .category-percentages {
336
+ display: flex;
337
+ flex-direction: column;
338
+ gap: 8px;
339
+ }
340
+
341
+ .category-item {
342
+ display: flex;
343
+ align-items: center;
344
+ gap: 10px;
345
+ padding: 6px 0;
346
+ }
347
+
348
+ .category-bar-container {
349
+ flex: 0 0 80px;
350
+ height: 16px;
351
+ background: #e9ecef;
352
+ border-radius: 8px;
353
+ overflow: hidden;
354
+ position: relative;
355
+ }
356
+
357
+ .category-bar {
358
+ height: 100%;
359
+ border-radius: 8px;
360
+ transition: width 0.3s ease;
361
+ }
362
+
363
+ .category-bar.user-category {
364
+ background: linear-gradient(90deg, #007bff, #0056b3);
365
+ }
366
+
367
+ .category-bar.rec-category-matched {
368
+ background: linear-gradient(90deg, #28a745, #1e7e34);
369
+ }
370
+
371
+ .category-bar.rec-category-new {
372
+ background: linear-gradient(90deg, #ffc107, #e0a800);
373
  }
374
 
375
+ .category-label {
376
+ flex: 1;
377
+ font-size: 12px;
378
+ color: #495057;
379
+ text-overflow: ellipsis;
380
+ overflow: hidden;
381
+ white-space: nowrap;
382
+ }
383
+
384
+ .category-percent {
385
+ flex: 0 0 40px;
386
+ font-size: 11px;
387
+ font-weight: bold;
388
+ text-align: right;
389
+ color: #495057;
390
+ }
391
+
392
+ .match-indicator {
393
+ color: #28a745;
394
+ font-weight: bold;
395
+ font-size: 14px;
396
+ }
397
+
398
+ .category-match-summary {
399
+ margin-top: 15px;
400
+ padding: 15px;
401
+ background: #e7f3ff;
402
+ border-radius: 6px;
403
+ border-left: 4px solid #007bff;
404
+ }
405
+
406
+ .category-match-summary p {
407
+ margin: 0;
408
+ font-size: 14px;
409
+ color: #495057;
410
+ }
411
+
412
+ .match-legend {
413
+ display: flex;
414
+ gap: 20px;
415
+ margin-top: 8px;
416
+ font-size: 12px;
417
+ }
418
+
419
+ .legend-item {
420
+ display: flex;
421
+ align-items: center;
422
+ gap: 6px;
423
+ }
424
+
425
+ .legend-dot {
426
+ width: 12px;
427
+ height: 12px;
428
+ border-radius: 50%;
429
+ }
430
+
431
+ .legend-dot.matched {
432
+ background: #28a745;
433
+ }
434
+
435
+ .legend-dot.new {
436
+ background: #ffc107;
437
+ }
438
+
439
+ /* Responsive design for smaller screens */
440
+ @media (max-width: 768px) {
441
+ .category-columns {
442
+ grid-template-columns: 1fr;
443
+ gap: 20px;
444
  }
445
+
446
+ .category-bar-container {
447
+ flex: 0 0 60px;
448
+ }
449
+
450
+ .match-legend {
451
+ flex-direction: column;
452
+ gap: 8px;
453
+ }
454
+ }
455
+
456
+ /* Pagination Styles */
457
+ .pagination-info {
458
+ margin: 20px 0;
459
+ text-align: center;
460
  }
461
 
462
+ .pagination-info p {
463
+ margin: 0 0 15px 0;
464
+ color: #6c757d;
465
+ font-size: 14px;
466
+ }
467
+
468
+ .pagination-controls {
469
+ display: flex;
470
+ align-items: center;
471
+ justify-content: center;
472
+ gap: 10px;
473
+ margin: 15px 0;
474
+ }
475
+
476
+ .pagination-btn {
477
+ background: #007bff;
478
+ color: white;
479
+ border: 1px solid #007bff;
480
+ border-radius: 5px;
481
+ padding: 8px 16px;
482
+ cursor: pointer;
483
+ font-size: 14px;
484
+ transition: background-color 0.3s;
485
+ }
486
+
487
+ .pagination-btn:hover:not(:disabled) {
488
+ background: #0056b3;
489
+ border-color: #0056b3;
490
+ }
491
+
492
+ .pagination-btn:disabled {
493
+ background: #6c757d;
494
+ border-color: #6c757d;
495
+ cursor: not-allowed;
496
+ opacity: 0.65;
497
+ }
498
+
499
+ .page-numbers {
500
+ display: flex;
501
+ align-items: center;
502
+ gap: 5px;
503
+ }
504
+
505
+ .page-number {
506
+ background: #f8f9fa;
507
+ color: #007bff;
508
+ border: 1px solid #dee2e6;
509
+ border-radius: 3px;
510
+ padding: 6px 10px;
511
+ cursor: pointer;
512
+ font-size: 14px;
513
+ transition: all 0.3s;
514
+ min-width: 35px;
515
+ }
516
+
517
+ .page-number:hover {
518
+ background: #e9ecef;
519
+ border-color: #007bff;
520
+ }
521
+
522
+ .page-number.active {
523
+ background: #007bff;
524
+ color: white;
525
+ border-color: #007bff;
526
+ }
527
+
528
+ .pagination-ellipsis {
529
+ color: #6c757d;
530
+ padding: 0 5px;
531
+ }
532
+
533
+ .bottom-pagination {
534
+ border-top: 1px solid #dee2e6;
535
+ padding-top: 20px;
536
+ margin-top: 20px;
537
+ }
538
+
539
+ .page-indicator {
540
+ color: #6c757d;
541
+ font-size: 14px;
542
+ margin: 0 15px;
543
+ }
544
+
545
+ /* Responsive pagination */
546
+ @media (max-width: 768px) {
547
+ .pagination-controls {
548
+ flex-wrap: wrap;
549
+ gap: 8px;
550
+ }
551
+
552
+ .page-numbers {
553
+ order: 3;
554
+ width: 100%;
555
+ justify-content: center;
556
+ margin-top: 10px;
557
+ }
558
+
559
+ .pagination-btn {
560
+ padding: 6px 12px;
561
+ font-size: 13px;
562
+ }
563
+
564
+ .page-number {
565
+ padding: 5px 8px;
566
+ min-width: 30px;
567
+ font-size: 13px;
568
+ }
569
+ }
570
+
571
+ /* User Profile Form */
572
+ .user-profile-form {
573
+ background: #f8f9fa;
574
+ padding: 25px;
575
+ border-radius: 10px;
576
+ margin-bottom: 30px;
577
+ border: 1px solid #e9ecef;
578
+ }
579
+
580
+ .user-profile-form h2 {
581
+ margin-top: 0;
582
+ color: #495057;
583
+ border-bottom: 2px solid #007bff;
584
+ padding-bottom: 10px;
585
+ }
586
+
587
+ .form-row {
588
+ display: flex;
589
+ gap: 20px;
590
+ flex-wrap: wrap;
591
+ }
592
+
593
+ .form-group {
594
+ flex: 1;
595
+ min-width: 200px;
596
+ }
597
+
598
+ .form-group label {
599
+ display: block;
600
+ margin-bottom: 5px;
601
+ font-weight: 600;
602
+ color: #495057;
603
+ }
604
+
605
+ .form-group input,
606
+ .form-group select {
607
+ width: 100%;
608
+ padding: 10px;
609
+ border: 1px solid #ced4da;
610
+ border-radius: 5px;
611
+ font-size: 14px;
612
+ }
613
+
614
+ .form-group input:focus,
615
+ .form-group select:focus {
616
+ outline: none;
617
+ border-color: #007bff;
618
+ box-shadow: 0 0 0 0.2rem rgba(0, 123, 255, 0.25);
619
+ }
620
+
621
+ /* Interaction Patterns */
622
+ .interaction-patterns {
623
+ background: #fff;
624
+ padding: 25px;
625
+ border-radius: 10px;
626
+ margin-bottom: 30px;
627
+ border: 1px solid #e9ecef;
628
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
629
+ }
630
+
631
+ .interaction-patterns h2 {
632
+ margin-top: 0;
633
+ color: #495057;
634
+ border-bottom: 2px solid #28a745;
635
+ padding-bottom: 10px;
636
+ }
637
+
638
+ .pattern-buttons {
639
+ display: flex;
640
+ gap: 15px;
641
+ margin: 20px 0;
642
+ flex-wrap: wrap;
643
+ }
644
+
645
+ .pattern-btn {
646
+ padding: 15px 20px;
647
+ border: 2px solid #007bff;
648
+ background: white;
649
+ color: #007bff;
650
+ border-radius: 8px;
651
+ cursor: pointer;
652
+ transition: all 0.3s ease;
653
+ font-size: 14px;
654
+ text-align: center;
655
+ min-width: 120px;
656
+ }
657
+
658
+ .pattern-btn:hover {
659
+ background: #007bff;
660
+ color: white;
661
+ transform: translateY(-2px);
662
+ box-shadow: 0 4px 8px rgba(0,123,255,0.3);
663
+ }
664
+
665
+ .pattern-btn.active {
666
+ background: #007bff;
667
+ color: white;
668
+ box-shadow: 0 4px 8px rgba(0,123,255,0.3);
669
+ }
670
+
671
+ .pattern-btn small {
672
+ opacity: 0.8;
673
+ font-size: 12px;
674
+ }
675
+
676
+ .pattern-summary {
677
+ display: flex;
678
+ gap: 20px;
679
+ margin: 20px 0;
680
+ flex-wrap: wrap;
681
+ }
682
+
683
+ .summary-card {
684
+ background: white;
685
+ border: 1px solid #e9ecef;
686
+ border-radius: 8px;
687
  padding: 20px;
688
+ text-align: center;
689
+ flex: 1;
690
+ min-width: 100px;
691
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05);
692
+ }
693
+
694
+ .summary-card.views {
695
+ border-left: 4px solid #17a2b8;
696
+ }
697
+
698
+ .summary-card.carts {
699
+ border-left: 4px solid #ffc107;
700
+ }
701
+
702
+ .summary-card.purchases {
703
+ border-left: 4px solid #28a745;
704
+ }
705
+
706
+ .summary-number {
707
+ font-size: 2rem;
708
+ font-weight: bold;
709
+ margin-bottom: 5px;
710
+ }
711
+
712
+ .summary-label {
713
+ color: #6c757d;
714
+ font-size: 14px;
715
+ }
716
+
717
+ /* Interaction History */
718
+ .interaction-history {
719
+ margin-top: 25px;
720
+ }
721
+
722
+ .interaction-history h3 {
723
+ color: #495057;
724
+ margin-bottom: 15px;
725
+ }
726
+
727
+ .interaction-item {
728
+ background: white;
729
+ border: 1px solid #e9ecef;
730
+ border-radius: 8px;
731
+ margin-bottom: 10px;
732
+ padding: 15px;
733
+ display: flex;
734
+ justify-content: space-between;
735
+ align-items: center;
736
+ transition: all 0.2s ease;
737
+ }
738
+
739
+ .interaction-item:hover {
740
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
741
+ transform: translateY(-1px);
742
+ }
743
+
744
+ .interaction-main {
745
+ display: flex;
746
+ align-items: center;
747
+ gap: 15px;
748
+ flex: 1;
749
+ }
750
+
751
+ .interaction-type {
752
+ padding: 4px 12px;
753
+ border-radius: 20px;
754
+ font-size: 12px;
755
+ font-weight: 600;
756
+ text-transform: uppercase;
757
+ min-width: 80px;
758
+ text-align: center;
759
+ }
760
+
761
+ .interaction-type.view {
762
+ background: #cce7ff;
763
+ color: #0066cc;
764
+ }
765
+
766
+ .interaction-type.cart {
767
+ background: #fff3cd;
768
+ color: #856404;
769
+ }
770
+
771
+ .interaction-type.purchase {
772
+ background: #d4edda;
773
+ color: #155724;
774
+ }
775
+
776
+ .interaction-details {
777
+ flex: 1;
778
+ font-size: 14px;
779
+ }
780
+
781
+ .category-tag {
782
+ background: #e9ecef;
783
+ color: #495057;
784
+ padding: 2px 8px;
785
+ border-radius: 12px;
786
+ font-size: 12px;
787
+ font-weight: 500;
788
+ display: inline-block;
789
+ margin: 0 5px;
790
+ }
791
+
792
+ .interaction-expand {
793
+ padding: 6px 12px;
794
+ background: #f8f9fa;
795
+ border: 1px solid #dee2e6;
796
+ border-radius: 4px;
797
+ cursor: pointer;
798
+ font-size: 12px;
799
+ color: #495057;
800
+ transition: all 0.2s ease;
801
+ }
802
+
803
+ .interaction-expand:hover {
804
+ background: #e9ecef;
805
+ border-color: #adb5bd;
806
+ }
807
+
808
+ .interaction-expanded {
809
+ background: #f8f9fa;
810
+ border: 1px solid #dee2e6;
811
+ border-radius: 8px;
812
+ padding: 20px;
813
+ margin-top: 10px;
814
+ }
815
+
816
+ .interaction-meta {
817
+ display: grid;
818
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
819
+ gap: 15px;
820
+ }
821
+
822
+ .interaction-meta-item {
823
+ display: flex;
824
+ justify-content: space-between;
825
+ padding: 8px 0;
826
+ border-bottom: 1px solid #e9ecef;
827
+ }
828
+
829
+ .interaction-meta-label {
830
+ font-weight: 600;
831
+ color: #495057;
832
+ }
833
+
834
+ .interaction-meta-value {
835
+ color: #6c757d;
836
+ font-family: monospace;
837
+ }
838
+
839
+ /* Recommendation Controls */
840
+ .recommendation-controls {
841
+ background: #fff;
842
+ padding: 25px;
843
+ border-radius: 10px;
844
+ margin-bottom: 30px;
845
+ border: 1px solid #e9ecef;
846
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
847
+ }
848
+
849
+ .recommendation-controls h2 {
850
+ margin-top: 0;
851
+ color: #495057;
852
+ border-bottom: 2px solid #dc3545;
853
+ padding-bottom: 10px;
854
+ }
855
+
856
+ .controls-row {
857
+ display: flex;
858
+ gap: 20px;
859
+ align-items: end;
860
+ flex-wrap: wrap;
861
+ }
862
+
863
+ .btn {
864
+ padding: 12px 24px;
865
+ border: none;
866
+ border-radius: 6px;
867
+ cursor: pointer;
868
+ font-size: 14px;
869
+ font-weight: 600;
870
+ transition: all 0.3s ease;
871
+ text-decoration: none;
872
+ display: inline-block;
873
+ text-align: center;
874
+ }
875
+
876
+ .btn-primary {
877
+ background: #007bff;
878
  color: white;
879
  }
880
 
881
+ .btn-primary:hover:not(:disabled) {
882
+ background: #0056b3;
883
+ transform: translateY(-1px);
884
+ box-shadow: 0 4px 8px rgba(0,123,255,0.3);
885
+ }
886
+
887
+ .btn:disabled {
888
+ background: #6c757d;
889
+ cursor: not-allowed;
890
+ opacity: 0.65;
891
+ }
892
+
893
+ /* Recommendations */
894
+ .recommendations {
895
+ background: #fff;
896
+ padding: 25px;
897
+ border-radius: 10px;
898
+ border: 1px solid #e9ecef;
899
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
900
+ }
901
+
902
+ .recommendations h2 {
903
+ margin-top: 0;
904
+ color: #495057;
905
+ border-bottom: 2px solid #6f42c1;
906
+ padding-bottom: 10px;
907
  }
908
 
909
+ .stats {
910
+ background: #f8f9fa;
911
+ padding: 15px;
912
+ border-radius: 6px;
913
+ margin-bottom: 20px;
914
+ font-size: 14px;
915
+ color: #495057;
916
+ }
917
+
918
+ .recommendations-grid {
919
+ display: grid;
920
+ grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
921
+ gap: 20px;
922
+ margin-top: 20px;
923
+ }
924
+
925
+ .recommendation-card {
926
+ background: white;
927
+ border: 1px solid #e9ecef;
928
+ border-radius: 8px;
929
+ padding: 20px;
930
+ transition: all 0.2s ease;
931
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05);
932
+ }
933
+
934
+ .recommendation-card:hover {
935
+ transform: translateY(-2px);
936
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
937
+ border-color: #007bff;
938
+ }
939
+
940
+ .card-header {
941
+ display: flex;
942
+ justify-content: space-between;
943
+ align-items: center;
944
+ margin-bottom: 15px;
945
+ padding-bottom: 10px;
946
+ border-bottom: 1px solid #e9ecef;
947
+ }
948
+
949
+ .item-id {
950
+ font-weight: 600;
951
+ color: #495057;
952
+ }
953
+
954
+ .score {
955
+ background: #007bff;
956
+ color: white;
957
+ padding: 4px 8px;
958
+ border-radius: 12px;
959
+ font-size: 12px;
960
+ font-weight: 600;
961
+ }
962
+
963
+ .item-details p {
964
+ margin: 8px 0;
965
+ }
966
+
967
+ .brand {
968
+ font-weight: 600;
969
+ color: #495057;
970
+ font-size: 16px;
971
+ }
972
+
973
+ .price {
974
+ color: #28a745;
975
+ font-weight: 600;
976
+ font-size: 18px;
977
+ }
978
+
979
+ .category {
980
+ color: #6c757d;
981
+ font-size: 14px;
982
+ background: #f8f9fa;
983
+ padding: 4px 8px;
984
+ border-radius: 4px;
985
+ display: inline-block;
986
+ }
987
+
988
+ /* Error and Loading States */
989
+ .error {
990
+ background: #f8d7da;
991
+ color: #721c24;
992
+ padding: 15px;
993
+ border-radius: 6px;
994
+ border: 1px solid #f5c6cb;
995
+ margin-bottom: 20px;
996
+ }
997
+
998
+ .loading {
999
+ text-align: center;
1000
+ padding: 40px;
1001
+ background: #f8f9fa;
1002
+ border-radius: 10px;
1003
+ border: 1px solid #e9ecef;
1004
+ }
1005
+
1006
+ .loading h3 {
1007
+ color: #495057;
1008
+ margin-bottom: 10px;
1009
+ }
1010
+
1011
+ .loading p {
1012
+ color: #6c757d;
1013
+ margin: 0;
1014
+ }
1015
+
1016
+ /* Responsive Design */
1017
+ @media (max-width: 768px) {
1018
+ .container {
1019
+ padding: 10px;
1020
+ }
1021
+
1022
+ .form-row,
1023
+ .controls-row {
1024
+ flex-direction: column;
1025
+ }
1026
+
1027
+ .pattern-buttons {
1028
+ flex-direction: column;
1029
+ }
1030
+
1031
+ .pattern-btn {
1032
+ width: 100%;
1033
+ }
1034
+
1035
+ .recommendations-grid {
1036
+ grid-template-columns: 1fr;
1037
+ }
1038
+
1039
+ .interaction-main {
1040
+ flex-direction: column;
1041
+ align-items: flex-start;
1042
+ gap: 10px;
1043
  }
1044
+
1045
+ .interaction-meta {
1046
+ grid-template-columns: 1fr;
1047
  }
1048
  }
frontend/src/App.js CHANGED
@@ -22,21 +22,38 @@ function App() {
22
  });
23
 
24
  const [recommendationType, setRecommendationType] = useState('hybrid');
25
- const [numRecommendations, setNumRecommendations] = useState(10);
26
  const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
27
 
28
  const [recommendations, setRecommendations] = useState([]);
29
  const [loading, setLoading] = useState(false);
30
  const [error, setError] = useState(null);
31
 
 
 
 
 
32
  const [sampleItems, setSampleItems] = useState([]);
33
  const [interactions, setInteractions] = useState([]);
34
  const [expandedInteraction, setExpandedInteraction] = useState(null);
35
  const [selectedPattern, setSelectedPattern] = useState(null);
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- // Load sample items on component mount
38
  useEffect(() => {
39
  fetchSampleItems();
 
 
40
  }, []);
41
 
42
  const fetchSampleItems = async () => {
@@ -48,6 +65,30 @@ function App() {
48
  }
49
  };
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  const handleProfileChange = (field, value) => {
52
  setUserProfile(prev => ({
53
  ...prev,
@@ -55,6 +96,41 @@ function App() {
55
  }));
56
  };
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  const generateTimestamp = (baseTime, offsetHours) => {
59
  const timestamp = new Date(baseTime.getTime() - (offsetHours * 60 * 60 * 1000));
60
  return timestamp.toISOString().replace('T', ' ').slice(0, 19);
@@ -170,9 +246,77 @@ function App() {
170
  }));
171
  };
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  const getRecommendations = async () => {
174
  setLoading(true);
175
  setError(null);
 
176
 
177
  try {
178
  const requestData = {
@@ -192,27 +336,154 @@ function App() {
192
  }
193
  };
194
 
195
- const getInteractionCounts = () => {
196
- const counts = { views: 0, carts: 0, purchases: 0 };
197
- interactions.forEach(interaction => {
198
- counts[interaction.type + 's'] = (counts[interaction.type + 's'] || 0) + 1;
199
- });
200
- return counts;
201
- };
202
-
203
- const counts = getInteractionCounts();
204
-
205
  return (
206
  <div className="App">
207
  <div className="container">
208
  <header className="header">
209
  <h1>Two-Tower Recommendation System Demo</h1>
210
- <p>Configure user demographics and realistic interaction patterns to get personalized recommendations</p>
 
 
 
 
 
 
 
211
  </header>
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  {/* User Profile Form */}
214
  <div className="user-profile-form">
215
- <h2>User Demographics</h2>
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  <div className="form-row">
218
  <div className="form-group">
@@ -224,6 +495,8 @@ function App() {
224
  onChange={(e) => handleProfileChange('age', e.target.value)}
225
  min="18"
226
  max="100"
 
 
227
  />
228
  </div>
229
 
@@ -233,6 +506,8 @@ function App() {
233
  id="gender"
234
  value={userProfile.gender}
235
  onChange={(e) => handleProfileChange('gender', e.target.value)}
 
 
236
  >
237
  <option value="male">Male</option>
238
  <option value="female">Female</option>
@@ -248,6 +523,8 @@ function App() {
248
  onChange={(e) => handleProfileChange('income', e.target.value)}
249
  min="0"
250
  step="1000"
 
 
251
  />
252
  </div>
253
  </div>
@@ -255,29 +532,140 @@ function App() {
255
 
256
  {/* Interaction Patterns */}
257
  <div className="interaction-patterns">
258
- <h2>Interaction Patterns</h2>
259
- <p>Generate realistic user behavior patterns with proportional view, cart, and purchase events</p>
260
-
261
- <div className="pattern-buttons">
262
- {INTERACTION_PATTERNS.map((pattern, index) => (
263
- <button
264
- key={index}
265
- className={`pattern-btn ${selectedPattern?.name === pattern.name ? 'active' : ''}`}
266
- onClick={() => handlePatternSelect(pattern)}
267
- >
268
- {pattern.name}
269
- <br />
270
- <small>{pattern.views}V • {pattern.carts}C • {pattern.purchases}P</small>
271
- </button>
272
- ))}
273
- <button
274
- className="pattern-btn"
275
- onClick={clearInteractions}
276
- style={{backgroundColor: '#dc3545', color: 'white', borderColor: '#dc3545'}}
277
- >
278
- Clear All
279
- </button>
280
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  {interactions.length > 0 && (
283
  <>
@@ -305,7 +693,7 @@ function App() {
305
  {interaction.type}
306
  </span>
307
  <span className="interaction-details">
308
- <strong>{interaction.brand}</strong> - ${interaction.price}
309
  {interaction.quantity && ` (x${interaction.quantity})`}
310
  {interaction.total_amount && ` = $${interaction.total_amount}`}
311
  </span>
@@ -387,6 +775,8 @@ function App() {
387
  onChange={(e) => setRecommendationType(e.target.value)}
388
  >
389
  <option value="hybrid">Hybrid</option>
 
 
390
  <option value="collaborative">Collaborative Filtering</option>
391
  <option value="content">Content-Based</option>
392
  </select>
@@ -399,14 +789,15 @@ function App() {
399
  value={numRecommendations}
400
  onChange={(e) => setNumRecommendations(parseInt(e.target.value))}
401
  >
402
- <option value="5">5</option>
403
  <option value="10">10</option>
404
- <option value="15">15</option>
405
  <option value="20">20</option>
 
 
 
406
  </select>
407
  </div>
408
 
409
- {recommendationType === 'hybrid' && (
410
  <div className="form-group">
411
  <label htmlFor="collabWeight">Collaborative Weight:</label>
412
  <input
@@ -448,19 +839,66 @@ function App() {
448
  {/* Recommendations Display */}
449
  {recommendations.length > 0 && (
450
  <div className="recommendations">
451
- <h2>Recommendations ({recommendationType})</h2>
452
 
453
  <div className="stats">
454
  <strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
455
- ${userProfile.income.toLocaleString()} income,
456
- {interactions.length} total interactions ({counts.views || 0} views, {counts.carts || 0} carts, {counts.purchases || 0} purchases)
 
 
 
 
 
 
 
457
  </div>
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  <div className="recommendations-grid">
460
- {recommendations.map((rec, index) => (
461
  <div key={rec.item_id} className="recommendation-card">
462
  <div className="card-header">
463
- <span className="item-id">#{index + 1} Item {rec.item_id}</span>
464
  <span className="score">{rec.score.toFixed(4)}</span>
465
  </div>
466
 
@@ -472,6 +910,27 @@ function App() {
472
  </div>
473
  ))}
474
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  </div>
476
  )}
477
 
 
22
  });
23
 
24
  const [recommendationType, setRecommendationType] = useState('hybrid');
25
+ const [numRecommendations, setNumRecommendations] = useState(100);
26
  const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
27
 
28
  const [recommendations, setRecommendations] = useState([]);
29
  const [loading, setLoading] = useState(false);
30
  const [error, setError] = useState(null);
31
 
32
+ // Pagination for recommendations
33
+ const [currentPage, setCurrentPage] = useState(1);
34
+ const [itemsPerPage] = useState(20); // Show 20 recommendations per page
35
+
36
  const [sampleItems, setSampleItems] = useState([]);
37
  const [interactions, setInteractions] = useState([]);
38
  const [expandedInteraction, setExpandedInteraction] = useState(null);
39
  const [selectedPattern, setSelectedPattern] = useState(null);
40
+
41
+ // Real user data states
42
+ const [realUsers, setRealUsers] = useState([]);
43
+ const [selectedRealUser, setSelectedRealUser] = useState(null);
44
+ const [datasetSummary, setDatasetSummary] = useState(null);
45
+ const [useRealUsers, setUseRealUsers] = useState(true);
46
+
47
+ // Expanded interaction states
48
+ const [showUserInteractions, setShowUserInteractions] = useState(false);
49
+ const [userInteractionDetails, setUserInteractionDetails] = useState(null);
50
+ const [loadingInteractions, setLoadingInteractions] = useState(false);
51
 
52
+ // Load sample items and real users on component mount
53
  useEffect(() => {
54
  fetchSampleItems();
55
+ fetchRealUsers();
56
+ fetchDatasetSummary();
57
  }, []);
58
 
59
  const fetchSampleItems = async () => {
 
65
  }
66
  };
67
 
68
+ const fetchRealUsers = async () => {
69
+ try {
70
+ const response = await axios.get(`${API_BASE_URL}/real-users?count=100&min_interactions=5`);
71
+ setRealUsers(response.data.users || []);
72
+ if (response.data.users && response.data.users.length > 0) {
73
+ // Auto-select the first (most active) user
74
+ handleRealUserSelect(response.data.users[0]);
75
+ }
76
+ } catch (error) {
77
+ console.error('Error fetching real users:', error);
78
+ setError('Could not load real users. Using synthetic data mode.');
79
+ setUseRealUsers(false);
80
+ }
81
+ };
82
+
83
+ const fetchDatasetSummary = async () => {
84
+ try {
85
+ const response = await axios.get(`${API_BASE_URL}/dataset-summary`);
86
+ setDatasetSummary(response.data);
87
+ } catch (error) {
88
+ console.error('Error fetching dataset summary:', error);
89
+ }
90
+ };
91
+
92
  const handleProfileChange = (field, value) => {
93
  setUserProfile(prev => ({
94
  ...prev,
 
96
  }));
97
  };
98
 
99
+ const handleRealUserSelect = (user) => {
100
+ setSelectedRealUser(user);
101
+ setUserProfile({
102
+ age: user.age,
103
+ gender: user.gender,
104
+ income: user.income,
105
+ interaction_history: user.interaction_history.slice(0, 50) // Limit to 50 items
106
+ });
107
+ // Clear any synthetic interactions and expanded states
108
+ setInteractions([]);
109
+ setSelectedPattern(null);
110
+ setShowUserInteractions(false);
111
+ setUserInteractionDetails(null);
112
+ };
113
+
114
+ const fetchUserInteractionDetails = async (userId) => {
115
+ setLoadingInteractions(true);
116
+ try {
117
+ const response = await axios.get(`${API_BASE_URL}/real-users/${userId}`);
118
+ setUserInteractionDetails(response.data);
119
+ } catch (error) {
120
+ console.error('Error fetching user interaction details:', error);
121
+ setError('Could not load user interaction details');
122
+ } finally {
123
+ setLoadingInteractions(false);
124
+ }
125
+ };
126
+
127
+ const toggleUserInteractions = async () => {
128
+ if (!showUserInteractions && selectedRealUser && !userInteractionDetails) {
129
+ await fetchUserInteractionDetails(selectedRealUser.user_id);
130
+ }
131
+ setShowUserInteractions(!showUserInteractions);
132
+ };
133
+
134
  const generateTimestamp = (baseTime, offsetHours) => {
135
  const timestamp = new Date(baseTime.getTime() - (offsetHours * 60 * 60 * 1000));
136
  return timestamp.toISOString().replace('T', ' ').slice(0, 19);
 
246
  }));
247
  };
248
 
249
+
250
+ const getInteractionCounts = () => {
251
+ const counts = { views: 0, carts: 0, purchases: 0 };
252
+ interactions.forEach(interaction => {
253
+ counts[interaction.type + 's'] = (counts[interaction.type + 's'] || 0) + 1;
254
+ });
255
+ return counts;
256
+ };
257
+
258
+ const counts = getInteractionCounts();
259
+
260
+ // Calculate category percentages from user interactions
261
+ const getCategoryPercentages = () => {
262
+ if (!selectedRealUser || !userInteractionDetails) return {};
263
+
264
+ const categoryCounts = {};
265
+ let totalInteractions = 0;
266
+
267
+ userInteractionDetails.timeline?.forEach(interaction => {
268
+ const category = interaction.category_code || 'Unknown';
269
+ categoryCounts[category] = (categoryCounts[category] || 0) + 1;
270
+ totalInteractions++;
271
+ });
272
+
273
+ const categoryPercentages = {};
274
+ Object.keys(categoryCounts).forEach(category => {
275
+ categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
276
+ });
277
+
278
+ return categoryPercentages;
279
+ };
280
+
281
+ // Calculate recommendation category percentages
282
+ const getRecommendationCategoryPercentages = () => {
283
+ if (!recommendations || recommendations.length === 0) return {};
284
+
285
+ const recCategoryCounts = {};
286
+
287
+ recommendations.forEach(rec => {
288
+ const category = rec.item_info?.category_code || 'Unknown';
289
+ recCategoryCounts[category] = (recCategoryCounts[category] || 0) + 1;
290
+ });
291
+
292
+ const recCategoryPercentages = {};
293
+ Object.keys(recCategoryCounts).forEach(category => {
294
+ recCategoryPercentages[category] = ((recCategoryCounts[category] / recommendations.length) * 100).toFixed(1);
295
+ });
296
+
297
+ return recCategoryPercentages;
298
+ };
299
+
300
+ const categoryPercentages = getCategoryPercentages();
301
+ const recommendationCategoryPercentages = getRecommendationCategoryPercentages();
302
+
303
+ // Pagination logic
304
+ const totalPages = Math.ceil(recommendations.length / itemsPerPage);
305
+ const startIndex = (currentPage - 1) * itemsPerPage;
306
+ const endIndex = startIndex + itemsPerPage;
307
+ const currentRecommendations = recommendations.slice(startIndex, endIndex);
308
+
309
+ const goToPage = (page) => {
310
+ setCurrentPage(page);
311
+ // Scroll to recommendations section
312
+ document.querySelector('.recommendations')?.scrollIntoView({ behavior: 'smooth' });
313
+ };
314
+
315
+ // Reset pagination when new recommendations are generated
316
  const getRecommendations = async () => {
317
  setLoading(true);
318
  setError(null);
319
+ setCurrentPage(1); // Reset to first page
320
 
321
  try {
322
  const requestData = {
 
336
  }
337
  };
338
 
 
 
 
 
 
 
 
 
 
 
339
  return (
340
  <div className="App">
341
  <div className="container">
342
  <header className="header">
343
  <h1>Two-Tower Recommendation System Demo</h1>
344
+ <p>Select from {realUsers.length} real users or configure custom demographics to get personalized recommendations</p>
345
+
346
+ {datasetSummary && (
347
+ <div className="dataset-info">
348
+ 📊 Dataset: {datasetSummary.total_users?.toLocaleString()} users, {datasetSummary.total_interactions?.toLocaleString()} interactions |
349
+ 👥 Demographics: Avg age {datasetSummary.demographics?.avg_age}, avg income ${datasetSummary.demographics?.avg_income?.toLocaleString()}
350
+ </div>
351
+ )}
352
  </header>
353
 
354
+ {/* Real User Selector */}
355
+ {useRealUsers && realUsers.length > 0 && (
356
+ <div className="real-user-selector">
357
+ <h2>Real User Selection</h2>
358
+ <div className="user-selector-controls">
359
+ <label htmlFor="realUserSelect">Choose from {realUsers.length} real users:</label>
360
+ <select
361
+ id="realUserSelect"
362
+ value={selectedRealUser?.user_id || ''}
363
+ onChange={(e) => {
364
+ const userId = parseInt(e.target.value);
365
+ const user = realUsers.find(u => u.user_id === userId);
366
+ if (user) handleRealUserSelect(user);
367
+ }}
368
+ >
369
+ <option value="">Select a real user...</option>
370
+ {realUsers.map((user, index) => (
371
+ <option key={user.user_id} value={user.user_id}>
372
+ #{index + 1}: {user.summary} - {user.interaction_pattern}
373
+ </option>
374
+ ))}
375
+ </select>
376
+
377
+ <button
378
+ onClick={() => setUseRealUsers(false)}
379
+ className="btn btn-secondary"
380
+ style={{marginLeft: '10px'}}
381
+ >
382
+ Use Custom User Instead
383
+ </button>
384
+ </div>
385
+
386
+ {selectedRealUser && (
387
+ <div className="selected-real-user">
388
+ <h3>Selected User: {selectedRealUser.user_id}</h3>
389
+ <div className="real-user-stats">
390
+ <div className="user-stat">
391
+ <span className="stat-label">Demographics:</span>
392
+ <span className="stat-value">{selectedRealUser.age}yr {selectedRealUser.gender}, ${selectedRealUser.income.toLocaleString()}</span>
393
+ </div>
394
+ <div className="user-stat">
395
+ <span className="stat-label">Behavior Pattern:</span>
396
+ <span className="stat-value">{selectedRealUser.interaction_pattern}</span>
397
+ </div>
398
+ <div className="user-stat">
399
+ <span className="stat-label">Interactions:</span>
400
+ <span className="stat-value">
401
+ {selectedRealUser.interaction_stats.total_interactions} total
402
+ ({selectedRealUser.interaction_stats.views} views, {selectedRealUser.interaction_stats.cart_adds} carts, {selectedRealUser.interaction_stats.purchases} purchases)
403
+ </span>
404
+ </div>
405
+ <div className="user-stat">
406
+ <span className="stat-label">History:</span>
407
+ <span className="stat-value">{selectedRealUser.interaction_stats.unique_items} unique items</span>
408
+ </div>
409
+ </div>
410
+
411
+ <button
412
+ onClick={toggleUserInteractions}
413
+ className="btn btn-info expand-interactions-btn"
414
+ disabled={loadingInteractions}
415
+ >
416
+ {loadingInteractions ? 'Loading...' : showUserInteractions ? 'Hide Interaction Timeline' : 'Show All Interactions Timeline'}
417
+ </button>
418
+
419
+ {showUserInteractions && userInteractionDetails && (
420
+ <div className="user-interactions-timeline">
421
+ <h4>Complete Interaction Timeline</h4>
422
+ <div className="timeline-stats">
423
+ <span><strong>Total Events:</strong> {userInteractionDetails.total_interactions}</span>
424
+ <span><strong>Pattern:</strong> {userInteractionDetails.interaction_pattern}</span>
425
+ <span><strong>Breakdown:</strong> {userInteractionDetails.breakdown.views} views, {userInteractionDetails.breakdown.cart_adds} carts, {userInteractionDetails.breakdown.purchases} purchases</span>
426
+ </div>
427
+
428
+ <div className="interactions-list">
429
+ <h5>Recent Interactions (Last {userInteractionDetails.timeline?.length || 0} events):</h5>
430
+ {userInteractionDetails.timeline?.map((interaction, index) => (
431
+ <div key={index} className="interaction-timeline-item">
432
+ <div className="interaction-timeline-time">
433
+ {new Date(interaction.timestamp).toLocaleString()}
434
+ </div>
435
+ <div className="interaction-timeline-content">
436
+ <div className="interaction-main-info">
437
+ <span className={`interaction-type ${interaction.event_type}`}>
438
+ {interaction.event_type.toUpperCase()}
439
+ </span>
440
+ <span className="interaction-icon">
441
+ {interaction.event_type === 'purchase' && '💰'}
442
+ {interaction.event_type === 'cart' && '🛒'}
443
+ {interaction.event_type === 'view' && '👁️'}
444
+ </span>
445
+ <span className="interaction-item-id">
446
+ Item #{interaction.product_id}
447
+ </span>
448
+ </div>
449
+ <div className="interaction-item-details">
450
+ <span className="item-brand">
451
+ <strong>{interaction.brand || 'Unknown Brand'}</strong>
452
+ </span>
453
+ <span className="item-category">
454
+ {interaction.category_code || 'Unknown Category'}
455
+ </span>
456
+ <span className="item-price">
457
+ ${interaction.price ? interaction.price.toFixed(2) : '0.00'}
458
+ </span>
459
+ </div>
460
+ </div>
461
+ </div>
462
+ ))}
463
+ </div>
464
+ </div>
465
+ )}
466
+ </div>
467
+ )}
468
+ </div>
469
+ )}
470
+
471
  {/* User Profile Form */}
472
  <div className="user-profile-form">
473
+ <h2>User Demographics {useRealUsers && selectedRealUser ? '(From Real User)' : '(Custom)'}</h2>
474
+
475
+ {!useRealUsers && (
476
+ <button
477
+ onClick={() => {
478
+ setUseRealUsers(true);
479
+ if (realUsers.length > 0) handleRealUserSelect(realUsers[0]);
480
+ }}
481
+ className="btn btn-secondary"
482
+ style={{marginBottom: '15px'}}
483
+ >
484
+ Switch to Real Users
485
+ </button>
486
+ )}
487
 
488
  <div className="form-row">
489
  <div className="form-group">
 
495
  onChange={(e) => handleProfileChange('age', e.target.value)}
496
  min="18"
497
  max="100"
498
+ disabled={useRealUsers && selectedRealUser}
499
+ style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
500
  />
501
  </div>
502
 
 
506
  id="gender"
507
  value={userProfile.gender}
508
  onChange={(e) => handleProfileChange('gender', e.target.value)}
509
+ disabled={useRealUsers && selectedRealUser}
510
+ style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
511
  >
512
  <option value="male">Male</option>
513
  <option value="female">Female</option>
 
523
  onChange={(e) => handleProfileChange('income', e.target.value)}
524
  min="0"
525
  step="1000"
526
+ disabled={useRealUsers && selectedRealUser}
527
+ style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
528
  />
529
  </div>
530
  </div>
 
532
 
533
  {/* Interaction Patterns */}
534
  <div className="interaction-patterns">
535
+ {useRealUsers && selectedRealUser ? (
536
+ <>
537
+ <h2>Real User Interaction History</h2>
538
+ <p>This user has genuine interaction history from the dataset - no synthetic patterns needed.</p>
539
+
540
+ <div className="real-interaction-summary">
541
+ <div className="summary-card views">
542
+ <div className="summary-number">{selectedRealUser.interaction_stats.views}</div>
543
+ <div className="summary-label">Views</div>
544
+ </div>
545
+ <div className="summary-card carts">
546
+ <div className="summary-number">{selectedRealUser.interaction_stats.cart_adds}</div>
547
+ <div className="summary-label">Cart Adds</div>
548
+ </div>
549
+ <div className="summary-card purchases">
550
+ <div className="summary-number">{selectedRealUser.interaction_stats.purchases}</div>
551
+ <div className="summary-label">Purchases</div>
552
+ </div>
553
+ </div>
554
+
555
+ <div className="real-history-info">
556
+ <p><strong>Pattern:</strong> {selectedRealUser.interaction_pattern}</p>
557
+ <p><strong>Total Interactions:</strong> {selectedRealUser.interaction_stats.total_interactions}</p>
558
+ <p><strong>Unique Items:</strong> {selectedRealUser.interaction_stats.unique_items}</p>
559
+ <p><strong>Items in History:</strong> {userProfile.interaction_history.length} (showing up to 50 most recent)</p>
560
+ </div>
561
+
562
+ {/* Category Analysis Columns */}
563
+ {(Object.keys(categoryPercentages).length > 0 || Object.keys(recommendationCategoryPercentages).length > 0) && (
564
+ <div className="category-analysis">
565
+ <h4>Category Analysis</h4>
566
+ <div className="category-columns">
567
+
568
+ {/* User's Interacted Categories */}
569
+ {Object.keys(categoryPercentages).length > 0 && (
570
+ <div className="category-column">
571
+ <h5>👁️ User's Category Interests</h5>
572
+ <div className="category-percentages">
573
+ {Object.entries(categoryPercentages)
574
+ .sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
575
+ .slice(0, 5)
576
+ .map(([category, percentage]) => (
577
+ <div key={category} className="category-item">
578
+ <div className="category-bar-container">
579
+ <div
580
+ className="category-bar user-category"
581
+ style={{ width: `${Math.max(parseFloat(percentage), 5)}%` }}
582
+ ></div>
583
+ </div>
584
+ <span className="category-label">{category.replace('_', ' ')}</span>
585
+ <span className="category-percent">{percentage}%</span>
586
+ </div>
587
+ ))}
588
+ </div>
589
+ </div>
590
+ )}
591
+
592
+ {/* Recommendation Categories */}
593
+ {Object.keys(recommendationCategoryPercentages).length > 0 && (
594
+ <div className="category-column">
595
+ <h5>🎯 Recommendation Categories</h5>
596
+ <div className="category-percentages">
597
+ {Object.entries(recommendationCategoryPercentages)
598
+ .sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
599
+ .map(([category, percentage]) => {
600
+ const userPercentage = categoryPercentages[category] || 0;
601
+ const isMatch = parseFloat(userPercentage) > 0;
602
+
603
+ return (
604
+ <div key={category} className={`category-item ${isMatch ? 'matched' : 'new'}`}>
605
+ <div className="category-bar-container">
606
+ <div
607
+ className={`category-bar ${isMatch ? 'rec-category-matched' : 'rec-category-new'}`}
608
+ style={{ width: `${Math.max(parseFloat(percentage), 5)}%` }}
609
+ ></div>
610
+ </div>
611
+ <span className="category-label">{category.replace('_', ' ')}</span>
612
+ <span className="category-percent">{percentage}%</span>
613
+ {isMatch && <span className="match-indicator">✓</span>}
614
+ </div>
615
+ );
616
+ })}
617
+ </div>
618
+ </div>
619
+ )}
620
+
621
+ </div>
622
+
623
+ {/* Category Match Analysis */}
624
+ {Object.keys(categoryPercentages).length > 0 && Object.keys(recommendationCategoryPercentages).length > 0 && (
625
+ <div className="category-match-summary">
626
+ <p>
627
+ <strong>Category Alignment:</strong> {
628
+ Object.keys(recommendationCategoryPercentages).filter(cat =>
629
+ parseFloat(categoryPercentages[cat] || 0) > 0
630
+ ).length
631
+ } of {Object.keys(recommendationCategoryPercentages).length} recommended categories match user interests
632
+ <span className="match-legend">
633
+ <span className="legend-item"><span className="legend-dot matched"></span> Matches user interest</span>
634
+ <span className="legend-item"><span className="legend-dot new"></span> New category exploration</span>
635
+ </span>
636
+ </p>
637
+ </div>
638
+ )}
639
+ </div>
640
+ )}
641
+ </>
642
+ ) : (
643
+ <>
644
+ <h2>Synthetic Interaction Patterns</h2>
645
+ <p>Generate realistic user behavior patterns with proportional view, cart, and purchase events</p>
646
+
647
+ <div className="pattern-buttons">
648
+ {INTERACTION_PATTERNS.map((pattern, index) => (
649
+ <button
650
+ key={index}
651
+ className={`pattern-btn ${selectedPattern?.name === pattern.name ? 'active' : ''}`}
652
+ onClick={() => handlePatternSelect(pattern)}
653
+ >
654
+ {pattern.name}
655
+ <br />
656
+ <small>{pattern.views}V • {pattern.carts}C • {pattern.purchases}P</small>
657
+ </button>
658
+ ))}
659
+ <button
660
+ className="pattern-btn"
661
+ onClick={clearInteractions}
662
+ style={{backgroundColor: '#dc3545', color: 'white', borderColor: '#dc3545'}}
663
+ >
664
+ Clear All
665
+ </button>
666
+ </div>
667
+ </>
668
+ )}
669
 
670
  {interactions.length > 0 && (
671
  <>
 
693
  {interaction.type}
694
  </span>
695
  <span className="interaction-details">
696
+ <strong>{interaction.brand}</strong> - <span className="category-tag">{interaction.category}</span> - ${interaction.price}
697
  {interaction.quantity && ` (x${interaction.quantity})`}
698
  {interaction.total_amount && ` = $${interaction.total_amount}`}
699
  </span>
 
775
  onChange={(e) => setRecommendationType(e.target.value)}
776
  >
777
  <option value="hybrid">Hybrid</option>
778
+ <option value="enhanced">🎯 Enhanced Hybrid (Category-Aware)</option>
779
+ <option value="category_focused">🎯 Category Focused (80% Match)</option>
780
  <option value="collaborative">Collaborative Filtering</option>
781
  <option value="content">Content-Based</option>
782
  </select>
 
789
  value={numRecommendations}
790
  onChange={(e) => setNumRecommendations(parseInt(e.target.value))}
791
  >
 
792
  <option value="10">10</option>
 
793
  <option value="20">20</option>
794
+ <option value="50">50</option>
795
+ <option value="100">100 (Top Items)</option>
796
+ <option value="200">200 (Extended)</option>
797
  </select>
798
  </div>
799
 
800
+ {(recommendationType === 'hybrid' || recommendationType === 'enhanced') && (
801
  <div className="form-group">
802
  <label htmlFor="collabWeight">Collaborative Weight:</label>
803
  <input
 
839
  {/* Recommendations Display */}
840
  {recommendations.length > 0 && (
841
  <div className="recommendations">
842
+ <h2>Top {recommendations.length} Recommendations ({recommendationType})</h2>
843
 
844
  <div className="stats">
845
  <strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
846
+ ${userProfile.income.toLocaleString()} income
847
+ {useRealUsers && selectedRealUser ? (
848
+ <span> | <strong>Real User {selectedRealUser.user_id}:</strong> {selectedRealUser.interaction_pattern} -
849
+ {selectedRealUser.interaction_stats.total_interactions} total interactions
850
+ ({selectedRealUser.interaction_stats.views} views, {selectedRealUser.interaction_stats.cart_adds} carts, {selectedRealUser.interaction_stats.purchases} purchases)
851
+ </span>
852
+ ) : (
853
+ <span> | <strong>Synthetic:</strong> {interactions.length} total interactions ({counts.views || 0} views, {counts.carts || 0} carts, {counts.purchases || 0} purchases)</span>
854
+ )}
855
  </div>
856
 
857
+ {/* Pagination Controls */}
858
+ {totalPages > 1 && (
859
+ <div className="pagination-info">
860
+ <p>Showing {startIndex + 1}-{Math.min(endIndex, recommendations.length)} of {recommendations.length} recommendations</p>
861
+ <div className="pagination-controls">
862
+ <button
863
+ onClick={() => goToPage(currentPage - 1)}
864
+ disabled={currentPage === 1}
865
+ className="pagination-btn"
866
+ >
867
+ ← Previous
868
+ </button>
869
+
870
+ <div className="page-numbers">
871
+ {Array.from({length: Math.min(totalPages, 10)}, (_, i) => {
872
+ const page = i + 1;
873
+ return (
874
+ <button
875
+ key={page}
876
+ onClick={() => goToPage(page)}
877
+ className={`page-number ${currentPage === page ? 'active' : ''}`}
878
+ >
879
+ {page}
880
+ </button>
881
+ );
882
+ })}
883
+ {totalPages > 10 && <span className="pagination-ellipsis">...</span>}
884
+ </div>
885
+
886
+ <button
887
+ onClick={() => goToPage(currentPage + 1)}
888
+ disabled={currentPage === totalPages}
889
+ className="pagination-btn"
890
+ >
891
+ Next →
892
+ </button>
893
+ </div>
894
+ </div>
895
+ )}
896
+
897
  <div className="recommendations-grid">
898
+ {currentRecommendations.map((rec, index) => (
899
  <div key={rec.item_id} className="recommendation-card">
900
  <div className="card-header">
901
+ <span className="item-id">#{startIndex + index + 1} Item {rec.item_id}</span>
902
  <span className="score">{rec.score.toFixed(4)}</span>
903
  </div>
904
 
 
910
  </div>
911
  ))}
912
  </div>
913
+
914
+ {/* Bottom Pagination */}
915
+ {totalPages > 1 && (
916
+ <div className="pagination-controls bottom-pagination">
917
+ <button
918
+ onClick={() => goToPage(currentPage - 1)}
919
+ disabled={currentPage === 1}
920
+ className="pagination-btn"
921
+ >
922
+ ← Previous
923
+ </button>
924
+ <span className="page-indicator">Page {currentPage} of {totalPages}</span>
925
+ <button
926
+ onClick={() => goToPage(currentPage + 1)}
927
+ disabled={currentPage === totalPages}
928
+ className="pagination-btn"
929
+ >
930
+ Next →
931
+ </button>
932
+ </div>
933
+ )}
934
  </div>
935
  )}
936
 
run_2phase_training.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Two-Phase Training Pipeline Runner
4
+
5
+ This script orchestrates the complete 2-phase training approach:
6
+ 1. Phase 1: Item tower pretraining on item features
7
+ 2. Phase 2: Joint training of user tower + fine-tuning pre-trained item tower
8
+
9
+ Usage:
10
+ python run_2phase_training.py
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import time
16
+ import pickle
17
+ import numpy as np
18
+ from typing import Dict
19
+
20
+ # Add src to path
21
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
22
+
23
+ from src.training.item_pretraining import ItemTowerPretrainer
24
+ from src.training.joint_training import JointTrainer
25
+ from src.preprocessing.data_loader import DataProcessor
26
+ from src.inference.faiss_index import FAISSItemIndex
27
+
28
+
29
+ def run_phase1_item_pretraining():
30
+ """Phase 1: Pre-train the item tower."""
31
+
32
+ print("\n" + "="*60)
33
+ print("PHASE 1: ITEM TOWER PRETRAINING")
34
+ print("="*60)
35
+
36
+ # Initialize components
37
+ data_processor = DataProcessor()
38
+ pretrainer = ItemTowerPretrainer(
39
+ embedding_dim=128,
40
+ hidden_dims=[256, 128],
41
+ dropout_rate=0.2,
42
+ learning_rate=0.001
43
+ )
44
+
45
+ # Prepare data
46
+ print("Preparing item data...")
47
+ dataset, data_processor, price_normalizer = pretrainer.prepare_data(data_processor)
48
+
49
+ # Build model
50
+ print("Building item tower...")
51
+ model = pretrainer.build_model(
52
+ item_vocab_size=len(data_processor.item_vocab),
53
+ category_vocab_size=len(data_processor.category_vocab),
54
+ brand_vocab_size=len(data_processor.brand_vocab),
55
+ price_normalizer=price_normalizer
56
+ )
57
+
58
+ # Train model
59
+ print("Training item tower (Phase 1)...")
60
+ start_time = time.time()
61
+ history = pretrainer.train(dataset, epochs=50)
62
+ phase1_time = time.time() - start_time
63
+
64
+ # Generate embeddings
65
+ print("Generating item embeddings...")
66
+ item_embeddings = pretrainer.generate_item_embeddings(dataset, data_processor)
67
+
68
+ # Save artifacts
69
+ print("Saving Phase 1 artifacts...")
70
+ os.makedirs("src/artifacts", exist_ok=True)
71
+ data_processor.save_vocabularies()
72
+ pretrainer.save_model()
73
+
74
+ # Save embeddings for FAISS index
75
+ np.save("src/artifacts/item_embeddings.npy", item_embeddings)
76
+
77
+ # Build FAISS index
78
+ print("Building FAISS index...")
79
+ faiss_index = FAISSItemIndex()
80
+ faiss_index.build_index(item_embeddings)
81
+ faiss_index.save_index("src/artifacts/")
82
+
83
+ print(f"✅ Phase 1 completed in {phase1_time:.2f} seconds")
84
+ print(f" - Items processed: {len(item_embeddings)}")
85
+ print(f" - Final loss: {history.history['total_loss'][-1]:.4f}")
86
+
87
+ return data_processor
88
+
89
+
90
+ def run_phase2_joint_training(data_processor: DataProcessor):
91
+ """Phase 2: Joint training with pre-trained item tower."""
92
+
93
+ print("\n" + "="*60)
94
+ print("PHASE 2: JOINT TRAINING")
95
+ print("="*60)
96
+
97
+ # Initialize joint trainer
98
+ trainer = JointTrainer(
99
+ embedding_dim=128,
100
+ user_learning_rate=0.001,
101
+ item_learning_rate=0.0001, # Lower LR for pre-trained item tower
102
+ rating_weight=1.0,
103
+ retrieval_weight=0.5
104
+ )
105
+
106
+ # Load pre-trained item tower
107
+ print("Loading pre-trained item tower...")
108
+ trainer.load_pre_trained_item_tower()
109
+
110
+ # Build user tower
111
+ print("Building user tower...")
112
+ trainer.build_user_tower(max_history_length=50)
113
+
114
+ # Build complete two-tower model
115
+ print("Building complete two-tower model...")
116
+ trainer.build_two_tower_model()
117
+
118
+ # Prepare training data
119
+ print("Preparing user interaction data...")
120
+
121
+ # Check if training features already exist
122
+ if os.path.exists("src/artifacts/training_features.pkl"):
123
+ print("Loading existing training features...")
124
+ with open("src/artifacts/training_features.pkl", 'rb') as f:
125
+ training_features = pickle.load(f)
126
+ with open("src/artifacts/validation_features.pkl", 'rb') as f:
127
+ validation_features = pickle.load(f)
128
+ else:
129
+ # Generate training features
130
+ print("Generating training features...")
131
+ training_features, validation_features = data_processor.prepare_training_data()
132
+
133
+ # Save features
134
+ with open("src/artifacts/training_features.pkl", 'wb') as f:
135
+ pickle.dump(training_features, f)
136
+ with open("src/artifacts/validation_features.pkl", 'wb') as f:
137
+ pickle.dump(validation_features, f)
138
+
139
+ # Train joint model
140
+ print("Starting joint training (Phase 2)...")
141
+ start_time = time.time()
142
+ history = trainer.train(
143
+ training_features=training_features,
144
+ validation_features=validation_features,
145
+ epochs=100,
146
+ batch_size=256
147
+ )
148
+ phase2_time = time.time() - start_time
149
+
150
+ # Save final model
151
+ print("Saving final two-tower model...")
152
+ trainer.save_model()
153
+
154
+ # Save training history
155
+ with open("src/artifacts/joint_training_history.pkl", 'wb') as f:
156
+ pickle.dump(history, f)
157
+
158
+ print(f"✅ Phase 2 completed in {phase2_time:.2f} seconds")
159
+ print(f" - Best validation loss: {min(history['val_total_loss']):.4f}")
160
+ print(f" - Epochs trained: {len(history['total_loss'])}")
161
+
162
+ return history
163
+
164
+
165
+ def main():
166
+ """Main function to run complete 2-phase training pipeline."""
167
+
168
+ print("🚀 STARTING 2-PHASE TRAINING PIPELINE")
169
+ print(f"Working directory: {os.getcwd()}")
170
+ print(f"Python path: {sys.executable}")
171
+
172
+ total_start_time = time.time()
173
+
174
+ try:
175
+ # Phase 1: Item tower pretraining
176
+ data_processor = run_phase1_item_pretraining()
177
+
178
+ # Phase 2: Joint training
179
+ history = run_phase2_joint_training(data_processor)
180
+
181
+ # Final summary
182
+ total_time = time.time() - total_start_time
183
+
184
+ print("\n" + "="*60)
185
+ print("🎉 2-PHASE TRAINING COMPLETED SUCCESSFULLY!")
186
+ print("="*60)
187
+ print(f"Total training time: {total_time:.2f} seconds ({total_time/60:.1f} minutes)")
188
+ print(f"Artifacts saved in: src/artifacts/")
189
+ print("\nKey files generated:")
190
+ print(" - item_tower_weights: Pre-trained item embeddings")
191
+ print(" - item_tower_weights_finetuned_best: Fine-tuned item tower")
192
+ print(" - user_tower_weights_best: Trained user tower")
193
+ print(" - rating_model_weights_best: Rating prediction model")
194
+ print(" - faiss_index.index: Item similarity index")
195
+ print(" - vocabularies.pkl: Feature vocabularies")
196
+
197
+ print(f"\n🔥 Final validation loss: {min(history['val_total_loss']):.4f}")
198
+ print("\n✅ Ready to run inference with api/main.py!")
199
+
200
+ except Exception as e:
201
+ print(f"\n❌ Training failed with error: {str(e)}")
202
+ raise
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
run_joint_training.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Single Joint Training Pipeline Runner
4
+
5
+ This script orchestrates the single-phase joint training approach:
6
+ - Trains user tower and item tower simultaneously from scratch
7
+ - No pre-training phase - direct end-to-end optimization
8
+ - Supports both regular and fast training modes
9
+
10
+ Usage:
11
+ python run_joint_training.py [--fast]
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import time
17
+ import pickle
18
+ import argparse
19
+ from typing import Dict
20
+
21
+ # Add src to path
22
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
23
+
24
+ from src.training.fast_joint_training import FastJointTrainer
25
+ from src.models.item_tower import ItemTower
26
+ from src.models.user_tower import UserTower, TwoTowerModel
27
+ from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
28
+ from src.inference.faiss_index import FAISSItemIndex
29
+ import tensorflow as tf
30
+ import numpy as np
31
+
32
+
33
+ class SingleJointTrainer:
34
+ """Complete single-phase joint training from scratch."""
35
+
36
+ def __init__(self):
37
+ self.item_tower = None
38
+ self.user_tower = None
39
+ self.model = None
40
+ self.data_processor = None
41
+
42
+ # Training hyperparameters
43
+ self.embedding_dim = 128
44
+ self.learning_rate = 0.001
45
+ self.batch_size = 256
46
+ self.epochs = 80
47
+ self.patience = 20
48
+
49
+ def prepare_data(self):
50
+ """Prepare all training data from scratch."""
51
+
52
+ print("Loading and preparing data...")
53
+
54
+ # Initialize data processor
55
+ self.data_processor = DataProcessor()
56
+
57
+ # Check if preprocessed data exists
58
+ if os.path.exists("src/artifacts/training_features.pkl"):
59
+ print("Loading existing preprocessed data...")
60
+
61
+ # Load vocabularies
62
+ self.data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
63
+
64
+ # Load training features
65
+ with open("src/artifacts/training_features.pkl", 'rb') as f:
66
+ training_features = pickle.load(f)
67
+ with open("src/artifacts/validation_features.pkl", 'rb') as f:
68
+ validation_features = pickle.load(f)
69
+ else:
70
+ print("Preprocessing data from scratch...")
71
+
72
+ # Load raw data and build vocabularies
73
+ items_df, users_df, interactions_df = self.data_processor.load_data()
74
+ self.data_processor.build_vocabularies(items_df, users_df, interactions_df)
75
+
76
+ # Generate training features
77
+ training_features, validation_features = self.data_processor.prepare_training_data()
78
+
79
+ # Save for future use
80
+ os.makedirs("src/artifacts", exist_ok=True)
81
+ self.data_processor.save_vocabularies()
82
+
83
+ with open("src/artifacts/training_features.pkl", 'wb') as f:
84
+ pickle.dump(training_features, f)
85
+ with open("src/artifacts/validation_features.pkl", 'wb') as f:
86
+ pickle.dump(validation_features, f)
87
+
88
+ print(f"Training samples: {len(training_features['rating']):,}")
89
+ print(f"Validation samples: {len(validation_features['rating']):,}")
90
+
91
+ return training_features, validation_features
92
+
93
+ def build_models(self):
94
+ """Build both towers from scratch."""
95
+
96
+ print("Building item tower...")
97
+ self.item_tower = ItemTower(
98
+ item_vocab_size=len(self.data_processor.item_vocab),
99
+ category_vocab_size=len(self.data_processor.category_vocab),
100
+ brand_vocab_size=len(self.data_processor.brand_vocab),
101
+ embedding_dim=self.embedding_dim,
102
+ hidden_dims=[256, 128],
103
+ dropout_rate=0.2
104
+ )
105
+
106
+ print("Building user tower...")
107
+ self.user_tower = UserTower(
108
+ max_history_length=50,
109
+ embedding_dim=self.embedding_dim,
110
+ hidden_dims=[128, 64], # Match trained architecture
111
+ dropout_rate=0.2
112
+ )
113
+
114
+ print("Building complete two-tower model...")
115
+ self.model = TwoTowerModel(
116
+ item_tower=self.item_tower,
117
+ user_tower=self.user_tower,
118
+ rating_weight=1.0,
119
+ retrieval_weight=0.5
120
+ )
121
+
122
+ print("Models initialized successfully")
123
+
124
+ def train_joint_model(self, training_features: Dict, validation_features: Dict):
125
+ """Train both towers jointly from scratch."""
126
+
127
+ print(f"Starting single-phase joint training...")
128
+ print(f"Configuration: {self.epochs} epochs, batch size {self.batch_size}")
129
+
130
+ # Create datasets
131
+ train_dataset = create_tf_dataset(training_features, self.batch_size)
132
+ val_dataset = create_tf_dataset(validation_features, self.batch_size)
133
+
134
+ # Setup optimizer
135
+ optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
136
+
137
+ # Training history
138
+ history = {
139
+ 'total_loss': [],
140
+ 'rating_loss': [],
141
+ 'retrieval_loss': [],
142
+ 'val_total_loss': [],
143
+ 'val_rating_loss': [],
144
+ 'val_retrieval_loss': []
145
+ }
146
+
147
+ best_val_loss = float('inf')
148
+ patience_counter = 0
149
+
150
+ for epoch in range(self.epochs):
151
+ epoch_start = time.time()
152
+ print(f"\nEpoch {epoch + 1}/{self.epochs}")
153
+
154
+ # Training phase
155
+ epoch_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
156
+
157
+ for batch in train_dataset:
158
+ with tf.GradientTape() as tape:
159
+ # Forward pass
160
+ user_embeddings = self.user_tower(batch, training=True)
161
+ item_embeddings = self.item_tower(batch, training=True)
162
+
163
+ # Rating prediction
164
+ concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
165
+ rating_predictions = self.model.rating_model(concatenated, training=True)
166
+
167
+ # Rating loss
168
+ rating_loss = self.model.rating_task(
169
+ labels=batch["rating"],
170
+ predictions=rating_predictions
171
+ )
172
+
173
+ # Retrieval loss (dot product similarity)
174
+ similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
175
+ retrieval_loss = self.model.retrieval_loss(
176
+ batch["rating"],
177
+ tf.nn.sigmoid(similarities)
178
+ )
179
+
180
+ # Combined loss
181
+ total_loss = (
182
+ self.model.rating_weight * rating_loss +
183
+ self.model.retrieval_weight * retrieval_loss
184
+ )
185
+
186
+ # Compute and apply gradients
187
+ all_variables = (
188
+ self.user_tower.trainable_variables +
189
+ self.item_tower.trainable_variables +
190
+ self.model.rating_model.trainable_variables
191
+ )
192
+ gradients = tape.gradient(total_loss, all_variables)
193
+ optimizer.apply_gradients(zip(gradients, all_variables))
194
+
195
+ # Track losses
196
+ epoch_losses['total_loss'].append(total_loss)
197
+ epoch_losses['rating_loss'].append(rating_loss)
198
+ epoch_losses['retrieval_loss'].append(retrieval_loss)
199
+
200
+ # Validation phase
201
+ val_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
202
+
203
+ for batch in val_dataset:
204
+ user_embeddings = self.user_tower(batch, training=False)
205
+ item_embeddings = self.item_tower(batch, training=False)
206
+
207
+ concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
208
+ rating_predictions = self.model.rating_model(concatenated, training=False)
209
+
210
+ rating_loss = self.model.rating_task(
211
+ labels=batch["rating"],
212
+ predictions=rating_predictions
213
+ )
214
+
215
+ similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
216
+ retrieval_loss = self.model.retrieval_loss(
217
+ batch["rating"],
218
+ tf.nn.sigmoid(similarities)
219
+ )
220
+
221
+ total_loss = (
222
+ self.model.rating_weight * rating_loss +
223
+ self.model.retrieval_weight * retrieval_loss
224
+ )
225
+
226
+ val_losses['total_loss'].append(total_loss)
227
+ val_losses['rating_loss'].append(rating_loss)
228
+ val_losses['retrieval_loss'].append(retrieval_loss)
229
+
230
+ # Calculate average losses
231
+ avg_train_losses = {k: tf.reduce_mean(v).numpy() for k, v in epoch_losses.items()}
232
+ avg_val_losses = {k: tf.reduce_mean(v).numpy() for k, v in val_losses.items()}
233
+
234
+ # Update history
235
+ for key in history.keys():
236
+ if key.startswith('val_'):
237
+ history[key].append(avg_val_losses[key.replace('val_', '')])
238
+ else:
239
+ history[key].append(avg_train_losses[key])
240
+
241
+ # Print progress
242
+ epoch_time = time.time() - epoch_start
243
+ print(f"Time: {epoch_time:.1f}s | Train: {avg_train_losses['total_loss']:.4f} | Val: {avg_val_losses['total_loss']:.4f}")
244
+ print(f" Rating: {avg_val_losses['rating_loss']:.4f} | Retrieval: {avg_val_losses['retrieval_loss']:.4f}")
245
+
246
+ # Early stopping and best model saving
247
+ if avg_val_losses['total_loss'] < best_val_loss:
248
+ best_val_loss = avg_val_losses['total_loss']
249
+ patience_counter = 0
250
+ self.save_model("_best")
251
+ print(" ✅ Best model saved!")
252
+ else:
253
+ patience_counter += 1
254
+ if patience_counter >= self.patience:
255
+ print(f"Early stopping at epoch {epoch + 1}")
256
+ break
257
+
258
+ print("Joint training completed!")
259
+ return history
260
+
261
+ def generate_item_embeddings(self, training_features: Dict):
262
+ """Generate item embeddings for FAISS index."""
263
+
264
+ print("Generating item embeddings...")
265
+
266
+ # Get all unique items from training data
267
+ unique_items = np.unique(training_features['product_id'])
268
+ item_embeddings = {}
269
+
270
+ # Process in batches
271
+ batch_size = 1000
272
+ for i in range(0, len(unique_items), batch_size):
273
+ batch_items = unique_items[i:i+batch_size]
274
+
275
+ # Create batch features
276
+ batch_features = {
277
+ 'product_id': batch_items,
278
+ 'category_id': training_features['category_id'][:len(batch_items)],
279
+ 'brand_id': training_features['brand_id'][:len(batch_items)],
280
+ 'price': training_features['price'][:len(batch_items)]
281
+ }
282
+
283
+ # Convert to tensors
284
+ batch_tensors = {k: tf.constant(v) for k, v in batch_features.items()}
285
+
286
+ # Get embeddings
287
+ embeddings = self.item_tower(batch_tensors, training=False)
288
+
289
+ # Store embeddings
290
+ for j, item_id in enumerate(batch_items):
291
+ # Map back from vocab index to actual item ID
292
+ actual_item_id = item_id # Assuming direct mapping
293
+ item_embeddings[actual_item_id] = embeddings[j].numpy()
294
+
295
+ print(f"Generated embeddings for {len(item_embeddings)} items")
296
+ return item_embeddings
297
+
298
+ def save_model(self, suffix=""):
299
+ """Save trained models."""
300
+
301
+ save_path = "src/artifacts/"
302
+ os.makedirs(save_path, exist_ok=True)
303
+
304
+ # Save model weights
305
+ self.user_tower.save_weights(f"{save_path}/user_tower_weights{suffix}")
306
+ self.item_tower.save_weights(f"{save_path}/item_tower_weights_finetuned{suffix}")
307
+ self.model.rating_model.save_weights(f"{save_path}/rating_model_weights{suffix}")
308
+
309
+ # Save item tower config for inference
310
+ with open(f"{save_path}/item_tower_config.txt", 'w') as f:
311
+ f.write(f"embedding_dim: {self.embedding_dim}\n")
312
+ f.write(f"hidden_dims: [256, 128]\n") # Item tower architecture
313
+ f.write(f"dropout_rate: 0.2\n")
314
+
315
+ if not suffix:
316
+ print("Final model saved")
317
+
318
+
319
+ def run_fast_joint_training():
320
+ """Run fast optimized joint training."""
321
+
322
+ print("\n" + "="*60)
323
+ print("FAST JOINT TRAINING MODE")
324
+ print("="*60)
325
+
326
+ # Initialize fast trainer
327
+ trainer = FastJointTrainer()
328
+
329
+ # Check if we need to prepare data first
330
+ if not os.path.exists("src/artifacts/training_features.pkl"):
331
+ print("Preparing data first...")
332
+ single_trainer = SingleJointTrainer()
333
+ training_features, validation_features = single_trainer.prepare_data()
334
+
335
+ # Run fast training
336
+ trainer.load_components()
337
+
338
+ print("Loading training data...")
339
+ with open("src/artifacts/training_features.pkl", 'rb') as f:
340
+ training_features = pickle.load(f)
341
+ with open("src/artifacts/validation_features.pkl", 'rb') as f:
342
+ validation_features = pickle.load(f)
343
+
344
+ start_time = time.time()
345
+ trainer.train_fast(training_features, validation_features)
346
+ training_time = time.time() - start_time
347
+
348
+ # Generate embeddings and build FAISS index
349
+ print("Building FAISS index...")
350
+ # Use single trainer for embedding generation
351
+ single_trainer = SingleJointTrainer()
352
+ single_trainer.data_processor = DataProcessor()
353
+ single_trainer.data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
354
+ single_trainer.item_tower = trainer.item_tower
355
+
356
+ item_embeddings = single_trainer.generate_item_embeddings(training_features)
357
+
358
+ faiss_index = FAISSItemIndex()
359
+ faiss_index.build_index(item_embeddings)
360
+ faiss_index.save_index("src/artifacts/")
361
+
362
+ return training_time
363
+
364
+
365
+ def run_regular_joint_training():
366
+ """Run regular comprehensive joint training."""
367
+
368
+ print("\n" + "="*60)
369
+ print("REGULAR JOINT TRAINING MODE")
370
+ print("="*60)
371
+
372
+ # Initialize trainer
373
+ trainer = SingleJointTrainer()
374
+
375
+ # Prepare data
376
+ training_features, validation_features = trainer.prepare_data()
377
+
378
+ # Build models from scratch
379
+ trainer.build_models()
380
+
381
+ # Train joint model
382
+ start_time = time.time()
383
+ history = trainer.train_joint_model(training_features, validation_features)
384
+ training_time = time.time() - start_time
385
+
386
+ # Generate item embeddings
387
+ item_embeddings = trainer.generate_item_embeddings(training_features)
388
+
389
+ # Build FAISS index
390
+ print("Building FAISS index...")
391
+ faiss_index = FAISSItemIndex()
392
+ faiss_index.build_index(item_embeddings)
393
+ faiss_index.save_index("src/artifacts/")
394
+
395
+ # Save final model
396
+ trainer.save_model()
397
+
398
+ # Save training history
399
+ with open("src/artifacts/single_joint_training_history.pkl", 'wb') as f:
400
+ pickle.dump(history, f)
401
+
402
+ return training_time, history
403
+
404
+
405
+ def main():
406
+ """Main function to run single joint training pipeline."""
407
+
408
+ parser = argparse.ArgumentParser(description='Single Joint Training Pipeline')
409
+ parser.add_argument('--fast', action='store_true', help='Use fast training mode')
410
+ args = parser.parse_args()
411
+
412
+ print("🚀 STARTING SINGLE JOINT TRAINING PIPELINE")
413
+ print(f"Working directory: {os.getcwd()}")
414
+ print(f"Training mode: {'FAST' if args.fast else 'REGULAR'}")
415
+
416
+ total_start_time = time.time()
417
+
418
+ try:
419
+ if args.fast:
420
+ training_time = run_fast_joint_training()
421
+ history = None
422
+ else:
423
+ training_time, history = run_regular_joint_training()
424
+
425
+ total_time = time.time() - total_start_time
426
+
427
+ print("\n" + "="*60)
428
+ print("🎉 SINGLE JOINT TRAINING COMPLETED SUCCESSFULLY!")
429
+ print("="*60)
430
+ print(f"Training time: {training_time:.2f} seconds ({training_time/60:.1f} minutes)")
431
+ print(f"Total time: {total_time:.2f} seconds ({total_time/60:.1f} minutes)")
432
+ print(f"Artifacts saved in: src/artifacts/")
433
+
434
+ print("\nKey files generated:")
435
+ print(" - user_tower_weights_best: Trained user tower")
436
+ print(" - item_tower_weights_finetuned_best: Trained item tower")
437
+ print(" - rating_model_weights_best: Rating prediction model")
438
+ print(" - faiss_index.index: Item similarity index")
439
+ print(" - vocabularies.pkl: Feature vocabularies")
440
+
441
+ if history:
442
+ print(f"\n🔥 Best validation loss: {min(history['val_total_loss']):.4f}")
443
+
444
+ print(f"\n🎯 Training approach: Single-phase joint optimization")
445
+ print("✅ Ready to run inference with api/main.py!")
446
+
447
+ except Exception as e:
448
+ print(f"\n❌ Training failed with error: {str(e)}")
449
+ raise
450
+
451
+
452
+ if __name__ == "__main__":
453
+ main()
src/inference/enhanced_recommendation_engine_128d.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enhanced recommendation engine using 128D embeddings with diversity regularization.
4
+ """
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import tensorflow as tf
9
+ import pickle
10
+ import os
11
+ from typing import Dict, List, Tuple, Optional
12
+ from collections import Counter, defaultdict
13
+
14
+ import sys
15
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
16
+
17
+ from src.models.enhanced_two_tower import EnhancedItemTower, EnhancedUserTower
18
+ from src.inference.faiss_index import FAISSItemIndex
19
+ from src.preprocessing.data_loader import DataProcessor
20
+ from src.preprocessing.user_data_preparation import prepare_user_features
21
+ from src.utils.real_user_selector import RealUserSelector
22
+
23
+
24
+ class Enhanced128DRecommendationEngine:
25
+ """Enhanced recommendation engine with 128D embeddings and all improvements."""
26
+
27
+ def __init__(self, artifacts_path: str = "src/artifacts/"):
28
+ self.artifacts_path = artifacts_path
29
+ self.embedding_dim = 128 # Fixed to 128D
30
+
31
+ # Model components
32
+ self.item_tower = None
33
+ self.user_tower = None
34
+ self.rating_model = None
35
+ self.faiss_index = None
36
+ self.data_processor = None
37
+
38
+ # Data
39
+ self.items_df = None
40
+ self.users_df = None
41
+ self.income_thresholds = None
42
+
43
+ # Load all components
44
+ self._load_all_components()
45
+
46
+ def _load_all_components(self):
47
+ """Load all enhanced model components."""
48
+
49
+ print("Loading enhanced 128D recommendation engine...")
50
+
51
+ # Load data processor
52
+ self.data_processor = DataProcessor()
53
+ try:
54
+ self.data_processor.load_vocabularies(f"{self.artifacts_path}/vocabularies.pkl")
55
+ except FileNotFoundError:
56
+ print("❌ Vocabularies not found. Please train the model first.")
57
+ return
58
+
59
+ # Load datasets
60
+ self.items_df = pd.read_csv("datasets/items.csv")
61
+ self.users_df = pd.read_csv("datasets/users.csv")
62
+
63
+ # Load enhanced model components
64
+ self._load_enhanced_models()
65
+
66
+ # Load FAISS index with 128D
67
+ try:
68
+ self.faiss_index = FAISSItemIndex(embedding_dim=self.embedding_dim)
69
+ # Try to load enhanced embeddings first
70
+ if os.path.exists(f"{self.artifacts_path}/enhanced_item_embeddings.npy"):
71
+ enhanced_embeddings = np.load(
72
+ f"{self.artifacts_path}/enhanced_item_embeddings.npy",
73
+ allow_pickle=True
74
+ ).item()
75
+ self.faiss_index.build_index(enhanced_embeddings)
76
+ print("✅ Loaded enhanced 128D FAISS index")
77
+ else:
78
+ print("⚠️ Enhanced embeddings not found. Train enhanced model first.")
79
+ self.faiss_index = None
80
+ except Exception as e:
81
+ print(f"⚠️ Could not load FAISS index: {e}")
82
+ self.faiss_index = None
83
+
84
+ # Load income thresholds for categorical demographics
85
+ self._load_income_thresholds()
86
+
87
+ print("✅ Enhanced 128D engine loaded successfully!")
88
+
89
+ def _load_enhanced_models(self):
90
+ """Load enhanced model components."""
91
+
92
+ try:
93
+ # Create model architecture
94
+ self.item_tower = EnhancedItemTower(
95
+ item_vocab_size=len(self.data_processor.item_vocab),
96
+ category_vocab_size=len(self.data_processor.category_vocab),
97
+ brand_vocab_size=len(self.data_processor.brand_vocab),
98
+ embedding_dim=self.embedding_dim,
99
+ use_bias=True,
100
+ use_diversity_reg=False # Disable during inference
101
+ )
102
+
103
+ self.user_tower = EnhancedUserTower(
104
+ max_history_length=50,
105
+ embedding_dim=self.embedding_dim,
106
+ use_bias=True,
107
+ use_diversity_reg=False # Disable during inference
108
+ )
109
+
110
+ # Create rating model
111
+ self.rating_model = tf.keras.Sequential([
112
+ tf.keras.layers.Dense(512, activation="relu"),
113
+ tf.keras.layers.BatchNormalization(),
114
+ tf.keras.layers.Dropout(0.3),
115
+ tf.keras.layers.Dense(256, activation="relu"),
116
+ tf.keras.layers.BatchNormalization(),
117
+ tf.keras.layers.Dropout(0.2),
118
+ tf.keras.layers.Dense(64, activation="relu"),
119
+ tf.keras.layers.Dense(1, activation="sigmoid")
120
+ ])
121
+
122
+ # Load weights - try enhanced first, fall back to regular
123
+ model_files = [
124
+ ('enhanced_item_tower_weights_enhanced_best', 'enhanced_user_tower_weights_enhanced_best', 'enhanced_rating_model_weights_enhanced_best'),
125
+ ('enhanced_item_tower_weights_enhanced_final', 'enhanced_user_tower_weights_enhanced_final', 'enhanced_rating_model_weights_enhanced_final'),
126
+ ]
127
+
128
+ loaded = False
129
+ for item_file, user_file, rating_file in model_files:
130
+ try:
131
+ # Need to build models first with dummy data
132
+ self._build_models()
133
+
134
+ self.item_tower.load_weights(f"{self.artifacts_path}/{item_file}")
135
+ self.user_tower.load_weights(f"{self.artifacts_path}/{user_file}")
136
+ self.rating_model.load_weights(f"{self.artifacts_path}/{rating_file}")
137
+
138
+ print(f"✅ Loaded enhanced model: {item_file}")
139
+ loaded = True
140
+ break
141
+ except Exception as e:
142
+ print(f"⚠️ Could not load {item_file}: {e}")
143
+ continue
144
+
145
+ if not loaded:
146
+ print("❌ No enhanced model weights found. Please train enhanced model first.")
147
+ self.item_tower = None
148
+ self.user_tower = None
149
+ self.rating_model = None
150
+
151
+ except Exception as e:
152
+ print(f"❌ Failed to load enhanced models: {e}")
153
+ self.item_tower = None
154
+ self.user_tower = None
155
+ self.rating_model = None
156
+
157
+ def _build_models(self):
158
+ """Build models with dummy data to initialize weights."""
159
+
160
+ # Dummy item features
161
+ dummy_item_features = {
162
+ 'product_id': tf.constant([0]),
163
+ 'category_id': tf.constant([0]),
164
+ 'brand_id': tf.constant([0]),
165
+ 'price': tf.constant([100.0])
166
+ }
167
+
168
+ # Dummy user features
169
+ dummy_user_features = {
170
+ 'age': tf.constant([2]), # Adult category
171
+ 'gender': tf.constant([0]), # Female
172
+ 'income': tf.constant([2]), # Middle income
173
+ 'item_history_embeddings': tf.constant(np.zeros((1, 50, self.embedding_dim), dtype=np.float32))
174
+ }
175
+
176
+ # Forward pass to build models
177
+ _ = self.item_tower(dummy_item_features, training=False)
178
+ _ = self.user_tower(dummy_user_features, training=False)
179
+
180
+ # Build rating model
181
+ dummy_concat = tf.constant(np.zeros((1, self.embedding_dim * 2), dtype=np.float32))
182
+ _ = self.rating_model(dummy_concat, training=False)
183
+
184
+ def _load_income_thresholds(self):
185
+ """Load income thresholds for categorical processing."""
186
+
187
+ # Calculate income thresholds from training data
188
+ user_incomes = self.users_df['income'].values
189
+ self.income_thresholds = np.percentile(user_incomes, [0, 20, 40, 60, 80, 100])
190
+ print(f"Income thresholds: {self.income_thresholds}")
191
+
192
+ def categorize_age(self, age: float) -> int:
193
+ """Categorize age into 6 groups."""
194
+ if age < 18: return 0 # Teen
195
+ elif age < 26: return 1 # Young Adult
196
+ elif age < 36: return 2 # Adult
197
+ elif age < 51: return 3 # Middle Age
198
+ elif age < 66: return 4 # Mature
199
+ else: return 5 # Senior
200
+
201
+ def categorize_income(self, income: float) -> int:
202
+ """Categorize income into 5 percentile groups."""
203
+ category = np.digitize([income], self.income_thresholds[1:-1])[0]
204
+ return min(max(category, 0), 4)
205
+
206
+ def categorize_gender(self, gender: str) -> int:
207
+ """Categorize gender."""
208
+ return 1 if gender.lower() == 'male' else 0
209
+
210
+ def get_user_embedding(self,
211
+ age: int,
212
+ gender: str,
213
+ income: float,
214
+ interaction_history: List[int] = None) -> np.ndarray:
215
+ """Generate user embedding with categorical demographics."""
216
+
217
+ if self.user_tower is None:
218
+ print("❌ User tower not loaded")
219
+ return None
220
+
221
+ # Categorize demographics
222
+ age_cat = self.categorize_age(age)
223
+ gender_cat = self.categorize_gender(gender)
224
+ income_cat = self.categorize_income(income)
225
+
226
+ # Prepare interaction history embeddings
227
+ if interaction_history is None:
228
+ interaction_history = []
229
+
230
+ # Get item embeddings for history
231
+ history_embeddings = np.zeros((50, self.embedding_dim), dtype=np.float32)
232
+
233
+ for i, item_id in enumerate(interaction_history[:50]):
234
+ if self.faiss_index and item_id in self.faiss_index.item_id_to_idx:
235
+ item_emb = self.faiss_index.get_item_embedding(item_id)
236
+ if item_emb is not None:
237
+ history_embeddings[i] = item_emb
238
+
239
+ # Create user features
240
+ user_features = {
241
+ 'age': tf.constant([age_cat]),
242
+ 'gender': tf.constant([gender_cat]),
243
+ 'income': tf.constant([income_cat]),
244
+ 'item_history_embeddings': tf.constant([history_embeddings])
245
+ }
246
+
247
+ # Get embedding
248
+ user_output = self.user_tower(user_features, training=False)
249
+ if isinstance(user_output, tuple):
250
+ user_embedding = user_output[0].numpy()[0]
251
+ else:
252
+ user_embedding = user_output.numpy()[0]
253
+
254
+ return user_embedding
255
+
256
+ def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
257
+ """Get item embedding."""
258
+
259
+ if self.faiss_index:
260
+ return self.faiss_index.get_item_embedding(item_id)
261
+
262
+ # Fallback to model computation
263
+ if self.item_tower is None:
264
+ return None
265
+
266
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
267
+ if item_row.empty:
268
+ return None
269
+
270
+ item_data = item_row.iloc[0]
271
+
272
+ # Prepare features
273
+ item_features = {
274
+ 'product_id': tf.constant([self.data_processor.item_vocab.get(item_id, 0)]),
275
+ 'category_id': tf.constant([self.data_processor.category_vocab.get(item_data['category_id'], 0)]),
276
+ 'brand_id': tf.constant([self.data_processor.brand_vocab.get(item_data.get('brand', 'unknown'), 0)]),
277
+ 'price': tf.constant([float(item_data.get('price', 0.0))])
278
+ }
279
+
280
+ # Get embedding
281
+ item_output = self.item_tower(item_features, training=False)
282
+ if isinstance(item_output, tuple):
283
+ item_embedding = item_output[0].numpy()[0]
284
+ else:
285
+ item_embedding = item_output.numpy()[0]
286
+
287
+ return item_embedding
288
+
289
+ def recommend_items_enhanced(self,
290
+ age: int,
291
+ gender: str,
292
+ income: float,
293
+ interaction_history: List[int] = None,
294
+ k: int = 10,
295
+ diversity_weight: float = 0.3,
296
+ category_boost: float = 1.5) -> List[Tuple[int, float, Dict]]:
297
+ """Generate enhanced recommendations with diversity and category boosting."""
298
+
299
+ if not self.faiss_index:
300
+ print("❌ FAISS index not available")
301
+ return []
302
+
303
+ # Get user embedding
304
+ user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
305
+ if user_embedding is None:
306
+ return []
307
+
308
+ # Get candidate recommendations (more than needed for filtering)
309
+ candidates = self.faiss_index.search_by_embedding(user_embedding, k * 3)
310
+
311
+ # Filter out items from interaction history
312
+ if interaction_history:
313
+ history_set = set(interaction_history)
314
+ candidates = [(item_id, score) for item_id, score in candidates
315
+ if item_id not in history_set]
316
+
317
+ # Add item metadata and apply enhancements
318
+ enhanced_candidates = []
319
+
320
+ for item_id, similarity_score in candidates[:k * 2]:
321
+ # Get item info
322
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
323
+ if item_row.empty:
324
+ continue
325
+
326
+ item_info = item_row.iloc[0].to_dict()
327
+
328
+ # Enhanced scoring with multiple factors
329
+ final_score = similarity_score
330
+
331
+ # Category boosting based on user history
332
+ if interaction_history and category_boost > 1.0:
333
+ user_categories = self._get_user_categories(interaction_history)
334
+ item_category = item_info.get('category_code', '')
335
+
336
+ if item_category in user_categories:
337
+ category_preference = user_categories[item_category]
338
+ final_score *= (1 + (category_boost - 1) * category_preference)
339
+
340
+ enhanced_candidates.append((item_id, final_score, item_info))
341
+
342
+ # Sort by enhanced scores
343
+ enhanced_candidates.sort(key=lambda x: x[1], reverse=True)
344
+
345
+ # Apply diversity filtering
346
+ if diversity_weight > 0:
347
+ diversified_candidates = self._apply_diversity_filter(
348
+ enhanced_candidates, diversity_weight
349
+ )
350
+ else:
351
+ diversified_candidates = enhanced_candidates
352
+
353
+ return diversified_candidates[:k]
354
+
355
+ def _get_user_categories(self, interaction_history: List[int]) -> Dict[str, float]:
356
+ """Get user's category preferences from history."""
357
+
358
+ category_counts = Counter()
359
+
360
+ for item_id in interaction_history:
361
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
362
+ if not item_row.empty:
363
+ category = item_row.iloc[0].get('category_code', 'Unknown')
364
+ category_counts[category] += 1
365
+
366
+ # Convert to preferences (percentages)
367
+ total = sum(category_counts.values())
368
+ if total == 0:
369
+ return {}
370
+
371
+ return {cat: count / total for cat, count in category_counts.items()}
372
+
373
+ def _apply_diversity_filter(self,
374
+ candidates: List[Tuple[int, float, Dict]],
375
+ diversity_weight: float,
376
+ max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
377
+ """Apply diversity filtering to recommendations."""
378
+
379
+ category_counts = defaultdict(int)
380
+ diversified = []
381
+
382
+ for item_id, score, item_info in candidates:
383
+ category = item_info.get('category_code', 'Unknown')
384
+
385
+ # Apply diversity penalty
386
+ if category_counts[category] >= max_per_category:
387
+ # Penalty for over-representation
388
+ diversity_penalty = diversity_weight * (category_counts[category] - max_per_category + 1)
389
+ adjusted_score = score * (1 - diversity_penalty)
390
+ else:
391
+ adjusted_score = score
392
+
393
+ diversified.append((item_id, adjusted_score, item_info))
394
+ category_counts[category] += 1
395
+
396
+ # Re-sort by adjusted scores
397
+ diversified.sort(key=lambda x: x[1], reverse=True)
398
+ return diversified
399
+
400
+ def predict_rating(self,
401
+ age: int,
402
+ gender: str,
403
+ income: float,
404
+ item_id: int,
405
+ interaction_history: List[int] = None) -> float:
406
+ """Predict rating for user-item pair."""
407
+
408
+ if self.rating_model is None:
409
+ return 0.5 # Default rating
410
+
411
+ # Get embeddings
412
+ user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
413
+ item_embedding = self.get_item_embedding(item_id)
414
+
415
+ if user_embedding is None or item_embedding is None:
416
+ return 0.5
417
+
418
+ # Concatenate embeddings
419
+ combined = np.concatenate([user_embedding, item_embedding])
420
+ combined = tf.constant([combined])
421
+
422
+ # Predict rating
423
+ rating = self.rating_model(combined, training=False)
424
+ return float(rating.numpy()[0][0])
425
+
426
+
427
+ def demo_enhanced_engine():
428
+ """Demo the enhanced 128D recommendation engine."""
429
+
430
+ print("🚀 ENHANCED 128D RECOMMENDATION ENGINE DEMO")
431
+ print("="*70)
432
+
433
+ try:
434
+ # Initialize engine
435
+ engine = Enhanced128DRecommendationEngine()
436
+
437
+ if engine.item_tower is None:
438
+ print("❌ Enhanced model not available. Please train first using:")
439
+ print(" python train_enhanced_model.py")
440
+ return
441
+
442
+ # Get real user for testing
443
+ real_user_selector = RealUserSelector()
444
+ test_users = real_user_selector.get_real_users(n=2, min_interactions=10)
445
+
446
+ for user in test_users:
447
+ print(f"\n📊 Testing User {user['user_id']} ({user['age']}yr {user['gender']}):")
448
+ print(f" Income: ${user['income']:,}")
449
+ print(f" History: {len(user['interaction_history'])} items")
450
+
451
+ # Test enhanced recommendations
452
+ try:
453
+ recs = engine.recommend_items_enhanced(
454
+ age=user['age'],
455
+ gender=user['gender'],
456
+ income=user['income'],
457
+ interaction_history=user['interaction_history'][:20],
458
+ k=10,
459
+ diversity_weight=0.3,
460
+ category_boost=1.5
461
+ )
462
+
463
+ print(f" 🎯 Enhanced Recommendations:")
464
+ categories = []
465
+ for i, (item_id, score, item_info) in enumerate(recs[:5]):
466
+ category = item_info.get('category_code', 'Unknown')[:30]
467
+ price = item_info.get('price', 0)
468
+ categories.append(category)
469
+ print(f" #{i+1} Item {item_id}: {score:.4f} | ${price:.2f} | {category}")
470
+
471
+ # Analyze diversity
472
+ unique_categories = len(set(categories))
473
+ print(f" 📈 Diversity: {unique_categories}/{len(categories)} unique categories")
474
+
475
+ # Test rating prediction
476
+ if recs:
477
+ test_item = recs[0][0]
478
+ predicted_rating = engine.predict_rating(
479
+ age=user['age'],
480
+ gender=user['gender'],
481
+ income=user['income'],
482
+ item_id=test_item,
483
+ interaction_history=user['interaction_history'][:20]
484
+ )
485
+ print(f" ⭐ Rating prediction for item {test_item}: {predicted_rating:.3f}")
486
+
487
+ except Exception as e:
488
+ print(f" ❌ Error: {e}")
489
+
490
+ print(f"\n✅ Enhanced 128D engine demo completed!")
491
+
492
+ except Exception as e:
493
+ print(f"❌ Demo failed: {e}")
494
+ import traceback
495
+ traceback.print_exc()
496
+
497
+
498
+ if __name__ == "__main__":
499
+ demo_enhanced_engine()
src/inference/faiss_index.py CHANGED
@@ -8,7 +8,7 @@ from typing import Dict, List, Tuple, Optional
8
  class FAISSItemIndex:
9
  """FAISS-based item similarity search index."""
10
 
11
- def __init__(self, embedding_dim: int = 64):
12
  self.embedding_dim = embedding_dim
13
  self.index = None
14
  self.item_id_to_idx = {}
@@ -40,14 +40,10 @@ class FAISSItemIndex:
40
  # Exact search (slower but accurate)
41
  self.index = faiss.IndexFlatIP(self.embedding_dim)
42
  elif index_type == "IVF":
43
- # Approximate search (faster)
44
- nlist = min(100, len(item_ids) // 10) # Number of clusters
45
- quantizer = faiss.IndexFlatIP(self.embedding_dim)
46
- self.index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist)
47
-
48
- # Train the index
49
- self.index.train(embeddings_array)
50
- self.index.nprobe = min(10, nlist) # Search in top 10 clusters
51
  else:
52
  raise ValueError(f"Unsupported index type: {index_type}")
53
 
@@ -134,6 +130,7 @@ class FAISSItemIndex:
134
  sample_queries = list(self.item_id_to_idx.keys())[:5]
135
 
136
  print("Validating FAISS index...")
 
137
 
138
  for query_item in sample_queries:
139
  if query_item not in self.item_id_to_idx:
@@ -141,9 +138,16 @@ class FAISSItemIndex:
141
 
142
  similar_items = self.search_similar_items(query_item, k=5)
143
 
144
- print(f"\nSimilar items to {query_item}:")
145
- for item_id, score in similar_items:
146
- print(f" Item {item_id}: similarity = {score:.4f}")
 
 
 
 
 
 
 
147
 
148
  def save_index(self, save_path: str = "src/artifacts/") -> None:
149
  """Save FAISS index and mappings."""
@@ -197,7 +201,7 @@ def main():
197
 
198
  # Create and build FAISS index
199
  print("Building FAISS index...")
200
- faiss_index = FAISSItemIndex(embedding_dim=64)
201
  faiss_index.build_index(item_embeddings, index_type="IVF")
202
 
203
  # Validate index
 
8
  class FAISSItemIndex:
9
  """FAISS-based item similarity search index."""
10
 
11
+ def __init__(self, embedding_dim: int = 128):
12
  self.embedding_dim = embedding_dim
13
  self.index = None
14
  self.item_id_to_idx = {}
 
40
  # Exact search (slower but accurate)
41
  self.index = faiss.IndexFlatIP(self.embedding_dim)
42
  elif index_type == "IVF":
43
+ # For CPU use exact search (IndexFlatIP) for better accuracy
44
+ # IVF is mainly beneficial for GPU, for CPU stick with exact search
45
+ print("Using IndexFlatIP for CPU (exact search)")
46
+ self.index = faiss.IndexFlatIP(self.embedding_dim)
 
 
 
 
47
  else:
48
  raise ValueError(f"Unsupported index type: {index_type}")
49
 
 
130
  sample_queries = list(self.item_id_to_idx.keys())[:5]
131
 
132
  print("Validating FAISS index...")
133
+ print("Note: Higher similarity scores = more similar items (cosine similarity)")
134
 
135
  for query_item in sample_queries:
136
  if query_item not in self.item_id_to_idx:
 
138
 
139
  similar_items = self.search_similar_items(query_item, k=5)
140
 
141
+ print(f"\nSimilar items to {query_item} (sorted by similarity DESC):")
142
+ for i, (item_id, score) in enumerate(similar_items):
143
+ print(f" #{i+1} Item {item_id}: similarity = {score:.4f}")
144
+
145
+ # Check if scores are properly ordered (descending)
146
+ scores = [score for _, score in similar_items]
147
+ if len(scores) > 1 and not all(scores[i] >= scores[i+1] for i in range(len(scores)-1)):
148
+ print(f" WARNING: Scores not in descending order! {scores}")
149
+ else:
150
+ print(f" ✓ Scores properly ordered (most to least similar)")
151
 
152
  def save_index(self, save_path: str = "src/artifacts/") -> None:
153
  """Save FAISS index and mappings."""
 
201
 
202
  # Create and build FAISS index
203
  print("Building FAISS index...")
204
+ faiss_index = FAISSItemIndex(embedding_dim=128)
205
  faiss_index.build_index(item_embeddings, index_type="IVF")
206
 
207
  # Validate index
src/inference/recommendation_engine.py CHANGED
@@ -129,8 +129,8 @@ class RecommendationEngine:
129
 
130
  self.user_tower = UserTower(
131
  max_history_length=50,
132
- embedding_dim=64,
133
- hidden_dims=[128, 64],
134
  dropout_rate=0.2
135
  )
136
 
@@ -139,7 +139,7 @@ class RecommendationEngine:
139
  'age': tf.constant([2]), # Adult category (26-35)
140
  'gender': tf.constant([1]), # Male
141
  'income': tf.constant([2]), # Middle income category
142
- 'item_history_embeddings': tf.constant([[[0.0] * 64] * 50])
143
  }
144
  _ = self.user_tower(dummy_input)
145
 
@@ -182,7 +182,7 @@ class RecommendationEngine:
182
  ])
183
 
184
  # Build model with dummy input (concatenated user and item embeddings)
185
- dummy_input = tf.constant([[0.0] * 128]) # 64 + 64 = 128
186
  _ = self.rating_model(dummy_input)
187
 
188
  try:
@@ -216,14 +216,16 @@ class RecommendationEngine:
216
  history_embeddings.append(embedding)
217
  else:
218
  # Use zero embedding for unknown items
219
- history_embeddings.append(np.zeros(64))
220
 
221
  # Pad or truncate to max_history_length
222
  max_history_length = 50
223
  if len(history_embeddings) < max_history_length:
224
- padding = [np.zeros(64)] * (max_history_length - len(history_embeddings))
225
- history_embeddings = padding + history_embeddings
 
226
  else:
 
227
  history_embeddings = history_embeddings[-max_history_length:]
228
 
229
  history_embeddings = np.array(history_embeddings, dtype=np.float32)
@@ -301,27 +303,52 @@ class RecommendationEngine:
301
  income: float,
302
  interaction_history: List[int] = None,
303
  k: int = 10,
304
- exclude_history: bool = True) -> List[Tuple[int, float, Dict]]:
305
- """Generate recommendations using collaborative filtering (user-item similarity)."""
 
306
 
307
  # Get user embedding
308
  user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
309
 
310
- # Find similar items using FAISS
311
- similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 2)
312
 
313
- # Filter out interaction history if requested
314
- if exclude_history and interaction_history:
315
- history_set = set(interaction_history)
316
- similar_items = [(item_id, score) for item_id, score in similar_items
317
- if item_id not in history_set]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- # Take top k
320
- similar_items = similar_items[:k]
 
321
 
322
  # Add item metadata
323
  recommendations = []
324
- for item_id, score in similar_items:
325
  item_info = self._get_item_info(item_id)
326
  recommendations.append((item_id, score, item_info))
327
 
 
129
 
130
  self.user_tower = UserTower(
131
  max_history_length=50,
132
+ embedding_dim=128, # Changed from 64 to 128
133
+ hidden_dims=[128, 64], # Match training architecture
134
  dropout_rate=0.2
135
  )
136
 
 
139
  'age': tf.constant([2]), # Adult category (26-35)
140
  'gender': tf.constant([1]), # Male
141
  'income': tf.constant([2]), # Middle income category
142
+ 'item_history_embeddings': tf.constant([[[0.0] * 128] * 50]) # Changed from 64 to 128
143
  }
144
  _ = self.user_tower(dummy_input)
145
 
 
182
  ])
183
 
184
  # Build model with dummy input (concatenated user and item embeddings)
185
+ dummy_input = tf.constant([[0.0] * 256]) # 128 + 128 = 256
186
  _ = self.rating_model(dummy_input)
187
 
188
  try:
 
216
  history_embeddings.append(embedding)
217
  else:
218
  # Use zero embedding for unknown items
219
+ history_embeddings.append(np.zeros(128)) # Changed from 64 to 128
220
 
221
  # Pad or truncate to max_history_length
222
  max_history_length = 50
223
  if len(history_embeddings) < max_history_length:
224
+ # Add padding at the END so real interactions are at the BEGINNING
225
+ padding = [np.zeros(128)] * (max_history_length - len(history_embeddings))
226
+ history_embeddings = history_embeddings + padding
227
  else:
228
+ # Keep most recent interactions
229
  history_embeddings = history_embeddings[-max_history_length:]
230
 
231
  history_embeddings = np.array(history_embeddings, dtype=np.float32)
 
303
  income: float,
304
  interaction_history: List[int] = None,
305
  k: int = 10,
306
+ exclude_history: bool = True,
307
+ category_boost: float = 1.3) -> List[Tuple[int, float, Dict]]:
308
+ """Generate recommendations using collaborative filtering with category awareness."""
309
 
310
  # Get user embedding
311
  user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
312
 
313
+ # Find similar items using FAISS (get more candidates for boosting)
314
+ similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 4)
315
 
316
+ # Get user's preferred categories from interaction history
317
+ user_categories = set()
318
+ if interaction_history:
319
+ for item_id in interaction_history[-10:]: # Focus on recent interactions
320
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
321
+ if len(item_row) > 0:
322
+ user_categories.add(item_row.iloc[0]['category_code'])
323
+
324
+ # Filter out interaction history and apply category boosting
325
+ boosted_items = []
326
+ history_set = set(interaction_history) if (exclude_history and interaction_history) else set()
327
+
328
+ for item_id, score in similar_items:
329
+ if item_id in history_set:
330
+ continue
331
+
332
+ # Get item category
333
+ item_row = self.items_df[self.items_df['product_id'] == item_id]
334
+ if len(item_row) > 0:
335
+ item_category = item_row.iloc[0]['category_code']
336
+
337
+ # Boost score if item is in user's preferred categories
338
+ if item_category in user_categories:
339
+ boosted_score = score * category_boost
340
+ else:
341
+ boosted_score = score
342
+
343
+ boosted_items.append((item_id, boosted_score))
344
 
345
+ # Sort by boosted score and take top k
346
+ boosted_items.sort(key=lambda x: x[1], reverse=True)
347
+ boosted_items = boosted_items[:k]
348
 
349
  # Add item metadata
350
  recommendations = []
351
+ for item_id, score in boosted_items:
352
  item_info = self._get_item_info(item_id)
353
  recommendations.append((item_id, score, item_info))
354
 
src/models/enhanced_two_tower.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enhanced two-tower model with embedding diversity regularization and improved discrimination.
4
+ """
5
+
6
+ import tensorflow as tf
7
+ import tensorflow_recommenders as tfrs
8
+ import numpy as np
9
+
10
+
11
+ class EmbeddingDiversityRegularizer(tf.keras.layers.Layer):
12
+ """Regularizer to prevent embedding collapse by enforcing diversity."""
13
+
14
+ def __init__(self, diversity_weight=0.01, orthogonality_weight=0.05, **kwargs):
15
+ super().__init__(**kwargs)
16
+ self.diversity_weight = diversity_weight
17
+ self.orthogonality_weight = orthogonality_weight
18
+
19
+ def call(self, embeddings):
20
+ """Apply diversity regularization to embeddings."""
21
+ batch_size = tf.shape(embeddings)[0]
22
+
23
+ # Compute pairwise cosine similarities
24
+ normalized_embeddings = tf.nn.l2_normalize(embeddings, axis=1)
25
+ similarity_matrix = tf.linalg.matmul(
26
+ normalized_embeddings, normalized_embeddings, transpose_b=True
27
+ )
28
+
29
+ # Remove diagonal (self-similarities)
30
+ mask = 1.0 - tf.eye(batch_size)
31
+ masked_similarities = similarity_matrix * mask
32
+
33
+ # Diversity loss: penalize high similarities between different embeddings
34
+ diversity_loss = tf.reduce_mean(tf.square(masked_similarities))
35
+
36
+ # Orthogonality loss: encourage embeddings to be orthogonal
37
+ identity_target = tf.eye(batch_size)
38
+ orthogonality_loss = tf.reduce_mean(
39
+ tf.square(similarity_matrix - identity_target)
40
+ )
41
+
42
+ # Add as regularization losses
43
+ self.add_loss(self.diversity_weight * diversity_loss)
44
+ self.add_loss(self.orthogonality_weight * orthogonality_loss)
45
+
46
+ return embeddings
47
+
48
+
49
+ class AdaptiveTemperatureScaling(tf.keras.layers.Layer):
50
+ """Advanced temperature scaling with learned parameters."""
51
+
52
+ def __init__(self, initial_temperature=1.0, min_temp=0.1, max_temp=5.0, **kwargs):
53
+ super().__init__(**kwargs)
54
+ self.initial_temperature = initial_temperature
55
+ self.min_temp = min_temp
56
+ self.max_temp = max_temp
57
+
58
+ def build(self, input_shape):
59
+ # Learnable temperature with constraints
60
+ self.raw_temperature = self.add_weight(
61
+ name='raw_temperature',
62
+ shape=(),
63
+ initializer=tf.keras.initializers.Constant(
64
+ np.log(self.initial_temperature - self.min_temp)
65
+ ),
66
+ trainable=True
67
+ )
68
+
69
+ # Learnable bias term for better discrimination
70
+ self.similarity_bias = self.add_weight(
71
+ name='similarity_bias',
72
+ shape=(),
73
+ initializer=tf.keras.initializers.Zeros(),
74
+ trainable=True
75
+ )
76
+
77
+ super().build(input_shape)
78
+
79
+ def call(self, user_embeddings, item_embeddings):
80
+ """Compute adaptive temperature-scaled similarity with bias."""
81
+ # Constrain temperature to valid range
82
+ temperature = self.min_temp + tf.nn.softplus(self.raw_temperature)
83
+ temperature = tf.minimum(temperature, self.max_temp)
84
+
85
+ # Compute similarities
86
+ similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
87
+
88
+ # Add learnable bias and apply temperature scaling
89
+ scaled_similarities = (similarities + self.similarity_bias) / temperature
90
+
91
+ return scaled_similarities, temperature
92
+
93
+
94
+ class EnhancedItemTower(tf.keras.Model):
95
+ """Enhanced item tower with diversity regularization."""
96
+
97
+ def __init__(self,
98
+ item_vocab_size: int,
99
+ category_vocab_size: int,
100
+ brand_vocab_size: int,
101
+ embedding_dim: int = 128,
102
+ hidden_dims: list = [256, 128],
103
+ dropout_rate: float = 0.3,
104
+ use_bias: bool = True,
105
+ use_diversity_reg: bool = True):
106
+ super().__init__()
107
+
108
+ self.embedding_dim = embedding_dim
109
+ self.use_bias = use_bias
110
+ self.use_diversity_reg = use_diversity_reg
111
+
112
+ # Embedding layers with better initialization
113
+ self.item_embedding = tf.keras.layers.Embedding(
114
+ item_vocab_size, embedding_dim,
115
+ embeddings_initializer='he_normal', # Better initialization
116
+ embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
117
+ name="item_embedding"
118
+ )
119
+ self.category_embedding = tf.keras.layers.Embedding(
120
+ category_vocab_size, embedding_dim,
121
+ embeddings_initializer='he_normal',
122
+ embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
123
+ name="category_embedding"
124
+ )
125
+ self.brand_embedding = tf.keras.layers.Embedding(
126
+ brand_vocab_size, embedding_dim,
127
+ embeddings_initializer='he_normal',
128
+ embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
129
+ name="brand_embedding"
130
+ )
131
+
132
+ # Price processing
133
+ self.price_normalization = tf.keras.layers.Normalization(name="price_norm")
134
+ self.price_projection = tf.keras.layers.Dense(
135
+ embedding_dim // 4, activation='relu', name="price_proj"
136
+ )
137
+
138
+ # Enhanced attention mechanism
139
+ self.feature_attention = tf.keras.layers.MultiHeadAttention(
140
+ num_heads=4,
141
+ key_dim=embedding_dim,
142
+ dropout=0.1,
143
+ name="feature_attention"
144
+ )
145
+
146
+ # Dense layers with residual connections
147
+ self.dense_layers = []
148
+ for i, dim in enumerate(hidden_dims):
149
+ self.dense_layers.extend([
150
+ tf.keras.layers.Dense(dim, activation=None, name=f"dense_{i}"),
151
+ tf.keras.layers.BatchNormalization(name=f"bn_{i}"),
152
+ tf.keras.layers.Activation('relu', name=f"relu_{i}"),
153
+ tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}")
154
+ ])
155
+
156
+ # Output layer with controlled normalization
157
+ self.output_layer = tf.keras.layers.Dense(
158
+ embedding_dim, activation=None, use_bias=use_bias, name="item_output"
159
+ )
160
+
161
+ # Diversity regularizer
162
+ if use_diversity_reg:
163
+ self.diversity_regularizer = EmbeddingDiversityRegularizer()
164
+
165
+ # Adaptive normalization instead of hard L2 normalization
166
+ self.adaptive_norm = tf.keras.layers.LayerNormalization(name="adaptive_norm")
167
+
168
+ # Item bias
169
+ if use_bias:
170
+ self.item_bias = tf.keras.layers.Embedding(
171
+ item_vocab_size, 1, name="item_bias"
172
+ )
173
+
174
+ def call(self, inputs, training=None):
175
+ """Enhanced forward pass with diversity regularization."""
176
+ item_id = inputs["product_id"]
177
+ category_id = inputs["category_id"]
178
+ brand_id = inputs["brand_id"]
179
+ price = inputs["price"]
180
+
181
+ # Get embeddings
182
+ item_emb = self.item_embedding(item_id)
183
+ category_emb = self.category_embedding(category_id)
184
+ brand_emb = self.brand_embedding(brand_id)
185
+
186
+ # Process price
187
+ price_norm = self.price_normalization(tf.expand_dims(price, -1))
188
+ price_emb = self.price_projection(price_norm)
189
+
190
+ # Pad price embedding
191
+ price_emb_padded = tf.pad(
192
+ price_emb,
193
+ [[0, 0], [0, self.embedding_dim - tf.shape(price_emb)[-1]]]
194
+ )
195
+
196
+ # Stack features for attention
197
+ features = tf.stack([item_emb, category_emb, brand_emb, price_emb_padded], axis=1)
198
+
199
+ # Apply attention
200
+ attended_features = self.feature_attention(
201
+ query=features,
202
+ value=features,
203
+ key=features,
204
+ training=training
205
+ )
206
+
207
+ # Aggregate with residual connection
208
+ combined = tf.reduce_mean(attended_features + features, axis=1)
209
+
210
+ # Pass through dense layers with residual connections
211
+ x = combined
212
+ residual = x
213
+ for i, layer in enumerate(self.dense_layers):
214
+ x = layer(x, training=training)
215
+ # Add residual connection every 4 layers (complete block)
216
+ if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
217
+ x = x + residual
218
+ residual = x
219
+
220
+ # Final output
221
+ output = self.output_layer(x)
222
+
223
+ # Apply diversity regularization if enabled
224
+ if self.use_diversity_reg and training:
225
+ output = self.diversity_regularizer(output)
226
+
227
+ # Adaptive normalization instead of hard L2
228
+ normalized_output = self.adaptive_norm(output)
229
+
230
+ # Add bias if enabled
231
+ if self.use_bias:
232
+ bias = tf.squeeze(self.item_bias(item_id), axis=-1)
233
+ return normalized_output, bias
234
+ else:
235
+ return normalized_output
236
+
237
+
238
+ class EnhancedUserTower(tf.keras.Model):
239
+ """Enhanced user tower with diversity regularization."""
240
+
241
+ def __init__(self,
242
+ max_history_length: int = 50,
243
+ embedding_dim: int = 128,
244
+ hidden_dims: list = [256, 128],
245
+ dropout_rate: float = 0.3,
246
+ use_bias: bool = True,
247
+ use_diversity_reg: bool = True):
248
+ super().__init__()
249
+
250
+ self.embedding_dim = embedding_dim
251
+ self.max_history_length = max_history_length
252
+ self.use_bias = use_bias
253
+ self.use_diversity_reg = use_diversity_reg
254
+
255
+ # Demographic embeddings with regularization
256
+ self.age_embedding = tf.keras.layers.Embedding(
257
+ 6, embedding_dim // 16,
258
+ embeddings_initializer='he_normal',
259
+ embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
260
+ name="age_embedding"
261
+ )
262
+ self.income_embedding = tf.keras.layers.Embedding(
263
+ 5, embedding_dim // 16,
264
+ embeddings_initializer='he_normal',
265
+ embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
266
+ name="income_embedding"
267
+ )
268
+ self.gender_embedding = tf.keras.layers.Embedding(
269
+ 2, embedding_dim // 16,
270
+ embeddings_initializer='he_normal',
271
+ embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
272
+ name="gender_embedding"
273
+ )
274
+
275
+ # Enhanced history processing
276
+ self.history_transformer = tf.keras.layers.MultiHeadAttention(
277
+ num_heads=8,
278
+ key_dim=embedding_dim,
279
+ dropout=0.1,
280
+ name="history_transformer"
281
+ )
282
+
283
+ # History aggregation with attention pooling
284
+ self.history_attention_pooling = tf.keras.layers.Dense(
285
+ 1, activation=None, name="history_attention"
286
+ )
287
+
288
+ # Dense layers with residual connections
289
+ self.dense_layers = []
290
+ for i, dim in enumerate(hidden_dims):
291
+ self.dense_layers.extend([
292
+ tf.keras.layers.Dense(dim, activation=None, name=f"user_dense_{i}"),
293
+ tf.keras.layers.BatchNormalization(name=f"user_bn_{i}"),
294
+ tf.keras.layers.Activation('relu', name=f"user_relu_{i}"),
295
+ tf.keras.layers.Dropout(dropout_rate, name=f"user_dropout_{i}")
296
+ ])
297
+
298
+ # Output layer
299
+ self.output_layer = tf.keras.layers.Dense(
300
+ embedding_dim, activation=None, use_bias=use_bias, name="user_output"
301
+ )
302
+
303
+ # Diversity regularizer
304
+ if use_diversity_reg:
305
+ self.diversity_regularizer = EmbeddingDiversityRegularizer()
306
+
307
+ # Adaptive normalization
308
+ self.adaptive_norm = tf.keras.layers.LayerNormalization(name="user_adaptive_norm")
309
+
310
+ # Global user bias
311
+ if use_bias:
312
+ self.global_user_bias = tf.Variable(
313
+ initial_value=0.0, trainable=True, name="global_user_bias"
314
+ )
315
+
316
+ def call(self, inputs, training=None):
317
+ """Enhanced forward pass with diversity regularization."""
318
+ age = inputs["age"]
319
+ gender = inputs["gender"]
320
+ income = inputs["income"]
321
+ item_history = inputs["item_history_embeddings"]
322
+
323
+ # Process demographics
324
+ age_emb = self.age_embedding(age)
325
+ income_emb = self.income_embedding(income)
326
+ gender_emb = self.gender_embedding(gender)
327
+
328
+ # Combine demographics
329
+ demo_combined = tf.concat([age_emb, income_emb, gender_emb], axis=-1)
330
+
331
+ # Enhanced history processing
332
+ batch_size = tf.shape(item_history)[0]
333
+ seq_len = tf.shape(item_history)[1]
334
+
335
+ # Simplified positional encoding - ensure shape compatibility
336
+ positions = tf.range(seq_len, dtype=tf.float32)
337
+ # Create simpler positional encoding
338
+ pos_encoding_scale = tf.cast(tf.range(self.embedding_dim, dtype=tf.float32), tf.float32) / self.embedding_dim
339
+ position_encoding = tf.sin(positions[:, tf.newaxis] * pos_encoding_scale[tf.newaxis, :])
340
+
341
+ # Ensure correct shape: [seq_len, embedding_dim] -> [batch_size, seq_len, embedding_dim]
342
+ position_encoding = tf.expand_dims(position_encoding, 0)
343
+ position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
344
+
345
+ # Add positional encoding with shape check
346
+ history_with_pos = item_history + position_encoding
347
+
348
+ # Create attention mask - fix shape for MultiHeadAttention
349
+ # MultiHeadAttention expects mask shape: [batch_size, seq_len] or [batch_size, seq_len, seq_len]
350
+ history_mask = tf.reduce_sum(tf.abs(item_history), axis=-1) > 0 # [batch_size, seq_len]
351
+
352
+ # Apply transformer attention
353
+ attended_history = self.history_transformer(
354
+ query=history_with_pos,
355
+ value=history_with_pos,
356
+ key=history_with_pos,
357
+ attention_mask=history_mask,
358
+ training=training
359
+ )
360
+
361
+ # Attention-based pooling instead of simple mean
362
+ attention_weights = tf.nn.softmax(
363
+ self.history_attention_pooling(attended_history), axis=1
364
+ )
365
+ history_aggregated = tf.reduce_sum(
366
+ attended_history * attention_weights, axis=1
367
+ )
368
+
369
+ # Combine features
370
+ combined = tf.concat([demo_combined, history_aggregated], axis=-1)
371
+
372
+ # Pass through dense layers with residual connections
373
+ x = combined
374
+ residual = x
375
+ for i, layer in enumerate(self.dense_layers):
376
+ x = layer(x, training=training)
377
+ # Add residual connection every 4 layers
378
+ if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
379
+ x = x + residual
380
+ residual = x
381
+
382
+ # Final output
383
+ output = self.output_layer(x)
384
+
385
+ # Apply diversity regularization if enabled
386
+ if self.use_diversity_reg and training:
387
+ output = self.diversity_regularizer(output)
388
+
389
+ # Adaptive normalization
390
+ normalized_output = self.adaptive_norm(output)
391
+
392
+ # Add bias if enabled
393
+ if self.use_bias:
394
+ return normalized_output, self.global_user_bias
395
+ else:
396
+ return normalized_output
397
+
398
+
399
+ class EnhancedTwoTowerModel(tfrs.Model):
400
+ """Enhanced two-tower model with all improvements."""
401
+
402
+ def __init__(self,
403
+ item_tower: EnhancedItemTower,
404
+ user_tower: EnhancedUserTower,
405
+ rating_weight: float = 1.0,
406
+ retrieval_weight: float = 1.0,
407
+ contrastive_weight: float = 0.3,
408
+ diversity_weight: float = 0.1):
409
+ super().__init__()
410
+
411
+ self.item_tower = item_tower
412
+ self.user_tower = user_tower
413
+ self.rating_weight = rating_weight
414
+ self.retrieval_weight = retrieval_weight
415
+ self.contrastive_weight = contrastive_weight
416
+ self.diversity_weight = diversity_weight
417
+
418
+ # Adaptive temperature scaling
419
+ self.temperature_similarity = AdaptiveTemperatureScaling()
420
+
421
+ # Enhanced rating model
422
+ self.rating_model = tf.keras.Sequential([
423
+ tf.keras.layers.Dense(512, activation="relu"),
424
+ tf.keras.layers.BatchNormalization(),
425
+ tf.keras.layers.Dropout(0.3),
426
+ tf.keras.layers.Dense(256, activation="relu"),
427
+ tf.keras.layers.BatchNormalization(),
428
+ tf.keras.layers.Dropout(0.2),
429
+ tf.keras.layers.Dense(64, activation="relu"),
430
+ tf.keras.layers.Dense(1, activation="sigmoid")
431
+ ])
432
+
433
+ # Focal loss for imbalanced data
434
+ self.focal_loss = self._focal_loss
435
+
436
+ def _focal_loss(self, y_true, y_pred, alpha=0.25, gamma=2.0):
437
+ """Focal loss implementation."""
438
+ epsilon = tf.keras.backend.epsilon()
439
+ y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
440
+
441
+ alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
442
+ p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
443
+ focal_weight = alpha_t * tf.pow((1 - p_t), gamma)
444
+
445
+ bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
446
+ focal_loss = focal_weight * bce
447
+
448
+ return tf.reduce_mean(focal_loss)
449
+
450
+ def call(self, features):
451
+ # Get embeddings
452
+ user_output = self.user_tower(features)
453
+ item_output = self.item_tower(features)
454
+
455
+ # Handle bias terms
456
+ if isinstance(user_output, tuple):
457
+ user_embeddings, user_bias = user_output
458
+ else:
459
+ user_embeddings = user_output
460
+ user_bias = 0.0
461
+
462
+ if isinstance(item_output, tuple):
463
+ item_embeddings, item_bias = item_output
464
+ else:
465
+ item_embeddings = item_output
466
+ item_bias = 0.0
467
+
468
+ return {
469
+ "user_embedding": user_embeddings,
470
+ "item_embedding": item_embeddings,
471
+ "user_bias": user_bias,
472
+ "item_bias": item_bias
473
+ }
474
+
475
+ def compute_loss(self, features, training=False):
476
+ # Get embeddings and biases
477
+ outputs = self(features)
478
+ user_embeddings = outputs["user_embedding"]
479
+ item_embeddings = outputs["item_embedding"]
480
+ user_bias = outputs["user_bias"]
481
+ item_bias = outputs["item_bias"]
482
+
483
+ # Rating prediction
484
+ concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
485
+ rating_predictions = self.rating_model(concatenated, training=training)
486
+
487
+ # Add bias terms
488
+ rating_predictions_with_bias = rating_predictions + user_bias + item_bias
489
+ rating_predictions_with_bias = tf.nn.sigmoid(rating_predictions_with_bias)
490
+
491
+ # Losses
492
+ rating_loss = self.focal_loss(features["rating"], rating_predictions_with_bias)
493
+
494
+ # Adaptive temperature-scaled retrieval loss
495
+ scaled_similarities, temperature = self.temperature_similarity(
496
+ user_embeddings, item_embeddings
497
+ )
498
+ retrieval_loss = tf.keras.losses.binary_crossentropy(
499
+ features["rating"],
500
+ tf.nn.sigmoid(scaled_similarities)
501
+ )
502
+ retrieval_loss = tf.reduce_mean(retrieval_loss)
503
+
504
+ # Enhanced contrastive loss with hard negatives
505
+ batch_size = tf.shape(user_embeddings)[0]
506
+ positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
507
+
508
+ # Random negative sampling
509
+ shuffled_indices = tf.random.shuffle(tf.range(batch_size))
510
+ negative_item_embeddings = tf.gather(item_embeddings, shuffled_indices)
511
+ negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
512
+
513
+ # Triplet loss with adaptive margin
514
+ margin = 0.5 / temperature # Adaptive margin based on temperature
515
+ contrastive_loss = tf.reduce_mean(
516
+ tf.maximum(0.0, margin + negative_similarities - positive_similarities)
517
+ )
518
+
519
+ # Combine losses
520
+ total_loss = (
521
+ self.rating_weight * rating_loss +
522
+ self.retrieval_weight * retrieval_loss +
523
+ self.contrastive_weight * contrastive_loss
524
+ )
525
+
526
+ # Add regularization losses from diversity regularizers
527
+ if training:
528
+ regularization_losses = tf.add_n(self.losses) if self.losses else 0.0
529
+ total_loss += self.diversity_weight * regularization_losses
530
+
531
+ return {
532
+ 'total_loss': total_loss,
533
+ 'rating_loss': rating_loss,
534
+ 'retrieval_loss': retrieval_loss,
535
+ 'contrastive_loss': contrastive_loss,
536
+ 'temperature': temperature,
537
+ 'diversity_loss': regularization_losses if training else 0.0
538
+ }
539
+
540
+
541
+ def create_enhanced_model(data_processor,
542
+ embedding_dim=128,
543
+ use_bias=True,
544
+ use_diversity_reg=True):
545
+ """Factory function to create enhanced two-tower model."""
546
+
547
+ # Create enhanced towers
548
+ item_tower = EnhancedItemTower(
549
+ item_vocab_size=len(data_processor.item_vocab),
550
+ category_vocab_size=len(data_processor.category_vocab),
551
+ brand_vocab_size=len(data_processor.brand_vocab),
552
+ embedding_dim=embedding_dim,
553
+ use_bias=use_bias,
554
+ use_diversity_reg=use_diversity_reg
555
+ )
556
+
557
+ user_tower = EnhancedUserTower(
558
+ max_history_length=50,
559
+ embedding_dim=embedding_dim,
560
+ use_bias=use_bias,
561
+ use_diversity_reg=use_diversity_reg
562
+ )
563
+
564
+ # Create enhanced model
565
+ model = EnhancedTwoTowerModel(
566
+ item_tower=item_tower,
567
+ user_tower=user_tower,
568
+ rating_weight=1.0,
569
+ retrieval_weight=0.5,
570
+ contrastive_weight=0.3,
571
+ diversity_weight=0.1
572
+ )
573
+
574
+ return model
src/models/item_tower.py CHANGED
@@ -10,8 +10,8 @@ class ItemTower(tf.keras.Model):
10
  item_vocab_size: int,
11
  category_vocab_size: int,
12
  brand_vocab_size: int,
13
- embedding_dim: int = 64,
14
- hidden_dims: list = [128, 64],
15
  dropout_rate: float = 0.2):
16
  super().__init__()
17
 
 
10
  item_vocab_size: int,
11
  category_vocab_size: int,
12
  brand_vocab_size: int,
13
+ embedding_dim: int = 128, # Output embedding dimension
14
+ hidden_dims: list = [256, 128], # Internal dims can be larger
15
  dropout_rate: float = 0.2):
16
  super().__init__()
17
 
src/models/user_tower.py CHANGED
@@ -8,8 +8,8 @@ class UserTower(tf.keras.Model):
8
 
9
  def __init__(self,
10
  max_history_length: int = 50,
11
- embedding_dim: int = 64,
12
- hidden_dims: list = [128, 64],
13
  dropout_rate: float = 0.2):
14
  super().__init__()
15
 
 
8
 
9
  def __init__(self,
10
  max_history_length: int = 50,
11
+ embedding_dim: int = 128, # Output embedding dimension
12
+ hidden_dims: list = [256, 128], # Internal dims for processing
13
  dropout_rate: float = 0.2):
14
  super().__init__()
15
 
src/preprocessing/data_loader.py CHANGED
@@ -4,6 +4,8 @@ import tensorflow as tf
4
  from typing import Dict, List, Tuple, Optional
5
  from collections import defaultdict
6
  import pickle
 
 
7
 
8
 
9
  class DataProcessor:
@@ -97,54 +99,73 @@ class DataProcessor:
97
  interactions_df: pd.DataFrame,
98
  items_df: pd.DataFrame,
99
  negative_samples_per_positive: int = 4) -> pd.DataFrame:
100
- """Create positive and negative user-item pairs for training."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # Get all unique items for negative sampling
103
  all_items = set(self.item_vocab.keys())
 
104
 
105
- # Create positive pairs
106
- positive_pairs = []
107
- for _, row in interactions_df.iterrows():
108
- if row['user_id'] in self.user_vocab and row['product_id'] in self.item_vocab:
109
- positive_pairs.append({
110
- 'user_id': row['user_id'],
111
- 'product_id': row['product_id'],
112
- 'rating': 1.0 # Implicit positive feedback
113
- })
114
-
115
- # Create negative pairs
116
- negative_pairs = []
117
- user_item_interactions = set(
118
- (row['user_id'], row['product_id'])
119
- for _, row in interactions_df.iterrows()
120
- )
121
-
122
- for pos_pair in positive_pairs:
123
- user_id = pos_pair['user_id']
124
- user_interactions = set(
125
- row['product_id'] for _, row in interactions_df.iterrows()
126
- if row['user_id'] == user_id
127
- )
128
 
129
- # Sample negative items
130
- negative_items = all_items - user_interactions
131
  if len(negative_items) >= negative_samples_per_positive:
 
132
  sampled_negatives = np.random.choice(
133
- list(negative_items),
134
- size=negative_samples_per_positive,
135
- replace=False
136
  )
137
 
138
- for neg_item in sampled_negatives:
139
- negative_pairs.append({
140
- 'user_id': user_id,
141
- 'product_id': neg_item,
142
- 'rating': 0.0 # Negative feedback
143
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # Combine positive and negative pairs
146
- all_pairs = positive_pairs + negative_pairs
147
- return pd.DataFrame(all_pairs)
148
 
149
  def save_vocabularies(self, save_path: str = "src/artifacts/"):
150
  """Save vocabularies for later use."""
@@ -176,9 +197,19 @@ class DataProcessor:
176
  print("Vocabularies loaded successfully")
177
 
178
 
179
- def create_tf_dataset(features: Dict[str, np.ndarray], batch_size: int = 256) -> tf.data.Dataset:
180
- """Create TensorFlow dataset from features."""
181
  dataset = tf.data.Dataset.from_tensor_slices(features)
 
 
 
 
 
 
 
182
  dataset = dataset.batch(batch_size)
183
- dataset = dataset.prefetch(tf.data.AUTOTUNE)
 
 
 
184
  return dataset
 
4
  from typing import Dict, List, Tuple, Optional
5
  from collections import defaultdict
6
  import pickle
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import multiprocessing as mp
9
 
10
 
11
  class DataProcessor:
 
99
  interactions_df: pd.DataFrame,
100
  items_df: pd.DataFrame,
101
  negative_samples_per_positive: int = 4) -> pd.DataFrame:
102
+ """Create positive and negative user-item pairs for training (optimized)."""
103
+
104
+ # Filter valid interactions once
105
+ valid_interactions = interactions_df[
106
+ (interactions_df['user_id'].isin(self.user_vocab)) &
107
+ (interactions_df['product_id'].isin(self.item_vocab))
108
+ ].copy()
109
+
110
+ # Create positive pairs vectorized
111
+ positive_pairs = valid_interactions[['user_id', 'product_id']].copy()
112
+ positive_pairs['rating'] = 1.0
113
+
114
+ # Pre-compute user interactions for faster lookup
115
+ user_items_dict = (
116
+ valid_interactions.groupby('user_id')['product_id']
117
+ .apply(set).to_dict()
118
+ )
119
 
 
120
  all_items = set(self.item_vocab.keys())
121
+ all_items_array = np.array(list(all_items))
122
 
123
+ # Generate negative samples in parallel
124
+ def generate_negatives_for_user(user_data):
125
+ user_id, user_items = user_data
126
+ negative_items = all_items - user_items
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
 
 
128
  if len(negative_items) >= negative_samples_per_positive:
129
+ neg_items_array = np.array(list(negative_items))
130
  sampled_negatives = np.random.choice(
131
+ neg_items_array,
132
+ size=negative_samples_per_positive * len(user_items),
133
+ replace=len(negative_items) < negative_samples_per_positive * len(user_items)
134
  )
135
 
136
+ # Repeat user_id for each negative sample
137
+ user_ids = np.repeat(user_id, len(sampled_negatives))
138
+ ratings = np.zeros(len(sampled_negatives))
139
+
140
+ return pd.DataFrame({
141
+ 'user_id': user_ids,
142
+ 'product_id': sampled_negatives,
143
+ 'rating': ratings
144
+ })
145
+ return pd.DataFrame(columns=['user_id', 'product_id', 'rating'])
146
+
147
+ # Process in parallel chunks
148
+ chunk_size = max(1, len(user_items_dict) // mp.cpu_count())
149
+ user_chunks = [
150
+ list(user_items_dict.items())[i:i + chunk_size]
151
+ for i in range(0, len(user_items_dict), chunk_size)
152
+ ]
153
+
154
+ negative_dfs = []
155
+ with ThreadPoolExecutor(max_workers=mp.cpu_count()) as executor:
156
+ for chunk in user_chunks:
157
+ chunk_results = list(executor.map(generate_negatives_for_user, chunk))
158
+ negative_dfs.extend(chunk_results)
159
+
160
+ # Combine all negative samples
161
+ if negative_dfs:
162
+ negative_pairs = pd.concat(negative_dfs, ignore_index=True)
163
+ else:
164
+ negative_pairs = pd.DataFrame(columns=['user_id', 'product_id', 'rating'])
165
 
166
  # Combine positive and negative pairs
167
+ all_pairs = pd.concat([positive_pairs, negative_pairs], ignore_index=True)
168
+ return all_pairs
169
 
170
  def save_vocabularies(self, save_path: str = "src/artifacts/"):
171
  """Save vocabularies for later use."""
 
197
  print("Vocabularies loaded successfully")
198
 
199
 
200
+ def create_tf_dataset(features: Dict[str, np.ndarray], batch_size: int = 256, shuffle: bool = True) -> tf.data.Dataset:
201
+ """Create optimized TensorFlow dataset from features for CPU training."""
202
  dataset = tf.data.Dataset.from_tensor_slices(features)
203
+
204
+ if shuffle:
205
+ # Use reasonable buffer size for memory efficiency - handle different feature types
206
+ sample_key = next(iter(features.keys()))
207
+ buffer_size = min(len(features[sample_key]), 10000)
208
+ dataset = dataset.shuffle(buffer_size)
209
+
210
  dataset = dataset.batch(batch_size)
211
+
212
+ # Optimize for CPU with reasonable prefetch
213
+ dataset = dataset.prefetch(2) # Reduced from AUTOTUNE for CPU efficiency
214
+
215
  return dataset
src/preprocessing/optimized_dataset_creator.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Optimized dataset creation script with performance improvements.
3
+ """
4
+ import time
5
+ import numpy as np
6
+ from src.preprocessing.user_data_preparation import UserDatasetCreator
7
+ from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
8
+
9
+
10
+ def create_optimized_dataset(max_history_length: int = 50,
11
+ batch_size: int = 512,
12
+ negative_samples_per_positive: int = 2,
13
+ use_sample: bool = False,
14
+ sample_size: int = 10000):
15
+ """
16
+ Create dataset with optimized performance settings.
17
+
18
+ Args:
19
+ max_history_length: Maximum user interaction history length
20
+ batch_size: Batch size for TensorFlow dataset
21
+ negative_samples_per_positive: Negative sampling ratio
22
+ use_sample: Whether to use a sample of the data for faster processing
23
+ sample_size: Size of sample if use_sample=True
24
+ """
25
+ print("Starting optimized dataset creation...")
26
+ start_time = time.time()
27
+
28
+ # Initialize with optimized settings
29
+ dataset_creator = UserDatasetCreator(max_history_length=max_history_length)
30
+ data_processor = DataProcessor()
31
+
32
+ # Load data
33
+ print("Loading data...")
34
+ load_start = time.time()
35
+ items_df, users_df, interactions_df = data_processor.load_data()
36
+ print(f"Data loaded in {time.time() - load_start:.2f} seconds")
37
+
38
+ # Optional: Use sample for faster development/testing
39
+ if use_sample:
40
+ print(f"Using sample of {sample_size} interactions for faster processing...")
41
+ sample_interactions = interactions_df.sample(min(sample_size, len(interactions_df)))
42
+ user_ids = set(sample_interactions['user_id'])
43
+ item_ids = set(sample_interactions['product_id'])
44
+
45
+ users_df = users_df[users_df['user_id'].isin(user_ids)]
46
+ items_df = items_df[items_df['product_id'].isin(item_ids)]
47
+ interactions_df = sample_interactions
48
+
49
+ print(f"Sample: {len(items_df)} items, {len(users_df)} users, {len(interactions_df)} interactions")
50
+
51
+ # Load embeddings with caching
52
+ print("Loading item embeddings...")
53
+ embed_start = time.time()
54
+ item_embeddings = dataset_creator.load_item_embeddings()
55
+ print(f"Embeddings loaded in {time.time() - embed_start:.2f} seconds")
56
+
57
+ # Create temporal split
58
+ print("Creating temporal split...")
59
+ split_start = time.time()
60
+ train_interactions, val_interactions = dataset_creator.create_temporal_split(interactions_df)
61
+ print(f"Temporal split created in {time.time() - split_start:.2f} seconds")
62
+
63
+ # Create training dataset with optimizations
64
+ print("Creating optimized training dataset...")
65
+ train_start = time.time()
66
+ training_features = dataset_creator.create_training_dataset(
67
+ train_interactions, items_df, users_df, item_embeddings,
68
+ negative_samples_per_positive=negative_samples_per_positive
69
+ )
70
+ print(f"Training dataset created in {time.time() - train_start:.2f} seconds")
71
+
72
+ # Create TensorFlow dataset optimized for CPU
73
+ print("Creating TensorFlow dataset...")
74
+ tf_start = time.time()
75
+ tf_dataset = create_tf_dataset(training_features, batch_size=batch_size)
76
+ print(f"TensorFlow dataset created in {time.time() - tf_start:.2f} seconds")
77
+
78
+ # Save optimized dataset
79
+ print("Saving dataset...")
80
+ save_start = time.time()
81
+ dataset_creator.save_dataset(training_features, "src/artifacts/")
82
+
83
+ # Save vocabularies for later use
84
+ data_processor.save_vocabularies("src/artifacts/")
85
+ print(f"Dataset saved in {time.time() - save_start:.2f} seconds")
86
+
87
+ total_time = time.time() - start_time
88
+ print(f"\nOptimized dataset creation completed in {total_time:.2f} seconds!")
89
+ print(f"Training samples: {len(training_features['rating'])}")
90
+ print(f"Memory usage optimized for CPU training")
91
+
92
+ return tf_dataset, training_features
93
+
94
+
95
+ if __name__ == "__main__":
96
+ # Run with optimized settings
97
+ tf_dataset, features = create_optimized_dataset(
98
+ max_history_length=30, # Reduced for speed
99
+ batch_size=512, # Larger batches for CPU efficiency
100
+ negative_samples_per_positive=2, # Reduced sampling ratio
101
+ use_sample=True, # Use sample for development
102
+ sample_size=50000 # Reasonable sample size
103
+ )
104
+
105
+ print("\nDataset creation optimization complete!")
106
+ print("Key optimizations applied:")
107
+ print("- Vectorized DataFrame operations")
108
+ print("- Parallel negative sampling")
109
+ print("- Memory-efficient embedding lookup")
110
+ print("- Optimized TensorFlow dataset pipeline")
111
+ print("- LRU caching for embeddings")
src/preprocessing/user_data_preparation.py CHANGED
@@ -63,7 +63,7 @@ class UserDatasetCreator:
63
  # Use more efficient random generation
64
  num_items = len(items_df['product_id'].unique())
65
  item_ids = items_df['product_id'].unique()
66
- embedding_matrix = np.random.rand(num_items, 64).astype(np.float32)
67
 
68
  dummy_embeddings = dict(zip(item_ids, embedding_matrix))
69
  print(f"Created dummy embeddings for {len(dummy_embeddings)} items")
@@ -72,7 +72,7 @@ class UserDatasetCreator:
72
  def aggregate_user_history_embeddings(self,
73
  user_histories: Dict[int, List[int]],
74
  item_embeddings: Dict[int, np.ndarray],
75
- embedding_dim: int = 64) -> Dict[int, np.ndarray]:
76
  """Aggregate item embeddings for each user's interaction history."""
77
 
78
  user_aggregated_embeddings = {}
@@ -103,9 +103,11 @@ class UserDatasetCreator:
103
 
104
  # Pad or truncate to max_history_length
105
  if len(history_embeddings) < self.max_history_length:
 
106
  padding = np.zeros((self.max_history_length - len(history_embeddings), embedding_dim))
107
- history_embeddings = np.vstack([padding, history_embeddings])
108
  else:
 
109
  history_embeddings = history_embeddings[-self.max_history_length:]
110
 
111
  user_aggregated_embeddings[user_id] = history_embeddings
@@ -368,5 +370,58 @@ def main():
368
  print("User dataset creation completed!")
369
 
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  if __name__ == "__main__":
372
  main()
 
63
  # Use more efficient random generation
64
  num_items = len(items_df['product_id'].unique())
65
  item_ids = items_df['product_id'].unique()
66
+ embedding_matrix = np.random.rand(num_items, 128).astype(np.float32) # Updated to 128D
67
 
68
  dummy_embeddings = dict(zip(item_ids, embedding_matrix))
69
  print(f"Created dummy embeddings for {len(dummy_embeddings)} items")
 
72
  def aggregate_user_history_embeddings(self,
73
  user_histories: Dict[int, List[int]],
74
  item_embeddings: Dict[int, np.ndarray],
75
+ embedding_dim: int = 128) -> Dict[int, np.ndarray]: # Updated to 128D
76
  """Aggregate item embeddings for each user's interaction history."""
77
 
78
  user_aggregated_embeddings = {}
 
103
 
104
  # Pad or truncate to max_history_length
105
  if len(history_embeddings) < self.max_history_length:
106
+ # Add padding at the END so real interactions are at the BEGINNING
107
  padding = np.zeros((self.max_history_length - len(history_embeddings), embedding_dim))
108
+ history_embeddings = np.vstack([history_embeddings, padding])
109
  else:
110
+ # Keep most recent interactions
111
  history_embeddings = history_embeddings[-self.max_history_length:]
112
 
113
  user_aggregated_embeddings[user_id] = history_embeddings
 
370
  print("User dataset creation completed!")
371
 
372
 
373
+ def prepare_user_features(users_df: pd.DataFrame,
374
+ user_histories: Dict[int, List[int]],
375
+ item_features: Dict[str, np.ndarray],
376
+ max_history_length: int = 50,
377
+ embedding_dim: int = 128) -> Dict[int, Dict]:
378
+ """Standalone function to prepare user features with categorical demographics."""
379
+
380
+ creator = UserDatasetCreator(max_history_length=max_history_length)
381
+
382
+ # Create dummy item embeddings if not available (for 128D)
383
+ item_embeddings = {}
384
+ unique_items = set()
385
+ for history in user_histories.values():
386
+ unique_items.update(history)
387
+
388
+ # Create random embeddings for items (will be replaced by actual embeddings later)
389
+ for item_vocab_idx in unique_items:
390
+ item_embeddings[item_vocab_idx] = np.random.randn(embedding_dim).astype(np.float32)
391
+
392
+ # Get user aggregated embeddings
393
+ user_aggregated_embeddings = creator.aggregate_user_history_embeddings(
394
+ user_histories, item_embeddings, embedding_dim
395
+ )
396
+
397
+ # Process user features
398
+ user_feature_dict = {}
399
+
400
+ for _, user_row in users_df.iterrows():
401
+ user_id = user_row['user_id']
402
+
403
+ if user_id not in user_aggregated_embeddings:
404
+ continue
405
+
406
+ # Categorize demographics
407
+ age_cat = creator.categorize_age(user_row['age'])
408
+ gender_cat = 1 if user_row['gender'].lower() == 'male' else 0
409
+
410
+ # Categorize income using percentiles from all users
411
+ income_categories = creator.categorize_income(users_df['income'])
412
+ user_idx = users_df[users_df['user_id'] == user_id].index[0]
413
+ income_cat = income_categories[user_idx]
414
+
415
+ user_feature_dict[user_id] = {
416
+ 'age': age_cat,
417
+ 'gender': gender_cat,
418
+ 'income': income_cat,
419
+ 'item_history_embeddings': user_aggregated_embeddings[user_id]
420
+ }
421
+
422
+ print(f"Prepared features for {len(user_feature_dict)} users with {embedding_dim}D embeddings")
423
+ return user_feature_dict
424
+
425
+
426
  if __name__ == "__main__":
427
  main()
src/training/fast_joint_training.py CHANGED
@@ -66,7 +66,7 @@ class FastJointTrainer:
66
  # Build user tower (simplified)
67
  self.user_tower = UserTower(
68
  max_history_length=50,
69
- embedding_dim=64,
70
  hidden_dims=[64], # Simplified architecture
71
  dropout_rate=0.1
72
  )
 
66
  # Build user tower (simplified)
67
  self.user_tower = UserTower(
68
  max_history_length=50,
69
+ embedding_dim=128, # Updated to 128D
70
  hidden_dims=[64], # Simplified architecture
71
  dropout_rate=0.1
72
  )
src/training/improved_joint_training.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Improved joint training with hard negative mining, curriculum learning, and better optimization.
4
+ """
5
+
6
+ import tensorflow as tf
7
+ import numpy as np
8
+ import pickle
9
+ import os
10
+ from typing import Dict, List, Tuple, Optional
11
+ import time
12
+ from collections import defaultdict
13
+
14
+ from src.models.improved_two_tower import create_improved_model
15
+ from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
16
+
17
+
18
+ class HardNegativeSampler:
19
+ """Hard negative sampling strategy for better training."""
20
+
21
+ def __init__(self, model, item_embeddings, sampling_strategy='mixed'):
22
+ self.model = model
23
+ self.item_embeddings = item_embeddings # Pre-computed item embeddings
24
+ self.sampling_strategy = sampling_strategy
25
+
26
+ def sample_hard_negatives(self, user_embeddings, positive_items, k_hard=2, k_random=2):
27
+ """Sample hard negatives based on user-item similarity."""
28
+ batch_size = tf.shape(user_embeddings)[0]
29
+
30
+ # Compute similarities between users and all items
31
+ similarities = tf.linalg.matmul(user_embeddings, self.item_embeddings, transpose_b=True)
32
+
33
+ # Mask out positive items
34
+ positive_mask = tf.one_hot(positive_items, depth=tf.shape(self.item_embeddings)[0])
35
+ similarities = similarities - positive_mask * 1e9 # Large negative value
36
+
37
+ # Get top-k similar items (hard negatives)
38
+ _, hard_negative_indices = tf.nn.top_k(similarities, k=k_hard)
39
+
40
+ # Sample random negatives
41
+ total_items = tf.shape(self.item_embeddings)[0]
42
+ random_negatives = tf.random.uniform(
43
+ shape=[batch_size, k_random],
44
+ minval=0,
45
+ maxval=total_items,
46
+ dtype=tf.int32
47
+ )
48
+
49
+ # Combine hard and random negatives
50
+ if self.sampling_strategy == 'hard':
51
+ return hard_negative_indices
52
+ elif self.sampling_strategy == 'random':
53
+ return random_negatives
54
+ else: # mixed
55
+ return tf.concat([hard_negative_indices, random_negatives], axis=1)
56
+
57
+
58
+ class CurriculumLearningScheduler:
59
+ """Curriculum learning scheduler for progressive difficulty."""
60
+
61
+ def __init__(self, total_epochs, warmup_epochs=10):
62
+ self.total_epochs = total_epochs
63
+ self.warmup_epochs = warmup_epochs
64
+
65
+ def get_difficulty_schedule(self, epoch):
66
+ """Get curriculum parameters for current epoch."""
67
+ if epoch < self.warmup_epochs:
68
+ # Easy phase: more random negatives, lower temperature
69
+ hard_negative_ratio = 0.2
70
+ temperature = 2.0
71
+ negative_samples = 2
72
+ elif epoch < self.total_epochs * 0.6:
73
+ # Medium phase: balanced negatives
74
+ hard_negative_ratio = 0.5
75
+ temperature = 1.0
76
+ negative_samples = 4
77
+ else:
78
+ # Hard phase: more hard negatives, higher temperature
79
+ hard_negative_ratio = 0.8
80
+ temperature = 0.5
81
+ negative_samples = 6
82
+
83
+ return {
84
+ 'hard_negative_ratio': hard_negative_ratio,
85
+ 'temperature': temperature,
86
+ 'negative_samples': negative_samples
87
+ }
88
+
89
+
90
+ class ImprovedJointTrainer:
91
+ """Enhanced joint trainer with advanced techniques."""
92
+
93
+ def __init__(self,
94
+ embedding_dim: int = 128,
95
+ learning_rate: float = 0.001,
96
+ use_mixed_precision: bool = True,
97
+ use_curriculum_learning: bool = True,
98
+ use_hard_negatives: bool = True):
99
+
100
+ self.embedding_dim = embedding_dim
101
+ self.learning_rate = learning_rate
102
+ self.use_mixed_precision = use_mixed_precision
103
+ self.use_curriculum_learning = use_curriculum_learning
104
+ self.use_hard_negatives = use_hard_negatives
105
+
106
+ # Enable mixed precision if requested
107
+ if use_mixed_precision:
108
+ policy = tf.keras.mixed_precision.Policy('mixed_float16')
109
+ tf.keras.mixed_precision.set_global_policy(policy)
110
+
111
+ self.model = None
112
+ self.data_processor = None
113
+ self.curriculum_scheduler = None
114
+ self.hard_negative_sampler = None
115
+
116
+ def setup_model(self, data_processor: DataProcessor):
117
+ """Setup the improved model."""
118
+ self.data_processor = data_processor
119
+
120
+ # Create improved model
121
+ self.model = create_improved_model(
122
+ data_processor=data_processor,
123
+ embedding_dim=self.embedding_dim,
124
+ use_bias=True,
125
+ use_focal_loss=True
126
+ )
127
+
128
+ print(f"Created improved two-tower model with {self.embedding_dim}D embeddings")
129
+
130
+ def setup_curriculum_learning(self, total_epochs: int):
131
+ """Setup curriculum learning scheduler."""
132
+ if self.use_curriculum_learning:
133
+ self.curriculum_scheduler = CurriculumLearningScheduler(
134
+ total_epochs=total_epochs,
135
+ warmup_epochs=max(5, total_epochs // 10)
136
+ )
137
+ print("Curriculum learning enabled")
138
+
139
+ def setup_hard_negative_sampling(self, item_features: Dict[str, np.ndarray]):
140
+ """Setup hard negative sampling."""
141
+ if self.use_hard_negatives:
142
+ # Pre-compute item embeddings for efficient hard negative sampling
143
+ print("Pre-computing item embeddings for hard negative sampling...")
144
+
145
+ # Create a dummy batch to get item embeddings
146
+ batch_size = 1000
147
+ total_items = len(item_features['product_id'])
148
+
149
+ item_embeddings_list = []
150
+ for i in range(0, total_items, batch_size):
151
+ end_idx = min(i + batch_size, total_items)
152
+ batch_features = {
153
+ key: tf.constant(value[i:end_idx])
154
+ for key, value in item_features.items()
155
+ }
156
+
157
+ item_emb_output = self.model.item_tower(batch_features, training=False)
158
+ if isinstance(item_emb_output, tuple):
159
+ item_emb = item_emb_output[0] # Get embeddings, ignore bias
160
+ else:
161
+ item_emb = item_emb_output
162
+
163
+ item_embeddings_list.append(item_emb.numpy())
164
+
165
+ item_embeddings = np.vstack(item_embeddings_list)
166
+
167
+ self.hard_negative_sampler = HardNegativeSampler(
168
+ model=self.model,
169
+ item_embeddings=tf.constant(item_embeddings, dtype=tf.float32),
170
+ sampling_strategy='mixed'
171
+ )
172
+ print(f"Hard negative sampling enabled with {len(item_embeddings)} items")
173
+
174
+ def create_advanced_training_dataset(self,
175
+ features: Dict[str, np.ndarray],
176
+ batch_size: int = 256,
177
+ epoch: int = 0) -> tf.data.Dataset:
178
+ """Create training dataset with curriculum learning and hard negatives."""
179
+
180
+ # Get curriculum parameters
181
+ if self.curriculum_scheduler:
182
+ curriculum_params = self.curriculum_scheduler.get_difficulty_schedule(epoch)
183
+ print(f"Epoch {epoch}: {curriculum_params}")
184
+ else:
185
+ curriculum_params = {
186
+ 'hard_negative_ratio': 0.5,
187
+ 'temperature': 1.0,
188
+ 'negative_samples': 4
189
+ }
190
+
191
+ # Filter data based on curriculum (start with easier examples)
192
+ if epoch < 5: # Warmup epochs - use only high-confidence positive examples
193
+ positive_mask = features['rating'] == 1.0
194
+ if np.sum(positive_mask) > 0:
195
+ # Sample subset of positives and all negatives
196
+ positive_indices = np.where(positive_mask)[0]
197
+ negative_indices = np.where(features['rating'] == 0.0)[0]
198
+
199
+ # Sample subset for easier learning
200
+ n_positive_samples = min(len(positive_indices), len(negative_indices))
201
+ selected_positive = np.random.choice(
202
+ positive_indices, size=n_positive_samples, replace=False
203
+ )
204
+ selected_negative = np.random.choice(
205
+ negative_indices, size=n_positive_samples, replace=False
206
+ )
207
+
208
+ selected_indices = np.concatenate([selected_positive, selected_negative])
209
+ np.random.shuffle(selected_indices)
210
+
211
+ # Filter features
212
+ filtered_features = {
213
+ key: value[selected_indices] for key, value in features.items()
214
+ }
215
+ else:
216
+ filtered_features = features
217
+ else:
218
+ filtered_features = features
219
+
220
+ # Create dataset
221
+ dataset = create_tf_dataset(filtered_features, batch_size, shuffle=True)
222
+
223
+ return dataset
224
+
225
+ def compile_model(self):
226
+ """Compile model with advanced optimizer."""
227
+ # Use AdamW with learning rate scheduling
228
+ initial_learning_rate = self.learning_rate
229
+ lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
230
+ initial_learning_rate=initial_learning_rate,
231
+ first_decay_steps=1000,
232
+ t_mul=2.0,
233
+ m_mul=0.9,
234
+ alpha=0.01
235
+ )
236
+
237
+ optimizer = tf.keras.optimizers.AdamW(
238
+ learning_rate=lr_schedule,
239
+ weight_decay=1e-5,
240
+ beta_1=0.9,
241
+ beta_2=0.999,
242
+ epsilon=1e-7
243
+ )
244
+
245
+ # Enable mixed precision optimizer if needed
246
+ if self.use_mixed_precision:
247
+ optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
248
+
249
+ self.optimizer = optimizer
250
+ print(f"Model compiled with AdamW optimizer (lr={self.learning_rate})")
251
+
252
+ @tf.function
253
+ def train_step(self, features):
254
+ """Optimized training step with gradient scaling."""
255
+ with tf.GradientTape() as tape:
256
+ # Forward pass
257
+ loss_dict = self.model.compute_loss(features, training=True)
258
+ total_loss = loss_dict['total_loss']
259
+
260
+ # Scale loss for mixed precision
261
+ if self.use_mixed_precision:
262
+ scaled_loss = self.optimizer.get_scaled_loss(total_loss)
263
+ else:
264
+ scaled_loss = total_loss
265
+
266
+ # Compute gradients
267
+ if self.use_mixed_precision:
268
+ scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
269
+ gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
270
+ else:
271
+ gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
272
+
273
+ # Clip gradients to prevent exploding gradients
274
+ gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
275
+
276
+ # Apply gradients
277
+ self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
278
+
279
+ return loss_dict
280
+
281
+ def evaluate_model(self, validation_dataset):
282
+ """Evaluate model on validation set."""
283
+ total_losses = defaultdict(list)
284
+
285
+ for batch in validation_dataset:
286
+ loss_dict = self.model.compute_loss(batch, training=False)
287
+ for key, value in loss_dict.items():
288
+ total_losses[key].append(float(value))
289
+
290
+ # Average losses
291
+ avg_losses = {key: np.mean(values) for key, values in total_losses.items()}
292
+ return avg_losses
293
+
294
+ def train(self,
295
+ training_features: Dict[str, np.ndarray],
296
+ validation_features: Dict[str, np.ndarray],
297
+ epochs: int = 50,
298
+ batch_size: int = 256,
299
+ save_path: str = "src/artifacts/") -> Dict:
300
+ """Enhanced training loop with all improvements."""
301
+
302
+ print(f"Starting improved training for {epochs} epochs...")
303
+
304
+ # Setup components
305
+ self.setup_curriculum_learning(epochs)
306
+ self.compile_model()
307
+
308
+ # Create validation dataset
309
+ validation_dataset = create_tf_dataset(validation_features, batch_size, shuffle=False)
310
+
311
+ # Training history
312
+ history = defaultdict(list)
313
+ best_val_loss = float('inf')
314
+ patience_counter = 0
315
+ early_stopping_patience = 10
316
+
317
+ # Training loop
318
+ for epoch in range(epochs):
319
+ epoch_start_time = time.time()
320
+
321
+ # Create training dataset for this epoch (curriculum learning)
322
+ training_dataset = self.create_advanced_training_dataset(
323
+ training_features, batch_size, epoch
324
+ )
325
+
326
+ # Training
327
+ epoch_losses = defaultdict(list)
328
+ num_batches = 0
329
+
330
+ for batch in training_dataset:
331
+ loss_dict = self.train_step(batch)
332
+
333
+ for key, value in loss_dict.items():
334
+ epoch_losses[key].append(float(value))
335
+ num_batches += 1
336
+
337
+ # Average training losses
338
+ avg_train_losses = {
339
+ key: np.mean(values) for key, values in epoch_losses.items()
340
+ }
341
+
342
+ # Validation
343
+ avg_val_losses = self.evaluate_model(validation_dataset)
344
+
345
+ # Log progress
346
+ epoch_time = time.time() - epoch_start_time
347
+ print(f"Epoch {epoch+1}/{epochs} ({epoch_time:.1f}s):")
348
+ print(f" Train Loss: {avg_train_losses['total_loss']:.4f}")
349
+ print(f" Val Loss: {avg_val_losses['total_loss']:.4f}")
350
+ print(f" Val Rating Loss: {avg_val_losses['rating_loss']:.4f}")
351
+ print(f" Val Retrieval Loss: {avg_val_losses['retrieval_loss']:.4f}")
352
+
353
+ # Save history
354
+ for key, value in avg_train_losses.items():
355
+ history[f'train_{key}'].append(value)
356
+ for key, value in avg_val_losses.items():
357
+ history[f'val_{key}'].append(value)
358
+
359
+ # Early stopping and model saving
360
+ current_val_loss = avg_val_losses['total_loss']
361
+ if current_val_loss < best_val_loss:
362
+ best_val_loss = current_val_loss
363
+ patience_counter = 0
364
+
365
+ # Save best model
366
+ self.save_model(save_path, suffix='_improved_best')
367
+ print(f" 💾 Saved best model (val_loss: {best_val_loss:.4f})")
368
+ else:
369
+ patience_counter += 1
370
+
371
+ if patience_counter >= early_stopping_patience:
372
+ print(f"Early stopping at epoch {epoch+1}")
373
+ break
374
+
375
+ # Save final model and history
376
+ self.save_model(save_path, suffix='_improved_final')
377
+ self.save_training_history(dict(history), save_path)
378
+
379
+ print("✅ Improved training completed!")
380
+ return dict(history)
381
+
382
+ def save_model(self, save_path: str, suffix: str = ''):
383
+ """Save the trained model components."""
384
+ os.makedirs(save_path, exist_ok=True)
385
+
386
+ # Save model weights
387
+ self.model.item_tower.save_weights(f"{save_path}/improved_item_tower_weights{suffix}")
388
+ self.model.user_tower.save_weights(f"{save_path}/improved_user_tower_weights{suffix}")
389
+
390
+ if hasattr(self.model, 'rating_model'):
391
+ self.model.rating_model.save_weights(f"{save_path}/improved_rating_model_weights{suffix}")
392
+
393
+ # Save configuration
394
+ config = {
395
+ 'embedding_dim': self.embedding_dim,
396
+ 'learning_rate': self.learning_rate,
397
+ 'use_mixed_precision': self.use_mixed_precision,
398
+ 'use_curriculum_learning': self.use_curriculum_learning,
399
+ 'use_hard_negatives': self.use_hard_negatives
400
+ }
401
+
402
+ with open(f"{save_path}/improved_model_config{suffix}.txt", 'w') as f:
403
+ for key, value in config.items():
404
+ f.write(f"{key}: {value}\n")
405
+
406
+ print(f"Model saved to {save_path} with suffix '{suffix}'")
407
+
408
+ def save_training_history(self, history: Dict, save_path: str):
409
+ """Save training history."""
410
+ with open(f"{save_path}/improved_training_history.pkl", 'wb') as f:
411
+ pickle.dump(history, f)
412
+ print(f"Training history saved to {save_path}")
413
+
414
+
415
+ def main():
416
+ """Demo of improved training."""
417
+ print("🚀 IMPROVED TWO-TOWER TRAINING DEMO")
418
+ print("="*60)
419
+
420
+ # Load data
421
+ print("Loading training data...")
422
+ try:
423
+ with open("src/artifacts/training_features.pkl", 'rb') as f:
424
+ training_features = pickle.load(f)
425
+ with open("src/artifacts/validation_features.pkl", 'rb') as f:
426
+ validation_features = pickle.load(f)
427
+
428
+ print(f"Loaded {len(training_features['rating'])} training samples")
429
+ print(f"Loaded {len(validation_features['rating'])} validation samples")
430
+ except FileNotFoundError:
431
+ print("❌ Training data not found. Please run data preparation first.")
432
+ return
433
+
434
+ # Load data processor
435
+ data_processor = DataProcessor()
436
+ data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
437
+
438
+ # Create trainer
439
+ trainer = ImprovedJointTrainer(
440
+ embedding_dim=128,
441
+ learning_rate=0.001,
442
+ use_mixed_precision=True,
443
+ use_curriculum_learning=True,
444
+ use_hard_negatives=True
445
+ )
446
+
447
+ # Setup and train
448
+ trainer.setup_model(data_processor)
449
+
450
+ # Train model
451
+ history = trainer.train(
452
+ training_features=training_features,
453
+ validation_features=validation_features,
454
+ epochs=30,
455
+ batch_size=256
456
+ )
457
+
458
+ print("✅ Improved training completed successfully!")
459
+
460
+
461
+ if __name__ == "__main__":
462
+ main()
src/training/item_pretraining.py CHANGED
@@ -12,8 +12,8 @@ class ItemTowerPretrainer:
12
  """Handles pre-training of the item tower."""
13
 
14
  def __init__(self,
15
- embedding_dim: int = 64,
16
- hidden_dims: List[int] = [128, 64],
17
  dropout_rate: float = 0.2,
18
  learning_rate: float = 0.001):
19
 
@@ -111,19 +111,26 @@ class ItemTowerPretrainer:
111
  return history
112
 
113
  def generate_item_embeddings(self,
114
- dataset: tf.data.Dataset) -> Dict[int, np.ndarray]:
 
115
  """Generate embeddings for all items in the catalog."""
116
 
117
  item_embeddings = {}
118
 
 
 
 
119
  for batch in dataset:
120
  embeddings = self.item_tower(batch)
121
- product_ids = batch['product_id'].numpy()
122
 
123
- for i, product_id in enumerate(product_ids):
124
- item_embeddings[product_id] = embeddings[i].numpy()
 
 
125
 
126
  print(f"Generated embeddings for {len(item_embeddings)} items")
 
127
  return item_embeddings
128
 
129
  def save_model(self, save_path: str = "src/artifacts/"):
@@ -185,7 +192,7 @@ def main():
185
  # Initialize components
186
  data_processor = DataProcessor()
187
  pretrainer = ItemTowerPretrainer(
188
- embedding_dim=64,
189
  hidden_dims=[128, 64],
190
  dropout_rate=0.2,
191
  learning_rate=0.001
@@ -210,7 +217,7 @@ def main():
210
 
211
  # Generate embeddings
212
  print("Generating item embeddings...")
213
- item_embeddings = pretrainer.generate_item_embeddings(dataset)
214
 
215
  # Save everything
216
  print("Saving artifacts...")
 
12
  """Handles pre-training of the item tower."""
13
 
14
  def __init__(self,
15
+ embedding_dim: int = 128, # Updated to 128D output
16
+ hidden_dims: List[int] = [256, 128], # Scaled up
17
  dropout_rate: float = 0.2,
18
  learning_rate: float = 0.001):
19
 
 
111
  return history
112
 
113
  def generate_item_embeddings(self,
114
+ dataset: tf.data.Dataset,
115
+ data_processor: 'DataProcessor') -> Dict[int, np.ndarray]:
116
  """Generate embeddings for all items in the catalog."""
117
 
118
  item_embeddings = {}
119
 
120
+ # Create reverse mapping from vocab indices to actual item IDs
121
+ idx_to_item_id = {idx: item_id for item_id, idx in data_processor.item_vocab.items()}
122
+
123
  for batch in dataset:
124
  embeddings = self.item_tower(batch)
125
+ product_idx_batch = batch['product_id'].numpy()
126
 
127
+ for i, product_idx in enumerate(product_idx_batch):
128
+ # Convert vocab index back to actual item ID
129
+ actual_item_id = idx_to_item_id.get(product_idx, product_idx)
130
+ item_embeddings[actual_item_id] = embeddings[i].numpy()
131
 
132
  print(f"Generated embeddings for {len(item_embeddings)} items")
133
+ print(f"Sample item IDs: {list(item_embeddings.keys())[:5]}")
134
  return item_embeddings
135
 
136
  def save_model(self, save_path: str = "src/artifacts/"):
 
192
  # Initialize components
193
  data_processor = DataProcessor()
194
  pretrainer = ItemTowerPretrainer(
195
+ embedding_dim=128, # Updated to 128D
196
  hidden_dims=[128, 64],
197
  dropout_rate=0.2,
198
  learning_rate=0.001
 
217
 
218
  # Generate embeddings
219
  print("Generating item embeddings...")
220
+ item_embeddings = pretrainer.generate_item_embeddings(dataset, data_processor)
221
 
222
  # Save everything
223
  print("Saving artifacts...")
src/training/joint_training.py CHANGED
@@ -13,7 +13,7 @@ class JointTrainer:
13
  """Handles joint training of user and item towers."""
14
 
15
  def __init__(self,
16
- embedding_dim: int = 64,
17
  user_learning_rate: float = 0.001,
18
  item_learning_rate: float = 0.0001, # Lower LR for pre-trained item tower
19
  rating_weight: float = 1.0,
@@ -330,7 +330,7 @@ def main():
330
 
331
  # Initialize trainer
332
  trainer = JointTrainer(
333
- embedding_dim=64,
334
  user_learning_rate=0.001,
335
  item_learning_rate=0.0001,
336
  rating_weight=1.0,
 
13
  """Handles joint training of user and item towers."""
14
 
15
  def __init__(self,
16
+ embedding_dim: int = 128, # Updated to 128D output
17
  user_learning_rate: float = 0.001,
18
  item_learning_rate: float = 0.0001, # Lower LR for pre-trained item tower
19
  rating_weight: float = 1.0,
 
330
 
331
  # Initialize trainer
332
  trainer = JointTrainer(
333
+ embedding_dim=128, # Updated to 128D
334
  user_learning_rate=0.001,
335
  item_learning_rate=0.0001,
336
  rating_weight=1.0,
src/training/optimized_joint_training.py CHANGED
@@ -14,7 +14,7 @@ class OptimizedJointTrainer:
14
  """Optimized joint training with performance enhancements."""
15
 
16
  def __init__(self,
17
- embedding_dim: int = 64,
18
  user_learning_rate: float = 0.001,
19
  item_learning_rate: float = 0.0001,
20
  rating_weight: float = 1.0,
@@ -381,7 +381,7 @@ def main():
381
 
382
  print("Initializing optimized joint trainer...")
383
  trainer = OptimizedJointTrainer(
384
- embedding_dim=64,
385
  user_learning_rate=0.002, # Slightly higher for faster convergence
386
  item_learning_rate=0.0002,
387
  rating_weight=1.0,
 
14
  """Optimized joint training with performance enhancements."""
15
 
16
  def __init__(self,
17
+ embedding_dim: int = 128, # Updated to 128D output
18
  user_learning_rate: float = 0.001,
19
  item_learning_rate: float = 0.0001,
20
  rating_weight: float = 1.0,
 
381
 
382
  print("Initializing optimized joint trainer...")
383
  trainer = OptimizedJointTrainer(
384
+ embedding_dim=128, # Updated to 128D
385
  user_learning_rate=0.002, # Slightly higher for faster convergence
386
  item_learning_rate=0.0002,
387
  rating_weight=1.0,