File size: 5,971 Bytes
0e868b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import cv2
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import time

try:
    from modelscope.pipelines import pipeline
    from modelscope.utils.constant import Tasks
    from modelscope.outputs import OutputKeys
    HAS_MODELSCOPE = True
except ImportError:
    HAS_MODELSCOPE = False

try:
    import torch
except ImportError:
    torch = None

class MockPipeline:
    def __call__(self, image):
        # Simulate work based on image size
        h, w = image.shape[:2]
        time.sleep((h * w) / 10_000_000.0)

        # Fake colorization (simple tint)
        # Input is RGB
        output = image.copy()
        # Convert to BGR for output consistency with real model
        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

        # Tint
        output[:, :, 0] = np.clip(output[:, :, 0] * 0.9, 0, 255) # B
        output[:, :, 1] = np.clip(output[:, :, 1] * 0.95, 0, 255) # G
        output[:, :, 2] = np.clip(output[:, :, 2] * 1.1, 0, 255) # R

        return {'output_img': output}

class Colorizer:
    def __init__(self, model_id="iic/cv_ddcolor_image-colorization", device="cpu"):
        self.model_id = model_id
        self.device = device
        self.pipeline = None
        self.load_model()

    def load_model(self):
        if HAS_MODELSCOPE:
            try:
                print(f"Loading model {self.model_id}...")
                self.pipeline = pipeline(
                    Tasks.image_colorization,
                    model=self.model_id,
                    # device=self.device
                )
                print("Model loaded.")

                # Dynamic Quantization for CPU
                if self.device == 'cpu' and torch is not None and hasattr(self.pipeline, 'model'):
                    try:
                        print("Applying dynamic quantization...")
                        self.pipeline.model = torch.quantization.quantize_dynamic(
                            self.pipeline.model, {torch.nn.Linear}, dtype=torch.qint8
                        )
                        print("Quantization applied.")
                    except Exception as qe:
                        print(f"Quantization failed: {qe}")

            except Exception as e:
                print(f"Failed to load real model: {e}. Using mock.")
                self.pipeline = MockPipeline()
        else:
            print("ModelScope not found. Using Mock.")
            self.pipeline = MockPipeline()

    def process(self, img_pil: Image.Image, brightness: float = 1.0, contrast: float = 1.0, edge_enhance: bool = False, adaptive_resolution: int = 512) -> Image.Image:
        """
        Process a PIL Image: Colorize -> Enhance.

        Args:
            img_pil: Input image (PIL)
            brightness: Brightness factor
            contrast: Contrast factor
            edge_enhance: Apply edge enhancement
            adaptive_resolution: Max dimension for inference.
                                 If image is larger, it's resized for colorization,
                                 then upscaled and merged with original Luma.
                                 Set to 0 to disable.

        Returns a PIL Image.
        """
        t0 = time.time()
        w_orig, h_orig = img_pil.size
        use_adaptive = (w_orig > adaptive_resolution or h_orig > adaptive_resolution) and adaptive_resolution > 0

        if use_adaptive:
            # Downscale for inference
            scale = adaptive_resolution / max(w_orig, h_orig)
            new_w, new_h = int(w_orig * scale), int(h_orig * scale)
            # print(f"Adaptive: Resizing {w_orig}x{h_orig} -> {new_w}x{new_h}")
            img_input = img_pil.resize((new_w, new_h), Image.BILINEAR)
        else:
            img_input = img_pil

        # Convert PIL to Numpy RGB
        img_np = np.array(img_input)

        t1 = time.time()
        # Colorize
        try:
            output = self.pipeline(img_np)
        except Exception as e:
            print(f"Inference error: {e}")
            raise e
        t2 = time.time()

        # Extract result (BGR)
        if isinstance(output, dict):
            key = OutputKeys.OUTPUT_IMG if HAS_MODELSCOPE else 'output_img'
            result_bgr = output[key]
        else:
            result_bgr = output

        result_bgr = result_bgr.astype(np.uint8)

        if use_adaptive:
            # 1. Convert Low-Res Result to LAB
            result_lab = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2LAB)

            # 2. Get High-Res Original Luma
            orig_np = np.array(img_pil) # RGB
            orig_bgr = cv2.cvtColor(orig_np, cv2.COLOR_RGB2BGR) # BGR
            orig_lab = cv2.cvtColor(orig_bgr, cv2.COLOR_BGR2LAB)
            L_orig = orig_lab[:, :, 0]

            # 3. Resize Low-Res AB channels to Original Size
            result_lab_up = cv2.resize(result_lab, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC)

            # 4. Merge
            merged_lab = np.empty_like(orig_lab)
            merged_lab[:, :, 0] = L_orig
            merged_lab[:, :, 1] = result_lab_up[:, :, 1]
            merged_lab[:, :, 2] = result_lab_up[:, :, 2]

            # 5. Convert back to RGB
            result_bgr_final = cv2.cvtColor(merged_lab, cv2.COLOR_LAB2BGR)
            result_rgb = cv2.cvtColor(result_bgr_final, cv2.COLOR_BGR2RGB)
        else:
            # Convert BGR to RGB
            result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)

        t3 = time.time()
        # Enhance
        out_pil = Image.fromarray(result_rgb)

        if brightness != 1.0:
            out_pil = ImageEnhance.Brightness(out_pil).enhance(brightness)
        if contrast != 1.0:
            out_pil = ImageEnhance.Contrast(out_pil).enhance(contrast)
        if edge_enhance:
            out_pil = out_pil.filter(ImageFilter.EDGE_ENHANCE)

        t4 = time.time()
        # print(f"Timing: Pre={t1-t0:.4f}, Infer={t2-t1:.4f}, Post={t3-t2:.4f}, Enhance={t4-t3:.4f}")
        return out_pil