"""Grounded answer generation for Warbler-CDA retrieval results.""" from __future__ import annotations import os import re from dataclasses import dataclass from typing import Any, Dict, List, Optional @dataclass class GeneratedAnswer: """A grounded answer plus provenance for display and API responses.""" answer: str citations: List[Dict[str, Any]] provider: str used_fallback: bool = False class AnswerGenerator: """Generate grounded answers from retrieval results with lazy provider loading.""" def __init__(self, config: Optional[Dict[str, Any]] = None): self.config = config or {} self.model_name = self.config.get("model_name", os.getenv("WARBLER_GENERATION_MODEL", "google/flan-t5-small")) self.max_context_chars = int(self.config.get("max_context_chars", os.getenv("WARBLER_GENERATION_MAX_CONTEXT_CHARS", "4000"))) self.max_new_tokens = int(self.config.get("max_new_tokens", os.getenv("WARBLER_GENERATION_MAX_NEW_TOKENS", "220"))) self.temperature = float(self.config.get("temperature", os.getenv("WARBLER_GENERATION_TEMPERATURE", "0.2"))) self.enabled = str(self.config.get("enabled", os.getenv("WARBLER_ENABLE_GENERATION", "1"))).lower() not in {"0", "false", "no", "off"} self._local_pipeline = None self._openai_client = None def generate_answer(self, query_text: str, results: List[Dict[str, Any]]) -> GeneratedAnswer: """Generate a grounded answer from retrieved results.""" citations = self._build_citations(results) if not results: return GeneratedAnswer( answer="I could not find enough relevant context to generate a grounded answer.", citations=[], provider="none", used_fallback=True, ) if not self.enabled: return self._extractive_fallback(query_text, results, citations, provider="disabled") prompt = self._build_prompt(query_text, results) openai_answer = self._try_openai_generation(prompt) if openai_answer: return GeneratedAnswer(answer=openai_answer, citations=citations, provider="openai") local_answer = self._try_local_generation(prompt) if local_answer: return GeneratedAnswer(answer=local_answer, citations=citations, provider=self.model_name) return self._extractive_fallback(query_text, results, citations, provider="extractive") def _try_openai_generation(self, prompt: str) -> Optional[str]: api_key = os.getenv("OPENAI_API_KEY") if not api_key: return None try: if self._openai_client is None: from openai import OpenAI self._openai_client = OpenAI(api_key=api_key) response = self._openai_client.chat.completions.create( model=os.getenv("WARBLER_OPENAI_CHAT_MODEL", "gpt-4o-mini"), temperature=self.temperature, max_tokens=self.max_new_tokens, messages=[ { "role": "system", "content": ( "You are a grounded retrieval-augmented assistant. " "Answer only from the provided context. If the context is insufficient, say so clearly." ), }, {"role": "user", "content": prompt}, ], ) content = response.choices[0].message.content if response.choices else None return content.strip() if content else None except Exception: return None def _try_local_generation(self, prompt: str) -> Optional[str]: try: if self._local_pipeline is None: from transformers import pipeline self._local_pipeline = pipeline( "text2text-generation", model=self.model_name, tokenizer=self.model_name, ) outputs = self._local_pipeline( prompt, max_new_tokens=self.max_new_tokens, do_sample=self.temperature > 0, temperature=self.temperature, ) if not outputs: return None text = outputs[0].get("generated_text") or outputs[0].get("summary_text") return text.strip() if text else None except Exception: return None def _build_prompt(self, query_text: str, results: List[Dict[str, Any]]) -> str: context_chunks: List[str] = [] current_length = 0 for index, result in enumerate(results, start=1): snippet = (result.get("content") or "").strip() if not snippet: continue pack_name = result.get("metadata", {}).get("pack", "unknown-pack") chunk = f"[{index}] Source: {pack_name}\nSnippet: {snippet}" if current_length + len(chunk) > self.max_context_chars: break context_chunks.append(chunk) current_length += len(chunk) context_text = "\n\n".join(context_chunks) return ( "Answer the user question using only the retrieved context below. " "Be concise but useful. Quote uncertainty when context is incomplete. " "At the end, include a short 'Sources:' line with bracketed source numbers you relied on.\n\n" f"Question: {query_text}\n\n" f"Retrieved Context:\n{context_text}\n" ) def _build_citations(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: citations = [] for index, result in enumerate(results, start=1): citations.append( { "index": index, "id": result.get("id") or result.get("content_id"), "pack": result.get("metadata", {}).get("pack"), "score": result.get("relevance_score"), } ) return citations def _extractive_fallback( self, query_text: str, results: List[Dict[str, Any]], citations: List[Dict[str, Any]], provider: str, ) -> GeneratedAnswer: snippets = [] for citation, result in zip(citations[:3], results[:3]): content = (result.get("content") or "").strip() if content: snippet = self._extract_relevant_snippet(query_text, content) snippets.append(f"[{citation['index']}] {snippet}") if snippets: answer = ( f"Based on the retrieved context for '{query_text}', the strongest evidence is: " + " ".join(snippets) ) else: answer = "I could not extract enough grounded context to form an answer." return GeneratedAnswer( answer=answer, citations=citations, provider=provider, used_fallback=True, ) def _extract_relevant_snippet(self, query_text: str, content: str) -> str: query_tokens = { token for token in re.findall(r"[a-zA-Z0-9']+", query_text.lower()) if len(token) > 2 } if not query_tokens: return content[:240] candidate_sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+|\n+", content) if segment.strip()] if not candidate_sentences: return content[:240] best_sentence = None best_score = -1 for sentence in candidate_sentences: lowered = sentence.lower() overlap = sum(1 for token in query_tokens if token in lowered) if overlap > best_score: best_score = overlap best_sentence = sentence if not best_sentence or best_score <= 0: return content[:240] return best_sentence[:240]