File size: 3,488 Bytes
0ca02ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# utils/generator.py
from typing import List, Tuple
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import nltk, re
from nltk.tokenize import sent_tokenize
import torch
import functools

# Ensure punkt is available
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt", quiet=True)

# Model names
EXTRACTIVE_MODEL_NAME = "deepset/roberta-base-squad2"
EMBED_MODEL_NAME = "all-MiniLM-L6-v2"

# Load models once
device = 0 if torch.cuda.is_available() else -1
qa_pipeline = pipeline("question-answering", model=EXTRACTIVE_MODEL_NAME, device=device)
embedder = SentenceTransformer(
    EMBED_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu"
)


@functools.lru_cache(maxsize=512)
def embed_text(text: str):
    """Cache embeddings to avoid recomputation."""
    return embedder.encode(text, convert_to_tensor=True)


def _select_relevant_sentences(query: str, chunks: List[str], top_k: int = 3) -> str:
    """Select top-k most relevant sentences from retrieved chunks."""
    sentences = []
    for ch in chunks:
        sentences.extend(sent_tokenize(ch))

    # Filter out numeric/table junk
    sentences = [s for s in sentences if not re.fullmatch(r"[\d\W]+", s.strip())]

    if not sentences:
        return ""

    query_emb = embed_text(query)
    sent_embs = embedder.encode(sentences, convert_to_tensor=True)
    cos_scores = util.cos_sim(query_emb, sent_embs)[0]
    top_results = cos_scores.topk(k=min(top_k, len(sentences)))
    selected = [sentences[idx] for idx in top_results[1]]
    return " ".join(selected)


def generate_answer(
    query: str,
    context_chunks: List[str],
) -> Tuple[str, str]:
    """
    Generate (answer, supporting_context) using extractive QA.
    """
    supporting_context = _select_relevant_sentences(query, context_chunks, top_k=5)

    if not supporting_context.strip():
        return ("I cannot find this information in the financial documents.", "")

    try:
        result = qa_pipeline({"question": query, "context": supporting_context})
        answer = normalize_answer(result.get("answer", "").strip())
        if not answer:
            return ("I cannot find this information in the financial documents.", supporting_context)

        refined_context = get_supporting_context(supporting_context, answer)
        return (answer, refined_context)

    except Exception as e:
        return (f"Error in extractive QA: {e}", supporting_context)


def normalize_answer(ans: str) -> str:
    """Normalize numeric answers like 57,094 -> $57.09 billion."""
    cleaned = ans.replace(",", "").replace("$", "").strip()
    if cleaned.isdigit():
        try:
            val = int(cleaned)
            if val >= 1e9:
                return f"${val/1e9:.2f} billion"
            elif val >= 1e6:
                return f"${val/1e6:.2f} million"
            else:
                return f"${val}"
        except Exception:
            return ans
    return ans


def get_supporting_context(context: str, answer: str, window: int = 1) -> str:
    """Return up to 2 sentences around the one containing the answer."""
    sentences = sent_tokenize(context)
    for i, sent in enumerate(sentences):
        if answer.replace(",", "") in sent.replace(",", ""):
            start = max(0, i - window)
            end = min(len(sentences), i + window + 1)
            return " ".join(sentences[start:end])
    return " ".join(sentences[:2])  # fallback