Update README.md
Browse files
README.md
CHANGED
|
@@ -102,7 +102,7 @@ formats = {
|
|
| 102 |
"entity_swapping": """<|im_start|>system\nEntity Swapping<|im_end|>\n<|im_start|>user\nentities:{entities}\ntext:\n{text}<|im_end|>\n<|im_start|>assistant\n"""
|
| 103 |
}
|
| 104 |
|
| 105 |
-
def model_inference(text, mode="anonymization", max_new_tokens=
|
| 106 |
if mode not in formats and mode != "anonymization":
|
| 107 |
raise ValueError("Invalid mode. Choose from 'sensitivity', 'complexity', 'entity_detection', 'anonymization'.")
|
| 108 |
|
|
@@ -154,7 +154,6 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
|
|
| 154 |
# Step 2: Select entities based on config
|
| 155 |
selected_entities = select_entities_based_on_json(detected_entities, config)
|
| 156 |
entities_str = "\n".join([f"{entity} : {label}" for entity, label in selected_entities])
|
| 157 |
-
|
| 158 |
# Step 3: Entity swapping for anonymization
|
| 159 |
swapping_prompt = formats["entity_swapping"].format(entities=entities_str, text=text)
|
| 160 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
|
@@ -168,24 +167,25 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
|
|
| 168 |
anonymized_text = tokenizer.decode(swapping_output[0], skip_special_tokens=True)
|
| 169 |
anonymized_text = anonymized_text.split("assistant\n", 1)[-1].strip() # Extract only the assistant's response
|
| 170 |
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
# Entity Restoration Mode using entity_swapping
|
| 174 |
elif mode == "entity_swapping" and entity_mapping:
|
| 175 |
-
#
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
|
| 184 |
-
# Combine all replacement mappings for the prompt
|
| 185 |
-
reversed_entities_str = "\n".join(reversed_entities)
|
| 186 |
|
| 187 |
# Create the swapping prompt with the aggregated reversed mappings
|
| 188 |
-
swapping_prompt = formats["entity_swapping"].format(entities=
|
| 189 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
| 190 |
swapping_output = model.generate(
|
| 191 |
**swapping_inputs,
|
|
@@ -206,7 +206,7 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
|
|
| 206 |
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 207 |
generation_output = model.generate(
|
| 208 |
**model_inputs,
|
| 209 |
-
max_new_tokens=
|
| 210 |
use_cache=True,
|
| 211 |
eos_token_id=151645
|
| 212 |
)
|
|
@@ -224,7 +224,7 @@ def postprocess_entity_recognition(detection_output: str) -> dict:
|
|
| 224 |
entity_pattern = re.compile(
|
| 225 |
r'(?P<entity>[\w\s]+)--(?P<type>[\w]+)--(?P<random>[\w\s]+)--(?P<generalizations>.+)'
|
| 226 |
)
|
| 227 |
-
generalization_pattern = re.compile(r'(\
|
| 228 |
|
| 229 |
lines = detection_output.strip().split("\n")
|
| 230 |
for line in lines:
|
|
@@ -236,8 +236,21 @@ def postprocess_entity_recognition(detection_output: str) -> dict:
|
|
| 236 |
|
| 237 |
generalizations = []
|
| 238 |
for gen_match in generalization_pattern.findall(match.group("generalizations")):
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
output_json[entity_name] = {
|
| 243 |
"TYPE": entity_type,
|
|
@@ -304,11 +317,16 @@ To protect sensitive information, the model detects specific entities in the tex
|
|
| 304 |
|
| 305 |
```python
|
| 306 |
# Anonymize the text
|
| 307 |
-
anonymized_text
|
| 308 |
print(f"Anonymized Text: {anonymized_text}\n")
|
|
|
|
| 309 |
|
|
|
|
| 310 |
# Restore the original text
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
| 312 |
print(f"Restored Text: {restored_text}")
|
| 313 |
```
|
| 314 |
|
|
|
|
| 102 |
"entity_swapping": """<|im_start|>system\nEntity Swapping<|im_end|>\n<|im_start|>user\nentities:{entities}\ntext:\n{text}<|im_end|>\n<|im_start|>assistant\n"""
|
| 103 |
}
|
| 104 |
|
| 105 |
+
def model_inference(text, mode="anonymization", max_new_tokens=2028, config=None, entity_mapping=None, return_entities=False, reverse_mapping=False):
|
| 106 |
if mode not in formats and mode != "anonymization":
|
| 107 |
raise ValueError("Invalid mode. Choose from 'sensitivity', 'complexity', 'entity_detection', 'anonymization'.")
|
| 108 |
|
|
|
|
| 154 |
# Step 2: Select entities based on config
|
| 155 |
selected_entities = select_entities_based_on_json(detected_entities, config)
|
| 156 |
entities_str = "\n".join([f"{entity} : {label}" for entity, label in selected_entities])
|
|
|
|
| 157 |
# Step 3: Entity swapping for anonymization
|
| 158 |
swapping_prompt = formats["entity_swapping"].format(entities=entities_str, text=text)
|
| 159 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
|
|
|
| 167 |
anonymized_text = tokenizer.decode(swapping_output[0], skip_special_tokens=True)
|
| 168 |
anonymized_text = anonymized_text.split("assistant\n", 1)[-1].strip() # Extract only the assistant's response
|
| 169 |
|
| 170 |
+
if return_entities:
|
| 171 |
+
return anonymized_text, entities_str
|
| 172 |
+
|
| 173 |
+
return anonymized_text
|
| 174 |
|
| 175 |
# Entity Restoration Mode using entity_swapping
|
| 176 |
elif mode == "entity_swapping" and entity_mapping:
|
| 177 |
+
# Reverse the entity mapping
|
| 178 |
+
if reverse_mapping:
|
| 179 |
+
reversed_mapping = []
|
| 180 |
+
for line in entity_mapping.splitlines():
|
| 181 |
+
if ':' in line: # Ensure the line contains a colon
|
| 182 |
+
left, right = map(str.strip, line.split(":", 1)) # Split and strip spaces
|
| 183 |
+
reversed_mapping.append(f"{right} : {left}") # Reverse and format
|
| 184 |
+
entity_mapping = "\n".join(reversed_mapping)
|
| 185 |
|
|
|
|
|
|
|
| 186 |
|
| 187 |
# Create the swapping prompt with the aggregated reversed mappings
|
| 188 |
+
swapping_prompt = formats["entity_swapping"].format(entities=entity_mapping, text=text)
|
| 189 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
| 190 |
swapping_output = model.generate(
|
| 191 |
**swapping_inputs,
|
|
|
|
| 206 |
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 207 |
generation_output = model.generate(
|
| 208 |
**model_inputs,
|
| 209 |
+
max_new_tokens=5,
|
| 210 |
use_cache=True,
|
| 211 |
eos_token_id=151645
|
| 212 |
)
|
|
|
|
| 224 |
entity_pattern = re.compile(
|
| 225 |
r'(?P<entity>[\w\s]+)--(?P<type>[\w]+)--(?P<random>[\w\s]+)--(?P<generalizations>.+)'
|
| 226 |
)
|
| 227 |
+
generalization_pattern = re.compile(r'([\w\s]+)::([\w\s]+)')
|
| 228 |
|
| 229 |
lines = detection_output.strip().split("\n")
|
| 230 |
for line in lines:
|
|
|
|
| 236 |
|
| 237 |
generalizations = []
|
| 238 |
for gen_match in generalization_pattern.findall(match.group("generalizations")):
|
| 239 |
+
first, second = gen_match
|
| 240 |
+
|
| 241 |
+
# Check if the first part is a digit (score) and swap if needed
|
| 242 |
+
if first.isdigit() and not second.isdigit():
|
| 243 |
+
score = first
|
| 244 |
+
label = second
|
| 245 |
+
generalizations.append([label.strip(), score.strip()])
|
| 246 |
+
|
| 247 |
+
elif not first.isdigit() and second.isdigit():
|
| 248 |
+
label = first
|
| 249 |
+
score = second
|
| 250 |
+
generalizations.append([label.strip(), score.strip()])
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
|
| 255 |
output_json[entity_name] = {
|
| 256 |
"TYPE": entity_type,
|
|
|
|
| 317 |
|
| 318 |
```python
|
| 319 |
# Anonymize the text
|
| 320 |
+
anonymized_text = model_inference(text, mode="anonymization")
|
| 321 |
print(f"Anonymized Text: {anonymized_text}\n")
|
| 322 |
+
```
|
| 323 |
|
| 324 |
+
```python
|
| 325 |
# Restore the original text
|
| 326 |
+
anonymized_text, entity_mapping = model_inference(text, mode="anonymization", return_entities=True)
|
| 327 |
+
print(f"Entity Mapping:\n{entity_mapping}\n")
|
| 328 |
+
print(f"Anonymized Text: {anonymized_text}\n")
|
| 329 |
+
restored_text = model_inference(anonymized_text, mode="entity_swapping", entity_mapping=entity_mapping, reverse_mapping=True)
|
| 330 |
print(f"Restored Text: {restored_text}")
|
| 331 |
```
|
| 332 |
|