import lancedb import torch import pyarrow as pa import pandas as pd from pathlib import Path import tqdm import numpy as np from sentence_transformers import SentenceTransformer EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" DB_TABLE_NAME = "ChunkedBigIndexSEM" VECTOR_COLUMN_NAME = "vector" TEXT_COLUMN_NAME = "text" INPUT_DIR = 'semchunksSEN' db = lancedb.connect(".lancedb") # db location batch_size = 32 model = SentenceTransformer(EMB_MODEL_NAME) model.eval() if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" schema = pa.schema( [ pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), 384)), pa.field(TEXT_COLUMN_NAME, pa.string()) ]) tbl = db.create_table(DB_TABLE_NAME, schema=schema, mode="overwrite") input_dir = Path(INPUT_DIR) files = list(input_dir.rglob("*")) sentences = [] for file in files: temp_string = '' with open(file) as f: for line in f: # Check if the line is not empty if line.strip(): temp_string += line.strip() + ' ' # Add non-empty line to temp_string else: if temp_string: # Add temp_string to array if it's not empty sentences.append(temp_string) temp_string = '' # Reset temp_string for next block of text # Add the last temp_string to the array if the file doesn't end with an empty line if temp_string: sentences.append(temp_string) for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))): try: batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0] encoded = model.encode(batch, normalize_embeddings=True, device=device) encoded = [list(vec) for vec in encoded] df = pd.DataFrame({ VECTOR_COLUMN_NAME: encoded, TEXT_COLUMN_NAME: batch }) tbl.add(df) except Exception as e: print(f"batch {i} was skipped") print(e) ''' create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/ with the size of the transformer docs, index is not really needed but we'll do it for demonstrational purposes ''' tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)