File size: 1,230 Bytes
2615cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

import os
import numpy as np
import pandas as pd

# Lazy import to allow CPU-only envs
from sentence_transformers import SentenceTransformer

DATASET_CSV = os.getenv("DATASET_CSV", "cars1200_text_dataset.csv")
EMB_PATH = os.getenv("EMB_PATH", "embeddings.npy")
ID_PATH = os.getenv("ID_PATH", "ids.csv")
MODEL_NAME = os.getenv("MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
TEXT_COL = os.getenv("TEXT_COL", "text_record")

def main():
    if not os.path.exists(DATASET_CSV):
        raise FileNotFoundError(f"Dataset not found: {DATASET_CSV}")
    df = pd.read_csv(DATASET_CSV)
    if TEXT_COL not in df.columns:
        raise KeyError(f"Column '{TEXT_COL}' not found in {DATASET_CSV}.")

    print(f"Loading model: {MODEL_NAME}")
    model = SentenceTransformer(MODEL_NAME)

    texts = df[TEXT_COL].astype(str).tolist()
    print(f"Encoding {len(texts)} records...")
    embs = model.encode(texts, batch_size=256, show_progress_bar=True, normalize_embeddings=True)
    embs = np.asarray(embs, dtype="float32")

    np.save(EMB_PATH, embs)
    df[["name","make","model","trim","year"]].to_csv(ID_PATH, index=False)
    print(f"Saved embeddings to {EMB_PATH} and ids to {ID_PATH}")

if __name__ == "__main__":
    main()