File size: 1,740 Bytes
0c81fa1
ebbd49e
0c81fa1
ebbd49e
 
50ab09a
0c81fa1
 
641185f
 
 
ebbd49e
641185f
b78fee4
36665e6
0c81fa1
ebbd49e
0c81fa1
 
 
 
 
641185f
93a72c6
641185f
 
 
 
 
 
 
 
 
 
93a72c6
 
 
 
 
 
641185f
 
 
 
 
 
 
 
 
 
36665e6
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# ----------------------------
# Hugging Face cache bootstrap
# ----------------------------
import os

CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
os.environ["HF_MODULES_CACHE"] = CACHE_DIR

print(f"✅ Using Hugging Face cache at {CACHE_DIR}")

# ----------------------------
# Imports AFTER cache bootstrap
# ----------------------------
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from vectorstore import search_faiss

# ----------------------------
# Query embedding model
# ----------------------------
_query_model = SentenceTransformer(
    "sentence-transformers/all-MiniLM-L6-v2",
    cache_folder=CACHE_DIR
)

# ----------------------------
# LLM for answers
# ----------------------------
MODEL_NAME = "google/flan-t5-small"

_answer_model = pipeline(
    "text2text-generation",
    model=MODEL_NAME,
    cache_dir=CACHE_DIR
)

# ----------------------------
# Functions
# ----------------------------
def retrieve_chunks(query, index, chunks, top_k=3):
    q_emb = _query_model.encode([query], convert_to_numpy=True)[0]
    return search_faiss(q_emb, index, chunks, top_k)

def generate_answer(query, retrieved_chunks):
    if not retrieved_chunks:
        return "Sorry, I could not find relevant information."

    context = " ".join(retrieved_chunks)
    prompt = (
        "You are an assistant. Use the context to answer the question clearly.\n"
        f"Context:\n{context}\n\nQuestion:\n{query}\n\nAnswer:"
    )
    result = _answer_model(prompt, max_length=300, do_sample=False)
    return result[0]["generated_text"].strip()