File size: 3,905 Bytes
6bf5137
 
 
 
 
 
 
 
360222f
 
6bf5137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
---
license: apache-2.0
language:
- zh
- en
base_model:
- microsoft/resnet-50
pipeline_tag: image-to-image
datasets:
- ynyg/InkEraser
---
# ink-eraser-latest(手写墨迹擦除模型)

本目录是一个用于“手写墨迹擦除 / 文档去涂写”的模型导出包(Hugging Face 兼容格式)。模型输入为带墨迹的 RGB 图像,输出为去除墨迹后的 RGB 图像。

## 模型信息

- 架构:U-Net++(`segmentation-models-pytorch`)+ ResNet50 编码器
- 任务:图像到图像(去除手写笔迹/墨迹)
- 输入:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`
- 输出:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`(末端 `sigmoid`## 文件说明

- `config.json`:模型结构与训练超参数(导出时写入)
- `model.safetensors`:推理用权重(推荐)
- `best.ckpt`:原始 PyTorch Lightning checkpoint(用于继续训练/复现实验)
- `configuration.json`:简要元数据(framework/task)

## 快速推理(SafeTensors,推荐)

依赖:`torch``torchvision``segmentation-models-pytorch``safetensors`,以及 `Pillow`(读写图片可选)。

```bash
pip install torch torchvision segmentation-models-pytorch safetensors pillow
```

```python
import json
from pathlib import Path

import torch
import segmentation_models_pytorch as smp
from safetensors.torch import load_file
from PIL import Image
import torchvision.transforms.functional as TF

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) 读取配置
cfg = json.loads(Path("config.json").read_text(encoding="utf-8"))

# 2) 构建网络(与导出配置保持一致)
model = smp.UnetPlusPlus(
    encoder_name=cfg["encoder_name"],
    encoder_weights=None,  # 权重来自 model.safetensors
    in_channels=cfg["in_channels"],
    classes=cfg["classes"],
    decoder_attention_type=cfg.get("decoder_attention_type"),
    activation=cfg.get("activation"),  # 通常为 "sigmoid"
).to(device)

# 3) 加载权重
# 说明:导出时可能混入非网络权重(例如 `edge_loss.kx/ky`),推理只需要 Unet++ 本体参数,过滤掉即可。
state_dict = load_file("model.safetensors")
model_keys = set(model.state_dict().keys())
state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
model.load_state_dict(state_dict, strict=True)
model.eval()

# 4) 准备输入(训练时仅做 0~1 归一化;如需更贴近训练分布可 resize 到 512x512)
img = Image.open("input.png").convert("RGB")
x = TF.to_tensor(img).unsqueeze(0).to(device)  # [1,3,H,W] in [0,1]

with torch.no_grad():
    y = model(x).clamp(0, 1)  # [1,3,H,W]

out = TF.to_pil_image(y.squeeze(0).cpu())
out.save("output.png")
```

提示:若输入尺寸不是 32 的倍数,部分编码器结构可能要求先 `pad/resize` 到合适尺寸(例如 `512x512`)。

也可以直接使用本项目提供的高清切块推理脚本(自动对大图切块并融合回原图),从项目根目录运行:

```bash
python infer_hd.py --model-dir assets/InkErase --input input.png --output output.png
```

## 使用 `best.ckpt`(继续训练/复现实验)

`best.ckpt` 是 PyTorch Lightning checkpoint,通常需要配合本项目的 `InkEraserModel` 代码使用,并提供 ResNet50 预训练权重文件(例如 `pretrained_weights/resnet50-0676ba61.pth`)。

```python
import torch
from model import InkEraserModel

model = InkEraserModel.load_from_checkpoint(
    "best.ckpt",
    weight="pretrained_weights/resnet50-0676ba61.pth",
)
model.eval()

with torch.no_grad():
    y = model(x)
```

## 训练超参数(来自 `config.json`)

以下参数主要用于训练/复现,推理不必关心:

```json
{
  "lr": 0.0001,
  "weight_decay": 0.01,
  "loss_w_charb": 0.78,
  "loss_w_ssim": 0.16,
  "loss_w_edge": 0.06,
  "use_mask_loss": true,
  "loss_mask_weight": 10.0,
  "charbonnier_eps": 0.001
}
```

## 许可证

MIT