qox commited on
Commit
eba17eb
Β·
verified Β·
1 Parent(s): c02b9f9

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +460 -0
inference.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VN Address Normalizer β€” Standalone Inference
3
+ ============================================
4
+ No FST, no vietnam-provinces. Runs standalone on any machine with:
5
+ pip install -r requirements.txt
6
+
7
+ Usage (CLI):
8
+ python inference.py "p tan dinh q1 tphcm"
9
+
10
+ Usage (import):
11
+ from inference import normalize
12
+ result = normalize("p tan dinh q1 tphcm")
13
+ print(result["canonical"])
14
+ """
15
+
16
+ import json, re, time, sys
17
+ import torch, torch.nn as nn, torch.nn.functional as F
18
+ from collections import defaultdict
19
+ from pathlib import Path
20
+ from unidecode import unidecode
21
+
22
+ MODEL_DIR = Path(__file__).resolve().parent / "model_v3_final"
23
+
24
+ def slug(s: str) -> str:
25
+ return unidecode(s).lower().strip()
26
+
27
+ # ── Load artifacts ────────────────────────────────────────────────────────────
28
+ cfg = json.load(open(MODEL_DIR / "config.json"))
29
+ src_vocab = json.load(open(MODEL_DIR / "src_vocab.json", encoding="utf-8"))
30
+ tgt_vocab = json.load(open(MODEL_DIR / "tgt_vocab.json", encoding="utf-8"))
31
+ clean = json.load(open(MODEL_DIR / "clean_canonicals.json", encoding="utf-8"))
32
+ legacy_idx = json.load(open(MODEL_DIR / "legacy_ward_idx.json", encoding="utf-8"))
33
+
34
+ src_ch2id = {c: i for i, c in enumerate(src_vocab)}
35
+ tgt_ch2id = {c: i for i, c in enumerate(tgt_vocab)}
36
+ SRC_PAD, SRC_UNK, SRC_BOS, SRC_EOS = 0, 1, 2, 3
37
+ TGT_PAD, TGT_UNK, TGT_BOS, TGT_EOS = 0, 1, 2, 3
38
+
39
+ print(f"Canonicals: {len(clean):,}", flush=True)
40
+
41
+ # ── Build indexes from clean_canonicals.json (no FST) ─────────────────────────
42
+ prov_to_c = defaultdict(list) # province_name β†’ [canonical, ...]
43
+ pw_to_c = defaultdict(list) # (prov, ward_slug) β†’ [canonical, ...]
44
+ ward_idx = defaultdict(list) # ward_slug β†’ [canonical, ...]
45
+ ps = {} # province_slug β†’ canonical_province_name
46
+
47
+ for _c in clean:
48
+ _parts = [p.strip() for p in _c.split(",")]
49
+ if len(_parts) < 2:
50
+ continue
51
+ _prov = _parts[-1]
52
+ _ward_part = _parts[-2]
53
+ _ps = slug(_prov)
54
+
55
+ ps[_ps] = _prov
56
+ _stripped = re.sub(r"^(tinh|thanh pho|tp\.?)\s*", "", _ps).strip()
57
+ if _stripped != _ps:
58
+ ps[_stripped] = _prov
59
+
60
+ prov_to_c[_prov].append(_c)
61
+
62
+ for _ws in [slug(_ward_part),
63
+ re.sub(r"^(phuong|xa|thi tran|dac khu)\s+", "", slug(_ward_part)).strip()]:
64
+ pw_to_c[(_prov, _ws)].append(_c)
65
+ ward_idx[_ws].append(_c)
66
+
67
+ # ── Province aliases (historical / colloquial names) ──────────────────────────
68
+ _OLD = {
69
+ "hcm": "ho chi minh", "tphcm": "ho chi minh",
70
+ "saigon": "ho chi minh", "sai gon": "ho chi minh",
71
+ "hanoi": "ha noi",
72
+ "ha giang": "tuyen quang", "yen bai": "lao cai",
73
+ "bac kan": "thai nguyen", "vinh phuc": "phu tho",
74
+ "hoa binh": "phu tho", "bac giang": "bac ninh",
75
+ "thai binh": "hung yen", "hai duong": "hai phong",
76
+ "ha nam": "ninh binh", "nam dinh": "ninh binh",
77
+ "quang binh": "quang tri", "quang nam": "da nang",
78
+ "kon tum": "quang ngai", "binh dinh": "gia lai",
79
+ "phu yen": "dak lak", "ninh thuan": "khanh hoa",
80
+ "dak nong": "dak lak", "binh phuoc": "dong nai",
81
+ "binh duong": "ho chi minh","ba ria vung tau": "ho chi minh",
82
+ "long an": "tay ninh", "tien giang": "tay ninh",
83
+ "ben tre": "vinh long", "tra vinh": "vinh long",
84
+ "dong thap": "an giang", "kien giang": "an giang",
85
+ "hau giang": "can tho", "soc trang": "ca mau",
86
+ "bac lieu": "ca mau", "thua thien hue": "hue",
87
+ "tt hue": "hue", "brvt": "ho chi minh",
88
+ "vung tau": "ho chi minh",
89
+ }
90
+
91
+
92
+ def _resolve_prov(ts: str):
93
+ ts2 = re.sub(r"^(tinh|tp\.?\s*|thanh pho)\s+", "", ts).strip()
94
+ ts3 = re.sub(r"[.\s]", "", ts)
95
+ for key in [ts, ts2, ts3]:
96
+ if key in ps:
97
+ return ps[key]
98
+ alias = _OLD.get(key)
99
+ if alias:
100
+ for k, v in ps.items():
101
+ if alias in k:
102
+ return v
103
+ for k, v in ps.items():
104
+ if ts2 and len(ts2) > 2 and (ts2 in k or k in ts2):
105
+ return v
106
+ return None
107
+
108
+
109
+ # ── Address component parser (inlined β€” no normalizer.py dependency) ──────────
110
+ # _WARD_PFX / _PROV_PFX operate on raw Vietnamese text (comma-split)
111
+ _WARD_PFX = re.compile(
112
+ r"^(phường|phuong|ph\.|p\.|x\xe3|xa|x\."
113
+ r"|Δ‘αΊ·c\s*khu|dk\.?)\s*", re.I)
114
+ _PROV_PFX = re.compile(
115
+ r"^(tỉnh|tinh|th\xe0nh\s*phα»‘|thanh\s*pho|tp\.?|t\.p\.?)\s*", re.I)
116
+ _DIST_PFX = re.compile(
117
+ r"^(quαΊ­n|quan|q\.?|huyện|huyen|h\.?|tx\.?)\s*", re.I)
118
+ _NUM_STR = re.compile(r"^(\d+[a-z]?(?:/\d+[a-z]?)*)[\s,]+(.+)", re.I)
119
+
120
+ # _NC_* operate on slug text (unidecode+lower β€” no diacritics)
121
+ _NC_PROV = re.compile(
122
+ r"\b(tphcm|hcm|hanoi|saigon|sai gon"
123
+ r"|ho chi minh|hai phong|da nang|can tho|hue"
124
+ r"|tp\s+[\w\s]{1,20}|tinh\s+[\w\s]{1,20})\b", re.I)
125
+ _NC_DIST = re.compile(r"\b(q\.?\s*\d+|quan\s*\d+|h\.\s*\w+|huyen\s+\w+)\b", re.I)
126
+ _NC_WARD = re.compile(r"^(phuong|xa|tt|p\.\s*|x\.\s*)([\w][\w\s]*)", re.I)
127
+
128
+
129
+ def _extract(raw: str) -> dict:
130
+ """Parse comma-separated address into components."""
131
+ parts = [p.strip() for p in re.split(r"[,;]", raw) if p.strip()]
132
+ r = {"ward": None, "province": None, "district_hint": None}
133
+ if parts:
134
+ m = _NUM_STR.match(parts[0])
135
+ if m:
136
+ parts = [m.group(2)] + parts[1:]
137
+ for part in parts:
138
+ if _PROV_PFX.match(part): r["province"] = _PROV_PFX.sub("", part).strip()
139
+ elif _DIST_PFX.match(part): r["district_hint"] = part
140
+ elif _WARD_PFX.match(part): r["ward"] = _WARD_PFX.sub("", part).strip()
141
+ elif not r["ward"]: r["ward"] = part
142
+ if not r["province"] and len(parts) >= 2:
143
+ r["province"] = parts[-1]
144
+ return r
145
+
146
+
147
+ def _parse_no_comma(raw: str) -> dict:
148
+ """Parse space-only address on slug text."""
149
+ r = {"ward": None, "province": None, "district_hint": None}
150
+ text = slug(raw)
151
+ m = _NC_PROV.search(text)
152
+ if m:
153
+ r["province"] = m.group(0)
154
+ text = (text[:m.start()] + " " + text[m.end():]).strip()
155
+ m = _NC_DIST.search(text)
156
+ if m:
157
+ r["district_hint"] = m.group(0)
158
+ text = (text[:m.start()] + " " + text[m.end():]).strip()
159
+ text = text.strip()
160
+ m = _NC_WARD.match(text)
161
+ r["ward"] = m.group(2).strip() if m else text
162
+ return r
163
+
164
+
165
+ def detect_prov(raw: str):
166
+ comps = _extract(raw) if "," in raw else _parse_no_comma(raw)
167
+ for field in ["province", "district_hint"]:
168
+ v = comps.get(field)
169
+ if v:
170
+ r = _resolve_prov(slug(v))
171
+ if r:
172
+ return r
173
+ return _resolve_prov(slug(raw))
174
+
175
+
176
+ # ── Ward hint extractor ───────────────────────────────────────────────────────
177
+ _WS = re.compile(r"\b(?:phuong|p\.|p\s|xa|x\.)\s*([a-z0-9][a-z0-9\s]{1,40})", re.I)
178
+ _NUM = re.compile(r"^\d{1,3}$")
179
+
180
+
181
+ def detect_ward(raw: str, prov: str):
182
+ m = _WS.search(slug(raw))
183
+ if not m:
184
+ return None, None
185
+ words = m.group(1).strip().split()
186
+ for n in range(min(4, len(words)), 0, -1):
187
+ cand = " ".join(words[:n])
188
+ lead = cand.split()[0] if cand.split() else cand
189
+ if _NUM.match(lead):
190
+ return None, "numbered"
191
+ for ws in [cand,
192
+ re.sub(r"^(phuong|xa|thi tran)\s+", "", cand).strip()]:
193
+ if prov:
194
+ canons = pw_to_c.get((prov, ws), [])
195
+ if canons:
196
+ return ws, canons
197
+ rb = ward_idx.get(ws, []) + legacy_idx.get(ws, [])
198
+ if rb:
199
+ pf = [c for c in rb if prov and prov in c] if prov else rb
200
+ if pf:
201
+ return ws, pf
202
+ return None, None
203
+
204
+
205
+ # ── Trie ──────────────────────────────────────────────────────────────────────
206
+ class TrieNode:
207
+ __slots__ = ("children", "is_terminal")
208
+ def __init__(self):
209
+ self.children = {}
210
+ self.is_terminal = False
211
+
212
+
213
+ class Trie:
214
+ def __init__(self, strings=None):
215
+ self.root = TrieNode()
216
+ if strings:
217
+ for s in strings:
218
+ self.insert(s)
219
+
220
+ def insert(self, s: str):
221
+ n = self.root
222
+ for c in s:
223
+ if c not in n.children:
224
+ n.children[c] = TrieNode()
225
+ n = n.children[c]
226
+ n.is_terminal = True
227
+
228
+ def valid_next(self, p: str):
229
+ n = self.root
230
+ for c in p:
231
+ if c not in n.children:
232
+ return frozenset(), False
233
+ n = n.children[c]
234
+ return frozenset(n.children.keys()), n.is_terminal
235
+
236
+ def accepts(self, s: str) -> bool:
237
+ n = self.root
238
+ for c in s:
239
+ if c not in n.children:
240
+ return False
241
+ n = n.children[c]
242
+ return n.is_terminal
243
+
244
+
245
+ full_trie = Trie(clean)
246
+ _pt: dict = {}
247
+
248
+
249
+ def get_pt(prov: str) -> Trie:
250
+ if prov not in _pt:
251
+ _pt[prov] = Trie(prov_to_c.get(prov, []))
252
+ return _pt[prov]
253
+
254
+
255
+ print("Tries built.", flush=True)
256
+
257
+
258
+ # ── Seq2Seq model ─────────────────────────────────────────────────────────────
259
+ class S2S(nn.Module):
260
+ def __init__(self):
261
+ super().__init__()
262
+ D = cfg["D_MODEL"]
263
+ self.src_emb = nn.Embedding(cfg["SRC_VOCAB"], D, padding_idx=0)
264
+ self.src_pos = nn.Embedding(cfg["MAX_SRC"], D)
265
+ el = nn.TransformerEncoderLayer(
266
+ D, cfg["N_HEADS"], cfg["D_FF"], .1,
267
+ batch_first=True, norm_first=True, activation="gelu")
268
+ self.encoder = nn.TransformerEncoder(el, cfg["ENC_LAYERS"])
269
+ self.enc_norm = nn.LayerNorm(D)
270
+ self.tgt_emb = nn.Embedding(cfg["TGT_VOCAB"], D, padding_idx=0)
271
+ self.tgt_pos = nn.Embedding(cfg["MAX_TGT"], D)
272
+ dl = nn.TransformerDecoderLayer(
273
+ D, cfg["N_HEADS"], cfg["D_FF"], .1,
274
+ batch_first=True, norm_first=True, activation="gelu")
275
+ self.decoder = nn.TransformerDecoder(dl, cfg["DEC_LAYERS"])
276
+ self.dec_norm = nn.LayerNorm(D)
277
+ self.out_proj = nn.Linear(D, cfg["TGT_VOCAB"])
278
+
279
+ def encode(self, src):
280
+ B, L = src.shape
281
+ h = (self.src_emb(src)
282
+ + self.src_pos(torch.arange(L, device=src.device)))
283
+ h = self.encoder(h, src_key_padding_mask=(src == 0))
284
+ return self.enc_norm(h), (src == 0)
285
+
286
+ def step(self, tgt, mem, sp):
287
+ L = tgt.shape[1]
288
+ cm = nn.Transformer.generate_square_subsequent_mask(L, device=tgt.device)
289
+ h = (self.tgt_emb(tgt)
290
+ + self.tgt_pos(torch.arange(L, device=tgt.device)))
291
+ h = self.decoder(h, mem, tgt_mask=cm, memory_key_padding_mask=sp)
292
+ return self.out_proj(self.dec_norm(h))[:, -1, :]
293
+
294
+
295
+ def _load_model() -> S2S:
296
+ m = S2S()
297
+ sf = MODEL_DIR / "model.safetensors"
298
+ pt = MODEL_DIR / "model_best.pt"
299
+ if sf.exists():
300
+ try:
301
+ from safetensors.torch import load_file
302
+ m.load_state_dict(load_file(str(sf)))
303
+ print("Model loaded (safetensors).", flush=True)
304
+ return m
305
+ except Exception as e:
306
+ print(f"safetensors failed ({e}), trying .pt", flush=True)
307
+ if pt.exists():
308
+ m.load_state_dict(
309
+ torch.load(str(pt), map_location="cpu", weights_only=True))
310
+ print("Model loaded (.pt).", flush=True)
311
+ return m
312
+ raise FileNotFoundError(
313
+ f"No model weights in {MODEL_DIR}. "
314
+ "Expected model.safetensors or model_best.pt.")
315
+
316
+
317
+ model = _load_model()
318
+ model.eval()
319
+
320
+
321
+ def enc_src(text: str) -> list:
322
+ ids = ([SRC_BOS]
323
+ + [src_ch2id.get(c, SRC_UNK) for c in text[:cfg["MAX_SRC"] - 2]]
324
+ + [SRC_EOS])
325
+ return ids + [SRC_PAD] * (cfg["MAX_SRC"] - len(ids))
326
+
327
+
328
+ def beam_search(mem, sp, trie: Trie, B: int = 5, maxs: int = 96):
329
+ dev = mem.device
330
+ beams = [(0., "", [TGT_BOS])]
331
+ done = []
332
+ for _ in range(maxs - 1):
333
+ if not beams:
334
+ break
335
+ nb = []
336
+ for sc, cs, ids in beams:
337
+ vc, it = trie.valid_next(cs)
338
+ if it and not vc:
339
+ done.append((sc, cs))
340
+ continue
341
+ tgt = torch.tensor([ids], dtype=torch.long, device=dev)
342
+ with torch.no_grad():
343
+ lp = F.log_softmax(model.step(tgt, mem, sp)[0], dim=-1)
344
+ cands = []
345
+ if it:
346
+ cands.append((sc + lp[TGT_EOS].item(), cs, ids + [TGT_EOS], True))
347
+ for c in vc:
348
+ if c in tgt_ch2id:
349
+ cid = tgt_ch2id[c]
350
+ cands.append((sc + lp[cid].item(), cs + c, ids + [cid], False))
351
+ if not cands:
352
+ if it:
353
+ done.append((sc, cs))
354
+ continue
355
+ cands.sort(key=lambda x: x[0], reverse=True)
356
+ for ns, nss, ni, d in cands[:B]:
357
+ if d:
358
+ done.append((ns, nss))
359
+ else:
360
+ nb.append((ns, nss, ni))
361
+ nb.sort(key=lambda x: x[0], reverse=True)
362
+ beams = nb[:B]
363
+ for sc, s, _ in beams:
364
+ _, it = trie.valid_next(s)
365
+ if it:
366
+ done.append((sc, s))
367
+ if not done:
368
+ return "", 0.
369
+ done.sort(key=lambda x: x[0], reverse=True)
370
+ return done[0][1], done[0][0]
371
+
372
+
373
+ # ── Public API ────────────────────────────────────────────────────────────────
374
+ def normalize(raw: str, beam_size: int = 5) -> dict:
375
+ """
376
+ Normalize a Vietnamese address string.
377
+
378
+ Args:
379
+ raw: Raw address string, e.g. "p tan dinh q1 tphcm".
380
+ Accepts Vietnamese diacritics or ASCII-slugified input.
381
+ Truncated to 300 characters if longer.
382
+ beam_size: Beam width. Higher = better accuracy, slower (default 5).
383
+
384
+ Returns:
385
+ dict:
386
+ canonical (str) β€” normalized address; empty if not found
387
+ valid (bool) β€” True if canonical is in the address database
388
+ confidence (float) β€” raw log-prob score (higher = more confident)
389
+ province (str) β€” resolved province name, or None
390
+ ward_hint (str) β€” detected ward slug, or None
391
+ search_space (int) β€” number of trie candidates searched
392
+ latency_ms (float) β€” wall-clock time in milliseconds
393
+ """
394
+ if not raw or not raw.strip():
395
+ return {
396
+ "canonical": "", "valid": False, "confidence": 0.,
397
+ "province": None, "ward_hint": None,
398
+ "search_space": 0, "latency_ms": 0.,
399
+ }
400
+
401
+ raw = raw.strip()[:300]
402
+
403
+ t0 = time.perf_counter()
404
+ src = torch.tensor([enc_src(raw)], dtype=torch.long)
405
+ with torch.no_grad():
406
+ mem, sp = model.encode(src)
407
+
408
+ prov = detect_prov(raw)
409
+ ward_hint = None
410
+ ward_c = None
411
+
412
+ if prov:
413
+ ward_hint, ward_c = detect_ward(raw, prov)
414
+ if ward_c == "numbered":
415
+ return {
416
+ "canonical": "", "valid": False, "confidence": 0.,
417
+ "province": prov, "ward_hint": None,
418
+ "search_space": 0,
419
+ "latency_ms": round((time.perf_counter() - t0) * 1e3, 1),
420
+ }
421
+
422
+ if ward_hint and isinstance(ward_c, list) and ward_c:
423
+ trie = Trie(ward_c)
424
+ n = len(ward_c)
425
+ elif prov and prov_to_c.get(prov):
426
+ trie = get_pt(prov)
427
+ n = len(prov_to_c[prov])
428
+ else:
429
+ trie = full_trie
430
+ n = len(clean)
431
+
432
+ res, sc = beam_search(mem, sp, trie, B=beam_size)
433
+ ms = round((time.perf_counter() - t0) * 1e3, 1)
434
+
435
+ return {
436
+ "canonical": res,
437
+ "valid": bool(res and full_trie.accepts(res)),
438
+ "confidence": round(float(sc), 4),
439
+ "province": prov,
440
+ "ward_hint": ward_hint,
441
+ "search_space": n,
442
+ "latency_ms": ms,
443
+ }
444
+
445
+
446
+ # ── CLI ───────────────────────────────────────────────────────────────────────
447
+ if __name__ == "__main__":
448
+ if len(sys.argv) < 2:
449
+ print("Usage: python inference.py \"Δ‘α»‹a chỉ cαΊ§n normalize\"")
450
+ sys.exit(1)
451
+
452
+ address = " ".join(sys.argv[1:])
453
+ r = normalize(address)
454
+ print(f"Input: {address}")
455
+ print(f"Canonical: {r['canonical'] or '(not found)'}")
456
+ print(f"Valid: {r['valid']}")
457
+ print(f"Province: {r['province'] or '(unknown)'}")
458
+ print(f"Ward hint: {r['ward_hint'] or '(none)'}")
459
+ print(f"Space: {r['search_space']:,} candidates")
460
+ print(f"Latency: {r['latency_ms']} ms")