DevAssist / smebuilder_vector.py
alaselababatunde's picture
Updated
9fee2b7
raw
history blame
1.86 kB
import os
import pandas as pd
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.schema import Document
# ----------------- CONFIG -----------------
DATASET_PATH = "sme_builder_dataset.csv"
DB_LOCATION = os.getenv("CHROMA_DB_DIR", "./Dev_Assist_SME_Builder_DB")
COLLECTION_NAME = "landing_page_generation_examples"
EMBEDDING_MODEL = os.getenv("HF_EMBEDDING_MODEL", "intfloat/e5-large-v2")
HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/app/huggingface_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)
os.makedirs(DB_LOCATION, exist_ok=True)
# ----------------- LOAD DATASET -----------------
if not os.path.exists(DATASET_PATH):
raise FileNotFoundError(f"Dataset file not found: {DATASET_PATH}")
df = pd.read_csv(DATASET_PATH)
# ----------------- EMBEDDINGS -----------------
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL
)
# ----------------- VECTOR STORE -----------------
# Only add documents if DB is empty
add_documents = not os.listdir(DB_LOCATION)
vector_store = Chroma(
collection_name=COLLECTION_NAME,
persist_directory=DB_LOCATION,
embedding_function=embeddings,
)
if add_documents:
documents = []
for i, row in df.iterrows():
content = " ".join([
str(row.get("prompt", "")),
str(row.get("html_code", "")),
str(row.get("css_code", "")),
str(row.get("js_code", "")),
str(row.get("sector", ""))
]).strip()
documents.append(Document(page_content=content, metadata={"id": str(i)}))
if documents:
vector_store.add_documents(documents=documents)
# ----------------- RETRIEVER -----------------
retriever = vector_store.as_retriever(search_kwargs={"k": 20})
print(f"Vector store ready with {vector_store._collection.count()} documents.")