|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- zh |
|
|
- en |
|
|
base_model: |
|
|
- microsoft/resnet-50 |
|
|
pipeline_tag: image-to-image |
|
|
datasets: |
|
|
- ynyg/InkEraser |
|
|
--- |
|
|
# ink-eraser-latest(手写墨迹擦除模型) |
|
|
|
|
|
本目录是一个用于“手写墨迹擦除 / 文档去涂写”的模型导出包(Hugging Face 兼容格式)。模型输入为带墨迹的 RGB 图像,输出为去除墨迹后的 RGB 图像。 |
|
|
|
|
|
## 模型信息 |
|
|
|
|
|
- 架构:U-Net++(`segmentation-models-pytorch`)+ ResNet50 编码器 |
|
|
- 任务:图像到图像(去除手写笔迹/墨迹) |
|
|
- 输入:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]` |
|
|
- 输出:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`(末端 `sigmoid`) |
|
|
|
|
|
## 文件说明 |
|
|
|
|
|
- `config.json`:模型结构与训练超参数(导出时写入) |
|
|
- `model.safetensors`:推理用权重(推荐) |
|
|
- `best.ckpt`:原始 PyTorch Lightning checkpoint(用于继续训练/复现实验) |
|
|
- `configuration.json`:简要元数据(framework/task) |
|
|
|
|
|
## 快速推理(SafeTensors,推荐) |
|
|
|
|
|
依赖:`torch`、`torchvision`、`segmentation-models-pytorch`、`safetensors`,以及 `Pillow`(读写图片可选)。 |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision segmentation-models-pytorch safetensors pillow |
|
|
``` |
|
|
|
|
|
```python |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import segmentation_models_pytorch as smp |
|
|
from safetensors.torch import load_file |
|
|
from PIL import Image |
|
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# 1) 读取配置 |
|
|
cfg = json.loads(Path("config.json").read_text(encoding="utf-8")) |
|
|
|
|
|
# 2) 构建网络(与导出配置保持一致) |
|
|
model = smp.UnetPlusPlus( |
|
|
encoder_name=cfg["encoder_name"], |
|
|
encoder_weights=None, # 权重来自 model.safetensors |
|
|
in_channels=cfg["in_channels"], |
|
|
classes=cfg["classes"], |
|
|
decoder_attention_type=cfg.get("decoder_attention_type"), |
|
|
activation=cfg.get("activation"), # 通常为 "sigmoid" |
|
|
).to(device) |
|
|
|
|
|
# 3) 加载权重 |
|
|
# 说明:导出时可能混入非网络权重(例如 `edge_loss.kx/ky`),推理只需要 Unet++ 本体参数,过滤掉即可。 |
|
|
state_dict = load_file("model.safetensors") |
|
|
model_keys = set(model.state_dict().keys()) |
|
|
state_dict = {k: v for k, v in state_dict.items() if k in model_keys} |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
model.eval() |
|
|
|
|
|
# 4) 准备输入(训练时仅做 0~1 归一化;如需更贴近训练分布可 resize 到 512x512) |
|
|
img = Image.open("input.png").convert("RGB") |
|
|
x = TF.to_tensor(img).unsqueeze(0).to(device) # [1,3,H,W] in [0,1] |
|
|
|
|
|
with torch.no_grad(): |
|
|
y = model(x).clamp(0, 1) # [1,3,H,W] |
|
|
|
|
|
out = TF.to_pil_image(y.squeeze(0).cpu()) |
|
|
out.save("output.png") |
|
|
``` |
|
|
|
|
|
提示:若输入尺寸不是 32 的倍数,部分编码器结构可能要求先 `pad/resize` 到合适尺寸(例如 `512x512`)。 |
|
|
|
|
|
也可以直接使用本项目提供的高清切块推理脚本(自动对大图切块并融合回原图),从项目根目录运行: |
|
|
|
|
|
```bash |
|
|
python infer_hd.py --model-dir assets/InkErase --input input.png --output output.png |
|
|
``` |
|
|
|
|
|
## 使用 `best.ckpt`(继续训练/复现实验) |
|
|
|
|
|
`best.ckpt` 是 PyTorch Lightning checkpoint,通常需要配合本项目的 `InkEraserModel` 代码使用,并提供 ResNet50 预训练权重文件(例如 `pretrained_weights/resnet50-0676ba61.pth`)。 |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from model import InkEraserModel |
|
|
|
|
|
model = InkEraserModel.load_from_checkpoint( |
|
|
"best.ckpt", |
|
|
weight="pretrained_weights/resnet50-0676ba61.pth", |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
y = model(x) |
|
|
``` |
|
|
|
|
|
## 训练超参数(来自 `config.json`) |
|
|
|
|
|
以下参数主要用于训练/复现,推理不必关心: |
|
|
|
|
|
```json |
|
|
{ |
|
|
"lr": 0.0001, |
|
|
"weight_decay": 0.01, |
|
|
"loss_w_charb": 0.78, |
|
|
"loss_w_ssim": 0.16, |
|
|
"loss_w_edge": 0.06, |
|
|
"use_mask_loss": true, |
|
|
"loss_mask_weight": 10.0, |
|
|
"charbonnier_eps": 0.001 |
|
|
} |
|
|
``` |
|
|
|
|
|
## 许可证 |
|
|
|
|
|
MIT |