Update README.md
Browse files
README.md
CHANGED
|
@@ -1,116 +1,125 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
-
|
| 17 |
-
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- zh
|
| 5 |
+
- en
|
| 6 |
+
base_model:
|
| 7 |
+
- microsoft/resnet-50
|
| 8 |
+
pipeline_tag: image-to-image
|
| 9 |
+
---
|
| 10 |
+
# ink-eraser-latest(手写墨迹擦除模型)
|
| 11 |
+
|
| 12 |
+
本目录是一个用于“手写墨迹擦除 / 文档去涂写”的模型导出包(Hugging Face 兼容格式)。模型输入为带墨迹的 RGB 图像,输出为去除墨迹后的 RGB 图像。
|
| 13 |
+
|
| 14 |
+
## 模型信息
|
| 15 |
+
|
| 16 |
+
- 架构:U-Net++(`segmentation-models-pytorch`)+ ResNet50 编码器
|
| 17 |
+
- 任务:图像到图像(去除手写笔迹/墨迹)
|
| 18 |
+
- 输入:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`
|
| 19 |
+
- 输出:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`(末端 `sigmoid`)
|
| 20 |
+
|
| 21 |
+
## 文件说明
|
| 22 |
+
|
| 23 |
+
- `config.json`:模型结构与训练超参数(导出时写入)
|
| 24 |
+
- `model.safetensors`:推理用权重(推荐)
|
| 25 |
+
- `best.ckpt`:原始 PyTorch Lightning checkpoint(用于继续训练/复现实验)
|
| 26 |
+
- `configuration.json`:简要元数据(framework/task)
|
| 27 |
+
|
| 28 |
+
## 快速推理(SafeTensors,推荐)
|
| 29 |
+
|
| 30 |
+
依赖:`torch`、`torchvision`、`segmentation-models-pytorch`、`safetensors`,以及 `Pillow`(读写图片可选)。
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
pip install torch torchvision segmentation-models-pytorch safetensors pillow
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
```python
|
| 37 |
+
import json
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
import segmentation_models_pytorch as smp
|
| 42 |
+
from safetensors.torch import load_file
|
| 43 |
+
from PIL import Image
|
| 44 |
+
import torchvision.transforms.functional as TF
|
| 45 |
+
|
| 46 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
+
|
| 48 |
+
# 1) 读取配置
|
| 49 |
+
cfg = json.loads(Path("config.json").read_text(encoding="utf-8"))
|
| 50 |
+
|
| 51 |
+
# 2) 构建网络(与导出配置保持一致)
|
| 52 |
+
model = smp.UnetPlusPlus(
|
| 53 |
+
encoder_name=cfg["encoder_name"],
|
| 54 |
+
encoder_weights=None, # 权重来自 model.safetensors
|
| 55 |
+
in_channels=cfg["in_channels"],
|
| 56 |
+
classes=cfg["classes"],
|
| 57 |
+
decoder_attention_type=cfg.get("decoder_attention_type"),
|
| 58 |
+
activation=cfg.get("activation"), # 通常为 "sigmoid"
|
| 59 |
+
).to(device)
|
| 60 |
+
|
| 61 |
+
# 3) 加载权重
|
| 62 |
+
# 说明:导出时可能混入非网络权重(例如 `edge_loss.kx/ky`),推理只需要 Unet++ 本体参数,过滤掉即可。
|
| 63 |
+
state_dict = load_file("model.safetensors")
|
| 64 |
+
model_keys = set(model.state_dict().keys())
|
| 65 |
+
state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
|
| 66 |
+
model.load_state_dict(state_dict, strict=True)
|
| 67 |
+
model.eval()
|
| 68 |
+
|
| 69 |
+
# 4) 准备输入(训练时仅做 0~1 归一化;如需更贴近训练分布可 resize 到 512x512)
|
| 70 |
+
img = Image.open("input.png").convert("RGB")
|
| 71 |
+
x = TF.to_tensor(img).unsqueeze(0).to(device) # [1,3,H,W] in [0,1]
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
y = model(x).clamp(0, 1) # [1,3,H,W]
|
| 75 |
+
|
| 76 |
+
out = TF.to_pil_image(y.squeeze(0).cpu())
|
| 77 |
+
out.save("output.png")
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
提示:若输入尺寸不是 32 的倍数,部分编码器结构可能要求先 `pad/resize` 到合适尺寸(例如 `512x512`)。
|
| 81 |
+
|
| 82 |
+
也可以直接使用本项目提供的高清切块推理脚本(自动对大图切块并融合回原图),从项目根目录运行:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python infer_hd.py --model-dir assets/InkErase --input input.png --output output.png
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## 使用 `best.ckpt`(继续训练/复现实验)
|
| 89 |
+
|
| 90 |
+
`best.ckpt` 是 PyTorch Lightning checkpoint,通常需要配合本项目的 `InkEraserModel` 代码使用,并提供 ResNet50 预训练权重文件(例如 `pretrained_weights/resnet50-0676ba61.pth`)。
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
import torch
|
| 94 |
+
from model import InkEraserModel
|
| 95 |
+
|
| 96 |
+
model = InkEraserModel.load_from_checkpoint(
|
| 97 |
+
"best.ckpt",
|
| 98 |
+
weight="pretrained_weights/resnet50-0676ba61.pth",
|
| 99 |
+
)
|
| 100 |
+
model.eval()
|
| 101 |
+
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
y = model(x)
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## 训练超参数(来自 `config.json`)
|
| 107 |
+
|
| 108 |
+
以下参数主要用于训练/复现,推理不必关心:
|
| 109 |
+
|
| 110 |
+
```json
|
| 111 |
+
{
|
| 112 |
+
"lr": 0.0001,
|
| 113 |
+
"weight_decay": 0.01,
|
| 114 |
+
"loss_w_charb": 0.78,
|
| 115 |
+
"loss_w_ssim": 0.16,
|
| 116 |
+
"loss_w_edge": 0.06,
|
| 117 |
+
"use_mask_loss": true,
|
| 118 |
+
"loss_mask_weight": 10.0,
|
| 119 |
+
"charbonnier_eps": 0.001
|
| 120 |
+
}
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## 许可证
|
| 124 |
+
|
| 125 |
+
MIT
|