LLM / rag.py
anaghanagesh's picture
Update rag.py
bb4c3e3 verified
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2"
)
def retrieve_context(query, papers):
try:
texts = [
p["abstract"]
for p in papers
if p["abstract"]
]
if len(texts) == 0:
return ""
embeddings = model.encode(
texts
)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(
dimension
)
index.add(
np.array(embeddings).astype(
"float32"
)
)
query_embedding = model.encode(
[query]
)
_, indices = index.search(
np.array(query_embedding).astype(
"float32"
),
2
)
retrieved = [
texts[i]
for i in indices[0]
]
return "\n".join(retrieved)
except Exception as e:
print("RAG Error:", e)
return ""