File size: 2,876 Bytes
8314c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02d8fbb
8314c30
 
 
 
 
89fffaa
8314c30
89fffaa
 
8314c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7add99b
 
 
89fffaa
 
 
7add99b
 
 
8314c30
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torchvision.transforms import v2
import gradio as gr
from PIL import Image
from colorizer import ColorComicNet, MODEL_CFG
from utils import smart_padding, remove_padding

# Define the transformation pipeline for the input image
TRANSFORM = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.5], std=[0.5])
])

# Image preprocessing and postprocessing functions
def preprocess_image(image: Image.Image, divisor=16):
    """ Preprocess the input PIL image for the model. """
    image = image.convert('RGB')
    image_tensor = TRANSFORM(image).unsqueeze(0)  # Shape: (1, 3, H, W)
    image_tensor, padding = smart_padding(image_tensor, divisor=divisor)
    return image_tensor, padding

def postprocess_output(output_tensor, padding):
    """ Postprocess the model output tensor to a PIL image. """
    output_tensor = remove_padding(output_tensor, padding)
    output_tensor = (output_tensor + 1) / 2  # Scale back to [0, 1]
    output_image = output_tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).numpy()  # Shape: (H, W, C)
    return output_image

# Define the colorization function
def colorize_image(gray_image: Image.Image):
    """ Colorize a single grayscale image using the model. """
    with torch.no_grad():
        # Preprocess
        input_tensor, padding = preprocess_image(gray_image, divisor=64)
        # Inference
        output = model(input_tensor)
        # Postprocess
        output_image = postprocess_output(output, padding)
        return output_image

# Initialize the model
model = ColorComicNet(**MODEL_CFG)
model.load_state_dict(torch.load("./weights/colorizer.pth", map_location=torch.device('cpu')))
model.fuse()
model.eval()

# Create the Gradio interface
with gr.Blocks() as demo:
    # Header
    gr.Markdown("# 🎨 Comic Colorization")
    gr.Markdown("Bring your grayscale comics to life with **ColorComicNet**")
    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            input_image = gr.Image(
                label="📥 Upload Grayscale Image",
                type="pil",
            )
            colorize_button = gr.Button(
                "✨ Colorize Image",
                elem_classes="button-primary"
            )
        with gr.Column(scale=1):
            output_image = gr.Image(
                label="📤 Colorized Result",
                type="numpy",
            )
    # Example section
    gr.Markdown("### 🖼️ Try an example")
    examples = gr.Examples(
        examples=[
            ["./examples/gray.jpg"],
            ["./examples/gray_2.jpg"],
            ["./examples/gray_4.jpg"],
        ],
        inputs=input_image
    )
    # Footer
    gr.Markdown("---")
    # Interaction
    colorize_button.click(
        fn=colorize_image,
        inputs=input_image,
        outputs=output_image
    )
demo.launch()