File size: 3,096 Bytes
e9f9fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import sys
import hashlib
import requests
from tqdm import tqdm
from pathlib import Path

# Configuration
MODELS_DIR = Path(__file__).parent.parent / "models"
BASE_URL_HF = "https://huggingface.co/thookham/DeOldify/resolve/main/"

# Model definitions with SHA256 hashes (from models.json)
MODELS = {
    "ColorizeArtistic_gen.pth": "3f750246fa220529323b85a8905f9b49c0e5d427099185334d048fb5b5e22477",
    "ColorizeStable_gen.pth": "ca9cd7f43fb8b222c9a70f7b292e305a000694b0ff9d2ae4a6747b1a2e1ee5af",
    "ColorizeVideo_gen.pth": "e3d98bb6a222354c79f95485c2f078a89dc724e9183662506d9e0f5aafac83ad",
    "deoldify-art.onnx": "be026e17c47c85527b3084cacad352f7ca0e021c33aa827062c5997ebe72c61f",
    "deoldify-quant.onnx": "35c3fb7ec52122e30e98ed03fa5082b175d0beb7951d62f8bc2178870229e44a"
}

def calculate_sha256(filepath):
    """Calculate SHA256 hash of a file."""
    sha256_hash = hashlib.sha256()
    with open(filepath, "rb") as f:
        for byte_block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(byte_block)
    return sha256_hash.hexdigest()

def download_file(url, filepath, expected_hash=None):
    """Download a file with progress bar and hash verification."""
    print(f"Downloading {filepath.name}...")
    
    response = requests.get(url, stream=True)
    response.raise_for_status()
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filepath, 'wb') as f, tqdm(
        desc=filepath.name,
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            size = f.write(data)
            bar.update(size)
            
    if expected_hash:
        print("Verifying hash...")
        file_hash = calculate_sha256(filepath)
        if file_hash != expected_hash:
            print(f"❌ Hash mismatch for {filepath.name}!")
            print(f"Expected: {expected_hash}")
            print(f"Got:      {file_hash}")
            return False
        print("✅ Hash verified.")
    return True

def main():
    MODELS_DIR.mkdir(exist_ok=True)
    print(f"Checking models in {MODELS_DIR}...")
    
    for filename, expected_hash in MODELS.items():
        filepath = MODELS_DIR / filename
        
        if filepath.exists():
            print(f"\nChecking existing file: {filename}")
            current_hash = calculate_sha256(filepath)
            if current_hash == expected_hash:
                print(f"✅ {filename} is up to date.")
                continue
            else:
                print(f"⚠️ {filename} exists but hash mismatch. Re-downloading...")
        else:
            print(f"\nMissing file: {filename}")
            
        url = BASE_URL_HF + filename
        try:
            success = download_file(url, filepath, expected_hash)
            if not success:
                print(f"❌ Failed to verify {filename}")
        except Exception as e:
            print(f"❌ Error downloading {filename}: {e}")

    print("\nDone!")

if __name__ == "__main__":
    main()