msIntui commited on
Commit
897d2f8
Β·
1 Parent(s): 984fc5d

Add Hugging Face Spaces upload functionality

Browse files

- Add direct upload to HF Spaces
- Improve download progress tracking
- Add better error handling
- Implement smaller chunk sizes for downloads

Files changed (1) hide show
  1. download_models.py +81 -38
download_models.py CHANGED
@@ -2,51 +2,80 @@ import os
2
  from azure.storage.blob import BlobServiceClient
3
  from dotenv import load_dotenv
4
  from tqdm import tqdm
 
 
5
 
6
  # Load environment variables
7
  load_dotenv()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def download_with_progress(blob_client, local_path):
10
- """Download blob with progress bar"""
11
- # Get blob properties for total size
12
- properties = blob_client.get_blob_properties()
13
- total_size = properties.size
14
-
15
- # Create progress bar
16
- progress = tqdm(
17
- total=total_size,
18
- unit='iB',
19
- unit_scale=True,
20
- desc=f"Downloading {os.path.basename(local_path)}"
21
- )
22
-
23
- # Download in chunks
24
- with open(local_path, "wb") as file:
25
- download_stream = blob_client.download_blob()
26
- for chunk in download_stream.chunks():
27
- file.write(chunk)
28
- progress.update(len(chunk))
29
- progress.close()
30
-
31
- def download_from_azure():
32
- """Download models from Azure Blob Storage"""
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
- # Try primary connection first, fallback to secondary
35
  connect_str = os.getenv('AZURE_STORAGE_PRIMARY_CONNECTION')
36
  if not connect_str:
37
  connect_str = os.getenv('AZURE_STORAGE_SECONDARY_CONNECTION')
38
-
39
  container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME', 'pnid-models')
 
 
40
 
41
  print("Connecting to Azure Blob Storage...")
42
- # Create the BlobServiceClient
43
  blob_service_client = BlobServiceClient.from_connection_string(connect_str)
44
  container_client = blob_service_client.get_container_client(container_name)
45
 
46
- # Create models directory
47
  os.makedirs('models', exist_ok=True)
48
 
49
- # Define model files to download
50
  model_files = {
51
  'Intui_SDM_41.pt': os.getenv('MODEL_SDM_41_PATH'),
52
  'Intui_SDM_30.pt': os.getenv('MODEL_SDM_30_PATH'),
@@ -57,27 +86,41 @@ def download_from_azure():
57
  'intui_LDM_01.pt': os.getenv('MODEL_LDM_PATH')
58
  }
59
 
60
- # Download each model
61
  for blob_name, local_path in model_files.items():
62
  if not local_path:
63
  continue
64
-
65
- print(f"\nChecking {blob_name}...")
66
  os.makedirs(os.path.dirname(local_path), exist_ok=True)
67
-
68
- if not os.path.exists(local_path) or os.path.getsize(local_path) == 0:
 
 
 
 
 
 
69
  blob_client = container_client.get_blob_client(blob_name)
70
- print(f"Starting download of {blob_name}...")
71
- download_with_progress(blob_client, local_path)
 
 
 
 
72
  print(f"Successfully downloaded {blob_name}")
73
  else:
74
  print(f"Skipping {blob_name}, already exists")
75
 
76
- print("\nAll models downloaded successfully!")
 
 
 
 
77
 
78
  except Exception as e:
79
- print(f"\nError downloading models from Azure: {str(e)}")
80
  raise
81
 
82
  if __name__ == "__main__":
83
- download_from_azure()
 
 
2
  from azure.storage.blob import BlobServiceClient
3
  from dotenv import load_dotenv
4
  from tqdm import tqdm
5
+ import requests
6
+ from huggingface_hub import HfApi, upload_file
7
 
8
  # Load environment variables
9
  load_dotenv()
10
 
