File size: 1,785 Bytes
5e51aba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# smebuilder_vector.py

import os
import pandas as pd
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.documents import Document

# ----------------- CONFIG -----------------
DATASET_PATH = "sme_builder_dataset.csv"
DB_LOCATION = "./Dev_Assist_SME_Builder_DB"
COLLECTION_NAME = "landing_page_generation_examples"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"

# ----------------- LOAD DATASET -----------------
if not os.path.exists(DATASET_PATH):
    raise FileNotFoundError(f"Dataset file not found: {DATASET_PATH}")

df = pd.read_csv(DATASET_PATH)

# ----------------- EMBEDDINGS -----------------
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)

# Check if vector store exists
add_documents = not os.path.exists(DB_LOCATION)

# ----------------- CREATE DOCUMENTS -----------------
documents, ids = [], []
if add_documents:
    for i, row in df.iterrows():
        prompt = row.get("prompt", "")
        html_code = row.get("html_code", "")
        css_code = row.get("css_code", "")
        js_code = row.get("js_code", "")
        sector = row.get("sector", "")

        page_content = " ".join(
            [str(prompt), str(html_code), str(css_code), str(js_code), str(sector)]
        ).strip()

        documents.append(Document(page_content=page_content, id=str(i)))
        ids.append(str(i))

# ----------------- VECTOR STORE -----------------
vector_store = Chroma(
    collection_name=COLLECTION_NAME,
    persist_directory=DB_LOCATION,
    embedding_function=embeddings,
)

if add_documents and documents:
    vector_store.add_documents(documents=documents, ids=ids)

# ----------------- RETRIEVER -----------------
retriever = vector_store.as_retriever(search_kwargs={"k": 20})