File size: 7,331 Bytes
cd266a5
386cde6
 
 
 
 
 
cd266a5
 
ebbd49e
e97699c
c91d8df
9f0da7b
874e5e3
ebbd49e
386cde6
9f0da7b
cd266a5
78f9c7e
cd266a5
50ab09a
0c81fa1
cd266a5
 
 
 
 
 
6718956
 
d14744d
cd266a5
43b802c
fea3890
d7aaa8f
a610ce4
43cd83d
fea3890
cd266a5
 
386cde6
cd266a5
c91d8df
 
386cde6
c91d8df
 
 
 
 
d14744d
c91d8df
 
 
 
 
 
 
 
 
 
 
386cde6
d14744d
c91d8df
 
 
6d7ba5b
cd266a5
386cde6
cd266a5
386cde6
a96ebe0
 
c91d8df
 
43cd83d
386cde6
a96ebe0
 
386cde6
 
 
fea3890
386cde6
fea3890
d14744d
386cde6
fea3890
 
c91d8df
 
386cde6
 
 
 
 
3cf73df
 
d14744d
 
 
 
3cf73df
386cde6
c91d8df
 
 
fea3890
 
0a64b8d
fea3890
386cde6
0a64b8d
cd6e69b
 
 
0a64b8d
fea3890
 
 
d14744d
cd6e69b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a96ebe0
fea3890
c7133f4
a96ebe0
cd6e69b
a96ebe0
 
 
cd6e69b
 
a96ebe0
 
 
cd6e69b
a96ebe0
 
 
 
 
 
 
 
cd6e69b
a96ebe0
 
cd6e69b
c91d8df
743f89e
d4d8027
386cde6
fea3890
a96ebe0
fea3890
d14744d
fea3890
 
43cd83d
fea3890
 
43cd83d
 
fea3890
c91d8df
43cd83d
 
c91d8df
 
43cd83d
 
 
386cde6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
qa.py — Phi-2 Fast + Smart Reasoning Mode (Hybrid)
-------------------------------------------------
✅ Uses intfloat/e5-small-v2 for embeddings
✅ Uses microsoft/phi-2 (fast CPU quantized)
✅ Reasoning Mode toggle integrated cleanly
✅ Retrieval unaffected by reasoning mode
"""

import os
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

print("✅ qa.py (Phi-2 Hybrid) loaded from:", __file__)

# ==========================================================
# 1️⃣ Cache Setups
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
    "HF_HOME": CACHE_DIR,
    "TRANSFORMERS_CACHE": CACHE_DIR,
    "HF_DATASETS_CACHE": CACHE_DIR,
    "HF_MODULES_CACHE": CACHE_DIR
})

# ==========================================================
# 2️⃣ Embedding Model
# ==========================================================
try:
    _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
    print("✅ Loaded embedding model: intfloat/e5-small-v2")
except Exception as e:
    print(f"⚠️ Fallback to MiniLM due to {e}")
    _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)

# ==========================================================
# 3️⃣ Phi-2 LLM Setup (Optimized for CPU)
# ==========================================================
try:
    MODEL_NAME = "microsoft/phi-2"
    print(f"✅ Loading LLM: {MODEL_NAME} (optimized for reasoning)")

    _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    _model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        cache_dir=CACHE_DIR,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
        low_cpu_mem_usage=True,
    ).to("cpu")

    _answer_model = pipeline(
        "text-generation",
        model=_model,
        tokenizer=_tokenizer,
        device=-1,
        model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True},
    )

    print("✅ Phi-2 text-generation pipeline ready.")

except Exception as e:
    print(f"⚠️ Phi-2 load failed: {e}")
    _answer_model = None

# ==========================================================
# 4️⃣ Prompt Templates
# ==========================================================
STRICT_PROMPT = (
    "Answer based ONLY on the context below.\n"
    "If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n"
    "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)

REASONING_PROMPT = (
    "You are an expert assistant. Use the context and your reasoning ability to form a clear, step-by-step answer.\n"
    "Be concise yet complete. If the context doesn’t contain the answer, say: 'I don't know based on the provided document.'\n\n"
    "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)

# ==========================================================
# 5️⃣ Chunk Retrieval (Unchanged — Fast)
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
    """Fast FAISS retrieval using cosine similarity."""
    if not index or not chunks:
        return []

    try:
        q_emb = _query_model.encode(
            [f"query: {query.strip()}"],
            convert_to_numpy=True,
            normalize_embeddings=True
        )[0]
        distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)

        selected = set()
        for idx in indices[0]:
            for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
                selected.add(i)

        return [chunks[i] for i in sorted(selected)]
    except Exception as e:
        print(f"⚠️ Retrieval error: {e}")
        return []

# ==========================================================
# 6️⃣ Answer Generation (Enhanced — Balanced Reasoning + Speed)
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
    """
    Generate answers with Phi-2.
    - reasoning_mode=False → strict factual, fast
    - reasoning_mode=True  → analytical, richer reasoning (slower)
    """
    if not retrieved_chunks:
        return "Sorry, I couldn’t find relevant information in the document."

    context = "\n".join(chunk.strip() for chunk in retrieved_chunks)

    # 🧠 Reasoning prompt: encourages explanation, not just lookup
    REASONING_PROMPT = (
        "You are an expert assistant with strong reasoning skills.\n"
        "Think step by step and form a detailed, logical answer.\n"
        "You can combine hints from the context with your general understanding.\n"
        "If the context doesn't mention the answer, acknowledge that.\n\n"
        "Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
    )

    # ⚡ Strict factual prompt
    STRICT_PROMPT = (
        "Answer based ONLY on the context below.\n"
        "If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n"
        "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
    )

    prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)

    try:
        if reasoning_mode:
            # 🧩 The “brainy” config that produced the great long answer
            result = _answer_model(
                prompt,
                max_new_tokens=180,
                temperature=0.45,
                do_sample=False,  # reasoning but deterministic
                pad_token_id=_tokenizer.eos_token_id,
            )
        else:
            # ⚡ Fast factual config
            result = _answer_model(
                prompt,
                max_new_tokens=120,
                temperature=0.2,
                do_sample=False,
                pad_token_id=_tokenizer.eos_token_id,
            )

        raw = result[0]["generated_text"].strip()
        if "Answer:" in raw:
            raw = raw.split("Answer:")[-1].strip()
        return raw

    except Exception as e:
        print(f"⚠️ Generation failed: {e}")
        return "⚠️ Error: Could not generate an answer."


# ==========================================================
# 7️⃣ Local Test
# ==========================================================
if __name__ == "__main__":
    from vectorstore import build_faiss_index
    dummy_chunks = [
        "Step 1: Open the dashboard and navigate to reports.",
        "Step 2: Click 'Export' to download a CSV summary.",
        "Step 3: Review the generated report in your downloads folder."
    ]
    embeddings = [
        _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
        for chunk in dummy_chunks
    ]
    index = build_faiss_index(embeddings)

    query = "What are the steps to export a report?"
    retrieved = retrieve_chunks(query, index, dummy_chunks)

    print("\n--- Strict Mode ---")
    print(generate_answer(query, retrieved, reasoning_mode=False))

    print("\n--- Reasoning Mode ---")
    print(generate_answer(query, retrieved, reasoning_mode=True))