# model.py - 完整修复版本,确保与原始测试脚本处理逻辑完全一致 import yaml import torch import math import numpy as np import torch.nn.functional as F from PIL import Image from io import BytesIO import cv2 from basicsr.models import create_model from basicsr.models.archs.HINT_arch import HINT from basicsr.utils.options import parse from skimage import img_as_ubyte import gc import os # 强制使用 CPU,以匹配 Hugging Face Space 的典型环境 os.environ['CUDA_VISIBLE_DEVICES'] = '' torch.backends.cudnn.enabled = False # 根据需要调整线程数,以优化在CPU上的性能 torch.set_num_threads(2) torch.set_num_interop_threads(1) device = torch.device('cpu') # 全局模型变量 dehazing_model = None deraining_model = None desnowing_model = None denoising_blind_model = None enhancement_synthetic_model = None enhancement_real_model = None def splitimage(imgtensor, crop_size=128, overlap_size=64): """原始切块函数 - 与测试脚本完全一致""" _, C, H, W = imgtensor.shape hstarts = [x for x in range(0, H, crop_size - overlap_size)] while hstarts and hstarts[-1] + crop_size >= H: hstarts.pop() hstarts.append(H - crop_size) wstarts = [x for x in range(0, W, crop_size - overlap_size)] while wstarts and wstarts[-1] + crop_size >= W: wstarts.pop() wstarts.append(W - crop_size) starts = [] split_data = [] for hs in hstarts: for ws in wstarts: cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] starts.append((hs, ws)) split_data.append(cimgdata) return split_data, starts def get_scoremap(H, W, C, B=1, is_mean=True): """原始权重图生成函数 - 与测试脚本完全一致""" center_h = H / 2 center_w = W / 2 score = torch.ones((B, C, H, W)) if not is_mean: for h in range(H): for w in range(W): score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) return score def mergeimage(split_data, starts, crop_size=128, resolution=(1, 3, 128, 128)): """原始合并函数 - 与测试脚本完全一致""" B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] tot_score = torch.zeros((B, C, H, W)) merge_img = torch.zeros((B, C, H, W)) scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) for simg, cstart in zip(split_data, starts): hs, ws = cstart merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap merge_img = merge_img / tot_score return merge_img def load_hint_model(config_path, weight_path): """加载HINT模型 - 逻辑与测试脚本完全一致""" try: if not os.path.exists(config_path) or not os.path.exists(weight_path): print(f"HINT模型文件不存在: config='{config_path}', weights='{weight_path}'") return None try: from yaml import CLoader as Loader except ImportError: from yaml import Loader x = yaml.load(open(config_path, mode='r'), Loader=Loader) # 与测试脚本完全一致的网络配置处理 network_config = x['network_g'] model_type = network_config.pop('type') model = HINT(**network_config) # 与测试脚本一致的权重加载方式 checkpoint = torch.load(weight_path, map_location=device) if 'params' in checkpoint: model.load_state_dict(checkpoint['params']) else: model.load_state_dict(checkpoint) model.eval() model = model.to(device) print(f"HINT模型加载成功: {weight_path}") return model except Exception as e: print(f"HINT模型加载失败: {e}") return None def load_basicsr_model(config_path, weight_path): """加载BasicSR模型 - 与测试脚本完全一致""" try: if not os.path.exists(config_path) or not os.path.exists(weight_path): print(f"BasicSR模型文件不存在: config='{config_path}', weights='{weight_path}'") return None opt = parse(config_path, is_train=False) opt['dist'] = False opt['num_gpu'] = 0 opt['device'] = 'cpu' model = create_model(opt).net_g checkpoint = torch.load(weight_path, map_location=device) # 与测试脚本完全一致的权重加载逻辑 try: model.load_state_dict(checkpoint['params']) except: new_checkpoint = {} for k in checkpoint['params']: new_checkpoint['module.' + k] = checkpoint['params'][k] model.load_state_dict(new_checkpoint) model.eval() model = model.to(device) print(f"BasicSR模型加载成功: {weight_path}") return model except Exception as e: print(f"BasicSR模型加载失败: {e}") return None def init(): """初始化所有模型""" global dehazing_model, deraining_model, desnowing_model, denoising_blind_model global enhancement_synthetic_model, enhancement_real_model print("开始加载所有模型...") dehazing_model = load_hint_model("./options/RealDehazing_HINT.yml", "./models/Dehazing.pth") deraining_model = load_hint_model("./options/Deraining_HINT_syn_rain100L.yml", "./models/Rain100L_HINT.pth") desnowing_model = load_hint_model("./options/Desnow_snow100k_HINT.yml", "./models/snow100k.pth") denoising_blind_model = load_hint_model("./options/GaussianColorDenoising_HINT.yml", "./models/net_g_latest_blind.pth") enhancement_synthetic_model = load_basicsr_model("./options/HINT_LOL_v2_synthetic.yml", "./models/lolv2Syn.pth") enhancement_real_model = load_basicsr_model("./options/HINT_LOL_v2_real.yml", "./models/lolv2Real.pth") gc.collect() print("所有模型加载完成。") def preprocess_image(img, task=None, sigma=None): """预处理图像 - 完全按照测试脚本的处理逻辑""" # 转换为numpy数组 if isinstance(img, Image.Image): arr = np.array(img) else: arr = img # 对于去噪任务,完全按照原始代码的处理逻辑 if task == 'denoising_blind' and sigma is not None and sigma > 0: # 步骤1:归一化为 float32 (与原始代码一致) img_normalized = np.float32(arr) / 255.0 # 步骤2:设置随机种子 (与原始代码一致) np.random.seed(seed=0) # 步骤3:添加高斯噪声 (与原始代码完全一致) img_normalized += np.random.normal(0, sigma / 255., img_normalized.shape) # 步骤4:转换为tensor (与原始代码一致: permute(2,0,1)) tensor = torch.from_numpy(img_normalized).permute(2, 0, 1).unsqueeze(0) else: # 其他任务的常规处理 - 完全按照测试脚本逻辑 # 测试脚本: img = np.float32(utils.load_img(inp_path)) / 255. # 测试脚本: img = torch.from_numpy(img).permute(2, 0, 1) arr = np.float32(arr) / 255.0 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) return tensor def inference(body: bytes, task_type: str, sigma: int = None, GT_mean: bool = False, target_image: bytes = None) -> bytes: """ 统一推理接口 - 完全按照测试脚本的处理逻辑 """ model_map = { "dehazing": dehazing_model, "deraining": deraining_model, "desnowing": desnowing_model, "denoising_blind": denoising_blind_model, "enhancement_synthetic": enhancement_synthetic_model, "enhancement_real": enhancement_real_model } model = model_map.get(task_type) if model is None: raise Exception(f"模型 '{task_type}' 未加载或不存在。") img = Image.open(BytesIO(body)).convert("RGB") input_tensor = preprocess_image(img, task_type, sigma).to(device) orig_h, orig_w = input_tensor.shape[2], input_tensor.shape[3] # 如果需要GT校正,加载目标图像 target_np = None if GT_mean and target_image is not None: target_img = Image.open(BytesIO(target_image)).convert("RGB") target_np = np.array(target_img).astype(np.float32) / 255.0 try: with torch.no_grad(): B, C, H, W = input_tensor.shape # 低光增强任务 - 按照原始测试脚本逻辑处理 if task_type in ["enhancement_synthetic", "enhancement_real"]: factor = 4 # 与测试脚本一致 # 按照原始代码的padding逻辑 H_pad = ((orig_h + factor) // factor) * factor W_pad = ((orig_w + factor) // factor) * factor padh = H_pad - orig_h if orig_h % factor != 0 else 0 padw = W_pad - orig_w if orig_w % factor != 0 else 0 padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') # 模型推理 restored_tensor = model(padded_input) # 去除padding,恢复原始尺寸 restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] # 去雾任务 - 按照测试脚本逻辑处理 elif task_type == "dehazing": factor = 8 # 按照测试脚本的padding逻辑 H_pad = ((orig_h + factor) // factor) * factor W_pad = ((orig_w + factor) // factor) * factor padh = H_pad - orig_h if orig_h % factor != 0 else 0 padw = W_pad - orig_w if orig_w % factor != 0 else 0 padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') # 使用测试脚本的切块参数 corp_size_arg = 256 overlap_size_arg = 64 B, C, H_pad, W_pad = padded_input.shape split_data, starts = splitimage(padded_input, crop_size=corp_size_arg, overlap_size=overlap_size_arg) processed_splits = [] for data in split_data: processed_splits.append(model(data).cpu()) restored_tensor = mergeimage(processed_splits, starts, crop_size=corp_size_arg, resolution=(B, C, H_pad, W_pad)) # 去除padding,恢复原始尺寸 restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] # 除雪任务 - 需要切块处理 elif task_type == "desnowing": factor = 8 crop_size = 256 overlap_size = 128 H_pad = ((orig_h + factor) // factor) * factor W_pad = ((orig_w + factor) // factor) * factor padh = H_pad - orig_h if orig_h % factor != 0 else 0 padw = W_pad - orig_w if orig_w % factor != 0 else 0 padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') split_data, starts = splitimage(padded_input, crop_size=crop_size, overlap_size=overlap_size) processed_splits = [model(s) for s in split_data] restored_tensor = mergeimage(processed_splits, starts, crop_size=crop_size, resolution=(1, 3, H_pad, W_pad)) # 去除padding,恢复原始尺寸 restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] # 去雨和降噪任务 - 全图处理 else: # deraining, denoising_blind factor = 8 H_pad = ((orig_h + factor) // factor) * factor W_pad = ((orig_w + factor) // factor) * factor padh = H_pad - orig_h if orig_h % factor != 0 else 0 padw = W_pad - orig_w if orig_w % factor != 0 else 0 padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') restored_tensor = model(padded_input) # 去除padding,恢复原始尺寸 restored_tensor = restored_tensor[:, :, :orig_h, :orig_w] # 转换为numpy数组 restored_np = torch.clamp(restored_tensor, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() # GT均值校正 - 完全按照原始代码逻辑 if GT_mean and target_np is not None: # 按照原始代码的逻辑进行GT均值校正 # 原始代码: mean_restored = cv2.cvtColor(restored.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() # 原始代码: mean_target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() # 原始代码: restored = np.clip(restored * (mean_target / mean_restored), 0, 1) mean_restored = cv2.cvtColor(restored_np.astype(np.float32), cv2.COLOR_RGB2GRAY).mean() mean_target = cv2.cvtColor(target_np.astype(np.float32), cv2.COLOR_RGB2GRAY).mean() restored_np = np.clip(restored_np * (mean_target / mean_restored), 0, 1) # 转换为8位图像格式 result_np = img_as_ubyte(restored_np) out_img = Image.fromarray(result_np) buf = BytesIO() out_img.save(buf, format="PNG") gc.collect() return buf.getvalue() except Exception as e: print(f"在任务 '{task_type}' 中推理失败: {e}") import traceback traceback.print_exc() raise e # --- 导出的推理函数 --- def dehazing_inference(body: bytes) -> bytes: return inference(body, "dehazing") def deraining_inference(body: bytes) -> bytes: return inference(body, "deraining") def desnowing_inference(body: bytes) -> bytes: return inference(body, "desnowing") def denoising_blind_inference(body: bytes, sigma: int = 25) -> bytes: return inference(body, "denoising_blind", sigma=sigma) def enhancement_synthetic_inference(body: bytes, GT_mean: bool = False, target_image: bytes = None) -> bytes: """低光增强(合成)推理""" return inference(body, "enhancement_synthetic", GT_mean=GT_mean, target_image=target_image) def enhancement_real_inference(body: bytes, GT_mean: bool = False, target_image: bytes = None) -> bytes: """低光增强(真实)推理""" return inference(body, "enhancement_real", GT_mean=GT_mean, target_image=target_image) # 为了保持向后兼容性,提供简化的调用方式 def enhancement_synthetic_inference_simple(body: bytes) -> bytes: """简化的低光增强(合成)推理,不使用GT校正""" return inference(body, "enhancement_synthetic") def enhancement_real_inference_simple(body: bytes) -> bytes: """简化的低光增强(真实)推理,不使用GT校正""" return inference(body, "enhancement_real") # --- 辅助函数 --- def get_model_status(): """获取模型加载状态""" return { "dehazing_model_loaded": dehazing_model is not None, "deraining_model_loaded": deraining_model is not None, "desnowing_model_loaded": desnowing_model is not None, "denoising_blind_model_loaded": denoising_blind_model is not None, "enhancement_synthetic_model_loaded": enhancement_synthetic_model is not None, "enhancement_real_model_loaded": enhancement_real_model is not None } def get_available_tasks(): """获取可用的任务列表""" return ['dehazing', 'deraining', 'desnowing', 'denoising_blind', 'enhancement_synthetic', 'enhancement_real'] def get_task_description(task): """获取任务描述""" descriptions = { 'dehazing': '去除图像中的雾霾效果', 'deraining': '去除图像中的雨线', 'desnowing': '去除图像中的雪花', 'denoising_blind': '通用降噪', 'enhancement_synthetic': '低光增强 (合成数据训练)', 'enhancement_real': '低光增强 (真实数据训练)' } return descriptions.get(task, '未知任务') # --- 测试函数 --- def test_model_loading(): """测试模型加载情况""" print("测试模型加载情况:") status = get_model_status() for model_name, loaded in status.items(): print(f" {model_name}: {'✓' if loaded else '✗'}") print() def test_inference(task_type, test_image_path=None): """测试推理功能""" if test_image_path is None: # 创建一个测试图像 test_img = Image.new('RGB', (256, 256), color='red') buf = BytesIO() test_img.save(buf, format='PNG') test_bytes = buf.getvalue() else: with open(test_image_path, 'rb') as f: test_bytes = f.read() try: if task_type == 'dehazing': result = dehazing_inference(test_bytes) elif task_type == 'deraining': result = deraining_inference(test_bytes) elif task_type == 'desnowing': result = desnowing_inference(test_bytes) elif task_type == 'denoising_blind': result = denoising_blind_inference(test_bytes, sigma=25) elif task_type == 'enhancement_synthetic': result = enhancement_synthetic_inference(test_bytes) elif task_type == 'enhancement_real': result = enhancement_real_inference(test_bytes) else: print(f"未知任务类型: {task_type}") return False print(f"✓ {task_type} 推理测试成功,输出大小: {len(result)} bytes") return True except Exception as e: print(f"✗ {task_type} 推理测试失败: {e}") return False if __name__ == "__main__": # 初始化模型 init() # 测试模型加载 test_model_loading() # 测试所有任务的推理 print("测试推理功能:") for task in get_available_tasks(): test_inference(task) print("\n所有测试完成!")