import os import cv2 import numpy as np from PIL import Image, ImageEnhance, ImageFilter import time try: from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.outputs import OutputKeys HAS_MODELSCOPE = True except ImportError: HAS_MODELSCOPE = False try: import torch except ImportError: torch = None class MockPipeline: def __call__(self, image): # Simulate work based on image size h, w = image.shape[:2] time.sleep((h * w) / 10_000_000.0) # Fake colorization (simple tint) # Input is RGB output = image.copy() # Convert to BGR for output consistency with real model output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) # Tint output[:, :, 0] = np.clip(output[:, :, 0] * 0.9, 0, 255) # B output[:, :, 1] = np.clip(output[:, :, 1] * 0.95, 0, 255) # G output[:, :, 2] = np.clip(output[:, :, 2] * 1.1, 0, 255) # R return {'output_img': output} class Colorizer: def __init__(self, model_id="iic/cv_ddcolor_image-colorization", device="cpu"): self.model_id = model_id self.device = device self.pipeline = None self.load_model() def load_model(self): if HAS_MODELSCOPE: try: print(f"Loading model {self.model_id}...") self.pipeline = pipeline( Tasks.image_colorization, model=self.model_id, # device=self.device ) print("Model loaded.") # Dynamic Quantization for CPU if self.device == 'cpu' and torch is not None and hasattr(self.pipeline, 'model'): try: print("Applying dynamic quantization...") self.pipeline.model = torch.quantization.quantize_dynamic( self.pipeline.model, {torch.nn.Linear}, dtype=torch.qint8 ) print("Quantization applied.") except Exception as qe: print(f"Quantization failed: {qe}") except Exception as e: print(f"Failed to load real model: {e}. Using mock.") self.pipeline = MockPipeline() else: print("ModelScope not found. Using Mock.") self.pipeline = MockPipeline() def process(self, img_pil: Image.Image, brightness: float = 1.0, contrast: float = 1.0, edge_enhance: bool = False, adaptive_resolution: int = 512) -> Image.Image: """ Process a PIL Image: Colorize -> Enhance. Args: img_pil: Input image (PIL) brightness: Brightness factor contrast: Contrast factor edge_enhance: Apply edge enhancement adaptive_resolution: Max dimension for inference. If image is larger, it's resized for colorization, then upscaled and merged with original Luma. Set to 0 to disable. Returns a PIL Image. """ t0 = time.time() w_orig, h_orig = img_pil.size use_adaptive = (w_orig > adaptive_resolution or h_orig > adaptive_resolution) and adaptive_resolution > 0 if use_adaptive: # Downscale for inference scale = adaptive_resolution / max(w_orig, h_orig) new_w, new_h = int(w_orig * scale), int(h_orig * scale) # print(f"Adaptive: Resizing {w_orig}x{h_orig} -> {new_w}x{new_h}") img_input = img_pil.resize((new_w, new_h), Image.BILINEAR) else: img_input = img_pil # Convert PIL to Numpy RGB img_np = np.array(img_input) t1 = time.time() # Colorize try: output = self.pipeline(img_np) except Exception as e: print(f"Inference error: {e}") raise e t2 = time.time() # Extract result (BGR) if isinstance(output, dict): key = OutputKeys.OUTPUT_IMG if HAS_MODELSCOPE else 'output_img' result_bgr = output[key] else: result_bgr = output result_bgr = result_bgr.astype(np.uint8) if use_adaptive: # 1. Convert Low-Res Result to LAB result_lab = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2LAB) # 2. Get High-Res Original Luma orig_np = np.array(img_pil) # RGB orig_bgr = cv2.cvtColor(orig_np, cv2.COLOR_RGB2BGR) # BGR orig_lab = cv2.cvtColor(orig_bgr, cv2.COLOR_BGR2LAB) L_orig = orig_lab[:, :, 0] # 3. Resize Low-Res AB channels to Original Size result_lab_up = cv2.resize(result_lab, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC) # 4. Merge merged_lab = np.empty_like(orig_lab) merged_lab[:, :, 0] = L_orig merged_lab[:, :, 1] = result_lab_up[:, :, 1] merged_lab[:, :, 2] = result_lab_up[:, :, 2] # 5. Convert back to RGB result_bgr_final = cv2.cvtColor(merged_lab, cv2.COLOR_LAB2BGR) result_rgb = cv2.cvtColor(result_bgr_final, cv2.COLOR_BGR2RGB) else: # Convert BGR to RGB result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) t3 = time.time() # Enhance out_pil = Image.fromarray(result_rgb) if brightness != 1.0: out_pil = ImageEnhance.Brightness(out_pil).enhance(brightness) if contrast != 1.0: out_pil = ImageEnhance.Contrast(out_pil).enhance(contrast) if edge_enhance: out_pil = out_pil.filter(ImageFilter.EDGE_ENHANCE) t4 = time.time() # print(f"Timing: Pre={t1-t0:.4f}, Infer={t2-t1:.4f}, Post={t3-t2:.4f}, Enhance={t4-t3:.4f}") return out_pil