import torch import torch.nn as nn import torchvision import torch.optim import os import sys import argparse import cv2 import numpy as np import model # 引用同目录下的 model.py def run_inference(input_path, output_path, model_path): # 强制单卡 os.environ['CUDA_VISIBLE_DEVICES']='0' # ================= 核心修复 ================= # 根据你提供的 model.py,类名是 enhance_net_nopool,且不需要参数 DCE_net = model.enhance_net_nopool().cuda() # ========================================== # 加载权重 DCE_net.load_state_dict(torch.load(model_path)) DCE_net.eval() # 读取图片 if not os.path.exists(input_path): print(f"❌ 错误: 找不到输入图片 {input_path}") return data_lowlight = cv2.imread(input_path) data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB) data_lowlight = torch.from_numpy(data_lowlight).float() / 255.0 data_lowlight = data_lowlight.permute(2,0,1).cuda().unsqueeze(0) # 推理 with torch.no_grad(): # model.py 的 forward 返回三个值: enhance_image_1, enhance_image, r # 我们通常要第二个 (enhance_image) 或者第一个,看效果,通常取最后一个作为最终结果 _, enhanced_image, _ = DCE_net(data_lowlight) # 保存 os.makedirs(os.path.dirname(output_path), exist_ok=True) torchvision.utils.save_image(enhanced_image, output_path) print(f"ZeroDCE_Success: {output_path}") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', required=True) parser.add_argument('-o', '--output', required=True) parser.add_argument('-m', '--model', required=True) args = parser.parse_args() run_inference(args.input, args.output, args.model)