my-research-agent / src /vector_store.py
1Paras1's picture
Update src/vector_store.py
851d050 verified
import os
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import torch
# Use /tmp for cache directory which should be writable
CACHE_DIR = "/tmp/huggingface_cache"
FAISS_INDEX_PATH = "/tmp/faiss_index_scientific_papers"
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
# Create cache directories
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(os.path.dirname(FAISS_INDEX_PATH), exist_ok=True)
def get_vector_store():
"""Creates or loads the FAISS vector store."""
# Initialize embeddings with proper cache folder
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
cache_folder=CACHE_DIR,
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
)
# Check if FAISS index already exists
if os.path.exists(FAISS_INDEX_PATH):
try:
return FAISS.load_local(
FAISS_INDEX_PATH,
embeddings,
allow_dangerous_deserialization=True
)
except Exception as e:
print(f"Failed to load existing index: {e}")
# Continue to create new index
# Create a new FAISS index
print("Creating new FAISS index...")
try:
full_dataset = load_dataset("franz96521/scientific_papers", split='train', streaming=True)
subset_dataset_iterable = full_dataset.take(100)
papers_data = list(subset_dataset_iterable)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
all_chunks = []
for paper in papers_data:
chunks = text_splitter.split_text(paper['full_text'])
for chunk in chunks:
all_chunks.append(Document(
page_content=chunk,
metadata={"paper_id": paper['id']}
))
print(f"Created {len(all_chunks)} document chunks")
vector_store = FAISS.from_documents(all_chunks, embeddings)
# Save the index
vector_store.save_local(FAISS_INDEX_PATH)
print("FAISS index saved successfully")
return vector_store
except Exception as e:
print(f"Error creating vector store: {e}")
raise