DanielKiani's picture
Initial commit of recommender system project
38ae75d
raw
history blame
1.92 kB
import pandas as pd
import numpy as np
import torch
import datetime
from models import SASRec
from utils import prepare_ground_truth, calculate_metrics, load_item_properties, load_category_tree, get_popular_items, show_user_recommendations
from data_prepare import prepare_data, SASRecDataset, SASRecDataModule
def main(checkpoint_path="checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt", data_folder="data/"):
"""
Main function to run the inference and qualitative analysis pipeline.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Loading model from checkpoint...")
best_model = SASRec.load_from_checkpoint(checkpoint_path)
best_model.to(device)
print("Preparing data...")
train_set, validation_set, test_set = prepare_data(data_folder=data_folder)
datamodule = SASRecDataModule(train_set, validation_set, test_set)
datamodule.setup()
item_category_map = load_item_properties(data_folder=data_folder)
category_parent_map = load_category_tree(data_folder=data_folder)
print("\nCalculating popular items for cold-start users...")
popular_items_list = get_popular_items(train_set, k=10)
users_in_train_history = set(datamodule.user_history.keys())
users_in_test_set = set(datamodule.test_df['visitorid'].unique())
valid_example_users = list(users_in_train_history.intersection(users_in_test_set))
print(f"\nFound {len(valid_example_users)} users for qualitative analysis.")
for user_id in valid_example_users[:3]:
show_user_recommendations(user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)
new_user_id = -999
show_user_recommendations(new_user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)
if __name__ == "__main__":
main()