Spaces:
Runtime error
Runtime error
| import lancedb | |
| import torch | |
| import pyarrow as pa | |
| import pandas as pd | |
| from pathlib import Path | |
| import tqdm | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| DB_TABLE_NAME = "ChunkedBigIndexSEM" | |
| VECTOR_COLUMN_NAME = "vector" | |
| TEXT_COLUMN_NAME = "text" | |
| INPUT_DIR = 'semchunksSEN' | |
| db = lancedb.connect(".lancedb") # db location | |
| batch_size = 32 | |
| model = SentenceTransformer(EMB_MODEL_NAME) | |
| model.eval() | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| elif torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| schema = pa.schema( | |
| [ | |
| pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), 384)), | |
| pa.field(TEXT_COLUMN_NAME, pa.string()) | |
| ]) | |
| tbl = db.create_table(DB_TABLE_NAME, schema=schema, mode="overwrite") | |
| input_dir = Path(INPUT_DIR) | |
| files = list(input_dir.rglob("*")) | |
| sentences = [] | |
| for file in files: | |
| temp_string = '' | |
| with open(file) as f: | |
| for line in f: | |
| # Check if the line is not empty | |
| if line.strip(): | |
| temp_string += line.strip() + ' ' # Add non-empty line to temp_string | |
| else: | |
| if temp_string: # Add temp_string to array if it's not empty | |
| sentences.append(temp_string) | |
| temp_string = '' # Reset temp_string for next block of text | |
| # Add the last temp_string to the array if the file doesn't end with an empty line | |
| if temp_string: | |
| sentences.append(temp_string) | |
| for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))): | |
| try: | |
| batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0] | |
| encoded = model.encode(batch, normalize_embeddings=True, device=device) | |
| encoded = [list(vec) for vec in encoded] | |
| df = pd.DataFrame({ | |
| VECTOR_COLUMN_NAME: encoded, | |
| TEXT_COLUMN_NAME: batch | |
| }) | |
| tbl.add(df) | |
| except Exception as e: | |
| print(f"batch {i} was skipped") | |
| print(e) | |
| ''' | |
| create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/ | |
| with the size of the transformer docs, index is not really needed | |
| but we'll do it for demonstrational purposes | |
| ''' | |
| tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME) |