hugaagg's picture
Upload folder using huggingface_hub
2ecc7ab verified
import torch
import torch.nn as nn
import torchvision
import torch.optim
import os
import sys
import argparse
import cv2
import numpy as np
import model # 引用同目录下的 model.py
def run_inference(input_path, output_path, model_path):
# 强制单卡
os.environ['CUDA_VISIBLE_DEVICES']='0'
# ================= 核心修复 =================
# 根据你提供的 model.py,类名是 enhance_net_nopool,且不需要参数
DCE_net = model.enhance_net_nopool().cuda()
# ==========================================
# 加载权重
DCE_net.load_state_dict(torch.load(model_path))
DCE_net.eval()
# 读取图片
if not os.path.exists(input_path):
print(f"❌ 错误: 找不到输入图片 {input_path}")
return
data_lowlight = cv2.imread(input_path)
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB)
data_lowlight = torch.from_numpy(data_lowlight).float() / 255.0
data_lowlight = data_lowlight.permute(2,0,1).cuda().unsqueeze(0)
# 推理
with torch.no_grad():
# model.py 的 forward 返回三个值: enhance_image_1, enhance_image, r
# 我们通常要第二个 (enhance_image) 或者第一个,看效果,通常取最后一个作为最终结果
_, enhanced_image, _ = DCE_net(data_lowlight)
# 保存
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchvision.utils.save_image(enhanced_image, output_path)
print(f"ZeroDCE_Success: {output_path}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', required=True)
parser.add_argument('-o', '--output', required=True)
parser.add_argument('-m', '--model', required=True)
args = parser.parse_args()
run_inference(args.input, args.output, args.model)