safe_rag / calibration /features.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from typing import List, Dict, Any
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import re
logger = logging.getLogger(__name__)
class RiskFeatureExtractor:
def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"):
self.embedding_model = SentenceTransformer(embedding_model)
self.tfidf_vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
def extract_features(self, question: str, retrieved_passages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Extract risk assessment features"""
if not retrieved_passages:
return self._get_empty_features()
features = {}
# Retrieval statistics
features.update(self._extract_retrieval_stats(retrieved_passages))
# Coverage features
features.update(self._extract_coverage_features(question, retrieved_passages))
# Consistency features
features.update(self._extract_consistency_features(question, retrieved_passages))
# Diversity features
features.update(self._extract_diversity_features(retrieved_passages))
return features
def _extract_retrieval_stats(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Extract retrieval statistics"""
if not passages:
return {}
scores = [p.get('score', 0.0) for p in passages]
return {
'num_passages': len(passages),
'avg_similarity': np.mean(scores),
'std_similarity': np.std(scores),
'max_similarity': np.max(scores),
'min_similarity': np.min(scores),
'score_variance': np.var(scores)
}
def _extract_coverage_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Extract coverage features between question and passages"""
if not passages:
return {}
# Token overlap
question_tokens = set(question.lower().split())
passage_texts = [p.get('text', '') for p in passages]
overlaps = []
for passage_text in passage_texts:
passage_tokens = set(passage_text.lower().split())
overlap = len(question_tokens.intersection(passage_tokens))
overlaps.append(overlap / len(question_tokens) if question_tokens else 0)
# Entity overlap (simplified)
question_entities = self._extract_entities(question)
entity_overlaps = []
for passage_text in passage_texts:
passage_entities = self._extract_entities(passage_text)
overlap = len(question_entities.intersection(passage_entities))
entity_overlaps.append(overlap / len(question_entities) if question_entities else 0)
return {
'avg_token_overlap': np.mean(overlaps),
'max_token_overlap': np.max(overlaps),
'avg_entity_overlap': np.mean(entity_overlaps),
'max_entity_overlap': np.max(entity_overlaps)
}
def _extract_consistency_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Extract consistency features between passages"""
if len(passages) < 2:
return {'passage_consistency': 1.0}
# Semantic similarity between passages
passage_texts = [p.get('text', '') for p in passages]
embeddings = self.embedding_model.encode(passage_texts)
# Compute pairwise similarities
similarities = cosine_similarity(embeddings)
# Get upper triangle (excluding diagonal)
upper_triangle = similarities[np.triu_indices_from(similarities, k=1)]
return {
'passage_consistency': np.mean(upper_triangle),
'passage_consistency_std': np.std(upper_triangle),
'min_passage_similarity': np.min(upper_triangle)
}
def _extract_diversity_features(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Extract diversity features"""
if len(passages) < 2:
return {'diversity': 1.0}
# Topic diversity using TF-IDF
passage_texts = [p.get('text', '') for p in passages]
try:
tfidf_matrix = self.tfidf_vectorizer.fit_transform(passage_texts)
similarities = cosine_similarity(tfidf_matrix)
# Diversity as 1 - average similarity
upper_triangle = similarities[np.triu_indices_from(similarities, k=1)]
diversity = 1.0 - np.mean(upper_triangle)
return {
'diversity': diversity,
'topic_variance': np.var(upper_triangle)
}
except:
return {'diversity': 0.5, 'topic_variance': 0.0}
def _extract_entities(self, text: str) -> set:
"""Extract entities from text (simplified)"""
# Simple entity extraction - in practice use NER
# Look for capitalized words and common entity patterns
entities = set()
# Capitalized words (potential entities)
capitalized = re.findall(r'\b[A-Z][a-z]+\b', text)
entities.update(capitalized)
# Numbers and dates
numbers = re.findall(r'\b\d+\b', text)
entities.update(numbers)
return entities
def _get_empty_features(self) -> Dict[str, Any]:
"""Return empty features when no passages available"""
return {
'num_passages': 0,
'avg_similarity': 0.0,
'std_similarity': 0.0,
'max_similarity': 0.0,
'min_similarity': 0.0,
'score_variance': 0.0,
'avg_token_overlap': 0.0,
'max_token_overlap': 0.0,
'avg_entity_overlap': 0.0,
'max_entity_overlap': 0.0,
'passage_consistency': 0.0,
'passage_consistency_std': 0.0,
'min_passage_similarity': 0.0,
'diversity': 0.0,
'topic_variance': 0.0
}
def extract_batch_features(self, questions: List[str],
passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""Extract features for multiple question-passage pairs"""
features_list = []
for question, passages in zip(questions, passages_list):
features = self.extract_features(question, passages)
features_list.append(features)
return features_list