dkescape's picture
Create src/utils.py
7c26b33 verified
import numpy as np
import requests
import streamlit as st
from PIL import Image
from PIL import Image
import numpy as np
import requests
import streamlit as st
import torch # Ensure torch is imported for type hints if not already
from typing import Optional, Tuple, Union # For type hints
from io import BytesIO # For type hints
# Assuming these are the actual model types, otherwise use torch.nn.Module
from .models.deep_colorization.colorizers import eccv16, siggraph17, BaseColor as ColorizationModule
from .models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img
@st.cache_data()
def load_lottieurl(url: str) -> Optional[dict]:
"""
Loads a Lottie animation from a URL.
Lottie files are JSON-based animation files that enable designers to ship
animations on any platform as easily as shipping static assets.
Args:
url: The URL of the Lottie JSON file.
Returns:
A dictionary representing the Lottie animation data if successful,
None otherwise.
"""
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
@st.cache_resource()
def change_model(model_name: str) -> ColorizationModule:
"""
Loads a specified pre-trained colorization model.
Args:
model_name: The name of the model to load ("ECCV16" or "SIGGRAPH17").
Returns:
The loaded PyTorch model (evaluated and pre-trained).
Raises:
ValueError: If the model_name is not recognized.
"""
if model_name == "ECCV16":
loaded_model = eccv16(pretrained=True).eval()
elif model_name == "SIGGRAPH17":
loaded_model = siggraph17(pretrained=True).eval()
else:
raise ValueError(f"Unknown model name: {model_name}. Choose 'ECCV16' or 'SIGGRAPH17'.")
return loaded_model
def format_time(seconds: float) -> str:
"""
Formats time in seconds to a human-readable string.
The output will be in the format of "X days, Y hours, Z minutes, and S seconds",
omitting larger units if they are zero.
Args:
seconds: The total number of seconds.
Returns:
A string representing the formatted time.
"""
if not isinstance(seconds, (int, float)):
raise TypeError("Input 'seconds' must be a number.")
if seconds < 0:
raise ValueError("Input 'seconds' cannot be negative.")
if seconds == 0:
return "0 seconds"
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
parts = []
if days > 0:
parts.append(f"{days} day{'s' if days != 1 else ''}")
if hours > 0:
parts.append(f"{hours} hour{'s' if hours != 1 else ''}")
if minutes > 0:
parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}")
if secs > 0 or not parts: # Always show seconds if it's the only unit or non-zero
parts.append(f"{secs} second{'s' if secs != 1 else ''}")
if not parts: # Should not happen if seconds >= 0
return "0 seconds"
if len(parts) == 1:
return parts[0]
return ", ".join(parts[:-1]) + " and " + parts[-1]
def colorize_frame(frame: np.ndarray, colorizer: ColorizationModule) -> np.ndarray:
"""
Colorizes a single video frame.
Args:
frame: The input video frame as a NumPy array (BGR format expected by OpenCV).
colorizer: The pre-loaded colorization model.
Returns:
The colorized frame as a NumPy array (RGB format).
"""
# preprocess_img expects RGB, cv2 frames are BGR
frame_rgb = frame[:,:,::-1]
tens_l_orig, tens_l_rs = preprocess_img(frame_rgb, HW=(256, 256))
# Model output is normalized, postprocess_tens handles unnormalization and returns RGB
colorized_rgb = postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
return colorized_rgb
def colorize_image(file: Union[str, BytesIO, np.ndarray], loaded_model: ColorizationModule) -> Tuple[np.ndarray, Image.Image]:
"""
Colorizes an image.
Args:
file: The image file, can be a path (str), a file-like object (BytesIO),
or an already loaded image as a NumPy array (RGB).
loaded_model: The pre-loaded colorization model.
Returns:
A tuple containing:
- out_img (np.ndarray): The colorized image as a NumPy array (RGB format),
suitable for display with st.image.
- new_img (PIL.Image.Image): The colorized image as a PIL Image object.
"""
img = load_img(file) # load_img handles path or BytesIO, returns RGB np.array
# If user input a colored image with 4 channels (RGBA), discard the alpha channel.
if img.ndim == 3 and img.shape[2] == 4:
img = img[:, :, :3]
tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256))
# Model output is normalized, postprocess_tens handles unnormalization and returns RGB
out_img_rgb = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu())
# Convert the float [0,1] RGB numpy array to uint8 [0,255] for PIL
new_img_pil = Image.fromarray((out_img_rgb * 255).astype(np.uint8))
return out_img_rgb, new_img_pil