File size: 1,886 Bytes
2ecc7ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import torchvision
import torch.optim
import os
import sys
import argparse
import cv2
import numpy as np
import model # 引用同目录下的 model.py

def run_inference(input_path, output_path, model_path):
    # 强制单卡
    os.environ['CUDA_VISIBLE_DEVICES']='0'
    
    # ================= 核心修复 =================
    # 根据你提供的 model.py,类名是 enhance_net_nopool,且不需要参数
    DCE_net = model.enhance_net_nopool().cuda()
    # ==========================================
    
    # 加载权重
    DCE_net.load_state_dict(torch.load(model_path))
    DCE_net.eval()
    
    # 读取图片
    if not os.path.exists(input_path):
        print(f"❌ 错误: 找不到输入图片 {input_path}")
        return

    data_lowlight = cv2.imread(input_path)
    data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB) 
    data_lowlight = torch.from_numpy(data_lowlight).float() / 255.0
    data_lowlight = data_lowlight.permute(2,0,1).cuda().unsqueeze(0)

    # 推理
    with torch.no_grad():
        # model.py 的 forward 返回三个值: enhance_image_1, enhance_image, r
        # 我们通常要第二个 (enhance_image) 或者第一个,看效果,通常取最后一个作为最终结果
        _, enhanced_image, _ = DCE_net(data_lowlight)

    # 保存
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    torchvision.utils.save_image(enhanced_image, output_path)
    print(f"ZeroDCE_Success: {output_path}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', required=True)
    parser.add_argument('-o', '--output', required=True)
    parser.add_argument('-m', '--model', required=True)
    args = parser.parse_args()
    run_inference(args.input, args.output, args.model)