Spaces:
Runtime error
Runtime error
| import os | |
| from azure.storage.blob import BlobServiceClient | |
| from dotenv import load_dotenv | |
| from tqdm import tqdm | |
| import requests | |
| from huggingface_hub import HfApi, upload_file | |
| # Load environment variables | |
| load_dotenv() | |
| def upload_to_hf_spaces(local_path, repo_id, token): | |
| """Upload file to Hugging Face Spaces""" | |
| try: | |
| api = HfApi() | |
| print(f"\nUploading {os.path.basename(local_path)} to {repo_id}...") | |
| response = api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=os.path.basename(local_path), | |
| repo_id=repo_id, | |
| token=token | |
| ) | |
| print(f"Successfully uploaded to {response}") | |
| return True | |
| except Exception as e: | |
| print(f"Error uploading to Hugging Face: {str(e)}") | |
| return False | |
| def download_with_progress(blob_client, local_path): | |
| """Download blob with improved progress tracking""" | |
| try: | |
| properties = blob_client.get_blob_properties() | |
| total_size = properties.size | |
| # Configure progress bar | |
| progress = tqdm( | |
| total=total_size, | |
| unit='B', | |
| unit_scale=True, | |
| desc=f"Downloading {os.path.basename(local_path)}", | |
| ncols=80 | |
| ) | |
| # Download in smaller chunks (5MB) | |
| chunk_size = 5 * 1024 * 1024 # 5MB chunks | |
| with open(local_path, "wb") as file: | |
| download_stream = blob_client.download_blob() | |
| for chunk in download_stream.chunks(chunk_size=chunk_size): | |
| if chunk: | |
| file.write(chunk) | |
| progress.update(len(chunk)) | |
| progress.close() | |
| return True | |
| except Exception as e: | |
| print(f"\nError during download: {str(e)}") | |
| # Clean up partial download | |
| if os.path.exists(local_path): | |
| os.remove(local_path) | |
| return False | |
| def download_from_azure(upload_to_hf=False): | |
| """Download models from Azure Blob Storage with optional HF upload""" | |
| try: | |
| # Get connection strings | |
| connect_str = os.getenv('AZURE_STORAGE_PRIMARY_CONNECTION') | |
| if not connect_str: | |
| connect_str = os.getenv('AZURE_STORAGE_SECONDARY_CONNECTION') | |
| container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME', 'pnid-models') | |
| hf_token = os.getenv('HF_TOKEN') | |
| hf_repo = os.getenv('HF_REPO_ID') | |
| print("Connecting to Azure Blob Storage...") | |
| blob_service_client = BlobServiceClient.from_connection_string(connect_str) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| os.makedirs('models', exist_ok=True) | |
| model_files = { | |
| 'Intui_SDM_41.pt': os.getenv('MODEL_SDM_41_PATH'), | |
| 'Intui_SDM_30.pt': os.getenv('MODEL_SDM_30_PATH'), | |
| 'Intui_SDM_20.pt': os.getenv('MODEL_SDM_20_PATH'), | |
| 'deeplsd_md.tar': os.getenv('MODEL_DEEPLSD_PATH'), | |
| 'craft_mlt_25k.pth': os.getenv('MODEL_CRAFT_PATH'), | |
| 'english_g2.pth': os.getenv('MODEL_ENGLISH_PATH'), | |
| 'intui_LDM_01.pt': os.getenv('MODEL_LDM_PATH') | |
| } | |
| for blob_name, local_path in model_files.items(): | |
| if not local_path: | |
| continue | |
| print(f"\nProcessing {blob_name}...") | |
| os.makedirs(os.path.dirname(local_path), exist_ok=True) | |
| # Check if file needs to be downloaded | |
| needs_download = ( | |
| not os.path.exists(local_path) or | |
| os.path.getsize(local_path) == 0 | |
| ) | |
| if needs_download: | |
| blob_client = container_client.get_blob_client(blob_name) | |
| success = download_with_progress(blob_client, local_path) | |
| if not success: | |
| print(f"Failed to download {blob_name}") | |
| continue | |
| print(f"Successfully downloaded {blob_name}") | |
| else: | |
| print(f"Skipping {blob_name}, already exists") | |
| # Upload to HF Spaces if requested | |
| if upload_to_hf and hf_token and hf_repo: | |
| upload_to_hf_spaces(local_path, hf_repo, hf_token) | |
| print("\nAll operations completed!") | |
| except Exception as e: | |
| print(f"\nError in main process: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| # Set to True if you want to upload to HF Spaces | |
| download_from_azure(upload_to_hf=False) |