vit-matte / app /utils.py
pillipop
finish
3c5f6c6 unverified
import os, torch
from io import BytesIO
from base64 import b64encode
from urllib.request import urlretrieve
from urllib.parse import urlparse
from PIL import Image
from typing import Optional
from collections import OrderedDict
def image_to_base64(image: Image.Image | None):
if image == None:
return None
buffered = BytesIO()
image.save(buffered, format="JPEG")
return "data:image/png;base64," + b64encode(buffered.getvalue()).decode()
def is_valid_url(url):
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
def download_with_progress(model_url, file_path):
try:
response, _ = urlretrieve(model_url, file_path, reporthook=download_progress)
except Exception as e:
print(f"Error downloading the model: {e}")
return False
else:
return True
def download_progress(block_num, block_size, total_size):
progress = min(1.0, block_num * block_size / total_size)
bar_length = 50
block = int(round(bar_length * progress))
progress_percent = progress * 100
progress_bar = f"[{'=' * block}{' ' * (bar_length - block)}] {progress_percent:.2f}%\r"
print(progress_bar, end='', flush=True)
def check_or_download_model(model_url, file_path):
if not is_valid_url(model_url):
print("Invalid model URL.")
return
if os.path.exists(file_path):
print("Model already exists at:", file_path)
else:
print("No model found, downloading model.")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
download_with_progress(model_url, file_path)
print("\nModel downloaded successfully.")
def load_checkpoint(model, checkpoint_path):
# Load model checkpoint
model_state_dict = torch.load(
checkpoint_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Create a new state dictionary without the 'module.' prefix
new_state_dict = OrderedDict()
for k, v in model_state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# Load the new state dictionary into the model
model.load_state_dict(new_state_dict)
# Print a confirmation message
print("---- Checkpoint loaded from path: {} ----".format(checkpoint_path))
return model