File size: 18,832 Bytes
c19c7bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
"""
Answer generation using LLM with RAG context.
Supports Gemini and OpenAI as providers.
"""
import google.generativeai as genai
from openai import OpenAI
from typing import Optional, Dict, Any, List
import logging
import os
import re

from app.config import settings
from app.rag.prompts import (
    format_rag_prompt,
    format_draft_prompt,
    get_no_context_response,
    get_low_confidence_response
)
from app.rag.verifier import get_verifier_service
from app.rag.intent import detect_intents
from app.models.schemas import Citation
from abc import ABC, abstractmethod

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class LLMProvider(ABC):
    """Base class for LLM providers."""
    
    @abstractmethod
    def generate(self, system_prompt: str, user_prompt: str) -> str:
        """Generate response from prompts."""
        raise NotImplementedError
    
    @abstractmethod
    def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
        """
        Generate response and return usage information.
        
        Returns:
            (response_text, usage_info)
            usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used
        """
        raise NotImplementedError


class GeminiProvider(LLMProvider):
    """Google Gemini LLM provider."""
    
    def __init__(self, api_key: Optional[str] = None, model: str = None):
        self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY")
        # Default to gemini-1.5-flash if not specified
        self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash"
        
        if not self.api_key:
            raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.")
        
        genai.configure(api_key=self.api_key)
        
        # Don't initialize client here - do it lazily in generate() to handle errors better
        self._client = None
        logger.info(f"Gemini provider initialized (will use model: {self.model})")
    
    def generate(self, system_prompt: str, user_prompt: str) -> str:
        """Generate response using Gemini."""
        text, _ = self.generate_with_usage(system_prompt, user_prompt)
        return text
    
    def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
        """Generate response using Gemini and return usage info."""
        # Combine system and user prompts for Gemini
        full_prompt = f"{system_prompt}\n\n{user_prompt}"
        
        # Estimate prompt tokens (rough: 1 token ≈ 4 chars)
        prompt_tokens = len(full_prompt) // 4
        
        # Try to list available models first, then use the first available one
        # If that fails, try common model names
        models_to_try = []
        
        # First, try to get available models
        try:
            available_models = genai.list_models()
            model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods]
            if model_names:
                # Extract just the model name (remove 'models/' prefix if present)
                clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names]
                models_to_try.extend(clean_names[:3])  # Use first 3 available models
                logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}")
        except Exception as e:
            logger.warning(f"Could not list available models: {e}, using fallback list")
        
        # Fallback to common model names if listing failed
        if not models_to_try:
            models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"]
        
        # Add configured model if different
        if self.model and self.model not in models_to_try:
            models_to_try.insert(0, self.model)
        
        # Remove duplicates while preserving order
        seen = set()
        models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))]
        
        last_error = None
        for model_name in models_to_try:
            try:
                logger.info(f"Attempting to generate with model: {model_name}")
                # Create a new client for this model
                client = genai.GenerativeModel(model_name)
                response = client.generate_content(
                    full_prompt,
                    generation_config=genai.types.GenerationConfig(
                        temperature=settings.TEMPERATURE,
                        max_output_tokens=1024,
                    )
                )
                
                # Extract response text
                response_text = response.text
                
                # Try to get usage info from response
                usage_info = {
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": len(response_text) // 4,  # Estimate
                    "total_tokens": prompt_tokens + (len(response_text) // 4),
                    "model_used": model_name.split('/')[-1] if '/' in model_name else model_name
                }
                
                # Try to get actual usage from response if available
                if hasattr(response, 'usage_metadata'):
                    usage_metadata = response.usage_metadata
                    if hasattr(usage_metadata, 'prompt_token_count'):
                        usage_info["prompt_tokens"] = usage_metadata.prompt_token_count
                    if hasattr(usage_metadata, 'candidates_token_count'):
                        usage_info["completion_tokens"] = usage_metadata.candidates_token_count
                    if hasattr(usage_metadata, 'total_token_count'):
                        usage_info["total_tokens"] = usage_metadata.total_token_count
                
                if model_name != self.model:
                    logger.info(f"✅ Successfully used model: {model_name}")
                
                return response_text, usage_info
            except Exception as e:
                error_str = str(e).lower()
                last_error = e
                if "not found" in error_str or "not supported" in error_str or "404" in error_str:
                    logger.warning(f"Model {model_name} failed: {e}")
                    continue  # Try next model
                else:
                    # Different error (not model not found), re-raise
                    logger.error(f"Gemini generation error with {model_name}: {e}")
                    raise
        
        # All models failed - return a helpful error message
        error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models."
        logger.error(error_msg)
        raise Exception(error_msg)


class OpenAIProvider(LLMProvider):
    """OpenAI LLM provider."""
    
    def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL):
        self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY")
        self.model = model
        
        if not self.api_key:
            raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.")
        
        self.client = OpenAI(api_key=self.api_key)
        logger.info(f"OpenAI provider initialized with model: {model}")
    
    def generate(self, system_prompt: str, user_prompt: str) -> str:
        """Generate response using OpenAI."""
        text, _ = self.generate_with_usage(system_prompt, user_prompt)
        return text
    
    def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
        """Generate response using OpenAI and return usage info."""
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=settings.TEMPERATURE,
                max_tokens=1024
            )
            
            response_text = response.choices[0].message.content
            
            # Extract usage info from OpenAI response
            usage_info = {
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens,
                "model_used": self.model
            }
            
            return response_text, usage_info
        except Exception as e:
            logger.error(f"OpenAI generation error: {e}")
            raise


