Spaces:
Sleeping
Sleeping
| """ | |
| Answer generation using LLM with RAG context. | |
| Supports Gemini and OpenAI as providers. | |
| """ | |
| import google.generativeai as genai | |
| from openai import OpenAI | |
| from typing import Optional, Dict, Any, List | |
| import logging | |
| import os | |
| import re | |
| from app.config import settings | |
| from app.rag.prompts import ( | |
| format_rag_prompt, | |
| format_draft_prompt, | |
| get_no_context_response, | |
| get_low_confidence_response | |
| ) | |
| from app.rag.verifier import get_verifier_service | |
| from app.rag.intent import detect_intents | |
| from app.models.schemas import Citation | |
| from abc import ABC, abstractmethod | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class LLMProvider(ABC): | |
| """Base class for LLM providers.""" | |
| def generate(self, system_prompt: str, user_prompt: str) -> str: | |
| """Generate response from prompts.""" | |
| raise NotImplementedError | |
| def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: | |
| """ | |
| Generate response and return usage information. | |
| Returns: | |
| (response_text, usage_info) | |
| usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used | |
| """ | |
| raise NotImplementedError | |
| class GeminiProvider(LLMProvider): | |
| """Google Gemini LLM provider.""" | |
| def __init__(self, api_key: Optional[str] = None, model: str = None): | |
| self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY") | |
| # Default to gemini-1.5-flash if not specified | |
| self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash" | |
| if not self.api_key: | |
| raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.") | |
| genai.configure(api_key=self.api_key) | |
| # Don't initialize client here - do it lazily in generate() to handle errors better | |
| self._client = None | |
| logger.info(f"Gemini provider initialized (will use model: {self.model})") | |
| def generate(self, system_prompt: str, user_prompt: str) -> str: | |
| """Generate response using Gemini.""" | |
| text, _ = self.generate_with_usage(system_prompt, user_prompt) | |
| return text | |
| def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: | |
| """Generate response using Gemini and return usage info.""" | |
| # Combine system and user prompts for Gemini | |
| full_prompt = f"{system_prompt}\n\n{user_prompt}" | |
| # Estimate prompt tokens (rough: 1 token ≈ 4 chars) | |
| prompt_tokens = len(full_prompt) // 4 | |
| # Try to list available models first, then use the first available one | |
| # If that fails, try common model names | |
| models_to_try = [] | |
| # First, try to get available models | |
| try: | |
| available_models = genai.list_models() | |
| model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods] | |
| if model_names: | |
| # Extract just the model name (remove 'models/' prefix if present) | |
| clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names] | |
| models_to_try.extend(clean_names[:3]) # Use first 3 available models | |
| logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}") | |
| except Exception as e: | |
| logger.warning(f"Could not list available models: {e}, using fallback list") | |
| # Fallback to common model names if listing failed | |
| if not models_to_try: | |
| models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"] | |
| # Add configured model if different | |
| if self.model and self.model not in models_to_try: | |
| models_to_try.insert(0, self.model) | |
| # Remove duplicates while preserving order | |
| seen = set() | |
| models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))] | |
| last_error = None | |
| for model_name in models_to_try: | |
| try: | |
| logger.info(f"Attempting to generate with model: {model_name}") | |
| # Create a new client for this model | |
| client = genai.GenerativeModel(model_name) | |
| response = client.generate_content( | |
| full_prompt, | |
| generation_config=genai.types.GenerationConfig( | |
| temperature=settings.TEMPERATURE, | |
| max_output_tokens=1024, | |
| ) | |
| ) | |
| # Extract response text | |
| response_text = response.text | |
| # Try to get usage info from response | |
| usage_info = { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": len(response_text) // 4, # Estimate | |
| "total_tokens": prompt_tokens + (len(response_text) // 4), | |
| "model_used": model_name.split('/')[-1] if '/' in model_name else model_name | |
| } | |
| # Try to get actual usage from response if available | |
| if hasattr(response, 'usage_metadata'): | |
| usage_metadata = response.usage_metadata | |
| if hasattr(usage_metadata, 'prompt_token_count'): | |
| usage_info["prompt_tokens"] = usage_metadata.prompt_token_count | |
| if hasattr(usage_metadata, 'candidates_token_count'): | |
| usage_info["completion_tokens"] = usage_metadata.candidates_token_count | |
| if hasattr(usage_metadata, 'total_token_count'): | |
| usage_info["total_tokens"] = usage_metadata.total_token_count | |
| if model_name != self.model: | |
| logger.info(f"✅ Successfully used model: {model_name}") | |
| return response_text, usage_info | |
| except Exception as e: | |
| error_str = str(e).lower() | |
| last_error = e | |
| if "not found" in error_str or "not supported" in error_str or "404" in error_str: | |
| logger.warning(f"Model {model_name} failed: {e}") | |
| continue # Try next model | |
| else: | |
| # Different error (not model not found), re-raise | |
| logger.error(f"Gemini generation error with {model_name}: {e}") | |
| raise | |
| # All models failed - return a helpful error message | |
| error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models." | |
| logger.error(error_msg) | |
| raise Exception(error_msg) | |
| class OpenAIProvider(LLMProvider): | |
| """OpenAI LLM provider.""" | |
| def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL): | |
| self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY") | |
| self.model = model | |
| if not self.api_key: | |
| raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.") | |
| self.client = OpenAI(api_key=self.api_key) | |
| logger.info(f"OpenAI provider initialized with model: {model}") | |
| def generate(self, system_prompt: str, user_prompt: str) -> str: | |
| """Generate response using OpenAI.""" | |
| text, _ = self.generate_with_usage(system_prompt, user_prompt) | |
| return text | |
| def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: | |
| """Generate response using OpenAI and return usage info.""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=settings.TEMPERATURE, | |
| max_tokens=1024 | |
| ) | |
| response_text = response.choices[0].message.content | |
| # Extract usage info from OpenAI response | |
| usage_info = { | |
| "prompt_tokens": response.usage.prompt_tokens, | |
| "completion_tokens": response.usage.completion_tokens, | |
| "total_tokens": response.usage.total_tokens, | |
| "model_used": self.model | |
| } | |
| return response_text, usage_info | |
| except Exception as e: | |
| logger.error(f"OpenAI generation error: {e}") | |
| raise | |
| class AnswerService: | |
| """ | |
| Generates answers using RAG context and LLM. | |
| Handles confidence scoring and citation extraction. | |
| """ | |
| # Confidence thresholds | |
| HIGH_CONFIDENCE_THRESHOLD = 0.5 | |
| LOW_CONFIDENCE_THRESHOLD = 0.20 # Lowered to match similarity threshold | |
| STRICT_CONFIDENCE_THRESHOLD = 0.30 # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results) | |
| def __init__(self, provider: str = settings.LLM_PROVIDER): | |
| """ | |
| Initialize the answer service. | |
| Args: | |
| provider: LLM provider to use ("gemini" or "openai") | |
| """ | |
| self.provider_name = provider | |
| self._provider: Optional[LLMProvider] = None | |
| def provider(self) -> LLMProvider: | |
| """Lazy load the LLM provider.""" | |
| if self._provider is None: | |
| if self.provider_name == "gemini": | |
| self._provider = GeminiProvider() | |
| elif self.provider_name == "openai": | |
| self._provider = OpenAIProvider() | |
| else: | |
| raise ValueError(f"Unknown LLM provider: {self.provider_name}") | |
| return self._provider | |
| def generate_answer( | |
| self, | |
| question: str, | |
| context: str, | |
| citations_info: List[Dict[str, Any]], | |
| confidence: float, | |
| has_relevant_results: bool, | |
| use_verifier: bool = None # None = use config default | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate an answer based on retrieved context with mandatory verifier. | |
| Args: | |
| question: User's question | |
| context: Retrieved context from knowledge base | |
| citations_info: List of citation information | |
| confidence: Average confidence score from retrieval | |
| has_relevant_results: Whether any results passed the threshold | |
| use_verifier: Whether to use verifier mode (None = use config default) | |
| Returns: | |
| Dictionary with answer, citations, confidence, and metadata | |
| """ | |
| # Determine if verifier should be used (mandatory by default) | |
| if use_verifier is None: | |
| use_verifier = settings.REQUIRE_VERIFIER | |
| # GATE 1: No relevant results found - REFUSE | |
| if not has_relevant_results or not context: | |
| logger.info("No relevant context found, returning no-context response") | |
| return { | |
| "answer": get_no_context_response(), | |
| "citations": [], | |
| "confidence": 0.0, | |
| "from_knowledge_base": False, | |
| "escalation_suggested": True, | |
| "refused": True | |
| } | |
| # GATE 2: Strict confidence threshold - REFUSE if below strict threshold | |
| if confidence < self.STRICT_CONFIDENCE_THRESHOLD: | |
| logger.warning( | |
| f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), " | |
| f"REFUSING to answer to prevent hallucination" | |
| ) | |
| return { | |
| "answer": get_no_context_response(), | |
| "citations": [], | |
| "confidence": confidence, | |
| "from_knowledge_base": False, | |
| "escalation_suggested": True, | |
| "refused": True, | |
| "refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}" | |
| } | |
| # GATE 3: Intent-based gating for specific intents (integration, API, etc.) | |
| intents = detect_intents(question) | |
| if "integration" in intents or "api" in question.lower(): | |
| # For integration/API questions, require strong relevance | |
| if confidence < 0.50: # Even stricter for integration questions | |
| logger.warning( | |
| f"Integration/API question with low confidence ({confidence:.3f}), " | |
| f"REFUSING to prevent hallucination" | |
| ) | |
| return { | |
| "answer": get_no_context_response(), | |
| "citations": [], | |
| "confidence": confidence, | |
| "from_knowledge_base": False, | |
| "escalation_suggested": True, | |
| "refused": True, | |
| "refusal_reason": "Integration/API questions require higher confidence" | |
| } | |
| # Case 3: Passed all gates - generate answer with MANDATORY verifier | |
| logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}") | |
| try: | |
| # VERIFIER MODE IS MANDATORY: Draft → Verify → Final | |
| # Step 1: Generate draft answer with usage tracking | |
| draft_system, draft_user = format_draft_prompt(context, question) | |
| draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user) | |
| logger.info("Generated draft answer, running verifier...") | |
| # Step 2: Verify draft answer (MANDATORY) | |
| verifier = get_verifier_service() | |
| verification = verifier.verify_answer( | |
| draft_answer=draft_answer, | |
| context=context, | |
| citations_info=citations_info | |
| ) | |
| # Step 3: Handle verification result | |
| if verification["pass"]: | |
| logger.info("✅ Verifier PASSED - Using draft answer") | |
| citations = self._extract_citations(draft_answer, citations_info) | |
| return { | |
| "answer": draft_answer, | |
| "citations": citations, | |
| "confidence": confidence, | |
| "from_knowledge_base": True, | |
| "escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD, | |
| "verifier_passed": True, | |
| "refused": False, | |
| "usage": usage_info # Include usage info for tracking | |
| } | |
| else: | |
| # Verifier failed - REFUSE to answer | |
| issues = verification.get('issues', []) | |
| unsupported = verification.get('unsupported_claims', []) | |
| logger.warning( | |
| f"❌ Verifier FAILED - Issues: {issues}, " | |
| f"Unsupported claims: {unsupported}" | |
| ) | |
| refusal_message = ( | |
| get_no_context_response() + | |
| "\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. " | |
| "This helps prevent providing incorrect information." | |
| ) | |
| return { | |
| "answer": refusal_message, | |
| "citations": [], | |
| "confidence": 0.0, | |
| "from_knowledge_base": False, | |
| "escalation_suggested": True, | |
| "verifier_passed": False, | |
| "verifier_issues": issues, | |
| "unsupported_claims": unsupported, | |
| "refused": True, | |
| "refusal_reason": "Verifier failed: claims not supported by context", | |
| "usage": usage_info # Still track usage even if refused | |
| } | |
| except ValueError as e: | |
| # Configuration errors (e.g., missing API key) | |
| error_msg = str(e) | |
| logger.error(f"Configuration error in answer generation: {error_msg}") | |
| if "API key" in error_msg.lower(): | |
| raise ValueError(f"LLM API key not configured: {error_msg}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error generating answer: {e}", exc_info=True) | |
| # Re-raise to be handled by the endpoint | |
| raise | |
| def _extract_citations( | |
| self, | |
| answer: str, | |
| citations_info: List[Dict[str, Any]] | |
| ) -> List[Citation]: | |
| """ | |
| Extract and format citations from the answer. | |
| Args: | |
| answer: Generated answer with [Source X] references | |
| citations_info: Available citation information | |
| Returns: | |
| List of Citation objects | |
| """ | |
| citations = [] | |
| # Find all [Source X] references in the answer | |
| source_pattern = r'\[Source\s*(\d+)\]' | |
| matches = re.findall(source_pattern, answer) | |
| referenced_indices = set(int(m) for m in matches) | |
| # Build citation objects for referenced sources | |
| for info in citations_info: | |
| if info.get("index") in referenced_indices: | |
| citations.append(Citation( | |
| file_name=info.get("file_name", "Unknown"), | |
| chunk_id=info.get("chunk_id", ""), | |
| page_number=info.get("page_number"), | |
| relevance_score=info.get("similarity_score", 0.0), | |
| excerpt=info.get("excerpt", "") | |
| )) | |
| # If no specific citations found but we have context, include top sources | |
| if not citations and citations_info: | |
| for info in citations_info[:3]: # Top 3 sources | |
| citations.append(Citation( | |
| file_name=info.get("file_name", "Unknown"), | |
| chunk_id=info.get("chunk_id", ""), | |
| page_number=info.get("page_number"), | |
| relevance_score=info.get("similarity_score", 0.0), | |
| excerpt=info.get("excerpt", "") | |
| )) | |
| return citations | |
| # Global answer service instance | |
| _answer_service: Optional[AnswerService] = None | |
| def get_answer_service() -> AnswerService: | |
| """Get the global answer service instance.""" | |
| global _answer_service | |
| if _answer_service is None: | |
| _answer_service = AnswerService() | |
| return _answer_service | |