IR_expeiment / PART2 /PromptIR /worker_promptir.py
hugaagg's picture
Upload folder using huggingface_hub
2ecc7ab verified
import argparse
import torch
import cv2
import numpy as np
import os
import sys
import glob
# 1. 路径修正
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
# 尝试导入 PromptIR
try:
from net.model import PromptIR
except ImportError as e:
print(f"❌ 导入失败: {e}")
sys.exit(1)
def run_inference(input_path, output_path, model_path_arg):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 [PromptIR] 启动中...")
# ================= 自动寻找权重文件 =================
final_model_path = model_path_arg
if not os.path.exists(model_path_arg):
ckpt_dir = os.path.join(current_dir, 'ckpt')
candidates = glob.glob(os.path.join(ckpt_dir, "*.pth")) + glob.glob(os.path.join(ckpt_dir, "*.ckpt"))
if candidates:
final_model_path = candidates[0]
print(f"⚠️ 指定权重不存在,自动使用: {final_model_path}")
else:
print(f"❌ 找不到权重文件!")
return
# ================= 2. 初始化模型 (关键修复) =================
print("⚙️ 初始化模型 (配置: All-in-One)...")
try:
# 🔥 核心修正:填入 PromptIR All-in-One 的官方参数 🔥
# 这些参数来自 configs/all-in-one.yml,必须匹配才能加载权重
model = PromptIR(
inp_channels=3,
out_channels=3,
dim=48, # 默认可能是其它值,必须改为 48
num_blocks=[4, 6, 6, 8],
num_refinement_blocks=4,
heads=[1, 2, 4, 8],
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type='WithBias',
decoder=True # 必须开启 Decoder
)
except TypeError as e:
print(f"⚠️ 参数初始化失败 ({e}),尝试回退到无参初始化...")
model = PromptIR()
except Exception as e:
print(f"❌ 初始化严重错误: {e}")
return
# ================= 3. 加载权重 =================
print(f"📦 加载权重: {os.path.basename(final_model_path)}")
try:
checkpoint = torch.load(final_model_path, map_location=device)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'params' in checkpoint:
state_dict = checkpoint['params']
else:
state_dict = checkpoint
# 移除 'net.' 前缀
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('net.'):
new_state_dict[k[4:]] = v
else:
new_state_dict[k] = v
# 允许 strict=True 来验证参数是否真的对了
# 如果还是报错,说明参数还有细微差别,改回 False 即可
model.load_state_dict(new_state_dict, strict=True)
print("✅ 权重加载成功 (Strict Mode)")
except Exception as e:
print(f"❌ 权重加载失败: {e}")
# 备选:尝试非严格加载
try:
print("🔄 尝试非严格加载...")
model.load_state_dict(new_state_dict, strict=False)
print("✅ 非严格加载成功")
except:
return
model.eval().to(device)
# ================= 4. 推理逻辑 =================
if not os.path.exists(input_path):
print(f"❌ 输入图片不存在: {input_path}")
return
img_lq = cv2.imread(input_path, cv2.IMREAD_COLOR)
if img_lq is None:
print(f"❌ 图片读取失败: {input_path}")
return
img_lq = img_lq.astype(np.float32) / 255.0
img_lq = np.transpose(img_lq, (2, 0, 1))
img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)
# Padding
with torch.no_grad():
_, _, h, w = img_lq.size()
factor = 8
pad_h = (factor - h % factor) % factor
pad_w = (factor - w % factor) % factor
if pad_h > 0 or pad_w > 0:
img_lq = torch.nn.functional.pad(img_lq, (0, pad_w, 0, pad_h), mode='reflect')
output = model(img_lq)
if pad_h > 0 or pad_w > 0:
output = output[:, :, :h, :w]
# 保存
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cv2.imwrite(output_path, output)
print(f"✅ PromptIR 处理完成: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', required=True)
parser.add_argument('-o', '--output', required=True)
parser.add_argument('-m', '--model', default="ckpt/model.ckpt")
args = parser.parse_args()
run_inference(args.input, args.output, args.model)