DeOldify / scripts /download_models.py
thookham's picture
Initial commit for Hugging Face sync (Clean History)
e9f9fd3
import os
import sys
import hashlib
import requests
from tqdm import tqdm
from pathlib import Path
# Configuration
MODELS_DIR = Path(__file__).parent.parent / "models"
BASE_URL_HF = "https://huggingface.co/thookham/DeOldify/resolve/main/"
# Model definitions with SHA256 hashes (from models.json)
MODELS = {
"ColorizeArtistic_gen.pth": "3f750246fa220529323b85a8905f9b49c0e5d427099185334d048fb5b5e22477",
"ColorizeStable_gen.pth": "ca9cd7f43fb8b222c9a70f7b292e305a000694b0ff9d2ae4a6747b1a2e1ee5af",
"ColorizeVideo_gen.pth": "e3d98bb6a222354c79f95485c2f078a89dc724e9183662506d9e0f5aafac83ad",
"deoldify-art.onnx": "be026e17c47c85527b3084cacad352f7ca0e021c33aa827062c5997ebe72c61f",
"deoldify-quant.onnx": "35c3fb7ec52122e30e98ed03fa5082b175d0beb7951d62f8bc2178870229e44a"
}
def calculate_sha256(filepath):
"""Calculate SHA256 hash of a file."""
sha256_hash = hashlib.sha256()
with open(filepath, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def download_file(url, filepath, expected_hash=None):
"""Download a file with progress bar and hash verification."""
print(f"Downloading {filepath.name}...")
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(filepath, 'wb') as f, tqdm(
desc=filepath.name,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
bar.update(size)
if expected_hash:
print("Verifying hash...")
file_hash = calculate_sha256(filepath)
if file_hash != expected_hash:
print(f"❌ Hash mismatch for {filepath.name}!")
print(f"Expected: {expected_hash}")
print(f"Got: {file_hash}")
return False
print("✅ Hash verified.")
return True
def main():
MODELS_DIR.mkdir(exist_ok=True)
print(f"Checking models in {MODELS_DIR}...")
for filename, expected_hash in MODELS.items():
filepath = MODELS_DIR / filename
if filepath.exists():
print(f"\nChecking existing file: {filename}")
current_hash = calculate_sha256(filepath)
if current_hash == expected_hash:
print(f"✅ {filename} is up to date.")
continue
else:
print(f"⚠️ {filename} exists but hash mismatch. Re-downloading...")
else:
print(f"\nMissing file: {filename}")
url = BASE_URL_HF + filename
try:
success = download_file(url, filepath, expected_hash)
if not success:
print(f"❌ Failed to verify {filename}")
except Exception as e:
print(f"❌ Error downloading {filename}: {e}")
print("\nDone!")
if __name__ == "__main__":
main()