lenawilli commited on
Commit
6f0fc78
·
verified ·
1 Parent(s): 7f34736

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +12 -36
src/streamlit_app.py CHANGED
@@ -20,50 +20,26 @@ from recommendation_utils import (
20
  # recommend_with_nn, recommend_with_svd, load_encodings
21
  #)
22
 
23
- @st.cache_resource(show_spinner=False)
24
- def download_models_once():
25
- DOWNLOAD_DIR = "/tmp"
26
-
27
- files = {
28
- #"encodings.pkl": "1EzpdpaopfUp-Tfc7YjxPVYQUwnU_BX5-",
29
- #"config.json": "1ZcTrVR0QtS-5EL4amsTR_y9xITDA6zlJ",
30
- #"model.weights.h5": "1VjjUx_7ulIVM-W1lqH-nDHI7HWfpcMQP",
31
- "svd_model.pkl": "1fN2biQruVjJHHv2vX1g1hLuMeJoPqyFX",
32
- "trainset.pkl": "1IDVVAQ57Xvf3HCAikbOSQgAigdHP7Ik7"
33
- }
34
 
35
- def gdrive_download(file_id, destination):
36
- URL = "https://drive.google.com/uc?export=download"
37
- session = requests.Session()
38
- response = session.get(URL, params={'id': file_id}, stream=True)
39
- token = None
40
- for k, v in response.cookies.items():
41
- if k.startswith('download_warning'):
42
- token = v
43
- if token:
44
- response = session.get(URL, params={'id': file_id, 'confirm': token}, stream=True)
45
- with open(destination, "wb") as f:
46
- for chunk in response.iter_content(32768):
47
- if chunk:
48
- f.write(chunk)
49
-
50
- for filename, file_id in files.items():
51
- full_path = os.path.join(DOWNLOAD_DIR, filename)
52
- if not os.path.exists(full_path):
53
- with st.spinner(f"Downloading {filename}..."):
54
- gdrive_download(file_id, full_path)
55
-
56
- download_models_once()
57
 
58
  @st.cache_resource
59
  def load_models():
60
- # nn_model = load_nn_model("/tmp/config.json", "/tmp/model.weights.h5")
61
- svd_model = load_svd_model("/tmp/svd_model.pkl")
62
- trainset = load_trainset("/tmp/trainset.pkl")
 
 
 
63
  return svd_model, trainset
64
 
65
  svd_model, trainset = load_models()
66
 
 
67
  # encodings = load_encodings("/tmp/encodings.pkl")
68
 
69
  st.set_page_config(layout="wide")
 
20
  # recommend_with_nn, recommend_with_svd, load_encodings
21
  #)
22
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ @st.cache_resource
25
+ def load_remote_pickle(url):
26
+ response = requests.get(url)
27
+ response.raise_for_status()
28
+ return pickle.loads(response.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @st.cache_resource
31
  def load_models():
32
+ SVD_URL = "https://huggingface.co/lenawilli/App_models_Py/resolve/main/svd_model.pkl"
33
+ TRAINSET_URL = "https://huggingface.co/lenawilli/App_models_Py/resolve/main/trainset.pkl"
34
+
35
+ svd_model = load_remote_pickle(SVD_URL)
36
+ trainset = load_remote_pickle(TRAINSET_URL)
37
+
38
  return svd_model, trainset
39
 
40
  svd_model, trainset = load_models()
41
 
42
+
43
  # encodings = load_encodings("/tmp/encodings.pkl")
44
 
45
  st.set_page_config(layout="wide")