HINT / model.py
yssszzzzzzzzy's picture
Upload model.py
ab46b01 verified
# model.py - 完整修复版本,确保与原始测试脚本处理逻辑完全一致
import yaml
import torch
import math
import numpy as np
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
import cv2
from basicsr.models import create_model
from basicsr.models.archs.HINT_arch import HINT
from basicsr.utils.options import parse
from skimage import img_as_ubyte
import gc
import os
# 强制使用 CPU,以匹配 Hugging Face Space 的典型环境
os.environ['CUDA_VISIBLE_DEVICES'] = ''
torch.backends.cudnn.enabled = False
# 根据需要调整线程数,以优化在CPU上的性能
torch.set_num_threads(2)
torch.set_num_interop_threads(1)
device = torch.device('cpu')
# 全局模型变量
dehazing_model = None
deraining_model = None
desnowing_model = None
denoising_blind_model = None
enhancement_synthetic_model = None
enhancement_real_model = None
def splitimage(imgtensor, crop_size=128, overlap_size=64):
"""原始切块函数 - 与测试脚本完全一致"""
_, C, H, W = imgtensor.shape
hstarts = [x for x in range(0, H, crop_size - overlap_size)]
while hstarts and hstarts[-1] + crop_size >= H:
hstarts.pop()
hstarts.append(H - crop_size)
wstarts = [x for x in range(0, W, crop_size - overlap_size)]
while wstarts and wstarts[-1] + crop_size >= W:
wstarts.pop()
wstarts.append(W - crop_size)
starts = []
split_data = []
for hs in hstarts:
for ws in wstarts:
cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
starts.append((hs, ws))
split_data.append(cimgdata)
return split_data, starts
def get_scoremap(H, W, C, B=1, is_mean=True):
"""原始权重图生成函数 - 与测试脚本完全一致"""
center_h = H / 2
center_w = W / 2
score = torch.ones((B, C, H, W))
if not is_mean:
for h in range(H):
for w in range(W):
score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
return score
def mergeimage(split_data, starts, crop_size=128, resolution=(1, 3, 128, 128)):
"""原始合并函数 - 与测试脚本完全一致"""
B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
tot_score = torch.zeros((B, C, H, W))
merge_img = torch.zeros((B, C, H, W))
scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
for simg, cstart in zip(split_data, starts):
hs, ws = cstart
merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
merge_img = merge_img / tot_score
return merge_img
def load_hint_model(config_path, weight_path):
"""加载HINT模型 - 逻辑与测试脚本完全一致"""
try:
if not os.path.exists(config_path) or not os.path.exists(weight_path):
print(f"HINT模型文件不存在: config='{config_path}', weights='{weight_path}'")
return None
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
x = yaml.load(open(config_path, mode='r'), Loader=Loader)
# 与测试脚本完全一致的网络配置处理
network_config = x['network_g']
model_type = network_config.pop('type')
model = HINT(**network_config)
# 与测试脚本一致的权重加载方式
checkpoint = torch.load(weight_path, map_location=device)
if 'params' in checkpoint:
model.load_state_dict(checkpoint['params'])
else:
model.load_state_dict(checkpoint)
model.eval()
model = model.to(device)
print(f"HINT模型加载成功: {weight_path}")
return model
except Exception as e:
print(f"HINT模型加载失败: {e}")
return None
def load_basicsr_model(config_path, weight_path):
"""加载BasicSR模型 - 与测试脚本完全一致"""
try:
if not os.path.exists(config_path) or not os.path.exists(weight_path):
print(f"BasicSR模型文件不存在: config='{config_path}', weights='{weight_path}'")
return None
opt = parse(config_path, is_train=False)
opt['dist'] = False
opt['num_gpu'] = 0
opt['device'] = 'cpu'
model = create_model(opt).net_g
checkpoint = torch.load(weight_path, map_location=device)
# 与测试脚本完全一致的权重加载逻辑
try:
model.load_state_dict(checkpoint['params'])
except:
new_checkpoint = {}
for k in checkpoint['params']:
new_checkpoint['module.' + k] = checkpoint['params'][k]
model.load_state_dict(new_checkpoint)
model.eval()
model = model.to(device)
print(f"BasicSR模型加载成功: {weight_path}")
return model
except Exception as e:
print(f"BasicSR模型加载失败: {e}")
return None
def init():
"""初始化所有模型"""
global dehazing_model, deraining_model, desnowing_model, denoising_blind_model
global enhancement_synthetic_model, enhancement_real_model
print("开始加载所有模型...")
dehazing_model = load_hint_model("./options/RealDehazing_HINT.yml", "./models/Dehazing.pth")
deraining_model = load_hint_model("./options/Deraining_HINT_syn_rain100L.yml", "./models/Rain100L_HINT.pth")
desnowing_model = load_hint_model("./options/Desnow_snow100k_HINT.yml", "./models/snow100k.pth")
denoising_blind_model = load_hint_model("./options/GaussianColorDenoising_HINT.yml",
"./models/net_g_latest_blind.pth")
enhancement_synthetic_model = load_basicsr_model("./options/HINT_LOL_v2_synthetic.yml", "./models/lolv2Syn.pth")
enhancement_real_model = load_basicsr_model("./options/HINT_LOL_v2_real.yml", "./models/lolv2Real.pth")
gc.collect()
print("所有模型加载完成。")
def preprocess_image(img, task=None, sigma=None):
"""预处理图像 - 完全按照测试脚本的处理逻辑"""
# 转换为numpy数组
if isinstance(img, Image.Image):
arr = np.array(img)
else:
arr = img
# 对于去噪任务,完全按照原始代码的处理逻辑
if task == 'denoising_blind' and sigma is not None and sigma > 0:
# 步骤1:归一化为 float32 (与原始代码一致)
img_normalized = np.float32(arr) / 255.0
# 步骤2:设置随机种子 (与原始代码一致)
np.random.seed(seed=0)
# 步骤3:添加高斯噪声 (与原始代码完全一致)
img_normalized += np.random.normal(0, sigma / 255., img_normalized.shape)
# 步骤4:转换为tensor (与原始代码一致: permute(2,0,1))
tensor = torch.from_numpy(img_normalized).permute(2, 0, 1).unsqueeze(0)
else:
# 其他任务的常规处理 - 完全按照测试脚本逻辑
# 测试脚本: img = np.float32(utils.load_img(inp_path)) / 255.
# 测试脚本: img = torch.from_numpy(img).permute(2, 0, 1)
arr = np.float32(arr) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor
def inference(body: bytes, task_type: str, sigma: int = None, GT_mean: bool = False,
target_image: bytes = None) -> bytes:
"""
统一推理接口 - 完全按照测试脚本的处理逻辑
"""
model_map = {
"dehazing": dehazing_model,
"deraining": deraining_model,
"desnowing": desnowing_model,
"denoising_blind": denoising_blind_model,
"enhancement_synthetic": enhancement_synthetic_model,
"enhancement_real": enhancement_real_model
}
model = model_map.get(task_type)
if model is None:
raise Exception(f"模型 '{task_type}' 未加载或不存在。")
img = Image.open(BytesIO(body)).convert("RGB")
input_tensor = preprocess_image(img, task_type, sigma).to(device)
orig_h, orig_w = input_tensor.shape[2], input_tensor.shape[3]
# 如果需要GT校正,加载目标图像
target_np = None
if GT_mean and target_image is not None:
target_img = Image.open(BytesIO(target_image)).convert("RGB")
target_np = np.array(target_img).astype(np.float32) / 255.0
try:
with torch.no_grad():
B, C, H, W = input_tensor.shape
# 低光增强任务 - 按照原始测试脚本逻辑处理
if task_type in ["enhancement_synthetic", "enhancement_real"]:
factor = 4 # 与测试脚本一致
# 按照原始代码的padding逻辑
H_pad = ((orig_h + factor) // factor) * factor
W_pad = ((orig_w + factor) // factor) * factor
padh = H_pad - orig_h if orig_h % factor != 0 else 0
padw = W_pad - orig_w if orig_w % factor != 0 else 0
padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
# 模型推理
restored_tensor = model(padded_input)
# 去除padding,恢复原始尺寸
restored_tensor = restored_tensor[:, :, :orig_h, :orig_w]
# 去雾任务 - 按照测试脚本逻辑处理
elif task_type == "dehazing":
factor = 8
# 按照测试脚本的padding逻辑
H_pad = ((orig_h + factor) // factor) * factor
W_pad = ((orig_w + factor) // factor) * factor
padh = H_pad - orig_h if orig_h % factor != 0 else 0
padw = W_pad - orig_w if orig_w % factor != 0 else 0
padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
# 使用测试脚本的切块参数
corp_size_arg = 256
overlap_size_arg = 64
B, C, H_pad, W_pad = padded_input.shape
split_data, starts = splitimage(padded_input, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
processed_splits = []
for data in split_data:
processed_splits.append(model(data).cpu())
restored_tensor = mergeimage(processed_splits, starts, crop_size=corp_size_arg,
resolution=(B, C, H_pad, W_pad))
# 去除padding,恢复原始尺寸
restored_tensor = restored_tensor[:, :, :orig_h, :orig_w]
# 除雪任务 - 需要切块处理
elif task_type == "desnowing":
factor = 8
crop_size = 256
overlap_size = 128
H_pad = ((orig_h + factor) // factor) * factor
W_pad = ((orig_w + factor) // factor) * factor
padh = H_pad - orig_h if orig_h % factor != 0 else 0
padw = W_pad - orig_w if orig_w % factor != 0 else 0
padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
split_data, starts = splitimage(padded_input, crop_size=crop_size, overlap_size=overlap_size)
processed_splits = [model(s) for s in split_data]
restored_tensor = mergeimage(processed_splits, starts, crop_size=crop_size,
resolution=(1, 3, H_pad, W_pad))
# 去除padding,恢复原始尺寸
restored_tensor = restored_tensor[:, :, :orig_h, :orig_w]
# 去雨和降噪任务 - 全图处理
else: # deraining, denoising_blind
factor = 8
H_pad = ((orig_h + factor) // factor) * factor
W_pad = ((orig_w + factor) // factor) * factor
padh = H_pad - orig_h if orig_h % factor != 0 else 0
padw = W_pad - orig_w if orig_w % factor != 0 else 0
padded_input = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
restored_tensor = model(padded_input)
# 去除padding,恢复原始尺寸
restored_tensor = restored_tensor[:, :, :orig_h, :orig_w]
# 转换为numpy数组
restored_np = torch.clamp(restored_tensor, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
# GT均值校正 - 完全按照原始代码逻辑
if GT_mean and target_np is not None:
# 按照原始代码的逻辑进行GT均值校正
# 原始代码: mean_restored = cv2.cvtColor(restored.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
# 原始代码: mean_target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
# 原始代码: restored = np.clip(restored * (mean_target / mean_restored), 0, 1)
mean_restored = cv2.cvtColor(restored_np.astype(np.float32), cv2.COLOR_RGB2GRAY).mean()
mean_target = cv2.cvtColor(target_np.astype(np.float32), cv2.COLOR_RGB2GRAY).mean()
restored_np = np.clip(restored_np * (mean_target / mean_restored), 0, 1)
# 转换为8位图像格式
result_np = img_as_ubyte(restored_np)
out_img = Image.fromarray(result_np)
buf = BytesIO()
out_img.save(buf, format="PNG")
gc.collect()
return buf.getvalue()
except Exception as e:
print(f"在任务 '{task_type}' 中推理失败: {e}")
import traceback
traceback.print_exc()
raise e
# --- 导出的推理函数 ---
def dehazing_inference(body: bytes) -> bytes:
return inference(body, "dehazing")
def deraining_inference(body: bytes) -> bytes:
return inference(body, "deraining")
def desnowing_inference(body: bytes) -> bytes:
return inference(body, "desnowing")
def denoising_blind_inference(body: bytes, sigma: int = 25) -> bytes:
return inference(body, "denoising_blind", sigma=sigma)
def enhancement_synthetic_inference(body: bytes, GT_mean: bool = False, target_image: bytes = None) -> bytes:
"""低光增强(合成)推理"""
return inference(body, "enhancement_synthetic", GT_mean=GT_mean, target_image=target_image)
def enhancement_real_inference(body: bytes, GT_mean: bool = False, target_image: bytes = None) -> bytes:
"""低光增强(真实)推理"""
return inference(body, "enhancement_real", GT_mean=GT_mean, target_image=target_image)
# 为了保持向后兼容性,提供简化的调用方式
def enhancement_synthetic_inference_simple(body: bytes) -> bytes:
"""简化的低光增强(合成)推理,不使用GT校正"""
return inference(body, "enhancement_synthetic")
def enhancement_real_inference_simple(body: bytes) -> bytes:
"""简化的低光增强(真实)推理,不使用GT校正"""
return inference(body, "enhancement_real")
# --- 辅助函数 ---
def get_model_status():
"""获取模型加载状态"""
return {
"dehazing_model_loaded": dehazing_model is not None,
"deraining_model_loaded": deraining_model is not None,
"desnowing_model_loaded": desnowing_model is not None,
"denoising_blind_model_loaded": denoising_blind_model is not None,
"enhancement_synthetic_model_loaded": enhancement_synthetic_model is not None,
"enhancement_real_model_loaded": enhancement_real_model is not None
}
def get_available_tasks():
"""获取可用的任务列表"""
return ['dehazing', 'deraining', 'desnowing', 'denoising_blind', 'enhancement_synthetic', 'enhancement_real']
def get_task_description(task):
"""获取任务描述"""
descriptions = {
'dehazing': '去除图像中的雾霾效果',
'deraining': '去除图像中的雨线',
'desnowing': '去除图像中的雪花',
'denoising_blind': '通用降噪',
'enhancement_synthetic': '低光增强 (合成数据训练)',
'enhancement_real': '低光增强 (真实数据训练)'
}
return descriptions.get(task, '未知任务')
# --- 测试函数 ---
def test_model_loading():
"""测试模型加载情况"""
print("测试模型加载情况:")
status = get_model_status()
for model_name, loaded in status.items():
print(f" {model_name}: {'✓' if loaded else '✗'}")
print()
def test_inference(task_type, test_image_path=None):
"""测试推理功能"""
if test_image_path is None:
# 创建一个测试图像
test_img = Image.new('RGB', (256, 256), color='red')
buf = BytesIO()
test_img.save(buf, format='PNG')
test_bytes = buf.getvalue()
else:
with open(test_image_path, 'rb') as f:
test_bytes = f.read()
try:
if task_type == 'dehazing':
result = dehazing_inference(test_bytes)
elif task_type == 'deraining':
result = deraining_inference(test_bytes)
elif task_type == 'desnowing':
result = desnowing_inference(test_bytes)
elif task_type == 'denoising_blind':
result = denoising_blind_inference(test_bytes, sigma=25)
elif task_type == 'enhancement_synthetic':
result = enhancement_synthetic_inference(test_bytes)
elif task_type == 'enhancement_real':
result = enhancement_real_inference(test_bytes)
else:
print(f"未知任务类型: {task_type}")
return False
print(f"✓ {task_type} 推理测试成功,输出大小: {len(result)} bytes")
return True
except Exception as e:
print(f"✗ {task_type} 推理测试失败: {e}")
return False
if __name__ == "__main__":
# 初始化模型
init()
# 测试模型加载
test_model_loading()
# 测试所有任务的推理
print("测试推理功能:")
for task in get_available_tasks():
test_inference(task)
print("\n所有测试完成!")