Spaces:
Sleeping
Sleeping
Bellok
feat: Enhance answer generation and synthesis fragment processing with improved filtering and snippet extraction
0ee3d53 | """Grounded answer generation for Warbler-CDA retrieval results.""" | |
| from __future__ import annotations | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| class GeneratedAnswer: | |
| """A grounded answer plus provenance for display and API responses.""" | |
| answer: str | |
| citations: List[Dict[str, Any]] | |
| provider: str | |
| used_fallback: bool = False | |
| class AnswerGenerator: | |
| """Generate grounded answers from retrieval results with lazy provider loading.""" | |
| def __init__(self, config: Optional[Dict[str, Any]] = None): | |
| self.config = config or {} | |
| self.model_name = self.config.get("model_name", os.getenv("WARBLER_GENERATION_MODEL", "google/flan-t5-small")) | |
| self.max_context_chars = int(self.config.get("max_context_chars", os.getenv("WARBLER_GENERATION_MAX_CONTEXT_CHARS", "4000"))) | |
| self.max_new_tokens = int(self.config.get("max_new_tokens", os.getenv("WARBLER_GENERATION_MAX_NEW_TOKENS", "220"))) | |
| self.temperature = float(self.config.get("temperature", os.getenv("WARBLER_GENERATION_TEMPERATURE", "0.2"))) | |
| self.enabled = str(self.config.get("enabled", os.getenv("WARBLER_ENABLE_GENERATION", "1"))).lower() not in {"0", "false", "no", "off"} | |
| self._local_pipeline = None | |
| self._openai_client = None | |
| def generate_answer(self, query_text: str, results: List[Dict[str, Any]]) -> GeneratedAnswer: | |
| """Generate a grounded answer from retrieved results.""" | |
| citations = self._build_citations(results) | |
| if not results: | |
| return GeneratedAnswer( | |
| answer="I could not find enough relevant context to generate a grounded answer.", | |
| citations=[], | |
| provider="none", | |
| used_fallback=True, | |
| ) | |
| if not self.enabled: | |
| return self._extractive_fallback(query_text, results, citations, provider="disabled") | |
| prompt = self._build_prompt(query_text, results) | |
| openai_answer = self._try_openai_generation(prompt) | |
| if openai_answer: | |
| return GeneratedAnswer(answer=openai_answer, citations=citations, provider="openai") | |
| local_answer = self._try_local_generation(prompt) | |
| if local_answer: | |
| return GeneratedAnswer(answer=local_answer, citations=citations, provider=self.model_name) | |
| return self._extractive_fallback(query_text, results, citations, provider="extractive") | |
| def _try_openai_generation(self, prompt: str) -> Optional[str]: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return None | |
| try: | |
| if self._openai_client is None: | |
| from openai import OpenAI | |
| self._openai_client = OpenAI(api_key=api_key) | |
| response = self._openai_client.chat.completions.create( | |
| model=os.getenv("WARBLER_OPENAI_CHAT_MODEL", "gpt-4o-mini"), | |
| temperature=self.temperature, | |
| max_tokens=self.max_new_tokens, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a grounded retrieval-augmented assistant. " | |
| "Answer only from the provided context. If the context is insufficient, say so clearly." | |
| ), | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| content = response.choices[0].message.content if response.choices else None | |
| return content.strip() if content else None | |
| except Exception: | |
| return None | |
| def _try_local_generation(self, prompt: str) -> Optional[str]: | |
| try: | |
| if self._local_pipeline is None: | |
| from transformers import pipeline | |
| self._local_pipeline = pipeline( | |
| "text2text-generation", | |
| model=self.model_name, | |
| tokenizer=self.model_name, | |
| ) | |
| outputs = self._local_pipeline( | |
| prompt, | |
| max_new_tokens=self.max_new_tokens, | |
| do_sample=self.temperature > 0, | |
| temperature=self.temperature, | |
| ) | |
| if not outputs: | |
| return None | |
| text = outputs[0].get("generated_text") or outputs[0].get("summary_text") | |
| return text.strip() if text else None | |
| except Exception: | |
| return None | |
| def _build_prompt(self, query_text: str, results: List[Dict[str, Any]]) -> str: | |
| context_chunks: List[str] = [] | |
| current_length = 0 | |
| for index, result in enumerate(results, start=1): | |
| snippet = (result.get("content") or "").strip() | |
| if not snippet: | |
| continue | |
| pack_name = result.get("metadata", {}).get("pack", "unknown-pack") | |
| chunk = f"[{index}] Source: {pack_name}\nSnippet: {snippet}" | |
| if current_length + len(chunk) > self.max_context_chars: | |
| break | |
| context_chunks.append(chunk) | |
| current_length += len(chunk) | |
| context_text = "\n\n".join(context_chunks) | |
| return ( | |
| "Answer the user question using only the retrieved context below. " | |
| "Be concise but useful. Quote uncertainty when context is incomplete. " | |
| "At the end, include a short 'Sources:' line with bracketed source numbers you relied on.\n\n" | |
| f"Question: {query_text}\n\n" | |
| f"Retrieved Context:\n{context_text}\n" | |
| ) | |
| def _build_citations(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| citations = [] | |
| for index, result in enumerate(results, start=1): | |
| citations.append( | |
| { | |
| "index": index, | |
| "id": result.get("id") or result.get("content_id"), | |
| "pack": result.get("metadata", {}).get("pack"), | |
| "score": result.get("relevance_score"), | |
| } | |
| ) | |
| return citations | |
| def _extractive_fallback( | |
| self, | |
| query_text: str, | |
| results: List[Dict[str, Any]], | |
| citations: List[Dict[str, Any]], | |
| provider: str, | |
| ) -> GeneratedAnswer: | |
| snippets = [] | |
| for citation, result in zip(citations[:3], results[:3]): | |
| content = (result.get("content") or "").strip() | |
| if content: | |
| snippet = self._extract_relevant_snippet(query_text, content) | |
| snippets.append(f"[{citation['index']}] {snippet}") | |
| if snippets: | |
| answer = ( | |
| f"Based on the retrieved context for '{query_text}', the strongest evidence is: " | |
| + " ".join(snippets) | |
| ) | |
| else: | |
| answer = "I could not extract enough grounded context to form an answer." | |
| return GeneratedAnswer( | |
| answer=answer, | |
| citations=citations, | |
| provider=provider, | |
| used_fallback=True, | |
| ) | |
| def _extract_relevant_snippet(self, query_text: str, content: str) -> str: | |
| query_tokens = { | |
| token for token in re.findall(r"[a-zA-Z0-9']+", query_text.lower()) if len(token) > 2 | |
| } | |
| if not query_tokens: | |
| return content[:240] | |
| candidate_sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+|\n+", content) if segment.strip()] | |
| if not candidate_sentences: | |
| return content[:240] | |
| best_sentence = None | |
| best_score = -1 | |
| for sentence in candidate_sentences: | |
| lowered = sentence.lower() | |
| overlap = sum(1 for token in query_tokens if token in lowered) | |
| if overlap > best_score: | |
| best_score = overlap | |
| best_sentence = sentence | |
| if not best_sentence or best_score <= 0: | |
| return content[:240] | |
| return best_sentence[:240] |