Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import requests | |
| import streamlit as st | |
| from PIL import Image | |
| from PIL import Image | |
| import numpy as np | |
| import requests | |
| import streamlit as st | |
| import torch # Ensure torch is imported for type hints if not already | |
| from typing import Optional, Tuple, Union # For type hints | |
| from io import BytesIO # For type hints | |
| # Assuming these are the actual model types, otherwise use torch.nn.Module | |
| from .models.deep_colorization.colorizers import eccv16, siggraph17, BaseColor as ColorizationModule | |
| from .models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img | |
| def load_lottieurl(url: str) -> Optional[dict]: | |
| """ | |
| Loads a Lottie animation from a URL. | |
| Lottie files are JSON-based animation files that enable designers to ship | |
| animations on any platform as easily as shipping static assets. | |
| Args: | |
| url: The URL of the Lottie JSON file. | |
| Returns: | |
| A dictionary representing the Lottie animation data if successful, | |
| None otherwise. | |
| """ | |
| r = requests.get(url) | |
| if r.status_code != 200: | |
| return None | |
| return r.json() | |
| def change_model(model_name: str) -> ColorizationModule: | |
| """ | |
| Loads a specified pre-trained colorization model. | |
| Args: | |
| model_name: The name of the model to load ("ECCV16" or "SIGGRAPH17"). | |
| Returns: | |
| The loaded PyTorch model (evaluated and pre-trained). | |
| Raises: | |
| ValueError: If the model_name is not recognized. | |
| """ | |
| if model_name == "ECCV16": | |
| loaded_model = eccv16(pretrained=True).eval() | |
| elif model_name == "SIGGRAPH17": | |
| loaded_model = siggraph17(pretrained=True).eval() | |
| else: | |
| raise ValueError(f"Unknown model name: {model_name}. Choose 'ECCV16' or 'SIGGRAPH17'.") | |
| return loaded_model | |
| def format_time(seconds: float) -> str: | |
| """ | |
| Formats time in seconds to a human-readable string. | |
| The output will be in the format of "X days, Y hours, Z minutes, and S seconds", | |
| omitting larger units if they are zero. | |
| Args: | |
| seconds: The total number of seconds. | |
| Returns: | |
| A string representing the formatted time. | |
| """ | |
| if not isinstance(seconds, (int, float)): | |
| raise TypeError("Input 'seconds' must be a number.") | |
| if seconds < 0: | |
| raise ValueError("Input 'seconds' cannot be negative.") | |
| if seconds == 0: | |
| return "0 seconds" | |
| days = int(seconds // 86400) | |
| hours = int((seconds % 86400) // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| parts = [] | |
| if days > 0: | |
| parts.append(f"{days} day{'s' if days != 1 else ''}") | |
| if hours > 0: | |
| parts.append(f"{hours} hour{'s' if hours != 1 else ''}") | |
| if minutes > 0: | |
| parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}") | |
| if secs > 0 or not parts: # Always show seconds if it's the only unit or non-zero | |
| parts.append(f"{secs} second{'s' if secs != 1 else ''}") | |
| if not parts: # Should not happen if seconds >= 0 | |
| return "0 seconds" | |
| if len(parts) == 1: | |
| return parts[0] | |
| return ", ".join(parts[:-1]) + " and " + parts[-1] | |
| def colorize_frame(frame: np.ndarray, colorizer: ColorizationModule) -> np.ndarray: | |
| """ | |
| Colorizes a single video frame. | |
| Args: | |
| frame: The input video frame as a NumPy array (BGR format expected by OpenCV). | |
| colorizer: The pre-loaded colorization model. | |
| Returns: | |
| The colorized frame as a NumPy array (RGB format). | |
| """ | |
| # preprocess_img expects RGB, cv2 frames are BGR | |
| frame_rgb = frame[:,:,::-1] | |
| tens_l_orig, tens_l_rs = preprocess_img(frame_rgb, HW=(256, 256)) | |
| # Model output is normalized, postprocess_tens handles unnormalization and returns RGB | |
| colorized_rgb = postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu()) | |
| return colorized_rgb | |
| def colorize_image(file: Union[str, BytesIO, np.ndarray], loaded_model: ColorizationModule) -> Tuple[np.ndarray, Image.Image]: | |
| """ | |
| Colorizes an image. | |
| Args: | |
| file: The image file, can be a path (str), a file-like object (BytesIO), | |
| or an already loaded image as a NumPy array (RGB). | |
| loaded_model: The pre-loaded colorization model. | |
| Returns: | |
| A tuple containing: | |
| - out_img (np.ndarray): The colorized image as a NumPy array (RGB format), | |
| suitable for display with st.image. | |
| - new_img (PIL.Image.Image): The colorized image as a PIL Image object. | |
| """ | |
| img = load_img(file) # load_img handles path or BytesIO, returns RGB np.array | |
| # If user input a colored image with 4 channels (RGBA), discard the alpha channel. | |
| if img.ndim == 3 and img.shape[2] == 4: | |
| img = img[:, :, :3] | |
| tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256)) | |
| # Model output is normalized, postprocess_tens handles unnormalization and returns RGB | |
| out_img_rgb = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu()) | |
| # Convert the float [0,1] RGB numpy array to uint8 [0,255] for PIL | |
| new_img_pil = Image.fromarray((out_img_rgb * 255).astype(np.uint8)) | |
| return out_img_rgb, new_img_pil |