jayyd commited on
Commit
0ca02ad
·
verified ·
1 Parent(s): 0356af1

Update utils/generator.py

Browse files
Files changed (1) hide show
  1. utils/generator.py +104 -0
utils/generator.py CHANGED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/generator.py
2
+ from typing import List, Tuple
3
+ from transformers import pipeline
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import nltk, re
6
+ from nltk.tokenize import sent_tokenize
7
+ import torch
8
+ import functools
9
+
10
+ # Ensure punkt is available
11
+ try:
12
+ nltk.data.find("tokenizers/punkt")
13
+ except LookupError:
14
+ nltk.download("punkt", quiet=True)
15
+
16
+ # Model names
17
+ EXTRACTIVE_MODEL_NAME = "deepset/roberta-base-squad2"
18
+ EMBED_MODEL_NAME = "all-MiniLM-L6-v2"
19
+
20
+ # Load models once
21
+ device = 0 if torch.cuda.is_available() else -1
22
+ qa_pipeline = pipeline("question-answering", model=EXTRACTIVE_MODEL_NAME, device=device)
23
+ embedder = SentenceTransformer(
24
+ EMBED_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu"
25
+ )
26
+
27
+
28
+ @functools.lru_cache(maxsize=512)
29
+ def embed_text(text: str):
30
+ """Cache embeddings to avoid recomputation."""
31
+ return embedder.encode(text, convert_to_tensor=True)
32
+
33
+
34
+ def _select_relevant_sentences(query: str, chunks: List[str], top_k: int = 3) -> str:
35
+ """Select top-k most relevant sentences from retrieved chunks."""
36
+ sentences = []
37
+ for ch in chunks:
38
+ sentences.extend(sent_tokenize(ch))
39
+
40
+ # Filter out numeric/table junk
41
+ sentences = [s for s in sentences if not re.fullmatch(r"[\d\W]+", s.strip())]
42
+
43
+ if not sentences:
44
+ return ""
45
+
46
+ query_emb = embed_text(query)
47
+ sent_embs = embedder.encode(sentences, convert_to_tensor=True)
48
+ cos_scores = util.cos_sim(query_emb, sent_embs)[0]
49
+ top_results = cos_scores.topk(k=min(top_k, len(sentences)))
50
+ selected = [sentences[idx] for idx in top_results[1]]
51
+ return " ".join(selected)
52
+
53
+
54
+ def generate_answer(
55
+ query: str,
56
+ context_chunks: List[str],
57
+ ) -> Tuple[str, str]:
58
+ """
59
+ Generate (answer, supporting_context) using extractive QA.
60
+ """
61
+ supporting_context = _select_relevant_sentences(query, context_chunks, top_k=5)
62
+
63
+ if not supporting_context.strip():
64
+ return ("I cannot find this information in the financial documents.", "")
65
+
66
+ try:
67
+ result = qa_pipeline({"question": query, "context": supporting_context})
68
+ answer = normalize_answer(result.get("answer", "").strip())
69
+ if not answer:
70
+ return ("I cannot find this information in the financial documents.", supporting_context)
71
+
72
+ refined_context = get_supporting_context(supporting_context, answer)
73
+ return (answer, refined_context)
74
+
75
+ except Exception as e:
76
+ return (f"Error in extractive QA: {e}", supporting_context)
77
+
78
+
79
+ def normalize_answer(ans: str) -> str:
80
+ """Normalize numeric answers like 57,094 -> $57.09 billion."""
81
+ cleaned = ans.replace(",", "").replace("$", "").strip()
82
+ if cleaned.isdigit():
83
+ try:
84
+ val = int(cleaned)
85
+ if val >= 1e9:
86
+ return f"${val/1e9:.2f} billion"
87
+ elif val >= 1e6:
88
+ return f"${val/1e6:.2f} million"
89
+ else:
90
+ return f"${val}"
91
+ except Exception:
92
+ return ans
93
+ return ans
94
+
95
+
96
+ def get_supporting_context(context: str, answer: str, window: int = 1) -> str:
97
+ """Return up to 2 sentences around the one containing the answer."""
98
+ sentences = sent_tokenize(context)
99
+ for i, sent in enumerate(sentences):
100
+ if answer.replace(",", "") in sent.replace(",", ""):
101
+ start = max(0, i - window)
102
+ end = min(len(sentences), i + window + 1)
103
+ return " ".join(sentences[start:end])
104
+ return " ".join(sentences[:2]) # fallback