alaajabari commited on
Commit
56abf08
·
verified ·
1 Parent(s): c8d7c34

Create ner.py

Browse files
Files changed (1) hide show
  1. ner.py +106 -0
ner.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ner_engine.py
2
+
3
+ import json
4
+ import pickle
5
+ from collections import namedtuple
6
+
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+ from Nested.utils.helpers import load_checkpoint
9
+ from Nested.utils.data import get_dataloaders, text2segments
10
+ from NER_Distiller import distill_entities
11
+
12
+
13
+ # =============================
14
+ # Load model ONCE (important)
15
+ # =============================
16
+ checkpoint_path = snapshot_download(
17
+ repo_id="SinaLab/Nested",
18
+ allow_patterns="checkpoints/"
19
+ )
20
+
21
+ args_path = hf_hub_download(
22
+ repo_id="SinaLab/Nested",
23
+ filename="args.json"
24
+ )
25
+
26
+ with open(args_path, "r") as f:
27
+ args_data = json.load(f)
28
+
29
+ # load vocab
30
+ with open("Nested/utils/tag_vocab.pkl", "rb") as f:
31
+ label_vocab = pickle.load(f)
32
+
33
+ label_vocab = label_vocab[0]
34
+
35
+
36
+ # =============================
37
+ # Load tagger ONCE
38
+ # =============================
39
+ tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
40
+
41
+
42
+ # =============================
43
+ # Core NER extraction (your logic preserved)
44
+ # =============================
45
+ def extract(sentence: str):
46
+ dataset, token_vocab = text2segments(sentence)
47
+
48
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
49
+ vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
50
+
51
+ dataloader = get_dataloaders(
52
+ (dataset,),
53
+ vocab,
54
+ args_data,
55
+ batch_size=32,
56
+ shuffle=(False,),
57
+ )[0]
58
+
59
+ segments = tagger.infer(dataloader)
60
+
61
+ lists = []
62
+
63
+ for segment in segments:
64
+ for token in segment:
65
+ tags = [t["tag"] for t in token.pred_tag]
66
+ tags = [t for t in tags if t not in ("O", " ", "")]
67
+
68
+ lists.append({
69
+ "token": token.text,
70
+ "tags": " ".join(tags) if tags else "O"
71
+ })
72
+
73
+ return lists
74
+
75
+
76
+ # =============================
77
+ # convert format for distiller
78
+ # =============================
79
+ def _to_list_of_lists(json_list):
80
+ return [[d["token"], d["tags"]] for d in json_list]
81
+
82
+
83
+ # =============================
84
+ # FINAL FUNCTION USED BY RE
85
+ # =============================
86
+ def entities_and_types(sentence: str):
87
+ """
88
+ Returns:
89
+ dict: {entity_text: entity_type}
90
+ """
91
+
92
+ ner_output = extract(sentence)
93
+ converted = _to_list_of_lists(ner_output)
94
+
95
+ entities = distill_entities(converted)
96
+
97
+ entity_dict = {}
98
+
99
+ for item in entities:
100
+ # item format: [text, type, start, end]
101
+ if len(item) >= 2:
102
+ entity_text = item[0].strip()
103
+ entity_type = item[1]
104
+ entity_dict[entity_text] = entity_type
105
+
106
+ return entity_dict