File size: 5,873 Bytes
cd266a5
 
 
 
a610ce4
cd266a5
 
 
 
 
ebbd49e
9f0da7b
4724824
9f0da7b
ebbd49e
9f0da7b
 
cd266a5
43b802c
cd266a5
50ab09a
0c81fa1
 
cd266a5
 
 
 
 
 
641185f
cd266a5
a610ce4
cd266a5
a610ce4
43b802c
 
a610ce4
43b802c
 
a610ce4
 
 
 
 
 
 
 
cd266a5
 
43b802c
cd266a5
a610ce4
cd266a5
641185f
4724824
 
 
93a72c6
 
4724824
cd266a5
a610ce4
93a72c6
641185f
cd266a5
43b802c
cd266a5
43b802c
 
 
 
cd266a5
6d7ba5b
cd266a5
6d7ba5b
 
cd266a5
6d7ba5b
 
cd266a5
43b802c
 
6d7ba5b
cd266a5
43b802c
cd266a5
 
6d7ba5b
43b802c
a610ce4
6d7ba5b
cd266a5
 
 
 
a610ce4
43b802c
a610ce4
43b802c
 
 
 
 
cd266a5
43b802c
cd266a5
 
 
641185f
6d7ba5b
cd266a5
 
 
 
6d7ba5b
cd266a5
6d7ba5b
641185f
6d7ba5b
 
a610ce4
6d7ba5b
a610ce4
 
cd266a5
6d7ba5b
 
09c2f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d7ba5b
09c2f03
 
 
 
 
 
 
cd266a5
 
 
a610ce4
cd266a5
 
 
 
 
 
 
 
43b802c
 
cd266a5
43b802c
cd266a5
 
43b802c
cd266a5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS)
• Answer generation (Flan-T5)
Optimized for Hugging Face Spaces & Streamlit.
"""

import os
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from vectorstore import search_faiss

print("✅ qa.py loaded from:", __file__)

# ==========================================================
# 1️⃣ Hugging Face Cache Setup (Safe for Spaces)
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

os.environ.update({
    "HF_HOME": CACHE_DIR,
    "TRANSFORMERS_CACHE": CACHE_DIR,
    "HF_DATASETS_CACHE": CACHE_DIR,
    "HF_MODULES_CACHE": CACHE_DIR
})

# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
# Use E5-small-v2 for retrieval consistency with embeddings.py
try:
    _query_model = SentenceTransformer(
        "intfloat/e5-small-v2",
        cache_folder=CACHE_DIR
    )
    print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
    print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
    _query_model = SentenceTransformer(
        "sentence-transformers/all-MiniLM-L6-v2",
        cache_folder=CACHE_DIR
    )
    print("✅ Loaded fallback model: all-MiniLM-L6-v2")

# ==========================================================
# 3️⃣ LLM for Answer Generation (FLAN-T5)
# ==========================================================
MODEL_NAME = "google/flan-t5-base"   # switch to 'large' if RAM allows
print(f"✅ Loading LLM: {MODEL_NAME}")

_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)

_answer_model = pipeline(
    "text2text-generation",
    model=_model,
    tokenizer=_tokenizer,
    device=-1  # CPU-safe for Spaces
)

# ==========================================================
# 4️⃣ Prompt Template (concise and factual)
# ==========================================================
PROMPT_TEMPLATE = """
You are an expert enterprise assistant.
Using ONLY the CONTEXT below, answer the QUESTION clearly and factually.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."

---
Context:
{context}
---
Question:
{query}
---
Answer:
"""

# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
    """
    Encodes the user query and retrieves top-k relevant chunks via FAISS.
    Uses 'query:' prefix (E5 training style) for semantic alignment.
    """
    if not index or not chunks:
        return []

    try:
        # E5 expects 'query:' prefix for better retrieval accuracy
        query_emb = _query_model.encode(
            [f"query: {query.strip()}"],
            convert_to_numpy=True,
            normalize_embeddings=True
        )[0]

        results = search_faiss(query_emb, index, chunks, top_k)
        return results

    except Exception as e:
        print(f"⚠️ Retrieval error: {e}")
        return []


# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
    """
    Generates an answer using FLAN-T5 and retrieved chunks as context.
    """
    if not retrieved_chunks:
        return "Sorry, I couldn’t find relevant information in the document."

    # Merge retrieved chunks for context
    context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])

    # Build structured prompt
    prompt = PROMPT_TEMPLATE.format(context=context, query=query)

    try:
    result = _answer_model(
        prompt,
        max_new_tokens=350,        # allow longer, more complete answers
        do_sample=True,            # enable sampling for natural flow
        temperature=0.7,           # slightly higher = more expressive responses
        top_p=0.95,                # nucleus sampling for coherence
        repetition_penalty=1.2     # discourages repetitive phrasing
    )

    answer = result[0]["generated_text"].strip()

    # 🧩 If the model outputs something too short, expand gracefully
    if len(answer.split()) < 8:
        answer = (
            "The document mentions this briefly. Based on the context, here's what it suggests: "
            + answer
        )

    return answer

except Exception as e:
    print(f"⚠️ Generation failed: {e}")
    return "⚠️ Error: Could not generate an answer at the moment."



# ==========================================================
# 7️⃣ Optional Local Test
# ==========================================================
if __name__ == "__main__":
    dummy_chunks = [
        "SAP Ariba is a cloud-based procurement solution.",
        "It helps companies manage suppliers and sourcing processes efficiently.",
        "Integration with SAP ERP allows for seamless data synchronization."
    ]
    from vectorstore import build_faiss_index
    import numpy as np

    index = build_faiss_index([
        _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
        for chunk in dummy_chunks
    ])

    query = "What is SAP Ariba used for?"
    retrieved = retrieve_chunks(query, index, dummy_chunks)
    print("🔍 Retrieved:", retrieved)
    print("💬 Answer:", generate_answer(query, retrieved))