Spaces:
Running
Running
| import numpy as np | |
| import os | |
| import math | |
| from tqdm import tqdm | |
| import cv2 | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| from basicsr.models import create_model | |
| from basicsr.utils.options import parse | |
| from skimage import img_as_ubyte | |
| import yaml | |
| from PIL import Image | |
| class ASTv2ImageRestoration: | |
| def __init__(self, task='dehazing', device='cpu', use_GT_mean=False): | |
| """ | |
| 初始化ASTv2图像恢复模型 - CPU版本 | |
| Args: | |
| task: 任务类型 - 'dehazing', 'deshadowing', 'desnowing', 'deraining', | |
| 'enhancement_lol_v2_real', 'enhancement_lol_v2_synthetic', | |
| 'enhancement_smid', 'deblurring' | |
| device: 运行设备 - 固定为'cpu'(为了兼容性保留此参数) | |
| use_GT_mean: 是否使用GT均值校正(仅用于enhancement任务) | |
| """ | |
| self.task = task | |
| # 无论传入什么device参数,都强制使用CPU | |
| self.device = 'cpu' | |
| self.use_GT_mean = use_GT_mean | |
| # 如果传入的device不是cpu,给出提示 | |
| if device != 'cpu': | |
| print(f"注意: 传入的device参数为'{device}',但此版本只支持CPU,已自动设置为'cpu'") | |
| self.factor = 4 if 'enhancement' in task else 8 | |
| # 配置文件和权重路径映射 | |
| self.config_map = { | |
| 'dehazing': { | |
| 'yaml': './options/RealDehazing_ASTv2_syn.yml', | |
| 'weights': './model/synHaze_ASTv2.pth', | |
| 'use_split': True, | |
| 'crop_size': 256, | |
| 'overlap_size': 200 | |
| }, | |
| 'deshadowing': { | |
| 'yaml': './options/Deshadowing_ASTv2.yml', | |
| 'weights': './model/shaodw_ASTv2.pth', | |
| 'use_split': False | |
| }, | |
| 'desnowing': { | |
| 'yaml': './options/Desnowing_ASTv2.yml', | |
| 'weights': './model/snow_astv2.pth', | |
| 'use_split': False | |
| }, | |
| 'deraining': { | |
| 'yaml': './options/Deraining_ASTv2_spad.yml', | |
| 'weights': './model/rainStreak_ASTv2.pth', | |
| 'use_split': False | |
| }, | |
| 'enhancement_lol_v2_real': { | |
| 'yaml': './options/AST_LOL_v2_real.yml', | |
| 'weights': './model/LOLv2_Real_ASTv2.pth', | |
| 'use_split': True, | |
| 'crop_size': 400, | |
| 'overlap_size': 384 | |
| }, | |
| 'enhancement_lol_v2_synthetic': { | |
| 'yaml': './options/AST_LOL_v2_synthetic.yml', | |
| 'weights': './model/LOLv2_Syn_ASTv2.pth', | |
| 'use_split': False | |
| }, | |
| 'enhancement_smid': { | |
| 'yaml': './options/AST_SMID.yml', | |
| 'weights': './model/SMID_ASTv2.pth', | |
| 'use_split': False | |
| }, | |
| 'deblurring': { | |
| 'yaml': './options/Deblurring_ASTv2_L.yml', | |
| 'weights': './model/motionblur_ASTv2.pth', | |
| 'use_split': True, | |
| 'crop_size': 512, | |
| 'overlap_size': 300 | |
| } | |
| } | |
| # 加载模型 | |
| self.model = self._load_model() | |
| def _load_model(self): | |
| """加载模型配置和权重""" | |
| config_info = self.config_map[self.task] | |
| # 对于enhancement和deblurring任务,使用不同的加载方式 | |
| if 'enhancement' in self.task: | |
| # 使用basicsr的create_model | |
| opt = parse(config_info['yaml'], is_train=False) | |
| opt['dist'] = False | |
| # 强制设置为CPU模式 | |
| opt['num_gpu'] = 0 | |
| if 'gpu_ids' in opt: | |
| opt['gpu_ids'] = None | |
| # 加载yaml获取网络类型 | |
| with open(config_info['yaml'], mode='r') as f: | |
| try: | |
| from yaml import CLoader as Loader | |
| except ImportError: | |
| from yaml import Loader | |
| x = yaml.load(f, Loader=Loader) | |
| x['network_g'].pop('type') | |
| # 直接创建ASTv2模型而不使用create_model | |
| try: | |
| from basicsr.models.archs.ASTv2_arch import ASTv2 | |
| model = ASTv2(**x['network_g']) | |
| except: | |
| # 如果直接导入失败,尝试使用create_model但确保CPU模式 | |
| model = create_model(opt).net_g | |
| # 加载权重到CPU | |
| checkpoint = torch.load(config_info['weights'], map_location='cpu') | |
| try: | |
| model.load_state_dict(checkpoint['params']) | |
| except: | |
| # 处理DataParallel包装的权重 | |
| new_checkpoint = {} | |
| for k in checkpoint['params']: | |
| if k.startswith('module.'): | |
| new_checkpoint[k[7:]] = checkpoint['params'][k] # 移除'module.'前缀 | |
| else: | |
| new_checkpoint['module.' + k] = checkpoint['params'][k] | |
| try: | |
| model.load_state_dict(new_checkpoint) | |
| except: | |
| # 如果还是失败,尝试原始的module.前缀方式 | |
| new_checkpoint2 = {} | |
| for k in checkpoint['params']: | |
| new_checkpoint2['module.' + k] = checkpoint['params'][k] | |
| model.load_state_dict(new_checkpoint2) | |
| print(f"===>Testing using weights: {config_info['weights']}") | |
| model.to(self.device) | |
| # CPU版本不使用DataParallel | |
| model.eval() | |
| elif self.task == 'deblurring': | |
| # 对于deblurring,直接导入ASTv2架构 | |
| from basicsr.models.archs.ASTv2_arch import ASTv2 | |
| with open(config_info['yaml'], mode='r') as f: | |
| try: | |
| from yaml import CLoader as Loader | |
| except ImportError: | |
| from yaml import Loader | |
| x = yaml.load(f, Loader=Loader) | |
| x['network_g'].pop('type') | |
| model = ASTv2(**x['network_g']) | |
| # 加载权重到CPU | |
| checkpoint = torch.load(config_info['weights'], map_location='cpu') | |
| # 处理可能的DataParallel权重格式 | |
| try: | |
| model.load_state_dict(checkpoint['params']) | |
| except: | |
| # 尝试移除module.前缀 | |
| new_checkpoint = {} | |
| for k in checkpoint['params']: | |
| if k.startswith('module.'): | |
| new_checkpoint[k[7:]] = checkpoint['params'][k] | |
| else: | |
| new_checkpoint[k] = checkpoint['params'][k] | |
| model.load_state_dict(new_checkpoint) | |
| print(f"===>Testing using weights: {config_info['weights']}") | |
| model.to(self.device) | |
| # CPU版本不使用DataParallel | |
| model.eval() | |
| else: | |
| # 原有的加载方式(其他任务) | |
| yaml_file = config_info['yaml'] | |
| try: | |
| from yaml import CLoader as Loader | |
| except ImportError: | |
| from yaml import Loader | |
| with open(yaml_file, mode='r') as f: | |
| config = yaml.load(f, Loader=Loader) | |
| config['network_g'].pop('type', None) | |
| # 动态导入正确的模型架构 | |
| if 'dehazing' in self.task or 'deshadowing' in self.task or 'desnowing' in self.task or 'deraining' in self.task: | |
| from basicsr.models.archs.ASTv2_arch import ASTv2 | |
| model = ASTv2(**config['network_g']) | |
| # 加载权重到CPU | |
| checkpoint = torch.load(config_info['weights'], map_location='cpu') | |
| # 处理可能的DataParallel权重格式 | |
| try: | |
| model.load_state_dict(checkpoint['params']) | |
| except: | |
| # 尝试移除module.前缀 | |
| new_checkpoint = {} | |
| for k in checkpoint['params']: | |
| if k.startswith('module.'): | |
| new_checkpoint[k[7:]] = checkpoint['params'][k] | |
| else: | |
| new_checkpoint[k] = checkpoint['params'][k] | |
| model.load_state_dict(new_checkpoint) | |
| print(f"===>Testing using weights: {config_info['weights']}") | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| def splitimage(self, imgtensor, crop_size=128, overlap_size=64): | |
| """将大图像分割成小块以节省内存""" | |
| _, C, H, W = imgtensor.shape | |
| hstarts = [x for x in range(0, H, crop_size - overlap_size)] | |
| while hstarts and hstarts[-1] + crop_size >= H: | |
| hstarts.pop() | |
| hstarts.append(H - crop_size) | |
| wstarts = [x for x in range(0, W, crop_size - overlap_size)] | |
| while wstarts and wstarts[-1] + crop_size >= W: | |
| wstarts.pop() | |
| wstarts.append(W - crop_size) | |
| starts = [] | |
| split_data = [] | |
| for hs in hstarts: | |
| for ws in wstarts: | |
| cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] | |
| starts.append((hs, ws)) | |
| split_data.append(cimgdata) | |
| return split_data, starts | |
| def get_scoremap(self, H, W, C, B=1, is_mean=True): | |
| """生成用于图像融合的得分图""" | |
| center_h = H / 2 | |
| center_w = W / 2 | |
| score = torch.ones((B, C, H, W)) | |
| if not is_mean: | |
| for h in range(H): | |
| for w in range(W): | |
| score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) | |
| return score | |
| def mergeimage(self, split_data, starts, crop_size=128, resolution=(1, 3, 128, 128)): | |
| """合并分割的图像块""" | |
| B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] | |
| tot_score = torch.zeros((B, C, H, W)) | |
| merge_img = torch.zeros((B, C, H, W)) | |
| scoremap = self.get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) | |
| for simg, cstart in zip(split_data, starts): | |
| hs, ws = cstart | |
| merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg | |
| tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap | |
| merge_img = merge_img / tot_score | |
| return merge_img | |
| def process_image(self, image_path, use_split=None, target_image=None): | |
| """ | |
| 处理单张图像 | |
| Args: | |
| image_path: 图像路径或PIL Image对象 | |
| use_split: 是否使用分块处理(None时根据配置自动选择) | |
| target_image: 目标图像(用于GT mean校正,仅enhancement任务需要) | |
| Returns: | |
| 处理后的图像(numpy array) | |
| """ | |
| # 加载图像 | |
| if isinstance(image_path, str): | |
| img = Image.open(image_path).convert('RGB') | |
| else: | |
| img = image_path.convert('RGB') | |
| img_np = np.array(img).astype(np.float32) / 255. | |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1) | |
| input_ = img_tensor.unsqueeze(0).to(self.device) | |
| # 确定是否使用分块 | |
| if use_split is None: | |
| use_split = self.config_map[self.task].get('use_split', False) | |
| with torch.no_grad(): | |
| # CPU版本不需要CUDA内存清理 | |
| # 移除了 torch.cuda.ipc_collect() 和 torch.cuda.empty_cache() | |
| if use_split: | |
| # 使用分块处理 | |
| B, C, H, W = input_.shape | |
| config_info = self.config_map[self.task] | |
| crop_size = config_info.get('crop_size', 256) | |
| overlap_size = config_info.get('overlap_size', 200) | |
| split_data, starts = self.splitimage(input_, crop_size=crop_size, | |
| overlap_size=overlap_size) | |
| # 处理每个块 | |
| for i, data in enumerate(split_data): | |
| split_data[i] = self.model(data).cpu() | |
| # 合并结果 | |
| restored = self.mergeimage(split_data, starts, crop_size=crop_size, | |
| resolution=(B, C, H, W)) | |
| else: | |
| # 直接处理(使用padding) | |
| h, w = input_.shape[2], input_.shape[3] | |
| H, W = ((h + self.factor) // self.factor) * self.factor, \ | |
| ((w + self.factor) // self.factor) * self.factor | |
| padh = H - h if h % self.factor != 0 else 0 | |
| padw = W - w if w % self.factor != 0 else 0 | |
| input_ = F.pad(input_, (0, padw, 0, padh), 'reflect') | |
| restored = self.model(input_) | |
| restored = restored[:, :, :h, :w] | |
| # 转换回numpy格式 | |
| restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() | |
| # GT mean校正(仅用于enhancement任务) | |
| if self.use_GT_mean and 'enhancement' in self.task and target_image is not None: | |
| if isinstance(target_image, str): | |
| target = np.array(Image.open(target_image).convert('RGB')).astype(np.float32) / 255. | |
| else: | |
| target = np.array(target_image.convert('RGB')).astype(np.float32) / 255. | |
| mean_restored = cv2.cvtColor(restored.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() | |
| mean_target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() | |
| restored = np.clip(restored * (mean_target / mean_restored), 0, 1) | |
| restored = img_as_ubyte(restored) | |
| return restored | |