class AnswerService:
    """
    Generates answers using RAG context and LLM.
    Handles confidence scoring and citation extraction.
    """
    
    # Confidence thresholds
    HIGH_CONFIDENCE_THRESHOLD = 0.5
    LOW_CONFIDENCE_THRESHOLD = 0.20  # Lowered to match similarity threshold
    STRICT_CONFIDENCE_THRESHOLD = 0.30  # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results)
    
    def __init__(self, provider: str = settings.LLM_PROVIDER):
        """
        Initialize the answer service.
        
        Args:
            provider: LLM provider to use ("gemini" or "openai")
        """
        self.provider_name = provider
        self._provider: Optional[LLMProvider] = None
    
    @property
    def provider(self) -> LLMProvider:
        """Lazy load the LLM provider."""
        if self._provider is None:
            if self.provider_name == "gemini":
                self._provider = GeminiProvider()
            elif self.provider_name == "openai":
                self._provider = OpenAIProvider()
            else:
                raise ValueError(f"Unknown LLM provider: {self.provider_name}")
        return self._provider
    
    def generate_answer(
        self,
        question: str,
        context: str,
        citations_info: List[Dict[str, Any]],
        confidence: float,
        has_relevant_results: bool,
        use_verifier: bool = None  # None = use config default
    ) -> Dict[str, Any]:
        """
        Generate an answer based on retrieved context with mandatory verifier.
        
        Args:
            question: User's question
            context: Retrieved context from knowledge base
            citations_info: List of citation information
            confidence: Average confidence score from retrieval
            has_relevant_results: Whether any results passed the threshold
            use_verifier: Whether to use verifier mode (None = use config default)
            
        Returns:
            Dictionary with answer, citations, confidence, and metadata
        """
        # Determine if verifier should be used (mandatory by default)
        if use_verifier is None:
            use_verifier = settings.REQUIRE_VERIFIER
        
        # GATE 1: No relevant results found - REFUSE
        if not has_relevant_results or not context:
            logger.info("No relevant context found, returning no-context response")
            return {
                "answer": get_no_context_response(),
                "citations": [],
                "confidence": 0.0,
                "from_knowledge_base": False,
                "escalation_suggested": True,
                "refused": True
            }
        
        # GATE 2: Strict confidence threshold - REFUSE if below strict threshold
        if confidence < self.STRICT_CONFIDENCE_THRESHOLD:
            logger.warning(
                f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), "
                f"REFUSING to answer to prevent hallucination"
            )
            return {
                "answer": get_no_context_response(),
                "citations": [],
                "confidence": confidence,
                "from_knowledge_base": False,
                "escalation_suggested": True,
                "refused": True,
                "refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}"
            }
        
        # GATE 3: Intent-based gating for specific intents (integration, API, etc.)
        intents = detect_intents(question)
        if "integration" in intents or "api" in question.lower():
            # For integration/API questions, require strong relevance
            if confidence < 0.50:  # Even stricter for integration questions
                logger.warning(
                    f"Integration/API question with low confidence ({confidence:.3f}), "
                    f"REFUSING to prevent hallucination"
                )
                return {
                    "answer": get_no_context_response(),
                    "citations": [],
                    "confidence": confidence,
                    "from_knowledge_base": False,
                    "escalation_suggested": True,
                    "refused": True,
                    "refusal_reason": "Integration/API questions require higher confidence"
                }
        
        # Case 3: Passed all gates - generate answer with MANDATORY verifier
        logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}")
        
        try:
            # VERIFIER MODE IS MANDATORY: Draft → Verify → Final
            # Step 1: Generate draft answer with usage tracking
            draft_system, draft_user = format_draft_prompt(context, question)
            draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user)
            logger.info("Generated draft answer, running verifier...")
            
            # Step 2: Verify draft answer (MANDATORY)
            verifier = get_verifier_service()
            verification = verifier.verify_answer(
                draft_answer=draft_answer,
                context=context,
                citations_info=citations_info
            )
            
            # Step 3: Handle verification result
            if verification["pass"]:
                logger.info("✅ Verifier PASSED - Using draft answer")
                citations = self._extract_citations(draft_answer, citations_info)
                return {
                    "answer": draft_answer,
                    "citations": citations,
                    "confidence": confidence,
                    "from_knowledge_base": True,
                    "escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD,
                    "verifier_passed": True,
                    "refused": False,
                    "usage": usage_info  # Include usage info for tracking
                }
            else:
                # Verifier failed - REFUSE to answer
                issues = verification.get('issues', [])
                unsupported = verification.get('unsupported_claims', [])
                logger.warning(
                    f"❌ Verifier FAILED - Issues: {issues}, "
                    f"Unsupported claims: {unsupported}"
                )
                refusal_message = (
                    get_no_context_response() + 
                    "\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. "
                    "This helps prevent providing incorrect information."
                )
                return {
                    "answer": refusal_message,
                    "citations": [],
                    "confidence": 0.0,
                    "from_knowledge_base": False,
                    "escalation_suggested": True,
                    "verifier_passed": False,
                    "verifier_issues": issues,
                    "unsupported_claims": unsupported,
                    "refused": True,
                    "refusal_reason": "Verifier failed: claims not supported by context",
                    "usage": usage_info  # Still track usage even if refused
                }
                
        except ValueError as e:
            # Configuration errors (e.g., missing API key)
            error_msg = str(e)
            logger.error(f"Configuration error in answer generation: {error_msg}")
            if "API key" in error_msg.lower():
                raise ValueError(f"LLM API key not configured: {error_msg}")
            raise
        except Exception as e:
            logger.error(f"Error generating answer: {e}", exc_info=True)
            # Re-raise to be handled by the endpoint
            raise
    
    def _extract_citations(
        self,
        answer: str,
        citations_info: List[Dict[str, Any]]
    ) -> List[Citation]:
        """
        Extract and format citations from the answer.
        
        Args:
            answer: Generated answer with [Source X] references
            citations_info: Available citation information
            
        Returns:
            List of Citation objects
        """
        citations = []
        
        # Find all [Source X] references in the answer
        source_pattern = r'\[Source\s*(\d+)\]'
        matches = re.findall(source_pattern, answer)
        referenced_indices = set(int(m) for m in matches)
        
        # Build citation objects for referenced sources
        for info in citations_info:
            if info.get("index") in referenced_indices:
                citations.append(Citation(
                    file_name=info.get("file_name", "Unknown"),
                    chunk_id=info.get("chunk_id", ""),
                    page_number=info.get("page_number"),
                    relevance_score=info.get("similarity_score", 0.0),
                    excerpt=info.get("excerpt", "")
                ))
        
        # If no specific citations found but we have context, include top sources
        if not citations and citations_info:
            for info in citations_info[:3]:  # Top 3 sources
                citations.append(Citation(
                    file_name=info.get("file_name", "Unknown"),
                    chunk_id=info.get("chunk_id", ""),
                    page_number=info.get("page_number"),
                    relevance_score=info.get("similarity_score", 0.0),
                    excerpt=info.get("excerpt", "")
                ))
        
        return citations


# Global answer service instance
_answer_service: Optional[AnswerService] = None


def get_answer_service() -> AnswerService:
    """Get the global answer service instance."""
    global _answer_service
    if _answer_service is None:
        _answer_service = AnswerService()
    return _answer_service