File size: 6,683 Bytes
db06013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from typing import List, Dict, Any, Optional, Tuple
import logging
from .vllm_server import VLLMServer
from .prompt_templates import PromptTemplates
from ..calibration.features import RiskFeatureExtractor

logger = logging.getLogger(__name__)

class SafeGenerator:
    def __init__(self, vllm_server: VLLMServer, 
                 risk_extractor: RiskFeatureExtractor,
                 tau1: float = 0.3, tau2: float = 0.7):
        self.vllm_server = vllm_server
        self.risk_extractor = risk_extractor
        self.prompt_templates = PromptTemplates()
        self.tau1 = tau1  # Low risk threshold
        self.tau2 = tau2  # High risk threshold
        
    def generate_with_strategy(self, question: str, 
                             retrieved_passages: List[Dict[str, Any]],
                             force_citation: bool = False) -> Dict[str, Any]:
        """Generate answer with adaptive strategy based on risk assessment"""
        
        # Extract risk features
        risk_features = self.risk_extractor.extract_features(
            question, retrieved_passages
        )
        
        # Get risk score (placeholder - will be implemented in calibration module)
        risk_score = self._estimate_risk_score(risk_features)
        
        # Determine strategy based on risk score
        if risk_score < self.tau1:
            # Low risk: normal generation
            strategy = "normal"
            temperature = 0.7
            template_name = "rag"
        elif risk_score < self.tau2:
            # Medium risk: conservative generation with citations
            strategy = "conservative"
            temperature = 0.5
            template_name = "rag_with_citations"
            force_citation = True
        else:
            # High risk: very conservative or refuse
            strategy = "conservative_or_refuse"
            temperature = 0.3
            template_name = "rag_safe"
            force_citation = True
        
        # Generate prompt
        prompt = self.prompt_templates.create_rag_prompt(
            question, retrieved_passages, template_name
        )
        
        # Generate answer
        try:
            result = self.vllm_server.generate_single(
                prompt, 
                max_tokens=512,
                temperature=temperature
            )
            
            # Post-process for citations if needed
            if force_citation:
                result = self._add_citations(result, retrieved_passages)
            
            return {
                'answer': result,
                'risk_score': risk_score,
                'strategy': strategy,
                'temperature': temperature,
                'features': risk_features,
                'citations': self._extract_citations(result, retrieved_passages)
            }
            
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            return {
                'answer': "I apologize, but I encountered an error while generating a response.",
                'risk_score': 1.0,
                'strategy': 'error',
                'temperature': 0.0,
                'features': risk_features,
                'citations': []
            }
    
    def generate_batch(self, questions: List[str], 
                      retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
        """Generate answers for multiple questions"""
        results = []
        
        for question, passages in zip(questions, retrieved_passages_list):
            result = self.generate_with_strategy(question, passages)
            results.append(result)
        
        return results
    
    def _estimate_risk_score(self, features: Dict[str, Any]) -> float:
        """Estimate risk score from features (placeholder implementation)"""
        # This is a simplified risk estimation
        # In practice, this would use a trained calibration model
        
        # Higher similarity scores = lower risk
        avg_similarity = features.get('avg_similarity', 0.5)
        
        # More diverse passages = lower risk
        diversity = features.get('diversity', 0.5)
        
        # More passages = lower risk (up to a point)
        num_passages = min(features.get('num_passages', 1), 10)
        passage_score = 1.0 - (num_passages / 10.0)
        
        # Combine factors
        risk_score = 1.0 - (avg_similarity * 0.4 + diversity * 0.3 + (1.0 - passage_score) * 0.3)
        
        return max(0.0, min(1.0, risk_score))
    
    def _add_citations(self, answer: str, passages: List[Dict[str, Any]]) -> str:
        """Add citations to answer if not present"""
        if '[' in answer and ']' in answer:
            return answer  # Already has citations
        
        # Simple citation addition (in practice, use more sophisticated methods)
        cited_answer = answer
        for i, passage in enumerate(passages[:3]):  # Limit to first 3 passages
            if any(word in answer.lower() for word in passage['text'].lower().split()[:5]):
                cited_answer += f" [{i+1}]"
        
        return cited_answer
    
    def _extract_citations(self, answer: str, passages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Extract citations from answer"""
        citations = []
        
        # Find citation markers like [1], [2], etc.
        import re
        citation_matches = re.findall(r'\[(\d+)\]', answer)
        
        for match in citation_matches:
            idx = int(match) - 1
            if 0 <= idx < len(passages):
                citations.append({
                    'id': idx,
                    'text': passages[idx]['text'],
                    'metadata': passages[idx].get('metadata', {})
                })
        
        return citations
    
    def get_generation_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Get statistics from generation results"""
        if not results:
            return {}
        
        risk_scores = [r['risk_score'] for r in results]
        strategies = [r['strategy'] for r in results]
        
        strategy_counts = {}
        for strategy in strategies:
            strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
        
        return {
            'num_queries': len(results),
            'avg_risk_score': sum(risk_scores) / len(risk_scores),
            'min_risk_score': min(risk_scores),
            'max_risk_score': max(risk_scores),
            'strategy_distribution': strategy_counts,
            'avg_citations_per_answer': sum(len(r.get('citations', [])) for r in results) / len(results)
        }