IR_expeiment / PART2 /run_maxim.py
hugaagg's picture
Upload folder using huggingface_hub
2ecc7ab verified
import torch
from transformers import MaximImageProcessor, MaximForImageDeblurring
from PIL import Image
import numpy as np
# 1. 设置文件路径 (请修改这里为你真实的路径)
input_image_path = r"G:\datasets\realblur_dataset_test\075_blur_1.png"
output_image_path = r"G:\datasets\maxim_result.png"
print(">>> 正在加载模型 (第一次运行会自动下载约 600MB 模型,请耐心等待)...")
# 2. 加载 Google 的 MAXIM 模型 (专用于 GoPro 去模糊任务)
# 这个库是纯 Python 的,不需要编译 C++,所以一定能跑通
processor = MaximImageProcessor.from_pretrained("google/maxim-s3-deblurring-gopro")
model = MaximForImageDeblurring.from_pretrained("google/maxim-s3-deblurring-gopro")
print(">>> 模型加载成功!正在读取图片...")
# 3. 读取并预处理图片
image = Image.open(input_image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
print(">>> 正在进行去模糊处理 (CPU运行可能需要1-2分钟,请稍候)...")
# 4. 推理 (如果不使用 GPU,这里会自动用 CPU)
with torch.no_grad():
outputs = model(**inputs)
# 5. 后处理并保存
# MAXIM 输出的是重构后的像素值
reconstructed_data = outputs.reconstruction.squeeze().permute(1, 2, 0).clamp(0, 1).numpy()
# 转换为 0-255 格式
reconstructed_image = (reconstructed_data * 255).astype(np.uint8)
reconstructed_image = Image.fromarray(reconstructed_image)
# 保存
reconstructed_image.save(output_image_path)
print(f">>> 处理完成!结果已保存至: {output_image_path}")