import numpy as np import requests import streamlit as st from PIL import Image from PIL import Image import numpy as np import requests import streamlit as st import torch # Ensure torch is imported for type hints if not already from typing import Optional, Tuple, Union # For type hints from io import BytesIO # For type hints # Assuming these are the actual model types, otherwise use torch.nn.Module from .models.deep_colorization.colorizers import eccv16, siggraph17, BaseColor as ColorizationModule from .models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img @st.cache_data() def load_lottieurl(url: str) -> Optional[dict]: """ Loads a Lottie animation from a URL. Lottie files are JSON-based animation files that enable designers to ship animations on any platform as easily as shipping static assets. Args: url: The URL of the Lottie JSON file. Returns: A dictionary representing the Lottie animation data if successful, None otherwise. """ r = requests.get(url) if r.status_code != 200: return None return r.json() @st.cache_resource() def change_model(model_name: str) -> ColorizationModule: """ Loads a specified pre-trained colorization model. Args: model_name: The name of the model to load ("ECCV16" or "SIGGRAPH17"). Returns: The loaded PyTorch model (evaluated and pre-trained). Raises: ValueError: If the model_name is not recognized. """ if model_name == "ECCV16": loaded_model = eccv16(pretrained=True).eval() elif model_name == "SIGGRAPH17": loaded_model = siggraph17(pretrained=True).eval() else: raise ValueError(f"Unknown model name: {model_name}. Choose 'ECCV16' or 'SIGGRAPH17'.") return loaded_model def format_time(seconds: float) -> str: """ Formats time in seconds to a human-readable string. The output will be in the format of "X days, Y hours, Z minutes, and S seconds", omitting larger units if they are zero. Args: seconds: The total number of seconds. Returns: A string representing the formatted time. """ if not isinstance(seconds, (int, float)): raise TypeError("Input 'seconds' must be a number.") if seconds < 0: raise ValueError("Input 'seconds' cannot be negative.") if seconds == 0: return "0 seconds" days = int(seconds // 86400) hours = int((seconds % 86400) // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) parts = [] if days > 0: parts.append(f"{days} day{'s' if days != 1 else ''}") if hours > 0: parts.append(f"{hours} hour{'s' if hours != 1 else ''}") if minutes > 0: parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}") if secs > 0 or not parts: # Always show seconds if it's the only unit or non-zero parts.append(f"{secs} second{'s' if secs != 1 else ''}") if not parts: # Should not happen if seconds >= 0 return "0 seconds" if len(parts) == 1: return parts[0] return ", ".join(parts[:-1]) + " and " + parts[-1] def colorize_frame(frame: np.ndarray, colorizer: ColorizationModule) -> np.ndarray: """ Colorizes a single video frame. Args: frame: The input video frame as a NumPy array (BGR format expected by OpenCV). colorizer: The pre-loaded colorization model. Returns: The colorized frame as a NumPy array (RGB format). """ # preprocess_img expects RGB, cv2 frames are BGR frame_rgb = frame[:,:,::-1] tens_l_orig, tens_l_rs = preprocess_img(frame_rgb, HW=(256, 256)) # Model output is normalized, postprocess_tens handles unnormalization and returns RGB colorized_rgb = postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu()) return colorized_rgb def colorize_image(file: Union[str, BytesIO, np.ndarray], loaded_model: ColorizationModule) -> Tuple[np.ndarray, Image.Image]: """ Colorizes an image. Args: file: The image file, can be a path (str), a file-like object (BytesIO), or an already loaded image as a NumPy array (RGB). loaded_model: The pre-loaded colorization model. Returns: A tuple containing: - out_img (np.ndarray): The colorized image as a NumPy array (RGB format), suitable for display with st.image. - new_img (PIL.Image.Image): The colorized image as a PIL Image object. """ img = load_img(file) # load_img handles path or BytesIO, returns RGB np.array # If user input a colored image with 4 channels (RGBA), discard the alpha channel. if img.ndim == 3 and img.shape[2] == 4: img = img[:, :, :3] tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256)) # Model output is normalized, postprocess_tens handles unnormalization and returns RGB out_img_rgb = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu()) # Convert the float [0,1] RGB numpy array to uint8 [0,255] for PIL new_img_pil = Image.fromarray((out_img_rgb * 255).astype(np.uint8)) return out_img_rgb, new_img_pil