Spaces:
Running
Running
| """ | |
| Topic Validation Module | |
| Validates cluster coherence, keyword overlap, and label consistency | |
| """ | |
| import logging | |
| from typing import List, Dict, Optional | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| logger = logging.getLogger(__name__) | |
| class TopicValidator: | |
| """Validates topic quality and consistency.""" | |
| def keyword_overlap_score( | |
| keywords_per_topic: Dict[int, List[str]], | |
| threshold: float = 0.3, | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculate keyword overlap between topics. | |
| High overlap may indicate mergeable/redundant topics. | |
| Args: | |
| keywords_per_topic: Dict mapping topic_id to keyword list | |
| threshold: Overlap threshold for warnings | |
| Returns: | |
| Dict with overlap scores and warnings | |
| """ | |
| if len(keywords_per_topic) < 2: | |
| return {"mean_overlap": 0.0, "warnings": []} | |
| topics = list(keywords_per_topic.keys()) | |
| topic_lists = list(keywords_per_topic.values()) | |
| overlaps = [] | |
| warnings = [] | |
| for i, (tid1, keywords1) in enumerate(zip(topics, topic_lists)): | |
| set1 = set(keywords1) | |
| for tid2, keywords2 in zip(topics[i + 1:], topic_lists[i + 1:]): | |
| set2 = set(keywords2) | |
| if not set1 or not set2: | |
| overlap = 0.0 | |
| else: | |
| overlap = len(set1.intersection(set2)) / max(len(set1), len(set2)) | |
| overlaps.append(overlap) | |
| if overlap > threshold: | |
| warnings.append( | |
| f"Topics {tid1} and {tid2} share {overlap:.1%} keywords (potential merge)" | |
| ) | |
| mean_overlap = np.mean(overlaps) if overlaps else 0.0 | |
| return { | |
| "mean_overlap": float(mean_overlap), | |
| "high_overlap_pairs": len(warnings), | |
| "warnings": warnings, | |
| } | |
| def cluster_coherence( | |
| embeddings: np.ndarray, | |
| labels: np.ndarray, | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculate silhouette-like coherence for clusters. | |
| Higher = more separated, more coherent clusters. | |
| Args: | |
| embeddings: Document embeddings (n_docs, n_dims) | |
| labels: Cluster labels (n_docs,) | |
| Returns: | |
| Dict with coherence scores | |
| """ | |
| if len(embeddings) < 2: | |
| return {"mean_coherence": 0.0, "min_coherence": 0.0, "max_coherence": 0.0} | |
| unique_labels = np.unique(labels) | |
| coherence_scores = [] | |
| for label in unique_labels: | |
| mask = labels == label | |
| cluster_embeddings = embeddings[mask] | |
| if len(cluster_embeddings) < 2: | |
| coherence_scores.append(0.0) | |
| continue | |
| # Mean pairwise similarity within cluster | |
| similarities = cosine_similarity(cluster_embeddings) | |
| # Average of upper triangle (excluding diagonal) | |
| n = len(similarities) | |
| upper_triangle = similarities[np.triu_indices(n, k=1)] | |
| mean_sim = np.mean(upper_triangle) if len(upper_triangle) > 0 else 0.0 | |
| coherence_scores.append(float(mean_sim)) | |
| return { | |
| "mean_coherence": float(np.mean(coherence_scores)), | |
| "min_coherence": float(np.min(coherence_scores)), | |
| "max_coherence": float(np.max(coherence_scores)), | |
| "std_coherence": float(np.std(coherence_scores)), | |
| } | |
| def label_consistency( | |
| topic_labels: Dict[int, str], | |
| topic_keywords: Dict[int, List[str]], | |
| ) -> Dict[str, any]: | |
| """ | |
| Check label consistency with keywords. | |
| Labels should be semantically related to their keywords. | |
| Args: | |
| topic_labels: Dict mapping topic_id to label | |
| topic_keywords: Dict mapping topic_id to keyword list | |
| Returns: | |
| Dict with consistency metrics | |
| """ | |
| if not topic_labels: | |
| return {"consistent": True, "issues": []} | |
| issues = [] | |
| # Check for empty or very short labels | |
| for tid, label in topic_labels.items(): | |
| label_text = str(label).strip() | |
| if not label_text or len(label_text) < 3: | |
| issues.append(f"Topic {tid}: Label too short or empty ('{label_text}')") | |
| # Check label length (should be 3-6 words typically) | |
| word_count = len(label_text.split()) | |
| if word_count < 2: | |
| issues.append(f"Topic {tid}: Label too short ({word_count} word)") | |
| elif word_count > 10: | |
| issues.append(f"Topic {tid}: Label too long ({word_count} words)") | |
| # Check for common issues | |
| if "topic" in label_text.lower() and "unknown" in label_text.lower(): | |
| issues.append(f"Topic {tid}: Generic/fallback label '{label}'") | |
| consistency_rate = 1.0 - (len(issues) / max(len(topic_labels), 1)) if issues else 1.0 | |
| return { | |
| "consistent": len(issues) == 0, | |
| "consistency_rate": float(consistency_rate), | |
| "issue_count": len(issues), | |
| "issues": issues[:10], # Return top 10 issues | |
| } | |
| def cluster_count_validation( | |
| label_count: int, | |
| min_clusters: int = 15, | |
| max_clusters: int = 30, | |
| ) -> Dict[str, any]: | |
| """ | |
| Validate cluster count is within acceptable range. | |
| Args: | |
| label_count: Number of clusters | |
| min_clusters: Minimum acceptable clusters | |
| max_clusters: Maximum acceptable clusters | |
| Returns: | |
| Dict with validation results | |
| """ | |
| valid = min_clusters <= label_count <= max_clusters | |
| status = "β VALID" if valid else "β INVALID" | |
| return { | |
| "valid": valid, | |
| "status": status, | |
| "cluster_count": label_count, | |
| "min_required": min_clusters, | |
| "max_allowed": max_clusters, | |
| "message": ( | |
| f"{status}: {label_count} clusters " | |
| f"(target: {min_clusters}-{max_clusters})" | |
| ), | |
| } | |
| def full_validation( | |
| embeddings: np.ndarray, | |
| labels: np.ndarray, | |
| topic_labels: Dict[int, str], | |
| topic_keywords: Dict[int, List[str]], | |
| min_clusters: int = 15, | |
| max_clusters: int = 30, | |
| ) -> Dict[str, any]: | |
| """ | |
| Run full validation suite. | |
| Args: | |
| embeddings: Document embeddings | |
| labels: Cluster labels | |
| topic_labels: Topic ID to label mapping | |
| topic_keywords: Topic ID to keywords mapping | |
| min_clusters: Minimum clusters | |
| max_clusters: Maximum clusters | |
| Returns: | |
| Comprehensive validation report | |
| """ | |
| validator = TopicValidator() | |
| return { | |
| "coherence": validator.cluster_coherence(embeddings, labels), | |
| "keyword_overlap": validator.keyword_overlap_score(topic_keywords), | |
| "label_consistency": validator.label_consistency(topic_labels, topic_keywords), | |
| "cluster_count": validator.cluster_count_validation( | |
| len(topic_labels), min_clusters, max_clusters | |
| ), | |
| } | |
| def print_validation_report(validation_result: Dict[str, any]): | |
| """Pretty print validation report.""" | |
| print("\n" + "=" * 70) | |
| print("TOPIC VALIDATION REPORT") | |
| print("=" * 70) | |
| # Coherence | |
| print("\nπ CLUSTER COHERENCE") | |
| coh = validation_result.get("coherence", {}) | |
| print(f" Mean Coherence: {coh.get('mean_coherence', 0):.3f} (0.0-1.0)") | |
| print(f" Range: {coh.get('min_coherence', 0):.3f} - {coh.get('max_coherence', 0):.3f}") | |
| # Keyword Overlap | |
| print("\nπ KEYWORD OVERLAP") | |
| overlap = validation_result.get("keyword_overlap", {}) | |
| print(f" Mean Overlap: {overlap.get('mean_overlap', 0):.1%}") | |
| high_pairs = overlap.get("high_overlap_pairs", 0) | |
| if high_pairs > 0: | |
| print(f" β οΈ {high_pairs} high-overlap pairs (potential merges)") | |
| for warning in overlap.get("warnings", [])[:3]: | |
| print(f" β {warning}") | |
| # Label Consistency | |
| print("\nβ LABEL CONSISTENCY") | |
| consistency = validation_result.get("label_consistency", {}) | |
| rate = consistency.get("consistency_rate", 0) | |
| print(f" Consistency Rate: {rate:.1%}") | |
| issues = consistency.get("issues", []) | |
| if issues: | |
| print(f" Issues Found: {len(issues)}") | |
| for issue in issues[:3]: | |
| print(f" β {issue}") | |
| # Cluster Count | |
| print("\nπ CLUSTER COUNT VALIDATION") | |
| cc = validation_result.get("cluster_count", {}) | |
| print(f" {cc.get('message', 'N/A')}") | |
| print("\n" + "=" * 70 + "\n") | |