File size: 4,491 Bytes
910e0d4
e2c1993
910e0d4
984fc5d
897d2f8
 
910e0d4
 
 
 
897d2f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984fc5d
897d2f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2c1993
897d2f8
e2c1993
 
 
897d2f8
e2c1993
897d2f8
 
e2c1993
984fc5d
e2c1993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897d2f8
 
e2c1993
897d2f8
 
 
 
 
 
 
 
e2c1993
897d2f8
 
 
 
 
 
984fc5d
e2c1993
 
 
897d2f8
 
 
 
 
984fc5d
e2c1993
897d2f8
e2c1993
910e0d4
 
897d2f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from azure.storage.blob import BlobServiceClient
from dotenv import load_dotenv
from tqdm import tqdm
import requests
from huggingface_hub import HfApi, upload_file

# Load environment variables
load_dotenv()

def upload_to_hf_spaces(local_path, repo_id, token):
    """Upload file to Hugging Face Spaces"""
    try:
        api = HfApi()
        print(f"\nUploading {os.path.basename(local_path)} to {repo_id}...")
        response = api.upload_file(
            path_or_fileobj=local_path,
            path_in_repo=os.path.basename(local_path),
            repo_id=repo_id,
            token=token
        )
        print(f"Successfully uploaded to {response}")
        return True
    except Exception as e:
        print(f"Error uploading to Hugging Face: {str(e)}")
        return False

def download_with_progress(blob_client, local_path):
    """Download blob with improved progress tracking"""
    try:
        properties = blob_client.get_blob_properties()
        total_size = properties.size
        
        # Configure progress bar
        progress = tqdm(
            total=total_size,
            unit='B',
            unit_scale=True,
            desc=f"Downloading {os.path.basename(local_path)}",
            ncols=80
        )

        # Download in smaller chunks (5MB)
        chunk_size = 5 * 1024 * 1024  # 5MB chunks
        with open(local_path, "wb") as file:
            download_stream = blob_client.download_blob()
            for chunk in download_stream.chunks(chunk_size=chunk_size):
                if chunk:
                    file.write(chunk)
                    progress.update(len(chunk))
        
        progress.close()
        return True
    except Exception as e:
        print(f"\nError during download: {str(e)}")
        # Clean up partial download
        if os.path.exists(local_path):
            os.remove(local_path)
        return False

def download_from_azure(upload_to_hf=False):
    """Download models from Azure Blob Storage with optional HF upload"""
    try:
        # Get connection strings
        connect_str = os.getenv('AZURE_STORAGE_PRIMARY_CONNECTION')
        if not connect_str:
            connect_str = os.getenv('AZURE_STORAGE_SECONDARY_CONNECTION')
        
        container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME', 'pnid-models')
        hf_token = os.getenv('HF_TOKEN')
        hf_repo = os.getenv('HF_REPO_ID')

        print("Connecting to Azure Blob Storage...")
        blob_service_client = BlobServiceClient.from_connection_string(connect_str)
        container_client = blob_service_client.get_container_client(container_name)

        os.makedirs('models', exist_ok=True)

        model_files = {
            'Intui_SDM_41.pt': os.getenv('MODEL_SDM_41_PATH'),
            'Intui_SDM_30.pt': os.getenv('MODEL_SDM_30_PATH'),
            'Intui_SDM_20.pt': os.getenv('MODEL_SDM_20_PATH'),
            'deeplsd_md.tar': os.getenv('MODEL_DEEPLSD_PATH'),
            'craft_mlt_25k.pth': os.getenv('MODEL_CRAFT_PATH'),
            'english_g2.pth': os.getenv('MODEL_ENGLISH_PATH'),
            'intui_LDM_01.pt': os.getenv('MODEL_LDM_PATH')
        }

        for blob_name, local_path in model_files.items():
            if not local_path:
                continue

            print(f"\nProcessing {blob_name}...")
            os.makedirs(os.path.dirname(local_path), exist_ok=True)

            # Check if file needs to be downloaded
            needs_download = (
                not os.path.exists(local_path) or 
                os.path.getsize(local_path) == 0
            )

            if needs_download:
                blob_client = container_client.get_blob_client(blob_name)
                success = download_with_progress(blob_client, local_path)
                
                if not success:
                    print(f"Failed to download {blob_name}")
                    continue
                
                print(f"Successfully downloaded {blob_name}")
            else:
                print(f"Skipping {blob_name}, already exists")

            # Upload to HF Spaces if requested
            if upload_to_hf and hf_token and hf_repo:
                upload_to_hf_spaces(local_path, hf_repo, hf_token)

        print("\nAll operations completed!")

    except Exception as e:
        print(f"\nError in main process: {str(e)}")
        raise

if __name__ == "__main__":
    # Set to True if you want to upload to HF Spaces
    download_from_azure(upload_to_hf=False)