| import os, torch |
| from io import BytesIO |
| from base64 import b64encode |
| from urllib.request import urlretrieve |
| from urllib.parse import urlparse |
| from PIL import Image |
| from typing import Optional |
| from collections import OrderedDict |
|
|
| def image_to_base64(image: Image.Image | None): |
| if image == None: |
| return None |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| return "data:image/png;base64," + b64encode(buffered.getvalue()).decode() |
|
|
| def is_valid_url(url): |
| try: |
| result = urlparse(url) |
| return all([result.scheme, result.netloc]) |
| except ValueError: |
| return False |
| |
| def download_with_progress(model_url, file_path): |
| try: |
| response, _ = urlretrieve(model_url, file_path, reporthook=download_progress) |
| except Exception as e: |
| print(f"Error downloading the model: {e}") |
| return False |
| else: |
| return True |
|
|
| def download_progress(block_num, block_size, total_size): |
| progress = min(1.0, block_num * block_size / total_size) |
| bar_length = 50 |
| block = int(round(bar_length * progress)) |
| progress_percent = progress * 100 |
| progress_bar = f"[{'=' * block}{' ' * (bar_length - block)}] {progress_percent:.2f}%\r" |
| print(progress_bar, end='', flush=True) |
|
|
| def check_or_download_model(model_url, file_path): |
| if not is_valid_url(model_url): |
| print("Invalid model URL.") |
| return |
|
|
| if os.path.exists(file_path): |
| print("Model already exists at:", file_path) |
| else: |
| print("No model found, downloading model.") |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| download_with_progress(model_url, file_path) |
| print("\nModel downloaded successfully.") |
|
|
| def load_checkpoint(model, checkpoint_path): |
| |
| model_state_dict = torch.load( |
| checkpoint_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
| |
| new_state_dict = OrderedDict() |
| for k, v in model_state_dict.items(): |
| name = k[7:] |
| new_state_dict[name] = v |
|
|
| |
| model.load_state_dict(new_state_dict) |
|
|
| |
| print("---- Checkpoint loaded from path: {} ----".format(checkpoint_path)) |
|
|
| return model |
|
|