Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any, Optional, Tuple | |
| import logging | |
| from .vllm_server import VLLMServer | |
| from .prompt_templates import PromptTemplates | |
| from ..calibration.features import RiskFeatureExtractor | |
| logger = logging.getLogger(__name__) | |
| class SafeGenerator: | |
| def __init__(self, vllm_server: VLLMServer, | |
| risk_extractor: RiskFeatureExtractor, | |
| tau1: float = 0.3, tau2: float = 0.7): | |
| self.vllm_server = vllm_server | |
| self.risk_extractor = risk_extractor | |
| self.prompt_templates = PromptTemplates() | |
| self.tau1 = tau1 # Low risk threshold | |
| self.tau2 = tau2 # High risk threshold | |
| def generate_with_strategy(self, question: str, | |
| retrieved_passages: List[Dict[str, Any]], | |
| force_citation: bool = False) -> Dict[str, Any]: | |
| """Generate answer with adaptive strategy based on risk assessment""" | |
| # Extract risk features | |
| risk_features = self.risk_extractor.extract_features( | |
| question, retrieved_passages | |
| ) | |
| # Get risk score (placeholder - will be implemented in calibration module) | |
| risk_score = self._estimate_risk_score(risk_features) | |
| # Determine strategy based on risk score | |
| if risk_score < self.tau1: | |
| # Low risk: normal generation | |
| strategy = "normal" | |
| temperature = 0.7 | |
| template_name = "rag" | |
| elif risk_score < self.tau2: | |
| # Medium risk: conservative generation with citations | |
| strategy = "conservative" | |
| temperature = 0.5 | |
| template_name = "rag_with_citations" | |
| force_citation = True | |
| else: | |
| # High risk: very conservative or refuse | |
| strategy = "conservative_or_refuse" | |
| temperature = 0.3 | |
| template_name = "rag_safe" | |
| force_citation = True | |
| # Generate prompt | |
| prompt = self.prompt_templates.create_rag_prompt( | |
| question, retrieved_passages, template_name | |
| ) | |
| # Generate answer | |
| try: | |
| result = self.vllm_server.generate_single( | |
| prompt, | |
| max_tokens=512, | |
| temperature=temperature | |
| ) | |
| # Post-process for citations if needed | |
| if force_citation: | |
| result = self._add_citations(result, retrieved_passages) | |
| return { | |
| 'answer': result, | |
| 'risk_score': risk_score, | |
| 'strategy': strategy, | |
| 'temperature': temperature, | |
| 'features': risk_features, | |
| 'citations': self._extract_citations(result, retrieved_passages) | |
| } | |
| except Exception as e: | |
| logger.error(f"Generation failed: {e}") | |
| return { | |
| 'answer': "I apologize, but I encountered an error while generating a response.", | |
| 'risk_score': 1.0, | |
| 'strategy': 'error', | |
| 'temperature': 0.0, | |
| 'features': risk_features, | |
| 'citations': [] | |
| } | |
| def generate_batch(self, questions: List[str], | |
| retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: | |
| """Generate answers for multiple questions""" | |
| results = [] | |
| for question, passages in zip(questions, retrieved_passages_list): | |
| result = self.generate_with_strategy(question, passages) | |
| results.append(result) | |
| return results | |
| def _estimate_risk_score(self, features: Dict[str, Any]) -> float: | |
| """Estimate risk score from features (placeholder implementation)""" | |
| # This is a simplified risk estimation | |
| # In practice, this would use a trained calibration model | |
| # Higher similarity scores = lower risk | |
| avg_similarity = features.get('avg_similarity', 0.5) | |
| # More diverse passages = lower risk | |
| diversity = features.get('diversity', 0.5) | |
| # More passages = lower risk (up to a point) | |
| num_passages = min(features.get('num_passages', 1), 10) | |
| passage_score = 1.0 - (num_passages / 10.0) | |
| # Combine factors | |
| risk_score = 1.0 - (avg_similarity * 0.4 + diversity * 0.3 + (1.0 - passage_score) * 0.3) | |
| return max(0.0, min(1.0, risk_score)) | |
| def _add_citations(self, answer: str, passages: List[Dict[str, Any]]) -> str: | |
| """Add citations to answer if not present""" | |
| if '[' in answer and ']' in answer: | |
| return answer # Already has citations | |
| # Simple citation addition (in practice, use more sophisticated methods) | |
| cited_answer = answer | |
| for i, passage in enumerate(passages[:3]): # Limit to first 3 passages | |
| if any(word in answer.lower() for word in passage['text'].lower().split()[:5]): | |
| cited_answer += f" [{i+1}]" | |
| return cited_answer | |
| def _extract_citations(self, answer: str, passages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Extract citations from answer""" | |
| citations = [] | |
| # Find citation markers like [1], [2], etc. | |
| import re | |
| citation_matches = re.findall(r'\[(\d+)\]', answer) | |
| for match in citation_matches: | |
| idx = int(match) - 1 | |
| if 0 <= idx < len(passages): | |
| citations.append({ | |
| 'id': idx, | |
| 'text': passages[idx]['text'], | |
| 'metadata': passages[idx].get('metadata', {}) | |
| }) | |
| return citations | |
| def get_generation_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Get statistics from generation results""" | |
| if not results: | |
| return {} | |
| risk_scores = [r['risk_score'] for r in results] | |
| strategies = [r['strategy'] for r in results] | |
| strategy_counts = {} | |
| for strategy in strategies: | |
| strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1 | |
| return { | |
| 'num_queries': len(results), | |
| 'avg_risk_score': sum(risk_scores) / len(risk_scores), | |
| 'min_risk_score': min(risk_scores), | |
| 'max_risk_score': max(risk_scores), | |
| 'strategy_distribution': strategy_counts, | |
| 'avg_citations_per_answer': sum(len(r.get('citations', [])) for r in results) / len(results) | |
| } | |