Duskfallcrew commited on
Commit
5b28d86
·
verified ·
1 Parent(s): d4bcc30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -10
app.py CHANGED
@@ -23,7 +23,9 @@ from typing import Dict, List, Optional
23
  from huggingface_hub import login, HfApi, hf_hub_download # Import hf_hub_download
24
  from huggingface_hub.utils import validate_repo_id, HFValidationError
25
  from huggingface_hub.errors import HfHubHTTPError
26
-
 
 
27
 
28
  # ---------------------- DEPENDENCIES ----------------------
29
  def install_dependencies_gradio():
@@ -58,26 +60,50 @@ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
58
 
59
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
60
  def download_model(model_path_or_url):
61
- """Downloads a model, handling URLs, HF repos, and local paths."""
62
  try:
63
- # Check if it's a valid Hugging Face repo ID (and potentially a file within)
64
  try:
65
  validate_repo_id(model_path_or_url)
66
- # It's a valid repo ID; use hf_hub_download without a filename
67
  local_path = hf_hub_download(repo_id=model_path_or_url)
68
  return local_path
69
  except HFValidationError:
70
- pass # Not a simple repo ID. Might be repo ID + filename, or a URL.
71
 
 
72
  if model_path_or_url.startswith("http://") or model_path_or_url.startswith("https://"):
73
- # It's a URL: use hf_hub_download with the url parameter
74
- local_path = hf_hub_download(url=model_path_or_url) # Corrected line
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return local_path
 
 
76
  elif os.path.isfile(model_path_or_url):
77
- # It's a local file
78
  return model_path_or_url
 
 
79
  else:
80
- # Try splitting into repo ID and filename
81
  try:
82
  parts = model_path_or_url.split("/", 1)
83
  if len(parts) == 2:
@@ -86,7 +112,8 @@ def download_model(model_path_or_url):
86
  local_path = hf_hub_download(repo_id=repo_id, filename=filename)
87
  return local_path
88
  else:
89
- raise ValueError("Invalid input format")
 
90
  except HFValidationError:
91
  raise ValueError(f"Invalid model path or URL: {model_path_or_url}")
92
 
 
23
  from huggingface_hub import login, HfApi, hf_hub_download # Import hf_hub_download
24
  from huggingface_hub.utils import validate_repo_id, HFValidationError
25
  from huggingface_hub.errors import HfHubHTTPError
26
+ from huggingface_hub import HfApi, hf_hub_download, cached_download, get_from_cache # Import cached_download and get_from_cache
27
+ from huggingface_hub.utils import validate_repo_id, HFValidationError
28
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
29
 
30
  # ---------------------- DEPENDENCIES ----------------------
31
  def install_dependencies_gradio():
 
60
 
61
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
62
  def download_model(model_path_or_url):
63
+ """Downloads a model, handling URLs, HF repos, and local paths, caching appropriately."""
64
  try:
65
+ # 1. Check if it's a valid Hugging Face repo ID (and potentially a file within)
66
  try:
67
  validate_repo_id(model_path_or_url)
68
+ # It's a valid repo ID; use hf_hub_download
69
  local_path = hf_hub_download(repo_id=model_path_or_url)
70
  return local_path
71
  except HFValidationError:
72
+ pass # Not a simple repo ID. Might be repo ID + filename, or a URL.
73
 
74
+ # 2. Check if it's a URL
75
  if model_path_or_url.startswith("http://") or model_path_or_url.startswith("https://"):
76
+ # Check if it's already in the cache
77
+ cache_path = get_from_cache(model_path_or_url) # Use get_from_cache
78
+ if cache_path is not None:
79
+ return cache_path
80
+
81
+ # It's a URL and not in cache: download manually and put into HF cache
82
+ response = requests.get(model_path_or_url, stream=True)
83
+ response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
84
+
85
+ # Get filename from URL, or use a hash if we can't determine it
86
+ parsed_url = urlparse(model_path_or_url)
87
+ filename = os.path.basename(unquote(parsed_url.path))
88
+ if not filename:
89
+ filename = hashlib.sha256(model_path_or_url.encode()).hexdigest()
90
+
91
+ # Construct the cache path (using HF_HUB_CACHE + "downloads" )
92
+ cache_dir = os.path.join(HUGGINGFACE_HUB_CACHE, "downloads")
93
+ os.makedirs(cache_dir, exist_ok=True) # Ensure the cache directory exists
94
+ local_path = os.path.join(cache_dir, filename)
95
+
96
+ with open(local_path, "wb") as f:
97
+ for chunk in response.iter_content(chunk_size=8192):
98
+ f.write(chunk)
99
  return local_path
100
+
101
+ # 3. Check if it's a local file
102
  elif os.path.isfile(model_path_or_url):
 
103
  return model_path_or_url
104
+
105
+ # 4. Handle Hugging Face repo with a specific file
106
  else:
 
107
  try:
108
  parts = model_path_or_url.split("/", 1)
109
  if len(parts) == 2:
 
112
  local_path = hf_hub_download(repo_id=repo_id, filename=filename)
113
  return local_path
114
  else:
115
+ raise ValueError("Invalid input format.")
116
+
117
  except HFValidationError:
118
  raise ValueError(f"Invalid model path or URL: {model_path_or_url}")
119