raycosine commited on
Commit
f158b5d
·
1 Parent(s): 4fd08f0

new augmentation

Browse files
Files changed (4) hide show
  1. app.py +31 -145
  2. features.py +24 -2
  3. features_preproc.py +57 -14
  4. requirements.txt +0 -6
app.py CHANGED
@@ -2,41 +2,10 @@ import gradio as gr, numpy as np
2
  from PIL import Image, ImageOps, ImageDraw, ImageFont
3
  from pathlib import Path
4
  import os, requests
5
- from features import binarize, feat_vec, cosine_sim, stroke_normalize
6
- from features import _ensure_ink_true
7
  from features_preproc import crop_and_center as crop_ref, LO
8
- from skimage.morphology import binary_dilation, disk
9
- import torch
10
- from annoy import AnnoyIndex
11
  from huggingface_hub import hf_hub_download
12
  ASSET_REPO = "raycosine/detangutify-data"
13
-
14
- EMBED_DIM = 128
15
- EMBEDDER_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset",
16
- filename="tangut_embedder.torchscript")
17
- CPS_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset",
18
- filename="tangut_cps.npy")
19
- EMBEDS_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset",
20
- filename="tangut_embeds.npy")
21
- ANNOY_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset",
22
- filename="tangut_index.ann")
23
-
24
- USE_CNN_EMB = os.path.exists(EMBEDDER_PATH) and os.path.exists(CPS_PATH) \
25
- and os.path.exists(EMBEDS_PATH) and os.path.exists(ANNOY_PATH)
26
-
27
-
28
- if USE_CNN_EMB:
29
- EMBEDDER = torch.jit.load(EMBEDDER_PATH, map_location="cpu").eval()
30
- CPS = np.load(CPS_PATH)
31
- E_TEMPL = np.load(EMBEDS_PATH)
32
- ANN = AnnoyIndex(EMBED_DIM, 'angular')
33
- ANN.load(ANNOY_PATH)
34
-
35
- def to_embed(bw_float01: np.ndarray) -> np.ndarray:
36
- x = torch.from_numpy(bw_float01[None, None, ...]).float()
37
- with torch.no_grad():
38
- e = EMBEDDER(x).detach().cpu().numpy()[0]
39
- return e.astype(np.float32)
40
  FONT_PATH = "data/NotoSerifTangut-Regular.ttf"
41
  URL = "https://notofonts.github.io/tangut/fonts/NotoSerifTangut/full/ttf/NotoSerifTangut-Regular.ttf"
42
 
@@ -45,10 +14,13 @@ if not os.path.exists(FONT_PATH):
45
  r = requests.get(URL)
46
  with open(FONT_PATH, "wb") as f:
47
  f.write(r.content)
48
-
49
- DATA = np.load("data/templates_aug.npz")
 
50
  X = DATA["X"]
51
  Y = DATA["y"]
 
 
52
  SIZE = 64
53
 
54
 
@@ -116,122 +88,36 @@ def infer(img):
116
  if arr.dtype != np.uint8:
117
  arr = np.clip(arr, 0, 255).astype(np.uint8)
118
 
119
- bw0 = binarize(arr, keep_largest=False, min_size=3)
120
- #bw0 = binary_dilation(bw0, disk(1))
121
- bw0 = _ensure_ink_true(bw0)
 
 
 
 
 
122
  bw = crop_ref(bw0, out_size=LO, margin_frac=0.08)
123
  bw = stroke_normalize(bw, target_px=3)
