| import argparse
|
| import torch
|
| import cv2
|
| import numpy as np
|
| import os
|
| import sys
|
| import glob
|
|
|
|
|
| current_dir = os.path.dirname(os.path.abspath(__file__))
|
| sys.path.append(current_dir)
|
|
|
|
|
| try:
|
| from net.model import PromptIR
|
| except ImportError as e:
|
| print(f"❌ 导入失败: {e}")
|
| sys.exit(1)
|
|
|
| def run_inference(input_path, output_path, model_path_arg):
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| print(f"🚀 [PromptIR] 启动中...")
|
|
|
|
|
| final_model_path = model_path_arg
|
| if not os.path.exists(model_path_arg):
|
| ckpt_dir = os.path.join(current_dir, 'ckpt')
|
| candidates = glob.glob(os.path.join(ckpt_dir, "*.pth")) + glob.glob(os.path.join(ckpt_dir, "*.ckpt"))
|
| if candidates:
|
| final_model_path = candidates[0]
|
| print(f"⚠️ 指定权重不存在,自动使用: {final_model_path}")
|
| else:
|
| print(f"❌ 找不到权重文件!")
|
| return
|
|
|
|
|
| print("⚙️ 初始化模型 (配置: All-in-One)...")
|
| try:
|
|
|
|
|
| model = PromptIR(
|
| inp_channels=3,
|
| out_channels=3,
|
| dim=48,
|
| num_blocks=[4, 6, 6, 8],
|
| num_refinement_blocks=4,
|
| heads=[1, 2, 4, 8],
|
| ffn_expansion_factor=2.66,
|
| bias=False,
|
| LayerNorm_type='WithBias',
|
| decoder=True
|
| )
|
| except TypeError as e:
|
| print(f"⚠️ 参数初始化失败 ({e}),尝试回退到无参初始化...")
|
| model = PromptIR()
|
| except Exception as e:
|
| print(f"❌ 初始化严重错误: {e}")
|
| return
|
|
|
|
|
| print(f"📦 加载权重: {os.path.basename(final_model_path)}")
|
| try:
|
| checkpoint = torch.load(final_model_path, map_location=device)
|
|
|
| if 'state_dict' in checkpoint:
|
| state_dict = checkpoint['state_dict']
|
| elif 'params' in checkpoint:
|
| state_dict = checkpoint['params']
|
| else:
|
| state_dict = checkpoint
|
|
|
|
|
| new_state_dict = {}
|
| for k, v in state_dict.items():
|
| if k.startswith('net.'):
|
| new_state_dict[k[4:]] = v
|
| else:
|
| new_state_dict[k] = v
|
|
|
|
|
|
|
| model.load_state_dict(new_state_dict, strict=True)
|
| print("✅ 权重加载成功 (Strict Mode)")
|
| except Exception as e:
|
| print(f"❌ 权重加载失败: {e}")
|
|
|
| try:
|
| print("🔄 尝试非严格加载...")
|
| model.load_state_dict(new_state_dict, strict=False)
|
| print("✅ 非严格加载成功")
|
| except:
|
| return
|
|
|
| model.eval().to(device)
|
|
|
|
|
| if not os.path.exists(input_path):
|
| print(f"❌ 输入图片不存在: {input_path}")
|
| return
|
|
|
| img_lq = cv2.imread(input_path, cv2.IMREAD_COLOR)
|
| if img_lq is None:
|
| print(f"❌ 图片读取失败: {input_path}")
|
| return
|
|
|
| img_lq = img_lq.astype(np.float32) / 255.0
|
| img_lq = np.transpose(img_lq, (2, 0, 1))
|
| img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)
|
|
|
|
|
| with torch.no_grad():
|
| _, _, h, w = img_lq.size()
|
| factor = 8
|
| pad_h = (factor - h % factor) % factor
|
| pad_w = (factor - w % factor) % factor
|
|
|
| if pad_h > 0 or pad_w > 0:
|
| img_lq = torch.nn.functional.pad(img_lq, (0, pad_w, 0, pad_h), mode='reflect')
|
|
|
| output = model(img_lq)
|
|
|
| if pad_h > 0 or pad_w > 0:
|
| output = output[:, :, :h, :w]
|
|
|
|
|
| output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| output = np.transpose(output, (1, 2, 0))
|
| output = (output * 255.0).round().astype(np.uint8)
|
|
|
| os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| cv2.imwrite(output_path, output)
|
| print(f"✅ PromptIR 处理完成: {output_path}")
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('-i', '--input', required=True)
|
| parser.add_argument('-o', '--output', required=True)
|
| parser.add_argument('-m', '--model', default="ckpt/model.ckpt")
|
| args = parser.parse_args()
|
| run_inference(args.input, args.output, args.model) |