File size: 5,227 Bytes
7c26b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy as np
import requests
import streamlit as st
from PIL import Image

from PIL import Image
import numpy as np
import requests
import streamlit as st
import torch # Ensure torch is imported for type hints if not already
from typing import Optional, Tuple, Union # For type hints
from io import BytesIO # For type hints

# Assuming these are the actual model types, otherwise use torch.nn.Module
from .models.deep_colorization.colorizers import eccv16, siggraph17, BaseColor as ColorizationModule
from .models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img


@st.cache_data()
def load_lottieurl(url: str) -> Optional[dict]:
    """
    Loads a Lottie animation from a URL.

    Lottie files are JSON-based animation files that enable designers to ship
    animations on any platform as easily as shipping static assets.

    Args:
        url: The URL of the Lottie JSON file.

    Returns:
        A dictionary representing the Lottie animation data if successful,
        None otherwise.
    """
    r = requests.get(url)
    if r.status_code != 200:
        return None
    return r.json()


@st.cache_resource()
def change_model(model_name: str) -> ColorizationModule:
    """
    Loads a specified pre-trained colorization model.

    Args:
        model_name: The name of the model to load ("ECCV16" or "SIGGRAPH17").

    Returns:
        The loaded PyTorch model (evaluated and pre-trained).

    Raises:
        ValueError: If the model_name is not recognized.
    """
    if model_name == "ECCV16":
        loaded_model = eccv16(pretrained=True).eval()
    elif model_name == "SIGGRAPH17":
        loaded_model = siggraph17(pretrained=True).eval()
    else:
        raise ValueError(f"Unknown model name: {model_name}. Choose 'ECCV16' or 'SIGGRAPH17'.")
    return loaded_model


def format_time(seconds: float) -> str:
    """
    Formats time in seconds to a human-readable string.

    The output will be in the format of "X days, Y hours, Z minutes, and S seconds",
    omitting larger units if they are zero.

    Args:
        seconds: The total number of seconds.

    Returns:
        A string representing the formatted time.
    """
    if not isinstance(seconds, (int, float)):
        raise TypeError("Input 'seconds' must be a number.")
    if seconds < 0:
        raise ValueError("Input 'seconds' cannot be negative.")

    if seconds == 0:
        return "0 seconds"

    days = int(seconds // 86400)
    hours = int((seconds % 86400) // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)

    parts = []
    if days > 0:
        parts.append(f"{days} day{'s' if days != 1 else ''}")
    if hours > 0:
        parts.append(f"{hours} hour{'s' if hours != 1 else ''}")
    if minutes > 0:
        parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}")
    if secs > 0 or not parts: # Always show seconds if it's the only unit or non-zero
        parts.append(f"{secs} second{'s' if secs != 1 else ''}")

    if not parts: # Should not happen if seconds >= 0
        return "0 seconds"

    if len(parts) == 1:
        return parts[0]

    return ", ".join(parts[:-1]) + " and " + parts[-1]


def colorize_frame(frame: np.ndarray, colorizer: ColorizationModule) -> np.ndarray:
    """
    Colorizes a single video frame.

    Args:
        frame: The input video frame as a NumPy array (BGR format expected by OpenCV).
        colorizer: The pre-loaded colorization model.

    Returns:
        The colorized frame as a NumPy array (RGB format).
    """
    # preprocess_img expects RGB, cv2 frames are BGR
    frame_rgb = frame[:,:,::-1]
    tens_l_orig, tens_l_rs = preprocess_img(frame_rgb, HW=(256, 256))
    # Model output is normalized, postprocess_tens handles unnormalization and returns RGB
    colorized_rgb = postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
    return colorized_rgb


def colorize_image(file: Union[str, BytesIO, np.ndarray], loaded_model: ColorizationModule) -> Tuple[np.ndarray, Image.Image]:
    """
    Colorizes an image.

    Args:
        file: The image file, can be a path (str), a file-like object (BytesIO),
              or an already loaded image as a NumPy array (RGB).
        loaded_model: The pre-loaded colorization model.

    Returns:
        A tuple containing:
            - out_img (np.ndarray): The colorized image as a NumPy array (RGB format),
                                    suitable for display with st.image.
            - new_img (PIL.Image.Image): The colorized image as a PIL Image object.
    """
    img = load_img(file) # load_img handles path or BytesIO, returns RGB np.array

    # If user input a colored image with 4 channels (RGBA), discard the alpha channel.
    if img.ndim == 3 and img.shape[2] == 4:
        img = img[:, :, :3]

    tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256))
    # Model output is normalized, postprocess_tens handles unnormalization and returns RGB
    out_img_rgb = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu())

    # Convert the float [0,1] RGB numpy array to uint8 [0,255] for PIL
    new_img_pil = Image.fromarray((out_img_rgb * 255).astype(np.uint8))

    return out_img_rgb, new_img_pil