File size: 5,901 Bytes
cd266a5
 
 
 
a610ce4
cd266a5
 
 
 
 
ebbd49e
9f0da7b
4724824
9f0da7b
ebbd49e
9f0da7b
 
cd266a5
f41d618
cd266a5
50ab09a
0c81fa1
 
cd266a5
 
 
 
 
 
641185f
cd266a5
a610ce4
cd266a5
43b802c
 
a610ce4
43b802c
 
a610ce4
 
 
 
 
 
 
 
cd266a5
 
43b802c
cd266a5
743f89e
cd266a5
641185f
4724824
 
 
93a72c6
 
4724824
cd266a5
743f89e
93a72c6
641185f
cd266a5
743f89e
cd266a5
43b802c
743f89e
 
43b802c
cd266a5
6d7ba5b
cd266a5
6d7ba5b
 
cd266a5
6d7ba5b
 
cd266a5
43b802c
 
6d7ba5b
cd266a5
43b802c
cd266a5
 
6d7ba5b
43b802c
a610ce4
6d7ba5b
cd266a5
 
 
 
43b802c
a610ce4
43b802c
 
 
 
 
cd266a5
43b802c
cd266a5
 
 
641185f
6d7ba5b
cd266a5
 
 
 
6d7ba5b
cd266a5
743f89e
6d7ba5b
641185f
6d7ba5b
 
743f89e
 
 
 
 
a610ce4
cd266a5
6d7ba5b
 
743f89e
 
 
 
 
 
 
6d7ba5b
09c2f03
743f89e
09c2f03
743f89e
7b609f8
 
 
 
 
09c2f03
743f89e
 
 
 
 
cd266a5
 
 
743f89e
cd266a5
 
 
 
 
 
 
 
43b802c
cd266a5
743f89e
 
 
 
 
cd266a5
 
43b802c
cd266a5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS)
• Answer generation (Flan-T5)
Optimized for Hugging Face Spaces & Streamlit.
"""

import os
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from vectorstore import search_faiss

print("✅ qa.py loaded from:", __file__)

# ==========================================================
# 1️⃣ Hugging Face Cache Setup (Safe for Spaces) :)
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

os.environ.update({
    "HF_HOME": CACHE_DIR,
    "TRANSFORMERS_CACHE": CACHE_DIR,
    "HF_DATASETS_CACHE": CACHE_DIR,
    "HF_MODULES_CACHE": CACHE_DIR
})

# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
try:
    _query_model = SentenceTransformer(
        "intfloat/e5-small-v2",
        cache_folder=CACHE_DIR
    )
    print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
    print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
    _query_model = SentenceTransformer(
        "sentence-transformers/all-MiniLM-L6-v2",
        cache_folder=CACHE_DIR
    )
    print("✅ Loaded fallback model: all-MiniLM-L6-v2")

# ==========================================================
# 3️⃣ LLM for Answer Generation (FLAN-T5)
# ==========================================================
MODEL_NAME = "google/flan-t5-base"   # Switch to 'large' if enough memory
print(f"✅ Loading LLM: {MODEL_NAME}")

_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)

_answer_model = pipeline(
    "text2text-generation",
    model=_model,
    tokenizer=_tokenizer,
    device=-1  # CPU-safe (Hugging Face Spaces)
)

# ==========================================================
# 4️⃣ Prompt Template
# ==========================================================
PROMPT_TEMPLATE = """
You are an expert enterprise knowledge assistant.
Use ONLY the CONTEXT below to answer the QUESTION clearly, factually, and completely.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."

---
Context:
{context}
---
Question:
{query}
---
Answer:
"""

# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
    """
    Encodes the user query and retrieves top-k relevant chunks via FAISS.
    Uses 'query:' prefix (E5 training style) for semantic alignment.
    """
    if not index or not chunks:
        return []

    try:
        query_emb = _query_model.encode(
            [f"query: {query.strip()}"],
            convert_to_numpy=True,
            normalize_embeddings=True
        )[0]

        results = search_faiss(query_emb, index, chunks, top_k)
        return results

    except Exception as e:
        print(f"⚠️ Retrieval error: {e}")
        return []


# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
    """
    Generates an answer using FLAN-T5 and retrieved chunks as context.
    Includes dynamic length, sampling for expressiveness, and fallback logic.
    """
    if not retrieved_chunks:
        return "Sorry, I couldn’t find relevant information in the document."

    # Merge retrieved chunks into one coherent context
    context = "\n\n".join([
        f"[Chunk {i+1}]: {chunk.strip()}"
        for i, chunk in enumerate(retrieved_chunks)
    ])

    prompt = PROMPT_TEMPLATE.format(context=context, query=query)

    try:
        result = _answer_model(
            prompt,
            max_new_tokens=400,        # allow more elaborate responses
            do_sample=True,            # enable natural variability
            temperature=0.7,           # creativity balance
            top_p=0.9,                 # nucleus sampling for relevance
            repetition_penalty=1.15    # discourage repetition
        )

        answer = result[0]["generated_text"].strip()

        # 🧩 Handle overly short answers
        #if len(answer.split()) < 5:
          #  answer = (
            #    "The document briefly mentions this. Based on the context, here's what it implies: "
                #+ answer
            #)

        return answer

    except Exception as e:
        print(f"⚠️ Generation failed: {e}")
        return "⚠️ Error: Could not generate an answer at the moment."


# ==========================================================
# 7️⃣ Optional Local Test (runs only in dev mode)
# ==========================================================
if __name__ == "__main__":
    dummy_chunks = [
        "SAP Ariba is a cloud-based procurement solution.",
        "It helps companies manage suppliers and sourcing processes efficiently.",
        "Integration with SAP ERP allows for seamless data synchronization."
    ]
    from vectorstore import build_faiss_index

    index = build_faiss_index([
        _query_model.encode(
            [f"passage: {chunk}"],
            convert_to_numpy=True,
            normalize_embeddings=True
        )[0]
        for chunk in dummy_chunks
    ])

    query = "What is SAP Ariba used for?"
    retrieved = retrieve_chunks(query, index, dummy_chunks)
    print("🔍 Retrieved:", retrieved)
    print("💬 Answer:", generate_answer(query, retrieved))