AStv2 / model.py
yssszzzzzzzzy's picture
Upload model.py
9cd663d verified
import numpy as np
import os
import math
from tqdm import tqdm
import cv2
import torch.nn as nn
import torch
import torch.nn.functional as F
from basicsr.models import create_model
from basicsr.utils.options import parse
from skimage import img_as_ubyte
import yaml
from PIL import Image
class ASTv2ImageRestoration:
def __init__(self, task='dehazing', device='cpu', use_GT_mean=False):
"""
初始化ASTv2图像恢复模型 - CPU版本
Args:
task: 任务类型 - 'dehazing', 'deshadowing', 'desnowing', 'deraining',
'enhancement_lol_v2_real', 'enhancement_lol_v2_synthetic',
'enhancement_smid', 'deblurring'
device: 运行设备 - 固定为'cpu'(为了兼容性保留此参数)
use_GT_mean: 是否使用GT均值校正(仅用于enhancement任务)
"""
self.task = task
# 无论传入什么device参数,都强制使用CPU
self.device = 'cpu'
self.use_GT_mean = use_GT_mean
# 如果传入的device不是cpu,给出提示
if device != 'cpu':
print(f"注意: 传入的device参数为'{device}',但此版本只支持CPU,已自动设置为'cpu'")
self.factor = 4 if 'enhancement' in task else 8
# 配置文件和权重路径映射
self.config_map = {
'dehazing': {
'yaml': './options/RealDehazing_ASTv2_syn.yml',
'weights': './model/synHaze_ASTv2.pth',
'use_split': True,
'crop_size': 256,
'overlap_size': 200
},
'deshadowing': {
'yaml': './options/Deshadowing_ASTv2.yml',
'weights': './model/shaodw_ASTv2.pth',
'use_split': False
},
'desnowing': {
'yaml': './options/Desnowing_ASTv2.yml',
'weights': './model/snow_astv2.pth',
'use_split': False
},
'deraining': {
'yaml': './options/Deraining_ASTv2_spad.yml',
'weights': './model/rainStreak_ASTv2.pth',
'use_split': False
},
'enhancement_lol_v2_real': {
'yaml': './options/AST_LOL_v2_real.yml',
'weights': './model/LOLv2_Real_ASTv2.pth',
'use_split': True,
'crop_size': 400,
'overlap_size': 384
},
'enhancement_lol_v2_synthetic': {
'yaml': './options/AST_LOL_v2_synthetic.yml',
'weights': './model/LOLv2_Syn_ASTv2.pth',
'use_split': False
},
'enhancement_smid': {
'yaml': './options/AST_SMID.yml',
'weights': './model/SMID_ASTv2.pth',
'use_split': False
},
'deblurring': {
'yaml': './options/Deblurring_ASTv2_L.yml',
'weights': './model/motionblur_ASTv2.pth',
'use_split': True,
'crop_size': 512,
'overlap_size': 300
}
}
# 加载模型
self.model = self._load_model()
def _load_model(self):
"""加载模型配置和权重"""
config_info = self.config_map[self.task]
# 对于enhancement和deblurring任务,使用不同的加载方式
if 'enhancement' in self.task:
# 使用basicsr的create_model
opt = parse(config_info['yaml'], is_train=False)
opt['dist'] = False
# 强制设置为CPU模式
opt['num_gpu'] = 0
if 'gpu_ids' in opt:
opt['gpu_ids'] = None
# 加载yaml获取网络类型
with open(config_info['yaml'], mode='r') as f:
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
x = yaml.load(f, Loader=Loader)
x['network_g'].pop('type')
# 直接创建ASTv2模型而不使用create_model
try:
from basicsr.models.archs.ASTv2_arch import ASTv2
model = ASTv2(**x['network_g'])
except:
# 如果直接导入失败,尝试使用create_model但确保CPU模式
model = create_model(opt).net_g
# 加载权重到CPU
checkpoint = torch.load(config_info['weights'], map_location='cpu')
try:
model.load_state_dict(checkpoint['params'])
except:
# 处理DataParallel包装的权重
new_checkpoint = {}
for k in checkpoint['params']:
if k.startswith('module.'):
new_checkpoint[k[7:]] = checkpoint['params'][k] # 移除'module.'前缀
else:
new_checkpoint['module.' + k] = checkpoint['params'][k]
try:
model.load_state_dict(new_checkpoint)
except:
# 如果还是失败,尝试原始的module.前缀方式
new_checkpoint2 = {}
for k in checkpoint['params']:
new_checkpoint2['module.' + k] = checkpoint['params'][k]
model.load_state_dict(new_checkpoint2)
print(f"===>Testing using weights: {config_info['weights']}")
model.to(self.device)
# CPU版本不使用DataParallel
model.eval()
elif self.task == 'deblurring':
# 对于deblurring,直接导入ASTv2架构
from basicsr.models.archs.ASTv2_arch import ASTv2
with open(config_info['yaml'], mode='r') as f:
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
x = yaml.load(f, Loader=Loader)
x['network_g'].pop('type')
model = ASTv2(**x['network_g'])
# 加载权重到CPU
checkpoint = torch.load(config_info['weights'], map_location='cpu')
# 处理可能的DataParallel权重格式
try:
model.load_state_dict(checkpoint['params'])
except:
# 尝试移除module.前缀
new_checkpoint = {}
for k in checkpoint['params']:
if k.startswith('module.'):
new_checkpoint[k[7:]] = checkpoint['params'][k]
else:
new_checkpoint[k] = checkpoint['params'][k]
model.load_state_dict(new_checkpoint)
print(f"===>Testing using weights: {config_info['weights']}")
model.to(self.device)
# CPU版本不使用DataParallel
model.eval()
else:
# 原有的加载方式(其他任务)
yaml_file = config_info['yaml']
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
with open(yaml_file, mode='r') as f:
config = yaml.load(f, Loader=Loader)
config['network_g'].pop('type', None)
# 动态导入正确的模型架构
if 'dehazing' in self.task or 'deshadowing' in self.task or 'desnowing' in self.task or 'deraining' in self.task:
from basicsr.models.archs.ASTv2_arch import ASTv2
model = ASTv2(**config['network_g'])
# 加载权重到CPU
checkpoint = torch.load(config_info['weights'], map_location='cpu')
# 处理可能的DataParallel权重格式
try:
model.load_state_dict(checkpoint['params'])
except:
# 尝试移除module.前缀
new_checkpoint = {}
for k in checkpoint['params']:
if k.startswith('module.'):
new_checkpoint[k[7:]] = checkpoint['params'][k]
else:
new_checkpoint[k] = checkpoint['params'][k]
model.load_state_dict(new_checkpoint)
print(f"===>Testing using weights: {config_info['weights']}")
model.to(self.device)
model.eval()
return model
def splitimage(self, imgtensor, crop_size=128, overlap_size=64):
"""将大图像分割成小块以节省内存"""
_, C, H, W = imgtensor.shape
hstarts = [x for x in range(0, H, crop_size - overlap_size)]
while hstarts and hstarts[-1] + crop_size >= H:
hstarts.pop()
hstarts.append(H - crop_size)
wstarts = [x for x in range(0, W, crop_size - overlap_size)]
while wstarts and wstarts[-1] + crop_size >= W:
wstarts.pop()
wstarts.append(W - crop_size)
starts = []
split_data = []
for hs in hstarts:
for ws in wstarts:
cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
starts.append((hs, ws))
split_data.append(cimgdata)
return split_data, starts
def get_scoremap(self, H, W, C, B=1, is_mean=True):
"""生成用于图像融合的得分图"""
center_h = H / 2
center_w = W / 2
score = torch.ones((B, C, H, W))
if not is_mean:
for h in range(H):
for w in range(W):
score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
return score
def mergeimage(self, split_data, starts, crop_size=128, resolution=(1, 3, 128, 128)):
"""合并分割的图像块"""
B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
tot_score = torch.zeros((B, C, H, W))
merge_img = torch.zeros((B, C, H, W))
scoremap = self.get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
for simg, cstart in zip(split_data, starts):
hs, ws = cstart
merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
merge_img = merge_img / tot_score
return merge_img
def process_image(self, image_path, use_split=None, target_image=None):
"""
处理单张图像
Args:
image_path: 图像路径或PIL Image对象
use_split: 是否使用分块处理(None时根据配置自动选择)
target_image: 目标图像(用于GT mean校正,仅enhancement任务需要)
Returns:
处理后的图像(numpy array)
"""
# 加载图像
if isinstance(image_path, str):
img = Image.open(image_path).convert('RGB')
else:
img = image_path.convert('RGB')
img_np = np.array(img).astype(np.float32) / 255.
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
input_ = img_tensor.unsqueeze(0).to(self.device)
# 确定是否使用分块
if use_split is None:
use_split = self.config_map[self.task].get('use_split', False)
with torch.no_grad():
# CPU版本不需要CUDA内存清理
# 移除了 torch.cuda.ipc_collect() 和 torch.cuda.empty_cache()
if use_split:
# 使用分块处理
B, C, H, W = input_.shape
config_info = self.config_map[self.task]
crop_size = config_info.get('crop_size', 256)
overlap_size = config_info.get('overlap_size', 200)
split_data, starts = self.splitimage(input_, crop_size=crop_size,
overlap_size=overlap_size)
# 处理每个块
for i, data in enumerate(split_data):
split_data[i] = self.model(data).cpu()
# 合并结果
restored = self.mergeimage(split_data, starts, crop_size=crop_size,
resolution=(B, C, H, W))
else:
# 直接处理(使用padding)
h, w = input_.shape[2], input_.shape[3]
H, W = ((h + self.factor) // self.factor) * self.factor, \
((w + self.factor) // self.factor) * self.factor
padh = H - h if h % self.factor != 0 else 0
padw = W - w if w % self.factor != 0 else 0
input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
restored = self.model(input_)
restored = restored[:, :, :h, :w]
# 转换回numpy格式
restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
# GT mean校正(仅用于enhancement任务)
if self.use_GT_mean and 'enhancement' in self.task and target_image is not None:
if isinstance(target_image, str):
target = np.array(Image.open(target_image).convert('RGB')).astype(np.float32) / 255.
else:
target = np.array(target_image.convert('RGB')).astype(np.float32) / 255.
mean_restored = cv2.cvtColor(restored.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
mean_target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
restored = np.clip(restored * (mean_target / mean_restored), 0, 1)
restored = img_as_ubyte(restored)
return restored