Spaces:
Running on Zero
Running on Zero
| """ | |
| LLM-based entity detection using AWS Bedrock. | |
| This module provides functions to detect PII entities using LLMs instead of AWS llm. | |
| """ | |
| import json | |
| import os | |
| import re | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import boto3 | |
| from gradio import Progress | |
| from tools.config import ( | |
| CHOSEN_LLM_PII_INFERENCE_METHOD, | |
| CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, | |
| CLOUD_LLM_PII_MODEL_CHOICE, | |
| INFERENCE_SERVER_API_URL, | |
| LLM_MAX_NEW_TOKENS, | |
| LLM_TEMPERATURE, | |
| model_name_map, | |
| ) | |
| from tools.llm_entity_detection_prompts import ( | |
| create_entity_detection_prompt, | |
| create_entity_detection_system_prompt, | |
| ) | |
| # Max length for column/sheet name in tabular log filenames (to keep filenames short) | |
| LLM_LOG_TABULAR_NAME_MAX_LEN = 25 | |
| # Import LLM functions from local tools.llm_funcs | |
| try: | |
| # Use send_request from llm_funcs.py which handles all model sources, retries, and response parsing | |
| from tools.llm_funcs import ( | |
| send_request, | |
| ) | |
| except ImportError as e: | |
| print(f"Warning: Could not import LLM functions: {e}") | |
| print("LLM-based entity detection will not be available.") | |
| print("Please ensure llm_funcs.py is in the tools folder.") | |
| call_aws_bedrock = None | |
| construct_azure_client = None | |
| ResponseObject = None | |
| def _find_text_in_passage( | |
| search_text: str, | |
| original_text: str, | |
| reported_offset: Optional[int] = None, | |
| start_from: int = 0, | |
| ) -> Optional[Tuple[int, int]]: | |
| """ | |
| Find the position of search_text in original_text and return (begin, end) offsets. | |
| Only considers occurrences at or after start_from. This allows a "first pass" where | |
| each entity is matched starting after the previous entity's end, so repeated phrases | |
| (e.g. "University of Notre Dame" vs "University" + "of Notre Dame") map to the | |
| correct occurrence. | |
| Args: | |
| search_text: The text to search for | |
| original_text: The text to search in | |
| reported_offset: Optional offset reported by LLM (used to disambiguate multiple matches) | |
| start_from: Only consider matches at or after this position (default 0). | |
| Returns: | |
| Tuple of (begin_offset, end_offset) if found, None otherwise | |
| """ | |
| if not search_text: | |
| return None | |
| def first_or_closest( | |
| positions: List[int], length: int | |
| ) -> Optional[Tuple[int, int]]: | |
| candidates = [p for p in positions if p >= start_from] | |
| if not candidates: | |
| return None | |
| if reported_offset is not None: | |
| closest_pos = min(candidates, key=lambda p: abs(p - reported_offset)) | |
| else: | |
| closest_pos = min(candidates) | |
| return (closest_pos, closest_pos + length) | |
| # Clean search text - remove trailing ellipsis that LLM might add | |
| search_text_clean = search_text.rstrip("...").strip() | |
| # Find all occurrences of the exact text | |
| all_positions = [] | |
| start = 0 | |
| while True: | |
| pos = original_text.find(search_text, start) | |
| if pos == -1: | |
| break | |
| all_positions.append(pos) | |
| start = pos + 1 | |
| if all_positions: | |
| result = first_or_closest(all_positions, len(search_text)) | |
| if result is not None: | |
| return result | |
| # Try with cleaned text (without ellipsis) if original didn't match | |
| if search_text_clean != search_text: | |
| all_positions_clean = [] | |
| start = 0 | |
| while True: | |
| pos = original_text.find(search_text_clean, start) | |
| if pos == -1: | |
| break | |
| all_positions_clean.append(pos) | |
| start = pos + 1 | |
| if all_positions_clean: | |
| result = first_or_closest(all_positions_clean, len(search_text_clean)) | |
| if result is not None: | |
| return result | |
| # Try case-insensitive match | |
| search_text_lower = search_text.lower() | |
| original_text_lower = original_text.lower() | |
| all_positions_lower = [] | |
| start = 0 | |
| while True: | |
| pos = original_text_lower.find(search_text_lower, start) | |
| if pos == -1: | |
| break | |
| all_positions_lower.append(pos) | |
| start = pos + 1 | |
| if all_positions_lower: | |
| result = first_or_closest(all_positions_lower, len(search_text)) | |
| if result is not None: | |
| return result | |
| # Try case-insensitive match with cleaned text | |
| if search_text_clean != search_text: | |
| search_text_clean_lower = search_text_clean.lower() | |
| all_positions_clean_lower = [] | |
| start = 0 | |
| while True: | |
| pos = original_text_lower.find(search_text_clean_lower, start) | |
| if pos == -1: | |
| break | |
| all_positions_clean_lower.append(pos) | |
| start = pos + 1 | |
| if all_positions_clean_lower: | |
| result = first_or_closest(all_positions_clean_lower, len(search_text_clean)) | |
| if result is not None: | |
| return result | |
| return None | |
| def _find_all_text_in_passage( | |
| search_text: str, original_text: str | |
| ) -> List[Tuple[int, int]]: | |
| """ | |
| Find all positions of search_text in original_text and return a list of (begin, end) offsets. | |
| Uses the same search strategy as _find_text_in_passage (exact, then cleaned, then case-insensitive). | |
| LLM offset values are never used; positions come only from search. | |
| Returns: | |
| List of (begin_offset, end_offset) tuples, sorted by begin_offset (ascending). | |
| """ | |
| if not search_text: | |
| return [] | |
| search_text_clean = search_text.rstrip("...").strip() | |
| def find_all_exact(needle: str, haystack: str) -> List[Tuple[int, int]]: | |
| result = [] | |
| start = 0 | |
| while True: | |
| pos = haystack.find(needle, start) | |
| if pos == -1: | |
| break | |
| result.append((pos, pos + len(needle))) | |
| start = pos + 1 | |
| return result | |
| positions = find_all_exact(search_text, original_text) | |
| if positions: | |
| return sorted(positions, key=lambda p: p[0]) | |
| if search_text_clean != search_text: | |
| positions = find_all_exact(search_text_clean, original_text) | |
| if positions: | |
| return sorted(positions, key=lambda p: p[0]) | |
| # Case-insensitive | |
| needle_lower = search_text.lower() | |
| haystack_lower = original_text.lower() | |
| positions = find_all_exact(needle_lower, haystack_lower) | |
| if positions: | |
| # Return (start, start + len(search_text)) so length matches original entity text | |
| return sorted( | |
| [(p[0], p[0] + len(search_text)) for p in positions], key=lambda p: p[0] | |
| ) | |
| if search_text_clean != search_text: | |
| needle_clean_lower = search_text_clean.lower() | |
| positions = find_all_exact(needle_clean_lower, haystack_lower) | |
| if positions: | |
| return sorted( | |
| [(p[0], p[0] + len(search_text_clean)) for p in positions], | |
| key=lambda p: p[0], | |
| ) | |
| return [] | |
| def _entity_get(obj: Dict[str, Any], key: str, default: Any = None) -> Any: | |
| """Get value from entity dict with case-insensitive key lookup (e.g. BeginOffset vs beginOffset).""" | |
| key_lower = key.lower() | |
| for k, v in obj.items(): | |
| if k.lower() == key_lower: | |
| return v | |
| return default | |
| def parse_llm_entity_response( | |
| response_text: str, | |
| original_text: str, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Parse LLM response and extract entity information. | |
| LLM BeginOffset/EndOffset are used only to define order. Positions are | |
| resolved by a first-pass text search: for each entity (in reported order), | |
| search for the entity's Text in the passage starting from the end of the | |
| preceding entity's resolved span. If not found there, search from the | |
| start of the passage. This ensures repeated phrases (e.g. "University of | |
| Notre Dame" once, then "University" and "of Notre Dame" separately) map | |
| to the correct occurrence and avoid duplicate redaction boxes. | |
| Args: | |
| response_text: The LLM response text (should contain JSON) | |
| original_text: The original text that was analyzed (for validation) | |
| Returns: | |
| List of entity dictionaries with keys: Type, BeginOffset, EndOffset, Score, Text | |
| """ | |
| entities_out: List[Dict[str, Any]] = [] | |
| # Remove <think> tags and their content (common in some LLM outputs) | |
| # This handles cases where LLMs include thinking/reasoning tags | |
| response_text = re.sub( | |
| r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE | |
| ) | |
| response_text = re.sub( | |
| r"<thinking>.*?</thinking>", "", response_text, flags=re.DOTALL | re.IGNORECASE | |
| ) | |
| # Prefer extracting from markdown code block (e.g. ```json\n...\n```<end_of_turn>) | |
| # so we get a clean slice and can strip trailing tokens before parsing | |
| json_str = None | |
| if "```json" in response_text or "```" in response_text: | |
| code_block = re.search( | |
| r"```(?:json)?\s*\n?(.*?)(?:\n?```|$)", response_text, re.DOTALL | |
| ) | |
| if code_block: | |
| candidate = code_block.group(1).strip() | |
| # Strip trailing tokens that some models append (e.g. <end_of_turn>) | |
| candidate = re.sub(r"<end_of_turn>\s*$", "", candidate, flags=re.IGNORECASE) | |
| candidate = candidate.rstrip() | |
| # Extract only the root JSON object by brace matching so we never include trailing garbage | |
| start = candidate.find("{") | |
| if start >= 0: | |
| depth = 0 | |
| for i in range(start, len(candidate)): | |
| if candidate[i] == "{": | |
| depth += 1 | |
| elif candidate[i] == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| json_str = candidate[start : i + 1] | |
| break | |
| if json_str is None: | |
| json_str = candidate[start:] # fallback: from first { to end | |
| # Fallback: try regex-based extraction (fragile for nested braces) | |
| if json_str is None: | |
| json_match = re.search( | |
| r'\{[^{}]*"entities"[^{}]*\[.*?\].*?\}', response_text, re.DOTALL | |
| ) | |
| if not json_match: | |
| json_match = re.search(r'\{.*?"entities".*?\}', response_text, re.DOTALL) | |
| if json_match: | |
| json_str = json_match.group(0) | |
| if json_str: | |
| try: | |
| # Clean up the JSON string (in case we came from regex path) | |
| json_str = json_str.strip() | |
| # Remove markdown code block markers if present (regex path may include them) | |
| json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) | |
| json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) | |
| # Strip trailing tokens again (e.g. <end_of_turn> after closing }) | |
| json_str = re.sub(r"<end_of_turn>\s*$", "", json_str, flags=re.IGNORECASE) | |
| json_str = json_str.strip() | |
| # Keep only the root object if trailing garbage remains (brace-match from start) | |
| start = json_str.find("{") | |
| if start >= 0: | |
| depth = 0 | |
| for i in range(start, len(json_str)): | |
| if json_str[i] == "{": | |
| depth += 1 | |
| elif json_str[i] == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| json_str = json_str[start : i + 1] | |
| break | |
| # Fix common JSON issues: | |
| # 1. Remove trailing commas before closing brackets/braces | |
| json_str = re.sub(r",\s*}", "}", json_str) | |
| json_str = re.sub(r",\s*]", "]", json_str) | |
| # 2. Fix unquoted string values (e.g., "Type": NAME should be "Type": "NAME") | |
| # This handles cases where LLMs output unquoted identifiers as values | |
| # Pattern: "key": VALUE where VALUE is an unquoted identifier | |
| def fix_unquoted_value(match): | |
| key_part = match.group(1) # The key (e.g., "Type") | |
| value = match.group(2) # The unquoted value | |
| separator = match.group(3) # The separator (comma, closing brace, etc.) | |
| # Only fix if it looks like an identifier (alphanumeric/underscore, not a number or boolean) | |
| if re.match( | |
| r"^[A-Za-z_][A-Za-z0-9_]*$", value | |
| ) and value.lower() not in ["true", "false", "null"]: | |
| return f'{key_part}: "{value}"{separator}' | |
| return match.group(0) # Return original if it doesn't need fixing | |
| # Fix unquoted string values after colons (common in LLM outputs) | |
| # Match: "key": VALUE where VALUE is unquoted identifier followed by comma, }, or ] | |
| # This pattern handles: "Type": NAME, or "Type": EMAIL_ADDRESS} | |
| json_str = re.sub( | |
| r'("[\w]+")\s*:\s*([A-Za-z_][A-Za-z0-9_]*)\s*([,}\]])', | |
| fix_unquoted_value, | |
| json_str, | |
| ) | |
| # Also handle cases where unquoted value is at end of line or followed by newline | |
| json_str = re.sub( | |
| r'("[\w]+")\s*:\s*([A-Za-z_][A-Za-z0-9_]*)\s*(\n)', | |
| r'\1: "\2"\3', | |
| json_str, | |
| ) | |
| # Final trim: strip trailing whitespace, control chars, backticks, and truncate to root object only | |
| # (avoids "Expecting ',' delimiter" when trailing \r, ```, <end_of_turn>, or other bytes remain) | |
| json_str = json_str.rstrip().rstrip("\r\t") | |
| json_str = re.sub(r"[ \t\r\n]+$", "", json_str) | |
| json_str = re.sub(r"`+$", "", json_str) | |
| json_str = re.sub(r"<end_of_turn>\s*$", "", json_str, flags=re.IGNORECASE) | |
| json_str = json_str.rstrip() | |
| start = json_str.find("{") | |
| if start >= 0: | |
| depth = 0 | |
| for i in range(start, len(json_str)): | |
| if json_str[i] == "{": | |
| depth += 1 | |
| elif json_str[i] == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| json_str = json_str[start : i + 1] | |
| break | |
| # Try to parse the JSON | |
| try: | |
| data = json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| # If parsing still fails, try a more aggressive fix for unquoted values | |
| # This is a fallback that quotes any unquoted identifiers after colons | |
| print( | |
| f"Initial JSON parse failed: {e}. Attempting more aggressive fixes..." | |
| ) | |
| # More aggressive fix: quote any unquoted word after a colon that's not already quoted | |
| # Pattern: ": WORD" where WORD is not in quotes and not a number/boolean | |
| def quote_unquoted_identifier(match): | |
| prefix = match.group(1) # Everything before the colon | |
| value = match.group(2) # The unquoted value | |
| suffix = match.group(3) # Everything after (comma, brace, etc.) | |
| # Only quote if it's a valid identifier and not a boolean/null | |
| if re.match( | |
| r"^[A-Za-z_][A-Za-z0-9_]*$", value | |
| ) and value.lower() not in ["true", "false", "null"]: | |
| return f'{prefix}: "{value}"{suffix}' | |
| return match.group(0) | |
| # Try fixing unquoted values more aggressively | |
| json_str = re.sub( | |
| r"(:\s*)([A-Za-z_][A-Za-z0-9_]*)(\s*[,}\]])", | |
| quote_unquoted_identifier, | |
| json_str, | |
| ) | |
| # Try parsing again | |
| try: | |
| data = json.loads(json_str) | |
| except json.JSONDecodeError as e2: | |
| print(f"JSON parsing failed after fixes: {e2}") | |
| print(f"Cleaned JSON string (first 1000 chars): {json_str[:1000]}") | |
| raise e2 | |
| if "entities" in data and isinstance(data["entities"], list): | |
| # Collect raw entity records (Type, Text, Score, reported BeginOffset for order only) | |
| raw_entities: List[Dict[str, Any]] = [] | |
| for entity in data["entities"]: | |
| entity_type_val = _entity_get(entity, "Type") | |
| if entity_type_val is None: | |
| print(f"Warning: Entity missing Type field: {entity}") | |
| continue | |
| entity_text = _entity_get(entity, "Text", "") | |
| reported_begin = _entity_get(entity, "BeginOffset") | |
| if reported_begin is not None: | |
| try: | |
| reported_begin = int(reported_begin) | |
| except (ValueError, TypeError): | |
| reported_begin = None | |
| reported_end = _entity_get(entity, "EndOffset") | |
| if reported_end is not None: | |
| try: | |
| reported_end = int(reported_end) | |
| except (ValueError, TypeError): | |
| reported_end = None | |
| # If no Text, try to derive from reported offsets (for display/grouping only) | |
| if ( | |
| not entity_text | |
| and reported_begin is not None | |
| and reported_end is not None | |
| and 0 <= reported_begin < reported_end <= len(original_text) | |
| ): | |
| entity_text = original_text[reported_begin:reported_end] | |
| if not entity_text: | |
| print( | |
| f"Warning: Entity of type '{entity_type_val}' has no Text value and invalid offsets" | |
| ) | |
| continue | |
| raw_entities.append( | |
| { | |
| "Type": str(entity_type_val), | |
| "Text": entity_text, | |
| "Score": float(_entity_get(entity, "Score", 0.8)), | |
| "reported_begin": reported_begin, | |
| } | |
| ) | |
| # Process entities in reported order. First-pass: search for each entity's | |
| # Text starting from the preceding entity's EndOffset; if not found, search | |
| # from the start of the passage. This disambiguates repeated phrases. | |
| ordered = sorted( | |
| raw_entities, | |
| key=lambda r: ( | |
| r["reported_begin"] is None, | |
| r["reported_begin"] or 0, | |
| ), | |
| ) | |
| search_start = 0 | |
| for rec in ordered: | |
| search_text = rec["Text"] | |
| result = _find_text_in_passage( | |
| search_text, | |
| original_text, | |
| reported_offset=rec["reported_begin"], | |
| start_from=search_start, | |
| ) | |
| if result is None: | |
| result = _find_text_in_passage( | |
| search_text, | |
| original_text, | |
| reported_offset=rec["reported_begin"], | |
| start_from=0, | |
| ) | |
| if result is None: | |
| print( | |
| f"Warning: Could not find text '{search_text[:50]}...' in original passage" | |
| ) | |
| continue | |
| start, end = result | |
| entities_out.append( | |
| { | |
| "Type": rec["Type"], | |
| "BeginOffset": start, | |
| "EndOffset": end, | |
| "Score": rec["Score"], | |
| "Text": original_text[start:end], | |
| } | |
| ) | |
| search_start = end | |
| except json.JSONDecodeError as e: | |
| print(f"Error parsing JSON from LLM response: {e}") | |
| print(f"Response text: {response_text[:500]}") | |
| except (ValueError, KeyError) as e: | |
| print(f"Error processing entity data: {e}") | |
| else: | |
| print("Warning: Could not find JSON in LLM response") | |
| print(f"Response text: {response_text[:500]}") | |
| return entities_out | |
| def _sanitize_for_filename(s: str, max_len: Optional[int] = None) -> str: | |
| """Sanitize a string for use in a filename (alphanumeric, spaces to underscores).""" | |
| out = ( | |
| "".join(c for c in (s or "") if c.isalnum() or c in (" ", "-", "_")) | |
| .strip() | |
| .replace(" ", "_") | |
| ) | |
| if max_len is not None and len(out) > max_len: | |
| out = out[:max_len] | |
| return out or "unknown" | |
| def save_llm_prompt_response( | |
| system_prompt: str, | |
| user_prompt: str, | |
| response_text: str, | |
| output_folder: str, | |
| batch_number: int, | |
| model_choice: str, | |
| entities_to_detect: List[str], | |
| language: str, | |
| temperature: float, | |
| max_tokens: int, | |
| file_name: Optional[str] = None, | |
| page_number: Optional[int] = None, | |
| sheet_name: Optional[str] = None, | |
| column_name: Optional[str] = None, | |
| row_number: Optional[int] = None, | |
| input_tokens: Optional[int] = None, | |
| output_tokens: Optional[int] = None, | |
| ) -> str: | |
| """ | |
| Save LLM prompt and response to a text file for traceability. | |
| Writes the exact system prompt and user prompt that were sent to the model | |
| (e.g. for local transformers, inference-server, AWS, etc.). Each section is | |
| clearly delimited so the log never duplicates or conflates system vs user. | |
| Args: | |
| system_prompt: System prompt sent to LLM (exactly as passed to the model). | |
| user_prompt: User prompt sent to LLM (exactly as passed to the model). | |
| response_text: Response text from LLM | |
| output_folder: Output folder path | |
| batch_number: Batch number for this call | |
| model_choice: Model used | |
| entities_to_detect: List of entities being detected | |
| language: Language code | |
| temperature: Temperature used | |
| max_tokens: Max tokens used | |
| file_name: Optional file name (without extension) for the filename / log header | |
| page_number: Optional page number (0-based) for the filename; displayed in log as 1-based. | |
| sheet_name: Optional Excel sheet name (tabular data); included in log and filename if present. | |
| column_name: Optional column name (tabular data); included in log and filename (shortened if long). | |
| row_number: Optional row number (1-based for display; tabular data); included in log and filename. | |
| input_tokens: Optional input token count from the LLM call | |
| output_tokens: Optional output token count from the LLM call | |
| Returns: | |
| Path to the saved file | |
| """ | |
| # Normalise to strings so we never write "None" or non-string types | |
| system_prompt_str = (system_prompt if system_prompt is not None else "").strip() | |
| user_prompt_str = (user_prompt if user_prompt is not None else "").strip() | |
| # Create LLM logs subfolder | |
| llm_logs_folder = os.path.join(output_folder, "llm_prompts_responses") | |
| os.makedirs(llm_logs_folder, exist_ok=True) | |
| # Tabular: filename = sheet (if relevant) + column (shortened) + row | |
| is_tabular = ( | |
| column_name is not None or sheet_name is not None or row_number is not None | |
| ) | |
| if is_tabular: | |
| parts = ["llm"] | |
| if sheet_name: | |
| parts.append( | |
| _sanitize_for_filename(sheet_name, LLM_LOG_TABULAR_NAME_MAX_LEN) | |
| ) | |
| if column_name: | |
| parts.append( | |
| _sanitize_for_filename(column_name, LLM_LOG_TABULAR_NAME_MAX_LEN) | |
| ) | |
| if row_number is not None: | |
| parts.append(f"row{row_number:05d}") | |
| parts.append(f"batch_{batch_number:04d}") | |
| filename = "_".join(parts) + ".txt" | |
| elif file_name and page_number is not None: | |
| # Document: file name + page number | |
| safe_file_name = _sanitize_for_filename(file_name) | |
| filename = ( | |
| f"llm_{safe_file_name}_page_{page_number:04d}_batch_{batch_number:04d}.txt" | |
| ) | |
| else: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"llm_batch_{batch_number:04d}_{timestamp}.txt" | |
| filepath = os.path.join(llm_logs_folder, filename) | |
| # Write prompt and response to file with explicit section boundaries | |
| # so system and user prompts are never duplicated or mixed. | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| f.write("=" * 80 + "\n") | |
| f.write("LLM ENTITY DETECTION - PROMPT AND RESPONSE LOG\n") | |
| f.write("=" * 80 + "\n\n") | |
| f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| if file_name: | |
| f.write(f"File: {file_name}\n") | |
| if sheet_name: | |
| f.write(f"Sheet: {sheet_name}\n") | |
| if column_name is not None: | |
| f.write(f"Column: {column_name}\n") | |
| if row_number is not None: | |
| f.write(f"Row: {row_number}\n") | |
| if page_number is not None: | |
| f.write(f"Page: {page_number + 1}\n") | |
| if input_tokens is not None: | |
| f.write(f"Input tokens: {input_tokens}\n") | |
| if output_tokens is not None: | |
| f.write(f"Output tokens: {output_tokens}\n") | |
| f.write(f"Batch Number: {batch_number}\n") | |
| f.write(f"Model: {model_choice}\n") | |
| f.write(f"Language: {language}\n") | |
| f.write(f"Temperature: {temperature}\n") | |
| f.write(f"Max Tokens: {max_tokens}\n") | |
| f.write(f"Entities to Detect: {', '.join(entities_to_detect)}\n") | |
| f.write("\n" + "=" * 80 + "\n") | |
| f.write("SYSTEM PROMPT (sent as system role)\n") | |
| f.write("=" * 80 + "\n") | |
| f.write("--- BEGIN SYSTEM PROMPT ---\n") | |
| f.write(system_prompt_str) | |
| f.write("\n--- END SYSTEM PROMPT ---\n") | |
| f.write("\n" + "=" * 80 + "\n") | |
| f.write("USER PROMPT (sent as user role)\n") | |
| f.write("=" * 80 + "\n") | |
| if ( | |
| system_prompt_str | |
| and user_prompt_str | |
| and system_prompt_str == user_prompt_str | |
| ): | |
| f.write( | |
| "[NOTE: System and user prompt content were identical - check caller.]\n" | |
| ) | |
| f.write("--- BEGIN USER PROMPT ---\n") | |
| f.write(user_prompt_str) | |
| f.write("\n--- END USER PROMPT ---\n") | |
| f.write("\n\n" + "=" * 80 + "\n") | |
| f.write("LLM RESPONSE\n") | |
| f.write("=" * 80 + "\n\n") | |
| f.write(response_text) | |
| f.write("\n\n" + "=" * 80 + "\n") | |
| f.write("END OF LOG\n") | |
| f.write("=" * 80 + "\n") | |
| return filepath | |
| def call_llm_for_entity_detection( | |
| text: str, | |
| entities_to_detect: List[str], | |
| language: str, | |
| bedrock_runtime: Optional[boto3.Session.client] = None, | |
| model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE, | |
| temperature: float = LLM_TEMPERATURE, | |
| max_tokens: int = LLM_MAX_NEW_TOKENS, | |
| max_retries: int = 10, | |
| retry_delay: int = 3, | |
| output_folder: Optional[str] = None, | |
| batch_number: int = 0, | |
| custom_instructions: str = "", | |
| file_name: Optional[str] = None, | |
| page_number: Optional[int] = None, | |
| sheet_name: Optional[str] = None, | |
| column_name: Optional[str] = None, | |
| row_number: Optional[int] = None, | |
| inference_method: Optional[str] = None, | |
| local_model=None, | |
| tokenizer=None, | |
| assistant_model=None, | |
| client=None, | |
| client_config=None, | |
| api_url: Optional[str] = None, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Call LLM to detect entities in text using various inference methods. | |
| Args: | |
| text: Text to analyze | |
| entities_to_detect: List of entity types to detect | |
| language: Language code | |
| bedrock_runtime: AWS Bedrock runtime client (required for AWS method) | |
| model_choice: Model identifier (varies by inference method) | |
| temperature: Temperature for LLM generation (lower = more deterministic) | |
| max_tokens: Maximum tokens in response | |
| max_retries: Maximum retry attempts | |
| retry_delay: Delay between retries (seconds) | |
| output_folder: Optional folder to save prompt/response logs | |
| batch_number: Batch number for logging | |
| custom_instructions: Optional custom instructions to include in the prompt | |
| file_name: Optional file name (without extension) for saving logs | |
| page_number: Optional page number for saving logs (document flow) | |
| sheet_name: Optional Excel sheet name for tabular logs | |
| column_name: Optional column name for tabular logs | |
| row_number: Optional row number (1-based) for tabular logs | |
| inference_method: Inference method to use ("aws-bedrock", "local", "inference-server", "azure-openai", "gemini") | |
| If None, uses CHOSEN_LLM_PII_INFERENCE_METHOD from config | |
| local_model: Local model instance (required for "local" method) | |
| tokenizer: Tokenizer instance (required for "local" method with transformers) | |
| assistant_model: Assistant model for speculative decoding (optional) | |
| client: API client (required for "azure-openai" or "gemini" methods) | |
| client_config: Client config (required for "gemini" method) | |
| api_url: API URL for inference-server (required for "inference-server" method) | |
| Returns: | |
| List of entity dictionaries | |
| """ | |
| # Ensure custom_instructions is a string (callers may pass bool or other types). | |
| # Treat boolean True and the string "True" as empty (e.g. from an unchecked/empty Gradio box). | |
| if not isinstance(custom_instructions, str): | |
| custom_instructions = ( | |
| "" | |
| if custom_instructions is True or not custom_instructions | |
| else str(custom_instructions) | |
| ) | |
| if ( | |
| isinstance(custom_instructions, str) | |
| and custom_instructions.strip().lower() == "true" | |
| ): | |
| custom_instructions = "" | |
| # Determine inference method | |
| if inference_method is None: | |
| inference_method = CHOSEN_LLM_PII_INFERENCE_METHOD | |
| # When custom instructions are provided, use the upgraded model if configured | |
| custom_instructions_model = ( | |
| CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() | |
| if isinstance(CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, str) | |
| and CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() | |
| else "" | |
| ) | |
| if ( | |
| custom_instructions.strip() | |
| and model_choice == CLOUD_LLM_PII_MODEL_CHOICE | |
| and custom_instructions_model | |
| ): | |
| model_choice = custom_instructions_model | |
| # Filter out CUSTOM_VLM_* entities (these are handled separately via VLM) | |
| filtered_entities = [ | |
| entity for entity in entities_to_detect if not entity.startswith("CUSTOM_VLM_") | |
| ] | |
| # No standard entities and no custom instructions | |
| if not filtered_entities and ( | |
| not custom_instructions or not custom_instructions.strip() | |
| ): | |
| # Nothing selected at all → error | |
| if not entities_to_detect: | |
| raise ValueError( | |
| "No standard entities selected and no custom instructions provided. " | |
| "Please select at least one entity type (excluding CUSTOM_VLM_* entities) or provide custom instructions for LLM-based PII detection." | |
| ) | |
| # Only CUSTOM_VLM_* entities selected (handled separately via VLM) → return blank | |
| return [] | |
| # Determine model source from model_choice if using model_name_map | |
| model_source = None | |
| if model_choice and model_name_map and model_choice in model_name_map: | |
| model_source = model_name_map[model_choice].get("source", "AWS") | |
| # Map model source to inference method | |
| if model_source == "Local": | |
| inference_method = "local" | |
| elif model_source == "inference-server": | |
| inference_method = "inference-server" | |
| elif model_source == "Azure/OpenAI": | |
| inference_method = "azure-openai" | |
| elif model_source == "Gemini": | |
| inference_method = "gemini" | |
| elif model_source == "AWS": | |
| inference_method = "aws-bedrock" | |
| system_prompt = create_entity_detection_system_prompt( | |
| filtered_entities, language, custom_instructions | |
| ) | |
| user_prompt = create_entity_detection_prompt( | |
| text, filtered_entities, language, custom_instructions | |
| ) | |
| # Map inference_method to model_source format expected by send_request | |
| model_source_map = { | |
| "aws-bedrock": "AWS", | |
| "local": "Local", | |
| "inference-server": "inference-server", | |
| "azure-openai": "Azure/OpenAI", | |
| "gemini": "Gemini", | |
| } | |
| model_source = model_source_map.get(inference_method, "AWS") | |
| # Prepare client and config for Gemini if needed | |
| if inference_method == "gemini" and (client is None or client_config is None): | |
| from tools.llm_funcs import construct_gemini_generative_model | |
| try: | |
| client, client_config = construct_gemini_generative_model( | |
| in_api_key="", # Will use environment variable | |
| temperature=temperature, | |
| model_choice=model_choice, | |
| system_prompt=system_prompt, | |
| max_tokens=max_tokens, # Use our specific max_tokens for entity detection | |
| ) | |
| except Exception as e: | |
| raise ValueError( | |
| f"Failed to construct Gemini client: {e}. " | |
| f"Ensure GEMINI_API_KEY is set or pass client and client_config." | |
| ) | |
| # Prepare client for Azure/OpenAI if needed | |
| if inference_method == "azure-openai" and client is None: | |
| from tools.llm_funcs import construct_azure_client | |
| try: | |
| client, _ = construct_azure_client( | |
| in_api_key="", # Will use environment variable | |
| endpoint="", # Will use environment variable | |
| ) | |
| except Exception as e: | |
| raise ValueError( | |
| f"Failed to construct Azure/OpenAI client: {e}. " | |
| f"Ensure AZURE_OPENAI_API_KEY is set or pass client." | |
| ) | |
| # Set up API URL for inference-server if needed | |
| if inference_method == "inference-server" and api_url is None: | |
| api_url = INFERENCE_SERVER_API_URL | |
| if not api_url: | |
| raise ValueError( | |
| "api_url is required when using inference-server method. " | |
| "Set INFERENCE_SERVER_API_URL in config or pass api_url parameter." | |
| ) | |
| try: | |
| # Call send_request which handles all routing, retries, and response parsing | |
| # Note: send_request signature shows local_model=list() but it's actually used as a single model object | |
| ( | |
| response, | |
| conversation_history, | |
| response_text, | |
| num_transformer_input_tokens, | |
| num_transformer_generated_tokens, | |
| ) = send_request( | |
| prompt=user_prompt, | |
| conversation_history=[], # Empty for entity detection (no conversation history needed) | |
| client=client, | |
| config=client_config, | |
| model_choice=model_choice, | |
| system_prompt=system_prompt, | |
| temperature=temperature, | |
| bedrock_runtime=bedrock_runtime, | |
| model_source=model_source, | |
| # local_model=( | |
| # local_model if local_model else [] | |
| # ), # Pass model directly (signature shows list but uses as single object) | |
| # tokenizer=tokenizer, | |
| # assistant_model=assistant_model, | |
| progress=Progress( | |
| track_tqdm=False | |
| ), # Disable progress bar for entity detection | |
| api_url=api_url, | |
| ) | |
| except Exception as e: | |
| print(f"LLM entity detection failed: {e}") | |
| raise | |
| # Extract token usage from response (before save so we can write it to the log file) | |
| input_tokens = 0 | |
| output_tokens = 0 | |
| try: | |
| if isinstance(response, dict) and "usage" in response: | |
| # inference-server or llama-cpp format | |
| input_tokens = response["usage"].get("prompt_tokens", 0) | |
| output_tokens = response["usage"].get("completion_tokens", 0) | |
| elif hasattr(response, "usage_metadata"): | |
| # Check if it's AWS Bedrock format | |
| if isinstance(response.usage_metadata, dict): | |
| input_tokens = response.usage_metadata.get("inputTokens", 0) | |
| output_tokens = response.usage_metadata.get("outputTokens", 0) | |
| # Check if it's Gemini format | |
| elif hasattr(response.usage_metadata, "prompt_token_count"): | |
| input_tokens = response.usage_metadata.prompt_token_count | |
| output_tokens = response.usage_metadata.candidates_token_count | |
| except (KeyError, AttributeError) as e: | |
| print(f"Warning: Could not extract token usage from response: {e}") | |
| # Fallback for Local/transformers: response is plain text, so use token counts from send_request | |
| if num_transformer_input_tokens and num_transformer_input_tokens > 0: | |
| input_tokens = num_transformer_input_tokens | |
| if num_transformer_generated_tokens and num_transformer_generated_tokens > 0: | |
| output_tokens = num_transformer_generated_tokens | |
| # Save prompt and response if output_folder is provided. | |
| # Use the same system_prompt and user_prompt that were sent to the model | |
| # (no combined/rendered version) so the log correctly shows system vs user. | |
| if output_folder and response_text: | |
| try: | |
| saved_file = save_llm_prompt_response( | |
| system_prompt=system_prompt, | |
| user_prompt=user_prompt, | |
| response_text=response_text, | |
| output_folder=output_folder, | |
| batch_number=batch_number, | |
| model_choice=model_choice, | |
| entities_to_detect=entities_to_detect, | |
| language=language, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| file_name=file_name, | |
| page_number=page_number, | |
| sheet_name=sheet_name, | |
| column_name=column_name, | |
| row_number=row_number, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| ) | |
| if 0 == 1: # To avoid lint check issue | |
| print(f"Saved LLM prompt/response to: {saved_file}") | |
| except Exception as e: | |
| print(f"Warning: Could not save LLM prompt/response: {e}") | |
| # Parse the response | |
| entities = parse_llm_entity_response(response_text, text) | |
| return entities, input_tokens, output_tokens | |
| def map_back_llm_entity_results( | |
| entities: List[Dict[str, Any]], | |
| current_batch_mapping: List[Tuple], | |
| allow_list: List[str], | |
| chosen_redact_llm_entities: List[str], | |
| all_text_line_results: List[Tuple], | |
| ) -> List[Tuple]: | |
| """ | |
| Map LLM-detected entities back to line-level results. | |
| Similar to map_back_llm_entity_results but for LLM responses. | |
| Args: | |
| entities: List of entity dictionaries from LLM | |
| current_batch_mapping: Mapping of batch positions to line indices | |
| allow_list: List of allowed text values (to skip) - case-insensitive matching | |
| chosen_redact_llm_entities: List of entity types to include | |
| all_text_line_results: Existing line-level results to append to | |
| Returns: | |
| Updated all_text_line_results | |
| """ | |
| if not entities: | |
| return all_text_line_results | |
| # Normalize allow_list for case-insensitive matching | |
| if allow_list: | |
| allow_list_normalized = [item.strip().lower() for item in allow_list if item] | |
| else: | |
| allow_list_normalized = [] | |
| for entity in entities: | |
| entity_type = entity.get("Type") | |
| # Allow all entity types returned by LLM, including custom types from custom instructions | |
| # Log when a custom entity type (not in the original list) is found | |
| # if entity_type not in chosen_redact_llm_entities: | |
| # print( | |
| # f"Info: Found custom entity type '{entity_type}' (not in original detection list). " | |
| # f"Including it in results as it was returned by LLM." | |
| # ) | |
| entity_start = entity["BeginOffset"] | |
| entity_end = entity["EndOffset"] | |
| # Track if the entity has been added to any line | |
| added_to_line = False | |
| # Find the correct line and offset within that line | |
| for ( | |
| batch_start, | |
| line_idx, | |
| original_line, | |
| chars, | |
| line_offset, | |
| ) in current_batch_mapping: | |
| # Calculate the end position of this line segment in the batch | |
| if line_offset is not None: | |
| # Line offset is the start position within the line | |
| line_text_length = len(original_line.text[line_offset:]) | |
| else: | |
| line_text_length = len(original_line.text) | |
| batch_end = batch_start + line_text_length | |
| # Check if the entity overlaps with the current line | |
| if batch_start < entity_end and batch_end > entity_start: | |
| # Calculate the relative position within the line | |
| if line_offset is not None: | |
| relative_start = max(0, entity_start - batch_start + line_offset) | |
| relative_end = min( | |
| entity_end - batch_start + line_offset, len(original_line.text) | |
| ) | |
| else: | |
| relative_start = max(0, entity_start - batch_start) | |
| relative_end = min( | |
| entity_end - batch_start, len(original_line.text) | |
| ) | |
| result_text = original_line.text[relative_start:relative_end] | |
| # Check if result_text is in allow_list (case-insensitive) | |
| # If allow_list contains this text, skip adding it as a PII entity | |
| # This allows allow_list terms to "overrule" LLM PII detection | |
| result_text_normalized = result_text.strip().lower() | |
| if result_text_normalized not in allow_list_normalized: | |
| # Create entity dict in llm-like format | |
| adjusted_entity = { | |
| "Type": entity_type, | |
| "BeginOffset": relative_start, | |
| "EndOffset": relative_end, | |
| "Score": entity.get("Score", 0.8), | |
| } | |
| # Import here to avoid circular imports | |
| from tools.presidio_analyzer_custom import ( | |
| recognizer_result_from_dict, | |
| ) | |
| recogniser_entity = recognizer_result_from_dict(adjusted_entity) | |
| # Check if this line already has an entry | |
| existing_entry = next( | |
| ( | |
| entry | |
| for idx, entry in all_text_line_results | |
| if idx == line_idx | |
| ), | |
| None, | |
| ) | |
| if existing_entry is None: | |
| all_text_line_results.append((line_idx, [recogniser_entity])) | |
| else: | |
| existing_entry.append(recogniser_entity) | |
| added_to_line = True | |
| # Optional: Handle cases where the entity does not fit in any line | |
| if not added_to_line: | |
| print( | |
| f"Entity '{entity_type}' at position {entity_start}-{entity_end} does not fit in any line." | |
| ) | |
| return all_text_line_results | |
| def do_llm_entity_detection_call( | |
| current_batch: str, | |
| current_batch_mapping: List[Tuple], | |
| bedrock_runtime: Optional[boto3.Session.client] = None, | |
| language: str = "en", | |
| allow_list: List[str] = None, | |
| chosen_redact_llm_entities: List[str] = None, | |
| all_text_line_results: List[Tuple] = None, | |
| model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE, | |
| temperature: float = LLM_TEMPERATURE, | |
| max_tokens: int = LLM_MAX_NEW_TOKENS, | |
| output_folder: Optional[str] = None, | |
| batch_number: int = 0, | |
| custom_instructions: str = "", | |
| file_name: Optional[str] = None, | |
| page_number: Optional[int] = None, | |
| inference_method: Optional[str] = None, | |
| local_model=None, | |
| tokenizer=None, | |
| assistant_model=None, | |
| client=None, | |
| client_config=None, | |
| api_url: Optional[str] = None, | |
| ) -> Tuple[List[Tuple], int, int]: | |
| """ | |
| Call LLM for entity detection on a batch of text. | |
| Similar interface to do_aws_llm_call. | |
| Args: | |
| current_batch: Text batch to analyze | |
| current_batch_mapping: Mapping of batch positions to line indices | |
| bedrock_runtime: AWS Bedrock runtime client (required for AWS method) | |
| language: Language code | |
| allow_list: List of allowed text values | |
| chosen_redact_llm_entities: List of entity types to detect | |
| all_text_line_results: Existing line-level results | |
| model_choice: Model identifier (varies by inference method) | |
| temperature: Temperature for LLM generation | |
| max_tokens: Maximum tokens in response | |
| output_folder: Optional folder to save prompt/response logs | |
| batch_number: Batch number for logging | |
| custom_instructions: Optional custom instructions to include in the prompt | |
| file_name: Optional file name (without extension) for saving logs | |
| page_number: Optional page number for saving logs | |
| inference_method: Inference method to use (if None, uses config default) | |
| local_model: Local model instance (required for "local" method) | |
| tokenizer: Tokenizer instance (required for "local" method with transformers) | |
| assistant_model: Assistant model for speculative decoding (optional) | |
| client: API client (required for "azure-openai" or "gemini" methods) | |
| client_config: Client config (required for "gemini" method) | |
| api_url: API URL for inference-server (required for "inference-server" method) | |
| Returns: | |
| Tuple of (updated all_text_line_results, input_tokens, output_tokens) | |
| """ | |
| if not current_batch: | |
| return (all_text_line_results or [], 0, 0) | |
| if allow_list is None: | |
| allow_list = [] | |
| if chosen_redact_llm_entities is None: | |
| chosen_redact_llm_entities = [] | |
| if all_text_line_results is None: | |
| all_text_line_results = [] | |
| try: | |
| entities, input_tokens, output_tokens = call_llm_for_entity_detection( | |
| text=current_batch.strip(), | |
| entities_to_detect=chosen_redact_llm_entities, | |
| language=language, | |
| bedrock_runtime=bedrock_runtime, | |
| model_choice=model_choice, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| output_folder=output_folder, | |
| batch_number=batch_number, | |
| custom_instructions=custom_instructions, | |
| file_name=file_name, | |
| page_number=page_number, | |
| inference_method=inference_method, | |
| local_model=local_model, | |
| tokenizer=tokenizer, | |
| assistant_model=assistant_model, | |
| client=client, | |
| client_config=client_config, | |
| api_url=api_url, | |
| ) | |
| all_text_line_results = map_back_llm_entity_results( | |
| entities, | |
| current_batch_mapping, | |
| allow_list, | |
| chosen_redact_llm_entities, | |
| all_text_line_results, | |
| ) | |
| return all_text_line_results, input_tokens, output_tokens | |
| except Exception as e: | |
| print(f"LLM entity detection call failed: {e}") | |
| raise | |