File size: 48,275 Bytes
4b7bdcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 | """
LLM-based entity detection using AWS Bedrock.
This module provides functions to detect PII entities using LLMs instead of AWS llm.
"""
import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import boto3
from gradio import Progress
from tools.config import (
CHOSEN_LLM_PII_INFERENCE_METHOD,
CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE,
CLOUD_LLM_PII_MODEL_CHOICE,
INFERENCE_SERVER_API_URL,
LLM_MAX_NEW_TOKENS,
LLM_TEMPERATURE,
model_name_map,
)
from tools.llm_entity_detection_prompts import (
create_entity_detection_prompt,
create_entity_detection_system_prompt,
)
# Max length for column/sheet name in tabular log filenames (to keep filenames short)
LLM_LOG_TABULAR_NAME_MAX_LEN = 25
# Import LLM functions from local tools.llm_funcs
try:
# Use send_request from llm_funcs.py which handles all model sources, retries, and response parsing
from tools.llm_funcs import (
send_request,
)
except ImportError as e:
print(f"Warning: Could not import LLM functions: {e}")
print("LLM-based entity detection will not be available.")
print("Please ensure llm_funcs.py is in the tools folder.")
call_aws_bedrock = None
construct_azure_client = None
ResponseObject = None
def _find_text_in_passage(
search_text: str,
original_text: str,
reported_offset: Optional[int] = None,
start_from: int = 0,
) -> Optional[Tuple[int, int]]:
"""
Find the position of search_text in original_text and return (begin, end) offsets.
Only considers occurrences at or after start_from. This allows a "first pass" where
each entity is matched starting after the previous entity's end, so repeated phrases
(e.g. "University of Notre Dame" vs "University" + "of Notre Dame") map to the
correct occurrence.
Args:
search_text: The text to search for
original_text: The text to search in
reported_offset: Optional offset reported by LLM (used to disambiguate multiple matches)
start_from: Only consider matches at or after this position (default 0).
Returns:
Tuple of (begin_offset, end_offset) if found, None otherwise
"""
if not search_text:
return None
def first_or_closest(
positions: List[int], length: int
) -> Optional[Tuple[int, int]]:
candidates = [p for p in positions if p >= start_from]
if not candidates:
return None
if reported_offset is not None:
closest_pos = min(candidates, key=lambda p: abs(p - reported_offset))
else:
closest_pos = min(candidates)
return (closest_pos, closest_pos + length)
# Clean search text - remove trailing ellipsis that LLM might add
search_text_clean = search_text.rstrip("...").strip()
# Find all occurrences of the exact text
all_positions = []
start = 0
while True:
pos = original_text.find(search_text, start)
if pos == -1:
break
all_positions.append(pos)
start = pos + 1
if all_positions:
result = first_or_closest(all_positions, len(search_text))
if result is not None:
return result
# Try with cleaned text (without ellipsis) if original didn't match
if search_text_clean != search_text:
all_positions_clean = []
start = 0
while True:
pos = original_text.find(search_text_clean, start)
if pos == -1:
break
all_positions_clean.append(pos)
start = pos + 1
if all_positions_clean:
result = first_or_closest(all_positions_clean, len(search_text_clean))
if result is not None:
return result
# Try case-insensitive match
search_text_lower = search_text.lower()
original_text_lower = original_text.lower()
all_positions_lower = []
start = 0
while True:
pos = original_text_lower.find(search_text_lower, start)
if pos == -1:
break
all_positions_lower.append(pos)
start = pos + 1
if all_positions_lower:
result = first_or_closest(all_positions_lower, len(search_text))
if result is not None:
return result
# Try case-insensitive match with cleaned text
if search_text_clean != search_text:
search_text_clean_lower = search_text_clean.lower()
all_positions_clean_lower = []
start = 0
while True:
pos = original_text_lower.find(search_text_clean_lower, start)
if pos == -1:
break
all_positions_clean_lower.append(pos)
start = pos + 1
if all_positions_clean_lower:
result = first_or_closest(all_positions_clean_lower, len(search_text_clean))
if result is not None:
return result
return None
def _find_all_text_in_passage(
search_text: str, original_text: str
) -> List[Tuple[int, int]]:
"""
Find all positions of search_text in original_text and return a list of (begin, end) offsets.
Uses the same search strategy as _find_text_in_passage (exact, then cleaned, then case-insensitive).
LLM offset values are never used; positions come only from search.
Returns:
List of (begin_offset, end_offset) tuples, sorted by begin_offset (ascending).
"""
if not search_text:
return []
search_text_clean = search_text.rstrip("...").strip()
def find_all_exact(needle: str, haystack: str) -> List[Tuple[int, int]]:
result = []
start = 0
while True:
pos = haystack.find(needle, start)
if pos == -1:
break
result.append((pos, pos + len(needle)))
start = pos + 1
return result
positions = find_all_exact(search_text, original_text)
if positions:
return sorted(positions, key=lambda p: p[0])
if search_text_clean != search_text:
positions = find_all_exact(search_text_clean, original_text)
if positions:
return sorted(positions, key=lambda p: p[0])
# Case-insensitive
needle_lower = search_text.lower()
haystack_lower = original_text.lower()
positions = find_all_exact(needle_lower, haystack_lower)
if positions:
# Return (start, start + len(search_text)) so length matches original entity text
return sorted(
[(p[0], p[0] + len(search_text)) for p in positions], key=lambda p: p[0]
)
if search_text_clean != search_text:
needle_clean_lower = search_text_clean.lower()
positions = find_all_exact(needle_clean_lower, haystack_lower)
if positions:
return sorted(
[(p[0], p[0] + len(search_text_clean)) for p in positions],
key=lambda p: p[0],
)
return []
def _entity_get(obj: Dict[str, Any], key: str, default: Any = None) -> Any:
"""Get value from entity dict with case-insensitive key lookup (e.g. BeginOffset vs beginOffset)."""
key_lower = key.lower()
for k, v in obj.items():
if k.lower() == key_lower:
return v
return default
def parse_llm_entity_response(
response_text: str,
original_text: str,
) -> List[Dict[str, Any]]:
"""
Parse LLM response and extract entity information.
LLM BeginOffset/EndOffset are used only to define order. Positions are
resolved by a first-pass text search: for each entity (in reported order),
search for the entity's Text in the passage starting from the end of the
preceding entity's resolved span. If not found there, search from the
start of the passage. This ensures repeated phrases (e.g. "University of
Notre Dame" once, then "University" and "of Notre Dame" separately) map
to the correct occurrence and avoid duplicate redaction boxes.
Args:
response_text: The LLM response text (should contain JSON)
original_text: The original text that was analyzed (for validation)
Returns:
List of entity dictionaries with keys: Type, BeginOffset, EndOffset, Score, Text
"""
entities_out: List[Dict[str, Any]] = []
# Remove <think> tags and their content (common in some LLM outputs)
# This handles cases where LLMs include thinking/reasoning tags
response_text = re.sub(
r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE
)
response_text = re.sub(
r"<thinking>.*?</thinking>", "", response_text, flags=re.DOTALL | re.IGNORECASE
)
# Prefer extracting from markdown code block (e.g. ```json\n...\n```<end_of_turn>)
# so we get a clean slice and can strip trailing tokens before parsing
json_str = None
if "```json" in response_text or "```" in response_text:
code_block = re.search(
r"```(?:json)?\s*\n?(.*?)(?:\n?```|$)", response_text, re.DOTALL
)
if code_block:
candidate = code_block.group(1).strip()
# Strip trailing tokens that some models append (e.g. <end_of_turn>)
candidate = re.sub(r"<end_of_turn>\s*$", "", candidate, flags=re.IGNORECASE)
candidate = candidate.rstrip()
# Extract only the root JSON object by brace matching so we never include trailing garbage
start = candidate.find("{")
if start >= 0:
depth = 0
for i in range(start, len(candidate)):
if candidate[i] == "{":
depth += 1
elif candidate[i] == "}":
depth -= 1
if depth == 0:
json_str = candidate[start : i + 1]
break
if json_str is None:
json_str = candidate[start:] # fallback: from first { to end
# Fallback: try regex-based extraction (fragile for nested braces)
if json_str is None:
json_match = re.search(
r'\{[^{}]*"entities"[^{}]*\[.*?\].*?\}', response_text, re.DOTALL
)
if not json_match:
json_match = re.search(r'\{.*?"entities".*?\}', response_text, re.DOTALL)
if json_match:
json_str = json_match.group(0)
if json_str:
try:
# Clean up the JSON string (in case we came from regex path)
json_str = json_str.strip()
# Remove markdown code block markers if present (regex path may include them)
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
# Strip trailing tokens again (e.g. <end_of_turn> after closing })
json_str = re.sub(r"<end_of_turn>\s*$", "", json_str, flags=re.IGNORECASE)
json_str = json_str.strip()
# Keep only the root object if trailing garbage remains (brace-match from start)
start = json_str.find("{")
if start >= 0:
depth = 0
for i in range(start, len(json_str)):
if json_str[i] == "{":
depth += 1
elif json_str[i] == "}":
depth -= 1
if depth == 0:
json_str = json_str[start : i + 1]
break
# Fix common JSON issues:
# 1. Remove trailing commas before closing brackets/braces
json_str = re.sub(r",\s*}", "}", json_str)
json_str = re.sub(r",\s*]", "]", json_str)
# 2. Fix unquoted string values (e.g., "Type": NAME should be "Type": "NAME")
# This handles cases where LLMs output unquoted identifiers as values
# Pattern: "key": VALUE where VALUE is an unquoted identifier
def fix_unquoted_value(match):
key_part = match.group(1) # The key (e.g., "Type")
value = match.group(2) # The unquoted value
separator = match.group(3) # The separator (comma, closing brace, etc.)
# Only fix if it looks like an identifier (alphanumeric/underscore, not a number or boolean)
if re.match(
r"^[A-Za-z_][A-Za-z0-9_]*$", value
) and value.lower() not in ["true", "false", "null"]:
return f'{key_part}: "{value}"{separator}'
return match.group(0) # Return original if it doesn't need fixing
# Fix unquoted string values after colons (common in LLM outputs)
# Match: "key": VALUE where VALUE is unquoted identifier followed by comma, }, or ]
# This pattern handles: "Type": NAME, or "Type": EMAIL_ADDRESS}
json_str = re.sub(
r'("[\w]+")\s*:\s*([A-Za-z_][A-Za-z0-9_]*)\s*([,}\]])',
fix_unquoted_value,
json_str,
)
# Also handle cases where unquoted value is at end of line or followed by newline
json_str = re.sub(
r'("[\w]+")\s*:\s*([A-Za-z_][A-Za-z0-9_]*)\s*(\n)',
r'\1: "\2"\3',
json_str,
)
# Final trim: strip trailing whitespace, control chars, backticks, and truncate to root object only
# (avoids "Expecting ',' delimiter" when trailing \r, ```, <end_of_turn>, or other bytes remain)
json_str = json_str.rstrip().rstrip("\r\t")
json_str = re.sub(r"[ \t\r\n]+$", "", json_str)
json_str = re.sub(r"`+$", "", json_str)
json_str = re.sub(r"<end_of_turn>\s*$", "", json_str, flags=re.IGNORECASE)
json_str = json_str.rstrip()
start = json_str.find("{")
if start >= 0:
depth = 0
for i in range(start, len(json_str)):
if json_str[i] == "{":
depth += 1
elif json_str[i] == "}":
depth -= 1
if depth == 0:
json_str = json_str[start : i + 1]
break
# Try to parse the JSON
try:
data = json.loads(json_str)
except json.JSONDecodeError as e:
# If parsing still fails, try a more aggressive fix for unquoted values
# This is a fallback that quotes any unquoted identifiers after colons
print(
f"Initial JSON parse failed: {e}. Attempting more aggressive fixes..."
)
# More aggressive fix: quote any unquoted word after a colon that's not already quoted
# Pattern: ": WORD" where WORD is not in quotes and not a number/boolean
def quote_unquoted_identifier(match):
prefix = match.group(1) # Everything before the colon
value = match.group(2) # The unquoted value
suffix = match.group(3) # Everything after (comma, brace, etc.)
# Only quote if it's a valid identifier and not a boolean/null
if re.match(
r"^[A-Za-z_][A-Za-z0-9_]*$", value
) and value.lower() not in ["true", "false", "null"]:
return f'{prefix}: "{value}"{suffix}'
return match.group(0)
# Try fixing unquoted values more aggressively
json_str = re.sub(
r"(:\s*)([A-Za-z_][A-Za-z0-9_]*)(\s*[,}\]])",
quote_unquoted_identifier,
json_str,
)
# Try parsing again
try:
data = json.loads(json_str)
except json.JSONDecodeError as e2:
print(f"JSON parsing failed after fixes: {e2}")
print(f"Cleaned JSON string (first 1000 chars): {json_str[:1000]}")
raise e2
if "entities" in data and isinstance(data["entities"], list):
# Collect raw entity records (Type, Text, Score, reported BeginOffset for order only)
raw_entities: List[Dict[str, Any]] = []
for entity in data["entities"]:
entity_type_val = _entity_get(entity, "Type")
if entity_type_val is None:
print(f"Warning: Entity missing Type field: {entity}")
continue
entity_text = _entity_get(entity, "Text", "")
reported_begin = _entity_get(entity, "BeginOffset")
if reported_begin is not None:
try:
reported_begin = int(reported_begin)
except (ValueError, TypeError):
reported_begin = None
reported_end = _entity_get(entity, "EndOffset")
if reported_end is not None:
try:
reported_end = int(reported_end)
except (ValueError, TypeError):
reported_end = None
# If no Text, try to derive from reported offsets (for display/grouping only)
if (
not entity_text
and reported_begin is not None
and reported_end is not None
and 0 <= reported_begin < reported_end <= len(original_text)
):
entity_text = original_text[reported_begin:reported_end]
if not entity_text:
print(
f"Warning: Entity of type '{entity_type_val}' has no Text value and invalid offsets"
)
continue
raw_entities.append(
{
"Type": str(entity_type_val),
"Text": entity_text,
"Score": float(_entity_get(entity, "Score", 0.8)),
"reported_begin": reported_begin,
}
)
# Process entities in reported order. First-pass: search for each entity's
# Text starting from the preceding entity's EndOffset; if not found, search
# from the start of the passage. This disambiguates repeated phrases.
ordered = sorted(
raw_entities,
key=lambda r: (
r["reported_begin"] is None,
r["reported_begin"] or 0,
),
)
search_start = 0
for rec in ordered:
search_text = rec["Text"]
result = _find_text_in_passage(
search_text,
original_text,
reported_offset=rec["reported_begin"],
start_from=search_start,
)
if result is None:
result = _find_text_in_passage(
search_text,
original_text,
reported_offset=rec["reported_begin"],
start_from=0,
)
if result is None:
print(
f"Warning: Could not find text '{search_text[:50]}...' in original passage"
)
continue
start, end = result
entities_out.append(
{
"Type": rec["Type"],
"BeginOffset": start,
"EndOffset": end,
"Score": rec["Score"],
"Text": original_text[start:end],
}
)
search_start = end
except json.JSONDecodeError as e:
print(f"Error parsing JSON from LLM response: {e}")
print(f"Response text: {response_text[:500]}")
except (ValueError, KeyError) as e:
print(f"Error processing entity data: {e}")
else:
print("Warning: Could not find JSON in LLM response")
print(f"Response text: {response_text[:500]}")
return entities_out
def _sanitize_for_filename(s: str, max_len: Optional[int] = None) -> str:
"""Sanitize a string for use in a filename (alphanumeric, spaces to underscores)."""
out = (
"".join(c for c in (s or "") if c.isalnum() or c in (" ", "-", "_"))
.strip()
.replace(" ", "_")
)
if max_len is not None and len(out) > max_len:
out = out[:max_len]
return out or "unknown"
def save_llm_prompt_response(
system_prompt: str,
user_prompt: str,
response_text: str,
output_folder: str,
batch_number: int,
model_choice: str,
entities_to_detect: List[str],
language: str,
temperature: float,
max_tokens: int,
file_name: Optional[str] = None,
page_number: Optional[int] = None,
sheet_name: Optional[str] = None,
column_name: Optional[str] = None,
row_number: Optional[int] = None,
input_tokens: Optional[int] = None,
output_tokens: Optional[int] = None,
) -> str:
"""
Save LLM prompt and response to a text file for traceability.
Writes the exact system prompt and user prompt that were sent to the model
(e.g. for local transformers, inference-server, AWS, etc.). Each section is
clearly delimited so the log never duplicates or conflates system vs user.
Args:
system_prompt: System prompt sent to LLM (exactly as passed to the model).
user_prompt: User prompt sent to LLM (exactly as passed to the model).
response_text: Response text from LLM
output_folder: Output folder path
batch_number: Batch number for this call
model_choice: Model used
entities_to_detect: List of entities being detected
language: Language code
temperature: Temperature used
max_tokens: Max tokens used
file_name: Optional file name (without extension) for the filename / log header
page_number: Optional page number (0-based) for the filename; displayed in log as 1-based.
sheet_name: Optional Excel sheet name (tabular data); included in log and filename if present.
column_name: Optional column name (tabular data); included in log and filename (shortened if long).
row_number: Optional row number (1-based for display; tabular data); included in log and filename.
input_tokens: Optional input token count from the LLM call
output_tokens: Optional output token count from the LLM call
Returns:
Path to the saved file
"""
# Normalise to strings so we never write "None" or non-string types
system_prompt_str = (system_prompt if system_prompt is not None else "").strip()
user_prompt_str = (user_prompt if user_prompt is not None else "").strip()
# Create LLM logs subfolder
llm_logs_folder = os.path.join(output_folder, "llm_prompts_responses")
os.makedirs(llm_logs_folder, exist_ok=True)
# Tabular: filename = sheet (if relevant) + column (shortened) + row
is_tabular = (
column_name is not None or sheet_name is not None or row_number is not None
)
if is_tabular:
parts = ["llm"]
if sheet_name:
parts.append(
_sanitize_for_filename(sheet_name, LLM_LOG_TABULAR_NAME_MAX_LEN)
)
if column_name:
parts.append(
_sanitize_for_filename(column_name, LLM_LOG_TABULAR_NAME_MAX_LEN)
)
if row_number is not None:
parts.append(f"row{row_number:05d}")
parts.append(f"batch_{batch_number:04d}")
filename = "_".join(parts) + ".txt"
elif file_name and page_number is not None:
# Document: file name + page number
safe_file_name = _sanitize_for_filename(file_name)
filename = (
f"llm_{safe_file_name}_page_{page_number:04d}_batch_{batch_number:04d}.txt"
)
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"llm_batch_{batch_number:04d}_{timestamp}.txt"
filepath = os.path.join(llm_logs_folder, filename)
# Write prompt and response to file with explicit section boundaries
# so system and user prompts are never duplicated or mixed.
with open(filepath, "w", encoding="utf-8") as f:
f.write("=" * 80 + "\n")
f.write("LLM ENTITY DETECTION - PROMPT AND RESPONSE LOG\n")
f.write("=" * 80 + "\n\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
if file_name:
f.write(f"File: {file_name}\n")
if sheet_name:
f.write(f"Sheet: {sheet_name}\n")
if column_name is not None:
f.write(f"Column: {column_name}\n")
if row_number is not None:
f.write(f"Row: {row_number}\n")
if page_number is not None:
f.write(f"Page: {page_number + 1}\n")
if input_tokens is not None:
f.write(f"Input tokens: {input_tokens}\n")
if output_tokens is not None:
f.write(f"Output tokens: {output_tokens}\n")
f.write(f"Batch Number: {batch_number}\n")
f.write(f"Model: {model_choice}\n")
f.write(f"Language: {language}\n")
f.write(f"Temperature: {temperature}\n")
f.write(f"Max Tokens: {max_tokens}\n")
f.write(f"Entities to Detect: {', '.join(entities_to_detect)}\n")
f.write("\n" + "=" * 80 + "\n")
f.write("SYSTEM PROMPT (sent as system role)\n")
f.write("=" * 80 + "\n")
f.write("--- BEGIN SYSTEM PROMPT ---\n")
f.write(system_prompt_str)
f.write("\n--- END SYSTEM PROMPT ---\n")
f.write("\n" + "=" * 80 + "\n")
f.write("USER PROMPT (sent as user role)\n")
f.write("=" * 80 + "\n")
if (
system_prompt_str
and user_prompt_str
and system_prompt_str == user_prompt_str
):
f.write(
"[NOTE: System and user prompt content were identical - check caller.]\n"
)
f.write("--- BEGIN USER PROMPT ---\n")
f.write(user_prompt_str)
f.write("\n--- END USER PROMPT ---\n")
f.write("\n\n" + "=" * 80 + "\n")
f.write("LLM RESPONSE\n")
f.write("=" * 80 + "\n\n")
f.write(response_text)
f.write("\n\n" + "=" * 80 + "\n")
f.write("END OF LOG\n")
f.write("=" * 80 + "\n")
return filepath
def call_llm_for_entity_detection(
text: str,
entities_to_detect: List[str],
language: str,
bedrock_runtime: Optional[boto3.Session.client] = None,
model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE,
temperature: float = LLM_TEMPERATURE,
max_tokens: int = LLM_MAX_NEW_TOKENS,
max_retries: int = 10,
retry_delay: int = 3,
output_folder: Optional[str] = None,
batch_number: int = 0,
custom_instructions: str = "",
file_name: Optional[str] = None,
page_number: Optional[int] = None,
sheet_name: Optional[str] = None,
column_name: Optional[str] = None,
row_number: Optional[int] = None,
inference_method: Optional[str] = None,
local_model=None,
tokenizer=None,
assistant_model=None,
client=None,
client_config=None,
api_url: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Call LLM to detect entities in text using various inference methods.
Args:
text: Text to analyze
entities_to_detect: List of entity types to detect
language: Language code
bedrock_runtime: AWS Bedrock runtime client (required for AWS method)
model_choice: Model identifier (varies by inference method)
temperature: Temperature for LLM generation (lower = more deterministic)
max_tokens: Maximum tokens in response
max_retries: Maximum retry attempts
retry_delay: Delay between retries (seconds)
output_folder: Optional folder to save prompt/response logs
batch_number: Batch number for logging
custom_instructions: Optional custom instructions to include in the prompt
file_name: Optional file name (without extension) for saving logs
page_number: Optional page number for saving logs (document flow)
sheet_name: Optional Excel sheet name for tabular logs
column_name: Optional column name for tabular logs
row_number: Optional row number (1-based) for tabular logs
inference_method: Inference method to use ("aws-bedrock", "local", "inference-server", "azure-openai", "gemini")
If None, uses CHOSEN_LLM_PII_INFERENCE_METHOD from config
local_model: Local model instance (required for "local" method)
tokenizer: Tokenizer instance (required for "local" method with transformers)
assistant_model: Assistant model for speculative decoding (optional)
client: API client (required for "azure-openai" or "gemini" methods)
client_config: Client config (required for "gemini" method)
api_url: API URL for inference-server (required for "inference-server" method)
Returns:
List of entity dictionaries
"""
# Ensure custom_instructions is a string (callers may pass bool or other types).
# Treat boolean True and the string "True" as empty (e.g. from an unchecked/empty Gradio box).
if not isinstance(custom_instructions, str):
custom_instructions = (
""
if custom_instructions is True or not custom_instructions
else str(custom_instructions)
)
if (
isinstance(custom_instructions, str)
and custom_instructions.strip().lower() == "true"
):
custom_instructions = ""
# Determine inference method
if inference_method is None:
inference_method = CHOSEN_LLM_PII_INFERENCE_METHOD
# When custom instructions are provided, use the upgraded model if configured
custom_instructions_model = (
CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip()
if isinstance(CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, str)
and CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip()
else ""
)
if (
custom_instructions.strip()
and model_choice == CLOUD_LLM_PII_MODEL_CHOICE
and custom_instructions_model
):
model_choice = custom_instructions_model
# Filter out CUSTOM_VLM_* entities (these are handled separately via VLM)
filtered_entities = [
entity for entity in entities_to_detect if not entity.startswith("CUSTOM_VLM_")
]
# No standard entities and no custom instructions
if not filtered_entities and (
not custom_instructions or not custom_instructions.strip()
):
# Nothing selected at all → error
if not entities_to_detect:
raise ValueError(
"No standard entities selected and no custom instructions provided. "
"Please select at least one entity type (excluding CUSTOM_VLM_* entities) or provide custom instructions for LLM-based PII detection."
)
# Only CUSTOM_VLM_* entities selected (handled separately via VLM) → return blank
return []
# Determine model source from model_choice if using model_name_map
model_source = None
if model_choice and model_name_map and model_choice in model_name_map:
model_source = model_name_map[model_choice].get("source", "AWS")
# Map model source to inference method
if model_source == "Local":
inference_method = "local"
elif model_source == "inference-server":
inference_method = "inference-server"
elif model_source == "Azure/OpenAI":
inference_method = "azure-openai"
elif model_source == "Gemini":
inference_method = "gemini"
elif model_source == "AWS":
inference_method = "aws-bedrock"
system_prompt = create_entity_detection_system_prompt(
filtered_entities, language, custom_instructions
)
user_prompt = create_entity_detection_prompt(
text, filtered_entities, language, custom_instructions
)
# Map inference_method to model_source format expected by send_request
model_source_map = {
"aws-bedrock": "AWS",
"local": "Local",
"inference-server": "inference-server",
"azure-openai": "Azure/OpenAI",
"gemini": "Gemini",
}
model_source = model_source_map.get(inference_method, "AWS")
# Prepare client and config for Gemini if needed
if inference_method == "gemini" and (client is None or client_config is None):
from tools.llm_funcs import construct_gemini_generative_model
try:
client, client_config = construct_gemini_generative_model(
in_api_key="", # Will use environment variable
temperature=temperature,
model_choice=model_choice,
system_prompt=system_prompt,
max_tokens=max_tokens, # Use our specific max_tokens for entity detection
)
except Exception as e:
raise ValueError(
f"Failed to construct Gemini client: {e}. "
f"Ensure GEMINI_API_KEY is set or pass client and client_config."
)
# Prepare client for Azure/OpenAI if needed
if inference_method == "azure-openai" and client is None:
from tools.llm_funcs import construct_azure_client
try:
client, _ = construct_azure_client(
in_api_key="", # Will use environment variable
endpoint="", # Will use environment variable
)
except Exception as e:
raise ValueError(
f"Failed to construct Azure/OpenAI client: {e}. "
f"Ensure AZURE_OPENAI_API_KEY is set or pass client."
)
# Set up API URL for inference-server if needed
if inference_method == "inference-server" and api_url is None:
api_url = INFERENCE_SERVER_API_URL
if not api_url:
raise ValueError(
"api_url is required when using inference-server method. "
"Set INFERENCE_SERVER_API_URL in config or pass api_url parameter."
)
try:
# Call send_request which handles all routing, retries, and response parsing
# Note: send_request signature shows local_model=list() but it's actually used as a single model object
(
response,
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
) = send_request(
prompt=user_prompt,
conversation_history=[], # Empty for entity detection (no conversation history needed)
client=client,
config=client_config,
model_choice=model_choice,
system_prompt=system_prompt,
temperature=temperature,
bedrock_runtime=bedrock_runtime,
model_source=model_source,
# local_model=(
# local_model if local_model else []
# ), # Pass model directly (signature shows list but uses as single object)
# tokenizer=tokenizer,
# assistant_model=assistant_model,
progress=Progress(
track_tqdm=False
), # Disable progress bar for entity detection
api_url=api_url,
)
except Exception as e:
print(f"LLM entity detection failed: {e}")
raise
# Extract token usage from response (before save so we can write it to the log file)
input_tokens = 0
output_tokens = 0
try:
if isinstance(response, dict) and "usage" in response:
# inference-server or llama-cpp format
input_tokens = response["usage"].get("prompt_tokens", 0)
output_tokens = response["usage"].get("completion_tokens", 0)
elif hasattr(response, "usage_metadata"):
# Check if it's AWS Bedrock format
if isinstance(response.usage_metadata, dict):
input_tokens = response.usage_metadata.get("inputTokens", 0)
output_tokens = response.usage_metadata.get("outputTokens", 0)
# Check if it's Gemini format
elif hasattr(response.usage_metadata, "prompt_token_count"):
input_tokens = response.usage_metadata.prompt_token_count
output_tokens = response.usage_metadata.candidates_token_count
except (KeyError, AttributeError) as e:
print(f"Warning: Could not extract token usage from response: {e}")
# Fallback for Local/transformers: response is plain text, so use token counts from send_request
if num_transformer_input_tokens and num_transformer_input_tokens > 0:
input_tokens = num_transformer_input_tokens
if num_transformer_generated_tokens and num_transformer_generated_tokens > 0:
output_tokens = num_transformer_generated_tokens
# Save prompt and response if output_folder is provided.
# Use the same system_prompt and user_prompt that were sent to the model
# (no combined/rendered version) so the log correctly shows system vs user.
if output_folder and response_text:
try:
saved_file = save_llm_prompt_response(
system_prompt=system_prompt,
user_prompt=user_prompt,
response_text=response_text,
output_folder=output_folder,
batch_number=batch_number,
model_choice=model_choice,
entities_to_detect=entities_to_detect,
language=language,
temperature=temperature,
max_tokens=max_tokens,
file_name=file_name,
page_number=page_number,
sheet_name=sheet_name,
column_name=column_name,
row_number=row_number,
input_tokens=input_tokens,
output_tokens=output_tokens,
)
if 0 == 1: # To avoid lint check issue
print(f"Saved LLM prompt/response to: {saved_file}")
except Exception as e:
print(f"Warning: Could not save LLM prompt/response: {e}")
# Parse the response
entities = parse_llm_entity_response(response_text, text)
return entities, input_tokens, output_tokens
def map_back_llm_entity_results(
entities: List[Dict[str, Any]],
current_batch_mapping: List[Tuple],
allow_list: List[str],
chosen_redact_llm_entities: List[str],
all_text_line_results: List[Tuple],
) -> List[Tuple]:
"""
Map LLM-detected entities back to line-level results.
Similar to map_back_llm_entity_results but for LLM responses.
Args:
entities: List of entity dictionaries from LLM
current_batch_mapping: Mapping of batch positions to line indices
allow_list: List of allowed text values (to skip) - case-insensitive matching
chosen_redact_llm_entities: List of entity types to include
all_text_line_results: Existing line-level results to append to
Returns:
Updated all_text_line_results
"""
if not entities:
return all_text_line_results
# Normalize allow_list for case-insensitive matching
if allow_list:
allow_list_normalized = [item.strip().lower() for item in allow_list if item]
else:
allow_list_normalized = []
for entity in entities:
entity_type = entity.get("Type")
# Allow all entity types returned by LLM, including custom types from custom instructions
# Log when a custom entity type (not in the original list) is found
# if entity_type not in chosen_redact_llm_entities:
# print(
# f"Info: Found custom entity type '{entity_type}' (not in original detection list). "
# f"Including it in results as it was returned by LLM."
# )
entity_start = entity["BeginOffset"]
entity_end = entity["EndOffset"]
# Track if the entity has been added to any line
added_to_line = False
# Find the correct line and offset within that line
for (
batch_start,
line_idx,
original_line,
chars,
line_offset,
) in current_batch_mapping:
# Calculate the end position of this line segment in the batch
if line_offset is not None:
# Line offset is the start position within the line
line_text_length = len(original_line.text[line_offset:])
else:
line_text_length = len(original_line.text)
batch_end = batch_start + line_text_length
# Check if the entity overlaps with the current line
if batch_start < entity_end and batch_end > entity_start:
# Calculate the relative position within the line
if line_offset is not None:
relative_start = max(0, entity_start - batch_start + line_offset)
relative_end = min(
entity_end - batch_start + line_offset, len(original_line.text)
)
else:
relative_start = max(0, entity_start - batch_start)
relative_end = min(
entity_end - batch_start, len(original_line.text)
)
result_text = original_line.text[relative_start:relative_end]
# Check if result_text is in allow_list (case-insensitive)
# If allow_list contains this text, skip adding it as a PII entity
# This allows allow_list terms to "overrule" LLM PII detection
result_text_normalized = result_text.strip().lower()
if result_text_normalized not in allow_list_normalized:
# Create entity dict in llm-like format
adjusted_entity = {
"Type": entity_type,
"BeginOffset": relative_start,
"EndOffset": relative_end,
"Score": entity.get("Score", 0.8),
}
# Import here to avoid circular imports
from tools.presidio_analyzer_custom import (
recognizer_result_from_dict,
)
recogniser_entity = recognizer_result_from_dict(adjusted_entity)
# Check if this line already has an entry
existing_entry = next(
(
entry
for idx, entry in all_text_line_results
if idx == line_idx
),
None,
)
if existing_entry is None:
all_text_line_results.append((line_idx, [recogniser_entity]))
else:
existing_entry.append(recogniser_entity)
added_to_line = True
# Optional: Handle cases where the entity does not fit in any line
if not added_to_line:
print(
f"Entity '{entity_type}' at position {entity_start}-{entity_end} does not fit in any line."
)
return all_text_line_results
def do_llm_entity_detection_call(
current_batch: str,
current_batch_mapping: List[Tuple],
bedrock_runtime: Optional[boto3.Session.client] = None,
language: str = "en",
allow_list: List[str] = None,
chosen_redact_llm_entities: List[str] = None,
all_text_line_results: List[Tuple] = None,
model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE,
temperature: float = LLM_TEMPERATURE,
max_tokens: int = LLM_MAX_NEW_TOKENS,
output_folder: Optional[str] = None,
batch_number: int = 0,
custom_instructions: str = "",
file_name: Optional[str] = None,
page_number: Optional[int] = None,
inference_method: Optional[str] = None,
local_model=None,
tokenizer=None,
assistant_model=None,
client=None,
client_config=None,
api_url: Optional[str] = None,
) -> Tuple[List[Tuple], int, int]:
"""
Call LLM for entity detection on a batch of text.
Similar interface to do_aws_llm_call.
Args:
current_batch: Text batch to analyze
current_batch_mapping: Mapping of batch positions to line indices
bedrock_runtime: AWS Bedrock runtime client (required for AWS method)
language: Language code
allow_list: List of allowed text values
chosen_redact_llm_entities: List of entity types to detect
all_text_line_results: Existing line-level results
model_choice: Model identifier (varies by inference method)
temperature: Temperature for LLM generation
max_tokens: Maximum tokens in response
output_folder: Optional folder to save prompt/response logs
batch_number: Batch number for logging
custom_instructions: Optional custom instructions to include in the prompt
file_name: Optional file name (without extension) for saving logs
page_number: Optional page number for saving logs
inference_method: Inference method to use (if None, uses config default)
local_model: Local model instance (required for "local" method)
tokenizer: Tokenizer instance (required for "local" method with transformers)
assistant_model: Assistant model for speculative decoding (optional)
client: API client (required for "azure-openai" or "gemini" methods)
client_config: Client config (required for "gemini" method)
api_url: API URL for inference-server (required for "inference-server" method)
Returns:
Tuple of (updated all_text_line_results, input_tokens, output_tokens)
"""
if not current_batch:
return (all_text_line_results or [], 0, 0)
if allow_list is None:
allow_list = []
if chosen_redact_llm_entities is None:
chosen_redact_llm_entities = []
if all_text_line_results is None:
all_text_line_results = []
try:
entities, input_tokens, output_tokens = call_llm_for_entity_detection(
text=current_batch.strip(),
entities_to_detect=chosen_redact_llm_entities,
language=language,
bedrock_runtime=bedrock_runtime,
model_choice=model_choice,
temperature=temperature,
max_tokens=max_tokens,
output_folder=output_folder,
batch_number=batch_number,
custom_instructions=custom_instructions,
file_name=file_name,
page_number=page_number,
inference_method=inference_method,
local_model=local_model,
tokenizer=tokenizer,
assistant_model=assistant_model,
client=client,
client_config=client_config,
api_url=api_url,
)
all_text_line_results = map_back_llm_entity_results(
entities,
current_batch_mapping,
allow_list,
chosen_redact_llm_entities,
all_text_line_results,
)
return all_text_line_results, input_tokens, output_tokens
except Exception as e:
print(f"LLM entity detection call failed: {e}")
raise
|