Spaces:
Sleeping
Major codebase cleanup and feature additions
Browse filesNew Features:
- Add dual API implementations (api_2phase.py, api_joint.py)
- Add 2-phase and joint training scripts
- Add enhanced recommendation engine with 128D embeddings
- Add improved joint training with curriculum learning
- Add enhanced two-tower model architecture
- Add optimized dataset creator
Enhancements:
- Enhance main API with new endpoints and features
- Improve frontend UI with advanced styling and components
- Upgrade inference engines with better recommendation logic
- Enhance data preprocessing with categorical demographics
- Improve training modules with better optimization
Cleanup:
- Remove debug, analysis, and documentation files
- Update gitignore with ML-specific patterns
- Remove backup directories and temporary files
- Streamline codebase structure
Files: 25 changed, 5843 insertions, 829 deletions
- .gitignore +29 -1
- CATEGORICAL_DEMOGRAPHICS_SUMMARY.md +0 -113
- analyze_recommendation_quality.py +0 -556
- analyze_recommendations.py +543 -0
- api/main.py +171 -5
- api_2phase.py +521 -0
- api_joint.py +522 -0
- frontend/src/App.css +1030 -15
- frontend/src/App.js +505 -46
- run_2phase_training.py +206 -0
- run_joint_training.py +453 -0
- src/inference/enhanced_recommendation_engine_128d.py +499 -0
- src/inference/faiss_index.py +17 -13
- src/inference/recommendation_engine.py +46 -19
- src/models/enhanced_two_tower.py +574 -0
- src/models/item_tower.py +2 -2
- src/models/user_tower.py +2 -2
- src/preprocessing/data_loader.py +72 -41
- src/preprocessing/optimized_dataset_creator.py +111 -0
- src/preprocessing/user_data_preparation.py +58 -3
- src/training/fast_joint_training.py +1 -1
- src/training/improved_joint_training.py +462 -0
- src/training/item_pretraining.py +15 -8
- src/training/joint_training.py +2 -2
- src/training/optimized_joint_training.py +2 -2
|
@@ -193,4 +193,32 @@ data/
|
|
| 193 |
.Spotlight-V100
|
| 194 |
.Trashes
|
| 195 |
ehthumbs.db
|
| 196 |
-
Thumbs.db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
.Spotlight-V100
|
| 194 |
.Trashes
|
| 195 |
ehthumbs.db
|
| 196 |
+
Thumbs.db
|
| 197 |
+
|
| 198 |
+
# Analysis and debug files
|
| 199 |
+
debug_*.py
|
| 200 |
+
test_*.py
|
| 201 |
+
analyze_enhanced_*.py
|
| 202 |
+
analyze_recommendation_quality.py
|
| 203 |
+
*_analysis_*
|
| 204 |
+
*_report.*
|
| 205 |
+
simple_*.py
|
| 206 |
+
|
| 207 |
+
# Keep analyze_recommendations.py - it's wanted
|
| 208 |
+
!analyze_recommendations.py
|
| 209 |
+
|
| 210 |
+
# Backup directories
|
| 211 |
+
*_backup/
|
| 212 |
+
*_bak/
|
| 213 |
+
artifacts_backup/
|
| 214 |
+
|
| 215 |
+
# Temporary files
|
| 216 |
+
*.tmp
|
| 217 |
+
*.temp
|
| 218 |
+
*.cache
|
| 219 |
+
|
| 220 |
+
# Configuration files with sensitive data
|
| 221 |
+
config.json
|
| 222 |
+
secrets.json
|
| 223 |
+
.secret
|
| 224 |
+
.key
|
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
# Categorical Demographics Implementation Summary
|
| 2 |
-
|
| 3 |
-
## ✅ **IMPLEMENTATION COMPLETE**
|
| 4 |
-
|
| 5 |
-
Successfully converted age and income from continuous normalized features to categorical embeddings, achieving the goal of reducing demographics to 25% of total input dimensions.
|
| 6 |
-
|
| 7 |
-
---
|
| 8 |
-
|
| 9 |
-
## 🎯 **Key Changes Made**
|
| 10 |
-
|
| 11 |
-
### **1. Age Categorization (6 Categories)**
|
| 12 |
-
- **Teen (0)**: Under 18
|
| 13 |
-
- **Young Adult (1)**: 18-25
|
| 14 |
-
- **Adult (2)**: 26-35
|
| 15 |
-
- **Middle Age (3)**: 36-50
|
| 16 |
-
- **Mature (4)**: 51-65
|
| 17 |
-
- **Senior (5)**: 65+
|
| 18 |
-
|
| 19 |
-
### **2. Income Categorization (5 Categories)**
|
| 20 |
-
- **Low Income (0)**: Bottom 20% (≤$56,276)
|
| 21 |
-
- **Lower Middle (1)**: 20-40% ($56,276-$69,236)
|
| 22 |
-
- **Middle (2)**: 40-60% ($69,236-$80,661)
|
| 23 |
-
- **Upper Middle (3)**: 60-80% ($80,661-$94,284)
|
| 24 |
-
- **High Income (4)**: Top 20% (≥$94,284)
|
| 25 |
-
|
| 26 |
-
### **3. Embedding Dimensions**
|
| 27 |
-
**Original Tower (64D):**
|
| 28 |
-
- Age: 4D, Income: 4D, Gender: 4D
|
| 29 |
-
- **Total Demographics**: 12D (18.8% of input)
|
| 30 |
-
|
| 31 |
-
**Improved Tower (128D):**
|
| 32 |
-
- Age: 8D, Income: 8D, Gender: 8D
|
| 33 |
-
- **Total Demographics**: 24D (18.8% of input)
|
| 34 |
-
|
| 35 |
-
---
|
| 36 |
-
|
| 37 |
-
## 📁 **Files Modified**
|
| 38 |
-
|
| 39 |
-
### **Data Preparation**
|
| 40 |
-
- `src/preprocessing/user_data_preparation.py`
|
| 41 |
-
- Added `categorize_age()` and `categorize_income()` functions
|
| 42 |
-
- Updated `prepare_user_features()` to output categorical features
|
| 43 |
-
|
| 44 |
-
### **Model Architecture**
|
| 45 |
-
- `src/models/user_tower.py`
|
| 46 |
-
- Replaced normalization layers with embedding layers
|
| 47 |
-
- Updated forward pass for categorical inputs
|
| 48 |
-
|
| 49 |
-
- `src/models/improved_two_tower.py`
|
| 50 |
-
- Same embedding updates as original tower
|
| 51 |
-
- Maintained sophisticated history aggregation
|
| 52 |
-
|
| 53 |
-
### **Training Scripts**
|
| 54 |
-
- `src/training/optimized_joint_training.py`
|
| 55 |
-
- `src/training/joint_training.py`
|
| 56 |
-
- `src/training/fast_joint_training.py`
|
| 57 |
-
- Removed normalization adaptation calls
|
| 58 |
-
|
| 59 |
-
### **Inference Engine**
|
| 60 |
-
- `src/inference/recommendation_engine.py`
|
| 61 |
-
- Added categorization functions for real-time inference
|
| 62 |
-
- Updated `prepare_user_features()` to categorize raw inputs
|
| 63 |
-
- Added income threshold loading from training data
|
| 64 |
-
|
| 65 |
-
---
|
| 66 |
-
|
| 67 |
-
## 🔍 **Verification Results**
|
| 68 |
-
|
| 69 |
-
✅ **All Tests Pass:**
|
| 70 |
-
- Age categorization: 6 categories (0-5) ✅
|
| 71 |
-
- Income categorization: 5 categories (0-4) ✅
|
| 72 |
-
- Training features: Correct int32 dtypes ✅
|
| 73 |
-
- User towers: Proper embedding dimensions ✅
|
| 74 |
-
- Inference engine: Successful categorical conversion ✅
|
| 75 |
-
- Recommendation engines: Working with categorical inputs ✅
|
| 76 |
-
|
| 77 |
-
---
|
| 78 |
-
|
| 79 |
-
## 📊 **Benefits Achieved**
|
| 80 |
-
|
| 81 |
-
### **1. Balanced Feature Representation**
|
| 82 |
-
- **Before**: Demographics 75% (96D), History 25% (32D)
|
| 83 |
-
- **After**: Demographics 19% (24D), History 81% (104D)
|
| 84 |
-
|
| 85 |
-
### **2. Better Learning Patterns**
|
| 86 |
-
- **Interpretable segments**: Clear demographic groups vs continuous values
|
| 87 |
-
- **Non-linear relationships**: Each category learns distinct behaviors
|
| 88 |
-
- **Reduced bias**: Less dependence on exact age/income values
|
| 89 |
-
- **Better generalization**: Discrete categories vs continuous normalization
|
| 90 |
-
|
| 91 |
-
### **3. Improved Model Architecture**
|
| 92 |
-
- **Smaller demographics footprint**: More capacity for behavioral signals
|
| 93 |
-
- **Category-specific patterns**: Age/income groups with unique preferences
|
| 94 |
-
- **Embedding benefits**: Learned representations vs fixed normalization
|
| 95 |
-
|
| 96 |
-
---
|
| 97 |
-
|
| 98 |
-
## 🚀 **Ready for Training**
|
| 99 |
-
|
| 100 |
-
The categorical demographics implementation is complete and verified. The system now:
|
| 101 |
-
|
| 102 |
-
1. **Prioritizes behavioral signals** (81%) over demographics (19%)
|
| 103 |
-
2. **Uses interpretable demographic segments** instead of continuous values
|
| 104 |
-
3. **Maintains all existing functionality** with enhanced representation
|
| 105 |
-
4. **Is ready for improved model training** with better feature balance
|
| 106 |
-
|
| 107 |
-
To train the improved model with categorical demographics:
|
| 108 |
-
|
| 109 |
-
```bash
|
| 110 |
-
python train_improved_model.py --embedding-dim 128 --epochs-per-stage 15
|
| 111 |
-
```
|
| 112 |
-
|
| 113 |
-
The enhanced recommendation system should now achieve better personalization through balanced feature representation and categorical demographic learning.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,556 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Comprehensive analysis of recommendation quality from the two-tower model.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import sys
|
| 7 |
-
import os
|
| 8 |
-
import numpy as np
|
| 9 |
-
import pandas as pd
|
| 10 |
-
from collections import Counter, defaultdict
|
| 11 |
-
import time
|
| 12 |
-
|
| 13 |
-
sys.path.append('/home/user/Desktop/RecSys-HP')
|
| 14 |
-
from src.inference.recommendation_engine import RecommendationEngine
|
| 15 |
-
from src.utils.real_user_selector import RealUserSelector
|
| 16 |
-
|
| 17 |
-
def analyze_score_distribution():
|
| 18 |
-
"""Analyze the distribution of recommendation scores."""
|
| 19 |
-
|
| 20 |
-
print("📊 SCORE DISTRIBUTION ANALYSIS")
|
| 21 |
-
print("="*50)
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
engine = RecommendationEngine()
|
| 25 |
-
real_user_selector = RealUserSelector()
|
| 26 |
-
|
| 27 |
-
# Get multiple users for comprehensive analysis
|
| 28 |
-
test_users = real_user_selector.get_real_users(n=10, min_interactions=10)
|
| 29 |
-
|
| 30 |
-
all_scores = {
|
| 31 |
-
'collaborative': [],
|
| 32 |
-
'hybrid': [],
|
| 33 |
-
'content': []
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
print(f"Testing with {len(test_users)} users...")
|
| 37 |
-
|
| 38 |
-
for i, user in enumerate(test_users):
|
| 39 |
-
print(f"\nUser {i+1}/10 - {user['user_id']} ({user['age']}yr {user['gender']}):")
|
| 40 |
-
|
| 41 |
-
# Test collaborative filtering
|
| 42 |
-
try:
|
| 43 |
-
collab_recs = engine.recommend_items_collaborative(
|
| 44 |
-
age=user['age'],
|
| 45 |
-
gender=user['gender'],
|
| 46 |
-
income=user['income'],
|
| 47 |
-
interaction_history=user['interaction_history'][:20],
|
| 48 |
-
k=20
|
| 49 |
-
)
|
| 50 |
-
collab_scores = [score for _, score, _ in collab_recs]
|
| 51 |
-
all_scores['collaborative'].extend(collab_scores)
|
| 52 |
-
|
| 53 |
-
print(f" Collaborative: {min(collab_scores):.4f} - {max(collab_scores):.4f} (std: {np.std(collab_scores):.4f})")
|
| 54 |
-
|
| 55 |
-
except Exception as e:
|
| 56 |
-
print(f" Collaborative failed: {e}")
|
| 57 |
-
|
| 58 |
-
# Test hybrid
|
| 59 |
-
try:
|
| 60 |
-
hybrid_recs = engine.recommend_items_hybrid(
|
| 61 |
-
age=user['age'],
|
| 62 |
-
gender=user['gender'],
|
| 63 |
-
income=user['income'],
|
| 64 |
-
interaction_history=user['interaction_history'][:20],
|
| 65 |
-
k=20,
|
| 66 |
-
collaborative_weight=0.7
|
| 67 |
-
)
|
| 68 |
-
hybrid_scores = [score for _, score, _ in hybrid_recs]
|
| 69 |
-
all_scores['hybrid'].extend(hybrid_scores)
|
| 70 |
-
|
| 71 |
-
print(f" Hybrid: {min(hybrid_scores):.4f} - {max(hybrid_scores):.4f} (std: {np.std(hybrid_scores):.4f})")
|
| 72 |
-
|
| 73 |
-
except Exception as e:
|
| 74 |
-
print(f" Hybrid failed: {e}")
|
| 75 |
-
|
| 76 |
-
# Test content-based (if user has history)
|
| 77 |
-
if user['interaction_history']:
|
| 78 |
-
try:
|
| 79 |
-
content_recs = engine.recommend_items_content_based(
|
| 80 |
-
seed_item_id=user['interaction_history'][0],
|
| 81 |
-
k=20
|
| 82 |
-
)
|
| 83 |
-
content_scores = [score for _, score, _ in content_recs]
|
| 84 |
-
all_scores['content'].extend(content_scores)
|
| 85 |
-
|
| 86 |
-
print(f" Content: {min(content_scores):.4f} - {max(content_scores):.4f} (std: {np.std(content_scores):.4f})")
|
| 87 |
-
|
| 88 |
-
except Exception as e:
|
| 89 |
-
print(f" Content failed: {e}")
|
| 90 |
-
|
| 91 |
-
# Overall score analysis
|
| 92 |
-
print(f"\n📈 OVERALL SCORE STATISTICS:")
|
| 93 |
-
for method, scores in all_scores.items():
|
| 94 |
-
if scores:
|
| 95 |
-
print(f"\n{method.upper()}:")
|
| 96 |
-
print(f" Total scores: {len(scores)}")
|
| 97 |
-
print(f" Range: {min(scores):.4f} - {max(scores):.4f}")
|
| 98 |
-
print(f" Mean: {np.mean(scores):.4f}")
|
| 99 |
-
print(f" Std: {np.std(scores):.4f}")
|
| 100 |
-
print(f" Variance: {np.var(scores):.6f}")
|
| 101 |
-
|
| 102 |
-
# Score distribution percentiles
|
| 103 |
-
percentiles = [10, 25, 50, 75, 90]
|
| 104 |
-
perc_values = np.percentile(scores, percentiles)
|
| 105 |
-
print(f" Percentiles: {dict(zip(percentiles, perc_values))}")
|
| 106 |
-
|
| 107 |
-
# Quality assessment
|
| 108 |
-
score_range = max(scores) - min(scores)
|
| 109 |
-
if score_range < 0.1:
|
| 110 |
-
print(f" ⚠️ WARNING: Low score range ({score_range:.4f}) - poor discrimination")
|
| 111 |
-
elif score_range < 0.3:
|
| 112 |
-
print(f" ⚠️ CAUTION: Moderate score range ({score_range:.4f})")
|
| 113 |
-
else:
|
| 114 |
-
print(f" ✅ GOOD: Wide score range ({score_range:.4f})")
|
| 115 |
-
|
| 116 |
-
if np.var(scores) < 0.001:
|
| 117 |
-
print(f" ⚠️ WARNING: Very low variance - poor ranking ability")
|
| 118 |
-
elif np.var(scores) < 0.01:
|
| 119 |
-
print(f" ⚠️ CAUTION: Low variance")
|
| 120 |
-
else:
|
| 121 |
-
print(f" ✅ GOOD: Adequate variance for ranking")
|
| 122 |
-
|
| 123 |
-
return all_scores
|
| 124 |
-
|
| 125 |
-
except Exception as e:
|
| 126 |
-
print(f"❌ Score analysis failed: {e}")
|
| 127 |
-
return None
|
| 128 |
-
|
| 129 |
-
def analyze_category_alignment():
|
| 130 |
-
"""Analyze how well recommendations align with user category preferences."""
|
| 131 |
-
|
| 132 |
-
print(f"\n🎯 CATEGORY ALIGNMENT ANALYSIS")
|
| 133 |
-
print("="*40)
|
| 134 |
-
|
| 135 |
-
try:
|
| 136 |
-
engine = RecommendationEngine()
|
| 137 |
-
real_user_selector = RealUserSelector()
|
| 138 |
-
|
| 139 |
-
test_users = real_user_selector.get_real_users(n=5, min_interactions=15)
|
| 140 |
-
|
| 141 |
-
alignment_results = []
|
| 142 |
-
|
| 143 |
-
for user in test_users:
|
| 144 |
-
print(f"\nUser {user['user_id']} ({user['age']}yr {user['gender']}):")
|
| 145 |
-
|
| 146 |
-
# Get user's detailed interactions
|
| 147 |
-
user_details = real_user_selector.get_user_interaction_details(user['user_id'])
|
| 148 |
-
|
| 149 |
-
# Analyze user's category preferences
|
| 150 |
-
user_categories = []
|
| 151 |
-
for interaction in user_details['timeline']:
|
| 152 |
-
category = interaction.get('category_code', 'Unknown')
|
| 153 |
-
user_categories.append(category)
|
| 154 |
-
|
| 155 |
-
user_category_counts = Counter(user_categories)
|
| 156 |
-
total_user_interactions = len(user_categories)
|
| 157 |
-
|
| 158 |
-
print(f" User's top categories:")
|
| 159 |
-
for category, count in user_category_counts.most_common(3):
|
| 160 |
-
percentage = (count / total_user_interactions) * 100
|
| 161 |
-
print(f" {category}: {count} ({percentage:.1f}%)")
|
| 162 |
-
|
| 163 |
-
# Get recommendations
|
| 164 |
-
recs = engine.recommend_items_hybrid(
|
| 165 |
-
age=user['age'],
|
| 166 |
-
gender=user['gender'],
|
| 167 |
-
income=user['income'],
|
| 168 |
-
interaction_history=user['interaction_history'][:20],
|
| 169 |
-
k=20,
|
| 170 |
-
collaborative_weight=0.7
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
# Analyze recommendation categories
|
| 174 |
-
rec_categories = []
|
| 175 |
-
for _, _, item_info in recs:
|
| 176 |
-
category = item_info.get('category_code', 'Unknown')
|
| 177 |
-
rec_categories.append(category)
|
| 178 |
-
|
| 179 |
-
rec_category_counts = Counter(rec_categories)
|
| 180 |
-
|
| 181 |
-
print(f" Recommendation categories:")
|
| 182 |
-
for category, count in rec_category_counts.most_common(3):
|
| 183 |
-
percentage = (count / len(rec_categories)) * 100
|
| 184 |
-
match = "✅" if category in user_category_counts else "🆕"
|
| 185 |
-
print(f" {category}: {count} ({percentage:.1f}%) {match}")
|
| 186 |
-
|
| 187 |
-
# Calculate alignment metrics
|
| 188 |
-
user_cats = set(user_category_counts.keys())
|
| 189 |
-
rec_cats = set(rec_category_counts.keys())
|
| 190 |
-
|
| 191 |
-
intersection = user_cats & rec_cats
|
| 192 |
-
alignment_percentage = len(intersection) / len(rec_cats) * 100 if rec_cats else 0
|
| 193 |
-
|
| 194 |
-
# Calculate weighted alignment (by user preference strength)
|
| 195 |
-
weighted_alignment = 0
|
| 196 |
-
for category in intersection:
|
| 197 |
-
user_weight = user_category_counts[category] / total_user_interactions
|
| 198 |
-
rec_weight = rec_category_counts[category] / len(rec_categories)
|
| 199 |
-
weighted_alignment += min(user_weight, rec_weight)
|
| 200 |
-
|
| 201 |
-
alignment_results.append({
|
| 202 |
-
'user_id': user['user_id'],
|
| 203 |
-
'alignment_percentage': alignment_percentage,
|
| 204 |
-
'weighted_alignment': weighted_alignment * 100,
|
| 205 |
-
'user_categories': len(user_cats),
|
| 206 |
-
'rec_categories': len(rec_cats),
|
| 207 |
-
'matched_categories': len(intersection)
|
| 208 |
-
})
|
| 209 |
-
|
| 210 |
-
print(f" Alignment: {alignment_percentage:.1f}% ({len(intersection)}/{len(rec_cats)} categories)")
|
| 211 |
-
print(f" Weighted alignment: {weighted_alignment * 100:.1f}%")
|
| 212 |
-
|
| 213 |
-
# Overall alignment analysis
|
| 214 |
-
print(f"\n📊 OVERALL ALIGNMENT STATISTICS:")
|
| 215 |
-
avg_alignment = np.mean([r['alignment_percentage'] for r in alignment_results])
|
| 216 |
-
avg_weighted = np.mean([r['weighted_alignment'] for r in alignment_results])
|
| 217 |
-
avg_user_cats = np.mean([r['user_categories'] for r in alignment_results])
|
| 218 |
-
avg_rec_cats = np.mean([r['rec_categories'] for r in alignment_results])
|
| 219 |
-
|
| 220 |
-
print(f" Average alignment: {avg_alignment:.1f}%")
|
| 221 |
-
print(f" Average weighted alignment: {avg_weighted:.1f}%")
|
| 222 |
-
print(f" Average user categories: {avg_user_cats:.1f}")
|
| 223 |
-
print(f" Average rec categories: {avg_rec_cats:.1f}")
|
| 224 |
-
|
| 225 |
-
# Quality assessment
|
| 226 |
-
if avg_alignment < 20:
|
| 227 |
-
print(f" ❌ POOR: Very low category alignment")
|
| 228 |
-
elif avg_alignment < 40:
|
| 229 |
-
print(f" ⚠️ FAIR: Low category alignment")
|
| 230 |
-
elif avg_alignment < 60:
|
| 231 |
-
print(f" ✅ GOOD: Moderate category alignment")
|
| 232 |
-
else:
|
| 233 |
-
print(f" 🎉 EXCELLENT: High category alignment")
|
| 234 |
-
|
| 235 |
-
return alignment_results
|
| 236 |
-
|
| 237 |
-
except Exception as e:
|
| 238 |
-
print(f"❌ Category alignment analysis failed: {e}")
|
| 239 |
-
return None
|
| 240 |
-
|
| 241 |
-
def analyze_diversity_metrics():
|
| 242 |
-
"""Analyze diversity metrics in recommendations."""
|
| 243 |
-
|
| 244 |
-
print(f"\n🌈 DIVERSITY ANALYSIS")
|
| 245 |
-
print("="*30)
|
| 246 |
-
|
| 247 |
-
try:
|
| 248 |
-
engine = RecommendationEngine()
|
| 249 |
-
real_user_selector = RealUserSelector()
|
| 250 |
-
|
| 251 |
-
test_users = real_user_selector.get_real_users(n=5, min_interactions=10)
|
| 252 |
-
|
| 253 |
-
diversity_results = []
|
| 254 |
-
|
| 255 |
-
for user in test_users:
|
| 256 |
-
print(f"\nUser {user['user_id']}:")
|
| 257 |
-
|
| 258 |
-
# Get recommendations
|
| 259 |
-
recs = engine.recommend_items_hybrid(
|
| 260 |
-
age=user['age'],
|
| 261 |
-
gender=user['gender'],
|
| 262 |
-
income=user['income'],
|
| 263 |
-
interaction_history=user['interaction_history'][:20],
|
| 264 |
-
k=20,
|
| 265 |
-
collaborative_weight=0.7
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
# Extract features for diversity analysis
|
| 269 |
-
categories = [item_info.get('category_code', 'Unknown') for _, _, item_info in recs]
|
| 270 |
-
brands = [item_info.get('brand', 'Unknown') for _, _, item_info in recs]
|
| 271 |
-
prices = [item_info.get('price', 0) for _, _, item_info in recs]
|
| 272 |
-
|
| 273 |
-
# Calculate diversity metrics
|
| 274 |
-
category_diversity = len(set(categories)) / len(categories) if categories else 0
|
| 275 |
-
brand_diversity = len(set(brands)) / len(brands) if brands else 0
|
| 276 |
-
|
| 277 |
-
# Price diversity (coefficient of variation)
|
| 278 |
-
price_diversity = np.std(prices) / np.mean(prices) if np.mean(prices) > 0 else 0
|
| 279 |
-
|
| 280 |
-
# Intra-list diversity (average pairwise dissimilarity)
|
| 281 |
-
category_counts = Counter(categories)
|
| 282 |
-
gini_categories = 1 - sum((count / len(categories)) ** 2 for count in category_counts.values())
|
| 283 |
-
|
| 284 |
-
diversity_results.append({
|
| 285 |
-
'user_id': user['user_id'],
|
| 286 |
-
'category_diversity': category_diversity,
|
| 287 |
-
'brand_diversity': brand_diversity,
|
| 288 |
-
'price_diversity': price_diversity,
|
| 289 |
-
'gini_categories': gini_categories,
|
| 290 |
-
'unique_categories': len(set(categories)),
|
| 291 |
-
'unique_brands': len(set(brands))
|
| 292 |
-
})
|
| 293 |
-
|
| 294 |
-
print(f" Categories: {len(set(categories))} unique ({category_diversity:.2f} ratio)")
|
| 295 |
-
print(f" Brands: {len(set(brands))} unique ({brand_diversity:.2f} ratio)")
|
| 296 |
-
print(f" Price range: ${min(prices):.2f} - ${max(prices):.2f}")
|
| 297 |
-
print(f" Gini (categories): {gini_categories:.2f}")
|
| 298 |
-
|
| 299 |
-
# Overall diversity statistics
|
| 300 |
-
print(f"\n📊 OVERALL DIVERSITY STATISTICS:")
|
| 301 |
-
avg_cat_diversity = np.mean([r['category_diversity'] for r in diversity_results])
|
| 302 |
-
avg_brand_diversity = np.mean([r['brand_diversity'] for r in diversity_results])
|
| 303 |
-
avg_gini = np.mean([r['gini_categories'] for r in diversity_results])
|
| 304 |
-
avg_unique_cats = np.mean([r['unique_categories'] for r in diversity_results])
|
| 305 |
-
|
| 306 |
-
print(f" Average category diversity: {avg_cat_diversity:.2f}")
|
| 307 |
-
print(f" Average brand diversity: {avg_brand_diversity:.2f}")
|
| 308 |
-
print(f" Average Gini coefficient: {avg_gini:.2f}")
|
| 309 |
-
print(f" Average unique categories: {avg_unique_cats:.1f}")
|
| 310 |
-
|
| 311 |
-
# Quality assessment
|
| 312 |
-
if avg_cat_diversity < 0.3:
|
| 313 |
-
print(f" ❌ POOR: Low category diversity - recommendations too similar")
|
| 314 |
-
elif avg_cat_diversity < 0.5:
|
| 315 |
-
print(f" ⚠️ FAIR: Moderate category diversity")
|
| 316 |
-
else:
|
| 317 |
-
print(f" ✅ GOOD: High category diversity")
|
| 318 |
-
|
| 319 |
-
return diversity_results
|
| 320 |
-
|
| 321 |
-
except Exception as e:
|
| 322 |
-
print(f"❌ Diversity analysis failed: {e}")
|
| 323 |
-
return None
|
| 324 |
-
|
| 325 |
-
def analyze_embedding_quality():
|
| 326 |
-
"""Analyze the quality of user and item embeddings."""
|
| 327 |
-
|
| 328 |
-
print(f"\n🧠 EMBEDDING QUALITY ANALYSIS")
|
| 329 |
-
print("="*35)
|
| 330 |
-
|
| 331 |
-
try:
|
| 332 |
-
engine = RecommendationEngine()
|
| 333 |
-
real_user_selector = RealUserSelector()
|
| 334 |
-
|
| 335 |
-
test_users = real_user_selector.get_real_users(n=3, min_interactions=10)
|
| 336 |
-
|
| 337 |
-
user_embeddings = []
|
| 338 |
-
user_item_similarities = []
|
| 339 |
-
|
| 340 |
-
for user in test_users:
|
| 341 |
-
print(f"\nUser {user['user_id']}:")
|
| 342 |
-
|
| 343 |
-
# Get user embedding
|
| 344 |
-
user_emb = engine.get_user_embedding(
|
| 345 |
-
age=user['age'],
|
| 346 |
-
gender=user['gender'],
|
| 347 |
-
income=user['income'],
|
| 348 |
-
interaction_history=user['interaction_history'][:10]
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
user_embeddings.append(user_emb)
|
| 352 |
-
|
| 353 |
-
print(f" User embedding shape: {user_emb.shape}")
|
| 354 |
-
print(f" User embedding norm: {np.linalg.norm(user_emb):.4f}")
|
| 355 |
-
print(f" User embedding mean: {user_emb.mean():.4f}")
|
| 356 |
-
print(f" User embedding std: {user_emb.std():.4f}")
|
| 357 |
-
|
| 358 |
-
# Get embeddings for user's interaction history
|
| 359 |
-
item_similarities = []
|
| 360 |
-
for item_id in user['interaction_history'][:5]:
|
| 361 |
-
item_emb = engine.get_item_embedding(item_id)
|
| 362 |
-
if item_emb is not None:
|
| 363 |
-
similarity = np.dot(user_emb, item_emb)
|
| 364 |
-
item_similarities.append(similarity)
|
| 365 |
-
|
| 366 |
-
if item_similarities:
|
| 367 |
-
user_item_similarities.extend(item_similarities)
|
| 368 |
-
print(f" Avg similarity with interacted items: {np.mean(item_similarities):.4f}")
|
| 369 |
-
print(f" Similarity range: {min(item_similarities):.4f} - {max(item_similarities):.4f}")
|
| 370 |
-
|
| 371 |
-
# Analyze user embedding diversity
|
| 372 |
-
if len(user_embeddings) > 1:
|
| 373 |
-
user_embeddings = np.array(user_embeddings)
|
| 374 |
-
|
| 375 |
-
# User-user similarities
|
| 376 |
-
user_similarities = []
|
| 377 |
-
for i in range(len(user_embeddings)):
|
| 378 |
-
for j in range(i+1, len(user_embeddings)):
|
| 379 |
-
sim = np.dot(user_embeddings[i], user_embeddings[j])
|
| 380 |
-
user_similarities.append(sim)
|
| 381 |
-
|
| 382 |
-
print(f"\n📊 USER EMBEDDING ANALYSIS:")
|
| 383 |
-
print(f" User-user similarities: {np.mean(user_similarities):.4f} ± {np.std(user_similarities):.4f}")
|
| 384 |
-
print(f" User-item similarities: {np.mean(user_item_similarities):.4f} ± {np.std(user_item_similarities):.4f}")
|
| 385 |
-
|
| 386 |
-
# Quality assessment
|
| 387 |
-
if np.mean(user_similarities) > 0.9:
|
| 388 |
-
print(f" ⚠️ WARNING: Users too similar - possible embedding collapse")
|
| 389 |
-
elif np.mean(user_similarities) > 0.7:
|
| 390 |
-
print(f" ⚠️ CAUTION: High user similarity - limited personalization")
|
| 391 |
-
else:
|
| 392 |
-
print(f" ✅ GOOD: Adequate user embedding diversity")
|
| 393 |
-
|
| 394 |
-
return {
|
| 395 |
-
'user_embeddings': user_embeddings,
|
| 396 |
-
'user_similarities': user_similarities if len(user_embeddings) > 1 else [],
|
| 397 |
-
'user_item_similarities': user_item_similarities
|
| 398 |
-
}
|
| 399 |
-
|
| 400 |
-
except Exception as e:
|
| 401 |
-
print(f"❌ Embedding analysis failed: {e}")
|
| 402 |
-
return None
|
| 403 |
-
|
| 404 |
-
def analyze_performance_metrics():
|
| 405 |
-
"""Analyze performance and efficiency metrics."""
|
| 406 |
-
|
| 407 |
-
print(f"\n⚡ PERFORMANCE ANALYSIS")
|
| 408 |
-
print("="*25)
|
| 409 |
-
|
| 410 |
-
try:
|
| 411 |
-
engine = RecommendationEngine()
|
| 412 |
-
real_user_selector = RealUserSelector()
|
| 413 |
-
|
| 414 |
-
test_user = real_user_selector.get_real_users(n=1, min_interactions=10)[0]
|
| 415 |
-
|
| 416 |
-
# Test recommendation generation speed
|
| 417 |
-
print("Testing recommendation generation speed...")
|
| 418 |
-
|
| 419 |
-
methods = [
|
| 420 |
-
('Collaborative', lambda: engine.recommend_items_collaborative(
|
| 421 |
-
age=test_user['age'], gender=test_user['gender'],
|
| 422 |
-
income=test_user['income'], interaction_history=test_user['interaction_history'][:20], k=10
|
| 423 |
-
)),
|
| 424 |
-
('Hybrid', lambda: engine.recommend_items_hybrid(
|
| 425 |
-
age=test_user['age'], gender=test_user['gender'],
|
| 426 |
-
income=test_user['income'], interaction_history=test_user['interaction_history'][:20], k=10
|
| 427 |
-
)),
|
| 428 |
-
]
|
| 429 |
-
|
| 430 |
-
for method_name, method_func in methods:
|
| 431 |
-
times = []
|
| 432 |
-
for _ in range(5): # Run 5 times for average
|
| 433 |
-
start_time = time.time()
|
| 434 |
-
recs = method_func()
|
| 435 |
-
end_time = time.time()
|
| 436 |
-
times.append(end_time - start_time)
|
| 437 |
-
|
| 438 |
-
avg_time = np.mean(times)
|
| 439 |
-
print(f" {method_name}: {avg_time:.3f}s ± {np.std(times):.3f}s")
|
| 440 |
-
|
| 441 |
-
if avg_time > 1.0:
|
| 442 |
-
print(f" ⚠️ SLOW: Consider optimization")
|
| 443 |
-
elif avg_time > 0.5:
|
| 444 |
-
print(f" ⚠️ MODERATE: Acceptable for real-time")
|
| 445 |
-
else:
|
| 446 |
-
print(f" ✅ FAST: Good for real-time recommendations")
|
| 447 |
-
|
| 448 |
-
# Test scalability with different recommendation counts
|
| 449 |
-
print(f"\nTesting scalability...")
|
| 450 |
-
for k in [10, 50, 100]:
|
| 451 |
-
start_time = time.time()
|
| 452 |
-
recs = engine.recommend_items_hybrid(
|
| 453 |
-
age=test_user['age'], gender=test_user['gender'],
|
| 454 |
-
income=test_user['income'], interaction_history=test_user['interaction_history'][:20], k=k
|
| 455 |
-
)
|
| 456 |
-
end_time = time.time()
|
| 457 |
-
|
| 458 |
-
print(f" {k} recommendations: {end_time - start_time:.3f}s")
|
| 459 |
-
|
| 460 |
-
return True
|
| 461 |
-
|
| 462 |
-
except Exception as e:
|
| 463 |
-
print(f"❌ Performance analysis failed: {e}")
|
| 464 |
-
return False
|
| 465 |
-
|
| 466 |
-
def generate_quality_report():
|
| 467 |
-
"""Generate a comprehensive quality report."""
|
| 468 |
-
|
| 469 |
-
print(f"\n📋 COMPREHENSIVE QUALITY REPORT")
|
| 470 |
-
print("="*40)
|
| 471 |
-
|
| 472 |
-
# Run all analyses
|
| 473 |
-
score_results = analyze_score_distribution()
|
| 474 |
-
alignment_results = analyze_category_alignment()
|
| 475 |
-
diversity_results = analyze_diversity_metrics()
|
| 476 |
-
embedding_results = analyze_embedding_quality()
|
| 477 |
-
performance_results = analyze_performance_metrics()
|
| 478 |
-
|
| 479 |
-
# Generate summary
|
| 480 |
-
print(f"\n🎯 QUALITY SUMMARY:")
|
| 481 |
-
|
| 482 |
-
issues = []
|
| 483 |
-
strengths = []
|
| 484 |
-
|
| 485 |
-
# Check score quality
|
| 486 |
-
if score_results:
|
| 487 |
-
for method, scores in score_results.items():
|
| 488 |
-
if scores:
|
| 489 |
-
score_variance = np.var(scores)
|
| 490 |
-
score_range = max(scores) - min(scores)
|
| 491 |
-
|
| 492 |
-
if score_variance < 0.001:
|
| 493 |
-
issues.append(f"Low {method} score variance ({score_variance:.6f})")
|
| 494 |
-
if score_range < 0.1:
|
| 495 |
-
issues.append(f"Narrow {method} score range ({score_range:.4f})")
|
| 496 |
-
|
| 497 |
-
if score_variance > 0.01 and score_range > 0.3:
|
| 498 |
-
strengths.append(f"Good {method} score discrimination")
|
| 499 |
-
|
| 500 |
-
# Check alignment quality
|
| 501 |
-
if alignment_results:
|
| 502 |
-
avg_alignment = np.mean([r['alignment_percentage'] for r in alignment_results])
|
| 503 |
-
if avg_alignment < 30:
|
| 504 |
-
issues.append(f"Low category alignment ({avg_alignment:.1f}%)")
|
| 505 |
-
elif avg_alignment > 50:
|
| 506 |
-
strengths.append(f"Good category alignment ({avg_alignment:.1f}%)")
|
| 507 |
-
|
| 508 |
-
# Check diversity
|
| 509 |
-
if diversity_results:
|
| 510 |
-
avg_diversity = np.mean([r['category_diversity'] for r in diversity_results])
|
| 511 |
-
if avg_diversity < 0.3:
|
| 512 |
-
issues.append(f"Low category diversity ({avg_diversity:.2f})")
|
| 513 |
-
elif avg_diversity > 0.5:
|
| 514 |
-
strengths.append(f"Good category diversity ({avg_diversity:.2f})")
|
| 515 |
-
|
| 516 |
-
# Print results
|
| 517 |
-
if issues:
|
| 518 |
-
print(f"\n❌ ISSUES IDENTIFIED:")
|
| 519 |
-
for issue in issues:
|
| 520 |
-
print(f" • {issue}")
|
| 521 |
-
|
| 522 |
-
if strengths:
|
| 523 |
-
print(f"\n✅ STRENGTHS:")
|
| 524 |
-
for strength in strengths:
|
| 525 |
-
print(f" • {strength}")
|
| 526 |
-
|
| 527 |
-
# Recommendations
|
| 528 |
-
print(f"\n💡 RECOMMENDATIONS:")
|
| 529 |
-
if any("score variance" in issue for issue in issues):
|
| 530 |
-
print(" • Increase embedding dimensions or add temperature scaling")
|
| 531 |
-
if any("alignment" in issue for issue in issues):
|
| 532 |
-
print(" • Implement category-aware recommendation boosting")
|
| 533 |
-
if any("diversity" in issue for issue in issues):
|
| 534 |
-
print(" • Add diversity regularization to recommendation algorithm")
|
| 535 |
-
|
| 536 |
-
if not issues:
|
| 537 |
-
print(" • No major issues detected - model performing well!")
|
| 538 |
-
|
| 539 |
-
def main():
|
| 540 |
-
"""Main analysis function."""
|
| 541 |
-
|
| 542 |
-
print("🔍 TWO-TOWER RECOMMENDATION QUALITY ANALYSIS")
|
| 543 |
-
print("="*60)
|
| 544 |
-
|
| 545 |
-
try:
|
| 546 |
-
generate_quality_report()
|
| 547 |
-
|
| 548 |
-
print(f"\n✅ Analysis completed successfully!")
|
| 549 |
-
|
| 550 |
-
except Exception as e:
|
| 551 |
-
print(f"❌ Analysis failed: {e}")
|
| 552 |
-
import traceback
|
| 553 |
-
traceback.print_exc()
|
| 554 |
-
|
| 555 |
-
if __name__ == "__main__":
|
| 556 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Recommendation Analysis Script
|
| 4 |
+
|
| 5 |
+
This script compares recommendations from both training approaches:
|
| 6 |
+
1. 2-phase training (pre-trained item tower + joint fine-tuning)
|
| 7 |
+
2. Single joint training (end-to-end optimization)
|
| 8 |
+
|
| 9 |
+
It analyzes:
|
| 10 |
+
- Category alignment between user interactions and recommendations
|
| 11 |
+
- Diversity of recommended categories
|
| 12 |
+
- Overlap between the two approaches
|
| 13 |
+
- Performance on real users
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python analyze_recommendations.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
from collections import defaultdict, Counter
|
| 24 |
+
from typing import Dict, List, Tuple
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import seaborn as sns
|
| 27 |
+
|
| 28 |
+
# Add src to path
|
| 29 |
+
sys.path.append('src')
|
| 30 |
+
|
| 31 |
+
from src.inference.recommendation_engine import RecommendationEngine
|
| 32 |
+
from src.utils.real_user_selector import RealUserSelector
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class RecommendationAnalyzer:
|
| 36 |
+
"""Analyzer for comparing different recommendation approaches."""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.recommendation_engine = None
|
| 40 |
+
self.real_user_selector = None
|
| 41 |
+
self.items_df = None
|
| 42 |
+
self.setup_engines()
|
| 43 |
+
|
| 44 |
+
def setup_engines(self):
|
| 45 |
+
"""Setup recommendation engines and data."""
|
| 46 |
+
print("Loading recommendation engines...")
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
# Load recommendation engine (assumes trained model artifacts exist)
|
| 50 |
+
self.recommendation_engine = RecommendationEngine()
|
| 51 |
+
print("✅ Recommendation engine loaded")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"❌ Error loading recommendation engine: {e}")
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
# Load real user selector
|
| 58 |
+
self.real_user_selector = RealUserSelector()
|
| 59 |
+
print("✅ Real user selector loaded")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"❌ Error loading real user selector: {e}")
|
| 62 |
+
|
| 63 |
+
# Load items data for category analysis
|
| 64 |
+
self.items_df = pd.read_csv("datasets/items.csv")
|
| 65 |
+
print(f"✅ Loaded {len(self.items_df)} items")
|
| 66 |
+
|
| 67 |
+
def get_item_categories(self, item_ids: List[int]) -> List[str]:
|
| 68 |
+
"""Get category codes for given item IDs."""
|
| 69 |
+
categories = []
|
| 70 |
+
for item_id in item_ids:
|
| 71 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 72 |
+
if len(item_row) > 0:
|
| 73 |
+
categories.append(item_row.iloc[0]['category_code'])
|
| 74 |
+
else:
|
| 75 |
+
categories.append('unknown')
|
| 76 |
+
return categories
|
| 77 |
+
|
| 78 |
+
def analyze_user_recommendations(self,
|
| 79 |
+
user_profile: Dict,
|
| 80 |
+
recommendation_types: List[str] = None) -> Dict:
|
| 81 |
+
"""Analyze recommendations for a single user across different approaches."""
|
| 82 |
+
|
| 83 |
+
if recommendation_types is None:
|
| 84 |
+
recommendation_types = ['collaborative', 'hybrid', 'content']
|
| 85 |
+
|
| 86 |
+
results = {
|
| 87 |
+
'user_profile': user_profile,
|
| 88 |
+
'interaction_categories': [],
|
| 89 |
+
'recommendations': {},
|
| 90 |
+
'category_analysis': {}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Get categories from user's interaction history
|
| 94 |
+
if user_profile['interaction_history']:
|
| 95 |
+
results['interaction_categories'] = self.get_item_categories(
|
| 96 |
+
user_profile['interaction_history']
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Get recommendations for each type
|
| 100 |
+
for rec_type in recommendation_types:
|
| 101 |
+
try:
|
| 102 |
+
if rec_type == 'collaborative':
|
| 103 |
+
recs = self.recommendation_engine.recommend_items_collaborative(
|
| 104 |
+
age=user_profile['age'],
|
| 105 |
+
gender=user_profile['gender'],
|
| 106 |
+
income=user_profile['income'],
|
| 107 |
+
interaction_history=user_profile['interaction_history'],
|
| 108 |
+
k=10
|
| 109 |
+
)
|
| 110 |
+
elif rec_type == 'hybrid':
|
| 111 |
+
recs = self.recommendation_engine.recommend_items_hybrid(
|
| 112 |
+
age=user_profile['age'],
|
| 113 |
+
gender=user_profile['gender'],
|
| 114 |
+
income=user_profile['income'],
|
| 115 |
+
interaction_history=user_profile['interaction_history'],
|
| 116 |
+
k=10
|
| 117 |
+
)
|
| 118 |
+
elif rec_type == 'content' and user_profile['interaction_history']:
|
| 119 |
+
recs = self.recommendation_engine.recommend_items_content_based(
|
| 120 |
+
seed_item_id=user_profile['interaction_history'][-1],
|
| 121 |
+
k=10
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
# Extract item IDs and categories
|
| 127 |
+
item_ids = [item_id for item_id, score, info in recs]
|
| 128 |
+
rec_categories = self.get_item_categories(item_ids)
|
| 129 |
+
|
| 130 |
+
results['recommendations'][rec_type] = {
|
| 131 |
+
'items': recs,
|
| 132 |
+
'item_ids': item_ids,
|
| 133 |
+
'categories': rec_categories,
|
| 134 |
+
'scores': [score for item_id, score, info in recs]
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# Analyze category alignment
|
| 138 |
+
results['category_analysis'][rec_type] = self.analyze_category_alignment(
|
| 139 |
+
results['interaction_categories'],
|
| 140 |
+
rec_categories
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Error generating {rec_type} recommendations: {e}")
|
| 145 |
+
|
| 146 |
+
return results
|
| 147 |
+
|
| 148 |
+
def analyze_category_alignment(self,
|
| 149 |
+
interaction_categories: List[str],
|
| 150 |
+
recommendation_categories: List[str]) -> Dict:
|
| 151 |
+
"""Analyze alignment between interaction and recommendation categories."""
|
| 152 |
+
|
| 153 |
+
if not interaction_categories:
|
| 154 |
+
return {
|
| 155 |
+
'overlap_ratio': 0.0,
|
| 156 |
+
'unique_interaction_categories': 0,
|
| 157 |
+
'unique_recommendation_categories': len(set(recommendation_categories)),
|
| 158 |
+
'common_categories': [],
|
| 159 |
+
'category_distribution': Counter(recommendation_categories)
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
interaction_set = set(interaction_categories)
|
| 163 |
+
recommendation_set = set(recommendation_categories)
|
| 164 |
+
|
| 165 |
+
common_categories = interaction_set.intersection(recommendation_set)
|
| 166 |
+
overlap_ratio = len(common_categories) / len(interaction_set) if interaction_set else 0.0
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
'overlap_ratio': overlap_ratio,
|
| 170 |
+
'unique_interaction_categories': len(interaction_set),
|
| 171 |
+
'unique_recommendation_categories': len(recommendation_set),
|
| 172 |
+
'common_categories': list(common_categories),
|
| 173 |
+
'category_distribution': Counter(recommendation_categories),
|
| 174 |
+
'interaction_category_distribution': Counter(interaction_categories)
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
def compare_recommendation_approaches(self,
|
| 178 |
+
users_sample: List[Dict],
|
| 179 |
+
approaches: List[str] = None) -> Dict:
|
| 180 |
+
"""Compare different recommendation approaches across multiple users."""
|
| 181 |
+
|
| 182 |
+
if approaches is None:
|
| 183 |
+
approaches = ['collaborative', 'hybrid', 'content']
|
| 184 |
+
|
| 185 |
+
comparison_results = {
|
| 186 |
+
'approach_stats': defaultdict(list),
|
| 187 |
+
'cross_approach_analysis': {},
|
| 188 |
+
'user_results': []
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
print(f"Analyzing {len(users_sample)} users across {len(approaches)} approaches...")
|
| 192 |
+
|
| 193 |
+
for i, user in enumerate(users_sample):
|
| 194 |
+
print(f"Analyzing user {i+1}/{len(users_sample)}...")
|
| 195 |
+
|
| 196 |
+
user_results = self.analyze_user_recommendations(user, approaches)
|
| 197 |
+
comparison_results['user_results'].append(user_results)
|
| 198 |
+
|
| 199 |
+
# Aggregate stats by approach
|
| 200 |
+
for approach in approaches:
|
| 201 |
+
if approach in user_results['category_analysis']:
|
| 202 |
+
analysis = user_results['category_analysis'][approach]
|
| 203 |
+
comparison_results['approach_stats'][approach].append({
|
| 204 |
+
'overlap_ratio': analysis['overlap_ratio'],
|
| 205 |
+
'unique_rec_categories': analysis['unique_recommendation_categories'],
|
| 206 |
+
'common_categories_count': len(analysis['common_categories'])
|
| 207 |
+
})
|
| 208 |
+
|
| 209 |
+
# Calculate aggregate statistics
|
| 210 |
+
for approach in approaches:
|
| 211 |
+
stats = comparison_results['approach_stats'][approach]
|
| 212 |
+
if stats:
|
| 213 |
+
comparison_results['approach_stats'][approach] = {
|
| 214 |
+
'avg_overlap_ratio': np.mean([s['overlap_ratio'] for s in stats]),
|
| 215 |
+
'std_overlap_ratio': np.std([s['overlap_ratio'] for s in stats]),
|
| 216 |
+
'avg_unique_categories': np.mean([s['unique_rec_categories'] for s in stats]),
|
| 217 |
+
'avg_common_categories': np.mean([s['common_categories_count'] for s in stats]),
|
| 218 |
+
'total_users': len(stats)
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# Cross-approach analysis
|
| 222 |
+
comparison_results['cross_approach_analysis'] = self.cross_approach_analysis(
|
| 223 |
+
comparison_results['user_results'], approaches
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return comparison_results
|
| 227 |
+
|
| 228 |
+
def cross_approach_analysis(self, user_results: List[Dict], approaches: List[str]) -> Dict:
|
| 229 |
+
"""Analyze similarities and differences between approaches."""
|
| 230 |
+
|
| 231 |
+
cross_analysis = {
|
| 232 |
+
'item_overlap': defaultdict(dict),
|
| 233 |
+
'category_overlap': defaultdict(dict),
|
| 234 |
+
'score_correlation': defaultdict(dict)
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
for user_result in user_results:
|
| 238 |
+
recommendations = user_result['recommendations']
|
| 239 |
+
|
| 240 |
+
# Compare each pair of approaches
|
| 241 |
+
for i, approach1 in enumerate(approaches):
|
| 242 |
+
for approach2 in approaches[i+1:]:
|
| 243 |
+
if approach1 in recommendations and approach2 in recommendations:
|
| 244 |
+
|
| 245 |
+
# Item overlap
|
| 246 |
+
items1 = set(recommendations[approach1]['item_ids'])
|
| 247 |
+
items2 = set(recommendations[approach2]['item_ids'])
|
| 248 |
+
item_overlap_ratio = len(items1.intersection(items2)) / len(items1.union(items2))
|
| 249 |
+
|
| 250 |
+
# Category overlap
|
| 251 |
+
cats1 = set(recommendations[approach1]['categories'])
|
| 252 |
+
cats2 = set(recommendations[approach2]['categories'])
|
| 253 |
+
cat_overlap_ratio = len(cats1.intersection(cats2)) / len(cats1.union(cats2)) if cats1.union(cats2) else 0
|
| 254 |
+
|
| 255 |
+
# Store results
|
| 256 |
+
pair_key = f"{approach1}_vs_{approach2}"
|
| 257 |
+
if pair_key not in cross_analysis['item_overlap']:
|
| 258 |
+
cross_analysis['item_overlap'][pair_key] = []
|
| 259 |
+
cross_analysis['category_overlap'][pair_key] = []
|
| 260 |
+
|
| 261 |
+
cross_analysis['item_overlap'][pair_key].append(item_overlap_ratio)
|
| 262 |
+
cross_analysis['category_overlap'][pair_key].append(cat_overlap_ratio)
|
| 263 |
+
|
| 264 |
+
# Calculate averages
|
| 265 |
+
for pair_key in cross_analysis['item_overlap']:
|
| 266 |
+
cross_analysis['item_overlap'][pair_key] = {
|
| 267 |
+
'avg': np.mean(cross_analysis['item_overlap'][pair_key]),
|
| 268 |
+
'std': np.std(cross_analysis['item_overlap'][pair_key])
|
| 269 |
+
}
|
| 270 |
+
cross_analysis['category_overlap'][pair_key] = {
|
| 271 |
+
'avg': np.mean(cross_analysis['category_overlap'][pair_key]),
|
| 272 |
+
'std': np.std(cross_analysis['category_overlap'][pair_key])
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
return cross_analysis
|
| 276 |
+
|
| 277 |
+
def generate_report(self, comparison_results: Dict, output_file: str = "recommendation_analysis_report.md"):
|
| 278 |
+
"""Generate a comprehensive analysis report."""
|
| 279 |
+
|
| 280 |
+
report = []
|
| 281 |
+
report.append("# Recommendation System Analysis Report")
|
| 282 |
+
report.append(f"Generated: {pd.Timestamp.now()}")
|
| 283 |
+
report.append("")
|
| 284 |
+
|
| 285 |
+
# Overall Statistics
|
| 286 |
+
report.append("## Overall Statistics")
|
| 287 |
+
report.append("")
|
| 288 |
+
|
| 289 |
+
for approach, stats in comparison_results['approach_stats'].items():
|
| 290 |
+
if isinstance(stats, dict):
|
| 291 |
+
report.append(f"### {approach.title()} Recommendations")
|
| 292 |
+
report.append(f"- **Average Category Overlap**: {stats['avg_overlap_ratio']:.3f} ± {stats['std_overlap_ratio']:.3f}")
|
| 293 |
+
report.append(f"- **Average Unique Categories per User**: {stats['avg_unique_categories']:.1f}")
|
| 294 |
+
report.append(f"- **Average Common Categories**: {stats['avg_common_categories']:.1f}")
|
| 295 |
+
report.append(f"- **Users Analyzed**: {stats['total_users']}")
|
| 296 |
+
report.append("")
|
| 297 |
+
|
| 298 |
+
# Cross-Approach Analysis
|
| 299 |
+
report.append("## Cross-Approach Comparison")
|
| 300 |
+
report.append("")
|
| 301 |
+
|
| 302 |
+
cross_analysis = comparison_results['cross_approach_analysis']
|
| 303 |
+
|
| 304 |
+
report.append("### Item Overlap Between Approaches")
|
| 305 |
+
for pair, overlap_stats in cross_analysis['item_overlap'].items():
|
| 306 |
+
report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} ± {overlap_stats['std']:.3f}")
|
| 307 |
+
report.append("")
|
| 308 |
+
|
| 309 |
+
report.append("### Category Overlap Between Approaches")
|
| 310 |
+
for pair, overlap_stats in cross_analysis['category_overlap'].items():
|
| 311 |
+
report.append(f"- **{pair.replace('_', ' ').title()}**: {overlap_stats['avg']:.3f} ± {overlap_stats['std']:.3f}")
|
| 312 |
+
report.append("")
|
| 313 |
+
|
| 314 |
+
# Category Alignment Analysis
|
| 315 |
+
report.append("## Category Alignment Analysis")
|
| 316 |
+
report.append("")
|
| 317 |
+
report.append("Category alignment measures how well recommendations match the categories")
|
| 318 |
+
report.append("of items users have previously interacted with.")
|
| 319 |
+
report.append("")
|
| 320 |
+
|
| 321 |
+
# Find best performing approach
|
| 322 |
+
best_approach = max(
|
| 323 |
+
comparison_results['approach_stats'].keys(),
|
| 324 |
+
key=lambda k: comparison_results['approach_stats'][k]['avg_overlap_ratio']
|
| 325 |
+
if isinstance(comparison_results['approach_stats'][k], dict) else 0
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
report.append(f"**Best Category Alignment**: {best_approach.title()} approach")
|
| 329 |
+
report.append("")
|
| 330 |
+
|
| 331 |
+
# Recommendations
|
| 332 |
+
report.append("## Key Findings & Recommendations")
|
| 333 |
+
report.append("")
|
| 334 |
+
|
| 335 |
+
# Analyze overlap ratios to provide insights
|
| 336 |
+
overlap_ratios = {
|
| 337 |
+
k: v['avg_overlap_ratio'] for k, v in comparison_results['approach_stats'].items()
|
| 338 |
+
if isinstance(v, dict)
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
if overlap_ratios:
|
| 342 |
+
avg_overlap = np.mean(list(overlap_ratios.values()))
|
| 343 |
+
if avg_overlap > 0.5:
|
| 344 |
+
report.append("✅ **Strong Category Alignment**: Recommendations show good alignment with user interaction patterns.")
|
| 345 |
+
elif avg_overlap > 0.3:
|
| 346 |
+
report.append("⚠️ **Moderate Category Alignment**: Some alignment present but room for improvement.")
|
| 347 |
+
else:
|
| 348 |
+
report.append("❌ **Weak Category Alignment**: Recommendations may be too diverse or not well-aligned with user preferences.")
|
| 349 |
+
|
| 350 |
+
report.append("")
|
| 351 |
+
|
| 352 |
+
# Compare approaches
|
| 353 |
+
if len(overlap_ratios) > 1:
|
| 354 |
+
sorted_approaches = sorted(overlap_ratios.items(), key=lambda x: x[1], reverse=True)
|
| 355 |
+
report.append("### Approach Rankings (by category alignment):")
|
| 356 |
+
for i, (approach, ratio) in enumerate(sorted_approaches, 1):
|
| 357 |
+
report.append(f"{i}. **{approach.title()}**: {ratio:.3f}")
|
| 358 |
+
report.append("")
|
| 359 |
+
|
| 360 |
+
# Write report
|
| 361 |
+
with open(output_file, 'w') as f:
|
| 362 |
+
f.write('\n'.join(report))
|
| 363 |
+
|
| 364 |
+
print(f"✅ Analysis report saved to: {output_file}")
|
| 365 |
+
return '\n'.join(report)
|
| 366 |
+
|
| 367 |
+
def visualize_results(self, comparison_results: Dict, save_plots: bool = True):
|
| 368 |
+
"""Create visualizations for the analysis results."""
|
| 369 |
+
|
| 370 |
+
# Set up plotting style
|
| 371 |
+
plt.style.use('default')
|
| 372 |
+
sns.set_palette("husl")
|
| 373 |
+
|
| 374 |
+
# Create figure with subplots
|
| 375 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
| 376 |
+
fig.suptitle('Recommendation System Analysis', fontsize=16, fontweight='bold')
|
| 377 |
+
|
| 378 |
+
# 1. Category Overlap by Approach
|
| 379 |
+
ax1 = axes[0, 0]
|
| 380 |
+
approaches = []
|
| 381 |
+
overlap_means = []
|
| 382 |
+
overlap_stds = []
|
| 383 |
+
|
| 384 |
+
for approach, stats in comparison_results['approach_stats'].items():
|
| 385 |
+
if isinstance(stats, dict):
|
| 386 |
+
approaches.append(approach.title())
|
| 387 |
+
overlap_means.append(stats['avg_overlap_ratio'])
|
| 388 |
+
overlap_stds.append(stats['std_overlap_ratio'])
|
| 389 |
+
|
| 390 |
+
bars1 = ax1.bar(approaches, overlap_means, yerr=overlap_stds, capsize=5, alpha=0.7)
|
| 391 |
+
ax1.set_title('Average Category Overlap by Approach')
|
| 392 |
+
ax1.set_ylabel('Category Overlap Ratio')
|
| 393 |
+
ax1.set_ylim(0, 1)
|
| 394 |
+
|
| 395 |
+
# Add value labels on bars
|
| 396 |
+
for bar, mean in zip(bars1, overlap_means):
|
| 397 |
+
height = bar.get_height()
|
| 398 |
+
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 399 |
+
f'{mean:.3f}', ha='center', va='bottom')
|
| 400 |
+
|
| 401 |
+
# 2. Cross-Approach Item Overlap
|
| 402 |
+
ax2 = axes[0, 1]
|
| 403 |
+
cross_analysis = comparison_results['cross_approach_analysis']
|
| 404 |
+
|
| 405 |
+
pair_names = []
|
| 406 |
+
item_overlaps = []
|
| 407 |
+
|
| 408 |
+
for pair, overlap_stats in cross_analysis['item_overlap'].items():
|
| 409 |
+
pair_names.append(pair.replace('_vs_', ' vs ').title())
|
| 410 |
+
item_overlaps.append(overlap_stats['avg'])
|
| 411 |
+
|
| 412 |
+
if pair_names:
|
| 413 |
+
bars2 = ax2.bar(pair_names, item_overlaps, alpha=0.7, color='coral')
|
| 414 |
+
ax2.set_title('Item Overlap Between Approaches')
|
| 415 |
+
ax2.set_ylabel('Item Overlap Ratio')
|
| 416 |
+
ax2.set_ylim(0, 1)
|
| 417 |
+
plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
|
| 418 |
+
|
| 419 |
+
# Add value labels
|
| 420 |
+
for bar, overlap in zip(bars2, item_overlaps):
|
| 421 |
+
height = bar.get_height()
|
| 422 |
+
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 423 |
+
f'{overlap:.3f}', ha='center', va='bottom')
|
| 424 |
+
|
| 425 |
+
# 3. Category Diversity
|
| 426 |
+
ax3 = axes[1, 0]
|
| 427 |
+
unique_categories = []
|
| 428 |
+
for approach, stats in comparison_results['approach_stats'].items():
|
| 429 |
+
if isinstance(stats, dict):
|
| 430 |
+
unique_categories.append(stats['avg_unique_categories'])
|
| 431 |
+
|
| 432 |
+
bars3 = ax3.bar(approaches, unique_categories, alpha=0.7, color='lightgreen')
|
| 433 |
+
ax3.set_title('Average Unique Categories per Recommendation')
|
| 434 |
+
ax3.set_ylabel('Number of Unique Categories')
|
| 435 |
+
|
| 436 |
+
for bar, cats in zip(bars3, unique_categories):
|
| 437 |
+
height = bar.get_height()
|
| 438 |
+
ax3.text(bar.get_x() + bar.get_width()/2., height + 0.1,
|
| 439 |
+
f'{cats:.1f}', ha='center', va='bottom')
|
| 440 |
+
|
| 441 |
+
# 4. Category vs Item Overlap Comparison
|
| 442 |
+
ax4 = axes[1, 1]
|
| 443 |
+
|
| 444 |
+
if cross_analysis['item_overlap'] and cross_analysis['category_overlap']:
|
| 445 |
+
pairs = list(cross_analysis['item_overlap'].keys())
|
| 446 |
+
item_overlaps = [cross_analysis['item_overlap'][p]['avg'] for p in pairs]
|
| 447 |
+
cat_overlaps = [cross_analysis['category_overlap'][p]['avg'] for p in pairs]
|
| 448 |
+
|
| 449 |
+
x = np.arange(len(pairs))
|
| 450 |
+
width = 0.35
|
| 451 |
+
|
| 452 |
+
bars4a = ax4.bar(x - width/2, item_overlaps, width, label='Item Overlap', alpha=0.7)
|
| 453 |
+
bars4b = ax4.bar(x + width/2, cat_overlaps, width, label='Category Overlap', alpha=0.7)
|
| 454 |
+
|
| 455 |
+
ax4.set_title('Item vs Category Overlap Between Approaches')
|
| 456 |
+
ax4.set_ylabel('Overlap Ratio')
|
| 457 |
+
ax4.set_xticks(x)
|
| 458 |
+
ax4.set_xticklabels([p.replace('_vs_', ' vs ') for p in pairs], rotation=45, ha='right')
|
| 459 |
+
ax4.legend()
|
| 460 |
+
ax4.set_ylim(0, 1)
|
| 461 |
+
|
| 462 |
+
plt.tight_layout()
|
| 463 |
+
|
| 464 |
+
if save_plots:
|
| 465 |
+
plt.savefig('recommendation_analysis_plots.png', dpi=300, bbox_inches='tight')
|
| 466 |
+
print("✅ Plots saved to: recommendation_analysis_plots.png")
|
| 467 |
+
|
| 468 |
+
plt.show()
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def main():
|
| 472 |
+
"""Main function to run the recommendation analysis."""
|
| 473 |
+
|
| 474 |
+
print("🔍 Starting Recommendation Analysis...")
|
| 475 |
+
print("=" * 50)
|
| 476 |
+
|
| 477 |
+
# Initialize analyzer
|
| 478 |
+
analyzer = RecommendationAnalyzer()
|
| 479 |
+
|
| 480 |
+
if analyzer.recommendation_engine is None:
|
| 481 |
+
print("❌ Cannot proceed without recommendation engine. Please ensure model is trained.")
|
| 482 |
+
return
|
| 483 |
+
|
| 484 |
+
# Get sample of real users for analysis
|
| 485 |
+
print("Getting real user sample...")
|
| 486 |
+
try:
|
| 487 |
+
real_users = analyzer.real_user_selector.get_real_users(n=20, min_interactions=3)
|
| 488 |
+
print(f"✅ Loaded {len(real_users)} real users for analysis")
|
| 489 |
+
except Exception as e:
|
| 490 |
+
print(f"❌ Error loading real users: {e}")
|
| 491 |
+
# Fallback to synthetic users
|
| 492 |
+
real_users = [
|
| 493 |
+
{
|
| 494 |
+
'age': 32, 'gender': 'male', 'income': 75000,
|
| 495 |
+
'interaction_history': [1000978, 1001588, 1001618, 1002039]
|
| 496 |
+
},
|
| 497 |
+
{
|
| 498 |
+
'age': 28, 'gender': 'female', 'income': 45000,
|
| 499 |
+
'interaction_history': [1003456, 1004567, 1005678]
|
| 500 |
+
},
|
| 501 |
+
{
|
| 502 |
+
'age': 45, 'gender': 'male', 'income': 85000,
|
| 503 |
+
'interaction_history': [1006789, 1007890, 1008901, 1009012, 1010123]
|
| 504 |
+
}
|
| 505 |
+
]
|
| 506 |
+
print(f"Using {len(real_users)} synthetic users for analysis")
|
| 507 |
+
|
| 508 |
+
# Run comprehensive analysis
|
| 509 |
+
print("Running recommendation analysis...")
|
| 510 |
+
approaches = ['collaborative', 'hybrid', 'content']
|
| 511 |
+
|
| 512 |
+
comparison_results = analyzer.compare_recommendation_approaches(
|
| 513 |
+
users_sample=real_users,
|
| 514 |
+
approaches=approaches
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Generate report
|
| 518 |
+
print("Generating analysis report...")
|
| 519 |
+
report = analyzer.generate_report(comparison_results)
|
| 520 |
+
|
| 521 |
+
# Create visualizations
|
| 522 |
+
print("Creating visualizations...")
|
| 523 |
+
try:
|
| 524 |
+
analyzer.visualize_results(comparison_results, save_plots=True)
|
| 525 |
+
except Exception as e:
|
| 526 |
+
print(f"Warning: Could not create visualizations: {e}")
|
| 527 |
+
|
| 528 |
+
# Print summary
|
| 529 |
+
print("\n" + "=" * 50)
|
| 530 |
+
print("📊 ANALYSIS SUMMARY")
|
| 531 |
+
print("=" * 50)
|
| 532 |
+
|
| 533 |
+
for approach, stats in comparison_results['approach_stats'].items():
|
| 534 |
+
if isinstance(stats, dict):
|
| 535 |
+
print(f"{approach.title()}: {stats['avg_overlap_ratio']:.3f} avg category overlap")
|
| 536 |
+
|
| 537 |
+
print(f"\n✅ Analysis complete! Check:")
|
| 538 |
+
print(" 📄 recommendation_analysis_report.md")
|
| 539 |
+
print(" 📊 recommendation_analysis_plots.png")
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
if __name__ == "__main__":
|
| 543 |
+
main()
|
|
@@ -13,6 +13,7 @@ sys.path.append(parent_dir)
|
|
| 13 |
os.chdir(parent_dir) # Change to project root directory
|
| 14 |
|
| 15 |
from src.inference.recommendation_engine import RecommendationEngine
|
|
|
|
| 16 |
|
| 17 |
# Initialize FastAPI app
|
| 18 |
app = FastAPI(
|
|
@@ -30,8 +31,10 @@ app.add_middleware(
|
|
| 30 |
allow_headers=["*"],
|
| 31 |
)
|
| 32 |
|
| 33 |
-
# Global
|
| 34 |
recommendation_engine = None
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
# Pydantic models for request/response
|
|
@@ -45,8 +48,11 @@ class UserProfile(BaseModel):
|
|
| 45 |
class RecommendationRequest(BaseModel):
|
| 46 |
user_profile: UserProfile
|
| 47 |
num_recommendations: int = 10
|
| 48 |
-
recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid"
|
| 49 |
collaborative_weight: Optional[float] = 0.7
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
class ItemSimilarityRequest(BaseModel):
|
|
@@ -87,10 +93,27 @@ class RatingPredictionResponse(BaseModel):
|
|
| 87 |
item_info: ItemInfo
|
| 88 |
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
@app.on_event("startup")
|
| 91 |
async def startup_event():
|
| 92 |
-
"""Initialize the recommendation
|
| 93 |
-
global recommendation_engine
|
| 94 |
|
| 95 |
try:
|
| 96 |
print("Loading recommendation engine...")
|
|
@@ -99,6 +122,30 @@ async def startup_event():
|
|
| 99 |
except Exception as e:
|
| 100 |
print(f"Error loading recommendation engine: {e}")
|
| 101 |
recommendation_engine = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
@app.get("/")
|
|
@@ -120,6 +167,68 @@ async def health_check():
|
|
| 120 |
}
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
@app.post("/recommendations", response_model=RecommendationsResponse)
|
| 124 |
async def get_recommendations(request: RecommendationRequest):
|
| 125 |
"""Get item recommendations for a user."""
|
|
@@ -164,10 +273,67 @@ async def get_recommendations(request: RecommendationRequest):
|
|
| 164 |
collaborative_weight=request.collaborative_weight
|
| 165 |
)
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
else:
|
| 168 |
raise HTTPException(
|
| 169 |
status_code=400,
|
| 170 |
-
detail="Invalid recommendation_type. Must be 'collaborative', 'content', or '
|
| 171 |
)
|
| 172 |
|
| 173 |
# Format response
|
|
|
|
| 13 |
os.chdir(parent_dir) # Change to project root directory
|
| 14 |
|
| 15 |
from src.inference.recommendation_engine import RecommendationEngine
|
| 16 |
+
from src.utils.real_user_selector import RealUserSelector
|
| 17 |
|
| 18 |
# Initialize FastAPI app
|
| 19 |
app = FastAPI(
|
|
|
|
| 31 |
allow_headers=["*"],
|
| 32 |
)
|
| 33 |
|
| 34 |
+
# Global instances
|
| 35 |
recommendation_engine = None
|
| 36 |
+
enhanced_recommendation_engine = None
|
| 37 |
+
real_user_selector = None
|
| 38 |
|
| 39 |
|
| 40 |
# Pydantic models for request/response
|
|
|
|
| 48 |
class RecommendationRequest(BaseModel):
|
| 49 |
user_profile: UserProfile
|
| 50 |
num_recommendations: int = 10
|
| 51 |
+
recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
|
| 52 |
collaborative_weight: Optional[float] = 0.7
|
| 53 |
+
category_boost: Optional[float] = 1.5 # For enhanced recommendations
|
| 54 |
+
enable_category_boost: Optional[bool] = True
|
| 55 |
+
enable_diversity: Optional[bool] = True
|
| 56 |
|
| 57 |
|
| 58 |
class ItemSimilarityRequest(BaseModel):
|
|
|
|
| 93 |
item_info: ItemInfo
|
| 94 |
|
| 95 |
|
| 96 |
+
class RealUserProfile(BaseModel):
|
| 97 |
+
user_id: int
|
| 98 |
+
age: int
|
| 99 |
+
gender: str
|
| 100 |
+
income: int
|
| 101 |
+
interaction_history: List[int]
|
| 102 |
+
interaction_stats: Dict[str, int]
|
| 103 |
+
interaction_pattern: str
|
| 104 |
+
summary: str
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class RealUsersResponse(BaseModel):
|
| 108 |
+
users: List[RealUserProfile]
|
| 109 |
+
total_count: int
|
| 110 |
+
dataset_summary: Dict[str, Any]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
@app.on_event("startup")
|
| 114 |
async def startup_event():
|
| 115 |
+
"""Initialize the recommendation engines and real user selector on startup."""
|
| 116 |
+
global recommendation_engine, enhanced_recommendation_engine, real_user_selector
|
| 117 |
|
| 118 |
try:
|
| 119 |
print("Loading recommendation engine...")
|
|
|
|
| 122 |
except Exception as e:
|
| 123 |
print(f"Error loading recommendation engine: {e}")
|
| 124 |
recommendation_engine = None
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
print("Loading enhanced recommendation engine...")
|
| 128 |
+
# Try enhanced 128D engine first, fallback to regular enhanced
|
| 129 |
+
try:
|
| 130 |
+
from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
|
| 131 |
+
enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
|
| 132 |
+
print("✅ Using Enhanced 128D Recommendation Engine")
|
| 133 |
+
except:
|
| 134 |
+
from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
|
| 135 |
+
enhanced_recommendation_engine = EnhancedRecommendationEngine()
|
| 136 |
+
print("⚠️ Using fallback Enhanced Recommendation Engine")
|
| 137 |
+
print("Enhanced recommendation engine loaded successfully!")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f"Error loading enhanced recommendation engine: {e}")
|
| 140 |
+
enhanced_recommendation_engine = None
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
print("Loading real user selector...")
|
| 144 |
+
real_user_selector = RealUserSelector()
|
| 145 |
+
print("Real user selector loaded successfully!")
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error loading real user selector: {e}")
|
| 148 |
+
real_user_selector = None
|
| 149 |
|
| 150 |
|
| 151 |
@app.get("/")
|
|
|
|
| 167 |
}
|
| 168 |
|
| 169 |
|
| 170 |
+
@app.get("/real-users", response_model=RealUsersResponse)
|
| 171 |
+
async def get_real_users(count: int = 100, min_interactions: int = 5):
|
| 172 |
+
"""Get real user profiles with genuine interaction histories."""
|
| 173 |
+
|
| 174 |
+
if real_user_selector is None:
|
| 175 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
# Get real user profiles
|
| 179 |
+
real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
|
| 180 |
+
|
| 181 |
+
# Get dataset summary
|
| 182 |
+
dataset_summary = real_user_selector.get_dataset_summary()
|
| 183 |
+
|
| 184 |
+
# Format users for response
|
| 185 |
+
formatted_users = []
|
| 186 |
+
for user in real_users:
|
| 187 |
+
formatted_users.append(RealUserProfile(**user))
|
| 188 |
+
|
| 189 |
+
return RealUsersResponse(
|
| 190 |
+
users=formatted_users,
|
| 191 |
+
total_count=len(formatted_users),
|
| 192 |
+
dataset_summary=dataset_summary
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@app.get("/real-users/{user_id}")
|
| 200 |
+
async def get_real_user_details(user_id: int):
|
| 201 |
+
"""Get detailed interaction breakdown for a specific real user."""
|
| 202 |
+
|
| 203 |
+
if real_user_selector is None:
|
| 204 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
user_details = real_user_selector.get_user_interaction_details(user_id)
|
| 208 |
+
|
| 209 |
+
if "error" in user_details:
|
| 210 |
+
raise HTTPException(status_code=404, detail=user_details["error"])
|
| 211 |
+
|
| 212 |
+
return user_details
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@app.get("/dataset-summary")
|
| 219 |
+
async def get_dataset_summary():
|
| 220 |
+
"""Get summary statistics of the real dataset."""
|
| 221 |
+
|
| 222 |
+
if real_user_selector is None:
|
| 223 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
return real_user_selector.get_dataset_summary()
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
@app.post("/recommendations", response_model=RecommendationsResponse)
|
| 233 |
async def get_recommendations(request: RecommendationRequest):
|
| 234 |
"""Get item recommendations for a user."""
|
|
|
|
| 273 |
collaborative_weight=request.collaborative_weight
|
| 274 |
)
|
| 275 |
|
| 276 |
+
elif request.recommendation_type == "enhanced":
|
| 277 |
+
if enhanced_recommendation_engine is None:
|
| 278 |
+
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 279 |
+
|
| 280 |
+
# Check if it's the 128D engine or fallback
|
| 281 |
+
if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 282 |
+
# 128D Enhanced engine
|
| 283 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 284 |
+
age=user_profile.age,
|
| 285 |
+
gender=user_profile.gender,
|
| 286 |
+
income=user_profile.income,
|
| 287 |
+
interaction_history=user_profile.interaction_history,
|
| 288 |
+
k=request.num_recommendations,
|
| 289 |
+
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 290 |
+
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
# Fallback enhanced engine
|
| 294 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
|
| 295 |
+
age=user_profile.age,
|
| 296 |
+
gender=user_profile.gender,
|
| 297 |
+
income=user_profile.income,
|
| 298 |
+
interaction_history=user_profile.interaction_history,
|
| 299 |
+
k=request.num_recommendations,
|
| 300 |
+
collaborative_weight=request.collaborative_weight,
|
| 301 |
+
category_boost=request.category_boost,
|
| 302 |
+
enable_category_boost=request.enable_category_boost,
|
| 303 |
+
enable_diversity=request.enable_diversity
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
elif request.recommendation_type == "enhanced_128d":
|
| 307 |
+
if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 308 |
+
raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
|
| 309 |
+
|
| 310 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 311 |
+
age=user_profile.age,
|
| 312 |
+
gender=user_profile.gender,
|
| 313 |
+
income=user_profile.income,
|
| 314 |
+
interaction_history=user_profile.interaction_history,
|
| 315 |
+
k=request.num_recommendations,
|
| 316 |
+
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 317 |
+
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
elif request.recommendation_type == "category_focused":
|
| 321 |
+
if enhanced_recommendation_engine is None:
|
| 322 |
+
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 323 |
+
|
| 324 |
+
recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
|
| 325 |
+
age=user_profile.age,
|
| 326 |
+
gender=user_profile.gender,
|
| 327 |
+
income=user_profile.income,
|
| 328 |
+
interaction_history=user_profile.interaction_history,
|
| 329 |
+
k=request.num_recommendations,
|
| 330 |
+
focus_percentage=0.8
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
else:
|
| 334 |
raise HTTPException(
|
| 335 |
status_code=400,
|
| 336 |
+
detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
|
| 337 |
)
|
| 338 |
|
| 339 |
# Format response
|
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
API for 2-Phase Trained Recommendation System
|
| 4 |
+
|
| 5 |
+
This API serves recommendations from a model trained using the 2-phase approach:
|
| 6 |
+
1. Pre-trained item tower
|
| 7 |
+
2. Joint training with fine-tuned item tower
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python api_2phase.py
|
| 11 |
+
|
| 12 |
+
Then access: http://localhost:8000
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from fastapi import FastAPI, HTTPException
|
| 16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
+
from pydantic import BaseModel
|
| 18 |
+
from typing import List, Optional, Dict, Any
|
| 19 |
+
import uvicorn
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import pandas as pd
|
| 23 |
+
|
| 24 |
+
# Add src to path for imports and set working directory
|
| 25 |
+
parent_dir = os.path.dirname(__file__)
|
| 26 |
+
sys.path.append(parent_dir)
|
| 27 |
+
os.chdir(parent_dir) # Change to project root directory
|
| 28 |
+
|
| 29 |
+
from src.inference.recommendation_engine import RecommendationEngine
|
| 30 |
+
from src.utils.real_user_selector import RealUserSelector
|
| 31 |
+
|
| 32 |
+
# Initialize FastAPI app
|
| 33 |
+
app = FastAPI(
|
| 34 |
+
title="Two-Tower Recommendation API (2-Phase Training)",
|
| 35 |
+
description="API for serving recommendations using a two-tower architecture trained with 2-phase approach",
|
| 36 |
+
version="1.0.0-2phase"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Add CORS middleware
|
| 40 |
+
app.add_middleware(
|
| 41 |
+
CORSMiddleware,
|
| 42 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 43 |
+
allow_credentials=True,
|
| 44 |
+
allow_methods=["*"],
|
| 45 |
+
allow_headers=["*"],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Global instances
|
| 49 |
+
recommendation_engine = None
|
| 50 |
+
enhanced_recommendation_engine = None
|
| 51 |
+
real_user_selector = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Pydantic models for request/response
|
| 55 |
+
class UserProfile(BaseModel):
|
| 56 |
+
age: int
|
| 57 |
+
gender: str # "male" or "female"
|
| 58 |
+
income: float
|
| 59 |
+
interaction_history: Optional[List[int]] = []
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RecommendationRequest(BaseModel):
|
| 63 |
+
user_profile: UserProfile
|
| 64 |
+
num_recommendations: int = 10
|
| 65 |
+
recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
|
| 66 |
+
collaborative_weight: Optional[float] = 0.7
|
| 67 |
+
category_boost: Optional[float] = 1.5 # For enhanced recommendations
|
| 68 |
+
enable_category_boost: Optional[bool] = True
|
| 69 |
+
enable_diversity: Optional[bool] = True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ItemSimilarityRequest(BaseModel):
|
| 73 |
+
item_id: int
|
| 74 |
+
num_recommendations: int = 10
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class RatingPredictionRequest(BaseModel):
|
| 78 |
+
user_profile: UserProfile
|
| 79 |
+
item_id: int
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ItemInfo(BaseModel):
|
| 83 |
+
product_id: int
|
| 84 |
+
category_id: int
|
| 85 |
+
category_code: str
|
| 86 |
+
brand: str
|
| 87 |
+
price: float
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class RecommendationResponse(BaseModel):
|
| 91 |
+
item_id: int
|
| 92 |
+
score: float
|
| 93 |
+
item_info: ItemInfo
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class RecommendationsResponse(BaseModel):
|
| 97 |
+
recommendations: List[RecommendationResponse]
|
| 98 |
+
user_profile: UserProfile
|
| 99 |
+
recommendation_type: str
|
| 100 |
+
total_count: int
|
| 101 |
+
training_approach: str = "2-phase"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class RatingPredictionResponse(BaseModel):
|
| 105 |
+
user_profile: UserProfile
|
| 106 |
+
item_id: int
|
| 107 |
+
predicted_rating: float
|
| 108 |
+
item_info: ItemInfo
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class RealUserProfile(BaseModel):
|
| 112 |
+
user_id: int
|
| 113 |
+
age: int
|
| 114 |
+
gender: str
|
| 115 |
+
income: int
|
| 116 |
+
interaction_history: List[int]
|
| 117 |
+
interaction_stats: Dict[str, int]
|
| 118 |
+
interaction_pattern: str
|
| 119 |
+
summary: str
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class RealUsersResponse(BaseModel):
|
| 123 |
+
users: List[RealUserProfile]
|
| 124 |
+
total_count: int
|
| 125 |
+
dataset_summary: Dict[str, Any]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.on_event("startup")
|
| 129 |
+
async def startup_event():
|
| 130 |
+
"""Initialize the recommendation engines and real user selector on startup."""
|
| 131 |
+
global recommendation_engine, enhanced_recommendation_engine, real_user_selector
|
| 132 |
+
|
| 133 |
+
print("🚀 Starting 2-Phase Training API...")
|
| 134 |
+
print(" Training approach: Pre-trained item tower + Joint fine-tuning")
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
print("Loading 2-phase trained recommendation engine...")
|
| 138 |
+
recommendation_engine = RecommendationEngine()
|
| 139 |
+
print("✅ 2-phase recommendation engine loaded successfully!")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"❌ Error loading recommendation engine: {e}")
|
| 142 |
+
recommendation_engine = None
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
print("Loading enhanced recommendation engine...")
|
| 146 |
+
# Try enhanced 128D engine first, fallback to regular enhanced
|
| 147 |
+
try:
|
| 148 |
+
from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
|
| 149 |
+
enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
|
| 150 |
+
print("✅ Using Enhanced 128D Recommendation Engine")
|
| 151 |
+
except:
|
| 152 |
+
from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
|
| 153 |
+
enhanced_recommendation_engine = EnhancedRecommendationEngine()
|
| 154 |
+
print("⚠️ Using fallback Enhanced Recommendation Engine")
|
| 155 |
+
print("Enhanced recommendation engine loaded successfully!")
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Error loading enhanced recommendation engine: {e}")
|
| 158 |
+
enhanced_recommendation_engine = None
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
print("Loading real user selector...")
|
| 162 |
+
real_user_selector = RealUserSelector()
|
| 163 |
+
print("Real user selector loaded successfully!")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Error loading real user selector: {e}")
|
| 166 |
+
real_user_selector = None
|
| 167 |
+
|
| 168 |
+
print("🎯 2-Phase API ready to serve recommendations!")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@app.get("/")
|
| 172 |
+
async def root():
|
| 173 |
+
"""Root endpoint with API information."""
|
| 174 |
+
return {
|
| 175 |
+
"message": "Two-Tower Recommendation API (2-Phase Training)",
|
| 176 |
+
"version": "1.0.0-2phase",
|
| 177 |
+
"training_approach": "2-phase (pre-trained item tower + joint fine-tuning)",
|
| 178 |
+
"status": "active" if recommendation_engine is not None else "initialization_failed"
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@app.get("/health")
|
| 183 |
+
async def health_check():
|
| 184 |
+
"""Health check endpoint."""
|
| 185 |
+
return {
|
| 186 |
+
"status": "healthy" if recommendation_engine is not None else "unhealthy",
|
| 187 |
+
"engine_loaded": recommendation_engine is not None,
|
| 188 |
+
"training_approach": "2-phase"
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@app.get("/model-info")
|
| 193 |
+
async def model_info():
|
| 194 |
+
"""Get information about the loaded model."""
|
| 195 |
+
if recommendation_engine is None:
|
| 196 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"training_approach": "2-phase",
|
| 200 |
+
"description": "Pre-trained item tower followed by joint training with user tower",
|
| 201 |
+
"phases": [
|
| 202 |
+
"Phase 1: Item tower pre-training on item features only",
|
| 203 |
+
"Phase 2: Joint training of user tower + fine-tuning pre-trained item tower"
|
| 204 |
+
],
|
| 205 |
+
"embedding_dimension": 128,
|
| 206 |
+
"item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
|
| 207 |
+
"artifacts_loaded": {
|
| 208 |
+
"item_tower_pretrained": "src/artifacts/item_tower_weights",
|
| 209 |
+
"item_tower_finetuned": "src/artifacts/item_tower_weights_finetuned_best",
|
| 210 |
+
"user_tower": "src/artifacts/user_tower_weights_best",
|
| 211 |
+
"rating_model": "src/artifacts/rating_model_weights_best"
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@app.get("/real-users", response_model=RealUsersResponse)
|
| 217 |
+
async def get_real_users(count: int = 100, min_interactions: int = 5):
|
| 218 |
+
"""Get real user profiles with genuine interaction histories."""
|
| 219 |
+
|
| 220 |
+
if real_user_selector is None:
|
| 221 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
# Get real user profiles
|
| 225 |
+
real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
|
| 226 |
+
|
| 227 |
+
# Get dataset summary
|
| 228 |
+
dataset_summary = real_user_selector.get_dataset_summary()
|
| 229 |
+
|
| 230 |
+
# Format users for response
|
| 231 |
+
formatted_users = []
|
| 232 |
+
for user in real_users:
|
| 233 |
+
formatted_users.append(RealUserProfile(**user))
|
| 234 |
+
|
| 235 |
+
return RealUsersResponse(
|
| 236 |
+
users=formatted_users,
|
| 237 |
+
total_count=len(formatted_users),
|
| 238 |
+
dataset_summary=dataset_summary
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@app.get("/real-users/{user_id}")
|
| 246 |
+
async def get_real_user_details(user_id: int):
|
| 247 |
+
"""Get detailed interaction breakdown for a specific real user."""
|
| 248 |
+
|
| 249 |
+
if real_user_selector is None:
|
| 250 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
user_details = real_user_selector.get_user_interaction_details(user_id)
|
| 254 |
+
|
| 255 |
+
if "error" in user_details:
|
| 256 |
+
raise HTTPException(status_code=404, detail=user_details["error"])
|
| 257 |
+
|
| 258 |
+
return user_details
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@app.get("/dataset-summary")
|
| 265 |
+
async def get_dataset_summary():
|
| 266 |
+
"""Get summary statistics of the real dataset."""
|
| 267 |
+
|
| 268 |
+
if real_user_selector is None:
|
| 269 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
return real_user_selector.get_dataset_summary()
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@app.post("/recommendations", response_model=RecommendationsResponse)
|
| 279 |
+
async def get_recommendations(request: RecommendationRequest):
|
| 280 |
+
"""Get item recommendations for a user."""
|
| 281 |
+
|
| 282 |
+
if recommendation_engine is None:
|
| 283 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
user_profile = request.user_profile
|
| 287 |
+
|
| 288 |
+
# Generate recommendations based on type
|
| 289 |
+
if request.recommendation_type == "collaborative":
|
| 290 |
+
recommendations = recommendation_engine.recommend_items_collaborative(
|
| 291 |
+
age=user_profile.age,
|
| 292 |
+
gender=user_profile.gender,
|
| 293 |
+
income=user_profile.income,
|
| 294 |
+
interaction_history=user_profile.interaction_history,
|
| 295 |
+
k=request.num_recommendations
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
elif request.recommendation_type == "content":
|
| 299 |
+
if not user_profile.interaction_history:
|
| 300 |
+
raise HTTPException(
|
| 301 |
+
status_code=400,
|
| 302 |
+
detail="Content-based recommendations require interaction history"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Use most recent interaction as seed
|
| 306 |
+
seed_item = user_profile.interaction_history[-1]
|
| 307 |
+
recommendations = recommendation_engine.recommend_items_content_based(
|
| 308 |
+
seed_item_id=seed_item,
|
| 309 |
+
k=request.num_recommendations
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
elif request.recommendation_type == "hybrid":
|
| 313 |
+
recommendations = recommendation_engine.recommend_items_hybrid(
|
| 314 |
+
age=user_profile.age,
|
| 315 |
+
gender=user_profile.gender,
|
| 316 |
+
income=user_profile.income,
|
| 317 |
+
interaction_history=user_profile.interaction_history,
|
| 318 |
+
k=request.num_recommendations,
|
| 319 |
+
collaborative_weight=request.collaborative_weight
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
elif request.recommendation_type == "enhanced":
|
| 323 |
+
if enhanced_recommendation_engine is None:
|
| 324 |
+
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 325 |
+
|
| 326 |
+
# Check if it's the 128D engine or fallback
|
| 327 |
+
if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 328 |
+
# 128D Enhanced engine
|
| 329 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 330 |
+
age=user_profile.age,
|
| 331 |
+
gender=user_profile.gender,
|
| 332 |
+
income=user_profile.income,
|
| 333 |
+
interaction_history=user_profile.interaction_history,
|
| 334 |
+
k=request.num_recommendations,
|
| 335 |
+
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 336 |
+
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
# Fallback enhanced engine
|
| 340 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
|
| 341 |
+
age=user_profile.age,
|
| 342 |
+
gender=user_profile.gender,
|
| 343 |
+
income=user_profile.income,
|
| 344 |
+
interaction_history=user_profile.interaction_history,
|
| 345 |
+
k=request.num_recommendations,
|
| 346 |
+
collaborative_weight=request.collaborative_weight,
|
| 347 |
+
category_boost=request.category_boost,
|
| 348 |
+
enable_category_boost=request.enable_category_boost,
|
| 349 |
+
enable_diversity=request.enable_diversity
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
elif request.recommendation_type == "enhanced_128d":
|
| 353 |
+
if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 354 |
+
raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
|
| 355 |
+
|
| 356 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 357 |
+
age=user_profile.age,
|
| 358 |
+
gender=user_profile.gender,
|
| 359 |
+
income=user_profile.income,
|
| 360 |
+
interaction_history=user_profile.interaction_history,
|
| 361 |
+
k=request.num_recommendations,
|
| 362 |
+
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 363 |
+
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
elif request.recommendation_type == "category_focused":
|
| 367 |
+
if enhanced_recommendation_engine is None:
|
| 368 |
+
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 369 |
+
|
| 370 |
+
recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
|
| 371 |
+
age=user_profile.age,
|
| 372 |
+
gender=user_profile.gender,
|
| 373 |
+
income=user_profile.income,
|
| 374 |
+
interaction_history=user_profile.interaction_history,
|
| 375 |
+
k=request.num_recommendations,
|
| 376 |
+
focus_percentage=0.8
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
else:
|
| 380 |
+
raise HTTPException(
|
| 381 |
+
status_code=400,
|
| 382 |
+
detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Format response
|
| 386 |
+
formatted_recommendations = []
|
| 387 |
+
for item_id, score, item_info in recommendations:
|
| 388 |
+
formatted_recommendations.append(
|
| 389 |
+
RecommendationResponse(
|
| 390 |
+
item_id=item_id,
|
| 391 |
+
score=score,
|
| 392 |
+
item_info=ItemInfo(**item_info)
|
| 393 |
+
)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
return RecommendationsResponse(
|
| 397 |
+
recommendations=formatted_recommendations,
|
| 398 |
+
user_profile=user_profile,
|
| 399 |
+
recommendation_type=request.recommendation_type,
|
| 400 |
+
total_count=len(formatted_recommendations),
|
| 401 |
+
training_approach="2-phase"
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
@app.post("/item-similarity", response_model=List[RecommendationResponse])
|
| 409 |
+
async def get_similar_items(request: ItemSimilarityRequest):
|
| 410 |
+
"""Get items similar to a given item."""
|
| 411 |
+
|
| 412 |
+
if recommendation_engine is None:
|
| 413 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 414 |
+
|
| 415 |
+
try:
|
| 416 |
+
recommendations = recommendation_engine.recommend_items_content_based(
|
| 417 |
+
seed_item_id=request.item_id,
|
| 418 |
+
k=request.num_recommendations
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
formatted_recommendations = []
|
| 422 |
+
for item_id, score, item_info in recommendations:
|
| 423 |
+
formatted_recommendations.append(
|
| 424 |
+
RecommendationResponse(
|
| 425 |
+
item_id=item_id,
|
| 426 |
+
score=score,
|
| 427 |
+
item_info=ItemInfo(**item_info)
|
| 428 |
+
)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
return formatted_recommendations
|
| 432 |
+
|
| 433 |
+
except Exception as e:
|
| 434 |
+
raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@app.post("/predict-rating", response_model=RatingPredictionResponse)
|
| 438 |
+
async def predict_user_item_rating(request: RatingPredictionRequest):
|
| 439 |
+
"""Predict rating for a user-item pair."""
|
| 440 |
+
|
| 441 |
+
if recommendation_engine is None:
|
| 442 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 443 |
+
|
| 444 |
+
try:
|
| 445 |
+
user_profile = request.user_profile
|
| 446 |
+
|
| 447 |
+
predicted_rating = recommendation_engine.predict_rating(
|
| 448 |
+
age=user_profile.age,
|
| 449 |
+
gender=user_profile.gender,
|
| 450 |
+
income=user_profile.income,
|
| 451 |
+
interaction_history=user_profile.interaction_history,
|
| 452 |
+
item_id=request.item_id
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
item_info = recommendation_engine._get_item_info(request.item_id)
|
| 456 |
+
|
| 457 |
+
return RatingPredictionResponse(
|
| 458 |
+
user_profile=user_profile,
|
| 459 |
+
item_id=request.item_id,
|
| 460 |
+
predicted_rating=predicted_rating,
|
| 461 |
+
item_info=ItemInfo(**item_info)
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
except Exception as e:
|
| 465 |
+
raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
@app.get("/items/{item_id}", response_model=ItemInfo)
|
| 469 |
+
async def get_item_info(item_id: int):
|
| 470 |
+
"""Get information about a specific item."""
|
| 471 |
+
|
| 472 |
+
if recommendation_engine is None:
|
| 473 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 474 |
+
|
| 475 |
+
try:
|
| 476 |
+
item_info = recommendation_engine._get_item_info(item_id)
|
| 477 |
+
return ItemInfo(**item_info)
|
| 478 |
+
|
| 479 |
+
except Exception as e:
|
| 480 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
@app.get("/items")
|
| 484 |
+
async def get_sample_items(limit: int = 20):
|
| 485 |
+
"""Get a sample of items for testing."""
|
| 486 |
+
|
| 487 |
+
if recommendation_engine is None:
|
| 488 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 489 |
+
|
| 490 |
+
try:
|
| 491 |
+
# Get sample items from the dataframe
|
| 492 |
+
sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
|
| 493 |
+
|
| 494 |
+
items = []
|
| 495 |
+
for _, row in sample_items.iterrows():
|
| 496 |
+
items.append({
|
| 497 |
+
"product_id": int(row['product_id']),
|
| 498 |
+
"category_id": int(row['category_id']),
|
| 499 |
+
"category_code": str(row['category_code']),
|
| 500 |
+
"brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
|
| 501 |
+
"price": float(row['price'])
|
| 502 |
+
})
|
| 503 |
+
|
| 504 |
+
return {"items": items, "total": len(items), "training_approach": "2-phase"}
|
| 505 |
+
|
| 506 |
+
except Exception as e:
|
| 507 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
if __name__ == "__main__":
|
| 511 |
+
print("🚀 Starting 2-Phase Training Recommendation API...")
|
| 512 |
+
print("📊 Training approach: Pre-trained item tower + Joint fine-tuning")
|
| 513 |
+
print("🌐 Server will be available at: http://localhost:8000")
|
| 514 |
+
print("📚 API docs at: http://localhost:8000/docs")
|
| 515 |
+
|
| 516 |
+
uvicorn.run(
|
| 517 |
+
"api_2phase:app",
|
| 518 |
+
host="0.0.0.0",
|
| 519 |
+
port=8000,
|
| 520 |
+
reload=True
|
| 521 |
+
)
|
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
API for Single Joint Trained Recommendation System
|
| 4 |
+
|
| 5 |
+
This API serves recommendations from a model trained using the single joint approach:
|
| 6 |
+
- Both user and item towers trained simultaneously from scratch
|
| 7 |
+
- End-to-end optimization without pre-training phases
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python api_joint.py
|
| 11 |
+
|
| 12 |
+
Then access: http://localhost:8000
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from fastapi import FastAPI, HTTPException
|
| 16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
+
from pydantic import BaseModel
|
| 18 |
+
from typing import List, Optional, Dict, Any
|
| 19 |
+
import uvicorn
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import pandas as pd
|
| 23 |
+
|
| 24 |
+
# Add src to path for imports and set working directory
|
| 25 |
+
parent_dir = os.path.dirname(__file__)
|
| 26 |
+
sys.path.append(parent_dir)
|
| 27 |
+
os.chdir(parent_dir) # Change to project root directory
|
| 28 |
+
|
| 29 |
+
from src.inference.recommendation_engine import RecommendationEngine
|
| 30 |
+
from src.utils.real_user_selector import RealUserSelector
|
| 31 |
+
|
| 32 |
+
# Initialize FastAPI app
|
| 33 |
+
app = FastAPI(
|
| 34 |
+
title="Two-Tower Recommendation API (Single Joint Training)",
|
| 35 |
+
description="API for serving recommendations using a two-tower architecture trained with single joint approach",
|
| 36 |
+
version="1.0.0-joint"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Add CORS middleware
|
| 40 |
+
app.add_middleware(
|
| 41 |
+
CORSMiddleware,
|
| 42 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 43 |
+
allow_credentials=True,
|
| 44 |
+
allow_methods=["*"],
|
| 45 |
+
allow_headers=["*"],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Global instances
|
| 49 |
+
recommendation_engine = None
|
| 50 |
+
enhanced_recommendation_engine = None
|
| 51 |
+
real_user_selector = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Pydantic models for request/response
|
| 55 |
+
class UserProfile(BaseModel):
|
| 56 |
+
age: int
|
| 57 |
+
gender: str # "male" or "female"
|
| 58 |
+
income: float
|
| 59 |
+
interaction_history: Optional[List[int]] = []
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RecommendationRequest(BaseModel):
|
| 63 |
+
user_profile: UserProfile
|
| 64 |
+
num_recommendations: int = 10
|
| 65 |
+
recommendation_type: str = "hybrid" # "collaborative", "content", "hybrid", "enhanced", "enhanced_128d", "category_focused"
|
| 66 |
+
collaborative_weight: Optional[float] = 0.7
|
| 67 |
+
category_boost: Optional[float] = 1.5 # For enhanced recommendations
|
| 68 |
+
enable_category_boost: Optional[bool] = True
|
| 69 |
+
enable_diversity: Optional[bool] = True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ItemSimilarityRequest(BaseModel):
|
| 73 |
+
item_id: int
|
| 74 |
+
num_recommendations: int = 10
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class RatingPredictionRequest(BaseModel):
|
| 78 |
+
user_profile: UserProfile
|
| 79 |
+
item_id: int
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ItemInfo(BaseModel):
|
| 83 |
+
product_id: int
|
| 84 |
+
category_id: int
|
| 85 |
+
category_code: str
|
| 86 |
+
brand: str
|
| 87 |
+
price: float
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class RecommendationResponse(BaseModel):
|
| 91 |
+
item_id: int
|
| 92 |
+
score: float
|
| 93 |
+
item_info: ItemInfo
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class RecommendationsResponse(BaseModel):
|
| 97 |
+
recommendations: List[RecommendationResponse]
|
| 98 |
+
user_profile: UserProfile
|
| 99 |
+
recommendation_type: str
|
| 100 |
+
total_count: int
|
| 101 |
+
training_approach: str = "single-joint"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class RatingPredictionResponse(BaseModel):
|
| 105 |
+
user_profile: UserProfile
|
| 106 |
+
item_id: int
|
| 107 |
+
predicted_rating: float
|
| 108 |
+
item_info: ItemInfo
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class RealUserProfile(BaseModel):
|
| 112 |
+
user_id: int
|
| 113 |
+
age: int
|
| 114 |
+
gender: str
|
| 115 |
+
income: int
|
| 116 |
+
interaction_history: List[int]
|
| 117 |
+
interaction_stats: Dict[str, int]
|
| 118 |
+
interaction_pattern: str
|
| 119 |
+
summary: str
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class RealUsersResponse(BaseModel):
|
| 123 |
+
users: List[RealUserProfile]
|
| 124 |
+
total_count: int
|
| 125 |
+
dataset_summary: Dict[str, Any]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.on_event("startup")
|
| 129 |
+
async def startup_event():
|
| 130 |
+
"""Initialize the recommendation engines and real user selector on startup."""
|
| 131 |
+
global recommendation_engine, enhanced_recommendation_engine, real_user_selector
|
| 132 |
+
|
| 133 |
+
print("🚀 Starting Single Joint Training API...")
|
| 134 |
+
print(" Training approach: End-to-end joint optimization from scratch")
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
print("Loading single joint trained recommendation engine...")
|
| 138 |
+
recommendation_engine = RecommendationEngine()
|
| 139 |
+
print("✅ Single joint recommendation engine loaded successfully!")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"❌ Error loading recommendation engine: {e}")
|
| 142 |
+
recommendation_engine = None
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
print("Loading enhanced recommendation engine...")
|
| 146 |
+
# Try enhanced 128D engine first, fallback to regular enhanced
|
| 147 |
+
try:
|
| 148 |
+
from src.inference.enhanced_recommendation_engine_128d import Enhanced128DRecommendationEngine
|
| 149 |
+
enhanced_recommendation_engine = Enhanced128DRecommendationEngine()
|
| 150 |
+
print("✅ Using Enhanced 128D Recommendation Engine")
|
| 151 |
+
except:
|
| 152 |
+
from src.inference.enhanced_recommendation_engine import EnhancedRecommendationEngine
|
| 153 |
+
enhanced_recommendation_engine = EnhancedRecommendationEngine()
|
| 154 |
+
print("⚠️ Using fallback Enhanced Recommendation Engine")
|
| 155 |
+
print("Enhanced recommendation engine loaded successfully!")
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Error loading enhanced recommendation engine: {e}")
|
| 158 |
+
enhanced_recommendation_engine = None
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
print("Loading real user selector...")
|
| 162 |
+
real_user_selector = RealUserSelector()
|
| 163 |
+
print("Real user selector loaded successfully!")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Error loading real user selector: {e}")
|
| 166 |
+
real_user_selector = None
|
| 167 |
+
|
| 168 |
+
print("🎯 Single Joint API ready to serve recommendations!")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@app.get("/")
|
| 172 |
+
async def root():
|
| 173 |
+
"""Root endpoint with API information."""
|
| 174 |
+
return {
|
| 175 |
+
"message": "Two-Tower Recommendation API (Single Joint Training)",
|
| 176 |
+
"version": "1.0.0-joint",
|
| 177 |
+
"training_approach": "single-joint (end-to-end optimization from scratch)",
|
| 178 |
+
"status": "active" if recommendation_engine is not None else "initialization_failed"
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@app.get("/health")
|
| 183 |
+
async def health_check():
|
| 184 |
+
"""Health check endpoint."""
|
| 185 |
+
return {
|
| 186 |
+
"status": "healthy" if recommendation_engine is not None else "unhealthy",
|
| 187 |
+
"engine_loaded": recommendation_engine is not None,
|
| 188 |
+
"training_approach": "single-joint"
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@app.get("/model-info")
|
| 193 |
+
async def model_info():
|
| 194 |
+
"""Get information about the loaded model."""
|
| 195 |
+
if recommendation_engine is None:
|
| 196 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"training_approach": "single-joint",
|
| 200 |
+
"description": "User and item towers trained simultaneously from scratch",
|
| 201 |
+
"advantages": [
|
| 202 |
+
"End-to-end optimization for better task alignment",
|
| 203 |
+
"No pre-training phase required",
|
| 204 |
+
"Faster overall training pipeline",
|
| 205 |
+
"Direct optimization for recommendation objectives"
|
| 206 |
+
],
|
| 207 |
+
"embedding_dimension": 128,
|
| 208 |
+
"item_vocab_size": len(recommendation_engine.data_processor.item_vocab) if recommendation_engine.data_processor else "unknown",
|
| 209 |
+
"artifacts_loaded": {
|
| 210 |
+
"user_tower": "src/artifacts/user_tower_weights_best",
|
| 211 |
+
"item_tower_joint": "src/artifacts/item_tower_weights_finetuned_best",
|
| 212 |
+
"rating_model": "src/artifacts/rating_model_weights_best"
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@app.get("/real-users", response_model=RealUsersResponse)
|
| 218 |
+
async def get_real_users(count: int = 100, min_interactions: int = 5):
|
| 219 |
+
"""Get real user profiles with genuine interaction histories."""
|
| 220 |
+
|
| 221 |
+
if real_user_selector is None:
|
| 222 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
# Get real user profiles
|
| 226 |
+
real_users = real_user_selector.get_real_users(n=count, min_interactions=min_interactions)
|
| 227 |
+
|
| 228 |
+
# Get dataset summary
|
| 229 |
+
dataset_summary = real_user_selector.get_dataset_summary()
|
| 230 |
+
|
| 231 |
+
# Format users for response
|
| 232 |
+
formatted_users = []
|
| 233 |
+
for user in real_users:
|
| 234 |
+
formatted_users.append(RealUserProfile(**user))
|
| 235 |
+
|
| 236 |
+
return RealUsersResponse(
|
| 237 |
+
users=formatted_users,
|
| 238 |
+
total_count=len(formatted_users),
|
| 239 |
+
dataset_summary=dataset_summary
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving real users: {str(e)}")
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@app.get("/real-users/{user_id}")
|
| 247 |
+
async def get_real_user_details(user_id: int):
|
| 248 |
+
"""Get detailed interaction breakdown for a specific real user."""
|
| 249 |
+
|
| 250 |
+
if real_user_selector is None:
|
| 251 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
user_details = real_user_selector.get_user_interaction_details(user_id)
|
| 255 |
+
|
| 256 |
+
if "error" in user_details:
|
| 257 |
+
raise HTTPException(status_code=404, detail=user_details["error"])
|
| 258 |
+
|
| 259 |
+
return user_details
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving user details: {str(e)}")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@app.get("/dataset-summary")
|
| 266 |
+
async def get_dataset_summary():
|
| 267 |
+
"""Get summary statistics of the real dataset."""
|
| 268 |
+
|
| 269 |
+
if real_user_selector is None:
|
| 270 |
+
raise HTTPException(status_code=503, detail="Real user selector not available")
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
return real_user_selector.get_dataset_summary()
|
| 274 |
+
|
| 275 |
+
except Exception as e:
|
| 276 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving dataset summary: {str(e)}")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@app.post("/recommendations", response_model=RecommendationsResponse)
|
| 280 |
+
async def get_recommendations(request: RecommendationRequest):
|
| 281 |
+
"""Get item recommendations for a user."""
|
| 282 |
+
|
| 283 |
+
if recommendation_engine is None:
|
| 284 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
user_profile = request.user_profile
|
| 288 |
+
|
| 289 |
+
# Generate recommendations based on type
|
| 290 |
+
if request.recommendation_type == "collaborative":
|
| 291 |
+
recommendations = recommendation_engine.recommend_items_collaborative(
|
| 292 |
+
age=user_profile.age,
|
| 293 |
+
gender=user_profile.gender,
|
| 294 |
+
income=user_profile.income,
|
| 295 |
+
interaction_history=user_profile.interaction_history,
|
| 296 |
+
k=request.num_recommendations
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
elif request.recommendation_type == "content":
|
| 300 |
+
if not user_profile.interaction_history:
|
| 301 |
+
raise HTTPException(
|
| 302 |
+
status_code=400,
|
| 303 |
+
detail="Content-based recommendations require interaction history"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Use most recent interaction as seed
|
| 307 |
+
seed_item = user_profile.interaction_history[-1]
|
| 308 |
+
recommendations = recommendation_engine.recommend_items_content_based(
|
| 309 |
+
seed_item_id=seed_item,
|
| 310 |
+
k=request.num_recommendations
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
elif request.recommendation_type == "hybrid":
|
| 314 |
+
recommendations = recommendation_engine.recommend_items_hybrid(
|
| 315 |
+
age=user_profile.age,
|
| 316 |
+
gender=user_profile.gender,
|
| 317 |
+
income=user_profile.income,
|
| 318 |
+
interaction_history=user_profile.interaction_history,
|
| 319 |
+
k=request.num_recommendations,
|
| 320 |
+
collaborative_weight=request.collaborative_weight
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
elif request.recommendation_type == "enhanced":
|
| 324 |
+
if enhanced_recommendation_engine is None:
|
| 325 |
+
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 326 |
+
|
| 327 |
+
# Check if it's the 128D engine or fallback
|
| 328 |
+
if hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 329 |
+
# 128D Enhanced engine
|
| 330 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 331 |
+
age=user_profile.age,
|
| 332 |
+
gender=user_profile.gender,
|
| 333 |
+
income=user_profile.income,
|
| 334 |
+
interaction_history=user_profile.interaction_history,
|
| 335 |
+
k=request.num_recommendations,
|
| 336 |
+
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 337 |
+
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 338 |
+
)
|
| 339 |
+
else:
|
| 340 |
+
# Fallback enhanced engine
|
| 341 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced_hybrid(
|
| 342 |
+
age=user_profile.age,
|
| 343 |
+
gender=user_profile.gender,
|
| 344 |
+
income=user_profile.income,
|
| 345 |
+
interaction_history=user_profile.interaction_history,
|
| 346 |
+
k=request.num_recommendations,
|
| 347 |
+
collaborative_weight=request.collaborative_weight,
|
| 348 |
+
category_boost=request.category_boost,
|
| 349 |
+
enable_category_boost=request.enable_category_boost,
|
| 350 |
+
enable_diversity=request.enable_diversity
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
elif request.recommendation_type == "enhanced_128d":
|
| 354 |
+
if enhanced_recommendation_engine is None or not hasattr(enhanced_recommendation_engine, 'recommend_items_enhanced'):
|
| 355 |
+
raise HTTPException(status_code=503, detail="Enhanced 128D recommendation engine not available")
|
| 356 |
+
|
| 357 |
+
recommendations = enhanced_recommendation_engine.recommend_items_enhanced(
|
| 358 |
+
age=user_profile.age,
|
| 359 |
+
gender=user_profile.gender,
|
| 360 |
+
income=user_profile.income,
|
| 361 |
+
interaction_history=user_profile.interaction_history,
|
| 362 |
+
k=request.num_recommendations,
|
| 363 |
+
diversity_weight=0.3 if request.enable_diversity else 0.0,
|
| 364 |
+
category_boost=request.category_boost if request.enable_category_boost else 1.0
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
elif request.recommendation_type == "category_focused":
|
| 368 |
+
if enhanced_recommendation_engine is None:
|
| 369 |
+
raise HTTPException(status_code=503, detail="Enhanced recommendation engine not available")
|
| 370 |
+
|
| 371 |
+
recommendations = enhanced_recommendation_engine.recommend_items_category_focused(
|
| 372 |
+
age=user_profile.age,
|
| 373 |
+
gender=user_profile.gender,
|
| 374 |
+
income=user_profile.income,
|
| 375 |
+
interaction_history=user_profile.interaction_history,
|
| 376 |
+
k=request.num_recommendations,
|
| 377 |
+
focus_percentage=0.8
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
else:
|
| 381 |
+
raise HTTPException(
|
| 382 |
+
status_code=400,
|
| 383 |
+
detail="Invalid recommendation_type. Must be 'collaborative', 'content', 'hybrid', 'enhanced', 'enhanced_128d', or 'category_focused'"
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Format response
|
| 387 |
+
formatted_recommendations = []
|
| 388 |
+
for item_id, score, item_info in recommendations:
|
| 389 |
+
formatted_recommendations.append(
|
| 390 |
+
RecommendationResponse(
|
| 391 |
+
item_id=item_id,
|
| 392 |
+
score=score,
|
| 393 |
+
item_info=ItemInfo(**item_info)
|
| 394 |
+
)
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
return RecommendationsResponse(
|
| 398 |
+
recommendations=formatted_recommendations,
|
| 399 |
+
user_profile=user_profile,
|
| 400 |
+
recommendation_type=request.recommendation_type,
|
| 401 |
+
total_count=len(formatted_recommendations),
|
| 402 |
+
training_approach="single-joint"
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
except Exception as e:
|
| 406 |
+
raise HTTPException(status_code=500, detail=f"Error generating recommendations: {str(e)}")
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
@app.post("/item-similarity", response_model=List[RecommendationResponse])
|
| 410 |
+
async def get_similar_items(request: ItemSimilarityRequest):
|
| 411 |
+
"""Get items similar to a given item."""
|
| 412 |
+
|
| 413 |
+
if recommendation_engine is None:
|
| 414 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
recommendations = recommendation_engine.recommend_items_content_based(
|
| 418 |
+
seed_item_id=request.item_id,
|
| 419 |
+
k=request.num_recommendations
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
formatted_recommendations = []
|
| 423 |
+
for item_id, score, item_info in recommendations:
|
| 424 |
+
formatted_recommendations.append(
|
| 425 |
+
RecommendationResponse(
|
| 426 |
+
item_id=item_id,
|
| 427 |
+
score=score,
|
| 428 |
+
item_info=ItemInfo(**item_info)
|
| 429 |
+
)
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
return formatted_recommendations
|
| 433 |
+
|
| 434 |
+
except Exception as e:
|
| 435 |
+
raise HTTPException(status_code=500, detail=f"Error finding similar items: {str(e)}")
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@app.post("/predict-rating", response_model=RatingPredictionResponse)
|
| 439 |
+
async def predict_user_item_rating(request: RatingPredictionRequest):
|
| 440 |
+
"""Predict rating for a user-item pair."""
|
| 441 |
+
|
| 442 |
+
if recommendation_engine is None:
|
| 443 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
user_profile = request.user_profile
|
| 447 |
+
|
| 448 |
+
predicted_rating = recommendation_engine.predict_rating(
|
| 449 |
+
age=user_profile.age,
|
| 450 |
+
gender=user_profile.gender,
|
| 451 |
+
income=user_profile.income,
|
| 452 |
+
interaction_history=user_profile.interaction_history,
|
| 453 |
+
item_id=request.item_id
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
item_info = recommendation_engine._get_item_info(request.item_id)
|
| 457 |
+
|
| 458 |
+
return RatingPredictionResponse(
|
| 459 |
+
user_profile=user_profile,
|
| 460 |
+
item_id=request.item_id,
|
| 461 |
+
predicted_rating=predicted_rating,
|
| 462 |
+
item_info=ItemInfo(**item_info)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
except Exception as e:
|
| 466 |
+
raise HTTPException(status_code=500, detail=f"Error predicting rating: {str(e)}")
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@app.get("/items/{item_id}", response_model=ItemInfo)
|
| 470 |
+
async def get_item_info(item_id: int):
|
| 471 |
+
"""Get information about a specific item."""
|
| 472 |
+
|
| 473 |
+
if recommendation_engine is None:
|
| 474 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 475 |
+
|
| 476 |
+
try:
|
| 477 |
+
item_info = recommendation_engine._get_item_info(item_id)
|
| 478 |
+
return ItemInfo(**item_info)
|
| 479 |
+
|
| 480 |
+
except Exception as e:
|
| 481 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving item info: {str(e)}")
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@app.get("/items")
|
| 485 |
+
async def get_sample_items(limit: int = 20):
|
| 486 |
+
"""Get a sample of items for testing."""
|
| 487 |
+
|
| 488 |
+
if recommendation_engine is None:
|
| 489 |
+
raise HTTPException(status_code=503, detail="Recommendation engine not available")
|
| 490 |
+
|
| 491 |
+
try:
|
| 492 |
+
# Get sample items from the dataframe
|
| 493 |
+
sample_items = recommendation_engine.items_df.sample(n=min(limit, len(recommendation_engine.items_df)))
|
| 494 |
+
|
| 495 |
+
items = []
|
| 496 |
+
for _, row in sample_items.iterrows():
|
| 497 |
+
items.append({
|
| 498 |
+
"product_id": int(row['product_id']),
|
| 499 |
+
"category_id": int(row['category_id']),
|
| 500 |
+
"category_code": str(row['category_code']),
|
| 501 |
+
"brand": str(row['brand']) if pd.notna(row['brand']) else 'Unknown',
|
| 502 |
+
"price": float(row['price'])
|
| 503 |
+
})
|
| 504 |
+
|
| 505 |
+
return {"items": items, "total": len(items), "training_approach": "single-joint"}
|
| 506 |
+
|
| 507 |
+
except Exception as e:
|
| 508 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving sample items: {str(e)}")
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
if __name__ == "__main__":
|
| 512 |
+
print("🚀 Starting Single Joint Training Recommendation API...")
|
| 513 |
+
print("⚡ Training approach: End-to-end joint optimization from scratch")
|
| 514 |
+
print("🌐 Server will be available at: http://localhost:8000")
|
| 515 |
+
print("📚 API docs at: http://localhost:8000/docs")
|
| 516 |
+
|
| 517 |
+
uvicorn.run(
|
| 518 |
+
"api_joint:app",
|
| 519 |
+
host="0.0.0.0",
|
| 520 |
+
port=8000,
|
| 521 |
+
reload=True
|
| 522 |
+
)
|
|
@@ -1,33 +1,1048 @@
|
|
| 1 |
.App {
|
| 2 |
text-align: center;
|
|
|
|
| 3 |
}
|
| 4 |
|
| 5 |
-
.
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
}
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
}
|
| 15 |
|
| 16 |
-
.
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
padding: 20px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
color: white;
|
| 20 |
}
|
| 21 |
|
| 22 |
-
.
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
}
|
| 30 |
-
|
| 31 |
-
|
|
|
|
| 32 |
}
|
| 33 |
}
|
|
|
|
| 1 |
.App {
|
| 2 |
text-align: center;
|
| 3 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
|
| 4 |
}
|
| 5 |
|
| 6 |
+
.container {
|
| 7 |
+
max-width: 1200px;
|
| 8 |
+
margin: 0 auto;
|
| 9 |
+
padding: 20px;
|
| 10 |
+
text-align: left;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
.header {
|
| 14 |
+
text-align: center;
|
| 15 |
+
margin-bottom: 30px;
|
| 16 |
+
padding: 20px;
|
| 17 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 18 |
+
color: white;
|
| 19 |
+
border-radius: 10px;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.header h1 {
|
| 23 |
+
margin: 0 0 10px 0;
|
| 24 |
+
font-size: 2.5rem;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.header p {
|
| 28 |
+
margin: 0;
|
| 29 |
+
opacity: 0.9;
|
| 30 |
+
font-size: 1.1rem;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.dataset-info {
|
| 34 |
+
margin-top: 15px;
|
| 35 |
+
padding: 10px;
|
| 36 |
+
background: rgba(255, 255, 255, 0.2);
|
| 37 |
+
border-radius: 5px;
|
| 38 |
+
font-size: 0.95rem;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
/* Real User Selector */
|
| 42 |
+
.real-user-selector {
|
| 43 |
+
background: #e8f5e8;
|
| 44 |
+
padding: 25px;
|
| 45 |
+
border-radius: 10px;
|
| 46 |
+
margin-bottom: 30px;
|
| 47 |
+
border: 1px solid #c3e6c3;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.real-user-selector h2 {
|
| 51 |
+
margin-top: 0;
|
| 52 |
+
color: #2d5a2d;
|
| 53 |
+
border-bottom: 2px solid #28a745;
|
| 54 |
+
padding-bottom: 10px;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.user-selector-controls {
|
| 58 |
+
display: flex;
|
| 59 |
+
align-items: center;
|
| 60 |
+
gap: 15px;
|
| 61 |
+
margin-bottom: 20px;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
.user-selector-controls label {
|
| 65 |
+
font-weight: bold;
|
| 66 |
+
min-width: 200px;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
.user-selector-controls select {
|
| 70 |
+
flex: 1;
|
| 71 |
+
padding: 8px 12px;
|
| 72 |
+
border: 1px solid #28a745;
|
| 73 |
+
border-radius: 5px;
|
| 74 |
+
font-size: 14px;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
.selected-real-user {
|
| 78 |
+
background: rgba(255, 255, 255, 0.8);
|
| 79 |
+
padding: 20px;
|
| 80 |
+
border-radius: 8px;
|
| 81 |
+
border: 1px solid #28a745;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
.real-user-stats {
|
| 85 |
+
display: grid;
|
| 86 |
+
gap: 12px;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.user-stat {
|
| 90 |
+
display: flex;
|
| 91 |
+
justify-content: space-between;
|
| 92 |
+
padding: 8px 0;
|
| 93 |
+
border-bottom: 1px solid #ddd;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
.stat-label {
|
| 97 |
+
font-weight: bold;
|
| 98 |
+
color: #495057;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
.stat-value {
|
| 102 |
+
color: #28a745;
|
| 103 |
+
font-weight: 600;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.real-interaction-summary {
|
| 107 |
+
display: flex;
|
| 108 |
+
gap: 20px;
|
| 109 |
+
justify-content: center;
|
| 110 |
+
margin: 20px 0;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.real-history-info {
|
| 114 |
+
background: rgba(255, 255, 255, 0.8);
|
| 115 |
+
padding: 15px;
|
| 116 |
+
border-radius: 8px;
|
| 117 |
+
margin-top: 15px;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
.real-history-info p {
|
| 121 |
+
margin: 8px 0;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/* Expand Interactions Button */
|
| 125 |
+
.expand-interactions-btn {
|
| 126 |
+
background: #007bff;
|
| 127 |
+
color: white;
|
| 128 |
+
border: 1px solid #007bff;
|
| 129 |
+
border-radius: 5px;
|
| 130 |
+
padding: 10px 20px;
|
| 131 |
+
margin-top: 15px;
|
| 132 |
+
cursor: pointer;
|
| 133 |
+
font-size: 14px;
|
| 134 |
+
transition: background-color 0.3s;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.expand-interactions-btn:hover:not(:disabled) {
|
| 138 |
+
background: #0056b3;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.expand-interactions-btn:disabled {
|
| 142 |
+
opacity: 0.6;
|
| 143 |
+
cursor: not-allowed;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/* User Interactions Timeline */
|
| 147 |
+
.user-interactions-timeline {
|
| 148 |
+
margin-top: 20px;
|
| 149 |
+
padding: 20px;
|
| 150 |
+
background: rgba(255, 255, 255, 0.95);
|
| 151 |
+
border-radius: 8px;
|
| 152 |
+
border: 1px solid #007bff;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.user-interactions-timeline h4 {
|
| 156 |
+
margin-top: 0;
|
| 157 |
+
color: #007bff;
|
| 158 |
+
border-bottom: 2px solid #007bff;
|
| 159 |
+
padding-bottom: 8px;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.user-interactions-timeline h5 {
|
| 163 |
+
color: #495057;
|
| 164 |
+
margin: 15px 0 10px 0;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
.timeline-stats {
|
| 168 |
+
display: flex;
|
| 169 |
+
flex-wrap: wrap;
|
| 170 |
+
gap: 20px;
|
| 171 |
+
margin: 15px 0;
|
| 172 |
+
padding: 15px;
|
| 173 |
+
background: #f8f9fa;
|
| 174 |
+
border-radius: 5px;
|
| 175 |
+
font-size: 14px;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.timeline-stats span {
|
| 179 |
+
color: #495057;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
/* Interactions List */
|
| 183 |
+
.interactions-list {
|
| 184 |
+
max-height: 400px;
|
| 185 |
+
overflow-y: auto;
|
| 186 |
+
border: 1px solid #e9ecef;
|
| 187 |
+
border-radius: 5px;
|
| 188 |
+
background: white;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.interaction-timeline-item {
|
| 192 |
+
display: flex;
|
| 193 |
+
align-items: center;
|
| 194 |
+
padding: 12px 15px;
|
| 195 |
+
border-bottom: 1px solid #e9ecef;
|
| 196 |
+
transition: background-color 0.2s;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
.interaction-timeline-item:hover {
|
| 200 |
+
background: #f8f9fa;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.interaction-timeline-item:last-child {
|
| 204 |
+
border-bottom: none;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
.interaction-timeline-time {
|
| 208 |
+
flex: 0 0 180px;
|
| 209 |
+
font-size: 12px;
|
| 210 |
+
color: #6c757d;
|
| 211 |
+
font-family: monospace;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
.interaction-timeline-content {
|
| 215 |
+
flex: 1;
|
| 216 |
+
display: flex;
|
| 217 |
+
flex-direction: column;
|
| 218 |
+
gap: 5px;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.interaction-main-info {
|
| 222 |
+
display: flex;
|
| 223 |
+
align-items: center;
|
| 224 |
+
gap: 10px;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
.interaction-item-details {
|
| 228 |
+
display: flex;
|
| 229 |
+
align-items: center;
|
| 230 |
+
gap: 15px;
|
| 231 |
+
padding-left: 20px;
|
| 232 |
+
font-size: 13px;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
.interaction-item-id {
|
| 236 |
+
color: #495057;
|
| 237 |
+
font-size: 13px;
|
| 238 |
+
font-weight: 500;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
.interaction-icon {
|
| 242 |
+
font-size: 16px;
|
| 243 |
+
min-width: 20px;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
.item-brand {
|
| 247 |
+
color: #007bff;
|
| 248 |
+
font-size: 14px;
|
| 249 |
+
min-width: 120px;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.item-category {
|
| 253 |
+
color: #6c757d;
|
| 254 |
+
font-size: 12px;
|
| 255 |
+
background: #f8f9fa;
|
| 256 |
+
padding: 2px 6px;
|
| 257 |
+
border-radius: 3px;
|
| 258 |
+
border: 1px solid #e9ecef;
|
| 259 |
+
min-width: 150px;
|
| 260 |
+
text-align: center;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
.item-price {
|
| 264 |
+
color: #28a745;
|
| 265 |
+
font-weight: 600;
|
| 266 |
+
font-size: 13px;
|
| 267 |
+
min-width: 70px;
|
| 268 |
+
text-align: right;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
/* Interaction type badges in timeline */
|
| 272 |
+
.interaction-timeline-content .interaction-type {
|
| 273 |
+
font-size: 11px;
|
| 274 |
+
font-weight: bold;
|
| 275 |
+
padding: 3px 8px;
|
| 276 |
+
border-radius: 12px;
|
| 277 |
+
text-transform: uppercase;
|
| 278 |
+
letter-spacing: 0.5px;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
.interaction-timeline-content .interaction-type.view {
|
| 282 |
+
background: #d1ecf1;
|
| 283 |
+
color: #0c5460;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
.interaction-timeline-content .interaction-type.cart {
|
| 287 |
+
background: #fff3cd;
|
| 288 |
+
color: #856404;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
.interaction-timeline-content .interaction-type.purchase {
|
| 292 |
+
background: #d4edda;
|
| 293 |
+
color: #155724;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
/* Category Analysis Columns */
|
| 297 |
+
.category-analysis {
|
| 298 |
+
margin-top: 20px;
|
| 299 |
+
padding: 20px;
|
| 300 |
+
background: rgba(255, 255, 255, 0.95);
|
| 301 |
+
border-radius: 8px;
|
| 302 |
+
border: 1px solid #6c757d;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
.category-analysis h4 {
|
| 306 |
+
margin-top: 0;
|
| 307 |
+
color: #495057;
|
| 308 |
+
border-bottom: 2px solid #6c757d;
|
| 309 |
+
padding-bottom: 8px;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
.category-columns {
|
| 313 |
+
display: grid;
|
| 314 |
+
grid-template-columns: 1fr 1fr;
|
| 315 |
+
gap: 30px;
|
| 316 |
+
margin: 20px 0;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
.category-column {
|
| 320 |
+
background: #f8f9fa;
|
| 321 |
+
padding: 15px;
|
| 322 |
+
border-radius: 8px;
|
| 323 |
+
border: 1px solid #dee2e6;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
.category-column h5 {
|
| 327 |
+
margin-top: 0;
|
| 328 |
+
margin-bottom: 15px;
|
| 329 |
+
color: #495057;
|
| 330 |
+
font-size: 16px;
|
| 331 |
+
padding-bottom: 5px;
|
| 332 |
+
border-bottom: 1px solid #dee2e6;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
.category-percentages {
|
| 336 |
+
display: flex;
|
| 337 |
+
flex-direction: column;
|
| 338 |
+
gap: 8px;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.category-item {
|
| 342 |
+
display: flex;
|
| 343 |
+
align-items: center;
|
| 344 |
+
gap: 10px;
|
| 345 |
+
padding: 6px 0;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
.category-bar-container {
|
| 349 |
+
flex: 0 0 80px;
|
| 350 |
+
height: 16px;
|
| 351 |
+
background: #e9ecef;
|
| 352 |
+
border-radius: 8px;
|
| 353 |
+
overflow: hidden;
|
| 354 |
+
position: relative;
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
.category-bar {
|
| 358 |
+
height: 100%;
|
| 359 |
+
border-radius: 8px;
|
| 360 |
+
transition: width 0.3s ease;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
.category-bar.user-category {
|
| 364 |
+
background: linear-gradient(90deg, #007bff, #0056b3);
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
.category-bar.rec-category-matched {
|
| 368 |
+
background: linear-gradient(90deg, #28a745, #1e7e34);
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
.category-bar.rec-category-new {
|
| 372 |
+
background: linear-gradient(90deg, #ffc107, #e0a800);
|
| 373 |
}
|
| 374 |
|
| 375 |
+
.category-label {
|
| 376 |
+
flex: 1;
|
| 377 |
+
font-size: 12px;
|
| 378 |
+
color: #495057;
|
| 379 |
+
text-overflow: ellipsis;
|
| 380 |
+
overflow: hidden;
|
| 381 |
+
white-space: nowrap;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
.category-percent {
|
| 385 |
+
flex: 0 0 40px;
|
| 386 |
+
font-size: 11px;
|
| 387 |
+
font-weight: bold;
|
| 388 |
+
text-align: right;
|
| 389 |
+
color: #495057;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
.match-indicator {
|
| 393 |
+
color: #28a745;
|
| 394 |
+
font-weight: bold;
|
| 395 |
+
font-size: 14px;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
.category-match-summary {
|
| 399 |
+
margin-top: 15px;
|
| 400 |
+
padding: 15px;
|
| 401 |
+
background: #e7f3ff;
|
| 402 |
+
border-radius: 6px;
|
| 403 |
+
border-left: 4px solid #007bff;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
.category-match-summary p {
|
| 407 |
+
margin: 0;
|
| 408 |
+
font-size: 14px;
|
| 409 |
+
color: #495057;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
.match-legend {
|
| 413 |
+
display: flex;
|
| 414 |
+
gap: 20px;
|
| 415 |
+
margin-top: 8px;
|
| 416 |
+
font-size: 12px;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
.legend-item {
|
| 420 |
+
display: flex;
|
| 421 |
+
align-items: center;
|
| 422 |
+
gap: 6px;
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
.legend-dot {
|
| 426 |
+
width: 12px;
|
| 427 |
+
height: 12px;
|
| 428 |
+
border-radius: 50%;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
.legend-dot.matched {
|
| 432 |
+
background: #28a745;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
.legend-dot.new {
|
| 436 |
+
background: #ffc107;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
/* Responsive design for smaller screens */
|
| 440 |
+
@media (max-width: 768px) {
|
| 441 |
+
.category-columns {
|
| 442 |
+
grid-template-columns: 1fr;
|
| 443 |
+
gap: 20px;
|
| 444 |
}
|
| 445 |
+
|
| 446 |
+
.category-bar-container {
|
| 447 |
+
flex: 0 0 60px;
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
.match-legend {
|
| 451 |
+
flex-direction: column;
|
| 452 |
+
gap: 8px;
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
/* Pagination Styles */
|
| 457 |
+
.pagination-info {
|
| 458 |
+
margin: 20px 0;
|
| 459 |
+
text-align: center;
|
| 460 |
}
|
| 461 |
|
| 462 |
+
.pagination-info p {
|
| 463 |
+
margin: 0 0 15px 0;
|
| 464 |
+
color: #6c757d;
|
| 465 |
+
font-size: 14px;
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
.pagination-controls {
|
| 469 |
+
display: flex;
|
| 470 |
+
align-items: center;
|
| 471 |
+
justify-content: center;
|
| 472 |
+
gap: 10px;
|
| 473 |
+
margin: 15px 0;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
.pagination-btn {
|
| 477 |
+
background: #007bff;
|
| 478 |
+
color: white;
|
| 479 |
+
border: 1px solid #007bff;
|
| 480 |
+
border-radius: 5px;
|
| 481 |
+
padding: 8px 16px;
|
| 482 |
+
cursor: pointer;
|
| 483 |
+
font-size: 14px;
|
| 484 |
+
transition: background-color 0.3s;
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
.pagination-btn:hover:not(:disabled) {
|
| 488 |
+
background: #0056b3;
|
| 489 |
+
border-color: #0056b3;
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
.pagination-btn:disabled {
|
| 493 |
+
background: #6c757d;
|
| 494 |
+
border-color: #6c757d;
|
| 495 |
+
cursor: not-allowed;
|
| 496 |
+
opacity: 0.65;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
.page-numbers {
|
| 500 |
+
display: flex;
|
| 501 |
+
align-items: center;
|
| 502 |
+
gap: 5px;
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
.page-number {
|
| 506 |
+
background: #f8f9fa;
|
| 507 |
+
color: #007bff;
|
| 508 |
+
border: 1px solid #dee2e6;
|
| 509 |
+
border-radius: 3px;
|
| 510 |
+
padding: 6px 10px;
|
| 511 |
+
cursor: pointer;
|
| 512 |
+
font-size: 14px;
|
| 513 |
+
transition: all 0.3s;
|
| 514 |
+
min-width: 35px;
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
.page-number:hover {
|
| 518 |
+
background: #e9ecef;
|
| 519 |
+
border-color: #007bff;
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
.page-number.active {
|
| 523 |
+
background: #007bff;
|
| 524 |
+
color: white;
|
| 525 |
+
border-color: #007bff;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
.pagination-ellipsis {
|
| 529 |
+
color: #6c757d;
|
| 530 |
+
padding: 0 5px;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
.bottom-pagination {
|
| 534 |
+
border-top: 1px solid #dee2e6;
|
| 535 |
+
padding-top: 20px;
|
| 536 |
+
margin-top: 20px;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
.page-indicator {
|
| 540 |
+
color: #6c757d;
|
| 541 |
+
font-size: 14px;
|
| 542 |
+
margin: 0 15px;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
/* Responsive pagination */
|
| 546 |
+
@media (max-width: 768px) {
|
| 547 |
+
.pagination-controls {
|
| 548 |
+
flex-wrap: wrap;
|
| 549 |
+
gap: 8px;
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
.page-numbers {
|
| 553 |
+
order: 3;
|
| 554 |
+
width: 100%;
|
| 555 |
+
justify-content: center;
|
| 556 |
+
margin-top: 10px;
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
.pagination-btn {
|
| 560 |
+
padding: 6px 12px;
|
| 561 |
+
font-size: 13px;
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
.page-number {
|
| 565 |
+
padding: 5px 8px;
|
| 566 |
+
min-width: 30px;
|
| 567 |
+
font-size: 13px;
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
/* User Profile Form */
|
| 572 |
+
.user-profile-form {
|
| 573 |
+
background: #f8f9fa;
|
| 574 |
+
padding: 25px;
|
| 575 |
+
border-radius: 10px;
|
| 576 |
+
margin-bottom: 30px;
|
| 577 |
+
border: 1px solid #e9ecef;
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
.user-profile-form h2 {
|
| 581 |
+
margin-top: 0;
|
| 582 |
+
color: #495057;
|
| 583 |
+
border-bottom: 2px solid #007bff;
|
| 584 |
+
padding-bottom: 10px;
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
.form-row {
|
| 588 |
+
display: flex;
|
| 589 |
+
gap: 20px;
|
| 590 |
+
flex-wrap: wrap;
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
.form-group {
|
| 594 |
+
flex: 1;
|
| 595 |
+
min-width: 200px;
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
.form-group label {
|
| 599 |
+
display: block;
|
| 600 |
+
margin-bottom: 5px;
|
| 601 |
+
font-weight: 600;
|
| 602 |
+
color: #495057;
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
.form-group input,
|
| 606 |
+
.form-group select {
|
| 607 |
+
width: 100%;
|
| 608 |
+
padding: 10px;
|
| 609 |
+
border: 1px solid #ced4da;
|
| 610 |
+
border-radius: 5px;
|
| 611 |
+
font-size: 14px;
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
.form-group input:focus,
|
| 615 |
+
.form-group select:focus {
|
| 616 |
+
outline: none;
|
| 617 |
+
border-color: #007bff;
|
| 618 |
+
box-shadow: 0 0 0 0.2rem rgba(0, 123, 255, 0.25);
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
/* Interaction Patterns */
|
| 622 |
+
.interaction-patterns {
|
| 623 |
+
background: #fff;
|
| 624 |
+
padding: 25px;
|
| 625 |
+
border-radius: 10px;
|
| 626 |
+
margin-bottom: 30px;
|
| 627 |
+
border: 1px solid #e9ecef;
|
| 628 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
.interaction-patterns h2 {
|
| 632 |
+
margin-top: 0;
|
| 633 |
+
color: #495057;
|
| 634 |
+
border-bottom: 2px solid #28a745;
|
| 635 |
+
padding-bottom: 10px;
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
.pattern-buttons {
|
| 639 |
+
display: flex;
|
| 640 |
+
gap: 15px;
|
| 641 |
+
margin: 20px 0;
|
| 642 |
+
flex-wrap: wrap;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
.pattern-btn {
|
| 646 |
+
padding: 15px 20px;
|
| 647 |
+
border: 2px solid #007bff;
|
| 648 |
+
background: white;
|
| 649 |
+
color: #007bff;
|
| 650 |
+
border-radius: 8px;
|
| 651 |
+
cursor: pointer;
|
| 652 |
+
transition: all 0.3s ease;
|
| 653 |
+
font-size: 14px;
|
| 654 |
+
text-align: center;
|
| 655 |
+
min-width: 120px;
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
.pattern-btn:hover {
|
| 659 |
+
background: #007bff;
|
| 660 |
+
color: white;
|
| 661 |
+
transform: translateY(-2px);
|
| 662 |
+
box-shadow: 0 4px 8px rgba(0,123,255,0.3);
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
.pattern-btn.active {
|
| 666 |
+
background: #007bff;
|
| 667 |
+
color: white;
|
| 668 |
+
box-shadow: 0 4px 8px rgba(0,123,255,0.3);
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
.pattern-btn small {
|
| 672 |
+
opacity: 0.8;
|
| 673 |
+
font-size: 12px;
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
.pattern-summary {
|
| 677 |
+
display: flex;
|
| 678 |
+
gap: 20px;
|
| 679 |
+
margin: 20px 0;
|
| 680 |
+
flex-wrap: wrap;
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
.summary-card {
|
| 684 |
+
background: white;
|
| 685 |
+
border: 1px solid #e9ecef;
|
| 686 |
+
border-radius: 8px;
|
| 687 |
padding: 20px;
|
| 688 |
+
text-align: center;
|
| 689 |
+
flex: 1;
|
| 690 |
+
min-width: 100px;
|
| 691 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
.summary-card.views {
|
| 695 |
+
border-left: 4px solid #17a2b8;
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
.summary-card.carts {
|
| 699 |
+
border-left: 4px solid #ffc107;
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
.summary-card.purchases {
|
| 703 |
+
border-left: 4px solid #28a745;
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
.summary-number {
|
| 707 |
+
font-size: 2rem;
|
| 708 |
+
font-weight: bold;
|
| 709 |
+
margin-bottom: 5px;
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
.summary-label {
|
| 713 |
+
color: #6c757d;
|
| 714 |
+
font-size: 14px;
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
/* Interaction History */
|
| 718 |
+
.interaction-history {
|
| 719 |
+
margin-top: 25px;
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
.interaction-history h3 {
|
| 723 |
+
color: #495057;
|
| 724 |
+
margin-bottom: 15px;
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
.interaction-item {
|
| 728 |
+
background: white;
|
| 729 |
+
border: 1px solid #e9ecef;
|
| 730 |
+
border-radius: 8px;
|
| 731 |
+
margin-bottom: 10px;
|
| 732 |
+
padding: 15px;
|
| 733 |
+
display: flex;
|
| 734 |
+
justify-content: space-between;
|
| 735 |
+
align-items: center;
|
| 736 |
+
transition: all 0.2s ease;
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
.interaction-item:hover {
|
| 740 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
| 741 |
+
transform: translateY(-1px);
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
.interaction-main {
|
| 745 |
+
display: flex;
|
| 746 |
+
align-items: center;
|
| 747 |
+
gap: 15px;
|
| 748 |
+
flex: 1;
|
| 749 |
+
}
|
| 750 |
+
|
| 751 |
+
.interaction-type {
|
| 752 |
+
padding: 4px 12px;
|
| 753 |
+
border-radius: 20px;
|
| 754 |
+
font-size: 12px;
|
| 755 |
+
font-weight: 600;
|
| 756 |
+
text-transform: uppercase;
|
| 757 |
+
min-width: 80px;
|
| 758 |
+
text-align: center;
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
.interaction-type.view {
|
| 762 |
+
background: #cce7ff;
|
| 763 |
+
color: #0066cc;
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
.interaction-type.cart {
|
| 767 |
+
background: #fff3cd;
|
| 768 |
+
color: #856404;
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
.interaction-type.purchase {
|
| 772 |
+
background: #d4edda;
|
| 773 |
+
color: #155724;
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
.interaction-details {
|
| 777 |
+
flex: 1;
|
| 778 |
+
font-size: 14px;
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
.category-tag {
|
| 782 |
+
background: #e9ecef;
|
| 783 |
+
color: #495057;
|
| 784 |
+
padding: 2px 8px;
|
| 785 |
+
border-radius: 12px;
|
| 786 |
+
font-size: 12px;
|
| 787 |
+
font-weight: 500;
|
| 788 |
+
display: inline-block;
|
| 789 |
+
margin: 0 5px;
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
.interaction-expand {
|
| 793 |
+
padding: 6px 12px;
|
| 794 |
+
background: #f8f9fa;
|
| 795 |
+
border: 1px solid #dee2e6;
|
| 796 |
+
border-radius: 4px;
|
| 797 |
+
cursor: pointer;
|
| 798 |
+
font-size: 12px;
|
| 799 |
+
color: #495057;
|
| 800 |
+
transition: all 0.2s ease;
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
.interaction-expand:hover {
|
| 804 |
+
background: #e9ecef;
|
| 805 |
+
border-color: #adb5bd;
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
.interaction-expanded {
|
| 809 |
+
background: #f8f9fa;
|
| 810 |
+
border: 1px solid #dee2e6;
|
| 811 |
+
border-radius: 8px;
|
| 812 |
+
padding: 20px;
|
| 813 |
+
margin-top: 10px;
|
| 814 |
+
}
|
| 815 |
+
|
| 816 |
+
.interaction-meta {
|
| 817 |
+
display: grid;
|
| 818 |
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
| 819 |
+
gap: 15px;
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
.interaction-meta-item {
|
| 823 |
+
display: flex;
|
| 824 |
+
justify-content: space-between;
|
| 825 |
+
padding: 8px 0;
|
| 826 |
+
border-bottom: 1px solid #e9ecef;
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
.interaction-meta-label {
|
| 830 |
+
font-weight: 600;
|
| 831 |
+
color: #495057;
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
.interaction-meta-value {
|
| 835 |
+
color: #6c757d;
|
| 836 |
+
font-family: monospace;
|
| 837 |
+
}
|
| 838 |
+
|
| 839 |
+
/* Recommendation Controls */
|
| 840 |
+
.recommendation-controls {
|
| 841 |
+
background: #fff;
|
| 842 |
+
padding: 25px;
|
| 843 |
+
border-radius: 10px;
|
| 844 |
+
margin-bottom: 30px;
|
| 845 |
+
border: 1px solid #e9ecef;
|
| 846 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
.recommendation-controls h2 {
|
| 850 |
+
margin-top: 0;
|
| 851 |
+
color: #495057;
|
| 852 |
+
border-bottom: 2px solid #dc3545;
|
| 853 |
+
padding-bottom: 10px;
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
.controls-row {
|
| 857 |
+
display: flex;
|
| 858 |
+
gap: 20px;
|
| 859 |
+
align-items: end;
|
| 860 |
+
flex-wrap: wrap;
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
.btn {
|
| 864 |
+
padding: 12px 24px;
|
| 865 |
+
border: none;
|
| 866 |
+
border-radius: 6px;
|
| 867 |
+
cursor: pointer;
|
| 868 |
+
font-size: 14px;
|
| 869 |
+
font-weight: 600;
|
| 870 |
+
transition: all 0.3s ease;
|
| 871 |
+
text-decoration: none;
|
| 872 |
+
display: inline-block;
|
| 873 |
+
text-align: center;
|
| 874 |
+
}
|
| 875 |
+
|
| 876 |
+
.btn-primary {
|
| 877 |
+
background: #007bff;
|
| 878 |
color: white;
|
| 879 |
}
|
| 880 |
|
| 881 |
+
.btn-primary:hover:not(:disabled) {
|
| 882 |
+
background: #0056b3;
|
| 883 |
+
transform: translateY(-1px);
|
| 884 |
+
box-shadow: 0 4px 8px rgba(0,123,255,0.3);
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
.btn:disabled {
|
| 888 |
+
background: #6c757d;
|
| 889 |
+
cursor: not-allowed;
|
| 890 |
+
opacity: 0.65;
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
/* Recommendations */
|
| 894 |
+
.recommendations {
|
| 895 |
+
background: #fff;
|
| 896 |
+
padding: 25px;
|
| 897 |
+
border-radius: 10px;
|
| 898 |
+
border: 1px solid #e9ecef;
|
| 899 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 900 |
+
}
|
| 901 |
+
|
| 902 |
+
.recommendations h2 {
|
| 903 |
+
margin-top: 0;
|
| 904 |
+
color: #495057;
|
| 905 |
+
border-bottom: 2px solid #6f42c1;
|
| 906 |
+
padding-bottom: 10px;
|
| 907 |
}
|
| 908 |
|
| 909 |
+
.stats {
|
| 910 |
+
background: #f8f9fa;
|
| 911 |
+
padding: 15px;
|
| 912 |
+
border-radius: 6px;
|
| 913 |
+
margin-bottom: 20px;
|
| 914 |
+
font-size: 14px;
|
| 915 |
+
color: #495057;
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
.recommendations-grid {
|
| 919 |
+
display: grid;
|
| 920 |
+
grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
|
| 921 |
+
gap: 20px;
|
| 922 |
+
margin-top: 20px;
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
.recommendation-card {
|
| 926 |
+
background: white;
|
| 927 |
+
border: 1px solid #e9ecef;
|
| 928 |
+
border-radius: 8px;
|
| 929 |
+
padding: 20px;
|
| 930 |
+
transition: all 0.2s ease;
|
| 931 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
|
| 932 |
+
}
|
| 933 |
+
|
| 934 |
+
.recommendation-card:hover {
|
| 935 |
+
transform: translateY(-2px);
|
| 936 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
| 937 |
+
border-color: #007bff;
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
.card-header {
|
| 941 |
+
display: flex;
|
| 942 |
+
justify-content: space-between;
|
| 943 |
+
align-items: center;
|
| 944 |
+
margin-bottom: 15px;
|
| 945 |
+
padding-bottom: 10px;
|
| 946 |
+
border-bottom: 1px solid #e9ecef;
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
.item-id {
|
| 950 |
+
font-weight: 600;
|
| 951 |
+
color: #495057;
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
.score {
|
| 955 |
+
background: #007bff;
|
| 956 |
+
color: white;
|
| 957 |
+
padding: 4px 8px;
|
| 958 |
+
border-radius: 12px;
|
| 959 |
+
font-size: 12px;
|
| 960 |
+
font-weight: 600;
|
| 961 |
+
}
|
| 962 |
+
|
| 963 |
+
.item-details p {
|
| 964 |
+
margin: 8px 0;
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
.brand {
|
| 968 |
+
font-weight: 600;
|
| 969 |
+
color: #495057;
|
| 970 |
+
font-size: 16px;
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
.price {
|
| 974 |
+
color: #28a745;
|
| 975 |
+
font-weight: 600;
|
| 976 |
+
font-size: 18px;
|
| 977 |
+
}
|
| 978 |
+
|
| 979 |
+
.category {
|
| 980 |
+
color: #6c757d;
|
| 981 |
+
font-size: 14px;
|
| 982 |
+
background: #f8f9fa;
|
| 983 |
+
padding: 4px 8px;
|
| 984 |
+
border-radius: 4px;
|
| 985 |
+
display: inline-block;
|
| 986 |
+
}
|
| 987 |
+
|
| 988 |
+
/* Error and Loading States */
|
| 989 |
+
.error {
|
| 990 |
+
background: #f8d7da;
|
| 991 |
+
color: #721c24;
|
| 992 |
+
padding: 15px;
|
| 993 |
+
border-radius: 6px;
|
| 994 |
+
border: 1px solid #f5c6cb;
|
| 995 |
+
margin-bottom: 20px;
|
| 996 |
+
}
|
| 997 |
+
|
| 998 |
+
.loading {
|
| 999 |
+
text-align: center;
|
| 1000 |
+
padding: 40px;
|
| 1001 |
+
background: #f8f9fa;
|
| 1002 |
+
border-radius: 10px;
|
| 1003 |
+
border: 1px solid #e9ecef;
|
| 1004 |
+
}
|
| 1005 |
+
|
| 1006 |
+
.loading h3 {
|
| 1007 |
+
color: #495057;
|
| 1008 |
+
margin-bottom: 10px;
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
+
.loading p {
|
| 1012 |
+
color: #6c757d;
|
| 1013 |
+
margin: 0;
|
| 1014 |
+
}
|
| 1015 |
+
|
| 1016 |
+
/* Responsive Design */
|
| 1017 |
+
@media (max-width: 768px) {
|
| 1018 |
+
.container {
|
| 1019 |
+
padding: 10px;
|
| 1020 |
+
}
|
| 1021 |
+
|
| 1022 |
+
.form-row,
|
| 1023 |
+
.controls-row {
|
| 1024 |
+
flex-direction: column;
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
.pattern-buttons {
|
| 1028 |
+
flex-direction: column;
|
| 1029 |
+
}
|
| 1030 |
+
|
| 1031 |
+
.pattern-btn {
|
| 1032 |
+
width: 100%;
|
| 1033 |
+
}
|
| 1034 |
+
|
| 1035 |
+
.recommendations-grid {
|
| 1036 |
+
grid-template-columns: 1fr;
|
| 1037 |
+
}
|
| 1038 |
+
|
| 1039 |
+
.interaction-main {
|
| 1040 |
+
flex-direction: column;
|
| 1041 |
+
align-items: flex-start;
|
| 1042 |
+
gap: 10px;
|
| 1043 |
}
|
| 1044 |
+
|
| 1045 |
+
.interaction-meta {
|
| 1046 |
+
grid-template-columns: 1fr;
|
| 1047 |
}
|
| 1048 |
}
|
|
@@ -22,21 +22,38 @@ function App() {
|
|
| 22 |
});
|
| 23 |
|
| 24 |
const [recommendationType, setRecommendationType] = useState('hybrid');
|
| 25 |
-
const [numRecommendations, setNumRecommendations] = useState(
|
| 26 |
const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
|
| 27 |
|
| 28 |
const [recommendations, setRecommendations] = useState([]);
|
| 29 |
const [loading, setLoading] = useState(false);
|
| 30 |
const [error, setError] = useState(null);
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
const [sampleItems, setSampleItems] = useState([]);
|
| 33 |
const [interactions, setInteractions] = useState([]);
|
| 34 |
const [expandedInteraction, setExpandedInteraction] = useState(null);
|
| 35 |
const [selectedPattern, setSelectedPattern] = useState(null);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
// Load sample items on component mount
|
| 38 |
useEffect(() => {
|
| 39 |
fetchSampleItems();
|
|
|
|
|
|
|
| 40 |
}, []);
|
| 41 |
|
| 42 |
const fetchSampleItems = async () => {
|
|
@@ -48,6 +65,30 @@ function App() {
|
|
| 48 |
}
|
| 49 |
};
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
const handleProfileChange = (field, value) => {
|
| 52 |
setUserProfile(prev => ({
|
| 53 |
...prev,
|
|
@@ -55,6 +96,41 @@ function App() {
|
|
| 55 |
}));
|
| 56 |
};
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
const generateTimestamp = (baseTime, offsetHours) => {
|
| 59 |
const timestamp = new Date(baseTime.getTime() - (offsetHours * 60 * 60 * 1000));
|
| 60 |
return timestamp.toISOString().replace('T', ' ').slice(0, 19);
|
|
@@ -170,9 +246,77 @@ function App() {
|
|
| 170 |
}));
|
| 171 |
};
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
const getRecommendations = async () => {
|
| 174 |
setLoading(true);
|
| 175 |
setError(null);
|
|
|
|
| 176 |
|
| 177 |
try {
|
| 178 |
const requestData = {
|
|
@@ -192,27 +336,154 @@ function App() {
|
|
| 192 |
}
|
| 193 |
};
|
| 194 |
|
| 195 |
-
const getInteractionCounts = () => {
|
| 196 |
-
const counts = { views: 0, carts: 0, purchases: 0 };
|
| 197 |
-
interactions.forEach(interaction => {
|
| 198 |
-
counts[interaction.type + 's'] = (counts[interaction.type + 's'] || 0) + 1;
|
| 199 |
-
});
|
| 200 |
-
return counts;
|
| 201 |
-
};
|
| 202 |
-
|
| 203 |
-
const counts = getInteractionCounts();
|
| 204 |
-
|
| 205 |
return (
|
| 206 |
<div className="App">
|
| 207 |
<div className="container">
|
| 208 |
<header className="header">
|
| 209 |
<h1>Two-Tower Recommendation System Demo</h1>
|
| 210 |
-
<p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
</header>
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
{/* User Profile Form */}
|
| 214 |
<div className="user-profile-form">
|
| 215 |
-
<h2>User Demographics</h2>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
<div className="form-row">
|
| 218 |
<div className="form-group">
|
|
@@ -224,6 +495,8 @@ function App() {
|
|
| 224 |
onChange={(e) => handleProfileChange('age', e.target.value)}
|
| 225 |
min="18"
|
| 226 |
max="100"
|
|
|
|
|
|
|
| 227 |
/>
|
| 228 |
</div>
|
| 229 |
|
|
@@ -233,6 +506,8 @@ function App() {
|
|
| 233 |
id="gender"
|
| 234 |
value={userProfile.gender}
|
| 235 |
onChange={(e) => handleProfileChange('gender', e.target.value)}
|
|
|
|
|
|
|
| 236 |
>
|
| 237 |
<option value="male">Male</option>
|
| 238 |
<option value="female">Female</option>
|
|
@@ -248,6 +523,8 @@ function App() {
|
|
| 248 |
onChange={(e) => handleProfileChange('income', e.target.value)}
|
| 249 |
min="0"
|
| 250 |
step="1000"
|
|
|
|
|
|
|
| 251 |
/>
|
| 252 |
</div>
|
| 253 |
</div>
|
|
@@ -255,29 +532,140 @@ function App() {
|
|
| 255 |
|
| 256 |
{/* Interaction Patterns */}
|
| 257 |
<div className="interaction-patterns">
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
<
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
{interactions.length > 0 && (
|
| 283 |
<>
|
|
@@ -305,7 +693,7 @@ function App() {
|
|
| 305 |
{interaction.type}
|
| 306 |
</span>
|
| 307 |
<span className="interaction-details">
|
| 308 |
-
<strong>{interaction.brand}</strong> - ${interaction.price}
|
| 309 |
{interaction.quantity && ` (x${interaction.quantity})`}
|
| 310 |
{interaction.total_amount && ` = $${interaction.total_amount}`}
|
| 311 |
</span>
|
|
@@ -387,6 +775,8 @@ function App() {
|
|
| 387 |
onChange={(e) => setRecommendationType(e.target.value)}
|
| 388 |
>
|
| 389 |
<option value="hybrid">Hybrid</option>
|
|
|
|
|
|
|
| 390 |
<option value="collaborative">Collaborative Filtering</option>
|
| 391 |
<option value="content">Content-Based</option>
|
| 392 |
</select>
|
|
@@ -399,14 +789,15 @@ function App() {
|
|
| 399 |
value={numRecommendations}
|
| 400 |
onChange={(e) => setNumRecommendations(parseInt(e.target.value))}
|
| 401 |
>
|
| 402 |
-
<option value="5">5</option>
|
| 403 |
<option value="10">10</option>
|
| 404 |
-
<option value="15">15</option>
|
| 405 |
<option value="20">20</option>
|
|
|
|
|
|
|
|
|
|
| 406 |
</select>
|
| 407 |
</div>
|
| 408 |
|
| 409 |
-
{recommendationType === 'hybrid' && (
|
| 410 |
<div className="form-group">
|
| 411 |
<label htmlFor="collabWeight">Collaborative Weight:</label>
|
| 412 |
<input
|
|
@@ -448,19 +839,66 @@ function App() {
|
|
| 448 |
{/* Recommendations Display */}
|
| 449 |
{recommendations.length > 0 && (
|
| 450 |
<div className="recommendations">
|
| 451 |
-
<h2>Recommendations ({recommendationType})</h2>
|
| 452 |
|
| 453 |
<div className="stats">
|
| 454 |
<strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
|
| 455 |
-
${userProfile.income.toLocaleString()} income
|
| 456 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
</div>
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
<div className="recommendations-grid">
|
| 460 |
-
{
|
| 461 |
<div key={rec.item_id} className="recommendation-card">
|
| 462 |
<div className="card-header">
|
| 463 |
-
<span className="item-id">#{index + 1} Item {rec.item_id}</span>
|
| 464 |
<span className="score">{rec.score.toFixed(4)}</span>
|
| 465 |
</div>
|
| 466 |
|
|
@@ -472,6 +910,27 @@ function App() {
|
|
| 472 |
</div>
|
| 473 |
))}
|
| 474 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
</div>
|
| 476 |
)}
|
| 477 |
|
|
|
|
| 22 |
});
|
| 23 |
|
| 24 |
const [recommendationType, setRecommendationType] = useState('hybrid');
|
| 25 |
+
const [numRecommendations, setNumRecommendations] = useState(100);
|
| 26 |
const [collaborativeWeight, setCollaborativeWeight] = useState(0.7);
|
| 27 |
|
| 28 |
const [recommendations, setRecommendations] = useState([]);
|
| 29 |
const [loading, setLoading] = useState(false);
|
| 30 |
const [error, setError] = useState(null);
|
| 31 |
|
| 32 |
+
// Pagination for recommendations
|
| 33 |
+
const [currentPage, setCurrentPage] = useState(1);
|
| 34 |
+
const [itemsPerPage] = useState(20); // Show 20 recommendations per page
|
| 35 |
+
|
| 36 |
const [sampleItems, setSampleItems] = useState([]);
|
| 37 |
const [interactions, setInteractions] = useState([]);
|
| 38 |
const [expandedInteraction, setExpandedInteraction] = useState(null);
|
| 39 |
const [selectedPattern, setSelectedPattern] = useState(null);
|
| 40 |
+
|
| 41 |
+
// Real user data states
|
| 42 |
+
const [realUsers, setRealUsers] = useState([]);
|
| 43 |
+
const [selectedRealUser, setSelectedRealUser] = useState(null);
|
| 44 |
+
const [datasetSummary, setDatasetSummary] = useState(null);
|
| 45 |
+
const [useRealUsers, setUseRealUsers] = useState(true);
|
| 46 |
+
|
| 47 |
+
// Expanded interaction states
|
| 48 |
+
const [showUserInteractions, setShowUserInteractions] = useState(false);
|
| 49 |
+
const [userInteractionDetails, setUserInteractionDetails] = useState(null);
|
| 50 |
+
const [loadingInteractions, setLoadingInteractions] = useState(false);
|
| 51 |
|
| 52 |
+
// Load sample items and real users on component mount
|
| 53 |
useEffect(() => {
|
| 54 |
fetchSampleItems();
|
| 55 |
+
fetchRealUsers();
|
| 56 |
+
fetchDatasetSummary();
|
| 57 |
}, []);
|
| 58 |
|
| 59 |
const fetchSampleItems = async () => {
|
|
|
|
| 65 |
}
|
| 66 |
};
|
| 67 |
|
| 68 |
+
const fetchRealUsers = async () => {
|
| 69 |
+
try {
|
| 70 |
+
const response = await axios.get(`${API_BASE_URL}/real-users?count=100&min_interactions=5`);
|
| 71 |
+
setRealUsers(response.data.users || []);
|
| 72 |
+
if (response.data.users && response.data.users.length > 0) {
|
| 73 |
+
// Auto-select the first (most active) user
|
| 74 |
+
handleRealUserSelect(response.data.users[0]);
|
| 75 |
+
}
|
| 76 |
+
} catch (error) {
|
| 77 |
+
console.error('Error fetching real users:', error);
|
| 78 |
+
setError('Could not load real users. Using synthetic data mode.');
|
| 79 |
+
setUseRealUsers(false);
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
const fetchDatasetSummary = async () => {
|
| 84 |
+
try {
|
| 85 |
+
const response = await axios.get(`${API_BASE_URL}/dataset-summary`);
|
| 86 |
+
setDatasetSummary(response.data);
|
| 87 |
+
} catch (error) {
|
| 88 |
+
console.error('Error fetching dataset summary:', error);
|
| 89 |
+
}
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
const handleProfileChange = (field, value) => {
|
| 93 |
setUserProfile(prev => ({
|
| 94 |
...prev,
|
|
|
|
| 96 |
}));
|
| 97 |
};
|
| 98 |
|
| 99 |
+
const handleRealUserSelect = (user) => {
|
| 100 |
+
setSelectedRealUser(user);
|
| 101 |
+
setUserProfile({
|
| 102 |
+
age: user.age,
|
| 103 |
+
gender: user.gender,
|
| 104 |
+
income: user.income,
|
| 105 |
+
interaction_history: user.interaction_history.slice(0, 50) // Limit to 50 items
|
| 106 |
+
});
|
| 107 |
+
// Clear any synthetic interactions and expanded states
|
| 108 |
+
setInteractions([]);
|
| 109 |
+
setSelectedPattern(null);
|
| 110 |
+
setShowUserInteractions(false);
|
| 111 |
+
setUserInteractionDetails(null);
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
const fetchUserInteractionDetails = async (userId) => {
|
| 115 |
+
setLoadingInteractions(true);
|
| 116 |
+
try {
|
| 117 |
+
const response = await axios.get(`${API_BASE_URL}/real-users/${userId}`);
|
| 118 |
+
setUserInteractionDetails(response.data);
|
| 119 |
+
} catch (error) {
|
| 120 |
+
console.error('Error fetching user interaction details:', error);
|
| 121 |
+
setError('Could not load user interaction details');
|
| 122 |
+
} finally {
|
| 123 |
+
setLoadingInteractions(false);
|
| 124 |
+
}
|
| 125 |
+
};
|
| 126 |
+
|
| 127 |
+
const toggleUserInteractions = async () => {
|
| 128 |
+
if (!showUserInteractions && selectedRealUser && !userInteractionDetails) {
|
| 129 |
+
await fetchUserInteractionDetails(selectedRealUser.user_id);
|
| 130 |
+
}
|
| 131 |
+
setShowUserInteractions(!showUserInteractions);
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
const generateTimestamp = (baseTime, offsetHours) => {
|
| 135 |
const timestamp = new Date(baseTime.getTime() - (offsetHours * 60 * 60 * 1000));
|
| 136 |
return timestamp.toISOString().replace('T', ' ').slice(0, 19);
|
|
|
|
| 246 |
}));
|
| 247 |
};
|
| 248 |
|
| 249 |
+
|
| 250 |
+
const getInteractionCounts = () => {
|
| 251 |
+
const counts = { views: 0, carts: 0, purchases: 0 };
|
| 252 |
+
interactions.forEach(interaction => {
|
| 253 |
+
counts[interaction.type + 's'] = (counts[interaction.type + 's'] || 0) + 1;
|
| 254 |
+
});
|
| 255 |
+
return counts;
|
| 256 |
+
};
|
| 257 |
+
|
| 258 |
+
const counts = getInteractionCounts();
|
| 259 |
+
|
| 260 |
+
// Calculate category percentages from user interactions
|
| 261 |
+
const getCategoryPercentages = () => {
|
| 262 |
+
if (!selectedRealUser || !userInteractionDetails) return {};
|
| 263 |
+
|
| 264 |
+
const categoryCounts = {};
|
| 265 |
+
let totalInteractions = 0;
|
| 266 |
+
|
| 267 |
+
userInteractionDetails.timeline?.forEach(interaction => {
|
| 268 |
+
const category = interaction.category_code || 'Unknown';
|
| 269 |
+
categoryCounts[category] = (categoryCounts[category] || 0) + 1;
|
| 270 |
+
totalInteractions++;
|
| 271 |
+
});
|
| 272 |
+
|
| 273 |
+
const categoryPercentages = {};
|
| 274 |
+
Object.keys(categoryCounts).forEach(category => {
|
| 275 |
+
categoryPercentages[category] = ((categoryCounts[category] / totalInteractions) * 100).toFixed(1);
|
| 276 |
+
});
|
| 277 |
+
|
| 278 |
+
return categoryPercentages;
|
| 279 |
+
};
|
| 280 |
+
|
| 281 |
+
// Calculate recommendation category percentages
|
| 282 |
+
const getRecommendationCategoryPercentages = () => {
|
| 283 |
+
if (!recommendations || recommendations.length === 0) return {};
|
| 284 |
+
|
| 285 |
+
const recCategoryCounts = {};
|
| 286 |
+
|
| 287 |
+
recommendations.forEach(rec => {
|
| 288 |
+
const category = rec.item_info?.category_code || 'Unknown';
|
| 289 |
+
recCategoryCounts[category] = (recCategoryCounts[category] || 0) + 1;
|
| 290 |
+
});
|
| 291 |
+
|
| 292 |
+
const recCategoryPercentages = {};
|
| 293 |
+
Object.keys(recCategoryCounts).forEach(category => {
|
| 294 |
+
recCategoryPercentages[category] = ((recCategoryCounts[category] / recommendations.length) * 100).toFixed(1);
|
| 295 |
+
});
|
| 296 |
+
|
| 297 |
+
return recCategoryPercentages;
|
| 298 |
+
};
|
| 299 |
+
|
| 300 |
+
const categoryPercentages = getCategoryPercentages();
|
| 301 |
+
const recommendationCategoryPercentages = getRecommendationCategoryPercentages();
|
| 302 |
+
|
| 303 |
+
// Pagination logic
|
| 304 |
+
const totalPages = Math.ceil(recommendations.length / itemsPerPage);
|
| 305 |
+
const startIndex = (currentPage - 1) * itemsPerPage;
|
| 306 |
+
const endIndex = startIndex + itemsPerPage;
|
| 307 |
+
const currentRecommendations = recommendations.slice(startIndex, endIndex);
|
| 308 |
+
|
| 309 |
+
const goToPage = (page) => {
|
| 310 |
+
setCurrentPage(page);
|
| 311 |
+
// Scroll to recommendations section
|
| 312 |
+
document.querySelector('.recommendations')?.scrollIntoView({ behavior: 'smooth' });
|
| 313 |
+
};
|
| 314 |
+
|
| 315 |
+
// Reset pagination when new recommendations are generated
|
| 316 |
const getRecommendations = async () => {
|
| 317 |
setLoading(true);
|
| 318 |
setError(null);
|
| 319 |
+
setCurrentPage(1); // Reset to first page
|
| 320 |
|
| 321 |
try {
|
| 322 |
const requestData = {
|
|
|
|
| 336 |
}
|
| 337 |
};
|
| 338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
return (
|
| 340 |
<div className="App">
|
| 341 |
<div className="container">
|
| 342 |
<header className="header">
|
| 343 |
<h1>Two-Tower Recommendation System Demo</h1>
|
| 344 |
+
<p>Select from {realUsers.length} real users or configure custom demographics to get personalized recommendations</p>
|
| 345 |
+
|
| 346 |
+
{datasetSummary && (
|
| 347 |
+
<div className="dataset-info">
|
| 348 |
+
📊 Dataset: {datasetSummary.total_users?.toLocaleString()} users, {datasetSummary.total_interactions?.toLocaleString()} interactions |
|
| 349 |
+
👥 Demographics: Avg age {datasetSummary.demographics?.avg_age}, avg income ${datasetSummary.demographics?.avg_income?.toLocaleString()}
|
| 350 |
+
</div>
|
| 351 |
+
)}
|
| 352 |
</header>
|
| 353 |
|
| 354 |
+
{/* Real User Selector */}
|
| 355 |
+
{useRealUsers && realUsers.length > 0 && (
|
| 356 |
+
<div className="real-user-selector">
|
| 357 |
+
<h2>Real User Selection</h2>
|
| 358 |
+
<div className="user-selector-controls">
|
| 359 |
+
<label htmlFor="realUserSelect">Choose from {realUsers.length} real users:</label>
|
| 360 |
+
<select
|
| 361 |
+
id="realUserSelect"
|
| 362 |
+
value={selectedRealUser?.user_id || ''}
|
| 363 |
+
onChange={(e) => {
|
| 364 |
+
const userId = parseInt(e.target.value);
|
| 365 |
+
const user = realUsers.find(u => u.user_id === userId);
|
| 366 |
+
if (user) handleRealUserSelect(user);
|
| 367 |
+
}}
|
| 368 |
+
>
|
| 369 |
+
<option value="">Select a real user...</option>
|
| 370 |
+
{realUsers.map((user, index) => (
|
| 371 |
+
<option key={user.user_id} value={user.user_id}>
|
| 372 |
+
#{index + 1}: {user.summary} - {user.interaction_pattern}
|
| 373 |
+
</option>
|
| 374 |
+
))}
|
| 375 |
+
</select>
|
| 376 |
+
|
| 377 |
+
<button
|
| 378 |
+
onClick={() => setUseRealUsers(false)}
|
| 379 |
+
className="btn btn-secondary"
|
| 380 |
+
style={{marginLeft: '10px'}}
|
| 381 |
+
>
|
| 382 |
+
Use Custom User Instead
|
| 383 |
+
</button>
|
| 384 |
+
</div>
|
| 385 |
+
|
| 386 |
+
{selectedRealUser && (
|
| 387 |
+
<div className="selected-real-user">
|
| 388 |
+
<h3>Selected User: {selectedRealUser.user_id}</h3>
|
| 389 |
+
<div className="real-user-stats">
|
| 390 |
+
<div className="user-stat">
|
| 391 |
+
<span className="stat-label">Demographics:</span>
|
| 392 |
+
<span className="stat-value">{selectedRealUser.age}yr {selectedRealUser.gender}, ${selectedRealUser.income.toLocaleString()}</span>
|
| 393 |
+
</div>
|
| 394 |
+
<div className="user-stat">
|
| 395 |
+
<span className="stat-label">Behavior Pattern:</span>
|
| 396 |
+
<span className="stat-value">{selectedRealUser.interaction_pattern}</span>
|
| 397 |
+
</div>
|
| 398 |
+
<div className="user-stat">
|
| 399 |
+
<span className="stat-label">Interactions:</span>
|
| 400 |
+
<span className="stat-value">
|
| 401 |
+
{selectedRealUser.interaction_stats.total_interactions} total
|
| 402 |
+
({selectedRealUser.interaction_stats.views} views, {selectedRealUser.interaction_stats.cart_adds} carts, {selectedRealUser.interaction_stats.purchases} purchases)
|
| 403 |
+
</span>
|
| 404 |
+
</div>
|
| 405 |
+
<div className="user-stat">
|
| 406 |
+
<span className="stat-label">History:</span>
|
| 407 |
+
<span className="stat-value">{selectedRealUser.interaction_stats.unique_items} unique items</span>
|
| 408 |
+
</div>
|
| 409 |
+
</div>
|
| 410 |
+
|
| 411 |
+
<button
|
| 412 |
+
onClick={toggleUserInteractions}
|
| 413 |
+
className="btn btn-info expand-interactions-btn"
|
| 414 |
+
disabled={loadingInteractions}
|
| 415 |
+
>
|
| 416 |
+
{loadingInteractions ? 'Loading...' : showUserInteractions ? 'Hide Interaction Timeline' : 'Show All Interactions Timeline'}
|
| 417 |
+
</button>
|
| 418 |
+
|
| 419 |
+
{showUserInteractions && userInteractionDetails && (
|
| 420 |
+
<div className="user-interactions-timeline">
|
| 421 |
+
<h4>Complete Interaction Timeline</h4>
|
| 422 |
+
<div className="timeline-stats">
|
| 423 |
+
<span><strong>Total Events:</strong> {userInteractionDetails.total_interactions}</span>
|
| 424 |
+
<span><strong>Pattern:</strong> {userInteractionDetails.interaction_pattern}</span>
|
| 425 |
+
<span><strong>Breakdown:</strong> {userInteractionDetails.breakdown.views} views, {userInteractionDetails.breakdown.cart_adds} carts, {userInteractionDetails.breakdown.purchases} purchases</span>
|
| 426 |
+
</div>
|
| 427 |
+
|
| 428 |
+
<div className="interactions-list">
|
| 429 |
+
<h5>Recent Interactions (Last {userInteractionDetails.timeline?.length || 0} events):</h5>
|
| 430 |
+
{userInteractionDetails.timeline?.map((interaction, index) => (
|
| 431 |
+
<div key={index} className="interaction-timeline-item">
|
| 432 |
+
<div className="interaction-timeline-time">
|
| 433 |
+
{new Date(interaction.timestamp).toLocaleString()}
|
| 434 |
+
</div>
|
| 435 |
+
<div className="interaction-timeline-content">
|
| 436 |
+
<div className="interaction-main-info">
|
| 437 |
+
<span className={`interaction-type ${interaction.event_type}`}>
|
| 438 |
+
{interaction.event_type.toUpperCase()}
|
| 439 |
+
</span>
|
| 440 |
+
<span className="interaction-icon">
|
| 441 |
+
{interaction.event_type === 'purchase' && '💰'}
|
| 442 |
+
{interaction.event_type === 'cart' && '🛒'}
|
| 443 |
+
{interaction.event_type === 'view' && '👁️'}
|
| 444 |
+
</span>
|
| 445 |
+
<span className="interaction-item-id">
|
| 446 |
+
Item #{interaction.product_id}
|
| 447 |
+
</span>
|
| 448 |
+
</div>
|
| 449 |
+
<div className="interaction-item-details">
|
| 450 |
+
<span className="item-brand">
|
| 451 |
+
<strong>{interaction.brand || 'Unknown Brand'}</strong>
|
| 452 |
+
</span>
|
| 453 |
+
<span className="item-category">
|
| 454 |
+
{interaction.category_code || 'Unknown Category'}
|
| 455 |
+
</span>
|
| 456 |
+
<span className="item-price">
|
| 457 |
+
${interaction.price ? interaction.price.toFixed(2) : '0.00'}
|
| 458 |
+
</span>
|
| 459 |
+
</div>
|
| 460 |
+
</div>
|
| 461 |
+
</div>
|
| 462 |
+
))}
|
| 463 |
+
</div>
|
| 464 |
+
</div>
|
| 465 |
+
)}
|
| 466 |
+
</div>
|
| 467 |
+
)}
|
| 468 |
+
</div>
|
| 469 |
+
)}
|
| 470 |
+
|
| 471 |
{/* User Profile Form */}
|
| 472 |
<div className="user-profile-form">
|
| 473 |
+
<h2>User Demographics {useRealUsers && selectedRealUser ? '(From Real User)' : '(Custom)'}</h2>
|
| 474 |
+
|
| 475 |
+
{!useRealUsers && (
|
| 476 |
+
<button
|
| 477 |
+
onClick={() => {
|
| 478 |
+
setUseRealUsers(true);
|
| 479 |
+
if (realUsers.length > 0) handleRealUserSelect(realUsers[0]);
|
| 480 |
+
}}
|
| 481 |
+
className="btn btn-secondary"
|
| 482 |
+
style={{marginBottom: '15px'}}
|
| 483 |
+
>
|
| 484 |
+
Switch to Real Users
|
| 485 |
+
</button>
|
| 486 |
+
)}
|
| 487 |
|
| 488 |
<div className="form-row">
|
| 489 |
<div className="form-group">
|
|
|
|
| 495 |
onChange={(e) => handleProfileChange('age', e.target.value)}
|
| 496 |
min="18"
|
| 497 |
max="100"
|
| 498 |
+
disabled={useRealUsers && selectedRealUser}
|
| 499 |
+
style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
|
| 500 |
/>
|
| 501 |
</div>
|
| 502 |
|
|
|
|
| 506 |
id="gender"
|
| 507 |
value={userProfile.gender}
|
| 508 |
onChange={(e) => handleProfileChange('gender', e.target.value)}
|
| 509 |
+
disabled={useRealUsers && selectedRealUser}
|
| 510 |
+
style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
|
| 511 |
>
|
| 512 |
<option value="male">Male</option>
|
| 513 |
<option value="female">Female</option>
|
|
|
|
| 523 |
onChange={(e) => handleProfileChange('income', e.target.value)}
|
| 524 |
min="0"
|
| 525 |
step="1000"
|
| 526 |
+
disabled={useRealUsers && selectedRealUser}
|
| 527 |
+
style={{backgroundColor: useRealUsers && selectedRealUser ? '#f5f5f5' : 'white'}}
|
| 528 |
/>
|
| 529 |
</div>
|
| 530 |
</div>
|
|
|
|
| 532 |
|
| 533 |
{/* Interaction Patterns */}
|
| 534 |
<div className="interaction-patterns">
|
| 535 |
+
{useRealUsers && selectedRealUser ? (
|
| 536 |
+
<>
|
| 537 |
+
<h2>Real User Interaction History</h2>
|
| 538 |
+
<p>This user has genuine interaction history from the dataset - no synthetic patterns needed.</p>
|
| 539 |
+
|
| 540 |
+
<div className="real-interaction-summary">
|
| 541 |
+
<div className="summary-card views">
|
| 542 |
+
<div className="summary-number">{selectedRealUser.interaction_stats.views}</div>
|
| 543 |
+
<div className="summary-label">Views</div>
|
| 544 |
+
</div>
|
| 545 |
+
<div className="summary-card carts">
|
| 546 |
+
<div className="summary-number">{selectedRealUser.interaction_stats.cart_adds}</div>
|
| 547 |
+
<div className="summary-label">Cart Adds</div>
|
| 548 |
+
</div>
|
| 549 |
+
<div className="summary-card purchases">
|
| 550 |
+
<div className="summary-number">{selectedRealUser.interaction_stats.purchases}</div>
|
| 551 |
+
<div className="summary-label">Purchases</div>
|
| 552 |
+
</div>
|
| 553 |
+
</div>
|
| 554 |
+
|
| 555 |
+
<div className="real-history-info">
|
| 556 |
+
<p><strong>Pattern:</strong> {selectedRealUser.interaction_pattern}</p>
|
| 557 |
+
<p><strong>Total Interactions:</strong> {selectedRealUser.interaction_stats.total_interactions}</p>
|
| 558 |
+
<p><strong>Unique Items:</strong> {selectedRealUser.interaction_stats.unique_items}</p>
|
| 559 |
+
<p><strong>Items in History:</strong> {userProfile.interaction_history.length} (showing up to 50 most recent)</p>
|
| 560 |
+
</div>
|
| 561 |
+
|
| 562 |
+
{/* Category Analysis Columns */}
|
| 563 |
+
{(Object.keys(categoryPercentages).length > 0 || Object.keys(recommendationCategoryPercentages).length > 0) && (
|
| 564 |
+
<div className="category-analysis">
|
| 565 |
+
<h4>Category Analysis</h4>
|
| 566 |
+
<div className="category-columns">
|
| 567 |
+
|
| 568 |
+
{/* User's Interacted Categories */}
|
| 569 |
+
{Object.keys(categoryPercentages).length > 0 && (
|
| 570 |
+
<div className="category-column">
|
| 571 |
+
<h5>👁️ User's Category Interests</h5>
|
| 572 |
+
<div className="category-percentages">
|
| 573 |
+
{Object.entries(categoryPercentages)
|
| 574 |
+
.sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
|
| 575 |
+
.slice(0, 5)
|
| 576 |
+
.map(([category, percentage]) => (
|
| 577 |
+
<div key={category} className="category-item">
|
| 578 |
+
<div className="category-bar-container">
|
| 579 |
+
<div
|
| 580 |
+
className="category-bar user-category"
|
| 581 |
+
style={{ width: `${Math.max(parseFloat(percentage), 5)}%` }}
|
| 582 |
+
></div>
|
| 583 |
+
</div>
|
| 584 |
+
<span className="category-label">{category.replace('_', ' ')}</span>
|
| 585 |
+
<span className="category-percent">{percentage}%</span>
|
| 586 |
+
</div>
|
| 587 |
+
))}
|
| 588 |
+
</div>
|
| 589 |
+
</div>
|
| 590 |
+
)}
|
| 591 |
+
|
| 592 |
+
{/* Recommendation Categories */}
|
| 593 |
+
{Object.keys(recommendationCategoryPercentages).length > 0 && (
|
| 594 |
+
<div className="category-column">
|
| 595 |
+
<h5>🎯 Recommendation Categories</h5>
|
| 596 |
+
<div className="category-percentages">
|
| 597 |
+
{Object.entries(recommendationCategoryPercentages)
|
| 598 |
+
.sort((a, b) => parseFloat(b[1]) - parseFloat(a[1]))
|
| 599 |
+
.map(([category, percentage]) => {
|
| 600 |
+
const userPercentage = categoryPercentages[category] || 0;
|
| 601 |
+
const isMatch = parseFloat(userPercentage) > 0;
|
| 602 |
+
|
| 603 |
+
return (
|
| 604 |
+
<div key={category} className={`category-item ${isMatch ? 'matched' : 'new'}`}>
|
| 605 |
+
<div className="category-bar-container">
|
| 606 |
+
<div
|
| 607 |
+
className={`category-bar ${isMatch ? 'rec-category-matched' : 'rec-category-new'}`}
|
| 608 |
+
style={{ width: `${Math.max(parseFloat(percentage), 5)}%` }}
|
| 609 |
+
></div>
|
| 610 |
+
</div>
|
| 611 |
+
<span className="category-label">{category.replace('_', ' ')}</span>
|
| 612 |
+
<span className="category-percent">{percentage}%</span>
|
| 613 |
+
{isMatch && <span className="match-indicator">✓</span>}
|
| 614 |
+
</div>
|
| 615 |
+
);
|
| 616 |
+
})}
|
| 617 |
+
</div>
|
| 618 |
+
</div>
|
| 619 |
+
)}
|
| 620 |
+
|
| 621 |
+
</div>
|
| 622 |
+
|
| 623 |
+
{/* Category Match Analysis */}
|
| 624 |
+
{Object.keys(categoryPercentages).length > 0 && Object.keys(recommendationCategoryPercentages).length > 0 && (
|
| 625 |
+
<div className="category-match-summary">
|
| 626 |
+
<p>
|
| 627 |
+
<strong>Category Alignment:</strong> {
|
| 628 |
+
Object.keys(recommendationCategoryPercentages).filter(cat =>
|
| 629 |
+
parseFloat(categoryPercentages[cat] || 0) > 0
|
| 630 |
+
).length
|
| 631 |
+
} of {Object.keys(recommendationCategoryPercentages).length} recommended categories match user interests
|
| 632 |
+
<span className="match-legend">
|
| 633 |
+
<span className="legend-item"><span className="legend-dot matched"></span> Matches user interest</span>
|
| 634 |
+
<span className="legend-item"><span className="legend-dot new"></span> New category exploration</span>
|
| 635 |
+
</span>
|
| 636 |
+
</p>
|
| 637 |
+
</div>
|
| 638 |
+
)}
|
| 639 |
+
</div>
|
| 640 |
+
)}
|
| 641 |
+
</>
|
| 642 |
+
) : (
|
| 643 |
+
<>
|
| 644 |
+
<h2>Synthetic Interaction Patterns</h2>
|
| 645 |
+
<p>Generate realistic user behavior patterns with proportional view, cart, and purchase events</p>
|
| 646 |
+
|
| 647 |
+
<div className="pattern-buttons">
|
| 648 |
+
{INTERACTION_PATTERNS.map((pattern, index) => (
|
| 649 |
+
<button
|
| 650 |
+
key={index}
|
| 651 |
+
className={`pattern-btn ${selectedPattern?.name === pattern.name ? 'active' : ''}`}
|
| 652 |
+
onClick={() => handlePatternSelect(pattern)}
|
| 653 |
+
>
|
| 654 |
+
{pattern.name}
|
| 655 |
+
<br />
|
| 656 |
+
<small>{pattern.views}V • {pattern.carts}C • {pattern.purchases}P</small>
|
| 657 |
+
</button>
|
| 658 |
+
))}
|
| 659 |
+
<button
|
| 660 |
+
className="pattern-btn"
|
| 661 |
+
onClick={clearInteractions}
|
| 662 |
+
style={{backgroundColor: '#dc3545', color: 'white', borderColor: '#dc3545'}}
|
| 663 |
+
>
|
| 664 |
+
Clear All
|
| 665 |
+
</button>
|
| 666 |
+
</div>
|
| 667 |
+
</>
|
| 668 |
+
)}
|
| 669 |
|
| 670 |
{interactions.length > 0 && (
|
| 671 |
<>
|
|
|
|
| 693 |
{interaction.type}
|
| 694 |
</span>
|
| 695 |
<span className="interaction-details">
|
| 696 |
+
<strong>{interaction.brand}</strong> - <span className="category-tag">{interaction.category}</span> - ${interaction.price}
|
| 697 |
{interaction.quantity && ` (x${interaction.quantity})`}
|
| 698 |
{interaction.total_amount && ` = $${interaction.total_amount}`}
|
| 699 |
</span>
|
|
|
|
| 775 |
onChange={(e) => setRecommendationType(e.target.value)}
|
| 776 |
>
|
| 777 |
<option value="hybrid">Hybrid</option>
|
| 778 |
+
<option value="enhanced">🎯 Enhanced Hybrid (Category-Aware)</option>
|
| 779 |
+
<option value="category_focused">🎯 Category Focused (80% Match)</option>
|
| 780 |
<option value="collaborative">Collaborative Filtering</option>
|
| 781 |
<option value="content">Content-Based</option>
|
| 782 |
</select>
|
|
|
|
| 789 |
value={numRecommendations}
|
| 790 |
onChange={(e) => setNumRecommendations(parseInt(e.target.value))}
|
| 791 |
>
|
|
|
|
| 792 |
<option value="10">10</option>
|
|
|
|
| 793 |
<option value="20">20</option>
|
| 794 |
+
<option value="50">50</option>
|
| 795 |
+
<option value="100">100 (Top Items)</option>
|
| 796 |
+
<option value="200">200 (Extended)</option>
|
| 797 |
</select>
|
| 798 |
</div>
|
| 799 |
|
| 800 |
+
{(recommendationType === 'hybrid' || recommendationType === 'enhanced') && (
|
| 801 |
<div className="form-group">
|
| 802 |
<label htmlFor="collabWeight">Collaborative Weight:</label>
|
| 803 |
<input
|
|
|
|
| 839 |
{/* Recommendations Display */}
|
| 840 |
{recommendations.length > 0 && (
|
| 841 |
<div className="recommendations">
|
| 842 |
+
<h2>Top {recommendations.length} Recommendations ({recommendationType})</h2>
|
| 843 |
|
| 844 |
<div className="stats">
|
| 845 |
<strong>User Profile:</strong> {userProfile.age}yr {userProfile.gender},
|
| 846 |
+
${userProfile.income.toLocaleString()} income
|
| 847 |
+
{useRealUsers && selectedRealUser ? (
|
| 848 |
+
<span> | <strong>Real User {selectedRealUser.user_id}:</strong> {selectedRealUser.interaction_pattern} -
|
| 849 |
+
{selectedRealUser.interaction_stats.total_interactions} total interactions
|
| 850 |
+
({selectedRealUser.interaction_stats.views} views, {selectedRealUser.interaction_stats.cart_adds} carts, {selectedRealUser.interaction_stats.purchases} purchases)
|
| 851 |
+
</span>
|
| 852 |
+
) : (
|
| 853 |
+
<span> | <strong>Synthetic:</strong> {interactions.length} total interactions ({counts.views || 0} views, {counts.carts || 0} carts, {counts.purchases || 0} purchases)</span>
|
| 854 |
+
)}
|
| 855 |
</div>
|
| 856 |
|
| 857 |
+
{/* Pagination Controls */}
|
| 858 |
+
{totalPages > 1 && (
|
| 859 |
+
<div className="pagination-info">
|
| 860 |
+
<p>Showing {startIndex + 1}-{Math.min(endIndex, recommendations.length)} of {recommendations.length} recommendations</p>
|
| 861 |
+
<div className="pagination-controls">
|
| 862 |
+
<button
|
| 863 |
+
onClick={() => goToPage(currentPage - 1)}
|
| 864 |
+
disabled={currentPage === 1}
|
| 865 |
+
className="pagination-btn"
|
| 866 |
+
>
|
| 867 |
+
← Previous
|
| 868 |
+
</button>
|
| 869 |
+
|
| 870 |
+
<div className="page-numbers">
|
| 871 |
+
{Array.from({length: Math.min(totalPages, 10)}, (_, i) => {
|
| 872 |
+
const page = i + 1;
|
| 873 |
+
return (
|
| 874 |
+
<button
|
| 875 |
+
key={page}
|
| 876 |
+
onClick={() => goToPage(page)}
|
| 877 |
+
className={`page-number ${currentPage === page ? 'active' : ''}`}
|
| 878 |
+
>
|
| 879 |
+
{page}
|
| 880 |
+
</button>
|
| 881 |
+
);
|
| 882 |
+
})}
|
| 883 |
+
{totalPages > 10 && <span className="pagination-ellipsis">...</span>}
|
| 884 |
+
</div>
|
| 885 |
+
|
| 886 |
+
<button
|
| 887 |
+
onClick={() => goToPage(currentPage + 1)}
|
| 888 |
+
disabled={currentPage === totalPages}
|
| 889 |
+
className="pagination-btn"
|
| 890 |
+
>
|
| 891 |
+
Next →
|
| 892 |
+
</button>
|
| 893 |
+
</div>
|
| 894 |
+
</div>
|
| 895 |
+
)}
|
| 896 |
+
|
| 897 |
<div className="recommendations-grid">
|
| 898 |
+
{currentRecommendations.map((rec, index) => (
|
| 899 |
<div key={rec.item_id} className="recommendation-card">
|
| 900 |
<div className="card-header">
|
| 901 |
+
<span className="item-id">#{startIndex + index + 1} Item {rec.item_id}</span>
|
| 902 |
<span className="score">{rec.score.toFixed(4)}</span>
|
| 903 |
</div>
|
| 904 |
|
|
|
|
| 910 |
</div>
|
| 911 |
))}
|
| 912 |
</div>
|
| 913 |
+
|
| 914 |
+
{/* Bottom Pagination */}
|
| 915 |
+
{totalPages > 1 && (
|
| 916 |
+
<div className="pagination-controls bottom-pagination">
|
| 917 |
+
<button
|
| 918 |
+
onClick={() => goToPage(currentPage - 1)}
|
| 919 |
+
disabled={currentPage === 1}
|
| 920 |
+
className="pagination-btn"
|
| 921 |
+
>
|
| 922 |
+
← Previous
|
| 923 |
+
</button>
|
| 924 |
+
<span className="page-indicator">Page {currentPage} of {totalPages}</span>
|
| 925 |
+
<button
|
| 926 |
+
onClick={() => goToPage(currentPage + 1)}
|
| 927 |
+
disabled={currentPage === totalPages}
|
| 928 |
+
className="pagination-btn"
|
| 929 |
+
>
|
| 930 |
+
Next →
|
| 931 |
+
</button>
|
| 932 |
+
</div>
|
| 933 |
+
)}
|
| 934 |
</div>
|
| 935 |
)}
|
| 936 |
|
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Two-Phase Training Pipeline Runner
|
| 4 |
+
|
| 5 |
+
This script orchestrates the complete 2-phase training approach:
|
| 6 |
+
1. Phase 1: Item tower pretraining on item features
|
| 7 |
+
2. Phase 2: Joint training of user tower + fine-tuning pre-trained item tower
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python run_2phase_training.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
import pickle
|
| 17 |
+
import numpy as np
|
| 18 |
+
from typing import Dict
|
| 19 |
+
|
| 20 |
+
# Add src to path
|
| 21 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 22 |
+
|
| 23 |
+
from src.training.item_pretraining import ItemTowerPretrainer
|
| 24 |
+
from src.training.joint_training import JointTrainer
|
| 25 |
+
from src.preprocessing.data_loader import DataProcessor
|
| 26 |
+
from src.inference.faiss_index import FAISSItemIndex
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def run_phase1_item_pretraining():
|
| 30 |
+
"""Phase 1: Pre-train the item tower."""
|
| 31 |
+
|
| 32 |
+
print("\n" + "="*60)
|
| 33 |
+
print("PHASE 1: ITEM TOWER PRETRAINING")
|
| 34 |
+
print("="*60)
|
| 35 |
+
|
| 36 |
+
# Initialize components
|
| 37 |
+
data_processor = DataProcessor()
|
| 38 |
+
pretrainer = ItemTowerPretrainer(
|
| 39 |
+
embedding_dim=128,
|
| 40 |
+
hidden_dims=[256, 128],
|
| 41 |
+
dropout_rate=0.2,
|
| 42 |
+
learning_rate=0.001
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Prepare data
|
| 46 |
+
print("Preparing item data...")
|
| 47 |
+
dataset, data_processor, price_normalizer = pretrainer.prepare_data(data_processor)
|
| 48 |
+
|
| 49 |
+
# Build model
|
| 50 |
+
print("Building item tower...")
|
| 51 |
+
model = pretrainer.build_model(
|
| 52 |
+
item_vocab_size=len(data_processor.item_vocab),
|
| 53 |
+
category_vocab_size=len(data_processor.category_vocab),
|
| 54 |
+
brand_vocab_size=len(data_processor.brand_vocab),
|
| 55 |
+
price_normalizer=price_normalizer
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Train model
|
| 59 |
+
print("Training item tower (Phase 1)...")
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
history = pretrainer.train(dataset, epochs=50)
|
| 62 |
+
phase1_time = time.time() - start_time
|
| 63 |
+
|
| 64 |
+
# Generate embeddings
|
| 65 |
+
print("Generating item embeddings...")
|
| 66 |
+
item_embeddings = pretrainer.generate_item_embeddings(dataset, data_processor)
|
| 67 |
+
|
| 68 |
+
# Save artifacts
|
| 69 |
+
print("Saving Phase 1 artifacts...")
|
| 70 |
+
os.makedirs("src/artifacts", exist_ok=True)
|
| 71 |
+
data_processor.save_vocabularies()
|
| 72 |
+
pretrainer.save_model()
|
| 73 |
+
|
| 74 |
+
# Save embeddings for FAISS index
|
| 75 |
+
np.save("src/artifacts/item_embeddings.npy", item_embeddings)
|
| 76 |
+
|
| 77 |
+
# Build FAISS index
|
| 78 |
+
print("Building FAISS index...")
|
| 79 |
+
faiss_index = FAISSItemIndex()
|
| 80 |
+
faiss_index.build_index(item_embeddings)
|
| 81 |
+
faiss_index.save_index("src/artifacts/")
|
| 82 |
+
|
| 83 |
+
print(f"✅ Phase 1 completed in {phase1_time:.2f} seconds")
|
| 84 |
+
print(f" - Items processed: {len(item_embeddings)}")
|
| 85 |
+
print(f" - Final loss: {history.history['total_loss'][-1]:.4f}")
|
| 86 |
+
|
| 87 |
+
return data_processor
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def run_phase2_joint_training(data_processor: DataProcessor):
|
| 91 |
+
"""Phase 2: Joint training with pre-trained item tower."""
|
| 92 |
+
|
| 93 |
+
print("\n" + "="*60)
|
| 94 |
+
print("PHASE 2: JOINT TRAINING")
|
| 95 |
+
print("="*60)
|
| 96 |
+
|
| 97 |
+
# Initialize joint trainer
|
| 98 |
+
trainer = JointTrainer(
|
| 99 |
+
embedding_dim=128,
|
| 100 |
+
user_learning_rate=0.001,
|
| 101 |
+
item_learning_rate=0.0001, # Lower LR for pre-trained item tower
|
| 102 |
+
rating_weight=1.0,
|
| 103 |
+
retrieval_weight=0.5
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Load pre-trained item tower
|
| 107 |
+
print("Loading pre-trained item tower...")
|
| 108 |
+
trainer.load_pre_trained_item_tower()
|
| 109 |
+
|
| 110 |
+
# Build user tower
|
| 111 |
+
print("Building user tower...")
|
| 112 |
+
trainer.build_user_tower(max_history_length=50)
|
| 113 |
+
|
| 114 |
+
# Build complete two-tower model
|
| 115 |
+
print("Building complete two-tower model...")
|
| 116 |
+
trainer.build_two_tower_model()
|
| 117 |
+
|
| 118 |
+
# Prepare training data
|
| 119 |
+
print("Preparing user interaction data...")
|
| 120 |
+
|
| 121 |
+
# Check if training features already exist
|
| 122 |
+
if os.path.exists("src/artifacts/training_features.pkl"):
|
| 123 |
+
print("Loading existing training features...")
|
| 124 |
+
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 125 |
+
training_features = pickle.load(f)
|
| 126 |
+
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 127 |
+
validation_features = pickle.load(f)
|
| 128 |
+
else:
|
| 129 |
+
# Generate training features
|
| 130 |
+
print("Generating training features...")
|
| 131 |
+
training_features, validation_features = data_processor.prepare_training_data()
|
| 132 |
+
|
| 133 |
+
# Save features
|
| 134 |
+
with open("src/artifacts/training_features.pkl", 'wb') as f:
|
| 135 |
+
pickle.dump(training_features, f)
|
| 136 |
+
with open("src/artifacts/validation_features.pkl", 'wb') as f:
|
| 137 |
+
pickle.dump(validation_features, f)
|
| 138 |
+
|
| 139 |
+
# Train joint model
|
| 140 |
+
print("Starting joint training (Phase 2)...")
|
| 141 |
+
start_time = time.time()
|
| 142 |
+
history = trainer.train(
|
| 143 |
+
training_features=training_features,
|
| 144 |
+
validation_features=validation_features,
|
| 145 |
+
epochs=100,
|
| 146 |
+
batch_size=256
|
| 147 |
+
)
|
| 148 |
+
phase2_time = time.time() - start_time
|
| 149 |
+
|
| 150 |
+
# Save final model
|
| 151 |
+
print("Saving final two-tower model...")
|
| 152 |
+
trainer.save_model()
|
| 153 |
+
|
| 154 |
+
# Save training history
|
| 155 |
+
with open("src/artifacts/joint_training_history.pkl", 'wb') as f:
|
| 156 |
+
pickle.dump(history, f)
|
| 157 |
+
|
| 158 |
+
print(f"✅ Phase 2 completed in {phase2_time:.2f} seconds")
|
| 159 |
+
print(f" - Best validation loss: {min(history['val_total_loss']):.4f}")
|
| 160 |
+
print(f" - Epochs trained: {len(history['total_loss'])}")
|
| 161 |
+
|
| 162 |
+
return history
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def main():
|
| 166 |
+
"""Main function to run complete 2-phase training pipeline."""
|
| 167 |
+
|
| 168 |
+
print("🚀 STARTING 2-PHASE TRAINING PIPELINE")
|
| 169 |
+
print(f"Working directory: {os.getcwd()}")
|
| 170 |
+
print(f"Python path: {sys.executable}")
|
| 171 |
+
|
| 172 |
+
total_start_time = time.time()
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
# Phase 1: Item tower pretraining
|
| 176 |
+
data_processor = run_phase1_item_pretraining()
|
| 177 |
+
|
| 178 |
+
# Phase 2: Joint training
|
| 179 |
+
history = run_phase2_joint_training(data_processor)
|
| 180 |
+
|
| 181 |
+
# Final summary
|
| 182 |
+
total_time = time.time() - total_start_time
|
| 183 |
+
|
| 184 |
+
print("\n" + "="*60)
|
| 185 |
+
print("🎉 2-PHASE TRAINING COMPLETED SUCCESSFULLY!")
|
| 186 |
+
print("="*60)
|
| 187 |
+
print(f"Total training time: {total_time:.2f} seconds ({total_time/60:.1f} minutes)")
|
| 188 |
+
print(f"Artifacts saved in: src/artifacts/")
|
| 189 |
+
print("\nKey files generated:")
|
| 190 |
+
print(" - item_tower_weights: Pre-trained item embeddings")
|
| 191 |
+
print(" - item_tower_weights_finetuned_best: Fine-tuned item tower")
|
| 192 |
+
print(" - user_tower_weights_best: Trained user tower")
|
| 193 |
+
print(" - rating_model_weights_best: Rating prediction model")
|
| 194 |
+
print(" - faiss_index.index: Item similarity index")
|
| 195 |
+
print(" - vocabularies.pkl: Feature vocabularies")
|
| 196 |
+
|
| 197 |
+
print(f"\n🔥 Final validation loss: {min(history['val_total_loss']):.4f}")
|
| 198 |
+
print("\n✅ Ready to run inference with api/main.py!")
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(f"\n❌ Training failed with error: {str(e)}")
|
| 202 |
+
raise
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
main()
|
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Single Joint Training Pipeline Runner
|
| 4 |
+
|
| 5 |
+
This script orchestrates the single-phase joint training approach:
|
| 6 |
+
- Trains user tower and item tower simultaneously from scratch
|
| 7 |
+
- No pre-training phase - direct end-to-end optimization
|
| 8 |
+
- Supports both regular and fast training modes
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python run_joint_training.py [--fast]
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
import pickle
|
| 18 |
+
import argparse
|
| 19 |
+
from typing import Dict
|
| 20 |
+
|
| 21 |
+
# Add src to path
|
| 22 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 23 |
+
|
| 24 |
+
from src.training.fast_joint_training import FastJointTrainer
|
| 25 |
+
from src.models.item_tower import ItemTower
|
| 26 |
+
from src.models.user_tower import UserTower, TwoTowerModel
|
| 27 |
+
from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
|
| 28 |
+
from src.inference.faiss_index import FAISSItemIndex
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SingleJointTrainer:
|
| 34 |
+
"""Complete single-phase joint training from scratch."""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.item_tower = None
|
| 38 |
+
self.user_tower = None
|
| 39 |
+
self.model = None
|
| 40 |
+
self.data_processor = None
|
| 41 |
+
|
| 42 |
+
# Training hyperparameters
|
| 43 |
+
self.embedding_dim = 128
|
| 44 |
+
self.learning_rate = 0.001
|
| 45 |
+
self.batch_size = 256
|
| 46 |
+
self.epochs = 80
|
| 47 |
+
self.patience = 20
|
| 48 |
+
|
| 49 |
+
def prepare_data(self):
|
| 50 |
+
"""Prepare all training data from scratch."""
|
| 51 |
+
|
| 52 |
+
print("Loading and preparing data...")
|
| 53 |
+
|
| 54 |
+
# Initialize data processor
|
| 55 |
+
self.data_processor = DataProcessor()
|
| 56 |
+
|
| 57 |
+
# Check if preprocessed data exists
|
| 58 |
+
if os.path.exists("src/artifacts/training_features.pkl"):
|
| 59 |
+
print("Loading existing preprocessed data...")
|
| 60 |
+
|
| 61 |
+
# Load vocabularies
|
| 62 |
+
self.data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
|
| 63 |
+
|
| 64 |
+
# Load training features
|
| 65 |
+
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 66 |
+
training_features = pickle.load(f)
|
| 67 |
+
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 68 |
+
validation_features = pickle.load(f)
|
| 69 |
+
else:
|
| 70 |
+
print("Preprocessing data from scratch...")
|
| 71 |
+
|
| 72 |
+
# Load raw data and build vocabularies
|
| 73 |
+
items_df, users_df, interactions_df = self.data_processor.load_data()
|
| 74 |
+
self.data_processor.build_vocabularies(items_df, users_df, interactions_df)
|
| 75 |
+
|
| 76 |
+
# Generate training features
|
| 77 |
+
training_features, validation_features = self.data_processor.prepare_training_data()
|
| 78 |
+
|
| 79 |
+
# Save for future use
|
| 80 |
+
os.makedirs("src/artifacts", exist_ok=True)
|
| 81 |
+
self.data_processor.save_vocabularies()
|
| 82 |
+
|
| 83 |
+
with open("src/artifacts/training_features.pkl", 'wb') as f:
|
| 84 |
+
pickle.dump(training_features, f)
|
| 85 |
+
with open("src/artifacts/validation_features.pkl", 'wb') as f:
|
| 86 |
+
pickle.dump(validation_features, f)
|
| 87 |
+
|
| 88 |
+
print(f"Training samples: {len(training_features['rating']):,}")
|
| 89 |
+
print(f"Validation samples: {len(validation_features['rating']):,}")
|
| 90 |
+
|
| 91 |
+
return training_features, validation_features
|
| 92 |
+
|
| 93 |
+
def build_models(self):
|
| 94 |
+
"""Build both towers from scratch."""
|
| 95 |
+
|
| 96 |
+
print("Building item tower...")
|
| 97 |
+
self.item_tower = ItemTower(
|
| 98 |
+
item_vocab_size=len(self.data_processor.item_vocab),
|
| 99 |
+
category_vocab_size=len(self.data_processor.category_vocab),
|
| 100 |
+
brand_vocab_size=len(self.data_processor.brand_vocab),
|
| 101 |
+
embedding_dim=self.embedding_dim,
|
| 102 |
+
hidden_dims=[256, 128],
|
| 103 |
+
dropout_rate=0.2
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
print("Building user tower...")
|
| 107 |
+
self.user_tower = UserTower(
|
| 108 |
+
max_history_length=50,
|
| 109 |
+
embedding_dim=self.embedding_dim,
|
| 110 |
+
hidden_dims=[128, 64], # Match trained architecture
|
| 111 |
+
dropout_rate=0.2
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
print("Building complete two-tower model...")
|
| 115 |
+
self.model = TwoTowerModel(
|
| 116 |
+
item_tower=self.item_tower,
|
| 117 |
+
user_tower=self.user_tower,
|
| 118 |
+
rating_weight=1.0,
|
| 119 |
+
retrieval_weight=0.5
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
print("Models initialized successfully")
|
| 123 |
+
|
| 124 |
+
def train_joint_model(self, training_features: Dict, validation_features: Dict):
|
| 125 |
+
"""Train both towers jointly from scratch."""
|
| 126 |
+
|
| 127 |
+
print(f"Starting single-phase joint training...")
|
| 128 |
+
print(f"Configuration: {self.epochs} epochs, batch size {self.batch_size}")
|
| 129 |
+
|
| 130 |
+
# Create datasets
|
| 131 |
+
train_dataset = create_tf_dataset(training_features, self.batch_size)
|
| 132 |
+
val_dataset = create_tf_dataset(validation_features, self.batch_size)
|
| 133 |
+
|
| 134 |
+
# Setup optimizer
|
| 135 |
+
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
| 136 |
+
|
| 137 |
+
# Training history
|
| 138 |
+
history = {
|
| 139 |
+
'total_loss': [],
|
| 140 |
+
'rating_loss': [],
|
| 141 |
+
'retrieval_loss': [],
|
| 142 |
+
'val_total_loss': [],
|
| 143 |
+
'val_rating_loss': [],
|
| 144 |
+
'val_retrieval_loss': []
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
best_val_loss = float('inf')
|
| 148 |
+
patience_counter = 0
|
| 149 |
+
|
| 150 |
+
for epoch in range(self.epochs):
|
| 151 |
+
epoch_start = time.time()
|
| 152 |
+
print(f"\nEpoch {epoch + 1}/{self.epochs}")
|
| 153 |
+
|
| 154 |
+
# Training phase
|
| 155 |
+
epoch_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
|
| 156 |
+
|
| 157 |
+
for batch in train_dataset:
|
| 158 |
+
with tf.GradientTape() as tape:
|
| 159 |
+
# Forward pass
|
| 160 |
+
user_embeddings = self.user_tower(batch, training=True)
|
| 161 |
+
item_embeddings = self.item_tower(batch, training=True)
|
| 162 |
+
|
| 163 |
+
# Rating prediction
|
| 164 |
+
concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
|
| 165 |
+
rating_predictions = self.model.rating_model(concatenated, training=True)
|
| 166 |
+
|
| 167 |
+
# Rating loss
|
| 168 |
+
rating_loss = self.model.rating_task(
|
| 169 |
+
labels=batch["rating"],
|
| 170 |
+
predictions=rating_predictions
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Retrieval loss (dot product similarity)
|
| 174 |
+
similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 175 |
+
retrieval_loss = self.model.retrieval_loss(
|
| 176 |
+
batch["rating"],
|
| 177 |
+
tf.nn.sigmoid(similarities)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Combined loss
|
| 181 |
+
total_loss = (
|
| 182 |
+
self.model.rating_weight * rating_loss +
|
| 183 |
+
self.model.retrieval_weight * retrieval_loss
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Compute and apply gradients
|
| 187 |
+
all_variables = (
|
| 188 |
+
self.user_tower.trainable_variables +
|
| 189 |
+
self.item_tower.trainable_variables +
|
| 190 |
+
self.model.rating_model.trainable_variables
|
| 191 |
+
)
|
| 192 |
+
gradients = tape.gradient(total_loss, all_variables)
|
| 193 |
+
optimizer.apply_gradients(zip(gradients, all_variables))
|
| 194 |
+
|
| 195 |
+
# Track losses
|
| 196 |
+
epoch_losses['total_loss'].append(total_loss)
|
| 197 |
+
epoch_losses['rating_loss'].append(rating_loss)
|
| 198 |
+
epoch_losses['retrieval_loss'].append(retrieval_loss)
|
| 199 |
+
|
| 200 |
+
# Validation phase
|
| 201 |
+
val_losses = {'total_loss': [], 'rating_loss': [], 'retrieval_loss': []}
|
| 202 |
+
|
| 203 |
+
for batch in val_dataset:
|
| 204 |
+
user_embeddings = self.user_tower(batch, training=False)
|
| 205 |
+
item_embeddings = self.item_tower(batch, training=False)
|
| 206 |
+
|
| 207 |
+
concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
|
| 208 |
+
rating_predictions = self.model.rating_model(concatenated, training=False)
|
| 209 |
+
|
| 210 |
+
rating_loss = self.model.rating_task(
|
| 211 |
+
labels=batch["rating"],
|
| 212 |
+
predictions=rating_predictions
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 216 |
+
retrieval_loss = self.model.retrieval_loss(
|
| 217 |
+
batch["rating"],
|
| 218 |
+
tf.nn.sigmoid(similarities)
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
total_loss = (
|
| 222 |
+
self.model.rating_weight * rating_loss +
|
| 223 |
+
self.model.retrieval_weight * retrieval_loss
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
val_losses['total_loss'].append(total_loss)
|
| 227 |
+
val_losses['rating_loss'].append(rating_loss)
|
| 228 |
+
val_losses['retrieval_loss'].append(retrieval_loss)
|
| 229 |
+
|
| 230 |
+
# Calculate average losses
|
| 231 |
+
avg_train_losses = {k: tf.reduce_mean(v).numpy() for k, v in epoch_losses.items()}
|
| 232 |
+
avg_val_losses = {k: tf.reduce_mean(v).numpy() for k, v in val_losses.items()}
|
| 233 |
+
|
| 234 |
+
# Update history
|
| 235 |
+
for key in history.keys():
|
| 236 |
+
if key.startswith('val_'):
|
| 237 |
+
history[key].append(avg_val_losses[key.replace('val_', '')])
|
| 238 |
+
else:
|
| 239 |
+
history[key].append(avg_train_losses[key])
|
| 240 |
+
|
| 241 |
+
# Print progress
|
| 242 |
+
epoch_time = time.time() - epoch_start
|
| 243 |
+
print(f"Time: {epoch_time:.1f}s | Train: {avg_train_losses['total_loss']:.4f} | Val: {avg_val_losses['total_loss']:.4f}")
|
| 244 |
+
print(f" Rating: {avg_val_losses['rating_loss']:.4f} | Retrieval: {avg_val_losses['retrieval_loss']:.4f}")
|
| 245 |
+
|
| 246 |
+
# Early stopping and best model saving
|
| 247 |
+
if avg_val_losses['total_loss'] < best_val_loss:
|
| 248 |
+
best_val_loss = avg_val_losses['total_loss']
|
| 249 |
+
patience_counter = 0
|
| 250 |
+
self.save_model("_best")
|
| 251 |
+
print(" ✅ Best model saved!")
|
| 252 |
+
else:
|
| 253 |
+
patience_counter += 1
|
| 254 |
+
if patience_counter >= self.patience:
|
| 255 |
+
print(f"Early stopping at epoch {epoch + 1}")
|
| 256 |
+
break
|
| 257 |
+
|
| 258 |
+
print("Joint training completed!")
|
| 259 |
+
return history
|
| 260 |
+
|
| 261 |
+
def generate_item_embeddings(self, training_features: Dict):
|
| 262 |
+
"""Generate item embeddings for FAISS index."""
|
| 263 |
+
|
| 264 |
+
print("Generating item embeddings...")
|
| 265 |
+
|
| 266 |
+
# Get all unique items from training data
|
| 267 |
+
unique_items = np.unique(training_features['product_id'])
|
| 268 |
+
item_embeddings = {}
|
| 269 |
+
|
| 270 |
+
# Process in batches
|
| 271 |
+
batch_size = 1000
|
| 272 |
+
for i in range(0, len(unique_items), batch_size):
|
| 273 |
+
batch_items = unique_items[i:i+batch_size]
|
| 274 |
+
|
| 275 |
+
# Create batch features
|
| 276 |
+
batch_features = {
|
| 277 |
+
'product_id': batch_items,
|
| 278 |
+
'category_id': training_features['category_id'][:len(batch_items)],
|
| 279 |
+
'brand_id': training_features['brand_id'][:len(batch_items)],
|
| 280 |
+
'price': training_features['price'][:len(batch_items)]
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
# Convert to tensors
|
| 284 |
+
batch_tensors = {k: tf.constant(v) for k, v in batch_features.items()}
|
| 285 |
+
|
| 286 |
+
# Get embeddings
|
| 287 |
+
embeddings = self.item_tower(batch_tensors, training=False)
|
| 288 |
+
|
| 289 |
+
# Store embeddings
|
| 290 |
+
for j, item_id in enumerate(batch_items):
|
| 291 |
+
# Map back from vocab index to actual item ID
|
| 292 |
+
actual_item_id = item_id # Assuming direct mapping
|
| 293 |
+
item_embeddings[actual_item_id] = embeddings[j].numpy()
|
| 294 |
+
|
| 295 |
+
print(f"Generated embeddings for {len(item_embeddings)} items")
|
| 296 |
+
return item_embeddings
|
| 297 |
+
|
| 298 |
+
def save_model(self, suffix=""):
|
| 299 |
+
"""Save trained models."""
|
| 300 |
+
|
| 301 |
+
save_path = "src/artifacts/"
|
| 302 |
+
os.makedirs(save_path, exist_ok=True)
|
| 303 |
+
|
| 304 |
+
# Save model weights
|
| 305 |
+
self.user_tower.save_weights(f"{save_path}/user_tower_weights{suffix}")
|
| 306 |
+
self.item_tower.save_weights(f"{save_path}/item_tower_weights_finetuned{suffix}")
|
| 307 |
+
self.model.rating_model.save_weights(f"{save_path}/rating_model_weights{suffix}")
|
| 308 |
+
|
| 309 |
+
# Save item tower config for inference
|
| 310 |
+
with open(f"{save_path}/item_tower_config.txt", 'w') as f:
|
| 311 |
+
f.write(f"embedding_dim: {self.embedding_dim}\n")
|
| 312 |
+
f.write(f"hidden_dims: [256, 128]\n") # Item tower architecture
|
| 313 |
+
f.write(f"dropout_rate: 0.2\n")
|
| 314 |
+
|
| 315 |
+
if not suffix:
|
| 316 |
+
print("Final model saved")
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def run_fast_joint_training():
|
| 320 |
+
"""Run fast optimized joint training."""
|
| 321 |
+
|
| 322 |
+
print("\n" + "="*60)
|
| 323 |
+
print("FAST JOINT TRAINING MODE")
|
| 324 |
+
print("="*60)
|
| 325 |
+
|
| 326 |
+
# Initialize fast trainer
|
| 327 |
+
trainer = FastJointTrainer()
|
| 328 |
+
|
| 329 |
+
# Check if we need to prepare data first
|
| 330 |
+
if not os.path.exists("src/artifacts/training_features.pkl"):
|
| 331 |
+
print("Preparing data first...")
|
| 332 |
+
single_trainer = SingleJointTrainer()
|
| 333 |
+
training_features, validation_features = single_trainer.prepare_data()
|
| 334 |
+
|
| 335 |
+
# Run fast training
|
| 336 |
+
trainer.load_components()
|
| 337 |
+
|
| 338 |
+
print("Loading training data...")
|
| 339 |
+
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 340 |
+
training_features = pickle.load(f)
|
| 341 |
+
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 342 |
+
validation_features = pickle.load(f)
|
| 343 |
+
|
| 344 |
+
start_time = time.time()
|
| 345 |
+
trainer.train_fast(training_features, validation_features)
|
| 346 |
+
training_time = time.time() - start_time
|
| 347 |
+
|
| 348 |
+
# Generate embeddings and build FAISS index
|
| 349 |
+
print("Building FAISS index...")
|
| 350 |
+
# Use single trainer for embedding generation
|
| 351 |
+
single_trainer = SingleJointTrainer()
|
| 352 |
+
single_trainer.data_processor = DataProcessor()
|
| 353 |
+
single_trainer.data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
|
| 354 |
+
single_trainer.item_tower = trainer.item_tower
|
| 355 |
+
|
| 356 |
+
item_embeddings = single_trainer.generate_item_embeddings(training_features)
|
| 357 |
+
|
| 358 |
+
faiss_index = FAISSItemIndex()
|
| 359 |
+
faiss_index.build_index(item_embeddings)
|
| 360 |
+
faiss_index.save_index("src/artifacts/")
|
| 361 |
+
|
| 362 |
+
return training_time
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def run_regular_joint_training():
|
| 366 |
+
"""Run regular comprehensive joint training."""
|
| 367 |
+
|
| 368 |
+
print("\n" + "="*60)
|
| 369 |
+
print("REGULAR JOINT TRAINING MODE")
|
| 370 |
+
print("="*60)
|
| 371 |
+
|
| 372 |
+
# Initialize trainer
|
| 373 |
+
trainer = SingleJointTrainer()
|
| 374 |
+
|
| 375 |
+
# Prepare data
|
| 376 |
+
training_features, validation_features = trainer.prepare_data()
|
| 377 |
+
|
| 378 |
+
# Build models from scratch
|
| 379 |
+
trainer.build_models()
|
| 380 |
+
|
| 381 |
+
# Train joint model
|
| 382 |
+
start_time = time.time()
|
| 383 |
+
history = trainer.train_joint_model(training_features, validation_features)
|
| 384 |
+
training_time = time.time() - start_time
|
| 385 |
+
|
| 386 |
+
# Generate item embeddings
|
| 387 |
+
item_embeddings = trainer.generate_item_embeddings(training_features)
|
| 388 |
+
|
| 389 |
+
# Build FAISS index
|
| 390 |
+
print("Building FAISS index...")
|
| 391 |
+
faiss_index = FAISSItemIndex()
|
| 392 |
+
faiss_index.build_index(item_embeddings)
|
| 393 |
+
faiss_index.save_index("src/artifacts/")
|
| 394 |
+
|
| 395 |
+
# Save final model
|
| 396 |
+
trainer.save_model()
|
| 397 |
+
|
| 398 |
+
# Save training history
|
| 399 |
+
with open("src/artifacts/single_joint_training_history.pkl", 'wb') as f:
|
| 400 |
+
pickle.dump(history, f)
|
| 401 |
+
|
| 402 |
+
return training_time, history
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def main():
|
| 406 |
+
"""Main function to run single joint training pipeline."""
|
| 407 |
+
|
| 408 |
+
parser = argparse.ArgumentParser(description='Single Joint Training Pipeline')
|
| 409 |
+
parser.add_argument('--fast', action='store_true', help='Use fast training mode')
|
| 410 |
+
args = parser.parse_args()
|
| 411 |
+
|
| 412 |
+
print("🚀 STARTING SINGLE JOINT TRAINING PIPELINE")
|
| 413 |
+
print(f"Working directory: {os.getcwd()}")
|
| 414 |
+
print(f"Training mode: {'FAST' if args.fast else 'REGULAR'}")
|
| 415 |
+
|
| 416 |
+
total_start_time = time.time()
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
if args.fast:
|
| 420 |
+
training_time = run_fast_joint_training()
|
| 421 |
+
history = None
|
| 422 |
+
else:
|
| 423 |
+
training_time, history = run_regular_joint_training()
|
| 424 |
+
|
| 425 |
+
total_time = time.time() - total_start_time
|
| 426 |
+
|
| 427 |
+
print("\n" + "="*60)
|
| 428 |
+
print("🎉 SINGLE JOINT TRAINING COMPLETED SUCCESSFULLY!")
|
| 429 |
+
print("="*60)
|
| 430 |
+
print(f"Training time: {training_time:.2f} seconds ({training_time/60:.1f} minutes)")
|
| 431 |
+
print(f"Total time: {total_time:.2f} seconds ({total_time/60:.1f} minutes)")
|
| 432 |
+
print(f"Artifacts saved in: src/artifacts/")
|
| 433 |
+
|
| 434 |
+
print("\nKey files generated:")
|
| 435 |
+
print(" - user_tower_weights_best: Trained user tower")
|
| 436 |
+
print(" - item_tower_weights_finetuned_best: Trained item tower")
|
| 437 |
+
print(" - rating_model_weights_best: Rating prediction model")
|
| 438 |
+
print(" - faiss_index.index: Item similarity index")
|
| 439 |
+
print(" - vocabularies.pkl: Feature vocabularies")
|
| 440 |
+
|
| 441 |
+
if history:
|
| 442 |
+
print(f"\n🔥 Best validation loss: {min(history['val_total_loss']):.4f}")
|
| 443 |
+
|
| 444 |
+
print(f"\n🎯 Training approach: Single-phase joint optimization")
|
| 445 |
+
print("✅ Ready to run inference with api/main.py!")
|
| 446 |
+
|
| 447 |
+
except Exception as e:
|
| 448 |
+
print(f"\n❌ Training failed with error: {str(e)}")
|
| 449 |
+
raise
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
if __name__ == "__main__":
|
| 453 |
+
main()
|
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Enhanced recommendation engine using 128D embeddings with diversity regularization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
import pickle
|
| 10 |
+
import os
|
| 11 |
+
from typing import Dict, List, Tuple, Optional
|
| 12 |
+
from collections import Counter, defaultdict
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 16 |
+
|
| 17 |
+
from src.models.enhanced_two_tower import EnhancedItemTower, EnhancedUserTower
|
| 18 |
+
from src.inference.faiss_index import FAISSItemIndex
|
| 19 |
+
from src.preprocessing.data_loader import DataProcessor
|
| 20 |
+
from src.preprocessing.user_data_preparation import prepare_user_features
|
| 21 |
+
from src.utils.real_user_selector import RealUserSelector
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Enhanced128DRecommendationEngine:
|
| 25 |
+
"""Enhanced recommendation engine with 128D embeddings and all improvements."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, artifacts_path: str = "src/artifacts/"):
|
| 28 |
+
self.artifacts_path = artifacts_path
|
| 29 |
+
self.embedding_dim = 128 # Fixed to 128D
|
| 30 |
+
|
| 31 |
+
# Model components
|
| 32 |
+
self.item_tower = None
|
| 33 |
+
self.user_tower = None
|
| 34 |
+
self.rating_model = None
|
| 35 |
+
self.faiss_index = None
|
| 36 |
+
self.data_processor = None
|
| 37 |
+
|
| 38 |
+
# Data
|
| 39 |
+
self.items_df = None
|
| 40 |
+
self.users_df = None
|
| 41 |
+
self.income_thresholds = None
|
| 42 |
+
|
| 43 |
+
# Load all components
|
| 44 |
+
self._load_all_components()
|
| 45 |
+
|
| 46 |
+
def _load_all_components(self):
|
| 47 |
+
"""Load all enhanced model components."""
|
| 48 |
+
|
| 49 |
+
print("Loading enhanced 128D recommendation engine...")
|
| 50 |
+
|
| 51 |
+
# Load data processor
|
| 52 |
+
self.data_processor = DataProcessor()
|
| 53 |
+
try:
|
| 54 |
+
self.data_processor.load_vocabularies(f"{self.artifacts_path}/vocabularies.pkl")
|
| 55 |
+
except FileNotFoundError:
|
| 56 |
+
print("❌ Vocabularies not found. Please train the model first.")
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
# Load datasets
|
| 60 |
+
self.items_df = pd.read_csv("datasets/items.csv")
|
| 61 |
+
self.users_df = pd.read_csv("datasets/users.csv")
|
| 62 |
+
|
| 63 |
+
# Load enhanced model components
|
| 64 |
+
self._load_enhanced_models()
|
| 65 |
+
|
| 66 |
+
# Load FAISS index with 128D
|
| 67 |
+
try:
|
| 68 |
+
self.faiss_index = FAISSItemIndex(embedding_dim=self.embedding_dim)
|
| 69 |
+
# Try to load enhanced embeddings first
|
| 70 |
+
if os.path.exists(f"{self.artifacts_path}/enhanced_item_embeddings.npy"):
|
| 71 |
+
enhanced_embeddings = np.load(
|
| 72 |
+
f"{self.artifacts_path}/enhanced_item_embeddings.npy",
|
| 73 |
+
allow_pickle=True
|
| 74 |
+
).item()
|
| 75 |
+
self.faiss_index.build_index(enhanced_embeddings)
|
| 76 |
+
print("✅ Loaded enhanced 128D FAISS index")
|
| 77 |
+
else:
|
| 78 |
+
print("⚠️ Enhanced embeddings not found. Train enhanced model first.")
|
| 79 |
+
self.faiss_index = None
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"⚠️ Could not load FAISS index: {e}")
|
| 82 |
+
self.faiss_index = None
|
| 83 |
+
|
| 84 |
+
# Load income thresholds for categorical demographics
|
| 85 |
+
self._load_income_thresholds()
|
| 86 |
+
|
| 87 |
+
print("✅ Enhanced 128D engine loaded successfully!")
|
| 88 |
+
|
| 89 |
+
def _load_enhanced_models(self):
|
| 90 |
+
"""Load enhanced model components."""
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
# Create model architecture
|
| 94 |
+
self.item_tower = EnhancedItemTower(
|
| 95 |
+
item_vocab_size=len(self.data_processor.item_vocab),
|
| 96 |
+
category_vocab_size=len(self.data_processor.category_vocab),
|
| 97 |
+
brand_vocab_size=len(self.data_processor.brand_vocab),
|
| 98 |
+
embedding_dim=self.embedding_dim,
|
| 99 |
+
use_bias=True,
|
| 100 |
+
use_diversity_reg=False # Disable during inference
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.user_tower = EnhancedUserTower(
|
| 104 |
+
max_history_length=50,
|
| 105 |
+
embedding_dim=self.embedding_dim,
|
| 106 |
+
use_bias=True,
|
| 107 |
+
use_diversity_reg=False # Disable during inference
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Create rating model
|
| 111 |
+
self.rating_model = tf.keras.Sequential([
|
| 112 |
+
tf.keras.layers.Dense(512, activation="relu"),
|
| 113 |
+
tf.keras.layers.BatchNormalization(),
|
| 114 |
+
tf.keras.layers.Dropout(0.3),
|
| 115 |
+
tf.keras.layers.Dense(256, activation="relu"),
|
| 116 |
+
tf.keras.layers.BatchNormalization(),
|
| 117 |
+
tf.keras.layers.Dropout(0.2),
|
| 118 |
+
tf.keras.layers.Dense(64, activation="relu"),
|
| 119 |
+
tf.keras.layers.Dense(1, activation="sigmoid")
|
| 120 |
+
])
|
| 121 |
+
|
| 122 |
+
# Load weights - try enhanced first, fall back to regular
|
| 123 |
+
model_files = [
|
| 124 |
+
('enhanced_item_tower_weights_enhanced_best', 'enhanced_user_tower_weights_enhanced_best', 'enhanced_rating_model_weights_enhanced_best'),
|
| 125 |
+
('enhanced_item_tower_weights_enhanced_final', 'enhanced_user_tower_weights_enhanced_final', 'enhanced_rating_model_weights_enhanced_final'),
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
loaded = False
|
| 129 |
+
for item_file, user_file, rating_file in model_files:
|
| 130 |
+
try:
|
| 131 |
+
# Need to build models first with dummy data
|
| 132 |
+
self._build_models()
|
| 133 |
+
|
| 134 |
+
self.item_tower.load_weights(f"{self.artifacts_path}/{item_file}")
|
| 135 |
+
self.user_tower.load_weights(f"{self.artifacts_path}/{user_file}")
|
| 136 |
+
self.rating_model.load_weights(f"{self.artifacts_path}/{rating_file}")
|
| 137 |
+
|
| 138 |
+
print(f"✅ Loaded enhanced model: {item_file}")
|
| 139 |
+
loaded = True
|
| 140 |
+
break
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"⚠️ Could not load {item_file}: {e}")
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
if not loaded:
|
| 146 |
+
print("❌ No enhanced model weights found. Please train enhanced model first.")
|
| 147 |
+
self.item_tower = None
|
| 148 |
+
self.user_tower = None
|
| 149 |
+
self.rating_model = None
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"❌ Failed to load enhanced models: {e}")
|
| 153 |
+
self.item_tower = None
|
| 154 |
+
self.user_tower = None
|
| 155 |
+
self.rating_model = None
|
| 156 |
+
|
| 157 |
+
def _build_models(self):
|
| 158 |
+
"""Build models with dummy data to initialize weights."""
|
| 159 |
+
|
| 160 |
+
# Dummy item features
|
| 161 |
+
dummy_item_features = {
|
| 162 |
+
'product_id': tf.constant([0]),
|
| 163 |
+
'category_id': tf.constant([0]),
|
| 164 |
+
'brand_id': tf.constant([0]),
|
| 165 |
+
'price': tf.constant([100.0])
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
# Dummy user features
|
| 169 |
+
dummy_user_features = {
|
| 170 |
+
'age': tf.constant([2]), # Adult category
|
| 171 |
+
'gender': tf.constant([0]), # Female
|
| 172 |
+
'income': tf.constant([2]), # Middle income
|
| 173 |
+
'item_history_embeddings': tf.constant(np.zeros((1, 50, self.embedding_dim), dtype=np.float32))
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
# Forward pass to build models
|
| 177 |
+
_ = self.item_tower(dummy_item_features, training=False)
|
| 178 |
+
_ = self.user_tower(dummy_user_features, training=False)
|
| 179 |
+
|
| 180 |
+
# Build rating model
|
| 181 |
+
dummy_concat = tf.constant(np.zeros((1, self.embedding_dim * 2), dtype=np.float32))
|
| 182 |
+
_ = self.rating_model(dummy_concat, training=False)
|
| 183 |
+
|
| 184 |
+
def _load_income_thresholds(self):
|
| 185 |
+
"""Load income thresholds for categorical processing."""
|
| 186 |
+
|
| 187 |
+
# Calculate income thresholds from training data
|
| 188 |
+
user_incomes = self.users_df['income'].values
|
| 189 |
+
self.income_thresholds = np.percentile(user_incomes, [0, 20, 40, 60, 80, 100])
|
| 190 |
+
print(f"Income thresholds: {self.income_thresholds}")
|
| 191 |
+
|
| 192 |
+
def categorize_age(self, age: float) -> int:
|
| 193 |
+
"""Categorize age into 6 groups."""
|
| 194 |
+
if age < 18: return 0 # Teen
|
| 195 |
+
elif age < 26: return 1 # Young Adult
|
| 196 |
+
elif age < 36: return 2 # Adult
|
| 197 |
+
elif age < 51: return 3 # Middle Age
|
| 198 |
+
elif age < 66: return 4 # Mature
|
| 199 |
+
else: return 5 # Senior
|
| 200 |
+
|
| 201 |
+
def categorize_income(self, income: float) -> int:
|
| 202 |
+
"""Categorize income into 5 percentile groups."""
|
| 203 |
+
category = np.digitize([income], self.income_thresholds[1:-1])[0]
|
| 204 |
+
return min(max(category, 0), 4)
|
| 205 |
+
|
| 206 |
+
def categorize_gender(self, gender: str) -> int:
|
| 207 |
+
"""Categorize gender."""
|
| 208 |
+
return 1 if gender.lower() == 'male' else 0
|
| 209 |
+
|
| 210 |
+
def get_user_embedding(self,
|
| 211 |
+
age: int,
|
| 212 |
+
gender: str,
|
| 213 |
+
income: float,
|
| 214 |
+
interaction_history: List[int] = None) -> np.ndarray:
|
| 215 |
+
"""Generate user embedding with categorical demographics."""
|
| 216 |
+
|
| 217 |
+
if self.user_tower is None:
|
| 218 |
+
print("❌ User tower not loaded")
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
# Categorize demographics
|
| 222 |
+
age_cat = self.categorize_age(age)
|
| 223 |
+
gender_cat = self.categorize_gender(gender)
|
| 224 |
+
income_cat = self.categorize_income(income)
|
| 225 |
+
|
| 226 |
+
# Prepare interaction history embeddings
|
| 227 |
+
if interaction_history is None:
|
| 228 |
+
interaction_history = []
|
| 229 |
+
|
| 230 |
+
# Get item embeddings for history
|
| 231 |
+
history_embeddings = np.zeros((50, self.embedding_dim), dtype=np.float32)
|
| 232 |
+
|
| 233 |
+
for i, item_id in enumerate(interaction_history[:50]):
|
| 234 |
+
if self.faiss_index and item_id in self.faiss_index.item_id_to_idx:
|
| 235 |
+
item_emb = self.faiss_index.get_item_embedding(item_id)
|
| 236 |
+
if item_emb is not None:
|
| 237 |
+
history_embeddings[i] = item_emb
|
| 238 |
+
|
| 239 |
+
# Create user features
|
| 240 |
+
user_features = {
|
| 241 |
+
'age': tf.constant([age_cat]),
|
| 242 |
+
'gender': tf.constant([gender_cat]),
|
| 243 |
+
'income': tf.constant([income_cat]),
|
| 244 |
+
'item_history_embeddings': tf.constant([history_embeddings])
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# Get embedding
|
| 248 |
+
user_output = self.user_tower(user_features, training=False)
|
| 249 |
+
if isinstance(user_output, tuple):
|
| 250 |
+
user_embedding = user_output[0].numpy()[0]
|
| 251 |
+
else:
|
| 252 |
+
user_embedding = user_output.numpy()[0]
|
| 253 |
+
|
| 254 |
+
return user_embedding
|
| 255 |
+
|
| 256 |
+
def get_item_embedding(self, item_id: int) -> Optional[np.ndarray]:
|
| 257 |
+
"""Get item embedding."""
|
| 258 |
+
|
| 259 |
+
if self.faiss_index:
|
| 260 |
+
return self.faiss_index.get_item_embedding(item_id)
|
| 261 |
+
|
| 262 |
+
# Fallback to model computation
|
| 263 |
+
if self.item_tower is None:
|
| 264 |
+
return None
|
| 265 |
+
|
| 266 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 267 |
+
if item_row.empty:
|
| 268 |
+
return None
|
| 269 |
+
|
| 270 |
+
item_data = item_row.iloc[0]
|
| 271 |
+
|
| 272 |
+
# Prepare features
|
| 273 |
+
item_features = {
|
| 274 |
+
'product_id': tf.constant([self.data_processor.item_vocab.get(item_id, 0)]),
|
| 275 |
+
'category_id': tf.constant([self.data_processor.category_vocab.get(item_data['category_id'], 0)]),
|
| 276 |
+
'brand_id': tf.constant([self.data_processor.brand_vocab.get(item_data.get('brand', 'unknown'), 0)]),
|
| 277 |
+
'price': tf.constant([float(item_data.get('price', 0.0))])
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
# Get embedding
|
| 281 |
+
item_output = self.item_tower(item_features, training=False)
|
| 282 |
+
if isinstance(item_output, tuple):
|
| 283 |
+
item_embedding = item_output[0].numpy()[0]
|
| 284 |
+
else:
|
| 285 |
+
item_embedding = item_output.numpy()[0]
|
| 286 |
+
|
| 287 |
+
return item_embedding
|
| 288 |
+
|
| 289 |
+
def recommend_items_enhanced(self,
|
| 290 |
+
age: int,
|
| 291 |
+
gender: str,
|
| 292 |
+
income: float,
|
| 293 |
+
interaction_history: List[int] = None,
|
| 294 |
+
k: int = 10,
|
| 295 |
+
diversity_weight: float = 0.3,
|
| 296 |
+
category_boost: float = 1.5) -> List[Tuple[int, float, Dict]]:
|
| 297 |
+
"""Generate enhanced recommendations with diversity and category boosting."""
|
| 298 |
+
|
| 299 |
+
if not self.faiss_index:
|
| 300 |
+
print("❌ FAISS index not available")
|
| 301 |
+
return []
|
| 302 |
+
|
| 303 |
+
# Get user embedding
|
| 304 |
+
user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
|
| 305 |
+
if user_embedding is None:
|
| 306 |
+
return []
|
| 307 |
+
|
| 308 |
+
# Get candidate recommendations (more than needed for filtering)
|
| 309 |
+
candidates = self.faiss_index.search_by_embedding(user_embedding, k * 3)
|
| 310 |
+
|
| 311 |
+
# Filter out items from interaction history
|
| 312 |
+
if interaction_history:
|
| 313 |
+
history_set = set(interaction_history)
|
| 314 |
+
candidates = [(item_id, score) for item_id, score in candidates
|
| 315 |
+
if item_id not in history_set]
|
| 316 |
+
|
| 317 |
+
# Add item metadata and apply enhancements
|
| 318 |
+
enhanced_candidates = []
|
| 319 |
+
|
| 320 |
+
for item_id, similarity_score in candidates[:k * 2]:
|
| 321 |
+
# Get item info
|
| 322 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 323 |
+
if item_row.empty:
|
| 324 |
+
continue
|
| 325 |
+
|
| 326 |
+
item_info = item_row.iloc[0].to_dict()
|
| 327 |
+
|
| 328 |
+
# Enhanced scoring with multiple factors
|
| 329 |
+
final_score = similarity_score
|
| 330 |
+
|
| 331 |
+
# Category boosting based on user history
|
| 332 |
+
if interaction_history and category_boost > 1.0:
|
| 333 |
+
user_categories = self._get_user_categories(interaction_history)
|
| 334 |
+
item_category = item_info.get('category_code', '')
|
| 335 |
+
|
| 336 |
+
if item_category in user_categories:
|
| 337 |
+
category_preference = user_categories[item_category]
|
| 338 |
+
final_score *= (1 + (category_boost - 1) * category_preference)
|
| 339 |
+
|
| 340 |
+
enhanced_candidates.append((item_id, final_score, item_info))
|
| 341 |
+
|
| 342 |
+
# Sort by enhanced scores
|
| 343 |
+
enhanced_candidates.sort(key=lambda x: x[1], reverse=True)
|
| 344 |
+
|
| 345 |
+
# Apply diversity filtering
|
| 346 |
+
if diversity_weight > 0:
|
| 347 |
+
diversified_candidates = self._apply_diversity_filter(
|
| 348 |
+
enhanced_candidates, diversity_weight
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
diversified_candidates = enhanced_candidates
|
| 352 |
+
|
| 353 |
+
return diversified_candidates[:k]
|
| 354 |
+
|
| 355 |
+
def _get_user_categories(self, interaction_history: List[int]) -> Dict[str, float]:
|
| 356 |
+
"""Get user's category preferences from history."""
|
| 357 |
+
|
| 358 |
+
category_counts = Counter()
|
| 359 |
+
|
| 360 |
+
for item_id in interaction_history:
|
| 361 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 362 |
+
if not item_row.empty:
|
| 363 |
+
category = item_row.iloc[0].get('category_code', 'Unknown')
|
| 364 |
+
category_counts[category] += 1
|
| 365 |
+
|
| 366 |
+
# Convert to preferences (percentages)
|
| 367 |
+
total = sum(category_counts.values())
|
| 368 |
+
if total == 0:
|
| 369 |
+
return {}
|
| 370 |
+
|
| 371 |
+
return {cat: count / total for cat, count in category_counts.items()}
|
| 372 |
+
|
| 373 |
+
def _apply_diversity_filter(self,
|
| 374 |
+
candidates: List[Tuple[int, float, Dict]],
|
| 375 |
+
diversity_weight: float,
|
| 376 |
+
max_per_category: int = 3) -> List[Tuple[int, float, Dict]]:
|
| 377 |
+
"""Apply diversity filtering to recommendations."""
|
| 378 |
+
|
| 379 |
+
category_counts = defaultdict(int)
|
| 380 |
+
diversified = []
|
| 381 |
+
|
| 382 |
+
for item_id, score, item_info in candidates:
|
| 383 |
+
category = item_info.get('category_code', 'Unknown')
|
| 384 |
+
|
| 385 |
+
# Apply diversity penalty
|
| 386 |
+
if category_counts[category] >= max_per_category:
|
| 387 |
+
# Penalty for over-representation
|
| 388 |
+
diversity_penalty = diversity_weight * (category_counts[category] - max_per_category + 1)
|
| 389 |
+
adjusted_score = score * (1 - diversity_penalty)
|
| 390 |
+
else:
|
| 391 |
+
adjusted_score = score
|
| 392 |
+
|
| 393 |
+
diversified.append((item_id, adjusted_score, item_info))
|
| 394 |
+
category_counts[category] += 1
|
| 395 |
+
|
| 396 |
+
# Re-sort by adjusted scores
|
| 397 |
+
diversified.sort(key=lambda x: x[1], reverse=True)
|
| 398 |
+
return diversified
|
| 399 |
+
|
| 400 |
+
def predict_rating(self,
|
| 401 |
+
age: int,
|
| 402 |
+
gender: str,
|
| 403 |
+
income: float,
|
| 404 |
+
item_id: int,
|
| 405 |
+
interaction_history: List[int] = None) -> float:
|
| 406 |
+
"""Predict rating for user-item pair."""
|
| 407 |
+
|
| 408 |
+
if self.rating_model is None:
|
| 409 |
+
return 0.5 # Default rating
|
| 410 |
+
|
| 411 |
+
# Get embeddings
|
| 412 |
+
user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
|
| 413 |
+
item_embedding = self.get_item_embedding(item_id)
|
| 414 |
+
|
| 415 |
+
if user_embedding is None or item_embedding is None:
|
| 416 |
+
return 0.5
|
| 417 |
+
|
| 418 |
+
# Concatenate embeddings
|
| 419 |
+
combined = np.concatenate([user_embedding, item_embedding])
|
| 420 |
+
combined = tf.constant([combined])
|
| 421 |
+
|
| 422 |
+
# Predict rating
|
| 423 |
+
rating = self.rating_model(combined, training=False)
|
| 424 |
+
return float(rating.numpy()[0][0])
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def demo_enhanced_engine():
|
| 428 |
+
"""Demo the enhanced 128D recommendation engine."""
|
| 429 |
+
|
| 430 |
+
print("🚀 ENHANCED 128D RECOMMENDATION ENGINE DEMO")
|
| 431 |
+
print("="*70)
|
| 432 |
+
|
| 433 |
+
try:
|
| 434 |
+
# Initialize engine
|
| 435 |
+
engine = Enhanced128DRecommendationEngine()
|
| 436 |
+
|
| 437 |
+
if engine.item_tower is None:
|
| 438 |
+
print("❌ Enhanced model not available. Please train first using:")
|
| 439 |
+
print(" python train_enhanced_model.py")
|
| 440 |
+
return
|
| 441 |
+
|
| 442 |
+
# Get real user for testing
|
| 443 |
+
real_user_selector = RealUserSelector()
|
| 444 |
+
test_users = real_user_selector.get_real_users(n=2, min_interactions=10)
|
| 445 |
+
|
| 446 |
+
for user in test_users:
|
| 447 |
+
print(f"\n📊 Testing User {user['user_id']} ({user['age']}yr {user['gender']}):")
|
| 448 |
+
print(f" Income: ${user['income']:,}")
|
| 449 |
+
print(f" History: {len(user['interaction_history'])} items")
|
| 450 |
+
|
| 451 |
+
# Test enhanced recommendations
|
| 452 |
+
try:
|
| 453 |
+
recs = engine.recommend_items_enhanced(
|
| 454 |
+
age=user['age'],
|
| 455 |
+
gender=user['gender'],
|
| 456 |
+
income=user['income'],
|
| 457 |
+
interaction_history=user['interaction_history'][:20],
|
| 458 |
+
k=10,
|
| 459 |
+
diversity_weight=0.3,
|
| 460 |
+
category_boost=1.5
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
print(f" 🎯 Enhanced Recommendations:")
|
| 464 |
+
categories = []
|
| 465 |
+
for i, (item_id, score, item_info) in enumerate(recs[:5]):
|
| 466 |
+
category = item_info.get('category_code', 'Unknown')[:30]
|
| 467 |
+
price = item_info.get('price', 0)
|
| 468 |
+
categories.append(category)
|
| 469 |
+
print(f" #{i+1} Item {item_id}: {score:.4f} | ${price:.2f} | {category}")
|
| 470 |
+
|
| 471 |
+
# Analyze diversity
|
| 472 |
+
unique_categories = len(set(categories))
|
| 473 |
+
print(f" 📈 Diversity: {unique_categories}/{len(categories)} unique categories")
|
| 474 |
+
|
| 475 |
+
# Test rating prediction
|
| 476 |
+
if recs:
|
| 477 |
+
test_item = recs[0][0]
|
| 478 |
+
predicted_rating = engine.predict_rating(
|
| 479 |
+
age=user['age'],
|
| 480 |
+
gender=user['gender'],
|
| 481 |
+
income=user['income'],
|
| 482 |
+
item_id=test_item,
|
| 483 |
+
interaction_history=user['interaction_history'][:20]
|
| 484 |
+
)
|
| 485 |
+
print(f" ⭐ Rating prediction for item {test_item}: {predicted_rating:.3f}")
|
| 486 |
+
|
| 487 |
+
except Exception as e:
|
| 488 |
+
print(f" ❌ Error: {e}")
|
| 489 |
+
|
| 490 |
+
print(f"\n✅ Enhanced 128D engine demo completed!")
|
| 491 |
+
|
| 492 |
+
except Exception as e:
|
| 493 |
+
print(f"❌ Demo failed: {e}")
|
| 494 |
+
import traceback
|
| 495 |
+
traceback.print_exc()
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
if __name__ == "__main__":
|
| 499 |
+
demo_enhanced_engine()
|
|
@@ -8,7 +8,7 @@ from typing import Dict, List, Tuple, Optional
|
|
| 8 |
class FAISSItemIndex:
|
| 9 |
"""FAISS-based item similarity search index."""
|
| 10 |
|
| 11 |
-
def __init__(self, embedding_dim: int =
|
| 12 |
self.embedding_dim = embedding_dim
|
| 13 |
self.index = None
|
| 14 |
self.item_id_to_idx = {}
|
|
@@ -40,14 +40,10 @@ class FAISSItemIndex:
|
|
| 40 |
# Exact search (slower but accurate)
|
| 41 |
self.index = faiss.IndexFlatIP(self.embedding_dim)
|
| 42 |
elif index_type == "IVF":
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
self.index = faiss.
|
| 47 |
-
|
| 48 |
-
# Train the index
|
| 49 |
-
self.index.train(embeddings_array)
|
| 50 |
-
self.index.nprobe = min(10, nlist) # Search in top 10 clusters
|
| 51 |
else:
|
| 52 |
raise ValueError(f"Unsupported index type: {index_type}")
|
| 53 |
|
|
@@ -134,6 +130,7 @@ class FAISSItemIndex:
|
|
| 134 |
sample_queries = list(self.item_id_to_idx.keys())[:5]
|
| 135 |
|
| 136 |
print("Validating FAISS index...")
|
|
|
|
| 137 |
|
| 138 |
for query_item in sample_queries:
|
| 139 |
if query_item not in self.item_id_to_idx:
|
|
@@ -141,9 +138,16 @@ class FAISSItemIndex:
|
|
| 141 |
|
| 142 |
similar_items = self.search_similar_items(query_item, k=5)
|
| 143 |
|
| 144 |
-
print(f"\nSimilar items to {query_item}:")
|
| 145 |
-
for item_id, score in similar_items:
|
| 146 |
-
print(f" Item {item_id}: similarity = {score:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def save_index(self, save_path: str = "src/artifacts/") -> None:
|
| 149 |
"""Save FAISS index and mappings."""
|
|
@@ -197,7 +201,7 @@ def main():
|
|
| 197 |
|
| 198 |
# Create and build FAISS index
|
| 199 |
print("Building FAISS index...")
|
| 200 |
-
faiss_index = FAISSItemIndex(embedding_dim=
|
| 201 |
faiss_index.build_index(item_embeddings, index_type="IVF")
|
| 202 |
|
| 203 |
# Validate index
|
|
|
|
| 8 |
class FAISSItemIndex:
|
| 9 |
"""FAISS-based item similarity search index."""
|
| 10 |
|
| 11 |
+
def __init__(self, embedding_dim: int = 128):
|
| 12 |
self.embedding_dim = embedding_dim
|
| 13 |
self.index = None
|
| 14 |
self.item_id_to_idx = {}
|
|
|
|
| 40 |
# Exact search (slower but accurate)
|
| 41 |
self.index = faiss.IndexFlatIP(self.embedding_dim)
|
| 42 |
elif index_type == "IVF":
|
| 43 |
+
# For CPU use exact search (IndexFlatIP) for better accuracy
|
| 44 |
+
# IVF is mainly beneficial for GPU, for CPU stick with exact search
|
| 45 |
+
print("Using IndexFlatIP for CPU (exact search)")
|
| 46 |
+
self.index = faiss.IndexFlatIP(self.embedding_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
else:
|
| 48 |
raise ValueError(f"Unsupported index type: {index_type}")
|
| 49 |
|
|
|
|
| 130 |
sample_queries = list(self.item_id_to_idx.keys())[:5]
|
| 131 |
|
| 132 |
print("Validating FAISS index...")
|
| 133 |
+
print("Note: Higher similarity scores = more similar items (cosine similarity)")
|
| 134 |
|
| 135 |
for query_item in sample_queries:
|
| 136 |
if query_item not in self.item_id_to_idx:
|
|
|
|
| 138 |
|
| 139 |
similar_items = self.search_similar_items(query_item, k=5)
|
| 140 |
|
| 141 |
+
print(f"\nSimilar items to {query_item} (sorted by similarity DESC):")
|
| 142 |
+
for i, (item_id, score) in enumerate(similar_items):
|
| 143 |
+
print(f" #{i+1} Item {item_id}: similarity = {score:.4f}")
|
| 144 |
+
|
| 145 |
+
# Check if scores are properly ordered (descending)
|
| 146 |
+
scores = [score for _, score in similar_items]
|
| 147 |
+
if len(scores) > 1 and not all(scores[i] >= scores[i+1] for i in range(len(scores)-1)):
|
| 148 |
+
print(f" WARNING: Scores not in descending order! {scores}")
|
| 149 |
+
else:
|
| 150 |
+
print(f" ✓ Scores properly ordered (most to least similar)")
|
| 151 |
|
| 152 |
def save_index(self, save_path: str = "src/artifacts/") -> None:
|
| 153 |
"""Save FAISS index and mappings."""
|
|
|
|
| 201 |
|
| 202 |
# Create and build FAISS index
|
| 203 |
print("Building FAISS index...")
|
| 204 |
+
faiss_index = FAISSItemIndex(embedding_dim=128)
|
| 205 |
faiss_index.build_index(item_embeddings, index_type="IVF")
|
| 206 |
|
| 207 |
# Validate index
|
|
@@ -129,8 +129,8 @@ class RecommendationEngine:
|
|
| 129 |
|
| 130 |
self.user_tower = UserTower(
|
| 131 |
max_history_length=50,
|
| 132 |
-
embedding_dim=
|
| 133 |
-
hidden_dims=[128, 64],
|
| 134 |
dropout_rate=0.2
|
| 135 |
)
|
| 136 |
|
|
@@ -139,7 +139,7 @@ class RecommendationEngine:
|
|
| 139 |
'age': tf.constant([2]), # Adult category (26-35)
|
| 140 |
'gender': tf.constant([1]), # Male
|
| 141 |
'income': tf.constant([2]), # Middle income category
|
| 142 |
-
'item_history_embeddings': tf.constant([[[0.0] *
|
| 143 |
}
|
| 144 |
_ = self.user_tower(dummy_input)
|
| 145 |
|
|
@@ -182,7 +182,7 @@ class RecommendationEngine:
|
|
| 182 |
])
|
| 183 |
|
| 184 |
# Build model with dummy input (concatenated user and item embeddings)
|
| 185 |
-
dummy_input = tf.constant([[0.0] *
|
| 186 |
_ = self.rating_model(dummy_input)
|
| 187 |
|
| 188 |
try:
|
|
@@ -216,14 +216,16 @@ class RecommendationEngine:
|
|
| 216 |
history_embeddings.append(embedding)
|
| 217 |
else:
|
| 218 |
# Use zero embedding for unknown items
|
| 219 |
-
history_embeddings.append(np.zeros(
|
| 220 |
|
| 221 |
# Pad or truncate to max_history_length
|
| 222 |
max_history_length = 50
|
| 223 |
if len(history_embeddings) < max_history_length:
|
| 224 |
-
padding
|
| 225 |
-
|
|
|
|
| 226 |
else:
|
|
|
|
| 227 |
history_embeddings = history_embeddings[-max_history_length:]
|
| 228 |
|
| 229 |
history_embeddings = np.array(history_embeddings, dtype=np.float32)
|
|
@@ -301,27 +303,52 @@ class RecommendationEngine:
|
|
| 301 |
income: float,
|
| 302 |
interaction_history: List[int] = None,
|
| 303 |
k: int = 10,
|
| 304 |
-
exclude_history: bool = True
|
| 305 |
-
|
|
|
|
| 306 |
|
| 307 |
# Get user embedding
|
| 308 |
user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
|
| 309 |
|
| 310 |
-
# Find similar items using FAISS
|
| 311 |
-
similar_items = self.faiss_index.search_by_embedding(user_embedding, k *
|
| 312 |
|
| 313 |
-
#
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
-
#
|
| 320 |
-
|
|
|
|
| 321 |
|
| 322 |
# Add item metadata
|
| 323 |
recommendations = []
|
| 324 |
-
for item_id, score in
|
| 325 |
item_info = self._get_item_info(item_id)
|
| 326 |
recommendations.append((item_id, score, item_info))
|
| 327 |
|
|
|
|
| 129 |
|
| 130 |
self.user_tower = UserTower(
|
| 131 |
max_history_length=50,
|
| 132 |
+
embedding_dim=128, # Changed from 64 to 128
|
| 133 |
+
hidden_dims=[128, 64], # Match training architecture
|
| 134 |
dropout_rate=0.2
|
| 135 |
)
|
| 136 |
|
|
|
|
| 139 |
'age': tf.constant([2]), # Adult category (26-35)
|
| 140 |
'gender': tf.constant([1]), # Male
|
| 141 |
'income': tf.constant([2]), # Middle income category
|
| 142 |
+
'item_history_embeddings': tf.constant([[[0.0] * 128] * 50]) # Changed from 64 to 128
|
| 143 |
}
|
| 144 |
_ = self.user_tower(dummy_input)
|
| 145 |
|
|
|
|
| 182 |
])
|
| 183 |
|
| 184 |
# Build model with dummy input (concatenated user and item embeddings)
|
| 185 |
+
dummy_input = tf.constant([[0.0] * 256]) # 128 + 128 = 256
|
| 186 |
_ = self.rating_model(dummy_input)
|
| 187 |
|
| 188 |
try:
|
|
|
|
| 216 |
history_embeddings.append(embedding)
|
| 217 |
else:
|
| 218 |
# Use zero embedding for unknown items
|
| 219 |
+
history_embeddings.append(np.zeros(128)) # Changed from 64 to 128
|
| 220 |
|
| 221 |
# Pad or truncate to max_history_length
|
| 222 |
max_history_length = 50
|
| 223 |
if len(history_embeddings) < max_history_length:
|
| 224 |
+
# Add padding at the END so real interactions are at the BEGINNING
|
| 225 |
+
padding = [np.zeros(128)] * (max_history_length - len(history_embeddings))
|
| 226 |
+
history_embeddings = history_embeddings + padding
|
| 227 |
else:
|
| 228 |
+
# Keep most recent interactions
|
| 229 |
history_embeddings = history_embeddings[-max_history_length:]
|
| 230 |
|
| 231 |
history_embeddings = np.array(history_embeddings, dtype=np.float32)
|
|
|
|
| 303 |
income: float,
|
| 304 |
interaction_history: List[int] = None,
|
| 305 |
k: int = 10,
|
| 306 |
+
exclude_history: bool = True,
|
| 307 |
+
category_boost: float = 1.3) -> List[Tuple[int, float, Dict]]:
|
| 308 |
+
"""Generate recommendations using collaborative filtering with category awareness."""
|
| 309 |
|
| 310 |
# Get user embedding
|
| 311 |
user_embedding = self.get_user_embedding(age, gender, income, interaction_history)
|
| 312 |
|
| 313 |
+
# Find similar items using FAISS (get more candidates for boosting)
|
| 314 |
+
similar_items = self.faiss_index.search_by_embedding(user_embedding, k * 4)
|
| 315 |
|
| 316 |
+
# Get user's preferred categories from interaction history
|
| 317 |
+
user_categories = set()
|
| 318 |
+
if interaction_history:
|
| 319 |
+
for item_id in interaction_history[-10:]: # Focus on recent interactions
|
| 320 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 321 |
+
if len(item_row) > 0:
|
| 322 |
+
user_categories.add(item_row.iloc[0]['category_code'])
|
| 323 |
+
|
| 324 |
+
# Filter out interaction history and apply category boosting
|
| 325 |
+
boosted_items = []
|
| 326 |
+
history_set = set(interaction_history) if (exclude_history and interaction_history) else set()
|
| 327 |
+
|
| 328 |
+
for item_id, score in similar_items:
|
| 329 |
+
if item_id in history_set:
|
| 330 |
+
continue
|
| 331 |
+
|
| 332 |
+
# Get item category
|
| 333 |
+
item_row = self.items_df[self.items_df['product_id'] == item_id]
|
| 334 |
+
if len(item_row) > 0:
|
| 335 |
+
item_category = item_row.iloc[0]['category_code']
|
| 336 |
+
|
| 337 |
+
# Boost score if item is in user's preferred categories
|
| 338 |
+
if item_category in user_categories:
|
| 339 |
+
boosted_score = score * category_boost
|
| 340 |
+
else:
|
| 341 |
+
boosted_score = score
|
| 342 |
+
|
| 343 |
+
boosted_items.append((item_id, boosted_score))
|
| 344 |
|
| 345 |
+
# Sort by boosted score and take top k
|
| 346 |
+
boosted_items.sort(key=lambda x: x[1], reverse=True)
|
| 347 |
+
boosted_items = boosted_items[:k]
|
| 348 |
|
| 349 |
# Add item metadata
|
| 350 |
recommendations = []
|
| 351 |
+
for item_id, score in boosted_items:
|
| 352 |
item_info = self._get_item_info(item_id)
|
| 353 |
recommendations.append((item_id, score, item_info))
|
| 354 |
|
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Enhanced two-tower model with embedding diversity regularization and improved discrimination.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import tensorflow_recommenders as tfrs
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EmbeddingDiversityRegularizer(tf.keras.layers.Layer):
|
| 12 |
+
"""Regularizer to prevent embedding collapse by enforcing diversity."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, diversity_weight=0.01, orthogonality_weight=0.05, **kwargs):
|
| 15 |
+
super().__init__(**kwargs)
|
| 16 |
+
self.diversity_weight = diversity_weight
|
| 17 |
+
self.orthogonality_weight = orthogonality_weight
|
| 18 |
+
|
| 19 |
+
def call(self, embeddings):
|
| 20 |
+
"""Apply diversity regularization to embeddings."""
|
| 21 |
+
batch_size = tf.shape(embeddings)[0]
|
| 22 |
+
|
| 23 |
+
# Compute pairwise cosine similarities
|
| 24 |
+
normalized_embeddings = tf.nn.l2_normalize(embeddings, axis=1)
|
| 25 |
+
similarity_matrix = tf.linalg.matmul(
|
| 26 |
+
normalized_embeddings, normalized_embeddings, transpose_b=True
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Remove diagonal (self-similarities)
|
| 30 |
+
mask = 1.0 - tf.eye(batch_size)
|
| 31 |
+
masked_similarities = similarity_matrix * mask
|
| 32 |
+
|
| 33 |
+
# Diversity loss: penalize high similarities between different embeddings
|
| 34 |
+
diversity_loss = tf.reduce_mean(tf.square(masked_similarities))
|
| 35 |
+
|
| 36 |
+
# Orthogonality loss: encourage embeddings to be orthogonal
|
| 37 |
+
identity_target = tf.eye(batch_size)
|
| 38 |
+
orthogonality_loss = tf.reduce_mean(
|
| 39 |
+
tf.square(similarity_matrix - identity_target)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Add as regularization losses
|
| 43 |
+
self.add_loss(self.diversity_weight * diversity_loss)
|
| 44 |
+
self.add_loss(self.orthogonality_weight * orthogonality_loss)
|
| 45 |
+
|
| 46 |
+
return embeddings
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AdaptiveTemperatureScaling(tf.keras.layers.Layer):
|
| 50 |
+
"""Advanced temperature scaling with learned parameters."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, initial_temperature=1.0, min_temp=0.1, max_temp=5.0, **kwargs):
|
| 53 |
+
super().__init__(**kwargs)
|
| 54 |
+
self.initial_temperature = initial_temperature
|
| 55 |
+
self.min_temp = min_temp
|
| 56 |
+
self.max_temp = max_temp
|
| 57 |
+
|
| 58 |
+
def build(self, input_shape):
|
| 59 |
+
# Learnable temperature with constraints
|
| 60 |
+
self.raw_temperature = self.add_weight(
|
| 61 |
+
name='raw_temperature',
|
| 62 |
+
shape=(),
|
| 63 |
+
initializer=tf.keras.initializers.Constant(
|
| 64 |
+
np.log(self.initial_temperature - self.min_temp)
|
| 65 |
+
),
|
| 66 |
+
trainable=True
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Learnable bias term for better discrimination
|
| 70 |
+
self.similarity_bias = self.add_weight(
|
| 71 |
+
name='similarity_bias',
|
| 72 |
+
shape=(),
|
| 73 |
+
initializer=tf.keras.initializers.Zeros(),
|
| 74 |
+
trainable=True
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
super().build(input_shape)
|
| 78 |
+
|
| 79 |
+
def call(self, user_embeddings, item_embeddings):
|
| 80 |
+
"""Compute adaptive temperature-scaled similarity with bias."""
|
| 81 |
+
# Constrain temperature to valid range
|
| 82 |
+
temperature = self.min_temp + tf.nn.softplus(self.raw_temperature)
|
| 83 |
+
temperature = tf.minimum(temperature, self.max_temp)
|
| 84 |
+
|
| 85 |
+
# Compute similarities
|
| 86 |
+
similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 87 |
+
|
| 88 |
+
# Add learnable bias and apply temperature scaling
|
| 89 |
+
scaled_similarities = (similarities + self.similarity_bias) / temperature
|
| 90 |
+
|
| 91 |
+
return scaled_similarities, temperature
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class EnhancedItemTower(tf.keras.Model):
|
| 95 |
+
"""Enhanced item tower with diversity regularization."""
|
| 96 |
+
|
| 97 |
+
def __init__(self,
|
| 98 |
+
item_vocab_size: int,
|
| 99 |
+
category_vocab_size: int,
|
| 100 |
+
brand_vocab_size: int,
|
| 101 |
+
embedding_dim: int = 128,
|
| 102 |
+
hidden_dims: list = [256, 128],
|
| 103 |
+
dropout_rate: float = 0.3,
|
| 104 |
+
use_bias: bool = True,
|
| 105 |
+
use_diversity_reg: bool = True):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
self.embedding_dim = embedding_dim
|
| 109 |
+
self.use_bias = use_bias
|
| 110 |
+
self.use_diversity_reg = use_diversity_reg
|
| 111 |
+
|
| 112 |
+
# Embedding layers with better initialization
|
| 113 |
+
self.item_embedding = tf.keras.layers.Embedding(
|
| 114 |
+
item_vocab_size, embedding_dim,
|
| 115 |
+
embeddings_initializer='he_normal', # Better initialization
|
| 116 |
+
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 117 |
+
name="item_embedding"
|
| 118 |
+
)
|
| 119 |
+
self.category_embedding = tf.keras.layers.Embedding(
|
| 120 |
+
category_vocab_size, embedding_dim,
|
| 121 |
+
embeddings_initializer='he_normal',
|
| 122 |
+
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 123 |
+
name="category_embedding"
|
| 124 |
+
)
|
| 125 |
+
self.brand_embedding = tf.keras.layers.Embedding(
|
| 126 |
+
brand_vocab_size, embedding_dim,
|
| 127 |
+
embeddings_initializer='he_normal',
|
| 128 |
+
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 129 |
+
name="brand_embedding"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Price processing
|
| 133 |
+
self.price_normalization = tf.keras.layers.Normalization(name="price_norm")
|
| 134 |
+
self.price_projection = tf.keras.layers.Dense(
|
| 135 |
+
embedding_dim // 4, activation='relu', name="price_proj"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Enhanced attention mechanism
|
| 139 |
+
self.feature_attention = tf.keras.layers.MultiHeadAttention(
|
| 140 |
+
num_heads=4,
|
| 141 |
+
key_dim=embedding_dim,
|
| 142 |
+
dropout=0.1,
|
| 143 |
+
name="feature_attention"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Dense layers with residual connections
|
| 147 |
+
self.dense_layers = []
|
| 148 |
+
for i, dim in enumerate(hidden_dims):
|
| 149 |
+
self.dense_layers.extend([
|
| 150 |
+
tf.keras.layers.Dense(dim, activation=None, name=f"dense_{i}"),
|
| 151 |
+
tf.keras.layers.BatchNormalization(name=f"bn_{i}"),
|
| 152 |
+
tf.keras.layers.Activation('relu', name=f"relu_{i}"),
|
| 153 |
+
tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}")
|
| 154 |
+
])
|
| 155 |
+
|
| 156 |
+
# Output layer with controlled normalization
|
| 157 |
+
self.output_layer = tf.keras.layers.Dense(
|
| 158 |
+
embedding_dim, activation=None, use_bias=use_bias, name="item_output"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Diversity regularizer
|
| 162 |
+
if use_diversity_reg:
|
| 163 |
+
self.diversity_regularizer = EmbeddingDiversityRegularizer()
|
| 164 |
+
|
| 165 |
+
# Adaptive normalization instead of hard L2 normalization
|
| 166 |
+
self.adaptive_norm = tf.keras.layers.LayerNormalization(name="adaptive_norm")
|
| 167 |
+
|
| 168 |
+
# Item bias
|
| 169 |
+
if use_bias:
|
| 170 |
+
self.item_bias = tf.keras.layers.Embedding(
|
| 171 |
+
item_vocab_size, 1, name="item_bias"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def call(self, inputs, training=None):
|
| 175 |
+
"""Enhanced forward pass with diversity regularization."""
|
| 176 |
+
item_id = inputs["product_id"]
|
| 177 |
+
category_id = inputs["category_id"]
|
| 178 |
+
brand_id = inputs["brand_id"]
|
| 179 |
+
price = inputs["price"]
|
| 180 |
+
|
| 181 |
+
# Get embeddings
|
| 182 |
+
item_emb = self.item_embedding(item_id)
|
| 183 |
+
category_emb = self.category_embedding(category_id)
|
| 184 |
+
brand_emb = self.brand_embedding(brand_id)
|
| 185 |
+
|
| 186 |
+
# Process price
|
| 187 |
+
price_norm = self.price_normalization(tf.expand_dims(price, -1))
|
| 188 |
+
price_emb = self.price_projection(price_norm)
|
| 189 |
+
|
| 190 |
+
# Pad price embedding
|
| 191 |
+
price_emb_padded = tf.pad(
|
| 192 |
+
price_emb,
|
| 193 |
+
[[0, 0], [0, self.embedding_dim - tf.shape(price_emb)[-1]]]
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Stack features for attention
|
| 197 |
+
features = tf.stack([item_emb, category_emb, brand_emb, price_emb_padded], axis=1)
|
| 198 |
+
|
| 199 |
+
# Apply attention
|
| 200 |
+
attended_features = self.feature_attention(
|
| 201 |
+
query=features,
|
| 202 |
+
value=features,
|
| 203 |
+
key=features,
|
| 204 |
+
training=training
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Aggregate with residual connection
|
| 208 |
+
combined = tf.reduce_mean(attended_features + features, axis=1)
|
| 209 |
+
|
| 210 |
+
# Pass through dense layers with residual connections
|
| 211 |
+
x = combined
|
| 212 |
+
residual = x
|
| 213 |
+
for i, layer in enumerate(self.dense_layers):
|
| 214 |
+
x = layer(x, training=training)
|
| 215 |
+
# Add residual connection every 4 layers (complete block)
|
| 216 |
+
if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
|
| 217 |
+
x = x + residual
|
| 218 |
+
residual = x
|
| 219 |
+
|
| 220 |
+
# Final output
|
| 221 |
+
output = self.output_layer(x)
|
| 222 |
+
|
| 223 |
+
# Apply diversity regularization if enabled
|
| 224 |
+
if self.use_diversity_reg and training:
|
| 225 |
+
output = self.diversity_regularizer(output)
|
| 226 |
+
|
| 227 |
+
# Adaptive normalization instead of hard L2
|
| 228 |
+
normalized_output = self.adaptive_norm(output)
|
| 229 |
+
|
| 230 |
+
# Add bias if enabled
|
| 231 |
+
if self.use_bias:
|
| 232 |
+
bias = tf.squeeze(self.item_bias(item_id), axis=-1)
|
| 233 |
+
return normalized_output, bias
|
| 234 |
+
else:
|
| 235 |
+
return normalized_output
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class EnhancedUserTower(tf.keras.Model):
|
| 239 |
+
"""Enhanced user tower with diversity regularization."""
|
| 240 |
+
|
| 241 |
+
def __init__(self,
|
| 242 |
+
max_history_length: int = 50,
|
| 243 |
+
embedding_dim: int = 128,
|
| 244 |
+
hidden_dims: list = [256, 128],
|
| 245 |
+
dropout_rate: float = 0.3,
|
| 246 |
+
use_bias: bool = True,
|
| 247 |
+
use_diversity_reg: bool = True):
|
| 248 |
+
super().__init__()
|
| 249 |
+
|
| 250 |
+
self.embedding_dim = embedding_dim
|
| 251 |
+
self.max_history_length = max_history_length
|
| 252 |
+
self.use_bias = use_bias
|
| 253 |
+
self.use_diversity_reg = use_diversity_reg
|
| 254 |
+
|
| 255 |
+
# Demographic embeddings with regularization
|
| 256 |
+
self.age_embedding = tf.keras.layers.Embedding(
|
| 257 |
+
6, embedding_dim // 16,
|
| 258 |
+
embeddings_initializer='he_normal',
|
| 259 |
+
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 260 |
+
name="age_embedding"
|
| 261 |
+
)
|
| 262 |
+
self.income_embedding = tf.keras.layers.Embedding(
|
| 263 |
+
5, embedding_dim // 16,
|
| 264 |
+
embeddings_initializer='he_normal',
|
| 265 |
+
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 266 |
+
name="income_embedding"
|
| 267 |
+
)
|
| 268 |
+
self.gender_embedding = tf.keras.layers.Embedding(
|
| 269 |
+
2, embedding_dim // 16,
|
| 270 |
+
embeddings_initializer='he_normal',
|
| 271 |
+
embeddings_regularizer=tf.keras.regularizers.L2(1e-6),
|
| 272 |
+
name="gender_embedding"
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Enhanced history processing
|
| 276 |
+
self.history_transformer = tf.keras.layers.MultiHeadAttention(
|
| 277 |
+
num_heads=8,
|
| 278 |
+
key_dim=embedding_dim,
|
| 279 |
+
dropout=0.1,
|
| 280 |
+
name="history_transformer"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# History aggregation with attention pooling
|
| 284 |
+
self.history_attention_pooling = tf.keras.layers.Dense(
|
| 285 |
+
1, activation=None, name="history_attention"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Dense layers with residual connections
|
| 289 |
+
self.dense_layers = []
|
| 290 |
+
for i, dim in enumerate(hidden_dims):
|
| 291 |
+
self.dense_layers.extend([
|
| 292 |
+
tf.keras.layers.Dense(dim, activation=None, name=f"user_dense_{i}"),
|
| 293 |
+
tf.keras.layers.BatchNormalization(name=f"user_bn_{i}"),
|
| 294 |
+
tf.keras.layers.Activation('relu', name=f"user_relu_{i}"),
|
| 295 |
+
tf.keras.layers.Dropout(dropout_rate, name=f"user_dropout_{i}")
|
| 296 |
+
])
|
| 297 |
+
|
| 298 |
+
# Output layer
|
| 299 |
+
self.output_layer = tf.keras.layers.Dense(
|
| 300 |
+
embedding_dim, activation=None, use_bias=use_bias, name="user_output"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Diversity regularizer
|
| 304 |
+
if use_diversity_reg:
|
| 305 |
+
self.diversity_regularizer = EmbeddingDiversityRegularizer()
|
| 306 |
+
|
| 307 |
+
# Adaptive normalization
|
| 308 |
+
self.adaptive_norm = tf.keras.layers.LayerNormalization(name="user_adaptive_norm")
|
| 309 |
+
|
| 310 |
+
# Global user bias
|
| 311 |
+
if use_bias:
|
| 312 |
+
self.global_user_bias = tf.Variable(
|
| 313 |
+
initial_value=0.0, trainable=True, name="global_user_bias"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def call(self, inputs, training=None):
|
| 317 |
+
"""Enhanced forward pass with diversity regularization."""
|
| 318 |
+
age = inputs["age"]
|
| 319 |
+
gender = inputs["gender"]
|
| 320 |
+
income = inputs["income"]
|
| 321 |
+
item_history = inputs["item_history_embeddings"]
|
| 322 |
+
|
| 323 |
+
# Process demographics
|
| 324 |
+
age_emb = self.age_embedding(age)
|
| 325 |
+
income_emb = self.income_embedding(income)
|
| 326 |
+
gender_emb = self.gender_embedding(gender)
|
| 327 |
+
|
| 328 |
+
# Combine demographics
|
| 329 |
+
demo_combined = tf.concat([age_emb, income_emb, gender_emb], axis=-1)
|
| 330 |
+
|
| 331 |
+
# Enhanced history processing
|
| 332 |
+
batch_size = tf.shape(item_history)[0]
|
| 333 |
+
seq_len = tf.shape(item_history)[1]
|
| 334 |
+
|
| 335 |
+
# Simplified positional encoding - ensure shape compatibility
|
| 336 |
+
positions = tf.range(seq_len, dtype=tf.float32)
|
| 337 |
+
# Create simpler positional encoding
|
| 338 |
+
pos_encoding_scale = tf.cast(tf.range(self.embedding_dim, dtype=tf.float32), tf.float32) / self.embedding_dim
|
| 339 |
+
position_encoding = tf.sin(positions[:, tf.newaxis] * pos_encoding_scale[tf.newaxis, :])
|
| 340 |
+
|
| 341 |
+
# Ensure correct shape: [seq_len, embedding_dim] -> [batch_size, seq_len, embedding_dim]
|
| 342 |
+
position_encoding = tf.expand_dims(position_encoding, 0)
|
| 343 |
+
position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
|
| 344 |
+
|
| 345 |
+
# Add positional encoding with shape check
|
| 346 |
+
history_with_pos = item_history + position_encoding
|
| 347 |
+
|
| 348 |
+
# Create attention mask - fix shape for MultiHeadAttention
|
| 349 |
+
# MultiHeadAttention expects mask shape: [batch_size, seq_len] or [batch_size, seq_len, seq_len]
|
| 350 |
+
history_mask = tf.reduce_sum(tf.abs(item_history), axis=-1) > 0 # [batch_size, seq_len]
|
| 351 |
+
|
| 352 |
+
# Apply transformer attention
|
| 353 |
+
attended_history = self.history_transformer(
|
| 354 |
+
query=history_with_pos,
|
| 355 |
+
value=history_with_pos,
|
| 356 |
+
key=history_with_pos,
|
| 357 |
+
attention_mask=history_mask,
|
| 358 |
+
training=training
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Attention-based pooling instead of simple mean
|
| 362 |
+
attention_weights = tf.nn.softmax(
|
| 363 |
+
self.history_attention_pooling(attended_history), axis=1
|
| 364 |
+
)
|
| 365 |
+
history_aggregated = tf.reduce_sum(
|
| 366 |
+
attended_history * attention_weights, axis=1
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Combine features
|
| 370 |
+
combined = tf.concat([demo_combined, history_aggregated], axis=-1)
|
| 371 |
+
|
| 372 |
+
# Pass through dense layers with residual connections
|
| 373 |
+
x = combined
|
| 374 |
+
residual = x
|
| 375 |
+
for i, layer in enumerate(self.dense_layers):
|
| 376 |
+
x = layer(x, training=training)
|
| 377 |
+
# Add residual connection every 4 layers
|
| 378 |
+
if (i + 1) % 4 == 0 and x.shape[-1] == residual.shape[-1]:
|
| 379 |
+
x = x + residual
|
| 380 |
+
residual = x
|
| 381 |
+
|
| 382 |
+
# Final output
|
| 383 |
+
output = self.output_layer(x)
|
| 384 |
+
|
| 385 |
+
# Apply diversity regularization if enabled
|
| 386 |
+
if self.use_diversity_reg and training:
|
| 387 |
+
output = self.diversity_regularizer(output)
|
| 388 |
+
|
| 389 |
+
# Adaptive normalization
|
| 390 |
+
normalized_output = self.adaptive_norm(output)
|
| 391 |
+
|
| 392 |
+
# Add bias if enabled
|
| 393 |
+
if self.use_bias:
|
| 394 |
+
return normalized_output, self.global_user_bias
|
| 395 |
+
else:
|
| 396 |
+
return normalized_output
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class EnhancedTwoTowerModel(tfrs.Model):
|
| 400 |
+
"""Enhanced two-tower model with all improvements."""
|
| 401 |
+
|
| 402 |
+
def __init__(self,
|
| 403 |
+
item_tower: EnhancedItemTower,
|
| 404 |
+
user_tower: EnhancedUserTower,
|
| 405 |
+
rating_weight: float = 1.0,
|
| 406 |
+
retrieval_weight: float = 1.0,
|
| 407 |
+
contrastive_weight: float = 0.3,
|
| 408 |
+
diversity_weight: float = 0.1):
|
| 409 |
+
super().__init__()
|
| 410 |
+
|
| 411 |
+
self.item_tower = item_tower
|
| 412 |
+
self.user_tower = user_tower
|
| 413 |
+
self.rating_weight = rating_weight
|
| 414 |
+
self.retrieval_weight = retrieval_weight
|
| 415 |
+
self.contrastive_weight = contrastive_weight
|
| 416 |
+
self.diversity_weight = diversity_weight
|
| 417 |
+
|
| 418 |
+
# Adaptive temperature scaling
|
| 419 |
+
self.temperature_similarity = AdaptiveTemperatureScaling()
|
| 420 |
+
|
| 421 |
+
# Enhanced rating model
|
| 422 |
+
self.rating_model = tf.keras.Sequential([
|
| 423 |
+
tf.keras.layers.Dense(512, activation="relu"),
|
| 424 |
+
tf.keras.layers.BatchNormalization(),
|
| 425 |
+
tf.keras.layers.Dropout(0.3),
|
| 426 |
+
tf.keras.layers.Dense(256, activation="relu"),
|
| 427 |
+
tf.keras.layers.BatchNormalization(),
|
| 428 |
+
tf.keras.layers.Dropout(0.2),
|
| 429 |
+
tf.keras.layers.Dense(64, activation="relu"),
|
| 430 |
+
tf.keras.layers.Dense(1, activation="sigmoid")
|
| 431 |
+
])
|
| 432 |
+
|
| 433 |
+
# Focal loss for imbalanced data
|
| 434 |
+
self.focal_loss = self._focal_loss
|
| 435 |
+
|
| 436 |
+
def _focal_loss(self, y_true, y_pred, alpha=0.25, gamma=2.0):
|
| 437 |
+
"""Focal loss implementation."""
|
| 438 |
+
epsilon = tf.keras.backend.epsilon()
|
| 439 |
+
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
|
| 440 |
+
|
| 441 |
+
alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
|
| 442 |
+
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
|
| 443 |
+
focal_weight = alpha_t * tf.pow((1 - p_t), gamma)
|
| 444 |
+
|
| 445 |
+
bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
|
| 446 |
+
focal_loss = focal_weight * bce
|
| 447 |
+
|
| 448 |
+
return tf.reduce_mean(focal_loss)
|
| 449 |
+
|
| 450 |
+
def call(self, features):
|
| 451 |
+
# Get embeddings
|
| 452 |
+
user_output = self.user_tower(features)
|
| 453 |
+
item_output = self.item_tower(features)
|
| 454 |
+
|
| 455 |
+
# Handle bias terms
|
| 456 |
+
if isinstance(user_output, tuple):
|
| 457 |
+
user_embeddings, user_bias = user_output
|
| 458 |
+
else:
|
| 459 |
+
user_embeddings = user_output
|
| 460 |
+
user_bias = 0.0
|
| 461 |
+
|
| 462 |
+
if isinstance(item_output, tuple):
|
| 463 |
+
item_embeddings, item_bias = item_output
|
| 464 |
+
else:
|
| 465 |
+
item_embeddings = item_output
|
| 466 |
+
item_bias = 0.0
|
| 467 |
+
|
| 468 |
+
return {
|
| 469 |
+
"user_embedding": user_embeddings,
|
| 470 |
+
"item_embedding": item_embeddings,
|
| 471 |
+
"user_bias": user_bias,
|
| 472 |
+
"item_bias": item_bias
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
def compute_loss(self, features, training=False):
|
| 476 |
+
# Get embeddings and biases
|
| 477 |
+
outputs = self(features)
|
| 478 |
+
user_embeddings = outputs["user_embedding"]
|
| 479 |
+
item_embeddings = outputs["item_embedding"]
|
| 480 |
+
user_bias = outputs["user_bias"]
|
| 481 |
+
item_bias = outputs["item_bias"]
|
| 482 |
+
|
| 483 |
+
# Rating prediction
|
| 484 |
+
concatenated = tf.concat([user_embeddings, item_embeddings], axis=-1)
|
| 485 |
+
rating_predictions = self.rating_model(concatenated, training=training)
|
| 486 |
+
|
| 487 |
+
# Add bias terms
|
| 488 |
+
rating_predictions_with_bias = rating_predictions + user_bias + item_bias
|
| 489 |
+
rating_predictions_with_bias = tf.nn.sigmoid(rating_predictions_with_bias)
|
| 490 |
+
|
| 491 |
+
# Losses
|
| 492 |
+
rating_loss = self.focal_loss(features["rating"], rating_predictions_with_bias)
|
| 493 |
+
|
| 494 |
+
# Adaptive temperature-scaled retrieval loss
|
| 495 |
+
scaled_similarities, temperature = self.temperature_similarity(
|
| 496 |
+
user_embeddings, item_embeddings
|
| 497 |
+
)
|
| 498 |
+
retrieval_loss = tf.keras.losses.binary_crossentropy(
|
| 499 |
+
features["rating"],
|
| 500 |
+
tf.nn.sigmoid(scaled_similarities)
|
| 501 |
+
)
|
| 502 |
+
retrieval_loss = tf.reduce_mean(retrieval_loss)
|
| 503 |
+
|
| 504 |
+
# Enhanced contrastive loss with hard negatives
|
| 505 |
+
batch_size = tf.shape(user_embeddings)[0]
|
| 506 |
+
positive_similarities = tf.reduce_sum(user_embeddings * item_embeddings, axis=1)
|
| 507 |
+
|
| 508 |
+
# Random negative sampling
|
| 509 |
+
shuffled_indices = tf.random.shuffle(tf.range(batch_size))
|
| 510 |
+
negative_item_embeddings = tf.gather(item_embeddings, shuffled_indices)
|
| 511 |
+
negative_similarities = tf.reduce_sum(user_embeddings * negative_item_embeddings, axis=1)
|
| 512 |
+
|
| 513 |
+
# Triplet loss with adaptive margin
|
| 514 |
+
margin = 0.5 / temperature # Adaptive margin based on temperature
|
| 515 |
+
contrastive_loss = tf.reduce_mean(
|
| 516 |
+
tf.maximum(0.0, margin + negative_similarities - positive_similarities)
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# Combine losses
|
| 520 |
+
total_loss = (
|
| 521 |
+
self.rating_weight * rating_loss +
|
| 522 |
+
self.retrieval_weight * retrieval_loss +
|
| 523 |
+
self.contrastive_weight * contrastive_loss
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
# Add regularization losses from diversity regularizers
|
| 527 |
+
if training:
|
| 528 |
+
regularization_losses = tf.add_n(self.losses) if self.losses else 0.0
|
| 529 |
+
total_loss += self.diversity_weight * regularization_losses
|
| 530 |
+
|
| 531 |
+
return {
|
| 532 |
+
'total_loss': total_loss,
|
| 533 |
+
'rating_loss': rating_loss,
|
| 534 |
+
'retrieval_loss': retrieval_loss,
|
| 535 |
+
'contrastive_loss': contrastive_loss,
|
| 536 |
+
'temperature': temperature,
|
| 537 |
+
'diversity_loss': regularization_losses if training else 0.0
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def create_enhanced_model(data_processor,
|
| 542 |
+
embedding_dim=128,
|
| 543 |
+
use_bias=True,
|
| 544 |
+
use_diversity_reg=True):
|
| 545 |
+
"""Factory function to create enhanced two-tower model."""
|
| 546 |
+
|
| 547 |
+
# Create enhanced towers
|
| 548 |
+
item_tower = EnhancedItemTower(
|
| 549 |
+
item_vocab_size=len(data_processor.item_vocab),
|
| 550 |
+
category_vocab_size=len(data_processor.category_vocab),
|
| 551 |
+
brand_vocab_size=len(data_processor.brand_vocab),
|
| 552 |
+
embedding_dim=embedding_dim,
|
| 553 |
+
use_bias=use_bias,
|
| 554 |
+
use_diversity_reg=use_diversity_reg
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
user_tower = EnhancedUserTower(
|
| 558 |
+
max_history_length=50,
|
| 559 |
+
embedding_dim=embedding_dim,
|
| 560 |
+
use_bias=use_bias,
|
| 561 |
+
use_diversity_reg=use_diversity_reg
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Create enhanced model
|
| 565 |
+
model = EnhancedTwoTowerModel(
|
| 566 |
+
item_tower=item_tower,
|
| 567 |
+
user_tower=user_tower,
|
| 568 |
+
rating_weight=1.0,
|
| 569 |
+
retrieval_weight=0.5,
|
| 570 |
+
contrastive_weight=0.3,
|
| 571 |
+
diversity_weight=0.1
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
return model
|
|
@@ -10,8 +10,8 @@ class ItemTower(tf.keras.Model):
|
|
| 10 |
item_vocab_size: int,
|
| 11 |
category_vocab_size: int,
|
| 12 |
brand_vocab_size: int,
|
| 13 |
-
embedding_dim: int =
|
| 14 |
-
hidden_dims: list = [
|
| 15 |
dropout_rate: float = 0.2):
|
| 16 |
super().__init__()
|
| 17 |
|
|
|
|
| 10 |
item_vocab_size: int,
|
| 11 |
category_vocab_size: int,
|
| 12 |
brand_vocab_size: int,
|
| 13 |
+
embedding_dim: int = 128, # Output embedding dimension
|
| 14 |
+
hidden_dims: list = [256, 128], # Internal dims can be larger
|
| 15 |
dropout_rate: float = 0.2):
|
| 16 |
super().__init__()
|
| 17 |
|
|
@@ -8,8 +8,8 @@ class UserTower(tf.keras.Model):
|
|
| 8 |
|
| 9 |
def __init__(self,
|
| 10 |
max_history_length: int = 50,
|
| 11 |
-
embedding_dim: int =
|
| 12 |
-
hidden_dims: list = [
|
| 13 |
dropout_rate: float = 0.2):
|
| 14 |
super().__init__()
|
| 15 |
|
|
|
|
| 8 |
|
| 9 |
def __init__(self,
|
| 10 |
max_history_length: int = 50,
|
| 11 |
+
embedding_dim: int = 128, # Output embedding dimension
|
| 12 |
+
hidden_dims: list = [256, 128], # Internal dims for processing
|
| 13 |
dropout_rate: float = 0.2):
|
| 14 |
super().__init__()
|
| 15 |
|
|
@@ -4,6 +4,8 @@ import tensorflow as tf
|
|
| 4 |
from typing import Dict, List, Tuple, Optional
|
| 5 |
from collections import defaultdict
|
| 6 |
import pickle
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class DataProcessor:
|
|
@@ -97,54 +99,73 @@ class DataProcessor:
|
|
| 97 |
interactions_df: pd.DataFrame,
|
| 98 |
items_df: pd.DataFrame,
|
| 99 |
negative_samples_per_positive: int = 4) -> pd.DataFrame:
|
| 100 |
-
"""Create positive and negative user-item pairs for training."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
# Get all unique items for negative sampling
|
| 103 |
all_items = set(self.item_vocab.keys())
|
|
|
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
positive_pairs.append({
|
| 110 |
-
'user_id': row['user_id'],
|
| 111 |
-
'product_id': row['product_id'],
|
| 112 |
-
'rating': 1.0 # Implicit positive feedback
|
| 113 |
-
})
|
| 114 |
-
|
| 115 |
-
# Create negative pairs
|
| 116 |
-
negative_pairs = []
|
| 117 |
-
user_item_interactions = set(
|
| 118 |
-
(row['user_id'], row['product_id'])
|
| 119 |
-
for _, row in interactions_df.iterrows()
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
for pos_pair in positive_pairs:
|
| 123 |
-
user_id = pos_pair['user_id']
|
| 124 |
-
user_interactions = set(
|
| 125 |
-
row['product_id'] for _, row in interactions_df.iterrows()
|
| 126 |
-
if row['user_id'] == user_id
|
| 127 |
-
)
|
| 128 |
|
| 129 |
-
# Sample negative items
|
| 130 |
-
negative_items = all_items - user_interactions
|
| 131 |
if len(negative_items) >= negative_samples_per_positive:
|
|
|
|
| 132 |
sampled_negatives = np.random.choice(
|
| 133 |
-
|
| 134 |
-
size=negative_samples_per_positive
|
| 135 |
-
replace=
|
| 136 |
)
|
| 137 |
|
| 138 |
-
for
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Combine positive and negative pairs
|
| 146 |
-
all_pairs = positive_pairs
|
| 147 |
-
return
|
| 148 |
|
| 149 |
def save_vocabularies(self, save_path: str = "src/artifacts/"):
|
| 150 |
"""Save vocabularies for later use."""
|
|
@@ -176,9 +197,19 @@ class DataProcessor:
|
|
| 176 |
print("Vocabularies loaded successfully")
|
| 177 |
|
| 178 |
|
| 179 |
-
def create_tf_dataset(features: Dict[str, np.ndarray], batch_size: int = 256) -> tf.data.Dataset:
|
| 180 |
-
"""Create TensorFlow dataset from features."""
|
| 181 |
dataset = tf.data.Dataset.from_tensor_slices(features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
dataset = dataset.batch(batch_size)
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
| 184 |
return dataset
|
|
|
|
| 4 |
from typing import Dict, List, Tuple, Optional
|
| 5 |
from collections import defaultdict
|
| 6 |
import pickle
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
|
| 10 |
|
| 11 |
class DataProcessor:
|
|
|
|
| 99 |
interactions_df: pd.DataFrame,
|
| 100 |
items_df: pd.DataFrame,
|
| 101 |
negative_samples_per_positive: int = 4) -> pd.DataFrame:
|
| 102 |
+
"""Create positive and negative user-item pairs for training (optimized)."""
|
| 103 |
+
|
| 104 |
+
# Filter valid interactions once
|
| 105 |
+
valid_interactions = interactions_df[
|
| 106 |
+
(interactions_df['user_id'].isin(self.user_vocab)) &
|
| 107 |
+
(interactions_df['product_id'].isin(self.item_vocab))
|
| 108 |
+
].copy()
|
| 109 |
+
|
| 110 |
+
# Create positive pairs vectorized
|
| 111 |
+
positive_pairs = valid_interactions[['user_id', 'product_id']].copy()
|
| 112 |
+
positive_pairs['rating'] = 1.0
|
| 113 |
+
|
| 114 |
+
# Pre-compute user interactions for faster lookup
|
| 115 |
+
user_items_dict = (
|
| 116 |
+
valid_interactions.groupby('user_id')['product_id']
|
| 117 |
+
.apply(set).to_dict()
|
| 118 |
+
)
|
| 119 |
|
|
|
|
| 120 |
all_items = set(self.item_vocab.keys())
|
| 121 |
+
all_items_array = np.array(list(all_items))
|
| 122 |
|
| 123 |
+
# Generate negative samples in parallel
|
| 124 |
+
def generate_negatives_for_user(user_data):
|
| 125 |
+
user_id, user_items = user_data
|
| 126 |
+
negative_items = all_items - user_items
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
|
|
|
|
|
|
| 128 |
if len(negative_items) >= negative_samples_per_positive:
|
| 129 |
+
neg_items_array = np.array(list(negative_items))
|
| 130 |
sampled_negatives = np.random.choice(
|
| 131 |
+
neg_items_array,
|
| 132 |
+
size=negative_samples_per_positive * len(user_items),
|
| 133 |
+
replace=len(negative_items) < negative_samples_per_positive * len(user_items)
|
| 134 |
)
|
| 135 |
|
| 136 |
+
# Repeat user_id for each negative sample
|
| 137 |
+
user_ids = np.repeat(user_id, len(sampled_negatives))
|
| 138 |
+
ratings = np.zeros(len(sampled_negatives))
|
| 139 |
+
|
| 140 |
+
return pd.DataFrame({
|
| 141 |
+
'user_id': user_ids,
|
| 142 |
+
'product_id': sampled_negatives,
|
| 143 |
+
'rating': ratings
|
| 144 |
+
})
|
| 145 |
+
return pd.DataFrame(columns=['user_id', 'product_id', 'rating'])
|
| 146 |
+
|
| 147 |
+
# Process in parallel chunks
|
| 148 |
+
chunk_size = max(1, len(user_items_dict) // mp.cpu_count())
|
| 149 |
+
user_chunks = [
|
| 150 |
+
list(user_items_dict.items())[i:i + chunk_size]
|
| 151 |
+
for i in range(0, len(user_items_dict), chunk_size)
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
negative_dfs = []
|
| 155 |
+
with ThreadPoolExecutor(max_workers=mp.cpu_count()) as executor:
|
| 156 |
+
for chunk in user_chunks:
|
| 157 |
+
chunk_results = list(executor.map(generate_negatives_for_user, chunk))
|
| 158 |
+
negative_dfs.extend(chunk_results)
|
| 159 |
+
|
| 160 |
+
# Combine all negative samples
|
| 161 |
+
if negative_dfs:
|
| 162 |
+
negative_pairs = pd.concat(negative_dfs, ignore_index=True)
|
| 163 |
+
else:
|
| 164 |
+
negative_pairs = pd.DataFrame(columns=['user_id', 'product_id', 'rating'])
|
| 165 |
|
| 166 |
# Combine positive and negative pairs
|
| 167 |
+
all_pairs = pd.concat([positive_pairs, negative_pairs], ignore_index=True)
|
| 168 |
+
return all_pairs
|
| 169 |
|
| 170 |
def save_vocabularies(self, save_path: str = "src/artifacts/"):
|
| 171 |
"""Save vocabularies for later use."""
|
|
|
|
| 197 |
print("Vocabularies loaded successfully")
|
| 198 |
|
| 199 |
|
| 200 |
+
def create_tf_dataset(features: Dict[str, np.ndarray], batch_size: int = 256, shuffle: bool = True) -> tf.data.Dataset:
|
| 201 |
+
"""Create optimized TensorFlow dataset from features for CPU training."""
|
| 202 |
dataset = tf.data.Dataset.from_tensor_slices(features)
|
| 203 |
+
|
| 204 |
+
if shuffle:
|
| 205 |
+
# Use reasonable buffer size for memory efficiency - handle different feature types
|
| 206 |
+
sample_key = next(iter(features.keys()))
|
| 207 |
+
buffer_size = min(len(features[sample_key]), 10000)
|
| 208 |
+
dataset = dataset.shuffle(buffer_size)
|
| 209 |
+
|
| 210 |
dataset = dataset.batch(batch_size)
|
| 211 |
+
|
| 212 |
+
# Optimize for CPU with reasonable prefetch
|
| 213 |
+
dataset = dataset.prefetch(2) # Reduced from AUTOTUNE for CPU efficiency
|
| 214 |
+
|
| 215 |
return dataset
|
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimized dataset creation script with performance improvements.
|
| 3 |
+
"""
|
| 4 |
+
import time
|
| 5 |
+
import numpy as np
|
| 6 |
+
from src.preprocessing.user_data_preparation import UserDatasetCreator
|
| 7 |
+
from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_optimized_dataset(max_history_length: int = 50,
|
| 11 |
+
batch_size: int = 512,
|
| 12 |
+
negative_samples_per_positive: int = 2,
|
| 13 |
+
use_sample: bool = False,
|
| 14 |
+
sample_size: int = 10000):
|
| 15 |
+
"""
|
| 16 |
+
Create dataset with optimized performance settings.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
max_history_length: Maximum user interaction history length
|
| 20 |
+
batch_size: Batch size for TensorFlow dataset
|
| 21 |
+
negative_samples_per_positive: Negative sampling ratio
|
| 22 |
+
use_sample: Whether to use a sample of the data for faster processing
|
| 23 |
+
sample_size: Size of sample if use_sample=True
|
| 24 |
+
"""
|
| 25 |
+
print("Starting optimized dataset creation...")
|
| 26 |
+
start_time = time.time()
|
| 27 |
+
|
| 28 |
+
# Initialize with optimized settings
|
| 29 |
+
dataset_creator = UserDatasetCreator(max_history_length=max_history_length)
|
| 30 |
+
data_processor = DataProcessor()
|
| 31 |
+
|
| 32 |
+
# Load data
|
| 33 |
+
print("Loading data...")
|
| 34 |
+
load_start = time.time()
|
| 35 |
+
items_df, users_df, interactions_df = data_processor.load_data()
|
| 36 |
+
print(f"Data loaded in {time.time() - load_start:.2f} seconds")
|
| 37 |
+
|
| 38 |
+
# Optional: Use sample for faster development/testing
|
| 39 |
+
if use_sample:
|
| 40 |
+
print(f"Using sample of {sample_size} interactions for faster processing...")
|
| 41 |
+
sample_interactions = interactions_df.sample(min(sample_size, len(interactions_df)))
|
| 42 |
+
user_ids = set(sample_interactions['user_id'])
|
| 43 |
+
item_ids = set(sample_interactions['product_id'])
|
| 44 |
+
|
| 45 |
+
users_df = users_df[users_df['user_id'].isin(user_ids)]
|
| 46 |
+
items_df = items_df[items_df['product_id'].isin(item_ids)]
|
| 47 |
+
interactions_df = sample_interactions
|
| 48 |
+
|
| 49 |
+
print(f"Sample: {len(items_df)} items, {len(users_df)} users, {len(interactions_df)} interactions")
|
| 50 |
+
|
| 51 |
+
# Load embeddings with caching
|
| 52 |
+
print("Loading item embeddings...")
|
| 53 |
+
embed_start = time.time()
|
| 54 |
+
item_embeddings = dataset_creator.load_item_embeddings()
|
| 55 |
+
print(f"Embeddings loaded in {time.time() - embed_start:.2f} seconds")
|
| 56 |
+
|
| 57 |
+
# Create temporal split
|
| 58 |
+
print("Creating temporal split...")
|
| 59 |
+
split_start = time.time()
|
| 60 |
+
train_interactions, val_interactions = dataset_creator.create_temporal_split(interactions_df)
|
| 61 |
+
print(f"Temporal split created in {time.time() - split_start:.2f} seconds")
|
| 62 |
+
|
| 63 |
+
# Create training dataset with optimizations
|
| 64 |
+
print("Creating optimized training dataset...")
|
| 65 |
+
train_start = time.time()
|
| 66 |
+
training_features = dataset_creator.create_training_dataset(
|
| 67 |
+
train_interactions, items_df, users_df, item_embeddings,
|
| 68 |
+
negative_samples_per_positive=negative_samples_per_positive
|
| 69 |
+
)
|
| 70 |
+
print(f"Training dataset created in {time.time() - train_start:.2f} seconds")
|
| 71 |
+
|
| 72 |
+
# Create TensorFlow dataset optimized for CPU
|
| 73 |
+
print("Creating TensorFlow dataset...")
|
| 74 |
+
tf_start = time.time()
|
| 75 |
+
tf_dataset = create_tf_dataset(training_features, batch_size=batch_size)
|
| 76 |
+
print(f"TensorFlow dataset created in {time.time() - tf_start:.2f} seconds")
|
| 77 |
+
|
| 78 |
+
# Save optimized dataset
|
| 79 |
+
print("Saving dataset...")
|
| 80 |
+
save_start = time.time()
|
| 81 |
+
dataset_creator.save_dataset(training_features, "src/artifacts/")
|
| 82 |
+
|
| 83 |
+
# Save vocabularies for later use
|
| 84 |
+
data_processor.save_vocabularies("src/artifacts/")
|
| 85 |
+
print(f"Dataset saved in {time.time() - save_start:.2f} seconds")
|
| 86 |
+
|
| 87 |
+
total_time = time.time() - start_time
|
| 88 |
+
print(f"\nOptimized dataset creation completed in {total_time:.2f} seconds!")
|
| 89 |
+
print(f"Training samples: {len(training_features['rating'])}")
|
| 90 |
+
print(f"Memory usage optimized for CPU training")
|
| 91 |
+
|
| 92 |
+
return tf_dataset, training_features
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
# Run with optimized settings
|
| 97 |
+
tf_dataset, features = create_optimized_dataset(
|
| 98 |
+
max_history_length=30, # Reduced for speed
|
| 99 |
+
batch_size=512, # Larger batches for CPU efficiency
|
| 100 |
+
negative_samples_per_positive=2, # Reduced sampling ratio
|
| 101 |
+
use_sample=True, # Use sample for development
|
| 102 |
+
sample_size=50000 # Reasonable sample size
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
print("\nDataset creation optimization complete!")
|
| 106 |
+
print("Key optimizations applied:")
|
| 107 |
+
print("- Vectorized DataFrame operations")
|
| 108 |
+
print("- Parallel negative sampling")
|
| 109 |
+
print("- Memory-efficient embedding lookup")
|
| 110 |
+
print("- Optimized TensorFlow dataset pipeline")
|
| 111 |
+
print("- LRU caching for embeddings")
|
|
@@ -63,7 +63,7 @@ class UserDatasetCreator:
|
|
| 63 |
# Use more efficient random generation
|
| 64 |
num_items = len(items_df['product_id'].unique())
|
| 65 |
item_ids = items_df['product_id'].unique()
|
| 66 |
-
embedding_matrix = np.random.rand(num_items,
|
| 67 |
|
| 68 |
dummy_embeddings = dict(zip(item_ids, embedding_matrix))
|
| 69 |
print(f"Created dummy embeddings for {len(dummy_embeddings)} items")
|
|
@@ -72,7 +72,7 @@ class UserDatasetCreator:
|
|
| 72 |
def aggregate_user_history_embeddings(self,
|
| 73 |
user_histories: Dict[int, List[int]],
|
| 74 |
item_embeddings: Dict[int, np.ndarray],
|
| 75 |
-
embedding_dim: int =
|
| 76 |
"""Aggregate item embeddings for each user's interaction history."""
|
| 77 |
|
| 78 |
user_aggregated_embeddings = {}
|
|
@@ -103,9 +103,11 @@ class UserDatasetCreator:
|
|
| 103 |
|
| 104 |
# Pad or truncate to max_history_length
|
| 105 |
if len(history_embeddings) < self.max_history_length:
|
|
|
|
| 106 |
padding = np.zeros((self.max_history_length - len(history_embeddings), embedding_dim))
|
| 107 |
-
history_embeddings = np.vstack([
|
| 108 |
else:
|
|
|
|
| 109 |
history_embeddings = history_embeddings[-self.max_history_length:]
|
| 110 |
|
| 111 |
user_aggregated_embeddings[user_id] = history_embeddings
|
|
@@ -368,5 +370,58 @@ def main():
|
|
| 368 |
print("User dataset creation completed!")
|
| 369 |
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
if __name__ == "__main__":
|
| 372 |
main()
|
|
|
|
| 63 |
# Use more efficient random generation
|
| 64 |
num_items = len(items_df['product_id'].unique())
|
| 65 |
item_ids = items_df['product_id'].unique()
|
| 66 |
+
embedding_matrix = np.random.rand(num_items, 128).astype(np.float32) # Updated to 128D
|
| 67 |
|
| 68 |
dummy_embeddings = dict(zip(item_ids, embedding_matrix))
|
| 69 |
print(f"Created dummy embeddings for {len(dummy_embeddings)} items")
|
|
|
|
| 72 |
def aggregate_user_history_embeddings(self,
|
| 73 |
user_histories: Dict[int, List[int]],
|
| 74 |
item_embeddings: Dict[int, np.ndarray],
|
| 75 |
+
embedding_dim: int = 128) -> Dict[int, np.ndarray]: # Updated to 128D
|
| 76 |
"""Aggregate item embeddings for each user's interaction history."""
|
| 77 |
|
| 78 |
user_aggregated_embeddings = {}
|
|
|
|
| 103 |
|
| 104 |
# Pad or truncate to max_history_length
|
| 105 |
if len(history_embeddings) < self.max_history_length:
|
| 106 |
+
# Add padding at the END so real interactions are at the BEGINNING
|
| 107 |
padding = np.zeros((self.max_history_length - len(history_embeddings), embedding_dim))
|
| 108 |
+
history_embeddings = np.vstack([history_embeddings, padding])
|
| 109 |
else:
|
| 110 |
+
# Keep most recent interactions
|
| 111 |
history_embeddings = history_embeddings[-self.max_history_length:]
|
| 112 |
|
| 113 |
user_aggregated_embeddings[user_id] = history_embeddings
|
|
|
|
| 370 |
print("User dataset creation completed!")
|
| 371 |
|
| 372 |
|
| 373 |
+
def prepare_user_features(users_df: pd.DataFrame,
|
| 374 |
+
user_histories: Dict[int, List[int]],
|
| 375 |
+
item_features: Dict[str, np.ndarray],
|
| 376 |
+
max_history_length: int = 50,
|
| 377 |
+
embedding_dim: int = 128) -> Dict[int, Dict]:
|
| 378 |
+
"""Standalone function to prepare user features with categorical demographics."""
|
| 379 |
+
|
| 380 |
+
creator = UserDatasetCreator(max_history_length=max_history_length)
|
| 381 |
+
|
| 382 |
+
# Create dummy item embeddings if not available (for 128D)
|
| 383 |
+
item_embeddings = {}
|
| 384 |
+
unique_items = set()
|
| 385 |
+
for history in user_histories.values():
|
| 386 |
+
unique_items.update(history)
|
| 387 |
+
|
| 388 |
+
# Create random embeddings for items (will be replaced by actual embeddings later)
|
| 389 |
+
for item_vocab_idx in unique_items:
|
| 390 |
+
item_embeddings[item_vocab_idx] = np.random.randn(embedding_dim).astype(np.float32)
|
| 391 |
+
|
| 392 |
+
# Get user aggregated embeddings
|
| 393 |
+
user_aggregated_embeddings = creator.aggregate_user_history_embeddings(
|
| 394 |
+
user_histories, item_embeddings, embedding_dim
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Process user features
|
| 398 |
+
user_feature_dict = {}
|
| 399 |
+
|
| 400 |
+
for _, user_row in users_df.iterrows():
|
| 401 |
+
user_id = user_row['user_id']
|
| 402 |
+
|
| 403 |
+
if user_id not in user_aggregated_embeddings:
|
| 404 |
+
continue
|
| 405 |
+
|
| 406 |
+
# Categorize demographics
|
| 407 |
+
age_cat = creator.categorize_age(user_row['age'])
|
| 408 |
+
gender_cat = 1 if user_row['gender'].lower() == 'male' else 0
|
| 409 |
+
|
| 410 |
+
# Categorize income using percentiles from all users
|
| 411 |
+
income_categories = creator.categorize_income(users_df['income'])
|
| 412 |
+
user_idx = users_df[users_df['user_id'] == user_id].index[0]
|
| 413 |
+
income_cat = income_categories[user_idx]
|
| 414 |
+
|
| 415 |
+
user_feature_dict[user_id] = {
|
| 416 |
+
'age': age_cat,
|
| 417 |
+
'gender': gender_cat,
|
| 418 |
+
'income': income_cat,
|
| 419 |
+
'item_history_embeddings': user_aggregated_embeddings[user_id]
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
print(f"Prepared features for {len(user_feature_dict)} users with {embedding_dim}D embeddings")
|
| 423 |
+
return user_feature_dict
|
| 424 |
+
|
| 425 |
+
|
| 426 |
if __name__ == "__main__":
|
| 427 |
main()
|
|
@@ -66,7 +66,7 @@ class FastJointTrainer:
|
|
| 66 |
# Build user tower (simplified)
|
| 67 |
self.user_tower = UserTower(
|
| 68 |
max_history_length=50,
|
| 69 |
-
embedding_dim=
|
| 70 |
hidden_dims=[64], # Simplified architecture
|
| 71 |
dropout_rate=0.1
|
| 72 |
)
|
|
|
|
| 66 |
# Build user tower (simplified)
|
| 67 |
self.user_tower = UserTower(
|
| 68 |
max_history_length=50,
|
| 69 |
+
embedding_dim=128, # Updated to 128D
|
| 70 |
hidden_dims=[64], # Simplified architecture
|
| 71 |
dropout_rate=0.1
|
| 72 |
)
|
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Improved joint training with hard negative mining, curriculum learning, and better optimization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pickle
|
| 9 |
+
import os
|
| 10 |
+
from typing import Dict, List, Tuple, Optional
|
| 11 |
+
import time
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
|
| 14 |
+
from src.models.improved_two_tower import create_improved_model
|
| 15 |
+
from src.preprocessing.data_loader import DataProcessor, create_tf_dataset
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HardNegativeSampler:
|
| 19 |
+
"""Hard negative sampling strategy for better training."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model, item_embeddings, sampling_strategy='mixed'):
|
| 22 |
+
self.model = model
|
| 23 |
+
self.item_embeddings = item_embeddings # Pre-computed item embeddings
|
| 24 |
+
self.sampling_strategy = sampling_strategy
|
| 25 |
+
|
| 26 |
+
def sample_hard_negatives(self, user_embeddings, positive_items, k_hard=2, k_random=2):
|
| 27 |
+
"""Sample hard negatives based on user-item similarity."""
|
| 28 |
+
batch_size = tf.shape(user_embeddings)[0]
|
| 29 |
+
|
| 30 |
+
# Compute similarities between users and all items
|
| 31 |
+
similarities = tf.linalg.matmul(user_embeddings, self.item_embeddings, transpose_b=True)
|
| 32 |
+
|
| 33 |
+
# Mask out positive items
|
| 34 |
+
positive_mask = tf.one_hot(positive_items, depth=tf.shape(self.item_embeddings)[0])
|
| 35 |
+
similarities = similarities - positive_mask * 1e9 # Large negative value
|
| 36 |
+
|
| 37 |
+
# Get top-k similar items (hard negatives)
|
| 38 |
+
_, hard_negative_indices = tf.nn.top_k(similarities, k=k_hard)
|
| 39 |
+
|
| 40 |
+
# Sample random negatives
|
| 41 |
+
total_items = tf.shape(self.item_embeddings)[0]
|
| 42 |
+
random_negatives = tf.random.uniform(
|
| 43 |
+
shape=[batch_size, k_random],
|
| 44 |
+
minval=0,
|
| 45 |
+
maxval=total_items,
|
| 46 |
+
dtype=tf.int32
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Combine hard and random negatives
|
| 50 |
+
if self.sampling_strategy == 'hard':
|
| 51 |
+
return hard_negative_indices
|
| 52 |
+
elif self.sampling_strategy == 'random':
|
| 53 |
+
return random_negatives
|
| 54 |
+
else: # mixed
|
| 55 |
+
return tf.concat([hard_negative_indices, random_negatives], axis=1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CurriculumLearningScheduler:
|
| 59 |
+
"""Curriculum learning scheduler for progressive difficulty."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, total_epochs, warmup_epochs=10):
|
| 62 |
+
self.total_epochs = total_epochs
|
| 63 |
+
self.warmup_epochs = warmup_epochs
|
| 64 |
+
|
| 65 |
+
def get_difficulty_schedule(self, epoch):
|
| 66 |
+
"""Get curriculum parameters for current epoch."""
|
| 67 |
+
if epoch < self.warmup_epochs:
|
| 68 |
+
# Easy phase: more random negatives, lower temperature
|
| 69 |
+
hard_negative_ratio = 0.2
|
| 70 |
+
temperature = 2.0
|
| 71 |
+
negative_samples = 2
|
| 72 |
+
elif epoch < self.total_epochs * 0.6:
|
| 73 |
+
# Medium phase: balanced negatives
|
| 74 |
+
hard_negative_ratio = 0.5
|
| 75 |
+
temperature = 1.0
|
| 76 |
+
negative_samples = 4
|
| 77 |
+
else:
|
| 78 |
+
# Hard phase: more hard negatives, higher temperature
|
| 79 |
+
hard_negative_ratio = 0.8
|
| 80 |
+
temperature = 0.5
|
| 81 |
+
negative_samples = 6
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
'hard_negative_ratio': hard_negative_ratio,
|
| 85 |
+
'temperature': temperature,
|
| 86 |
+
'negative_samples': negative_samples
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ImprovedJointTrainer:
|
| 91 |
+
"""Enhanced joint trainer with advanced techniques."""
|
| 92 |
+
|
| 93 |
+
def __init__(self,
|
| 94 |
+
embedding_dim: int = 128,
|
| 95 |
+
learning_rate: float = 0.001,
|
| 96 |
+
use_mixed_precision: bool = True,
|
| 97 |
+
use_curriculum_learning: bool = True,
|
| 98 |
+
use_hard_negatives: bool = True):
|
| 99 |
+
|
| 100 |
+
self.embedding_dim = embedding_dim
|
| 101 |
+
self.learning_rate = learning_rate
|
| 102 |
+
self.use_mixed_precision = use_mixed_precision
|
| 103 |
+
self.use_curriculum_learning = use_curriculum_learning
|
| 104 |
+
self.use_hard_negatives = use_hard_negatives
|
| 105 |
+
|
| 106 |
+
# Enable mixed precision if requested
|
| 107 |
+
if use_mixed_precision:
|
| 108 |
+
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
| 109 |
+
tf.keras.mixed_precision.set_global_policy(policy)
|
| 110 |
+
|
| 111 |
+
self.model = None
|
| 112 |
+
self.data_processor = None
|
| 113 |
+
self.curriculum_scheduler = None
|
| 114 |
+
self.hard_negative_sampler = None
|
| 115 |
+
|
| 116 |
+
def setup_model(self, data_processor: DataProcessor):
|
| 117 |
+
"""Setup the improved model."""
|
| 118 |
+
self.data_processor = data_processor
|
| 119 |
+
|
| 120 |
+
# Create improved model
|
| 121 |
+
self.model = create_improved_model(
|
| 122 |
+
data_processor=data_processor,
|
| 123 |
+
embedding_dim=self.embedding_dim,
|
| 124 |
+
use_bias=True,
|
| 125 |
+
use_focal_loss=True
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
print(f"Created improved two-tower model with {self.embedding_dim}D embeddings")
|
| 129 |
+
|
| 130 |
+
def setup_curriculum_learning(self, total_epochs: int):
|
| 131 |
+
"""Setup curriculum learning scheduler."""
|
| 132 |
+
if self.use_curriculum_learning:
|
| 133 |
+
self.curriculum_scheduler = CurriculumLearningScheduler(
|
| 134 |
+
total_epochs=total_epochs,
|
| 135 |
+
warmup_epochs=max(5, total_epochs // 10)
|
| 136 |
+
)
|
| 137 |
+
print("Curriculum learning enabled")
|
| 138 |
+
|
| 139 |
+
def setup_hard_negative_sampling(self, item_features: Dict[str, np.ndarray]):
|
| 140 |
+
"""Setup hard negative sampling."""
|
| 141 |
+
if self.use_hard_negatives:
|
| 142 |
+
# Pre-compute item embeddings for efficient hard negative sampling
|
| 143 |
+
print("Pre-computing item embeddings for hard negative sampling...")
|
| 144 |
+
|
| 145 |
+
# Create a dummy batch to get item embeddings
|
| 146 |
+
batch_size = 1000
|
| 147 |
+
total_items = len(item_features['product_id'])
|
| 148 |
+
|
| 149 |
+
item_embeddings_list = []
|
| 150 |
+
for i in range(0, total_items, batch_size):
|
| 151 |
+
end_idx = min(i + batch_size, total_items)
|
| 152 |
+
batch_features = {
|
| 153 |
+
key: tf.constant(value[i:end_idx])
|
| 154 |
+
for key, value in item_features.items()
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
item_emb_output = self.model.item_tower(batch_features, training=False)
|
| 158 |
+
if isinstance(item_emb_output, tuple):
|
| 159 |
+
item_emb = item_emb_output[0] # Get embeddings, ignore bias
|
| 160 |
+
else:
|
| 161 |
+
item_emb = item_emb_output
|
| 162 |
+
|
| 163 |
+
item_embeddings_list.append(item_emb.numpy())
|
| 164 |
+
|
| 165 |
+
item_embeddings = np.vstack(item_embeddings_list)
|
| 166 |
+
|
| 167 |
+
self.hard_negative_sampler = HardNegativeSampler(
|
| 168 |
+
model=self.model,
|
| 169 |
+
item_embeddings=tf.constant(item_embeddings, dtype=tf.float32),
|
| 170 |
+
sampling_strategy='mixed'
|
| 171 |
+
)
|
| 172 |
+
print(f"Hard negative sampling enabled with {len(item_embeddings)} items")
|
| 173 |
+
|
| 174 |
+
def create_advanced_training_dataset(self,
|
| 175 |
+
features: Dict[str, np.ndarray],
|
| 176 |
+
batch_size: int = 256,
|
| 177 |
+
epoch: int = 0) -> tf.data.Dataset:
|
| 178 |
+
"""Create training dataset with curriculum learning and hard negatives."""
|
| 179 |
+
|
| 180 |
+
# Get curriculum parameters
|
| 181 |
+
if self.curriculum_scheduler:
|
| 182 |
+
curriculum_params = self.curriculum_scheduler.get_difficulty_schedule(epoch)
|
| 183 |
+
print(f"Epoch {epoch}: {curriculum_params}")
|
| 184 |
+
else:
|
| 185 |
+
curriculum_params = {
|
| 186 |
+
'hard_negative_ratio': 0.5,
|
| 187 |
+
'temperature': 1.0,
|
| 188 |
+
'negative_samples': 4
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# Filter data based on curriculum (start with easier examples)
|
| 192 |
+
if epoch < 5: # Warmup epochs - use only high-confidence positive examples
|
| 193 |
+
positive_mask = features['rating'] == 1.0
|
| 194 |
+
if np.sum(positive_mask) > 0:
|
| 195 |
+
# Sample subset of positives and all negatives
|
| 196 |
+
positive_indices = np.where(positive_mask)[0]
|
| 197 |
+
negative_indices = np.where(features['rating'] == 0.0)[0]
|
| 198 |
+
|
| 199 |
+
# Sample subset for easier learning
|
| 200 |
+
n_positive_samples = min(len(positive_indices), len(negative_indices))
|
| 201 |
+
selected_positive = np.random.choice(
|
| 202 |
+
positive_indices, size=n_positive_samples, replace=False
|
| 203 |
+
)
|
| 204 |
+
selected_negative = np.random.choice(
|
| 205 |
+
negative_indices, size=n_positive_samples, replace=False
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
selected_indices = np.concatenate([selected_positive, selected_negative])
|
| 209 |
+
np.random.shuffle(selected_indices)
|
| 210 |
+
|
| 211 |
+
# Filter features
|
| 212 |
+
filtered_features = {
|
| 213 |
+
key: value[selected_indices] for key, value in features.items()
|
| 214 |
+
}
|
| 215 |
+
else:
|
| 216 |
+
filtered_features = features
|
| 217 |
+
else:
|
| 218 |
+
filtered_features = features
|
| 219 |
+
|
| 220 |
+
# Create dataset
|
| 221 |
+
dataset = create_tf_dataset(filtered_features, batch_size, shuffle=True)
|
| 222 |
+
|
| 223 |
+
return dataset
|
| 224 |
+
|
| 225 |
+
def compile_model(self):
|
| 226 |
+
"""Compile model with advanced optimizer."""
|
| 227 |
+
# Use AdamW with learning rate scheduling
|
| 228 |
+
initial_learning_rate = self.learning_rate
|
| 229 |
+
lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
|
| 230 |
+
initial_learning_rate=initial_learning_rate,
|
| 231 |
+
first_decay_steps=1000,
|
| 232 |
+
t_mul=2.0,
|
| 233 |
+
m_mul=0.9,
|
| 234 |
+
alpha=0.01
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
optimizer = tf.keras.optimizers.AdamW(
|
| 238 |
+
learning_rate=lr_schedule,
|
| 239 |
+
weight_decay=1e-5,
|
| 240 |
+
beta_1=0.9,
|
| 241 |
+
beta_2=0.999,
|
| 242 |
+
epsilon=1e-7
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Enable mixed precision optimizer if needed
|
| 246 |
+
if self.use_mixed_precision:
|
| 247 |
+
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
|
| 248 |
+
|
| 249 |
+
self.optimizer = optimizer
|
| 250 |
+
print(f"Model compiled with AdamW optimizer (lr={self.learning_rate})")
|
| 251 |
+
|
| 252 |
+
@tf.function
|
| 253 |
+
def train_step(self, features):
|
| 254 |
+
"""Optimized training step with gradient scaling."""
|
| 255 |
+
with tf.GradientTape() as tape:
|
| 256 |
+
# Forward pass
|
| 257 |
+
loss_dict = self.model.compute_loss(features, training=True)
|
| 258 |
+
total_loss = loss_dict['total_loss']
|
| 259 |
+
|
| 260 |
+
# Scale loss for mixed precision
|
| 261 |
+
if self.use_mixed_precision:
|
| 262 |
+
scaled_loss = self.optimizer.get_scaled_loss(total_loss)
|
| 263 |
+
else:
|
| 264 |
+
scaled_loss = total_loss
|
| 265 |
+
|
| 266 |
+
# Compute gradients
|
| 267 |
+
if self.use_mixed_precision:
|
| 268 |
+
scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
| 269 |
+
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
|
| 270 |
+
else:
|
| 271 |
+
gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
| 272 |
+
|
| 273 |
+
# Clip gradients to prevent exploding gradients
|
| 274 |
+
gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
|
| 275 |
+
|
| 276 |
+
# Apply gradients
|
| 277 |
+
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
|
| 278 |
+
|
| 279 |
+
return loss_dict
|
| 280 |
+
|
| 281 |
+
def evaluate_model(self, validation_dataset):
|
| 282 |
+
"""Evaluate model on validation set."""
|
| 283 |
+
total_losses = defaultdict(list)
|
| 284 |
+
|
| 285 |
+
for batch in validation_dataset:
|
| 286 |
+
loss_dict = self.model.compute_loss(batch, training=False)
|
| 287 |
+
for key, value in loss_dict.items():
|
| 288 |
+
total_losses[key].append(float(value))
|
| 289 |
+
|
| 290 |
+
# Average losses
|
| 291 |
+
avg_losses = {key: np.mean(values) for key, values in total_losses.items()}
|
| 292 |
+
return avg_losses
|
| 293 |
+
|
| 294 |
+
def train(self,
|
| 295 |
+
training_features: Dict[str, np.ndarray],
|
| 296 |
+
validation_features: Dict[str, np.ndarray],
|
| 297 |
+
epochs: int = 50,
|
| 298 |
+
batch_size: int = 256,
|
| 299 |
+
save_path: str = "src/artifacts/") -> Dict:
|
| 300 |
+
"""Enhanced training loop with all improvements."""
|
| 301 |
+
|
| 302 |
+
print(f"Starting improved training for {epochs} epochs...")
|
| 303 |
+
|
| 304 |
+
# Setup components
|
| 305 |
+
self.setup_curriculum_learning(epochs)
|
| 306 |
+
self.compile_model()
|
| 307 |
+
|
| 308 |
+
# Create validation dataset
|
| 309 |
+
validation_dataset = create_tf_dataset(validation_features, batch_size, shuffle=False)
|
| 310 |
+
|
| 311 |
+
# Training history
|
| 312 |
+
history = defaultdict(list)
|
| 313 |
+
best_val_loss = float('inf')
|
| 314 |
+
patience_counter = 0
|
| 315 |
+
early_stopping_patience = 10
|
| 316 |
+
|
| 317 |
+
# Training loop
|
| 318 |
+
for epoch in range(epochs):
|
| 319 |
+
epoch_start_time = time.time()
|
| 320 |
+
|
| 321 |
+
# Create training dataset for this epoch (curriculum learning)
|
| 322 |
+
training_dataset = self.create_advanced_training_dataset(
|
| 323 |
+
training_features, batch_size, epoch
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Training
|
| 327 |
+
epoch_losses = defaultdict(list)
|
| 328 |
+
num_batches = 0
|
| 329 |
+
|
| 330 |
+
for batch in training_dataset:
|
| 331 |
+
loss_dict = self.train_step(batch)
|
| 332 |
+
|
| 333 |
+
for key, value in loss_dict.items():
|
| 334 |
+
epoch_losses[key].append(float(value))
|
| 335 |
+
num_batches += 1
|
| 336 |
+
|
| 337 |
+
# Average training losses
|
| 338 |
+
avg_train_losses = {
|
| 339 |
+
key: np.mean(values) for key, values in epoch_losses.items()
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
# Validation
|
| 343 |
+
avg_val_losses = self.evaluate_model(validation_dataset)
|
| 344 |
+
|
| 345 |
+
# Log progress
|
| 346 |
+
epoch_time = time.time() - epoch_start_time
|
| 347 |
+
print(f"Epoch {epoch+1}/{epochs} ({epoch_time:.1f}s):")
|
| 348 |
+
print(f" Train Loss: {avg_train_losses['total_loss']:.4f}")
|
| 349 |
+
print(f" Val Loss: {avg_val_losses['total_loss']:.4f}")
|
| 350 |
+
print(f" Val Rating Loss: {avg_val_losses['rating_loss']:.4f}")
|
| 351 |
+
print(f" Val Retrieval Loss: {avg_val_losses['retrieval_loss']:.4f}")
|
| 352 |
+
|
| 353 |
+
# Save history
|
| 354 |
+
for key, value in avg_train_losses.items():
|
| 355 |
+
history[f'train_{key}'].append(value)
|
| 356 |
+
for key, value in avg_val_losses.items():
|
| 357 |
+
history[f'val_{key}'].append(value)
|
| 358 |
+
|
| 359 |
+
# Early stopping and model saving
|
| 360 |
+
current_val_loss = avg_val_losses['total_loss']
|
| 361 |
+
if current_val_loss < best_val_loss:
|
| 362 |
+
best_val_loss = current_val_loss
|
| 363 |
+
patience_counter = 0
|
| 364 |
+
|
| 365 |
+
# Save best model
|
| 366 |
+
self.save_model(save_path, suffix='_improved_best')
|
| 367 |
+
print(f" 💾 Saved best model (val_loss: {best_val_loss:.4f})")
|
| 368 |
+
else:
|
| 369 |
+
patience_counter += 1
|
| 370 |
+
|
| 371 |
+
if patience_counter >= early_stopping_patience:
|
| 372 |
+
print(f"Early stopping at epoch {epoch+1}")
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
# Save final model and history
|
| 376 |
+
self.save_model(save_path, suffix='_improved_final')
|
| 377 |
+
self.save_training_history(dict(history), save_path)
|
| 378 |
+
|
| 379 |
+
print("✅ Improved training completed!")
|
| 380 |
+
return dict(history)
|
| 381 |
+
|
| 382 |
+
def save_model(self, save_path: str, suffix: str = ''):
|
| 383 |
+
"""Save the trained model components."""
|
| 384 |
+
os.makedirs(save_path, exist_ok=True)
|
| 385 |
+
|
| 386 |
+
# Save model weights
|
| 387 |
+
self.model.item_tower.save_weights(f"{save_path}/improved_item_tower_weights{suffix}")
|
| 388 |
+
self.model.user_tower.save_weights(f"{save_path}/improved_user_tower_weights{suffix}")
|
| 389 |
+
|
| 390 |
+
if hasattr(self.model, 'rating_model'):
|
| 391 |
+
self.model.rating_model.save_weights(f"{save_path}/improved_rating_model_weights{suffix}")
|
| 392 |
+
|
| 393 |
+
# Save configuration
|
| 394 |
+
config = {
|
| 395 |
+
'embedding_dim': self.embedding_dim,
|
| 396 |
+
'learning_rate': self.learning_rate,
|
| 397 |
+
'use_mixed_precision': self.use_mixed_precision,
|
| 398 |
+
'use_curriculum_learning': self.use_curriculum_learning,
|
| 399 |
+
'use_hard_negatives': self.use_hard_negatives
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
with open(f"{save_path}/improved_model_config{suffix}.txt", 'w') as f:
|
| 403 |
+
for key, value in config.items():
|
| 404 |
+
f.write(f"{key}: {value}\n")
|
| 405 |
+
|
| 406 |
+
print(f"Model saved to {save_path} with suffix '{suffix}'")
|
| 407 |
+
|
| 408 |
+
def save_training_history(self, history: Dict, save_path: str):
|
| 409 |
+
"""Save training history."""
|
| 410 |
+
with open(f"{save_path}/improved_training_history.pkl", 'wb') as f:
|
| 411 |
+
pickle.dump(history, f)
|
| 412 |
+
print(f"Training history saved to {save_path}")
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def main():
|
| 416 |
+
"""Demo of improved training."""
|
| 417 |
+
print("🚀 IMPROVED TWO-TOWER TRAINING DEMO")
|
| 418 |
+
print("="*60)
|
| 419 |
+
|
| 420 |
+
# Load data
|
| 421 |
+
print("Loading training data...")
|
| 422 |
+
try:
|
| 423 |
+
with open("src/artifacts/training_features.pkl", 'rb') as f:
|
| 424 |
+
training_features = pickle.load(f)
|
| 425 |
+
with open("src/artifacts/validation_features.pkl", 'rb') as f:
|
| 426 |
+
validation_features = pickle.load(f)
|
| 427 |
+
|
| 428 |
+
print(f"Loaded {len(training_features['rating'])} training samples")
|
| 429 |
+
print(f"Loaded {len(validation_features['rating'])} validation samples")
|
| 430 |
+
except FileNotFoundError:
|
| 431 |
+
print("❌ Training data not found. Please run data preparation first.")
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
# Load data processor
|
| 435 |
+
data_processor = DataProcessor()
|
| 436 |
+
data_processor.load_vocabularies("src/artifacts/vocabularies.pkl")
|
| 437 |
+
|
| 438 |
+
# Create trainer
|
| 439 |
+
trainer = ImprovedJointTrainer(
|
| 440 |
+
embedding_dim=128,
|
| 441 |
+
learning_rate=0.001,
|
| 442 |
+
use_mixed_precision=True,
|
| 443 |
+
use_curriculum_learning=True,
|
| 444 |
+
use_hard_negatives=True
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Setup and train
|
| 448 |
+
trainer.setup_model(data_processor)
|
| 449 |
+
|
| 450 |
+
# Train model
|
| 451 |
+
history = trainer.train(
|
| 452 |
+
training_features=training_features,
|
| 453 |
+
validation_features=validation_features,
|
| 454 |
+
epochs=30,
|
| 455 |
+
batch_size=256
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
print("✅ Improved training completed successfully!")
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
if __name__ == "__main__":
|
| 462 |
+
main()
|
|
@@ -12,8 +12,8 @@ class ItemTowerPretrainer:
|
|
| 12 |
"""Handles pre-training of the item tower."""
|
| 13 |
|
| 14 |
def __init__(self,
|
| 15 |
-
embedding_dim: int =
|
| 16 |
-
hidden_dims: List[int] = [
|
| 17 |
dropout_rate: float = 0.2,
|
| 18 |
learning_rate: float = 0.001):
|
| 19 |
|
|
@@ -111,19 +111,26 @@ class ItemTowerPretrainer:
|
|
| 111 |
return history
|
| 112 |
|
| 113 |
def generate_item_embeddings(self,
|
| 114 |
-
dataset: tf.data.Dataset
|
|
|
|
| 115 |
"""Generate embeddings for all items in the catalog."""
|
| 116 |
|
| 117 |
item_embeddings = {}
|
| 118 |
|
|
|
|
|
|
|
|
|
|
| 119 |
for batch in dataset:
|
| 120 |
embeddings = self.item_tower(batch)
|
| 121 |
-
|
| 122 |
|
| 123 |
-
for i,
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
|
| 126 |
print(f"Generated embeddings for {len(item_embeddings)} items")
|
|
|
|
| 127 |
return item_embeddings
|
| 128 |
|
| 129 |
def save_model(self, save_path: str = "src/artifacts/"):
|
|
@@ -185,7 +192,7 @@ def main():
|
|
| 185 |
# Initialize components
|
| 186 |
data_processor = DataProcessor()
|
| 187 |
pretrainer = ItemTowerPretrainer(
|
| 188 |
-
embedding_dim=
|
| 189 |
hidden_dims=[128, 64],
|
| 190 |
dropout_rate=0.2,
|
| 191 |
learning_rate=0.001
|
|
@@ -210,7 +217,7 @@ def main():
|
|
| 210 |
|
| 211 |
# Generate embeddings
|
| 212 |
print("Generating item embeddings...")
|
| 213 |
-
item_embeddings = pretrainer.generate_item_embeddings(dataset)
|
| 214 |
|
| 215 |
# Save everything
|
| 216 |
print("Saving artifacts...")
|
|
|
|
| 12 |
"""Handles pre-training of the item tower."""
|
| 13 |
|
| 14 |
def __init__(self,
|
| 15 |
+
embedding_dim: int = 128, # Updated to 128D output
|
| 16 |
+
hidden_dims: List[int] = [256, 128], # Scaled up
|
| 17 |
dropout_rate: float = 0.2,
|
| 18 |
learning_rate: float = 0.001):
|
| 19 |
|
|
|
|
| 111 |
return history
|
| 112 |
|
| 113 |
def generate_item_embeddings(self,
|
| 114 |
+
dataset: tf.data.Dataset,
|
| 115 |
+
data_processor: 'DataProcessor') -> Dict[int, np.ndarray]:
|
| 116 |
"""Generate embeddings for all items in the catalog."""
|
| 117 |
|
| 118 |
item_embeddings = {}
|
| 119 |
|
| 120 |
+
# Create reverse mapping from vocab indices to actual item IDs
|
| 121 |
+
idx_to_item_id = {idx: item_id for item_id, idx in data_processor.item_vocab.items()}
|
| 122 |
+
|
| 123 |
for batch in dataset:
|
| 124 |
embeddings = self.item_tower(batch)
|
| 125 |
+
product_idx_batch = batch['product_id'].numpy()
|
| 126 |
|
| 127 |
+
for i, product_idx in enumerate(product_idx_batch):
|
| 128 |
+
# Convert vocab index back to actual item ID
|
| 129 |
+
actual_item_id = idx_to_item_id.get(product_idx, product_idx)
|
| 130 |
+
item_embeddings[actual_item_id] = embeddings[i].numpy()
|
| 131 |
|
| 132 |
print(f"Generated embeddings for {len(item_embeddings)} items")
|
| 133 |
+
print(f"Sample item IDs: {list(item_embeddings.keys())[:5]}")
|
| 134 |
return item_embeddings
|
| 135 |
|
| 136 |
def save_model(self, save_path: str = "src/artifacts/"):
|
|
|
|
| 192 |
# Initialize components
|
| 193 |
data_processor = DataProcessor()
|
| 194 |
pretrainer = ItemTowerPretrainer(
|
| 195 |
+
embedding_dim=128, # Updated to 128D
|
| 196 |
hidden_dims=[128, 64],
|
| 197 |
dropout_rate=0.2,
|
| 198 |
learning_rate=0.001
|
|
|
|
| 217 |
|
| 218 |
# Generate embeddings
|
| 219 |
print("Generating item embeddings...")
|
| 220 |
+
item_embeddings = pretrainer.generate_item_embeddings(dataset, data_processor)
|
| 221 |
|
| 222 |
# Save everything
|
| 223 |
print("Saving artifacts...")
|
|
@@ -13,7 +13,7 @@ class JointTrainer:
|
|
| 13 |
"""Handles joint training of user and item towers."""
|
| 14 |
|
| 15 |
def __init__(self,
|
| 16 |
-
embedding_dim: int =
|
| 17 |
user_learning_rate: float = 0.001,
|
| 18 |
item_learning_rate: float = 0.0001, # Lower LR for pre-trained item tower
|
| 19 |
rating_weight: float = 1.0,
|
|
@@ -330,7 +330,7 @@ def main():
|
|
| 330 |
|
| 331 |
# Initialize trainer
|
| 332 |
trainer = JointTrainer(
|
| 333 |
-
embedding_dim=
|
| 334 |
user_learning_rate=0.001,
|
| 335 |
item_learning_rate=0.0001,
|
| 336 |
rating_weight=1.0,
|
|
|
|
| 13 |
"""Handles joint training of user and item towers."""
|
| 14 |
|
| 15 |
def __init__(self,
|
| 16 |
+
embedding_dim: int = 128, # Updated to 128D output
|
| 17 |
user_learning_rate: float = 0.001,
|
| 18 |
item_learning_rate: float = 0.0001, # Lower LR for pre-trained item tower
|
| 19 |
rating_weight: float = 1.0,
|
|
|
|
| 330 |
|
| 331 |
# Initialize trainer
|
| 332 |
trainer = JointTrainer(
|
| 333 |
+
embedding_dim=128, # Updated to 128D
|
| 334 |
user_learning_rate=0.001,
|
| 335 |
item_learning_rate=0.0001,
|
| 336 |
rating_weight=1.0,
|
|
@@ -14,7 +14,7 @@ class OptimizedJointTrainer:
|
|
| 14 |
"""Optimized joint training with performance enhancements."""
|
| 15 |
|
| 16 |
def __init__(self,
|
| 17 |
-
embedding_dim: int =
|
| 18 |
user_learning_rate: float = 0.001,
|
| 19 |
item_learning_rate: float = 0.0001,
|
| 20 |
rating_weight: float = 1.0,
|
|
@@ -381,7 +381,7 @@ def main():
|
|
| 381 |
|
| 382 |
print("Initializing optimized joint trainer...")
|
| 383 |
trainer = OptimizedJointTrainer(
|
| 384 |
-
embedding_dim=
|
| 385 |
user_learning_rate=0.002, # Slightly higher for faster convergence
|
| 386 |
item_learning_rate=0.0002,
|
| 387 |
rating_weight=1.0,
|
|
|
|
| 14 |
"""Optimized joint training with performance enhancements."""
|
| 15 |
|
| 16 |
def __init__(self,
|
| 17 |
+
embedding_dim: int = 128, # Updated to 128D output
|
| 18 |
user_learning_rate: float = 0.001,
|
| 19 |
item_learning_rate: float = 0.0001,
|
| 20 |
rating_weight: float = 1.0,
|
|
|
|
| 381 |
|
| 382 |
print("Initializing optimized joint trainer...")
|
| 383 |
trainer = OptimizedJointTrainer(
|
| 384 |
+
embedding_dim=128, # Updated to 128D
|
| 385 |
user_learning_rate=0.002, # Slightly higher for faster convergence
|
| 386 |
item_learning_rate=0.0002,
|
| 387 |
rating_weight=1.0,
|