from typing import List, Dict, Any, Optional, Tuple import logging from .vllm_server import VLLMServer from .prompt_templates import PromptTemplates from ..calibration.features import RiskFeatureExtractor logger = logging.getLogger(__name__) class SafeGenerator: def __init__(self, vllm_server: VLLMServer, risk_extractor: RiskFeatureExtractor, tau1: float = 0.3, tau2: float = 0.7): self.vllm_server = vllm_server self.risk_extractor = risk_extractor self.prompt_templates = PromptTemplates() self.tau1 = tau1 # Low risk threshold self.tau2 = tau2 # High risk threshold def generate_with_strategy(self, question: str, retrieved_passages: List[Dict[str, Any]], force_citation: bool = False) -> Dict[str, Any]: """Generate answer with adaptive strategy based on risk assessment""" # Extract risk features risk_features = self.risk_extractor.extract_features( question, retrieved_passages ) # Get risk score (placeholder - will be implemented in calibration module) risk_score = self._estimate_risk_score(risk_features) # Determine strategy based on risk score if risk_score < self.tau1: # Low risk: normal generation strategy = "normal" temperature = 0.7 template_name = "rag" elif risk_score < self.tau2: # Medium risk: conservative generation with citations strategy = "conservative" temperature = 0.5 template_name = "rag_with_citations" force_citation = True else: # High risk: very conservative or refuse strategy = "conservative_or_refuse" temperature = 0.3 template_name = "rag_safe" force_citation = True # Generate prompt prompt = self.prompt_templates.create_rag_prompt( question, retrieved_passages, template_name ) # Generate answer try: result = self.vllm_server.generate_single( prompt, max_tokens=512, temperature=temperature ) # Post-process for citations if needed if force_citation: result = self._add_citations(result, retrieved_passages) return { 'answer': result, 'risk_score': risk_score, 'strategy': strategy, 'temperature': temperature, 'features': risk_features, 'citations': self._extract_citations(result, retrieved_passages) } except Exception as e: logger.error(f"Generation failed: {e}") return { 'answer': "I apologize, but I encountered an error while generating a response.", 'risk_score': 1.0, 'strategy': 'error', 'temperature': 0.0, 'features': risk_features, 'citations': [] } def generate_batch(self, questions: List[str], retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: """Generate answers for multiple questions""" results = [] for question, passages in zip(questions, retrieved_passages_list): result = self.generate_with_strategy(question, passages) results.append(result) return results def _estimate_risk_score(self, features: Dict[str, Any]) -> float: """Estimate risk score from features (placeholder implementation)""" # This is a simplified risk estimation # In practice, this would use a trained calibration model # Higher similarity scores = lower risk avg_similarity = features.get('avg_similarity', 0.5) # More diverse passages = lower risk diversity = features.get('diversity', 0.5) # More passages = lower risk (up to a point) num_passages = min(features.get('num_passages', 1), 10) passage_score = 1.0 - (num_passages / 10.0) # Combine factors risk_score = 1.0 - (avg_similarity * 0.4 + diversity * 0.3 + (1.0 - passage_score) * 0.3) return max(0.0, min(1.0, risk_score)) def _add_citations(self, answer: str, passages: List[Dict[str, Any]]) -> str: """Add citations to answer if not present""" if '[' in answer and ']' in answer: return answer # Already has citations # Simple citation addition (in practice, use more sophisticated methods) cited_answer = answer for i, passage in enumerate(passages[:3]): # Limit to first 3 passages if any(word in answer.lower() for word in passage['text'].lower().split()[:5]): cited_answer += f" [{i+1}]" return cited_answer def _extract_citations(self, answer: str, passages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Extract citations from answer""" citations = [] # Find citation markers like [1], [2], etc. import re citation_matches = re.findall(r'\[(\d+)\]', answer) for match in citation_matches: idx = int(match) - 1 if 0 <= idx < len(passages): citations.append({ 'id': idx, 'text': passages[idx]['text'], 'metadata': passages[idx].get('metadata', {}) }) return citations def get_generation_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """Get statistics from generation results""" if not results: return {} risk_scores = [r['risk_score'] for r in results] strategies = [r['strategy'] for r in results] strategy_counts = {} for strategy in strategies: strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1 return { 'num_queries': len(results), 'avg_risk_score': sum(risk_scores) / len(risk_scores), 'min_risk_score': min(risk_scores), 'max_risk_score': max(risk_scores), 'strategy_distribution': strategy_counts, 'avg_citations_per_answer': sum(len(r.get('citations', [])) for r in results) / len(results) }