Spaces:
Sleeping
Sleeping
File size: 6,683 Bytes
db06013 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | from typing import List, Dict, Any, Optional, Tuple
import logging
from .vllm_server import VLLMServer
from .prompt_templates import PromptTemplates
from ..calibration.features import RiskFeatureExtractor
logger = logging.getLogger(__name__)
class SafeGenerator:
def __init__(self, vllm_server: VLLMServer,
risk_extractor: RiskFeatureExtractor,
tau1: float = 0.3, tau2: float = 0.7):
self.vllm_server = vllm_server
self.risk_extractor = risk_extractor
self.prompt_templates = PromptTemplates()
self.tau1 = tau1 # Low risk threshold
self.tau2 = tau2 # High risk threshold
def generate_with_strategy(self, question: str,
retrieved_passages: List[Dict[str, Any]],
force_citation: bool = False) -> Dict[str, Any]:
"""Generate answer with adaptive strategy based on risk assessment"""
# Extract risk features
risk_features = self.risk_extractor.extract_features(
question, retrieved_passages
)
# Get risk score (placeholder - will be implemented in calibration module)
risk_score = self._estimate_risk_score(risk_features)
# Determine strategy based on risk score
if risk_score < self.tau1:
# Low risk: normal generation
strategy = "normal"
temperature = 0.7
template_name = "rag"
elif risk_score < self.tau2:
# Medium risk: conservative generation with citations
strategy = "conservative"
temperature = 0.5
template_name = "rag_with_citations"
force_citation = True
else:
# High risk: very conservative or refuse
strategy = "conservative_or_refuse"
temperature = 0.3
template_name = "rag_safe"
force_citation = True
# Generate prompt
prompt = self.prompt_templates.create_rag_prompt(
question, retrieved_passages, template_name
)
# Generate answer
try:
result = self.vllm_server.generate_single(
prompt,
max_tokens=512,
temperature=temperature
)
# Post-process for citations if needed
if force_citation:
result = self._add_citations(result, retrieved_passages)
return {
'answer': result,
'risk_score': risk_score,
'strategy': strategy,
'temperature': temperature,
'features': risk_features,
'citations': self._extract_citations(result, retrieved_passages)
}
except Exception as e:
logger.error(f"Generation failed: {e}")
return {
'answer': "I apologize, but I encountered an error while generating a response.",
'risk_score': 1.0,
'strategy': 'error',
'temperature': 0.0,
'features': risk_features,
'citations': []
}
def generate_batch(self, questions: List[str],
retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""Generate answers for multiple questions"""
results = []
for question, passages in zip(questions, retrieved_passages_list):
result = self.generate_with_strategy(question, passages)
results.append(result)
return results
def _estimate_risk_score(self, features: Dict[str, Any]) -> float:
"""Estimate risk score from features (placeholder implementation)"""
# This is a simplified risk estimation
# In practice, this would use a trained calibration model
# Higher similarity scores = lower risk
avg_similarity = features.get('avg_similarity', 0.5)
# More diverse passages = lower risk
diversity = features.get('diversity', 0.5)
# More passages = lower risk (up to a point)
num_passages = min(features.get('num_passages', 1), 10)
passage_score = 1.0 - (num_passages / 10.0)
# Combine factors
risk_score = 1.0 - (avg_similarity * 0.4 + diversity * 0.3 + (1.0 - passage_score) * 0.3)
return max(0.0, min(1.0, risk_score))
def _add_citations(self, answer: str, passages: List[Dict[str, Any]]) -> str:
"""Add citations to answer if not present"""
if '[' in answer and ']' in answer:
return answer # Already has citations
# Simple citation addition (in practice, use more sophisticated methods)
cited_answer = answer
for i, passage in enumerate(passages[:3]): # Limit to first 3 passages
if any(word in answer.lower() for word in passage['text'].lower().split()[:5]):
cited_answer += f" [{i+1}]"
return cited_answer
def _extract_citations(self, answer: str, passages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract citations from answer"""
citations = []
# Find citation markers like [1], [2], etc.
import re
citation_matches = re.findall(r'\[(\d+)\]', answer)
for match in citation_matches:
idx = int(match) - 1
if 0 <= idx < len(passages):
citations.append({
'id': idx,
'text': passages[idx]['text'],
'metadata': passages[idx].get('metadata', {})
})
return citations
def get_generation_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Get statistics from generation results"""
if not results:
return {}
risk_scores = [r['risk_score'] for r in results]
strategies = [r['strategy'] for r in results]
strategy_counts = {}
for strategy in strategies:
strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
return {
'num_queries': len(results),
'avg_risk_score': sum(risk_scores) / len(risk_scores),
'min_risk_score': min(risk_scores),
'max_risk_score': max(risk_scores),
'strategy_distribution': strategy_counts,
'avg_citations_per_answer': sum(len(r.get('citations', [])) for r in results) / len(results)
}
|