Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import logging | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import re | |
| logger = logging.getLogger(__name__) | |
| class RiskFeatureExtractor: | |
| def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"): | |
| self.embedding_model = SentenceTransformer(embedding_model) | |
| self.tfidf_vectorizer = TfidfVectorizer(max_features=1000, stop_words='english') | |
| def extract_features(self, question: str, retrieved_passages: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Extract risk assessment features""" | |
| if not retrieved_passages: | |
| return self._get_empty_features() | |
| features = {} | |
| # Retrieval statistics | |
| features.update(self._extract_retrieval_stats(retrieved_passages)) | |
| # Coverage features | |
| features.update(self._extract_coverage_features(question, retrieved_passages)) | |
| # Consistency features | |
| features.update(self._extract_consistency_features(question, retrieved_passages)) | |
| # Diversity features | |
| features.update(self._extract_diversity_features(retrieved_passages)) | |
| return features | |
| def _extract_retrieval_stats(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Extract retrieval statistics""" | |
| if not passages: | |
| return {} | |
| scores = [p.get('score', 0.0) for p in passages] | |
| return { | |
| 'num_passages': len(passages), | |
| 'avg_similarity': np.mean(scores), | |
| 'std_similarity': np.std(scores), | |
| 'max_similarity': np.max(scores), | |
| 'min_similarity': np.min(scores), | |
| 'score_variance': np.var(scores) | |
| } | |
| def _extract_coverage_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Extract coverage features between question and passages""" | |
| if not passages: | |
| return {} | |
| # Token overlap | |
| question_tokens = set(question.lower().split()) | |
| passage_texts = [p.get('text', '') for p in passages] | |
| overlaps = [] | |
| for passage_text in passage_texts: | |
| passage_tokens = set(passage_text.lower().split()) | |
| overlap = len(question_tokens.intersection(passage_tokens)) | |
| overlaps.append(overlap / len(question_tokens) if question_tokens else 0) | |
| # Entity overlap (simplified) | |
| question_entities = self._extract_entities(question) | |
| entity_overlaps = [] | |
| for passage_text in passage_texts: | |
| passage_entities = self._extract_entities(passage_text) | |
| overlap = len(question_entities.intersection(passage_entities)) | |
| entity_overlaps.append(overlap / len(question_entities) if question_entities else 0) | |
| return { | |
| 'avg_token_overlap': np.mean(overlaps), | |
| 'max_token_overlap': np.max(overlaps), | |
| 'avg_entity_overlap': np.mean(entity_overlaps), | |
| 'max_entity_overlap': np.max(entity_overlaps) | |
| } | |
| def _extract_consistency_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Extract consistency features between passages""" | |
| if len(passages) < 2: | |
| return {'passage_consistency': 1.0} | |
| # Semantic similarity between passages | |
| passage_texts = [p.get('text', '') for p in passages] | |
| embeddings = self.embedding_model.encode(passage_texts) | |
| # Compute pairwise similarities | |
| similarities = cosine_similarity(embeddings) | |
| # Get upper triangle (excluding diagonal) | |
| upper_triangle = similarities[np.triu_indices_from(similarities, k=1)] | |
| return { | |
| 'passage_consistency': np.mean(upper_triangle), | |
| 'passage_consistency_std': np.std(upper_triangle), | |
| 'min_passage_similarity': np.min(upper_triangle) | |
| } | |
| def _extract_diversity_features(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Extract diversity features""" | |
| if len(passages) < 2: | |
| return {'diversity': 1.0} | |
| # Topic diversity using TF-IDF | |
| passage_texts = [p.get('text', '') for p in passages] | |
| try: | |
| tfidf_matrix = self.tfidf_vectorizer.fit_transform(passage_texts) | |
| similarities = cosine_similarity(tfidf_matrix) | |
| # Diversity as 1 - average similarity | |
| upper_triangle = similarities[np.triu_indices_from(similarities, k=1)] | |
| diversity = 1.0 - np.mean(upper_triangle) | |
| return { | |
| 'diversity': diversity, | |
| 'topic_variance': np.var(upper_triangle) | |
| } | |
| except: | |
| return {'diversity': 0.5, 'topic_variance': 0.0} | |
| def _extract_entities(self, text: str) -> set: | |
| """Extract entities from text (simplified)""" | |
| # Simple entity extraction - in practice use NER | |
| # Look for capitalized words and common entity patterns | |
| entities = set() | |
| # Capitalized words (potential entities) | |
| capitalized = re.findall(r'\b[A-Z][a-z]+\b', text) | |
| entities.update(capitalized) | |
| # Numbers and dates | |
| numbers = re.findall(r'\b\d+\b', text) | |
| entities.update(numbers) | |
| return entities | |
| def _get_empty_features(self) -> Dict[str, Any]: | |
| """Return empty features when no passages available""" | |
| return { | |
| 'num_passages': 0, | |
| 'avg_similarity': 0.0, | |
| 'std_similarity': 0.0, | |
| 'max_similarity': 0.0, | |
| 'min_similarity': 0.0, | |
| 'score_variance': 0.0, | |
| 'avg_token_overlap': 0.0, | |
| 'max_token_overlap': 0.0, | |
| 'avg_entity_overlap': 0.0, | |
| 'max_entity_overlap': 0.0, | |
| 'passage_consistency': 0.0, | |
| 'passage_consistency_std': 0.0, | |
| 'min_passage_similarity': 0.0, | |
| 'diversity': 0.0, | |
| 'topic_variance': 0.0 | |
| } | |
| def extract_batch_features(self, questions: List[str], | |
| passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: | |
| """Extract features for multiple question-passage pairs""" | |
| features_list = [] | |
| for question, passages in zip(questions, passages_list): | |
| features = self.extract_features(question, passages) | |
| features_list.append(features) | |
| return features_list | |