# utils/generator.py from typing import List, Tuple from transformers import pipeline from sentence_transformers import SentenceTransformer, util import nltk, re from nltk.tokenize import sent_tokenize import torch import functools # Ensure punkt is available try: nltk.data.find("tokenizers/punkt") except LookupError: nltk.download("punkt", quiet=True) # Model names EXTRACTIVE_MODEL_NAME = "deepset/roberta-base-squad2" EMBED_MODEL_NAME = "all-MiniLM-L6-v2" # Load models once device = 0 if torch.cuda.is_available() else -1 qa_pipeline = pipeline("question-answering", model=EXTRACTIVE_MODEL_NAME, device=device) embedder = SentenceTransformer( EMBED_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu" ) @functools.lru_cache(maxsize=512) def embed_text(text: str): """Cache embeddings to avoid recomputation.""" return embedder.encode(text, convert_to_tensor=True) def _select_relevant_sentences(query: str, chunks: List[str], top_k: int = 3) -> str: """Select top-k most relevant sentences from retrieved chunks.""" sentences = [] for ch in chunks: sentences.extend(sent_tokenize(ch)) # Filter out numeric/table junk sentences = [s for s in sentences if not re.fullmatch(r"[\d\W]+", s.strip())] if not sentences: return "" query_emb = embed_text(query) sent_embs = embedder.encode(sentences, convert_to_tensor=True) cos_scores = util.cos_sim(query_emb, sent_embs)[0] top_results = cos_scores.topk(k=min(top_k, len(sentences))) selected = [sentences[idx] for idx in top_results[1]] return " ".join(selected) def generate_answer( query: str, context_chunks: List[str], ) -> Tuple[str, str]: """ Generate (answer, supporting_context) using extractive QA. """ supporting_context = _select_relevant_sentences(query, context_chunks, top_k=5) if not supporting_context.strip(): return ("I cannot find this information in the financial documents.", "") try: result = qa_pipeline({"question": query, "context": supporting_context}) answer = normalize_answer(result.get("answer", "").strip()) if not answer: return ("I cannot find this information in the financial documents.", supporting_context) refined_context = get_supporting_context(supporting_context, answer) return (answer, refined_context) except Exception as e: return (f"Error in extractive QA: {e}", supporting_context) def normalize_answer(ans: str) -> str: """Normalize numeric answers like 57,094 -> $57.09 billion.""" cleaned = ans.replace(",", "").replace("$", "").strip() if cleaned.isdigit(): try: val = int(cleaned) if val >= 1e9: return f"${val/1e9:.2f} billion" elif val >= 1e6: return f"${val/1e6:.2f} million" else: return f"${val}" except Exception: return ans return ans def get_supporting_context(context: str, answer: str, window: int = 1) -> str: """Return up to 2 sentences around the one containing the answer.""" sentences = sent_tokenize(context) for i, sent in enumerate(sentences): if answer.replace(",", "") in sent.replace(",", ""): start = max(0, i - window) end = min(len(sentences), i + window + 1) return " ".join(sentences[start:end]) return " ".join(sentences[:2]) # fallback