File size: 2,343 Bytes
5240a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import lancedb
import torch
import pyarrow as pa
import pandas as pd
from pathlib import Path
import tqdm
import numpy as np

from sentence_transformers import SentenceTransformer


EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DB_TABLE_NAME = "ChunkedBigIndexSEM"
VECTOR_COLUMN_NAME = "vector"
TEXT_COLUMN_NAME = "text"
INPUT_DIR = 'semchunksSEN'
db = lancedb.connect(".lancedb") # db location
batch_size = 32

model = SentenceTransformer(EMB_MODEL_NAME)
model.eval()

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

schema = pa.schema(
  [
      pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), 384)),
      pa.field(TEXT_COLUMN_NAME, pa.string())
  ])
tbl = db.create_table(DB_TABLE_NAME, schema=schema, mode="overwrite")

input_dir = Path(INPUT_DIR)
files = list(input_dir.rglob("*"))

sentences = []
for file in files:
    temp_string = ''
    with open(file) as f:
        for line in f:
            # Check if the line is not empty
            if line.strip():
                temp_string += line.strip() + ' '  # Add non-empty line to temp_string
            else:
                if temp_string:  # Add temp_string to array if it's not empty
                    sentences.append(temp_string)
                    temp_string = ''  # Reset temp_string for next block of text

    # Add the last temp_string to the array if the file doesn't end with an empty line
    if temp_string:
        sentences.append(temp_string)

for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
    try:
        batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0]
        encoded = model.encode(batch, normalize_embeddings=True, device=device)
        encoded = [list(vec) for vec in encoded]

        df = pd.DataFrame({
            VECTOR_COLUMN_NAME: encoded,
            TEXT_COLUMN_NAME: batch
        })

        tbl.add(df)
    except Exception as e:
       print(f"batch {i} was skipped")
       print(e)

'''
create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
with the size of the transformer docs, index is not really needed
but we'll do it for demonstrational purposes
'''
tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)