test213 / core.py
dkescape's picture
Upload 10 files
0e868b4 verified
import os
import cv2
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import time
try:
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.outputs import OutputKeys
HAS_MODELSCOPE = True
except ImportError:
HAS_MODELSCOPE = False
try:
import torch
except ImportError:
torch = None
class MockPipeline:
def __call__(self, image):
# Simulate work based on image size
h, w = image.shape[:2]
time.sleep((h * w) / 10_000_000.0)
# Fake colorization (simple tint)
# Input is RGB
output = image.copy()
# Convert to BGR for output consistency with real model
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
# Tint
output[:, :, 0] = np.clip(output[:, :, 0] * 0.9, 0, 255) # B
output[:, :, 1] = np.clip(output[:, :, 1] * 0.95, 0, 255) # G
output[:, :, 2] = np.clip(output[:, :, 2] * 1.1, 0, 255) # R
return {'output_img': output}
class Colorizer:
def __init__(self, model_id="iic/cv_ddcolor_image-colorization", device="cpu"):
self.model_id = model_id
self.device = device
self.pipeline = None
self.load_model()
def load_model(self):
if HAS_MODELSCOPE:
try:
print(f"Loading model {self.model_id}...")
self.pipeline = pipeline(
Tasks.image_colorization,
model=self.model_id,
# device=self.device
)
print("Model loaded.")
# Dynamic Quantization for CPU
if self.device == 'cpu' and torch is not None and hasattr(self.pipeline, 'model'):
try:
print("Applying dynamic quantization...")
self.pipeline.model = torch.quantization.quantize_dynamic(
self.pipeline.model, {torch.nn.Linear}, dtype=torch.qint8
)
print("Quantization applied.")
except Exception as qe:
print(f"Quantization failed: {qe}")
except Exception as e:
print(f"Failed to load real model: {e}. Using mock.")
self.pipeline = MockPipeline()
else:
print("ModelScope not found. Using Mock.")
self.pipeline = MockPipeline()
def process(self, img_pil: Image.Image, brightness: float = 1.0, contrast: float = 1.0, edge_enhance: bool = False, adaptive_resolution: int = 512) -> Image.Image:
"""
Process a PIL Image: Colorize -> Enhance.
Args:
img_pil: Input image (PIL)
brightness: Brightness factor
contrast: Contrast factor
edge_enhance: Apply edge enhancement
adaptive_resolution: Max dimension for inference.
If image is larger, it's resized for colorization,
then upscaled and merged with original Luma.
Set to 0 to disable.
Returns a PIL Image.
"""
t0 = time.time()
w_orig, h_orig = img_pil.size
use_adaptive = (w_orig > adaptive_resolution or h_orig > adaptive_resolution) and adaptive_resolution > 0
if use_adaptive:
# Downscale for inference
scale = adaptive_resolution / max(w_orig, h_orig)
new_w, new_h = int(w_orig * scale), int(h_orig * scale)
# print(f"Adaptive: Resizing {w_orig}x{h_orig} -> {new_w}x{new_h}")
img_input = img_pil.resize((new_w, new_h), Image.BILINEAR)
else:
img_input = img_pil
# Convert PIL to Numpy RGB
img_np = np.array(img_input)
t1 = time.time()
# Colorize
try:
output = self.pipeline(img_np)
except Exception as e:
print(f"Inference error: {e}")
raise e
t2 = time.time()
# Extract result (BGR)
if isinstance(output, dict):
key = OutputKeys.OUTPUT_IMG if HAS_MODELSCOPE else 'output_img'
result_bgr = output[key]
else:
result_bgr = output
result_bgr = result_bgr.astype(np.uint8)
if use_adaptive:
# 1. Convert Low-Res Result to LAB
result_lab = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2LAB)
# 2. Get High-Res Original Luma
orig_np = np.array(img_pil) # RGB
orig_bgr = cv2.cvtColor(orig_np, cv2.COLOR_RGB2BGR) # BGR
orig_lab = cv2.cvtColor(orig_bgr, cv2.COLOR_BGR2LAB)
L_orig = orig_lab[:, :, 0]
# 3. Resize Low-Res AB channels to Original Size
result_lab_up = cv2.resize(result_lab, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC)
# 4. Merge
merged_lab = np.empty_like(orig_lab)
merged_lab[:, :, 0] = L_orig
merged_lab[:, :, 1] = result_lab_up[:, :, 1]
merged_lab[:, :, 2] = result_lab_up[:, :, 2]
# 5. Convert back to RGB
result_bgr_final = cv2.cvtColor(merged_lab, cv2.COLOR_LAB2BGR)
result_rgb = cv2.cvtColor(result_bgr_final, cv2.COLOR_BGR2RGB)
else:
# Convert BGR to RGB
result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
t3 = time.time()
# Enhance
out_pil = Image.fromarray(result_rgb)
if brightness != 1.0:
out_pil = ImageEnhance.Brightness(out_pil).enhance(brightness)
if contrast != 1.0:
out_pil = ImageEnhance.Contrast(out_pil).enhance(contrast)
if edge_enhance:
out_pil = out_pil.filter(ImageFilter.EDGE_ENHANCE)
t4 = time.time()
# print(f"Timing: Pre={t1-t0:.4f}, Infer={t2-t1:.4f}, Post={t3-t2:.4f}, Enhance={t4-t3:.4f}")
return out_pil