Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import datetime | |
| from models import SASRec | |
| from utils import prepare_ground_truth, calculate_metrics, load_item_properties, load_category_tree, get_popular_items, show_user_recommendations | |
| from data_prepare import prepare_data, SASRecDataset, SASRecDataModule | |
| def main(checkpoint_path="checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt", data_folder="data/"): | |
| """ | |
| Main function to run the inference and qualitative analysis pipeline. | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| print("Loading model from checkpoint...") | |
| best_model = SASRec.load_from_checkpoint(checkpoint_path) | |
| best_model.to(device) | |
| print("Preparing data...") | |
| train_set, validation_set, test_set = prepare_data(data_folder=data_folder) | |
| datamodule = SASRecDataModule(train_set, validation_set, test_set) | |
| datamodule.setup() | |
| item_category_map = load_item_properties(data_folder=data_folder) | |
| category_parent_map = load_category_tree(data_folder=data_folder) | |
| print("\nCalculating popular items for cold-start users...") | |
| popular_items_list = get_popular_items(train_set, k=10) | |
| users_in_train_history = set(datamodule.user_history.keys()) | |
| users_in_test_set = set(datamodule.test_df['visitorid'].unique()) | |
| valid_example_users = list(users_in_train_history.intersection(users_in_test_set)) | |
| print(f"\nFound {len(valid_example_users)} users for qualitative analysis.") | |
| for user_id in valid_example_users[:3]: | |
| show_user_recommendations(user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map) | |
| new_user_id = -999 | |
| show_user_recommendations(new_user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map) | |
| if __name__ == "__main__": | |
| main() |