|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torchvision
|
|
|
import torch.optim
|
|
|
import os
|
|
|
import sys
|
|
|
import argparse
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
import model
|
|
|
|
|
|
def run_inference(input_path, output_path, model_path):
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES']='0'
|
|
|
|
|
|
|
|
|
|
|
|
DCE_net = model.enhance_net_nopool().cuda()
|
|
|
|
|
|
|
|
|
|
|
|
DCE_net.load_state_dict(torch.load(model_path))
|
|
|
DCE_net.eval()
|
|
|
|
|
|
|
|
|
if not os.path.exists(input_path):
|
|
|
print(f"❌ 错误: 找不到输入图片 {input_path}")
|
|
|
return
|
|
|
|
|
|
data_lowlight = cv2.imread(input_path)
|
|
|
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB)
|
|
|
data_lowlight = torch.from_numpy(data_lowlight).float() / 255.0
|
|
|
data_lowlight = data_lowlight.permute(2,0,1).cuda().unsqueeze(0)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
|
_, enhanced_image, _ = DCE_net(data_lowlight)
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
torchvision.utils.save_image(enhanced_image, output_path)
|
|
|
print(f"ZeroDCE_Success: {output_path}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('-i', '--input', required=True)
|
|
|
parser.add_argument('-o', '--output', required=True)
|
|
|
parser.add_argument('-m', '--model', required=True)
|
|
|
args = parser.parse_args()
|
|
|
run_inference(args.input, args.output, args.model) |