safe_rag / generator /safe_generate.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from typing import List, Dict, Any, Optional, Tuple
import logging
from .vllm_server import VLLMServer
from .prompt_templates import PromptTemplates
from ..calibration.features import RiskFeatureExtractor
logger = logging.getLogger(__name__)
class SafeGenerator:
def __init__(self, vllm_server: VLLMServer,
risk_extractor: RiskFeatureExtractor,
tau1: float = 0.3, tau2: float = 0.7):
self.vllm_server = vllm_server
self.risk_extractor = risk_extractor
self.prompt_templates = PromptTemplates()
self.tau1 = tau1 # Low risk threshold
self.tau2 = tau2 # High risk threshold
def generate_with_strategy(self, question: str,
retrieved_passages: List[Dict[str, Any]],
force_citation: bool = False) -> Dict[str, Any]:
"""Generate answer with adaptive strategy based on risk assessment"""
# Extract risk features
risk_features = self.risk_extractor.extract_features(
question, retrieved_passages
)
# Get risk score (placeholder - will be implemented in calibration module)
risk_score = self._estimate_risk_score(risk_features)
# Determine strategy based on risk score
if risk_score < self.tau1:
# Low risk: normal generation
strategy = "normal"
temperature = 0.7
template_name = "rag"
elif risk_score < self.tau2:
# Medium risk: conservative generation with citations
strategy = "conservative"
temperature = 0.5
template_name = "rag_with_citations"
force_citation = True
else:
# High risk: very conservative or refuse
strategy = "conservative_or_refuse"
temperature = 0.3
template_name = "rag_safe"
force_citation = True
# Generate prompt
prompt = self.prompt_templates.create_rag_prompt(
question, retrieved_passages, template_name
)
# Generate answer
try:
result = self.vllm_server.generate_single(
prompt,
max_tokens=512,
temperature=temperature
)
# Post-process for citations if needed
if force_citation:
result = self._add_citations(result, retrieved_passages)
return {
'answer': result,
'risk_score': risk_score,
'strategy': strategy,
'temperature': temperature,
'features': risk_features,
'citations': self._extract_citations(result, retrieved_passages)
}
except Exception as e:
logger.error(f"Generation failed: {e}")
return {
'answer': "I apologize, but I encountered an error while generating a response.",
'risk_score': 1.0,
'strategy': 'error',
'temperature': 0.0,
'features': risk_features,
'citations': []
}
def generate_batch(self, questions: List[str],
retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""Generate answers for multiple questions"""
results = []
for question, passages in zip(questions, retrieved_passages_list):
result = self.generate_with_strategy(question, passages)
results.append(result)
return results
def _estimate_risk_score(self, features: Dict[str, Any]) -> float:
"""Estimate risk score from features (placeholder implementation)"""
# This is a simplified risk estimation
# In practice, this would use a trained calibration model
# Higher similarity scores = lower risk
avg_similarity = features.get('avg_similarity', 0.5)
# More diverse passages = lower risk
diversity = features.get('diversity', 0.5)
# More passages = lower risk (up to a point)
num_passages = min(features.get('num_passages', 1), 10)
passage_score = 1.0 - (num_passages / 10.0)
# Combine factors
risk_score = 1.0 - (avg_similarity * 0.4 + diversity * 0.3 + (1.0 - passage_score) * 0.3)
return max(0.0, min(1.0, risk_score))
def _add_citations(self, answer: str, passages: List[Dict[str, Any]]) -> str:
"""Add citations to answer if not present"""
if '[' in answer and ']' in answer:
return answer # Already has citations
# Simple citation addition (in practice, use more sophisticated methods)
cited_answer = answer
for i, passage in enumerate(passages[:3]): # Limit to first 3 passages
if any(word in answer.lower() for word in passage['text'].lower().split()[:5]):
cited_answer += f" [{i+1}]"
return cited_answer
def _extract_citations(self, answer: str, passages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract citations from answer"""
citations = []
# Find citation markers like [1], [2], etc.
import re
citation_matches = re.findall(r'\[(\d+)\]', answer)
for match in citation_matches:
idx = int(match) - 1
if 0 <= idx < len(passages):
citations.append({
'id': idx,
'text': passages[idx]['text'],
'metadata': passages[idx].get('metadata', {})
})
return citations
def get_generation_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Get statistics from generation results"""
if not results:
return {}
risk_scores = [r['risk_score'] for r in results]
strategies = [r['strategy'] for r in results]
strategy_counts = {}
for strategy in strategies:
strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
return {
'num_queries': len(results),
'avg_risk_score': sum(risk_scores) / len(risk_scores),
'min_risk_score': min(risk_scores),
'max_risk_score': max(risk_scores),
'strategy_distribution': strategy_counts,
'avg_citations_per_answer': sum(len(r.get('citations', [])) for r in results) / len(results)
}