code2-repo / risk_postprocessing.py
Deepu1965's picture
Upload folder using huggingface_hub
21613a7 verified
"""
Post-processing utilities for risk discovery results
Includes merging duplicate topics and validating cluster quality
"""
import numpy as np
from typing import Dict, List, Any
from collections import defaultdict
import re
def merge_duplicate_topics(discovered_patterns: Dict, cluster_labels: np.ndarray,
merge_rules: Dict[str, List[str]] = None) -> tuple:
"""
Merge duplicate or highly similar topics in discovered risk patterns.
This addresses the issue where clustering/topic modeling discovers semantically
similar categories (e.g., "LIABILITY_Insurance" and "LIABILITY_Breach").
Args:
discovered_patterns: Dictionary from discover_risk_patterns() or just the topics dict
cluster_labels: Array of cluster assignments for each document
merge_rules: Optional dict mapping new topic name to list of old topic names/IDs
Example: {'LIABILITY': ['Topic_LIABILITY_INSURANCE', 'Topic_LIABILITY_BREACH']}
Or: {'LIABILITY': [0, 6]} for numeric IDs
Returns:
tuple: (merged_patterns, new_cluster_labels)
"""
# PHASE 2 FIX: Handle both formats
if 'discovered_topics' in discovered_patterns:
topics = discovered_patterns['discovered_topics']
else:
topics = discovered_patterns
if merge_rules is None:
# Default: Merge topics with "LIABILITY" in name
merge_rules = detect_duplicate_topics(discovered_patterns)
if not merge_rules:
print("โ„น๏ธ No duplicate topics detected - no merging needed")
return topics, cluster_labels
print(f"๐Ÿ”ง Merging duplicate topics...")
# Create mapping from old to new IDs
old_to_new = {}
new_id = 0
merged_patterns = {}
# Track which old IDs have been merged
merged_old_ids = set()
for new_name, old_names_or_ids in merge_rules.items():
print(f" Merging {len(old_names_or_ids)} topics โ†’ {new_name}")
# Collect all patterns to merge
patterns_to_merge = []
old_ids_to_merge = []
for old_ref in old_names_or_ids:
if isinstance(old_ref, int):
# Numeric ID reference
old_id = old_ref
old_ids_to_merge.append(old_id)
else:
# Name reference - find matching pattern
for pattern_id, pattern in topics.items():
pattern_name = pattern.get('topic_name') or pattern.get('pattern_name', '')
if old_ref in pattern_name or pattern_name in old_ref:
old_id = int(pattern_id) if isinstance(pattern_id, str) and pattern_id.isdigit() else pattern_id
old_ids_to_merge.append(old_id)
# Get pattern data
pattern_key = str(old_id) if isinstance(old_id, int) else old_id
if pattern_key in topics:
patterns_to_merge.append(topics[pattern_key])
merged_old_ids.add(pattern_key)
if patterns_to_merge:
# Merge patterns
merged_pattern = merge_topic_data(patterns_to_merge, new_name)
merged_patterns[str(new_id)] = merged_pattern
# Map old IDs to new ID
for old_id in old_ids_to_merge:
old_to_new[old_id] = new_id
new_id += 1
# Add non-merged patterns
for pattern_id, pattern in topics.items():
if pattern_id not in merged_old_ids:
old_id = int(pattern_id) if isinstance(pattern_id, str) and pattern_id.isdigit() else pattern_id
old_to_new[old_id] = new_id
merged_patterns[str(new_id)] = pattern.copy()
merged_patterns[str(new_id)]['topic_id'] = new_id
new_id += 1
# Remap cluster labels
new_labels = np.array([old_to_new.get(label, label) for label in cluster_labels])
print(f"โœ… Merging complete: {len(discovered_patterns)} โ†’ {len(merged_patterns)} topics")
return merged_patterns, new_labels
def detect_duplicate_topics(discovered_patterns: Dict) -> Dict[str, List]:
"""
Automatically detect duplicate topics based on name similarity.
Looks for topics with:
- Same base word (e.g., "LIABILITY" in multiple topics)
- Similar keyword overlap (>60% shared keywords)
Args:
discovered_patterns: Dictionary from discover_risk_patterns() or just the topics dict
Returns:
Merge rules dict mapping new name to list of old topic IDs
"""
merge_rules = {}
# PHASE 2 FIX: Handle both formats
if 'discovered_topics' in discovered_patterns:
topics = discovered_patterns['discovered_topics']
else:
topics = discovered_patterns
# Group topics by base name
base_name_groups = defaultdict(list)
for topic_id, topic in topics.items():
topic_name = topic.get('topic_name') or topic.get('pattern_name', '')
# Extract base name (text before parentheses or descriptive suffix)
base_name = re.sub(r'[(_\s].+', '', topic_name).upper()
# Clean up common prefixes
base_name = base_name.replace('TOPIC_', '').replace('PATTERN_', '')
if base_name:
topic_id_int = int(topic_id) if isinstance(topic_id, str) and topic_id.isdigit() else topic_id
base_name_groups[base_name].append(topic_id_int)
# Identify groups with duplicates
for base_name, topic_ids in base_name_groups.items():
if len(topic_ids) > 1:
merge_rules[base_name] = topic_ids
print(f" ๐Ÿ” Detected duplicate: {len(topic_ids)} topics with base name '{base_name}'")
return merge_rules
def merge_topic_data(patterns: List[Dict], new_name: str) -> Dict:
"""
Merge multiple topic patterns into a single consolidated pattern.
Args:
patterns: List of topic pattern dictionaries to merge
new_name: Name for the merged topic
Returns:
Merged topic dictionary
"""
merged = {
'topic_name': f"Topic_{new_name}",
'clause_count': sum(p.get('clause_count', 0) for p in patterns),
}
# Merge keywords/top_words (take union and sort by frequency)
all_keywords = []
for pattern in patterns:
keywords = pattern.get('keywords', pattern.get('top_words', []))
all_keywords.extend(keywords[:10]) # Top 10 from each
# Count and sort
from collections import Counter
keyword_counts = Counter(all_keywords)
merged['top_words'] = [word for word, _ in keyword_counts.most_common(15)]
merged['keywords'] = merged['top_words'] # For compatibility
# Merge word weights if available
if 'word_weights' in patterns[0]:
all_weights = []
for pattern in patterns:
weights = pattern.get('word_weights', [])
all_weights.extend(weights[:10])
merged['word_weights'] = sorted(all_weights, reverse=True)[:15]
# Average numeric features
numeric_fields = ['avg_risk_intensity', 'avg_legal_complexity', 'avg_obligation_strength', 'proportion']
for field in numeric_fields:
values = [p.get(field, 0) for p in patterns if field in p]
if values:
merged[field] = np.mean(values)
# Combine sample clauses
all_samples = []
for pattern in patterns:
samples = pattern.get('sample_clauses', [])
all_samples.extend(samples[:2]) # Top 2 from each
merged['sample_clauses'] = all_samples[:5] # Keep top 5 overall
return merged
def validate_cluster_quality(discovered_patterns: Dict, min_cluster_size: int = 150) -> Dict:
"""
Validate cluster quality and flag issues.
Checks for:
- Clusters that are too small (< min_cluster_size samples)
- Clusters with duplicate names
- Imbalanced cluster sizes (largest > 3x smallest)
Args:
discovered_patterns: Dictionary from discover_risk_patterns() or just the topics dict
min_cluster_size: Minimum acceptable cluster size
Returns:
Validation report dictionary
"""
report = {
'is_valid': True,
'issues': [],
'warnings': [],
'cluster_sizes': {}
}
# PHASE 2 FIX: Handle both formats - full result dict or just topics dict
if 'discovered_topics' in discovered_patterns:
# Full result dictionary from discover_risk_patterns()
topics = discovered_patterns['discovered_topics']
elif any(isinstance(v, dict) and ('topic_name' in v or 'pattern_name' in v or 'key_terms' in v)
for v in discovered_patterns.values()):
# Already the topics dictionary
topics = discovered_patterns
else:
# Unknown format
report['is_valid'] = False
report['issues'].append("Invalid format: expected 'discovered_topics' key or topics dictionary")
return report
sizes = []
names = []
for topic_id, topic in topics.items():
count = topic.get('clause_count', 0)
name = topic.get('topic_name', topic.get('pattern_name', f"Topic_{topic_id}"))
sizes.append(count)
names.append(name)
report['cluster_sizes'][name] = count
# Check cluster size
if count < min_cluster_size:
report['is_valid'] = False
report['issues'].append(f"Cluster '{name}' too small: {count} < {min_cluster_size}")
# Check for duplicate names
from collections import Counter
name_counts = Counter(names)
for name, count in name_counts.items():
if count > 1:
report['is_valid'] = False
report['issues'].append(f"Duplicate cluster name: '{name}' appears {count} times")
# Check balance
if sizes:
max_size = max(sizes)
min_size = min(sizes)
ratio = max_size / min_size if min_size > 0 else float('inf')
if ratio > 3.0:
report['warnings'].append(
f"Imbalanced clusters: largest ({max_size}) is {ratio:.1f}x bigger than smallest ({min_size})"
)
return report
# Example usage
if __name__ == "__main__":
print("๐Ÿ”ง Risk Discovery Post-Processing Utilities\n")
# Simulate discovered patterns with duplicates
test_patterns = {
'0': {'topic_name': 'Topic_LIABILITY', 'clause_count': 400, 'top_words': ['insurance', 'coverage']},
'1': {'topic_name': 'Topic_COMPLIANCE', 'clause_count': 300, 'top_words': ['laws', 'governed']},
'2': {'topic_name': 'Topic_TERMINATION', 'clause_count': 350, 'top_words': ['term', 'notice']},
'6': {'topic_name': 'Topic_LIABILITY', 'clause_count': 250, 'top_words': ['damages', 'breach']},
}
test_labels = np.array([0, 1, 2, 0, 1, 6, 2, 0, 6])
# Detect duplicates
print("1. Detecting duplicate topics:")
merge_rules = detect_duplicate_topics(test_patterns)
print()
# Merge duplicates
print("2. Merging duplicates:")
merged_patterns, new_labels = merge_duplicate_topics(test_patterns, test_labels, merge_rules)
print()
# Validate quality
print("3. Validating cluster quality:")
report = validate_cluster_quality(merged_patterns, min_cluster_size=200)
print(f" Valid: {report['is_valid']}")
print(f" Issues: {report['issues']}")
print(f" Warnings: {report['warnings']}")