Spaces:
Sleeping
Sleeping
| # model.py - 完整修复版本,确保与原始测试脚本处理逻辑完全一致 | |
| import yaml | |
| import torch | |
| import math | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from io import BytesIO | |
| import cv2 | |
| from basicsr.models import create_model | |
| from basicsr.models.archs.HINT_arch import HINT | |
| from basicsr.utils.options import parse | |
| from skimage import img_as_ubyte | |
| import gc | |
| import os | |
| # 强制使用 CPU,以匹配 Hugging Face Space 的典型环境 | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| torch.backends.cudnn.enabled = False | |
| # 根据需要调整线程数,以优化在CPU上的性能 | |
| torch.set_num_threads(2) | |
| torch.set_num_interop_threads(1) | |
| device = torch.device('cpu') | |
| # 全局模型变量 | |
| dehazing_model = None | |
| deraining_model = None | |
| desnowing_model = None | |
| denoising_blind_model = None | |
| enhancement_synthetic_model = None | |
| enhancement_real_model = None | |
| def splitimage(imgtensor, crop_size=128, overlap_size=64): | |
| """原始切块函数 - 与测试脚本完全一致""" | |
| _, C, H, W = imgtensor.shape | |
| hstarts = [x for x in range(0, H, crop_size - overlap_size)] | |
| while hstarts and hstarts[-1] + crop_size >= H: | |
| hstarts.pop() | |
| hstarts.append(H - crop_size) | |
| wstarts = [x for x in range(0, W, crop_size - overlap_size)] | |
| while wstarts and wstarts[-1] + crop_size >= W: | |
| wstarts.pop() | |
| wstarts.append(W - crop_size) | |
| starts = [] | |
| split_data = [] | |
| for hs in hstarts: | |
| for ws in wstarts: | |
| cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] | |
| starts.append((hs, ws)) | |
| split_data.append(cimgdata) | |
| return split_data, starts | |
| def get_scoremap(H, W, C, B=1, is_mean=True): | |
| """原始权重图生成函数 - 与测试脚本完全一致""" | |
| center_h = H / 2 | |
| center_w = W / 2 | |
| score = torch.ones((B, C, H, W)) | |
| if not is_mean: | |
| for h in range(H): | |
| for w in range(W): | |
| score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) | |
| return score | |
| def mergeimage(split_data, starts, crop_size=128, resolution=(1, 3, 128, 128)): | |
| """原始合并函数 - 与测试脚本完全一致""" | |
| B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] | |
| tot_score = torch.zeros((B, C, H, W)) | |
| merge_img = torch.zeros((B, C, H, W)) | |
| scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) | |
| for simg, cstart in zip(split_data, starts): | |
| hs, ws = cstart | |
| merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg | |
| tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap | |
| merge_img = merge_img / tot_score | |
| return merge_img | |
| def load_hint_model(config_path, weight_path): | |
| """加载HINT模型 - 逻辑与测试脚本完全一致""" | |
| try: | |
| if not os.path.exists(config_path) or not os.path.exists(weight_path): | |
| print(f"HINT模型文件不存在: config='{config_path}', weights='{weight_path}'") | |
| return None | |
| try: | |
| from yaml import CLoader as Loader | |
| except ImportError: | |
| from yaml import Loader | |
| x = yaml.load(open(config_path, mode='r'), Loader=Loader) | |
| # 与测试脚本完全一致的网络配置处理 | |
| network_config = x['network_g'] | |
| model_type = network_config.pop('type') | |
| model = HINT(**network_config) | |
| # 与测试脚本一致的权重加载方式 | |
| checkpoint = torch.load(weight_path, map_location=device) | |
| if 'params' in checkpoint: | |
| model.load_state_dict(checkpoint['params']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| model = model.to(device) | |
| print(f"HINT模型加载成功: {weight_path}") | |
| return model | |
| except Exception as e: | |
| print(f"HINT模型加载失败: {e}") | |
| return None | |
| def load_basicsr_model(config_path, weight_path): | |
| """加载BasicSR模型 - 与测试脚本完全一致""" | |
| try: | |
| if not os.path.exists(config_path) or not os.path.exists(weight_path): | |
| print(f"BasicSR模型文件不存在: config='{config_path}', weights='{weight_path}'") | |
| return None | |
| opt = parse(config_path, is_train=False) | |
| opt['dist'] = False | |
| opt['num_gpu'] = 0 | |
| opt['device'] = 'cpu' | |
| model = create_model(opt).net_g | |
| checkpoint = torch.load(weight_path, map_location=device) | |
| # 与测试脚本完全一致的权重加载逻辑 | |
| try: | |
| model.load_state_dict(checkpoint['params']) | |
| except: | |
| new_checkpoint = {} | |
| for k in checkpoint['params']: | |
| new_checkpoint['module.' + k] = checkpoint['params'][k] | |
| model.load_state_dict(new_checkpoint) | |
| model.eval() | |
| model = model.to(device) | |
| print(f"BasicSR模型加载成功: {weight_path}") | |
| return model | |
| except Exception as e: | |
| print(f"BasicSR模型加载失败: {e}") | |
| return None | |
| def init(): | |
| """初始化所有模型""" | |
| global dehazing_model, deraining_model, desnowing_model, denoising_blind_model | |
| global enhancement_synthetic_model, enhancement_real_model | |
| print("开始加载所有模型...") | |
| dehazing_model = load_hint_model("./options/RealDehazing_HINT.yml", "./models/Dehazing.pth") | |
| deraining_model = load_hint_model("./options/Deraining_HINT_syn_rain100L.yml", "./models/Rain100L_HINT.pth") | |
| desnowing_model = load_hint_model("./options/Desnow_snow100k_HINT.yml", "./models/snow100k.pth") | |
| denoising_blind_model = load_hint_model("./options/GaussianColorDenoising_HINT.yml", | |
| "./models/net_g_latest_blind.pth") | |
| enhancement_synthetic_model = load_basicsr_model("./options/HINT_LOL_v2_synthetic.yml", "./models/lolv2Syn.pth") | |
| enhancement_real_model = load_basicsr_model("./options/HINT_LOL_v2_real.yml", "./models/lolv2Real.pth") | |
| gc.collect() | |
| print("所有模型加载完成。") | |
| def preprocess_image(img, task=None, sigma=None): | |
| """预处理图像 - 完全按照测试脚本的处理逻辑""" | |
| # 转换为numpy数组 | |
| if isinstance(img, Image.Image): | |
| arr = np.array(img) | |
| else: | |
| arr = img | |
| # 对于去噪任务,完全按照原始代码的处理逻辑 | |
| if task == 'denoising_blind' and sigma is not None and sigma > 0: | |
| # 步骤1:归一化为 float32 (与原始代码一致) | |
| img_normalized = np.float32(arr) / 255.0 | |
| # 步骤2:设置随机种子 (与原始代码一致) | |
| np.random.seed(seed=0) | |
| # 步骤3:添加高斯噪声 (与原始代码完全一致) | |
| img_normalized += np.random.normal(0, sigma / 255., img_normalized.shape) | |
| # 步骤4:转换为tensor (与原始代码一致: permute(2,0,1)) | |
| tensor = torch.from_numpy(img_normalized).permute(2, 0, 1).unsqueeze(0) | |
| else: | |
| # 其他任务的常规处理 - 完全按照测试脚本逻辑 | |
| # 测试脚本: img = np.float32(utils.load_img(inp_path)) / 255. | |
| # 测试脚本: img = torch.from_numpy(img).permute(2, 0, 1) | |
| arr = np.float32(arr) / 255.0 | |
| tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) | |
| return tensor | |
| def inference(body: bytes, task_type: str, sigma: int = None, GT_mean: bool = False, | |
| target_image: bytes = None) -> bytes: | |
| """ | |
| 统一推理接口 - 完全按照测试脚本的处理逻辑 | |
| """ | |
| model_map = { | |
| "dehazing": dehazing_model, | |
| "deraining": deraining_model, | |
| "desnowing": desnowing_model, | |
| "denoising_blind": denoising_blind_model, | |
| "enhancement_synthetic": enhancement_synthetic_model, | |
| "enhancement_real": enhancement_real_model | |
| } | |
| model = model_map.get(task_type) | |
| if model is None: | |
| raise Exception(f"模型 '{task_type}' 未加载或不存在。") | |
| img = Image.open(BytesIO(body)).convert("RGB") | |
| input_tensor = preprocess_image(img, task_type, sigma).to(device) | |
| orig_h, orig_w = input_tensor.shape[2], input_tensor.shape[3] | |
| # 如果需要GT校正,加载目标图像 | |
| target_np = None | |
| if GT_mean and target_image is not None: | |
| target_img = Image.open(BytesIO(target_image)).convert("RGB") | |
| target_np = np.array(target_img).astype(np.float32) / 255.0 | |
| try: | |
| with torch.no_grad(): | |
| B, C, H, W = input_tensor.shape | |
| # 低光增强任务 - 按照原始测试脚本逻辑处理 | |
| if task_type in ["enhancement_synthetic", "enhancement_real"]: | |
| factor = 4 # 与测试脚本一致 | |
| # 按照原始代码的padding逻辑 | |
| H_pad = ((orig_h + factor) // factor) * factor | |
| W_pad = ((orig_w + factor) // factor) * factor | |
| padh = H_pad - orig_h if orig_h % factor != 0 else 0 | |
| padw = W_pad - orig_w if orig_w % factor != 0 else 0 | |
| padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') | |
| # 模型推理 | |
| restored_tensor = model(padded_input) | |
| # 去除padding,恢复原始尺寸 | |
| restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] | |
| # 去雾任务 - 按照测试脚本逻辑处理 | |
| elif task_type == "dehazing": | |
| factor = 8 | |
| # 按照测试脚本的padding逻辑 | |
| H_pad = ((orig_h + factor) // factor) * factor | |
| W_pad = ((orig_w + factor) // factor) * factor | |
| padh = H_pad - orig_h if orig_h % factor != 0 else 0 | |
| padw = W_pad - orig_w if orig_w % factor != 0 else 0 | |
| padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') | |
| # 使用测试脚本的切块参数 | |
| corp_size_arg = 256 | |
| overlap_size_arg = 64 | |
| B, C, H_pad, W_pad = padded_input.shape | |
| split_data, starts = splitimage(padded_input, crop_size=corp_size_arg, overlap_size=overlap_size_arg) | |
| processed_splits = [] | |
| for data in split_data: | |
| processed_splits.append(model(data).cpu()) | |
| restored_tensor = mergeimage(processed_splits, starts, crop_size=corp_size_arg, | |
| resolution=(B, C, H_pad, W_pad)) | |
| # 去除padding,恢复原始尺寸 | |
| restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] | |
| # 除雪任务 - 需要切块处理 | |
| elif task_type == "desnowing": | |
| factor = 8 | |
| crop_size = 256 | |
| overlap_size = 128 | |
| H_pad = ((orig_h + factor) // factor) * factor | |
| W_pad = ((orig_w + factor) // factor) * factor | |
| padh = H_pad - orig_h if orig_h % factor != 0 else 0 | |
| padw = W_pad - orig_w if orig_w % factor != 0 else 0 | |
| padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') | |
| split_data, starts = splitimage(padded_input, crop_size=crop_size, overlap_size=overlap_size) | |
| processed_splits = [model(s) for s in split_data] | |
| restored_tensor = mergeimage(processed_splits, starts, crop_size=crop_size, | |
| resolution=(1, 3, H_pad, W_pad)) | |
| # 去除padding,恢复原始尺寸 | |
| restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] | |
| # 去雨和降噪任务 - 全图处理 | |
| else: # deraining, denoising_blind | |
| factor = 8 | |
| H_pad = ((orig_h + factor) // factor) * factor | |
| W_pad = ((orig_w + factor) // factor) * factor | |
| padh = H_pad - orig_h if orig_h % factor != 0 else 0 | |
| padw = W_pad - orig_w if orig_w % factor != 0 else 0 | |
| padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') | |
| restored_tensor = model(padded_input) | |
| # 去除padding,恢复原始尺寸 | |
| restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] | |
| # 转换为numpy数组 | |
| restored_np = torch.clamp(restored_tensor, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() | |
| # GT均值校正 - 完全按照原始代码逻辑 | |
| if GT_mean and target_np is not None: | |
| # 按照原始代码的逻辑进行GT均值校正 | |
| # 原始代码: mean_restored = cv2.cvtColor(restored.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() | |
| # 原始代码: mean_target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() | |
| # 原始代码: restored = np.clip(restored * (mean_target / mean_restored), 0, 1) | |
| mean_restored = cv2.cvtColor(restored_np.astype(np.float32), cv2.COLOR_RGB2GRAY).mean() | |
| mean_target = cv2.cvtColor(target_np.astype(np.float32), cv2.COLOR_RGB2GRAY).mean() | |
| restored_np = np.clip(restored_np * (mean_target / mean_restored), 0, 1) | |
| # 转换为8位图像格式 | |
| result_np = img_as_ubyte(restored_np) | |
| out_img = Image.fromarray(result_np) | |
| buf = BytesIO() | |
| out_img.save(buf, format="PNG") | |
| gc.collect() | |
| return buf.getvalue() | |
| except Exception as e: | |
| print(f"在任务 '{task_type}' 中推理失败: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise e | |
| # --- 导出的推理函数 --- | |
| def dehazing_inference(body: bytes) -> bytes: | |
| return inference(body, "dehazing") | |
| def deraining_inference(body: bytes) -> bytes: | |
| return inference(body, "deraining") | |
| def desnowing_inference(body: bytes) -> bytes: | |
| return inference(body, "desnowing") | |
| def denoising_blind_inference(body: bytes, sigma: int = 25) -> bytes: | |
| return inference(body, "denoising_blind", sigma=sigma) | |
| def enhancement_synthetic_inference(body: bytes, GT_mean: bool = False, target_image: bytes = None) -> bytes: | |
| """低光增强(合成)推理""" | |
| return inference(body, "enhancement_synthetic", GT_mean=GT_mean, target_image=target_image) | |
| def enhancement_real_inference(body: bytes, GT_mean: bool = False, target_image: bytes = None) -> bytes: | |
| """低光增强(真实)推理""" | |
| return inference(body, "enhancement_real", GT_mean=GT_mean, target_image=target_image) | |
| # 为了保持向后兼容性,提供简化的调用方式 | |
| def enhancement_synthetic_inference_simple(body: bytes) -> bytes: | |
| """简化的低光增强(合成)推理,不使用GT校正""" | |
| return inference(body, "enhancement_synthetic") | |
| def enhancement_real_inference_simple(body: bytes) -> bytes: | |
| """简化的低光增强(真实)推理,不使用GT校正""" | |
| return inference(body, "enhancement_real") | |
| # --- 辅助函数 --- | |
| def get_model_status(): | |
| """获取模型加载状态""" | |
| return { | |
| "dehazing_model_loaded": dehazing_model is not None, | |
| "deraining_model_loaded": deraining_model is not None, | |
| "desnowing_model_loaded": desnowing_model is not None, | |
| "denoising_blind_model_loaded": denoising_blind_model is not None, | |
| "enhancement_synthetic_model_loaded": enhancement_synthetic_model is not None, | |
| "enhancement_real_model_loaded": enhancement_real_model is not None | |
| } | |
| def get_available_tasks(): | |
| """获取可用的任务列表""" | |
| return ['dehazing', 'deraining', 'desnowing', 'denoising_blind', 'enhancement_synthetic', 'enhancement_real'] | |
| def get_task_description(task): | |
| """获取任务描述""" | |
| descriptions = { | |
| 'dehazing': '去除图像中的雾霾效果', | |
| 'deraining': '去除图像中的雨线', | |
| 'desnowing': '去除图像中的雪花', | |
| 'denoising_blind': '通用降噪', | |
| 'enhancement_synthetic': '低光增强 (合成数据训练)', | |
| 'enhancement_real': '低光增强 (真实数据训练)' | |
| } | |
| return descriptions.get(task, '未知任务') | |
| # --- 测试函数 --- | |
| def test_model_loading(): | |
| """测试模型加载情况""" | |
| print("测试模型加载情况:") | |
| status = get_model_status() | |
| for model_name, loaded in status.items(): | |
| print(f" {model_name}: {'✓' if loaded else '✗'}") | |
| print() | |
| def test_inference(task_type, test_image_path=None): | |
| """测试推理功能""" | |
| if test_image_path is None: | |
| # 创建一个测试图像 | |
| test_img = Image.new('RGB', (256, 256), color='red') | |
| buf = BytesIO() | |
| test_img.save(buf, format='PNG') | |
| test_bytes = buf.getvalue() | |
| else: | |
| with open(test_image_path, 'rb') as f: | |
| test_bytes = f.read() | |
| try: | |
| if task_type == 'dehazing': | |
| result = dehazing_inference(test_bytes) | |
| elif task_type == 'deraining': | |
| result = deraining_inference(test_bytes) | |
| elif task_type == 'desnowing': | |
| result = desnowing_inference(test_bytes) | |
| elif task_type == 'denoising_blind': | |
| result = denoising_blind_inference(test_bytes, sigma=25) | |
| elif task_type == 'enhancement_synthetic': | |
| result = enhancement_synthetic_inference(test_bytes) | |
| elif task_type == 'enhancement_real': | |
| result = enhancement_real_inference(test_bytes) | |
| else: | |
| print(f"未知任务类型: {task_type}") | |
| return False | |
| print(f"✓ {task_type} 推理测试成功,输出大小: {len(result)} bytes") | |
| return True | |
| except Exception as e: | |
| print(f"✗ {task_type} 推理测试失败: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| # 初始化模型 | |
| init() | |
| # 测试模型加载 | |
| test_model_loading() | |
| # 测试所有任务的推理 | |
| print("测试推理功能:") | |
| for task in get_available_tasks(): | |
| test_inference(task) | |
| print("\n所有测试完成!") |