124
  viz_img = Image.fromarray((bw*255).astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
-
128
- if USE_CNN_EMB:
129
- e = to_embed(bw).astype(np.float32)
130
- e = e / (np.linalg.norm(e) + 1e-8)
131
- K = 800
132
- idxs = ANN.get_nns_by_vector(e.tolist(), K, include_distances=False)
133
- cnn_cos = (E_TEMPL[idxs] @ e).astype(np.float32)
134
- from skimage.morphology import skeletonize
135
- import numpy as _np
136
- from scipy.signal import convolve2d
137
- def preprocess_glyph_for_rank(cp:int) -> _np.ndarray:
138
- gimg = render_glyph(cp)
139
- garr = _np.array(gimg, dtype=_np.uint8)
140
- gbw0 = binarize(garr, keep_largest=False, min_size=3)
141
- #gbw0 = binary_dilation(gbw0, disk(1))
142
- gbw0 = _ensure_ink_true(gbw0)
143
- gbw = crop_ref(gbw0, out_size=LO, margin_frac=0.08)
144
- gbw = stroke_normalize(gbw, target_px=3) > 0.5
145
- return gbw
146
-
147
- def skel_bool(bw_bool:_np.ndarray):
148
- return skeletonize(bw_bool.astype(bool))
149
- from skimage.morphology import skeletonize
150
- from scipy.ndimage import distance_transform_edt
151
-
152
- q_bool = (bw > 0.5)
153
- q_skel = skeletonize(q_bool)
154
-
155
- from features import feat_vec, cosine_sim
156
- q_shape = feat_vec(q_bool.astype(np.float32))
157
-
158
- def overlap_score(a, b):
159
- inter = (a & b).sum()
160
- denom = max(1, min(a.sum(), b.sum()))
161
- return float(inter) / float(denom)
162
-
163
- def block_occ(bw, m=8):
164
- H, W = bw.shape
165
- ys = np.array_split(np.arange(H), m)
166
- xs = np.array_split(np.arange(W), m)
167
- occ = []
168
- for yy in ys:
169
- for xx in xs:
170
- occ.append(bw[np.ix_(yy, xx)].any())
171
- return np.asarray(occ, dtype=np.uint8)
172
-
173
- def iou_bool(a, b):
174
- inter = (a & b).sum()
175
- union = (a | b).sum()
176
- return float(inter) / max(1, union)
177
-
178
- def chamfer_sim(a_bool, b_bool, gamma=0.5):
179
- if a_bool.sum() == 0 or b_bool.sum() == 0:
180
- return 0.0
181
- da = distance_transform_edt(~a_bool)
182
- db = distance_transform_edt(~b_bool)
183
- s1 = np.exp(-gamma * float(db[a_bool].mean()))
184
- s2 = np.exp(-gamma * float(da[b_bool].mean()))
185
- return 0.5 * (s1 + s2)
186
-
187
- q_occ = block_occ(q_bool, m=8)
188
-
189
- K = 800
190
- idxs = ANN.get_nns_by_vector(e.tolist(), K, include_distances=False)
191
- cnn_cos = (E_TEMPL[idxs] @ e).astype(np.float32)
192
-
193
- rank_scores = []
194
- for i, cos_sc in zip(idxs, cnn_cos):
195
- cp_i = int(CPS[i])
196
- gbw = preprocess_glyph_for_rank(cp_i)
197
- gskel = skeletonize(gbw)
198
- ov = overlap_score(q_skel, gskel)
199
- g_shape = feat_vec(gbw.astype(np.float32))
200
- sh_cos = float(cosine_sim(q_shape, np.expand_dims(g_shape,0))[0])
201
- occ_iou = iou_bool(q_occ, block_occ(gbw, m=8))
202
- chf = chamfer_sim(q_skel, gskel, gamma=0.6)
203
-
204
- # 权重:降低 CNN,提升几何一致性;Chamfer 占 0.25
205
- final = 0.30*float(cos_sc) + 0.20*ov + 0.15*sh_cos + 0.10*occ_iou + 0.25*chf
206
- rank_scores.append(final)
207
-
208
- order = np.argsort(-np.asarray(rank_scores))[:10]
209
- idxs = [idxs[i] for i in order]
210
- sims = [float(np.asarray(rank_scores)[i]) for i in order]
211
- gallery_items, results_json = [], []
212
- for idx, sc in zip(idxs, sims):
213
- cp = int(CPS[idx])
214
- glyph_img = render_glyph(cp)
215
- caption = f"U+{cp:05X} {chr(cp)}\nScore: {float(sc):.6f}"
216
- gallery_items.append((glyph_img, caption))
217
- results_json.append({"cp": cp, "char": chr(cp), "score": float(sc)})
218
- return gallery_items, viz_img, results_json
219
- else:
220
- q = feat_vec(bw)
221
- #q = pca_transform(q).astype(np.float32)
222
- s = cosine_sim(q, X).astype(np.float32)
223
- idxs = np.argsort(-s)[:10]
224
-
225
- gallery_items = []
226
- results_json = []
227
- for idx in idxs:
228
- cp = int(Y[idx]); sc = float(s[idx])
229
- glyph_img = render_glyph(cp)
230
- caption = f"U+{cp:05X} {chr(cp)}\nScore: {sc:.6f}"
231
- gallery_items.append((glyph_img, caption))
232
- results_json.append({"cp": cp, "char": chr(cp), "score": sc})
233
- return gallery_items, viz_img, results_json
234
-
235
  with gr.Blocks() as demo:
236
 
237
  gr.Markdown("### Detangutify (Tangut Character classifier)")
 
2
  from PIL import Image, ImageOps, ImageDraw, ImageFont
3
  from pathlib import Path
4
  import os, requests
5
+ from features import binarize, feat_vec, cosine_sim, stroke_normalize, _ensure_ink_true
 
6
  from features_preproc import crop_and_center as crop_ref, LO
 
 
 
7
  from huggingface_hub import hf_hub_download
8
  ASSET_REPO = "raycosine/detangutify-data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  FONT_PATH = "data/NotoSerifTangut-Regular.ttf"
10
  URL = "https://notofonts.github.io/tangut/fonts/NotoSerifTangut/full/ttf/NotoSerifTangut-Regular.ttf"
11
 
 
14
  r = requests.get(URL)
15
  with open(FONT_PATH, "wb") as f:
16
  f.write(r.content)
17
+ DATA_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset",
18
+ filename="templates_aug.npz")
19
+ DATA = np.load(DATA_PATH)
20
  X = DATA["X"]
