ynyg commited on
Commit
6bf5137
·
verified ·
1 Parent(s): ad308d4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +125 -116
README.md CHANGED
@@ -1,116 +1,125 @@
1
- # ink-eraser-latest(手写墨迹擦除模型)
2
-
3
- 本目录是一个用于“手写墨迹擦除 / 文档去涂写”的模型导出包(Hugging Face 兼容格式)。模型输入为带墨迹的 RGB 图像,输出为去除墨迹后的 RGB 图像。
4
-
5
- ## 模型信息
6
-
7
- - 架构:U-Net++(`segmentation-models-pytorch`)+ ResNet50 编码器
8
- - 任务:图像到图像(去除手写笔迹/墨迹)
9
- - 输入:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`
10
- - 输出:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`(末端 `sigmoid`)
11
-
12
- ## 文件说明
13
-
14
- - `config.json`:模型结构与训练超参数(导出时写入)
15
- - `model.safetensors`:推理用权重(推荐)
16
- - `best.ckpt`:原始 PyTorch Lightning checkpoint(用于继续训练/复现实验)
17
- - `configuration.json`:简要元数据(framework/task)
18
-
19
- ## 快速推理(SafeTensors,推荐)
20
-
21
- 依赖:`torch`、`torchvision`、`segmentation-models-pytorch`、`safetensors`,以及 `Pillow`(读写图片可选)。
22
-
23
- ```bash
24
- pip install torch torchvision segmentation-models-pytorch safetensors pillow
25
- ```
26
-
27
- ```python
28
- import json
29
- from pathlib import Path
30
-
31
- import torch
32
- import segmentation_models_pytorch as smp
33
- from safetensors.torch import load_file
34
- from PIL import Image
35
- import torchvision.transforms.functional as TF
36
-
37
- device = "cuda" if torch.cuda.is_available() else "cpu"
38
-
39
- # 1) 读取配置
40
- cfg = json.loads(Path("config.json").read_text(encoding="utf-8"))
41
-
42
- # 2) 构建网络(与导出配置保持一致)
43
- model = smp.UnetPlusPlus(
44
- encoder_name=cfg["encoder_name"],
45
- encoder_weights=None, # 权重来自 model.safetensors
46
- in_channels=cfg["in_channels"],
47
- classes=cfg["classes"],
48
- decoder_attention_type=cfg.get("decoder_attention_type"),
49
- activation=cfg.get("activation"), # 通常为 "sigmoid"
50
- ).to(device)
51
-
52
- # 3) 加载权重
53
- # 说明:导出时可能混入非网络权重(例如 `edge_loss.kx/ky`),推理只需要 Unet++ 本体参数,过滤掉即可。
54
- state_dict = load_file("model.safetensors")
55
- model_keys = set(model.state_dict().keys())
56
- state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
57
- model.load_state_dict(state_dict, strict=True)
58
- model.eval()
59
-
60
- # 4) 准备输入(训练时仅做 0~1 归一化;如需更贴近训练分布可 resize 到 512x512)
61
- img = Image.open("input.png").convert("RGB")
62
- x = TF.to_tensor(img).unsqueeze(0).to(device) # [1,3,H,W] in [0,1]
63
-
64
- with torch.no_grad():
65
- y = model(x).clamp(0, 1) # [1,3,H,W]
66
-
67
- out = TF.to_pil_image(y.squeeze(0).cpu())
68
- out.save("output.png")
69
- ```
70
-
71
- 提示:若输入尺寸不是 32 的倍数,部分编码器结构可能要求先 `pad/resize` 到合适尺寸(例如 `512x512`)。
72
-
73
- 也可以直接使用本项目提供的高清切块推理脚本(自动对大图切块并融合回原图),从项目根目录运行:
74
-
75
- ```bash
76
- python infer_hd.py --model-dir assets/InkErase --input input.png --output output.png
77
- ```
78
-
79
- ## 使用 `best.ckpt`(继续训练/复现实验)
80
-
81
- `best.ckpt` 是 PyTorch Lightning checkpoint,通常需要配合本项目的 `InkEraserModel` 代码使用,并提供 ResNet50 预训练权重文件(例如 `pretrained_weights/resnet50-0676ba61.pth`)。
82
-
83
- ```python
84
- import torch
85
- from model import InkEraserModel
86
-
87
- model = InkEraserModel.load_from_checkpoint(
88
- "best.ckpt",
89
- weight="pretrained_weights/resnet50-0676ba61.pth",
90
- )
91
- model.eval()
92
-
93
- with torch.no_grad():
94
- y = model(x)
95
- ```
96
-
97
- ## 训练超参数(来自 `config.json`)
98
-
99
- 以下参数主要用于训练/复现,推理不必关心:
100
-
101
- ```json
102
- {
103
- "lr": 0.0001,
104
- "weight_decay": 0.01,
105
- "loss_w_charb": 0.78,
106
- "loss_w_ssim": 0.16,
107
- "loss_w_edge": 0.06,
108
- "use_mask_loss": true,
109
- "loss_mask_weight": 10.0,
110
- "charbonnier_eps": 0.001
111
- }
112
- ```
113
-
114
- ## 许可证
115
-
116
- MIT
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - zh
5
+ - en
6
+ base_model:
7
+ - microsoft/resnet-50
8
+ pipeline_tag: image-to-image
9
+ ---
10
+ # ink-eraser-latest(手写墨迹擦除模型)
11
+
12
+ 本目录是一个用于“手写墨迹擦除 / 文档去涂写”的模型导出包(Hugging Face 兼容格式)。模型输入为带墨迹的 RGB 图像,输出为去除墨迹后的 RGB 图像。
13
+
14
+ ## 模型信息
15
+
16
+ - 架构:U-Net++(`segmentation-models-pytorch`)+ ResNet50 编码器
17
+ - 任务:图像到图像(去除手写笔迹/墨迹)
18
+ - 输入:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`
19
+ - 输出:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`(末端 `sigmoid`)
20
+
21
+ ## 文件说明
22
+
23
+ - `config.json`:模型结构与训练超参数(导出时写入)
24
+ - `model.safetensors`:推理用权重(推荐)
25
+ - `best.ckpt`:原始 PyTorch Lightning checkpoint(用于继续训练/复现实验)
26
+ - `configuration.json`:简要元数据(framework/task)
27
+
28
+ ## 快速推理(SafeTensors,推荐)
29
+
30
+ 依赖:`torch`、`torchvision`、`segmentation-models-pytorch`、`safetensors`,以及 `Pillow`(读写图片可选)。
31
+
32
+ ```bash
33
+ pip install torch torchvision segmentation-models-pytorch safetensors pillow
34
+ ```
35
+
36
+ ```python
37
+ import json
38
+ from pathlib import Path
39
+
40
+ import torch
41
+ import segmentation_models_pytorch as smp
42
+ from safetensors.torch import load_file
43
+ from PIL import Image
44
+ import torchvision.transforms.functional as TF
45
+
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ # 1) 读取配置
49
+ cfg = json.loads(Path("config.json").read_text(encoding="utf-8"))
50
+
51
+ # 2) 构建网络(与导出配置保持一致)
52
+ model = smp.UnetPlusPlus(
53
+ encoder_name=cfg["encoder_name"],
54
+ encoder_weights=None, # 权重来自 model.safetensors
55
+ in_channels=cfg["in_channels"],
56
+ classes=cfg["classes"],
57
+ decoder_attention_type=cfg.get("decoder_attention_type"),
58
+ activation=cfg.get("activation"), # 通常为 "sigmoid"
59
+ ).to(device)
60
+
61
+ # 3) 加载权重
62
+ # 说明:导出时可能混入非网络权重(例如 `edge_loss.kx/ky`),推理只需要 Unet++ 本体参数,过滤掉即可。
63
+ state_dict = load_file("model.safetensors")
64
+ model_keys = set(model.state_dict().keys())
65
+ state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
66
+ model.load_state_dict(state_dict, strict=True)
67
+ model.eval()
68
+
69
+ # 4) 准备输入(训练时仅做 0~1 归一化;如需更贴近训练分布可 resize 到 512x512)
70
+ img = Image.open("input.png").convert("RGB")
71
+ x = TF.to_tensor(img).unsqueeze(0).to(device) # [1,3,H,W] in [0,1]
72
+
73
+ with torch.no_grad():
74
+ y = model(x).clamp(0, 1) # [1,3,H,W]
75
+
76
+ out = TF.to_pil_image(y.squeeze(0).cpu())
77
+ out.save("output.png")
78
+ ```
79
+
80
+ 提示:若输入尺寸不是 32 的倍数,部分编码器结构可能要求先 `pad/resize` 到合适尺寸(例如 `512x512`)。
81
+
82
+ 也可以直接使用本项目提供的高清切块推理脚本(自动对大图切块并融合回原图),从项目根目录运行:
83
+
84
+ ```bash
85
+ python infer_hd.py --model-dir assets/InkErase --input input.png --output output.png
86
+ ```
87
+
88
+ ## 使用 `best.ckpt`(继续训练/复现实验)
89
+
90
+ `best.ckpt` 是 PyTorch Lightning checkpoint,通常需要配合本项目的 `InkEraserModel` 代码使用,并提供 ResNet50 预训练权重文件(例如 `pretrained_weights/resnet50-0676ba61.pth`)。
91
+
92
+ ```python
93
+ import torch
94
+ from model import InkEraserModel
95
+
96
+ model = InkEraserModel.load_from_checkpoint(
97
+ "best.ckpt",
98
+ weight="pretrained_weights/resnet50-0676ba61.pth",
99
+ )
100
+ model.eval()
101
+
102
+ with torch.no_grad():
103
+ y = model(x)
104
+ ```
105
+
106
+ ## 训练超参数(来自 `config.json`)
107
+
108
+ 以下参数主要用于训练/复现,推理不必关心:
109
+
110
+ ```json
111
+ {
112
+ "lr": 0.0001,
113
+ "weight_decay": 0.01,
114
+ "loss_w_charb": 0.78,
115
+ "loss_w_ssim": 0.16,
116
+ "loss_w_edge": 0.06,
117
+ "use_mask_loss": true,
118
+ "loss_mask_weight": 10.0,
119
+ "charbonnier_eps": 0.001
120
+ }
121
+ ```
122
+
123
+ ## 许可证
124
+
125
+ MIT