akhatre commited on
Commit
dff5567
·
1 Parent(s): ac320eb

minor code changes and readme improvements

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. anonymise.py +128 -63
  3. requirements.txt +2 -0
README.md CHANGED
@@ -31,7 +31,7 @@ model-index:
31
  pipeline_tag: token-classification
32
  ---
33
 
34
- # NERPA Fine-Tuned GLiNER2 for PII Anonymisation
35
 
36
  A fine-tuned [GLiNER2 Large](https://huggingface.co/fastino/gliner2-large-v1) (340M params) model trained to detect Personally Identifiable Information (PII) in text. Built as a flexible, self-hosted replacement for AWS Comprehend at [Overmind](https://overmindlab.ai).
37
 
@@ -164,7 +164,7 @@ entities = detect_entities(model, text, entities={
164
 
165
  The inference pipeline in `anonymise.py`:
166
 
167
- 1. **Chunking** — Long texts are split into 3000-character chunks with 100-char overlap to stay within the model's context window. Specific chunk size can be varied since DeBERTa-v2 (underlying encoder) uses relative position encoding. We found that this size works as well as smaller ones.
168
  2. **Batch prediction** — Chunks are fed through `GLiNER2.batch_extract_entities()` with `include_spans=True` to get character-level offsets.
169
  3. **Date disambiguation** — Both `DATE_TIME` and `DATE_OF_BIRTH` are always detected together so the model can choose the best label per span.
170
  4. **De-duplication** — Overlapping detections from chunk boundaries are merged, keeping the highest-confidence label for each position.
@@ -205,4 +205,4 @@ If you use NERPA, please cite both this model and the original GLiNER2 paper:
205
 
206
  Built by [Akhat Rakishev](https://github.com/akhatre) at [Overmind](https://overmindlab.ai).
207
 
208
- Overmind is infrastructure to make agents more reliable. Learn more at [overmindlab.ai](https://overmindlab.ai).
 
31
  pipeline_tag: token-classification
32
  ---
33
 
34
+ # NERPA - Fine-Tuned GLiNER2 for PII Anonymisation
35
 
36
  A fine-tuned [GLiNER2 Large](https://huggingface.co/fastino/gliner2-large-v1) (340M params) model trained to detect Personally Identifiable Information (PII) in text. Built as a flexible, self-hosted replacement for AWS Comprehend at [Overmind](https://overmindlab.ai).
37
 
 
164
 
165
  The inference pipeline in `anonymise.py`:
166
 
167
+ 1. **Chunking** — Long texts are split into 3000-character chunks with 100-char overlap to stay within the model's context window. Specific chunk size can be varied since DeBERTa-v3 (underlying encoder) uses relative position encoding. We found that this size works as well as smaller ones.
168
  2. **Batch prediction** — Chunks are fed through `GLiNER2.batch_extract_entities()` with `include_spans=True` to get character-level offsets.
169
  3. **Date disambiguation** — Both `DATE_TIME` and `DATE_OF_BIRTH` are always detected together so the model can choose the best label per span.
170
  4. **De-duplication** — Overlapping detections from chunk boundaries are merged, keeping the highest-confidence label for each position.
 
205
 
206
  Built by [Akhat Rakishev](https://github.com/akhatre) at [Overmind](https://overmindlab.ai).
207
 
208
+ Overmind is infrastructure for end-to-end agent optimisation. Learn more at [overmindlab.ai](https://overmindlab.ai).
anonymise.py CHANGED
@@ -8,16 +8,21 @@ Usage:
8
  """
9
 
10
  import argparse
 
11
  import sys
12
- from typing import Dict, List, Tuple
 
 
 
13
 
14
  import torch
15
  from gliner2 import GLiNER2
16
 
 
17
 
18
  # Entity types the model was fine-tuned to recognise, with descriptions
19
  # that guide the bi-encoder towards better detection.
20
- PII_ENTITIES = {
21
  "LOCATION": "Address, country, city, postcode, street, any other location",
22
  "AGE": "Age of a person",
23
  "DIGITAL_KEYS": "Digital keys, passwords, pins used to access anything like servers, banks, APIs, accounts etc",
@@ -39,6 +44,7 @@ PII_ENTITIES = {
39
  CONFIDENCE_THRESHOLD = 0.25
40
  CHUNK_SIZE = 3000
41
  CHUNK_OVERLAP = 100
 
42
 
43
 
44
  def load_model(model_path: str = ".") -> GLiNER2:
@@ -51,40 +57,47 @@ def load_model(model_path: str = ".") -> GLiNER2:
51
  device = torch.device("cpu")
52
 
53
  model = GLiNER2.from_pretrained(model_path)
54
- model.to(device)
 
 
 
 
 
 
55
  return model
56
 
57
 
58
- def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> Tuple[List[str], List[int]]:
 
 
 
 
59
  """Split text into overlapping chunks, returning chunks and their start offsets."""
60
  if not text:
61
  return [], []
62
- chunks, starts = [], []
 
63
  step = chunk_size - overlap
64
- pos = 0
65
- while pos < len(text):
66
  chunks.append(text[pos : pos + chunk_size])
67
  starts.append(pos)
68
- if pos + chunk_size >= len(text):
69
- break
70
- pos += step
71
  return chunks, starts
72
 
73
 
74
  def detect_entities(
75
  model: GLiNER2,
76
  text: str,
77
- entities: Dict[str, str] = None,
78
  threshold: float = CONFIDENCE_THRESHOLD,
79
- ) -> List[dict]:
80
  """
81
  Detect PII entities in text, returning a list of
82
- {"type": str, "start": int, "end": int, "score": float} dicts
83
  with character offsets into the original text.
84
  """
85
  entities = entities or PII_ENTITIES
86
 
87
- # Always detect both date types so the model can disambiguate
88
  detect = dict(entities)
89
  if "DATE_TIME" in detect and "DATE_OF_BIRTH" not in detect:
90
  detect["DATE_OF_BIRTH"] = PII_ENTITIES["DATE_OF_BIRTH"]
@@ -93,9 +106,9 @@ def detect_entities(
93
 
94
  chunks, offsets = chunk_text(text)
95
 
96
- all_chunk_results = []
97
- for batch_start in range(0, len(chunks), 32):
98
- batch = chunks[batch_start : batch_start + 32]
99
  results = model.batch_extract_entities(
100
  batch,
101
  detect,
@@ -105,63 +118,108 @@ def detect_entities(
105
  )
106
  all_chunk_results.extend(results)
107
 
108
- # Merge results across chunks: de-duplicate overlapping detections
109
- seen: Dict[Tuple[int, int], dict] = {}
110
  for chunk_result, chunk_offset in zip(all_chunk_results, offsets):
111
  for label, occurrences in chunk_result["entities"].items():
112
- for occ in occurrences:
113
- start = occ["start"] + chunk_offset
114
- end = occ["end"] + chunk_offset
115
- pos = (start, end)
116
- if pos not in seen or seen[pos]["score"] < occ["confidence"]:
117
- seen[pos] = {"type": label, "score": occ["confidence"]}
118
-
119
- # Merge overlapping spans, keeping highest confidence label
 
 
 
 
 
 
 
 
120
  items = sorted(
121
- [(s, e, info) for (s, e), info in seen.items() if info["type"] in entities],
 
 
 
 
122
  key=lambda x: (x[0], x[1]),
123
  )
124
  if not items:
125
  return []
126
 
127
- merged = []
128
- cur_s, cur_e, cur_info = items[0]
129
- for s, e, info in items[1:]:
130
- if s < cur_e: # overlapping
131
- cur_e = max(cur_e, e)
132
- if info["score"] > cur_info["score"]:
133
- cur_info = info
134
  else:
135
- merged.append({"type": cur_info["type"], "start": cur_s, "end": cur_e, "score": cur_info["score"]})
136
- cur_s, cur_e, cur_info = s, e, info
137
- merged.append({"type": cur_info["type"], "start": cur_s, "end": cur_e, "score": cur_info["score"]})
 
 
 
 
 
 
 
 
 
 
138
 
139
  return merged
140
 
141
 
142
- def anonymise(text: str, detected: List[dict]) -> str:
143
- """Replace detected entities with placeholders like [PERSON_NAME]."""
144
- # Process from end to start so offsets stay valid
145
- result = text
146
- for entity in sorted(detected, key=lambda e: e["start"], reverse=True):
147
- placeholder = f'[{entity["type"]}]'
148
- result = result[: entity["start"]] + placeholder + result[entity["end"] :]
149
- return result
150
-
151
-
152
- def main():
153
- parser = argparse.ArgumentParser(description="Anonymise PII in text using the NERPA model.")
154
- parser.add_argument("text", nargs="?", help="Text to anonymise (or use --file)")
155
- parser.add_argument("--file", "-f", help="Read text from a file instead")
156
- parser.add_argument("--output", "-o", help="Write anonymised text to file (default: stdout)")
157
- parser.add_argument("--model", "-m", default=".", help="Path to model directory (default: current dir)")
158
- parser.add_argument("--threshold", "-t", type=float, default=CONFIDENCE_THRESHOLD, help="Confidence threshold (default: 0.25)")
159
- parser.add_argument("--show-entities", action="store_true", help="Print detected entities before anonymised text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  args = parser.parse_args()
161
 
162
  if args.file:
163
- with open(args.file) as f:
164
- text = f.read()
 
 
 
165
  elif args.text:
166
  text = args.text
167
  else:
@@ -171,15 +229,22 @@ def main():
171
  detected = detect_entities(model, text, threshold=args.threshold)
172
 
173
  if args.show_entities:
174
- for e in detected:
175
- print(f' {e["type"]:25s} [{e["start"]:5d}:{e["end"]:5d}] (score={e["score"]:.2f}) "{text[e["start"]:e["end"]]}"', file=sys.stderr)
176
- print(file=sys.stderr)
 
 
 
 
177
 
178
  result = anonymise(text, detected)
179
 
180
  if args.output:
181
- with open(args.output, "w") as f:
182
- f.write(result)
 
 
 
183
  else:
184
  print(result)
185
 
 
8
  """
9
 
10
  import argparse
11
+ import logging
12
  import sys
13
+ import warnings
14
+ from typing import Optional
15
+
16
+ warnings.filterwarnings("ignore", message=r".*incorrect regex pattern.*fix_mistral_regex.*")
17
 
18
  import torch
19
  from gliner2 import GLiNER2
20
 
21
+ logger = logging.getLogger(__name__)
22
 
23
  # Entity types the model was fine-tuned to recognise, with descriptions
24
  # that guide the bi-encoder towards better detection.
25
+ PII_ENTITIES: dict[str, str] = {
26
  "LOCATION": "Address, country, city, postcode, street, any other location",
27
  "AGE": "Age of a person",
28
  "DIGITAL_KEYS": "Digital keys, passwords, pins used to access anything like servers, banks, APIs, accounts etc",
 
44
  CONFIDENCE_THRESHOLD = 0.25
45
  CHUNK_SIZE = 3000
46
  CHUNK_OVERLAP = 100
47
+ BATCH_SIZE = 32
48
 
49
 
50
  def load_model(model_path: str = ".") -> GLiNER2:
 
57
  device = torch.device("cpu")
58
 
59
  model = GLiNER2.from_pretrained(model_path)
60
+ try:
61
+ model.to(device)
62
+ except RuntimeError:
63
+ logger.warning(
64
+ "Failed to load model on %s, falling back to CPU.", device
65
+ )
66
+ model.to(torch.device("cpu"))
67
  return model
68
 
69
 
70
+ def chunk_text(
71
+ text: str,
72
+ chunk_size: int = CHUNK_SIZE,
73
+ overlap: int = CHUNK_OVERLAP,
74
+ ) -> tuple[list[str], list[int]]:
75
  """Split text into overlapping chunks, returning chunks and their start offsets."""
76
  if not text:
77
  return [], []
78
+ chunks: list[str] = []
79
+ starts: list[int] = []
80
  step = chunk_size - overlap
81
+ for pos in range(0, len(text), step):
 
82
  chunks.append(text[pos : pos + chunk_size])
83
  starts.append(pos)
 
 
 
84
  return chunks, starts
85
 
86
 
87
  def detect_entities(
88
  model: GLiNER2,
89
  text: str,
90
+ entities: Optional[dict[str, str]] = None,
91
  threshold: float = CONFIDENCE_THRESHOLD,
92
+ ) -> list[dict]:
93
  """
94
  Detect PII entities in text, returning a list of
95
+ ``{"type": str, "start": int, "end": int, "score": float}`` dicts
96
  with character offsets into the original text.
97
  """
98
  entities = entities or PII_ENTITIES
99
 
100
+ # Always detect both date types so the model can disambiguate.
101
  detect = dict(entities)
102
  if "DATE_TIME" in detect and "DATE_OF_BIRTH" not in detect:
103
  detect["DATE_OF_BIRTH"] = PII_ENTITIES["DATE_OF_BIRTH"]
 
106
 
107
  chunks, offsets = chunk_text(text)
108
 
109
+ all_chunk_results: list[dict] = []
110
+ for batch_start in range(0, len(chunks), BATCH_SIZE):
111
+ batch = chunks[batch_start : batch_start + BATCH_SIZE]
112
  results = model.batch_extract_entities(
113
  batch,
114
  detect,
 
118
  )
119
  all_chunk_results.extend(results)
120
 
121
+ # Merge results across chunks: de-duplicate overlapping detections.
122
+ seen: dict[tuple[int, int], dict] = {}
123
  for chunk_result, chunk_offset in zip(all_chunk_results, offsets):
124
  for label, occurrences in chunk_result["entities"].items():
125
+ for occurrence in occurrences:
126
+ start = occurrence["start"] + chunk_offset
127
+ end = occurrence["end"] + chunk_offset
128
+ position = (start, end)
129
+ if (
130
+ position not in seen
131
+ or seen[position]["score"] < occurrence["confidence"]
132
+ ):
133
+ seen[position] = {
134
+ "type": label,
135
+ "score": occurrence["confidence"],
136
+ }
137
+
138
+ # Merge overlapping spans, keeping the highest-confidence label.
139
+ # NOTE: when two spans overlap they are fused into one span and
140
+ # assigned the label with the higher confidence score.
141
  items = sorted(
142
+ [
143
+ (start, end, info)
144
+ for (start, end), info in seen.items()
145
+ if info["type"] in entities
146
+ ],
147
  key=lambda x: (x[0], x[1]),
148
  )
149
  if not items:
150
  return []
151
 
152
+ merged: list[dict] = []
153
+ current_start, current_end, current_info = items[0]
154
+ for start, end, info in items[1:]:
155
+ if start < current_end: # overlapping
156
+ current_end = max(current_end, end)
157
+ if info["score"] > current_info["score"]:
158
+ current_info = info
159
  else:
160
+ merged.append({
161
+ "type": current_info["type"],
162
+ "start": current_start,
163
+ "end": current_end,
164
+ "score": current_info["score"],
165
+ })
166
+ current_start, current_end, current_info = start, end, info
167
+ merged.append({
168
+ "type": current_info["type"],
169
+ "start": current_start,
170
+ "end": current_end,
171
+ "score": current_info["score"],
172
+ })
173
 
174
  return merged
175
 
176
 
177
+ def anonymise(text: str, detected: list[dict]) -> str:
178
+ """Replace detected entities with placeholders like ``[PERSON_NAME]``."""
179
+ parts: list[str] = []
180
+ prev_end = 0
181
+ for entity in sorted(detected, key=lambda e: e["start"]):
182
+ parts.append(text[prev_end : entity["start"]])
183
+ parts.append(f'[{entity["type"]}]')
184
+ prev_end = entity["end"]
185
+ parts.append(text[prev_end:])
186
+ return "".join(parts)
187
+
188
+
189
+ def main() -> None:
190
+ parser = argparse.ArgumentParser(
191
+ description="Anonymise PII in text using the NERPA model.",
192
+ )
193
+ parser.add_argument(
194
+ "text", nargs="?", help="Text to anonymise (or use --file)",
195
+ )
196
+ parser.add_argument(
197
+ "--file", "-f", help="Read text from a file instead",
198
+ )
199
+ parser.add_argument(
200
+ "--output", "-o",
201
+ help="Write anonymised text to file (default: stdout)",
202
+ )
203
+ parser.add_argument(
204
+ "--model", "-m", default=".",
205
+ help="Path to model directory (default: current dir)",
206
+ )
207
+ parser.add_argument(
208
+ "--threshold", "-t", type=float, default=CONFIDENCE_THRESHOLD,
209
+ help=f"Confidence threshold (default: {CONFIDENCE_THRESHOLD})",
210
+ )
211
+ parser.add_argument(
212
+ "--show-entities", action="store_true",
213
+ help="Print detected entities before anonymised text",
214
+ )
215
  args = parser.parse_args()
216
 
217
  if args.file:
218
+ try:
219
+ with open(args.file, encoding="utf-8") as f:
220
+ text = f.read()
221
+ except OSError as exc:
222
+ sys.exit(f"Error reading {args.file}: {exc}")
223
  elif args.text:
224
  text = args.text
225
  else:
 
229
  detected = detect_entities(model, text, threshold=args.threshold)
230
 
231
  if args.show_entities:
232
+ for entity in detected:
233
+ span = text[entity["start"] : entity["end"]]
234
+ logger.info(
235
+ " %-25s [%5d:%5d] (score=%.2f) %r",
236
+ entity["type"], entity["start"], entity["end"],
237
+ entity["score"], span,
238
+ )
239
 
240
  result = anonymise(text, detected)
241
 
242
  if args.output:
243
+ try:
244
+ with open(args.output, "w", encoding="utf-8") as f:
245
+ f.write(result)
246
+ except OSError as exc:
247
+ sys.exit(f"Error writing {args.output}: {exc}")
248
  else:
249
  print(result)
250
 
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gliner2>=1.2.4
2
+ torch>=2.8.0