""" Post-processing utilities for risk discovery results Includes merging duplicate topics and validating cluster quality """ import numpy as np from typing import Dict, List, Any from collections import defaultdict import re def merge_duplicate_topics(discovered_patterns: Dict, cluster_labels: np.ndarray, merge_rules: Dict[str, List[str]] = None) -> tuple: """ Merge duplicate or highly similar topics in discovered risk patterns. This addresses the issue where clustering/topic modeling discovers semantically similar categories (e.g., "LIABILITY_Insurance" and "LIABILITY_Breach"). Args: discovered_patterns: Dictionary from discover_risk_patterns() or just the topics dict cluster_labels: Array of cluster assignments for each document merge_rules: Optional dict mapping new topic name to list of old topic names/IDs Example: {'LIABILITY': ['Topic_LIABILITY_INSURANCE', 'Topic_LIABILITY_BREACH']} Or: {'LIABILITY': [0, 6]} for numeric IDs Returns: tuple: (merged_patterns, new_cluster_labels) """ # PHASE 2 FIX: Handle both formats if 'discovered_topics' in discovered_patterns: topics = discovered_patterns['discovered_topics'] else: topics = discovered_patterns if merge_rules is None: # Default: Merge topics with "LIABILITY" in name merge_rules = detect_duplicate_topics(discovered_patterns) if not merge_rules: print("â„šī¸ No duplicate topics detected - no merging needed") return topics, cluster_labels print(f"🔧 Merging duplicate topics...") # Create mapping from old to new IDs old_to_new = {} new_id = 0 merged_patterns = {} # Track which old IDs have been merged merged_old_ids = set() for new_name, old_names_or_ids in merge_rules.items(): print(f" Merging {len(old_names_or_ids)} topics → {new_name}") # Collect all patterns to merge patterns_to_merge = [] old_ids_to_merge = [] for old_ref in old_names_or_ids: if isinstance(old_ref, int): # Numeric ID reference old_id = old_ref old_ids_to_merge.append(old_id) else: # Name reference - find matching pattern for pattern_id, pattern in topics.items(): pattern_name = pattern.get('topic_name') or pattern.get('pattern_name', '') if old_ref in pattern_name or pattern_name in old_ref: old_id = int(pattern_id) if isinstance(pattern_id, str) and pattern_id.isdigit() else pattern_id old_ids_to_merge.append(old_id) # Get pattern data pattern_key = str(old_id) if isinstance(old_id, int) else old_id if pattern_key in topics: patterns_to_merge.append(topics[pattern_key]) merged_old_ids.add(pattern_key) if patterns_to_merge: # Merge patterns merged_pattern = merge_topic_data(patterns_to_merge, new_name) merged_patterns[str(new_id)] = merged_pattern # Map old IDs to new ID for old_id in old_ids_to_merge: old_to_new[old_id] = new_id new_id += 1 # Add non-merged patterns for pattern_id, pattern in topics.items(): if pattern_id not in merged_old_ids: old_id = int(pattern_id) if isinstance(pattern_id, str) and pattern_id.isdigit() else pattern_id old_to_new[old_id] = new_id merged_patterns[str(new_id)] = pattern.copy() merged_patterns[str(new_id)]['topic_id'] = new_id new_id += 1 # Remap cluster labels new_labels = np.array([old_to_new.get(label, label) for label in cluster_labels]) print(f"✅ Merging complete: {len(discovered_patterns)} → {len(merged_patterns)} topics") return merged_patterns, new_labels def detect_duplicate_topics(discovered_patterns: Dict) -> Dict[str, List]: """ Automatically detect duplicate topics based on name similarity. Looks for topics with: - Same base word (e.g., "LIABILITY" in multiple topics) - Similar keyword overlap (>60% shared keywords) Args: discovered_patterns: Dictionary from discover_risk_patterns() or just the topics dict Returns: Merge rules dict mapping new name to list of old topic IDs """ merge_rules = {} # PHASE 2 FIX: Handle both formats if 'discovered_topics' in discovered_patterns: topics = discovered_patterns['discovered_topics'] else: topics = discovered_patterns # Group topics by base name base_name_groups = defaultdict(list) for topic_id, topic in topics.items(): topic_name = topic.get('topic_name') or topic.get('pattern_name', '') # Extract base name (text before parentheses or descriptive suffix) base_name = re.sub(r'[(_\s].+', '', topic_name).upper() # Clean up common prefixes base_name = base_name.replace('TOPIC_', '').replace('PATTERN_', '') if base_name: topic_id_int = int(topic_id) if isinstance(topic_id, str) and topic_id.isdigit() else topic_id base_name_groups[base_name].append(topic_id_int) # Identify groups with duplicates for base_name, topic_ids in base_name_groups.items(): if len(topic_ids) > 1: merge_rules[base_name] = topic_ids print(f" 🔍 Detected duplicate: {len(topic_ids)} topics with base name '{base_name}'") return merge_rules def merge_topic_data(patterns: List[Dict], new_name: str) -> Dict: """ Merge multiple topic patterns into a single consolidated pattern. Args: patterns: List of topic pattern dictionaries to merge new_name: Name for the merged topic Returns: Merged topic dictionary """ merged = { 'topic_name': f"Topic_{new_name}", 'clause_count': sum(p.get('clause_count', 0) for p in patterns), } # Merge keywords/top_words (take union and sort by frequency) all_keywords = [] for pattern in patterns: keywords = pattern.get('keywords', pattern.get('top_words', [])) all_keywords.extend(keywords[:10]) # Top 10 from each # Count and sort from collections import Counter keyword_counts = Counter(all_keywords) merged['top_words'] = [word for word, _ in keyword_counts.most_common(15)] merged['keywords'] = merged['top_words'] # For compatibility # Merge word weights if available if 'word_weights' in patterns[0]: all_weights = [] for pattern in patterns: weights = pattern.get('word_weights', []) all_weights.extend(weights[:10]) merged['word_weights'] = sorted(all_weights, reverse=True)[:15] # Average numeric features numeric_fields = ['avg_risk_intensity', 'avg_legal_complexity', 'avg_obligation_strength', 'proportion'] for field in numeric_fields: values = [p.get(field, 0) for p in patterns if field in p] if values: merged[field] = np.mean(values) # Combine sample clauses all_samples = [] for pattern in patterns: samples = pattern.get('sample_clauses', []) all_samples.extend(samples[:2]) # Top 2 from each merged['sample_clauses'] = all_samples[:5] # Keep top 5 overall return merged def validate_cluster_quality(discovered_patterns: Dict, min_cluster_size: int = 150) -> Dict: """ Validate cluster quality and flag issues. Checks for: - Clusters that are too small (< min_cluster_size samples) - Clusters with duplicate names - Imbalanced cluster sizes (largest > 3x smallest) Args: discovered_patterns: Dictionary from discover_risk_patterns() or just the topics dict min_cluster_size: Minimum acceptable cluster size Returns: Validation report dictionary """ report = { 'is_valid': True, 'issues': [], 'warnings': [], 'cluster_sizes': {} } # PHASE 2 FIX: Handle both formats - full result dict or just topics dict if 'discovered_topics' in discovered_patterns: # Full result dictionary from discover_risk_patterns() topics = discovered_patterns['discovered_topics'] elif any(isinstance(v, dict) and ('topic_name' in v or 'pattern_name' in v or 'key_terms' in v) for v in discovered_patterns.values()): # Already the topics dictionary topics = discovered_patterns else: # Unknown format report['is_valid'] = False report['issues'].append("Invalid format: expected 'discovered_topics' key or topics dictionary") return report sizes = [] names = [] for topic_id, topic in topics.items(): count = topic.get('clause_count', 0) name = topic.get('topic_name', topic.get('pattern_name', f"Topic_{topic_id}")) sizes.append(count) names.append(name) report['cluster_sizes'][name] = count # Check cluster size if count < min_cluster_size: report['is_valid'] = False report['issues'].append(f"Cluster '{name}' too small: {count} < {min_cluster_size}") # Check for duplicate names from collections import Counter name_counts = Counter(names) for name, count in name_counts.items(): if count > 1: report['is_valid'] = False report['issues'].append(f"Duplicate cluster name: '{name}' appears {count} times") # Check balance if sizes: max_size = max(sizes) min_size = min(sizes) ratio = max_size / min_size if min_size > 0 else float('inf') if ratio > 3.0: report['warnings'].append( f"Imbalanced clusters: largest ({max_size}) is {ratio:.1f}x bigger than smallest ({min_size})" ) return report # Example usage if __name__ == "__main__": print("🔧 Risk Discovery Post-Processing Utilities\n") # Simulate discovered patterns with duplicates test_patterns = { '0': {'topic_name': 'Topic_LIABILITY', 'clause_count': 400, 'top_words': ['insurance', 'coverage']}, '1': {'topic_name': 'Topic_COMPLIANCE', 'clause_count': 300, 'top_words': ['laws', 'governed']}, '2': {'topic_name': 'Topic_TERMINATION', 'clause_count': 350, 'top_words': ['term', 'notice']}, '6': {'topic_name': 'Topic_LIABILITY', 'clause_count': 250, 'top_words': ['damages', 'breach']}, } test_labels = np.array([0, 1, 2, 0, 1, 6, 2, 0, 6]) # Detect duplicates print("1. Detecting duplicate topics:") merge_rules = detect_duplicate_topics(test_patterns) print() # Merge duplicates print("2. Merging duplicates:") merged_patterns, new_labels = merge_duplicate_topics(test_patterns, test_labels, merge_rules) print() # Validate quality print("3. Validating cluster quality:") report = validate_cluster_quality(merged_patterns, min_cluster_size=200) print(f" Valid: {report['is_valid']}") print(f" Issues: {report['issues']}") print(f" Warnings: {report['warnings']}")