gray_to_color / inference.py
VanNguyen1214's picture
Update inference.py
2e802be verified
from GAN.infer import ImageColorizationPipeline
from PIL import Image
import numpy as np
import cv2
class GrayscaleColorizer:
def __init__(self):
self.pipeline = ImageColorizationPipeline(
model_path="GAN/pytorch_model.bin", # relative path to weights
input_size=256,
model_size="tiny"
)
def __call__(self, inputs):
pil_image: Image.Image = inputs["image"]
pil_image = pil_image.convert("RGB")
np_img = np.array(pil_image)
img_bgr = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
result_bgr = self.pipeline.process(img_bgr)
result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
result_pil = Image.fromarray(result_rgb)
return {"image": result_pil}
def pipeline():
return GrayscaleColorizer()