edwinbh commited on
Commit
e73cff5
ยท
verified ยท
1 Parent(s): 2968f9e

Upload dlrm_inference.py

Browse files
Files changed (1) hide show
  1. dlrm_inference.py +527 -0
dlrm_inference.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DLRM Inference Engine for Book Recommendations
3
+ Loads trained DLRM model and provides recommendation functionality
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pickle
10
+ import mlflow
11
+ from mlflow import MlflowClient
12
+ import tempfile
13
+ import os
14
+ from typing import List, Dict, Tuple, Optional, Any
15
+ from functools import partial
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+
19
+ from torchrec import EmbeddingBagCollection
20
+ from torchrec.models.dlrm import DLRM, DLRMTrain
21
+ from torchrec.modules.embedding_configs import EmbeddingBagConfig
22
+ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
23
+ from torchrec.datasets.utils import Batch
24
+
25
+ class DLRMBookRecommender:
26
+ """DLRM-based book recommender for inference"""
27
+
28
+ def __init__(self, model_path: str = None, run_id: str = None):
29
+ """
30
+ Initialize DLRM book recommender
31
+
32
+ Args:
33
+ model_path: Path to saved model state dict
34
+ run_id: MLflow run ID to load model from
35
+ """
36
+ self.device = torch.device("cpu")
37
+ self.model = None
38
+ self.preprocessing_info = None
39
+
40
+ # Load preprocessing info
41
+ self._load_preprocessing_info()
42
+
43
+ # Load model
44
+ if model_path and os.path.exists(model_path):
45
+ self._load_model_from_path(model_path)
46
+ elif run_id:
47
+ self._load_model_from_mlflow(run_id)
48
+ else:
49
+ print("โš ๏ธ No model loaded. Please provide model_path or run_id")
50
+
51
+ def _load_preprocessing_info(self):
52
+ """Load preprocessing information"""
53
+ if os.path.exists('book_dlrm_preprocessing.pkl'):
54
+ with open('book_dlrm_preprocessing.pkl', 'rb') as f:
55
+ self.preprocessing_info = pickle.load(f)
56
+
57
+ self.dense_cols = self.preprocessing_info['dense_cols']
58
+ self.cat_cols = self.preprocessing_info['cat_cols']
59
+ self.emb_counts = self.preprocessing_info['emb_counts']
60
+ self.user_encoder = self.preprocessing_info['user_encoder']
61
+ self.book_encoder = self.preprocessing_info['book_encoder']
62
+ self.publisher_encoder = self.preprocessing_info['publisher_encoder']
63
+ self.location_encoder = self.preprocessing_info['location_encoder']
64
+ self.scaler = self.preprocessing_info['scaler']
65
+
66
+ print("โœ… Preprocessing info loaded")
67
+ else:
68
+ raise FileNotFoundError("book_dlrm_preprocessing.pkl not found. Run preprocessing first.")
69
+
70
+ def _load_model_from_path(self, model_path: str):
71
+ """Load model from saved state dict"""
72
+ try:
73
+ # Create model architecture
74
+ eb_configs = [
75
+ EmbeddingBagConfig(
76
+ name=f"t_{feature_name}",
77
+ embedding_dim=64, # Default embedding dim
78
+ num_embeddings=self.emb_counts[feature_idx],
79
+ feature_names=[feature_name],
80
+ )
81
+ for feature_idx, feature_name in enumerate(self.cat_cols)
82
+ ]
83
+
84
+ dlrm_model = DLRM(
85
+ embedding_bag_collection=EmbeddingBagCollection(
86
+ tables=eb_configs, device=self.device
87
+ ),
88
+ dense_in_features=len(self.dense_cols),
89
+ dense_arch_layer_sizes=[256, 128, 64],
90
+ over_arch_layer_sizes=[512, 256, 128, 1],
91
+ dense_device=self.device,
92
+ )
93
+
94
+ # Load state dict
95
+ state_dict = torch.load(model_path, map_location=self.device)
96
+
97
+ # Remove 'model.' prefix if present
98
+ if any(key.startswith('model.') for key in state_dict.keys()):
99
+ state_dict = {k[6:]: v for k, v in state_dict.items()}
100
+
101
+ dlrm_model.load_state_dict(state_dict)
102
+ self.model = dlrm_model
103
+ self.model.eval()
104
+
105
+ print(f"โœ… Model loaded from {model_path}")
106
+
107
+ except Exception as e:
108
+ print(f"โŒ Error loading model: {e}")
109
+
110
+ def _load_model_from_mlflow(self, run_id: str):
111
+ """Load model from MLflow"""
112
+ try:
113
+ client = MlflowClient()
114
+ run = client.get_run(run_id)
115
+
116
+ # Get model parameters from MLflow
117
+ params = run.data.params
118
+ cat_cols = eval(params.get('cat_cols'))
119
+ emb_counts = eval(params.get('emb_counts'))
120
+ dense_cols = eval(params.get('dense_cols'))
121
+ embedding_dim = int(params.get('embedding_dim', 64))
122
+ dense_arch_layer_sizes = eval(params.get('dense_arch_layer_sizes'))
123
+ over_arch_layer_sizes = eval(params.get('over_arch_layer_sizes'))
124
+
125
+ # Download model from MLflow
126
+ temp_dir = tempfile.mkdtemp()
127
+
128
+ # Try different artifact paths
129
+ for artifact_path in ['model_state_dict_final', 'model_state_dict_2', 'model_state_dict_1', 'model_state_dict_0']:
130
+ try:
131
+ client.download_artifacts(run_id, f"{artifact_path}/state_dict.pth", temp_dir)
132
+ state_dict = mlflow.pytorch.load_state_dict(f"{temp_dir}/{artifact_path}")
133
+ break
134
+ except:
135
+ continue
136
+ else:
137
+ raise Exception("No model artifacts found")
138
+
139
+ # Create model
140
+ eb_configs = [
141
+ EmbeddingBagConfig(
142
+ name=f"t_{feature_name}",
143
+ embedding_dim=embedding_dim,
144
+ num_embeddings=emb_counts[feature_idx],
145
+ feature_names=[feature_name],
146
+ )
147
+ for feature_idx, feature_name in enumerate(cat_cols)
148
+ ]
149
+
150
+ dlrm_model = DLRM(
151
+ embedding_bag_collection=EmbeddingBagCollection(
152
+ tables=eb_configs, device=self.device
153
+ ),
154
+ dense_in_features=len(dense_cols),
155
+ dense_arch_layer_sizes=dense_arch_layer_sizes,
156
+ over_arch_layer_sizes=over_arch_layer_sizes,
157
+ dense_device=self.device,
158
+ )
159
+
160
+ # Remove prefix and load state dict
161
+ if any(key.startswith('model.') for key in state_dict.keys()):
162
+ state_dict = {k[6:]: v for k, v in state_dict.items()}
163
+
164
+ dlrm_model.load_state_dict(state_dict)
165
+ self.model = dlrm_model
166
+ self.model.eval()
167
+
168
+ print(f"โœ… Model loaded from MLflow run: {run_id}")
169
+
170
+ except Exception as e:
171
+ print(f"โŒ Error loading model from MLflow: {e}")
172
+
173
+ def _prepare_user_features(self, user_id: int, user_data: Optional[Dict] = None) -> Tuple[torch.Tensor, KeyedJaggedTensor]:
174
+ """Prepare user features for inference"""
175
+
176
+ if user_data is None:
177
+ # Create default user features
178
+ user_data = {
179
+ 'User-ID': user_id,
180
+ 'Age': 30, # Default age
181
+ 'Location': 'usa', # Default location
182
+ }
183
+
184
+ # Encode categorical features
185
+ try:
186
+ user_id_encoded = self.user_encoder.transform([str(user_id)])[0]
187
+ except:
188
+ # Handle unknown user
189
+ user_id_encoded = 0
190
+
191
+ try:
192
+ location = str(user_data.get('Location', 'usa')).split(',')[-1].strip().lower()
193
+ country_encoded = self.location_encoder.transform([location])[0]
194
+ except:
195
+ country_encoded = 0
196
+
197
+ # Age group
198
+ age = user_data.get('Age', 30)
199
+ if age < 18:
200
+ age_group = 0
201
+ elif age < 25:
202
+ age_group = 1
203
+ elif age < 35:
204
+ age_group = 2
205
+ elif age < 50:
206
+ age_group = 3
207
+ elif age < 65:
208
+ age_group = 4
209
+ else:
210
+ age_group = 5
211
+
212
+ # Get user statistics (if available)
213
+ user_activity = user_data.get('user_activity', 10) # Default
214
+ user_avg_rating = user_data.get('user_avg_rating', 6.0) # Default
215
+ age_normalized = user_data.get('Age', 30)
216
+
217
+ # Normalize dense features
218
+ dense_features = np.array([[age_normalized, 2000, user_activity, 10, user_avg_rating, 6.0]]) # Default values
219
+ dense_features = self.scaler.transform(dense_features)
220
+ dense_features = torch.tensor(dense_features, dtype=torch.float32)
221
+
222
+ return dense_features, user_id_encoded, country_encoded, age_group
223
+
224
+ def _prepare_book_features(self, book_isbn: str, book_data: Optional[Dict] = None) -> Tuple[int, int, int, int]:
225
+ """Prepare book features for inference"""
226
+
227
+ if book_data is None:
228
+ book_data = {}
229
+
230
+ # Encode book ID
231
+ try:
232
+ book_id_encoded = self.book_encoder.transform([str(book_isbn)])[0]
233
+ except:
234
+ book_id_encoded = 0
235
+
236
+ # Encode publisher
237
+ try:
238
+ publisher = str(book_data.get('Publisher', 'Unknown'))
239
+ publisher_encoded = self.publisher_encoder.transform([publisher])[0]
240
+ except:
241
+ publisher_encoded = 0
242
+
243
+ # Publication decade
244
+ year = book_data.get('Year-Of-Publication', 2000)
245
+ decade = ((int(year) // 10) * 10)
246
+ try:
247
+ decade_encoded = preprocessing_info.get('decade_encoder', LabelEncoder()).transform([str(decade)])[0]
248
+ except:
249
+ decade_encoded = 6 # Default to 2000s
250
+
251
+ # Rating level (default to medium)
252
+ rating_level = 1
253
+
254
+ return book_id_encoded, publisher_encoded, decade_encoded, rating_level
255
+
256
+ def predict_rating(self, user_id: int, book_isbn: str,
257
+ user_data: Optional[Dict] = None,
258
+ book_data: Optional[Dict] = None) -> float:
259
+ """
260
+ Predict rating probability for user-book pair
261
+
262
+ Args:
263
+ user_id: User ID
264
+ book_isbn: Book ISBN
265
+ user_data: Additional user data (optional)
266
+ book_data: Additional book data (optional)
267
+
268
+ Returns:
269
+ Prediction probability (0-1)
270
+ """
271
+ if self.model is None:
272
+ print("โŒ Model not loaded")
273
+ return 0.0
274
+
275
+ try:
276
+ # Prepare features
277
+ dense_features, user_id_encoded, country_encoded, age_group = self._prepare_user_features(user_id, user_data)
278
+ book_id_encoded, publisher_encoded, decade_encoded, rating_level = self._prepare_book_features(book_isbn, book_data)
279
+
280
+ # Create sparse features
281
+ kjt_values = [user_id_encoded, book_id_encoded, publisher_encoded, country_encoded, age_group, decade_encoded, rating_level]
282
+ kjt_lengths = [1] * len(kjt_values)
283
+
284
+ sparse_features = KeyedJaggedTensor.from_lengths_sync(
285
+ self.cat_cols,
286
+ torch.tensor(kjt_values),
287
+ torch.tensor(kjt_lengths, dtype=torch.int32),
288
+ )
289
+
290
+ # Make prediction
291
+ with torch.no_grad():
292
+ logits = self.model(dense_features=dense_features, sparse_features=sparse_features)
293
+ prediction = torch.sigmoid(logits).item()
294
+
295
+ return prediction
296
+
297
+ except Exception as e:
298
+ print(f"Error in prediction: {e}")
299
+ return 0.0
300
+
301
+ def get_user_recommendations(self, user_id: int,
302
+ candidate_books: List[str],
303
+ k: int = 10,
304
+ user_data: Optional[Dict] = None) -> List[Tuple[str, float]]:
305
+ """
306
+ Get top-k book recommendations for a user
307
+
308
+ Args:
309
+ user_id: User ID
310
+ candidate_books: List of candidate book ISBNs
311
+ k: Number of recommendations
312
+ user_data: Additional user data
313
+
314
+ Returns:
315
+ List of (book_isbn, prediction_score) tuples
316
+ """
317
+ if self.model is None:
318
+ print("โŒ Model not loaded")
319
+ return []
320
+
321
+ recommendations = []
322
+
323
+ print(f"Generating recommendations for user {user_id} from {len(candidate_books)} candidates...")
324
+
325
+ for book_isbn in candidate_books:
326
+ score = self.predict_rating(user_id, book_isbn, user_data)
327
+ recommendations.append((book_isbn, score))
328
+
329
+ # Sort by score and return top-k
330
+ recommendations.sort(key=lambda x: x[1], reverse=True)
331
+ return recommendations[:k]
332
+
333
+ def batch_recommend(self, user_ids: List[int],
334
+ candidate_books: List[str],
335
+ k: int = 10) -> Dict[int, List[Tuple[str, float]]]:
336
+ """
337
+ Generate recommendations for multiple users
338
+
339
+ Args:
340
+ user_ids: List of user IDs
341
+ candidate_books: List of candidate book ISBNs
342
+ k: Number of recommendations per user
343
+
344
+ Returns:
345
+ Dictionary mapping user_id to recommendations
346
+ """
347
+ results = {}
348
+
349
+ for user_id in user_ids:
350
+ results[user_id] = self.get_user_recommendations(user_id, candidate_books, k)
351
+
352
+ return results
353
+
354
+ def get_similar_books(self, target_book_isbn: str,
355
+ candidate_books: List[str],
356
+ sample_users: List[int],
357
+ k: int = 10) -> List[Tuple[str, float]]:
358
+ """
359
+ Find books similar to target book by comparing user preferences
360
+
361
+ Args:
362
+ target_book_isbn: Target book ISBN
363
+ candidate_books: List of candidate book ISBNs
364
+ sample_users: Sample users to test similarity with
365
+ k: Number of similar books
366
+
367
+ Returns:
368
+ List of (book_isbn, similarity_score) tuples
369
+ """
370
+ target_scores = []
371
+ candidate_scores = {book: [] for book in candidate_books}
372
+
373
+ # Get predictions for target book and candidates across sample users
374
+ for user_id in sample_users:
375
+ target_score = self.predict_rating(user_id, target_book_isbn)
376
+ target_scores.append(target_score)
377
+
378
+ for book_isbn in candidate_books:
379
+ if book_isbn != target_book_isbn:
380
+ score = self.predict_rating(user_id, book_isbn)
381
+ candidate_scores[book_isbn].append(score)
382
+
383
+ # Calculate similarity based on correlation of user preferences
384
+ similarities = []
385
+ target_scores = np.array(target_scores)
386
+
387
+ for book_isbn, scores in candidate_scores.items():
388
+ if len(scores) > 0:
389
+ scores_array = np.array(scores)
390
+ # Calculate correlation as similarity measure
391
+ correlation = np.corrcoef(target_scores, scores_array)[0, 1]
392
+ if not np.isnan(correlation):
393
+ similarities.append((book_isbn, correlation))
394
+
395
+ # Sort by similarity and return top-k
396
+ similarities.sort(key=lambda x: x[1], reverse=True)
397
+ return similarities[:k]
398
+
399
+
400
+ def load_dlrm_recommender(model_source: str = "latest") -> DLRMBookRecommender:
401
+ """
402
+ Load DLRM recommender from various sources
403
+
404
+ Args:
405
+ model_source: "latest" for latest MLflow run, "file" for local file, or specific run_id
406
+
407
+ Returns:
408
+ DLRMBookRecommender instance
409
+ """
410
+ recommender = DLRMBookRecommender()
411
+
412
+ if model_source == "latest":
413
+ # Try to get latest MLflow run
414
+ try:
415
+ experiment = mlflow.get_experiment_by_name('dlrm-book-recommendation-book_recommender')
416
+ if experiment:
417
+ runs = mlflow.search_runs(experiment_ids=[experiment.experiment_id],
418
+ order_by=["start_time desc"], max_results=1)
419
+ if len(runs) > 0:
420
+ latest_run_id = runs.iloc[0].run_id
421
+ recommender = DLRMBookRecommender(run_id=latest_run_id)
422
+ return recommender
423
+ except:
424
+ pass
425
+
426
+ elif model_source == "file":
427
+ # Try to load from local file
428
+ for filename in ['dlrm_book_model_final.pth', 'dlrm_book_model_epoch_2.pth', 'dlrm_book_model_epoch_1.pth']:
429
+ if os.path.exists(filename):
430
+ recommender = DLRMBookRecommender(model_path=filename)
431
+ return recommender
432
+
433
+ else:
434
+ # Treat as run_id
435
+ recommender = DLRMBookRecommender(run_id=model_source)
436
+ return recommender
437
+
438
+ print("โš ๏ธ Could not load any trained model")
439
+ return recommender
440
+
441
+
442
+ def demo_dlrm_recommendations():
443
+ """Demo function to show DLRM recommendations"""
444
+
445
+ print("๐Ÿš€ DLRM Book Recommendation Demo")
446
+ print("=" * 50)
447
+
448
+ # Load book data for demo
449
+ books_df = pd.read_csv('Books.csv', encoding='latin-1', low_memory=False)
450
+ users_df = pd.read_csv('Users.csv', encoding='latin-1', low_memory=False)
451
+ ratings_df = pd.read_csv('Ratings.csv', encoding='latin-1', low_memory=False)
452
+
453
+ books_df.columns = books_df.columns.str.replace('"', '')
454
+ users_df.columns = users_df.columns.str.replace('"', '')
455
+ ratings_df.columns = ratings_df.columns.str.replace('"', '')
456
+
457
+ # Load recommender
458
+ recommender = load_dlrm_recommender("file")
459
+
460
+ if recommender.model is None:
461
+ print("โŒ No trained model found. Please run training first.")
462
+ return
463
+
464
+ # Get sample user and books
465
+ sample_user_id = ratings_df['User-ID'].iloc[0]
466
+ sample_books = books_df['ISBN'].head(20).tolist()
467
+
468
+ print(f"\n๐Ÿ“š Getting recommendations for User {sample_user_id}")
469
+ print(f"Testing with {len(sample_books)} candidate books...")
470
+
471
+ # Get recommendations
472
+ recommendations = recommender.get_user_recommendations(
473
+ user_id=sample_user_id,
474
+ candidate_books=sample_books,
475
+ k=10
476
+ )
477
+
478
+ print(f"\n๐ŸŽฏ Top 10 DLRM Recommendations:")
479
+ print("-" * 50)
480
+
481
+ for i, (book_isbn, score) in enumerate(recommendations, 1):
482
+ # Get book info
483
+ book_info = books_df[books_df['ISBN'] == book_isbn]
484
+ if len(book_info) > 0:
485
+ book = book_info.iloc[0]
486
+ title = book['Book-Title']
487
+ author = book['Book-Author']
488
+ print(f"{i:2d}. {title} by {author}")
489
+ print(f" ISBN: {book_isbn}, Score: {score:.4f}")
490
+ else:
491
+ print(f"{i:2d}. ISBN: {book_isbn}, Score: {score:.4f}")
492
+ print()
493
+
494
+ # Show user's actual ratings for comparison
495
+ user_ratings = ratings_df[ratings_df['User-ID'] == sample_user_id]
496
+ if len(user_ratings) > 0:
497
+ print(f"\n๐Ÿ“– User {sample_user_id}'s Actual Reading History:")
498
+ print("-" * 50)
499
+
500
+ for _, rating in user_ratings.head(5).iterrows():
501
+ book_info = books_df[books_df['ISBN'] == rating['ISBN']]
502
+ if len(book_info) > 0:
503
+ book = book_info.iloc[0]
504
+ print(f"โ€ข {book['Book-Title']} by {book['Book-Author']} - Rating: {rating['Book-Rating']}/10")
505
+
506
+ # Test book similarity
507
+ if len(recommendations) > 0:
508
+ target_book = recommendations[0][0]
509
+ print(f"\n๐Ÿ” Finding books similar to: {target_book}")
510
+
511
+ similar_books = recommender.get_similar_books(
512
+ target_book_isbn=target_book,
513
+ candidate_books=sample_books,
514
+ sample_users=ratings_df['User-ID'].head(10).tolist(),
515
+ k=5
516
+ )
517
+
518
+ print(f"\n๐Ÿ“š Similar Books:")
519
+ print("-" * 30)
520
+ for i, (book_isbn, similarity) in enumerate(similar_books, 1):
521
+ book_info = books_df[books_df['ISBN'] == book_isbn]
522
+ if len(book_info) > 0:
523
+ book = book_info.iloc[0]
524
+ print(f"{i}. {book['Book-Title']} (similarity: {similarity:.3f})")
525
+
526
+ if __name__ == "__main__":
527
+ demo_dlrm_recommendations()