21
  Y = DATA["y"]
22
+ MEAN = DATA.get("mean", None)
23
+ STD = DATA.get("std", None)
24
  SIZE = 64
25
 
26
 
 
88
  if arr.dtype != np.uint8:
89
  arr = np.clip(arr, 0, 255).astype(np.uint8)
90
 
91
+ #pil = Image.fromarray(arr, mode="L").resize((SIZE, SIZE), Image.BILINEAR)
92
+ #bw = binarize(np.array(pil, dtype=np.uint8))
93
+ #bw = crop_and_center(bw, SIZE)
94
+ #bw = stroke_normalize(bw, target_px=3)
95
+ #bw = crop_ref(bw, out_size=LO, margin_frac=0.08) # 用训练同款
96
+ #bw = stroke_normalize(bw, target_px=2)
97
+ bw0 = binarize(arr)
98
+ bw0 = _ensure_ink_true(bw0)
99
  bw = crop_ref(bw0, out_size=LO, margin_frac=0.08)
100
  bw = stroke_normalize(bw, target_px=3)
101
  viz_img = Image.fromarray((bw*255).astype(np.uint8))
102
+ q = feat_vec(bw)
103
+ if MEAN is not None and STD is not None:
104
+ q = (q - MEAN.ravel()) / STD.ravel()
105
+ s = cosine_sim(q, X)
106
+ idxs = np.argsort(-s)[:10]
107
+ top, sec = float(s[idxs[0]]), float(s[idxs[1]]) if len(idxs)>1 else (float(s[idxs[0]]), -1)
108
+ low_conf = (top < 0.58) or (top - sec < 0.05)
109
+ gallery_items = []
110
+ results_json = []
111
+ for idx in idxs:
112
+ cp = int(Y[idx]); sc = float(s[idx])
113
+ glyph_img = render_glyph(cp)
114
+ #caption = f"U+{cp:05X} {chr(cp)}\nScore: {sc:.6f}"
115
+ caption = f"U+{cp:05X} {chr(cp)}\nScore: {sc:.6f}" + (" ⚠️" if low_conf and idx==idxs[0] else "")
116
+ gallery_items.append((glyph_img, caption))
117
+ results_json.append({"cp": cp, "char": chr(cp), "score": sc})
118
+ return gallery_items, viz_img, results_json
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  with gr.Blocks() as demo:
122
 
123
  gr.Markdown("### Detangutify (Tangut Character classifier)")
features.py CHANGED
@@ -4,7 +4,24 @@ from skimage.morphology import remove_small_objects
4
  from skimage.feature import hog
5
  from skimage.measure import moments_hu, label
6
  from skimage.morphology import skeletonize, binary_dilation, disk
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def _ensure_ink_true(bw_bool: np.ndarray) -> np.ndarray:
9
  bw = bw_bool.astype(bool)
10
  if bw.mean() > 0.5:
@@ -15,12 +32,17 @@ def stroke_normalize(bw: np.ndarray, target_px: int = 2) -> np.ndarray:
15
  bw = (bw > 0)
16
  if bw.mean() > 0.5:
17
  bw = ~bw
 
18
  skel = skeletonize(bw)
 
 
19
  if target_px <= 1:
20
  return skel.astype(np.float32)
