DevAssist / smebuilder_vector.py
alaselababatunde's picture
Updated
82fd433
raw
history blame
1.91 kB
import os
import pandas as pd
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
# ----------------- CONFIG -----------------
DATASET_PATH = "sme_builder_dataset.csv"
DB_LOCATION = os.getenv("CHROMA_DB_DIR", "./Dev_Assist_SME_Builder_DB")
COLLECTION_NAME = "landing_page_generation_examples"
EMBEDDING_MODEL = os.getenv("HF_EMBEDDING_MODEL", "intfloat/e5-large-v2")
HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/app/huggingface_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)
os.makedirs(DB_LOCATION, exist_ok=True)
# ----------------- LOAD DATASET -----------------
if not os.path.exists(DATASET_PATH):
raise FileNotFoundError(f"Dataset file not found: {DATASET_PATH}")
df = pd.read_csv(DATASET_PATH)
# ----------------- EMBEDDINGS -----------------
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
# ----------------- VECTOR STORE -----------------
# Only add documents if DB is empty
add_documents = not os.listdir(DB_LOCATION)
vector_store = Chroma(
collection_name=COLLECTION_NAME,
persist_directory=DB_LOCATION,
embedding_function=embeddings,
)
if add_documents:
documents = []
for i, row in df.iterrows():
content = " ".join([
str(row.get("prompt", "")),
str(row.get("html_code", "")),
str(row.get("css_code", "")),
str(row.get("js_code", "")),
str(row.get("sector", ""))
]).strip()
documents.append(Document(page_content=content, metadata={"id": str(i)}))
if documents:
vector_store.add_documents(documents=documents)
# ----------------- RETRIEVER -----------------
retriever = vector_store.as_retriever(search_kwargs={"k": 20})
print(f"SME vector store initialized. collection={COLLECTION_NAME}, documents={vector_store._collection.count()}")