import os, torch from io import BytesIO from base64 import b64encode from urllib.request import urlretrieve from urllib.parse import urlparse from PIL import Image from typing import Optional from collections import OrderedDict def image_to_base64(image: Image.Image | None): if image == None: return None buffered = BytesIO() image.save(buffered, format="JPEG") return "data:image/png;base64," + b64encode(buffered.getvalue()).decode() def is_valid_url(url): try: result = urlparse(url) return all([result.scheme, result.netloc]) except ValueError: return False def download_with_progress(model_url, file_path): try: response, _ = urlretrieve(model_url, file_path, reporthook=download_progress) except Exception as e: print(f"Error downloading the model: {e}") return False else: return True def download_progress(block_num, block_size, total_size): progress = min(1.0, block_num * block_size / total_size) bar_length = 50 block = int(round(bar_length * progress)) progress_percent = progress * 100 progress_bar = f"[{'=' * block}{' ' * (bar_length - block)}] {progress_percent:.2f}%\r" print(progress_bar, end='', flush=True) def check_or_download_model(model_url, file_path): if not is_valid_url(model_url): print("Invalid model URL.") return if os.path.exists(file_path): print("Model already exists at:", file_path) else: print("No model found, downloading model.") os.makedirs(os.path.dirname(file_path), exist_ok=True) download_with_progress(model_url, file_path) print("\nModel downloaded successfully.") def load_checkpoint(model, checkpoint_path): # Load model checkpoint model_state_dict = torch.load( checkpoint_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Create a new state dictionary without the 'module.' prefix new_state_dict = OrderedDict() for k, v in model_state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # Load the new state dictionary into the model model.load_state_dict(new_state_dict) # Print a confirmation message print("---- Checkpoint loaded from path: {} ----".format(checkpoint_path)) return model