21
  rad = max(1, int(round(target_px/2)))
22
  thick = binary_dilation(skel, disk(rad))
23
- return (thick).astype(np.float32)
 
 
24
  def to_64_gray(imgPIL):
25
  return np.array(imgPIL, dtype=np.uint8)
26
 
 
4
  from skimage.feature import hog
5
  from skimage.measure import moments_hu, label
6
  from skimage.morphology import skeletonize, binary_dilation, disk
7
+ from scipy.ndimage import convolve
8
+ from skimage.morphology import binary_opening
9
+ def _prune_spurs(skel: np.ndarray, iters: int = 2) -> np.ndarray:
10
+ """
11
+ 迭代剪掉骨架上长度很短的端点分支(spur)。
12
+ iters 表示最多向内剪掉几步(像素)。推荐 1~3。
13
+ """
14
+ s = skel.copy().astype(bool)
15
+ # 用 3x3 邻域统计端点:中心权重10,其它1;“10+1=11”即1个邻居的端点
16
+ K = np.array([[1,1,1],
17
+ [1,10,1],
18
+ [1,1,1]], dtype=np.uint8)
19
+ for _ in range(iters):
20
+ nb = convolve(s.astype(np.uint8), K, mode="constant", cval=0)
21
+ endpoints = (nb == 11) # 只有 1 个邻居
22
+ # 只剪 endpoints,不动分叉/主干
23
+ s = s & ~endpoints
24
+ return s
25
  def _ensure_ink_true(bw_bool: np.ndarray) -> np.ndarray:
26
  bw = bw_bool.astype(bool)
27
  if bw.mean() > 0.5:
 
32
  bw = (bw > 0)
33
  if bw.mean() > 0.5:
34
  bw = ~bw
35
+
36
  skel = skeletonize(bw)
37
+ skel = _prune_spurs(skel, iters=2) # ← 新增:剪短刺,去笔锋小尖
38
+
39
  if target_px <= 1:
40
  return skel.astype(np.float32)
41
  rad = max(1, int(round(target_px/2)))
42
  thick = binary_dilation(skel, disk(rad))
43
+ #thick = binary_opening(thick, disk(1))#optional
44
+
45
+ return (thick & bw).astype(np.float32)
46
  def to_64_gray(imgPIL):
47
  return np.array(imgPIL, dtype=np.uint8)
48
 
features_preproc.py CHANGED
@@ -3,7 +3,7 @@ from typing import Tuple
3
  import numpy as np
4
  from skimage.filters import threshold_otsu
5
  from skimage.morphology import remove_small_objects, binary_dilation, square
6
- from skimage.measure import label, moments_hu
7
  from skimage.transform import resize
8
  from skimage.feature import hog
9
 
@@ -15,22 +15,65 @@ def binarize_from_gray01(gray01: np.ndarray, thr: float = 0.5) -> np.ndarray:
15
  g /= 255.0
16
  return (g < thr)
17
 
18
- def binarize_otsu(gray: np.ndarray) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  g = gray.astype(np.float32)
20
- if g.max() > 1:
21
- g /= 255.0
22
- t = threshold_otsu(g)
23
- bw = g <= t
 
24
  lab = label(bw)
 
25
  if lab.max() > 0:
26
- areas = np.bincount(lab.ravel())
27
- areas[0] = 0
28
- keep = areas.argmax()
29
- bw = (lab == keep)
30
- bw = remove_small_objects(bw.astype(bool), min_size=4).astype(bool)
31
- #bw = binary_dilation(bw, np.ones((2, 2), dtype=bool))
32
- return bw
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def crop_and_center(bw: np.ndarray, out_size: int = LO, margin_frac: float = 0.08) -> np.ndarray:
35
  ys, xs = np.where(bw)
36
  if len(xs) == 0 or len(ys) == 0:
