|
|
import os |
|
|
import sys |
|
|
import hashlib |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
MODELS_DIR = Path(__file__).parent.parent / "models" |
|
|
BASE_URL_HF = "https://huggingface.co/thookham/DeOldify/resolve/main/" |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"ColorizeArtistic_gen.pth": "3f750246fa220529323b85a8905f9b49c0e5d427099185334d048fb5b5e22477", |
|
|
"ColorizeStable_gen.pth": "ca9cd7f43fb8b222c9a70f7b292e305a000694b0ff9d2ae4a6747b1a2e1ee5af", |
|
|
"ColorizeVideo_gen.pth": "e3d98bb6a222354c79f95485c2f078a89dc724e9183662506d9e0f5aafac83ad", |
|
|
"deoldify-art.onnx": "be026e17c47c85527b3084cacad352f7ca0e021c33aa827062c5997ebe72c61f", |
|
|
"deoldify-quant.onnx": "35c3fb7ec52122e30e98ed03fa5082b175d0beb7951d62f8bc2178870229e44a" |
|
|
} |
|
|
|
|
|
def calculate_sha256(filepath): |
|
|
"""Calculate SHA256 hash of a file.""" |
|
|
sha256_hash = hashlib.sha256() |
|
|
with open(filepath, "rb") as f: |
|
|
for byte_block in iter(lambda: f.read(4096), b""): |
|
|
sha256_hash.update(byte_block) |
|
|
return sha256_hash.hexdigest() |
|
|
|
|
|
def download_file(url, filepath, expected_hash=None): |
|
|
"""Download a file with progress bar and hash verification.""" |
|
|
print(f"Downloading {filepath.name}...") |
|
|
|
|
|
response = requests.get(url, stream=True) |
|
|
response.raise_for_status() |
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(filepath, 'wb') as f, tqdm( |
|
|
desc=filepath.name, |
|
|
total=total_size, |
|
|
unit='iB', |
|
|
unit_scale=True, |
|
|
unit_divisor=1024, |
|
|
) as bar: |
|
|
for data in response.iter_content(chunk_size=1024): |
|
|
size = f.write(data) |
|
|
bar.update(size) |
|
|
|
|
|
if expected_hash: |
|
|
print("Verifying hash...") |
|
|
file_hash = calculate_sha256(filepath) |
|
|
if file_hash != expected_hash: |
|
|
print(f"❌ Hash mismatch for {filepath.name}!") |
|
|
print(f"Expected: {expected_hash}") |
|
|
print(f"Got: {file_hash}") |
|
|
return False |
|
|
print("✅ Hash verified.") |
|
|
return True |
|
|
|
|
|
def main(): |
|
|
MODELS_DIR.mkdir(exist_ok=True) |
|
|
print(f"Checking models in {MODELS_DIR}...") |
|
|
|
|
|
for filename, expected_hash in MODELS.items(): |
|
|
filepath = MODELS_DIR / filename |
|
|
|
|
|
if filepath.exists(): |
|
|
print(f"\nChecking existing file: {filename}") |
|
|
current_hash = calculate_sha256(filepath) |
|
|
if current_hash == expected_hash: |
|
|
print(f"✅ {filename} is up to date.") |
|
|
continue |
|
|
else: |
|
|
print(f"⚠️ {filename} exists but hash mismatch. Re-downloading...") |
|
|
else: |
|
|
print(f"\nMissing file: {filename}") |
|
|
|
|
|
url = BASE_URL_HF + filename |
|
|
try: |
|
|
success = download_file(url, filepath, expected_hash) |
|
|
if not success: |
|
|
print(f"❌ Failed to verify {filename}") |
|
|
except Exception as e: |
|
|
print(f"❌ Error downloading {filename}: {e}") |
|
|
|
|
|
print("\nDone!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|