rajaatif786 commited on
Commit
6248d60
·
verified ·
1 Parent(s): e908ffb

Update load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +49 -5
load_model.py CHANGED
@@ -1,9 +1,53 @@
1
  import os
2
- import git
3
- #if(os.path.isdir("./TickerExtraction")==False):
4
- git.Git("./").clone("https://huggingface.co/rajaatif786/TickerExtraction")
5
- print(os.path.exists("./rajaatif786/TickerExtraction/entity_model2.pt"))
6
- #st.write(os.listdir("./TickerExtraction/"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  #x = st.slider('Select a value')
8
  #st.write(x, 'squared is', x * x)
9
 
 
1
  import os
2
+ import os
3
+ import requests
4
+
5
+ def download_model_if_not_exists(model_url, destination_dir):
6
+ """Downloads the model from the given URL if it doesn't exist locally.
7
+
8
+ Args:
9
+ model_url (str): The URL of the model to download.
10
+ destination_dir (str): The directory to download the model to.
11
+
12
+ Returns:
13
+ bool: True if the model was downloaded, False if it already existed.
14
+ """
15
+
16
+ if not os.path.exists(destination_dir):
17
+ print(f"Model directory '{destination_dir}' not found. Creating it.")
18
+ os.makedirs(destination_dir)
19
+
20
+ model_path = os.path.join(destination_dir, "entity_model2.pt") # Assuming model filename
21
+ if os.path.exists(model_path):
22
+ print(f"Model already exists at '{model_path}'. Skipping download.")
23
+ return False
24
+
25
+ print(f"Downloading model from '{model_url}' to '{model_path}'...")
26
+ try:
27
+ response = requests.get(model_url, stream=True)
28
+ response.raise_for_status() # Raise an exception for non-2xx status codes
29
+ with open(model_path, 'wb') as f:
30
+ for chunk in response.iter_content(1024):
31
+ f.write(chunk)
32
+ print("Download complete.")
33
+ return True
34
+ except requests.exceptions.RequestException as e:
35
+ print(f"Error downloading model: {e}")
36
+ return False
37
+
38
+ # Assuming your model URL is https://huggingface.co/rajaatif786/TickerExtraction
39
+ model_url = "https://huggingface.co/rajaatif786/TickerExtraction"
40
+ model_dir = "./TickerExtraction" # Change this if needed
41
+
42
+ downloaded = download_model_if_not_exists(model_url, model_dir)
43
+
44
+ if downloaded:
45
+ # Import your entity extraction module (assuming it's in 'TickerExtraction')
46
+ from TickerExtraction import load_model, entity_extractor, LABEL_MAP
47
+
48
+ # Rest of your code using the imported functions...
49
+ else:
50
+ print("Model download failed. Please check the model URL and network connectivity.")
51
  #x = st.slider('Select a value')
52
  #st.write(x, 'squared is', x * x)
53