@@ -46,7 +89,7 @@ def crop_and_center(bw: np.ndarray, out_size: int = LO, margin_frac: float = 0.0
46
  pad_x_lft = (side - w) // 2 + margin
47
  pad_x_rgt = side - w - (side - w) // 2 + margin
48
  sq = np.pad(crop, ((pad_y_top, pad_y_bot), (pad_x_lft, pad_x_rgt)), mode='constant')
49
- sq = resize(sq, (out_size, out_size), order=0, anti_aliasing=True, preserve_range=True)
50
  return (sq > 0.5).astype(bool)
51
 
52
  def proj_features(bw: np.ndarray, m: int = 32) -> np.ndarray:
 
3
  import numpy as np
4
  from skimage.filters import threshold_otsu
5
  from skimage.morphology import remove_small_objects, binary_dilation, square
6
+ from skimage.measure import label, moments_hu, regionprops
7
  from skimage.transform import resize
8
  from skimage.feature import hog
9
 
 
15
  g /= 255.0
16
  return (g < thr)
17
 
18
+
19
+
20
+ def binarize_otsu(
21
+ gray: np.ndarray,
22
+ min_size: int = 12,
23
+ dilate_k: int = 2,
24
+ keep: str = "largest", # "largest" | "multi" | "smart"
25
+ area_ratio: float = 0.08, # ↓ 放宽一点
26
+ topk: int = 8, # ↑ 多留一点备选
27
+ horiz_keep_frac: float = 0.50, # ↓ 细长横更容易保留
28
+ vert_keep_frac: float = 0.55, # ↓ 细长竖更容易保留
29
+ ar_keep: float = 3.2, # 新增:细长(长/宽≥ar_keep)也保
30
+ top_edge_frac: float = 0.15 # 新增:靠顶部的细长撇也保(y0<=H*0.15)
31
+ ) -> np.ndarray:
32
  g = gray.astype(np.float32)
33
+ if g.max() > 1: g /= 255.0
34
+ t = threshold_otsu(g)
35
+ bw = (g <= t)
36
+
37
+ bw = remove_small_objects(bw.astype(bool), min_size=min_size).astype(bool)
38
  lab = label(bw)
39
+
40
  if lab.max() > 0:
41
+ areas = np.bincount(lab.ravel()); areas[0] = 0
42
+ if keep == "largest":
43
+ bw = (lab == areas.argmax())
44
+ else:
45
+ props = regionprops(lab)
46
+ H, W = bw.shape
47
+ max_area = areas.max()
48
+ max_w = max([p.bbox[3]-p.bbox[1] for p in props]) if props else 0
49
+ max_h = max([p.bbox[2]-p.bbox[0] for p in props]) if props else 0
50
+
51
+ keep_labels = []
52
+ for p in props:
53
+ k = p.label
54
+ y0, x0, y1, x1 = p.bbox
55
+ w = x1 - x0; h = y1 - y0
56
+ aspect = max(w, h) / max(1, min(w, h)) # 细长度
57
+ near_top = (y0 <= int(H * top_edge_frac))
58
+
59
+ cond_area = (areas[k] >= max_area * area_ratio)
60
+ cond_long = (max_w>0 and w >= max_w*horiz_keep_frac) or (max_h>0 and h >= max_h*vert_keep_frac)
61
+ cond_slim = (aspect >= ar_keep) # 细长撇/挑
62
+ cond_top = near_top and (w >= 0.45*max_w) # 顶边细长撇
63
 
64
+ if cond_area or cond_long or cond_slim or cond_top:
65
+ keep_labels.append(k)
66
+ if len(keep_labels) >= topk:
67
+ break
68
+
69
+ mask = np.zeros_like(bw, dtype=bool)
70
+ for k in keep_labels:
71
+ mask |= (lab == k)
72
+ bw = mask
73
+
74
+ if dilate_k > 0:
75
+ bw = binary_dilation(bw, square(dilate_k))
76
+ return bw
77
  def crop_and_center(bw: np.ndarray, out_size: int = LO, margin_frac: float = 0.08) -> np.ndarray:
78
  ys, xs = np.where(bw)
79
  if len(xs) == 0 or len(ys) == 0:
 
89
  pad_x_lft = (side - w) // 2 + margin
90
  pad_x_rgt = side - w - (side - w) // 2 + margin
91
  sq = np.pad(crop, ((pad_y_top, pad_y_bot), (pad_x_lft, pad_x_rgt)), mode='constant')
92
+ sq = resize(sq, (out_size, out_size), order=1, anti_aliasing=True, preserve_range=True)
93
  return (sq > 0.5).astype(bool)
94
 
95
  def proj_features(bw: np.ndarray, m: int = 32) -> np.ndarray:
requirements.txt CHANGED
@@ -2,9 +2,3 @@ gradio>=4.0.0
2
  numpy
3
  Pillow
4
  scikit-image
5
- torch
6
- torchvision
7
- tqdm
8
- annoy
9
- huggingface_hub
10
-
 
2
  numpy
3
  Pillow
4
  scikit-image