warbler-cda / warbler_cda /answer_generator.py
Bellok
feat: Enhance answer generation and synthesis fragment processing with improved filtering and snippet extraction
0ee3d53
"""Grounded answer generation for Warbler-CDA retrieval results."""
from __future__ import annotations
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@dataclass
class GeneratedAnswer:
"""A grounded answer plus provenance for display and API responses."""
answer: str
citations: List[Dict[str, Any]]
provider: str
used_fallback: bool = False
class AnswerGenerator:
"""Generate grounded answers from retrieval results with lazy provider loading."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
self.model_name = self.config.get("model_name", os.getenv("WARBLER_GENERATION_MODEL", "google/flan-t5-small"))
self.max_context_chars = int(self.config.get("max_context_chars", os.getenv("WARBLER_GENERATION_MAX_CONTEXT_CHARS", "4000")))
self.max_new_tokens = int(self.config.get("max_new_tokens", os.getenv("WARBLER_GENERATION_MAX_NEW_TOKENS", "220")))
self.temperature = float(self.config.get("temperature", os.getenv("WARBLER_GENERATION_TEMPERATURE", "0.2")))
self.enabled = str(self.config.get("enabled", os.getenv("WARBLER_ENABLE_GENERATION", "1"))).lower() not in {"0", "false", "no", "off"}
self._local_pipeline = None
self._openai_client = None
def generate_answer(self, query_text: str, results: List[Dict[str, Any]]) -> GeneratedAnswer:
"""Generate a grounded answer from retrieved results."""
citations = self._build_citations(results)
if not results:
return GeneratedAnswer(
answer="I could not find enough relevant context to generate a grounded answer.",
citations=[],
provider="none",
used_fallback=True,
)
if not self.enabled:
return self._extractive_fallback(query_text, results, citations, provider="disabled")
prompt = self._build_prompt(query_text, results)
openai_answer = self._try_openai_generation(prompt)
if openai_answer:
return GeneratedAnswer(answer=openai_answer, citations=citations, provider="openai")
local_answer = self._try_local_generation(prompt)
if local_answer:
return GeneratedAnswer(answer=local_answer, citations=citations, provider=self.model_name)
return self._extractive_fallback(query_text, results, citations, provider="extractive")
def _try_openai_generation(self, prompt: str) -> Optional[str]:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
return None
try:
if self._openai_client is None:
from openai import OpenAI
self._openai_client = OpenAI(api_key=api_key)
response = self._openai_client.chat.completions.create(
model=os.getenv("WARBLER_OPENAI_CHAT_MODEL", "gpt-4o-mini"),
temperature=self.temperature,
max_tokens=self.max_new_tokens,
messages=[
{
"role": "system",
"content": (
"You are a grounded retrieval-augmented assistant. "
"Answer only from the provided context. If the context is insufficient, say so clearly."
),
},
{"role": "user", "content": prompt},
],
)
content = response.choices[0].message.content if response.choices else None
return content.strip() if content else None
except Exception:
return None
def _try_local_generation(self, prompt: str) -> Optional[str]:
try:
if self._local_pipeline is None:
from transformers import pipeline
self._local_pipeline = pipeline(
"text2text-generation",
model=self.model_name,
tokenizer=self.model_name,
)
outputs = self._local_pipeline(
prompt,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature,
)
if not outputs:
return None
text = outputs[0].get("generated_text") or outputs[0].get("summary_text")
return text.strip() if text else None
except Exception:
return None
def _build_prompt(self, query_text: str, results: List[Dict[str, Any]]) -> str:
context_chunks: List[str] = []
current_length = 0
for index, result in enumerate(results, start=1):
snippet = (result.get("content") or "").strip()
if not snippet:
continue
pack_name = result.get("metadata", {}).get("pack", "unknown-pack")
chunk = f"[{index}] Source: {pack_name}\nSnippet: {snippet}"
if current_length + len(chunk) > self.max_context_chars:
break
context_chunks.append(chunk)
current_length += len(chunk)
context_text = "\n\n".join(context_chunks)
return (
"Answer the user question using only the retrieved context below. "
"Be concise but useful. Quote uncertainty when context is incomplete. "
"At the end, include a short 'Sources:' line with bracketed source numbers you relied on.\n\n"
f"Question: {query_text}\n\n"
f"Retrieved Context:\n{context_text}\n"
)
def _build_citations(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
citations = []
for index, result in enumerate(results, start=1):
citations.append(
{
"index": index,
"id": result.get("id") or result.get("content_id"),
"pack": result.get("metadata", {}).get("pack"),
"score": result.get("relevance_score"),
}
)
return citations
def _extractive_fallback(
self,
query_text: str,
results: List[Dict[str, Any]],
citations: List[Dict[str, Any]],
provider: str,
) -> GeneratedAnswer:
snippets = []
for citation, result in zip(citations[:3], results[:3]):
content = (result.get("content") or "").strip()
if content:
snippet = self._extract_relevant_snippet(query_text, content)
snippets.append(f"[{citation['index']}] {snippet}")
if snippets:
answer = (
f"Based on the retrieved context for '{query_text}', the strongest evidence is: "
+ " ".join(snippets)
)
else:
answer = "I could not extract enough grounded context to form an answer."
return GeneratedAnswer(
answer=answer,
citations=citations,
provider=provider,
used_fallback=True,
)
def _extract_relevant_snippet(self, query_text: str, content: str) -> str:
query_tokens = {
token for token in re.findall(r"[a-zA-Z0-9']+", query_text.lower()) if len(token) > 2
}
if not query_tokens:
return content[:240]
candidate_sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+|\n+", content) if segment.strip()]
if not candidate_sentences:
return content[:240]
best_sentence = None
best_score = -1
for sentence in candidate_sentences:
lowered = sentence.lower()
overlap = sum(1 for token in query_tokens if token in lowered)
if overlap > best_score:
best_score = overlap
best_sentence = sentence
if not best_sentence or best_score <= 0:
return content[:240]
return best_sentence[:240]