Daniel kiani commited on
Commit
66ca941
·
verified ·
1 Parent(s): c6233c8

Updated app.py to use the RetailRocket-Recommender-Data in huggingface datasets

Browse files
Files changed (1) hide show
  1. scripts/app.py +31 -7
scripts/app.py CHANGED
@@ -3,8 +3,12 @@ import torch
3
  import numpy as np
4
  import pandas as pd
5
  from datetime import datetime, timedelta
 
 
 
 
6
  from models import SASRec
7
- from data_prepare import SASRecDataset, SASRecDataModule, prepare_data
8
  from utils import load_item_properties, load_category_tree, get_popular_items
9
 
10
  # --- Global variables to hold loaded artifacts ---
@@ -16,29 +20,49 @@ CATEGORY_PARENT_MAP = None
16
  POPULAR_ITEMS = None
17
 
18
  # --- Data Loading and Preparation Functions ---
19
-
20
  def load_artifacts():
21
  """
22
- Loads all necessary artifacts (model, data, mappings) into global variables.
 
23
  This function is called only once when the app starts.
24
  """
25
  global MODEL, DATAMODULE, ITEM_CATEGORY_MAP, CATEGORY_PARENT_MAP, POPULAR_ITEMS
26
 
27
  print("--- Loading all artifacts for the Gradio app ---")
28
 
29
- # HF-FRIENDLY: Path is relative, assuming the checkpoint is in the root of the Space repo.
30
  CHECKPOINT_PATH = "checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt"
31
  DATA_FOLDER = "data/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
  print(f"Using device: {device}")
35
-
36
  print(f"Loading model from checkpoint: {CHECKPOINT_PATH}...")
37
  MODEL = SASRec.load_from_checkpoint(CHECKPOINT_PATH)
38
  MODEL.to(device)
39
  MODEL.eval()
40
 
41
- print("Preparing data...")
42
  train_set, validation_set, test_set = prepare_data(data_folder=DATA_FOLDER)
43
 
44
  DATAMODULE = SASRecDataModule(train_set, validation_set, test_set)
@@ -156,7 +180,7 @@ if __name__ == "__main__":
156
  inputs=visitor_id_input,
157
  outputs=[history_output, recs_output, status_message]
158
  )
159
-
160
  # For local testing, this creates a shareable link.
161
  # On Hugging Face Spaces, this is not strictly necessary but doesn't hurt.
162
  iface.launch(share=True)
 
3
  import numpy as np
4
  import pandas as pd
5
  from datetime import datetime, timedelta
6
+ import os
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # Import from your project's modules
10
  from models import SASRec
11
+ from data_prepare import SASRecDataModule, prepare_data
12
  from utils import load_item_properties, load_category_tree, get_popular_items
13
 
14
  # --- Global variables to hold loaded artifacts ---
 
20
  POPULAR_ITEMS = None
21
 
22
  # --- Data Loading and Preparation Functions ---
 
23
  def load_artifacts():
24
  """
25
+ Downloads data from Hugging Face Hub, then loads all necessary artifacts
26
+ (model, data, mappings) into global variables.
27
  This function is called only once when the app starts.
28
  """
29
  global MODEL, DATAMODULE, ITEM_CATEGORY_MAP, CATEGORY_PARENT_MAP, POPULAR_ITEMS
30
 
31
  print("--- Loading all artifacts for the Gradio app ---")
32
 
33
+ # Configuration
34
  CHECKPOINT_PATH = "checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt"
35
  DATA_FOLDER = "data/"
36
+ DATA_REPO_ID = "Deathshot78/RetailRocket-Recommender-Data"
37
+
38
+ # --- Download Data from Hugging Face Hub ---
39
+ print(f"Downloading data from Hugging Face Hub repo: {DATA_REPO_ID}")
40
+ os.makedirs(DATA_FOLDER, exist_ok=True)
41
+
42
+ files_to_download = [
43
+ "events.csv", "item_properties_part1.csv",
44
+ "item_properties_part2.csv", "category_tree.csv"
45
+ ]
46
+
47
+ for filename in files_to_download:
48
+ hf_hub_download(
49
+ repo_id=DATA_REPO_ID,
50
+ filename=f"data/{filename}", # Path within the dataset repo
51
+ local_dir=".", # Download to the root of the Space
52
+ repo_type="dataset"
53
+ )
54
+ print("All data files downloaded successfully.")
55
+ # --- End of Download Logic ---
56
 
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
  print(f"Using device: {device}")
59
+
60
  print(f"Loading model from checkpoint: {CHECKPOINT_PATH}...")
61
  MODEL = SASRec.load_from_checkpoint(CHECKPOINT_PATH)
62
  MODEL.to(device)
63
  MODEL.eval()
64
 
65
+ print("Preparing data from downloaded files...")
66
  train_set, validation_set, test_set = prepare_data(data_folder=DATA_FOLDER)
67
 
68
  DATAMODULE = SASRecDataModule(train_set, validation_set, test_set)
 
180
  inputs=visitor_id_input,
181
  outputs=[history_output, recs_output, status_message]
182
  )
183
+
184
  # For local testing, this creates a shareable link.
185
  # On Hugging Face Spaces, this is not strictly necessary but doesn't hurt.
186
  iface.launch(share=True)