File size: 7,232 Bytes
cd266a5
 
 
 
a610ce4
c7133f4
 
cd266a5
 
 
ebbd49e
e97699c
9f0da7b
e97699c
5491531
ebbd49e
c7133f4
9f0da7b
cd266a5
c7133f4
cd266a5
50ab09a
0c81fa1
cd266a5
 
 
 
 
 
641185f
cd266a5
c7133f4
cd266a5
43b802c
fea3890
c7133f4
a610ce4
c7133f4
fea3890
cd266a5
 
c7133f4
cd266a5
c7133f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c28ff15
 
 
fea3890
c7133f4
 
 
6d7ba5b
cd266a5
c7133f4
cd266a5
fea3890
 
c7133f4
 
c28ff15
 
fea3890
c28ff15
 
fea3890
c28ff15
 
fea3890
c28ff15
 
6d7ba5b
fea3890
c7133f4
fea3890
 
c7133f4
fea3890
 
 
 
c7133f4
fea3890
 
 
 
 
 
c7133f4
fea3890
c7133f4
 
fea3890
 
 
 
 
c7133f4
fea3890
 
 
 
 
 
 
c7133f4
fea3890
 
 
 
 
 
c7133f4
fea3890
c7133f4
fea3890
 
c7133f4
fea3890
 
 
c7133f4
 
 
 
 
fea3890
 
c7133f4
 
 
5491531
 
c28ff15
c7133f4
fea3890
c28ff15
fea3890
b41f253
c28ff15
5491531
fea3890
c7133f4
 
c28ff15
c7133f4
 
 
 
 
 
 
 
 
 
 
 
743f89e
c7133f4
 
 
fea3890
 
c7133f4
fea3890
 
 
 
 
 
 
 
 
 
c7133f4
 
 
 
 
fea3890
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS with neighborhood merging + re-ranking)
• Answer generation (OpenAI GPT-4o-mini → FLAN-T5 fallback)
Optimized for Hugging Face Spaces & Streamlit.
"""

import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from vectorstore import search_faiss

print("✅ qa.py loaded from:", __file__)

# ==========================================================
# 1️⃣ Hugging Face Cache Setup
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
    "HF_HOME": CACHE_DIR,
    "TRANSFORMERS_CACHE": CACHE_DIR,
    "HF_DATASETS_CACHE": CACHE_DIR,
    "HF_MODULES_CACHE": CACHE_DIR
})

# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
try:
    _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
    print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
    print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
    _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)

# ==========================================================
# 3️⃣ LLM Setup: OpenAI (primary) + FLAN (fallback)
# ==========================================================
USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
_answer_model = None  # ensures it's always defined

if USE_OPENAI:
    try:
        from openai import OpenAI
        client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        print("✅ Using OpenAI GPT-4o-mini for answer generation")
    except Exception as e:
        print(f"⚠️ Failed to initialize OpenAI client: {e}")
        USE_OPENAI = False

# Always prepare fallback safely
try:
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
    MODEL_NAME = "google/flan-t5-base"
    _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
    print("💡 Fallback FLAN-T5 ready.")
except Exception as e:
    print(f"⚠️ Could not initialize FLAN fallback: {e}")

# ==========================================================
# 4️⃣ Prompt Template
# ==========================================================
PROMPT_TEMPLATE = """
You are an enterprise knowledge assistant.
Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."

---
Context:
{context}
---
Question:
{query}
---
Answer:
"""

# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
    """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by semantic similarity."""
    if not index or not chunks:
        return []

    try:
        # Step 1: Encode the query
        query_emb = _query_model.encode(
            [f"query: {query.strip()}"],
            convert_to_numpy=True,
            normalize_embeddings=True
        )[0]

        # Step 2: Initial FAISS retrieval
        distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)

        # Step 3: Merge neighboring chunks
        merged_chunks = []
        for idx in indices[0]:
            neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
            merged_chunks.append(" ".join(neighbors))

        # Step 4: Re-rank using cosine similarity
        chunk_vecs = np.array([
            _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
            for c in merged_chunks
        ])
        scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
        sorted_indices = np.argsort(scores)[::-1]

        # Step 5: Return top-ranked merged chunks
        return [merged_chunks[i] for i in sorted_indices[:top_k]]

    except Exception as e:
        print(f"⚠️ Retrieval error: {e}")
        return []


# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
    """Generate factual, context-grounded answers using OpenAI or fallback FLAN-T5."""
    if not retrieved_chunks:
        return "Sorry, I couldn’t find relevant information in the document."

    # Build full context
    context = "\n\n".join([
        f"[Chunk {i+1}]: {chunk.strip()}"
        for i, chunk in enumerate(retrieved_chunks)
    ])
    prompt = PROMPT_TEMPLATE.format(context=context, query=query)

    # --- Try OpenAI first ---
    if USE_OPENAI:
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": "You are a precise enterprise document assistant."},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.4,
                max_tokens=800,
            )
            return response.choices[0].message.content.strip()

        except Exception as e:
            print(f"⚠️ OpenAI generation failed: {e}. Switching to fallback...")

    # --- Fallback to FLAN-T5 ---
    try:
        if _answer_model:
            result = _answer_model(
                prompt,
                max_new_tokens=600,
                do_sample=False,
                temperature=0.3
            )
            return result[0]["generated_text"].strip()
        else:
            return "⚠️ Error: Fallback model not available."
    except Exception as e:
        print(f"⚠️ Fallback model failed: {e}")
        return "⚠️ Error: Both OpenAI and fallback generation failed."


# ==========================================================
# 7️⃣ Local Test
# ==========================================================
if __name__ == "__main__":
    dummy_chunks = [
        "Step 1: Open the dashboard and navigate to reports.",
        "Step 2: Click 'Export' to download a CSV summary.",
        "Step 3: Review the generated report in your downloads folder."
    ]
    from vectorstore import build_faiss_index

    index = build_faiss_index([
        _query_model.encode(
            [f"passage: {chunk}"],
            convert_to_numpy=True,
            normalize_embeddings=True
        )[0]
        for chunk in dummy_chunks
    ])

    query = "What are the steps to export a report?"
    retrieved = retrieve_chunks(query, index, dummy_chunks)
    print("🔍 Retrieved:", retrieved)
    print("💬 Answer:", generate_answer(query, retrieved))