msIntui commited on
Commit
984fc5d
Β·
1 Parent(s): e2c1993

feat: add progress bar for model downloads and update deployment files

Browse files
Files changed (1) hide show
  1. download_models.py +31 -6
download_models.py CHANGED
@@ -1,10 +1,33 @@
1
  import os
2
  from azure.storage.blob import BlobServiceClient
3
  from dotenv import load_dotenv
 
4
 
5
  # Load environment variables
6
  load_dotenv()
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def download_from_azure():
9
  """Download models from Azure Blob Storage"""
10
  try:
@@ -15,6 +38,7 @@ def download_from_azure():
15
 
16
  container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME', 'pnid-models')
17
 
 
18
  # Create the BlobServiceClient
19
  blob_service_client = BlobServiceClient.from_connection_string(connect_str)
20
  container_client = blob_service_client.get_container_client(container_name)
@@ -38,20 +62,21 @@ def download_from_azure():
38
  if not local_path:
39
  continue
40
 
41
- print(f"Downloading {blob_name}...")
42
  os.makedirs(os.path.dirname(local_path), exist_ok=True)
43
 
44
  if not os.path.exists(local_path) or os.path.getsize(local_path) == 0:
45
  blob_client = container_client.get_blob_client(blob_name)
46
- with open(local_path, "wb") as file:
47
- data = blob_client.download_blob()
48
- file.write(data.readall())
49
- print(f"Downloaded {blob_name} to {local_path}")
50
  else:
51
  print(f"Skipping {blob_name}, already exists")
52
 
 
 
53
  except Exception as e:
54
- print(f"Error downloading models from Azure: {str(e)}")
55
  raise
56
 
57
  if __name__ == "__main__":
 
1
  import os
2
  from azure.storage.blob import BlobServiceClient
3
  from dotenv import load_dotenv
4
+ from tqdm import tqdm
5
 
6
  # Load environment variables
7
  load_dotenv()
8
 
9
+ def download_with_progress(blob_client, local_path):
10
+ """Download blob with progress bar"""
11
+ # Get blob properties for total size
12
+ properties = blob_client.get_blob_properties()
13
+ total_size = properties.size
14
+
15
+ # Create progress bar
16
+ progress = tqdm(
17
+ total=total_size,
18
+ unit='iB',
19
+ unit_scale=True,
20
+ desc=f"Downloading {os.path.basename(local_path)}"
21
+ )
22
+
23
+ # Download in chunks
24
+ with open(local_path, "wb") as file:
25
+ download_stream = blob_client.download_blob()
26
+ for chunk in download_stream.chunks():
27
+ file.write(chunk)
28
+ progress.update(len(chunk))
29
+ progress.close()
30
+
31
  def download_from_azure():
32
  """Download models from Azure Blob Storage"""
33
  try:
 
38
 
39
  container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME', 'pnid-models')
40
 
41
+ print("Connecting to Azure Blob Storage...")
42
  # Create the BlobServiceClient
43
  blob_service_client = BlobServiceClient.from_connection_string(connect_str)
44
  container_client = blob_service_client.get_container_client(container_name)
 
62
  if not local_path:
63
  continue
64
 
65
+ print(f"\nChecking {blob_name}...")
66
  os.makedirs(os.path.dirname(local_path), exist_ok=True)
67
 
68
  if not os.path.exists(local_path) or os.path.getsize(local_path) == 0:
69
  blob_client = container_client.get_blob_client(blob_name)
70
+ print(f"Starting download of {blob_name}...")
71
+ download_with_progress(blob_client, local_path)
72
+ print(f"Successfully downloaded {blob_name}")
 
73
  else:
74
  print(f"Skipping {blob_name}, already exists")
75
 
76
+ print("\nAll models downloaded successfully!")
77
+
78
  except Exception as e:
79
+ print(f"\nError downloading models from Azure: {str(e)}")
80
  raise
81
 
82
  if __name__ == "__main__":