11
+ def upload_to_hf_spaces(local_path, repo_id, token):
12
+ """Upload file to Hugging Face Spaces"""
13
+ try:
14
+ api = HfApi()
15
+ print(f"\nUploading {os.path.basename(local_path)} to {repo_id}...")
16
+ response = api.upload_file(
17
+ path_or_fileobj=local_path,
18
+ path_in_repo=os.path.basename(local_path),
19
+ repo_id=repo_id,
20
+ token=token
21
+ )
22
+ print(f"Successfully uploaded to {response}")
23
+ return True
24
+ except Exception as e:
25
+ print(f"Error uploading to Hugging Face: {str(e)}")
26
+ return False
27
+
28
  def download_with_progress(blob_client, local_path):
29
+ """Download blob with improved progress tracking"""
30
+ try:
31
+ properties = blob_client.get_blob_properties()
32
+ total_size = properties.size
33
+
34
+ # Configure progress bar
35
+ progress = tqdm(
36
+ total=total_size,
37
+ unit='B',
38
+ unit_scale=True,
39
+ desc=f"Downloading {os.path.basename(local_path)}",
40
+ ncols=80
41
+ )
42
+
43
+ # Download in smaller chunks (5MB)
44
+ chunk_size = 5 * 1024 * 1024 # 5MB chunks
45
+ with open(local_path, "wb") as file:
46
+ download_stream = blob_client.download_blob()
47
+ for chunk in download_stream.chunks(chunk_size=chunk_size):
48
+ if chunk:
49
+ file.write(chunk)
50
+ progress.update(len(chunk))
51
+
52
+ progress.close()
53
+ return True
54
+ except Exception as e:
55
+ print(f"\nError during download: {str(e)}")
56
+ # Clean up partial download
57
+ if os.path.exists(local_path):
58
+ os.remove(local_path)
59
+ return False
60
+
61
+ def download_from_azure(upload_to_hf=False):
62
+ """Download models from Azure Blob Storage with optional HF upload"""
63
  try:
64
+ # Get connection strings
65
  connect_str = os.getenv('AZURE_STORAGE_PRIMARY_CONNECTION')
66
  if not connect_str:
67
  connect_str = os.getenv('AZURE_STORAGE_SECONDARY_CONNECTION')
68
+
69
  container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME', 'pnid-models')
70
+ hf_token = os.getenv('HF_TOKEN')
71
+ hf_repo = os.getenv('HF_REPO_ID')
72
 
73
  print("Connecting to Azure Blob Storage...")
 
74
  blob_service_client = BlobServiceClient.from_connection_string(connect_str)
75
  container_client = blob_service_client.get_container_client(container_name)
76
 
 
77
  os.makedirs('models', exist_ok=True)
78
 
 
79
  model_files = {
80
  'Intui_SDM_41.pt': os.getenv('MODEL_SDM_41_PATH'),
81
  'Intui_SDM_30.pt': os.getenv('MODEL_SDM_30_PATH'),
 
86
  'intui_LDM_01.pt': os.getenv('MODEL_LDM_PATH')
87
  }
88
 
 
89
  for blob_name, local_path in model_files.items():
90
  if not local_path:
91
  continue
92
+
93
+ print(f"\nProcessing {blob_name}...")
94
  os.makedirs(os.path.dirname(local_path), exist_ok=True)
95
+
96
+ # Check if file needs to be downloaded
97
+ needs_download = (
98
+ not os.path.exists(local_path) or
99
+ os.path.getsize(local_path) == 0
100
+ )
101
+
102
+ if needs_download:
103
  blob_client = container_client.get_blob_client(blob_name)
104
+ success = download_with_progress(blob_client, local_path)
105
+
106
+ if not success:
107
+ print(f"Failed to download {blob_name}")
108
+ continue
109
+
110
  print(f"Successfully downloaded {blob_name}")
111
  else:
112
  print(f"Skipping {blob_name}, already exists")
113
 
114
+ # Upload to HF Spaces if requested
115
+ if upload_to_hf and hf_token and hf_repo:
116
+ upload_to_hf_spaces(local_path, hf_repo, hf_token)
117
+
118
+ print("\nAll operations completed!")
119
 
120
  except Exception as e:
121
+ print(f"\nError in main process: {str(e)}")
122
  raise
123
 
124
  if __name__ == "__main__":
125
+ # Set to True if you want to upload to HF Spaces
126
+ download_from_azure(upload_to_hf=False)