File size: 5,506 Bytes
cd266a5
d7aaa8f
 
 
 
 
 
 
cd266a5
 
ebbd49e
e97699c
9f0da7b
e97699c
ebbd49e
d7aaa8f
9f0da7b
cd266a5
d7aaa8f
cd266a5
50ab09a
0c81fa1
cd266a5
 
 
 
 
 
641185f
cd266a5
d7aaa8f
cd266a5
43b802c
fea3890
d7aaa8f
a610ce4
d7aaa8f
fea3890
cd266a5
 
d7aaa8f
cd266a5
d7aaa8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d7ba5b
cd266a5
d7aaa8f
cd266a5
49c4268
d7aaa8f
 
 
49c4268
 
 
6d7ba5b
fea3890
d7aaa8f
fea3890
d7aaa8f
 
fea3890
 
 
 
d7aaa8f
 
fea3890
d7aaa8f
49c4268
 
d7aaa8f
 
fea3890
d7aaa8f
 
fea3890
 
 
 
 
d7aaa8f
fea3890
 
d7aaa8f
fea3890
 
 
49c4268
fea3890
 
c7133f4
a5ea9d2
 
d7aaa8f
a5ea9d2
d7aaa8f
a5ea9d2
 
 
d4d8027
743f89e
d4d8027
 
fea3890
 
d7aaa8f
fea3890
 
d7aaa8f
fea3890
 
 
 
 
 
d7aaa8f
 
fea3890
d7aaa8f
 
fea3890
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
qa.py — Retrieval + Generation (Phi-2 Fast Reasoning)
-----------------------------------------------------
Uses:
 - intfloat/e5-small-v2 for embeddings
 - microsoft/phi-2 as main LLM (fast, strong reasoning)
 - Optional fallback: google/flan-t5-base
Optimized for CPU inference (Hugging Face Spaces / Streamlit)
"""

import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

print("✅ qa.py (Phi-2 optimized) loaded from:", __file__)

# ==========================================================
# 1️⃣ Cache Setup
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
    "HF_HOME": CACHE_DIR,
    "TRANSFORMERS_CACHE": CACHE_DIR,
    "HF_DATASETS_CACHE": CACHE_DIR,
    "HF_MODULES_CACHE": CACHE_DIR
})

# ==========================================================
# 2️⃣ Embedding Model
# ==========================================================
try:
    _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
    print("✅ Loaded embedding model: intfloat/e5-small-v2")
except Exception as e:
    print(f"⚠️ Fallback to MiniLM due to {e}")
    _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)

# ==========================================================
# 3️⃣ Phi-2 LLM Setup
# ==========================================================
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

try:
    MODEL_NAME = "microsoft/phi-2"
    print(f"✅ Loading LLM: {MODEL_NAME}")
    _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    _model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        cache_dir=CACHE_DIR,
        torch_dtype="auto",
        low_cpu_mem_usage=True,
    )
    _answer_model = pipeline(
        "text-generation",
        model=_model,
        tokenizer=_tokenizer,
        device=-1,
        max_new_tokens=250,
        do_sample=False,
    )
    print("✅ Phi-2 generation pipeline ready.")
except Exception as e:
    print(f"⚠️ Phi-2 load failed: {e}")
    _answer_model = None

# ==========================================================
# 4️⃣ Prompt Template
# ==========================================================
PROMPT_TEMPLATE = (
    "You are an expert assistant for enterprise document understanding.\n"
    "Use ONLY the context below to answer the question clearly and factually.\n"
    "If the context doesn’t contain the answer, reply: "
    "'I don't know based on the provided document.'\n\n"
    "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)

# ==========================================================
# 5️⃣ Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
    """Fast FAISS retrieval with E5 embeddings."""
    if not index or not chunks:
        return []

    try:
        q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
        distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)

        # Merge nearby chunks for continuity
        selected = set()
        for idx in indices[0]:
            for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
                selected.add(i)

        ordered_chunks = [chunks[i] for i in sorted(selected)]
        return ordered_chunks
    except Exception as e:
        print(f"⚠️ Retrieval error: {e}")
        return []

# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
    """Generate grounded answers using Phi-2."""
    if not retrieved_chunks:
        return "Sorry, I couldn’t find relevant information in the document."

    context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
    prompt = PROMPT_TEMPLATE.format(context=context, query=query)

    try:
        result = _answer_model(
            prompt,
            max_new_tokens=250,
            do_sample=False,
            early_stopping=True,
            pad_token_id=_tokenizer.eos_token_id,
        )
        answer = result[0]["generated_text"].strip()
        return answer
    except Exception as e:
        print(f"⚠️ Generation failed: {e}")
        return "⚠️ Error: Could not generate an answer at the moment."

# ==========================================================
# 7️⃣ Local Test (optional)
# ==========================================================
if __name__ == "__main__":
    from vectorstore import build_faiss_index
    dummy_chunks = [
        "Step 1: Open the dashboard and navigate to reports.",
        "Step 2: Click 'Export' to download a CSV summary.",
        "Step 3: Review the generated report in your downloads folder."
    ]

    embeddings = [
        _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
        for chunk in dummy_chunks
    ]
    index = build_faiss_index(embeddings)
    query = "What are the steps to export a report?"
    retrieved = retrieve_chunks(query, index, dummy_chunks)
    print("🔍 Retrieved:", retrieved)
    print("💬 Answer:", generate_answer(query, retrieved))