| import os |
| import cv2 |
| import json |
| import torch |
| import torchvision.transforms as transforms |
| from CPNet_model import LiteAWBISPNet |
| import torchvision |
| import numpy as np |
| from Utiles import white_balance,apply_color_space_transform, transform_xyz_to_srgb, apply_gamma,fix_orientation,binning,Four2One,One2Four |
| import time |
| from net.mwrcanet import Net |
| import torch.nn as nn |
| from PIL import Image |
| import torch.nn.functional as F |
|
|
| |
| Rpath = './Input' |
| image_files = [] |
|
|
| |
|
|
|
|
| infer_times = [] |
|
|
|
|
| |
| color_matrix = [1.06835938, -0.29882812, -0.14257812, |
| -0.43164062, 1.35546875, 0.05078125, |
| -0.1015625, 0.24414062, 0.5859375] |
|
|
|
|
| |
| transforms_ = [ transforms.ToTensor(), |
| transforms.Resize([768,1024])] |
| transform = transforms.Compose(transforms_) |
|
|
| transforms_ = [ transforms.ToTensor()] |
| transformo = transforms.Compose(transforms_) |
|
|
| |
| model = LiteAWBISPNet() |
| model.cuda() |
| model.load_state_dict(torch.load('./model_zoo/CC2.pth') ) |
|
|
| |
| last_ckpt = './model_zoo/dn_mwrcanet_raw_c1.pth' |
| dn_net = Net() |
| dn_model = nn.DataParallel(dn_net).cuda() |
| tmp_ckpt = torch.load(last_ckpt) |
| pretrained_dict = tmp_ckpt['state_dict'] |
| model_dict=dn_model.state_dict() |
| pretrained_dict_update = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
| assert(len(pretrained_dict)==len(pretrained_dict_update)) |
| assert(len(pretrained_dict_update)==len(model_dict)) |
| model_dict.update(pretrained_dict_update) |
| dn_model.load_state_dict(model_dict) |
|
|
| |
|
|
| for filename in os.listdir(Rpath): |
|
|
| if os.path.splitext(filename)[-1].lower() == ".png": |
| image_files.append(filename) |
|
|
| with torch.no_grad(): |
| for fp in image_files: |
| |
| fp = os.path.join(Rpath, fp) |
| mn = os.path.splitext(fp)[-2] |
| mf = str(mn) + '.json' |
| |
| raw_image = cv2.imread(fp, -1) |
| with open(mf, 'r') as file: |
| data = json.load(file) |
| |
| |
| time_BL_S = time.time() |
| |
| raw_image = (raw_image.astype(np.float32) - 256.) |
| raw_image = raw_image / (4095. - 256.) |
| raw_image = np.clip(raw_image, 0.0, 1.0) |
| |
| |
| |
| |
| |
| raw_image = binning(raw_image,data) |
| |
| |
| |
| |
| |
| raw_image = cv2.resize(raw_image, [1024,768]) |
| |
| |
| |
| |
| Temp_I = Four2One(raw_image) |
| Temp_I = transformo(Temp_I).unsqueeze(0).cuda() |
| Temp_I = dn_model(Temp_I) |
| Temp_I = np.asarray(Temp_I.squeeze(0).squeeze(0).cpu()) |
| raw_image = One2Four(Temp_I) |
| |
| |
| |
| |
| raw_image = white_balance(raw_image, data['as_shot_neutral']) |
| raw_image = apply_color_space_transform(raw_image, color_matrix) |
| raw_image = transform_xyz_to_srgb(raw_image) |
| raw_image = apply_gamma(raw_image) |
|
|
| |
| |
| |
| Source = transform(raw_image).unsqueeze(0).float().cuda() |
| Out = model(Source) |
|
|
| |
|
|
| Out = Out.clip(0,1) |
| OA = np.asarray(Out.squeeze(0).cpu()).transpose(1,2,0).astype(np.float32) |
| OA = OA*255. |
| OA = OA.astype(np.uint8) |
| OA = fix_orientation(OA,data["orientation"]) |
| time_Save_F = time.time() |
| OA = cv2.cvtColor(OA, cv2.COLOR_RGB2BGR) |
| OA = cv2.imwrite('./Output/' + str(os.path.basename(fp)),OA) |
|
|
| infer_times.append(time_Save_F-time_BL_S) |
| print(f"Average inference time: {np.mean(infer_times)} seconds") |
|
|