File size: 2,361 Bytes
80b95e8 eae02ee 80b95e8 eae02ee 80b95e8 eae02ee 80b95e8 eae02ee 80b95e8 eae02ee 80b95e8 eae02ee 80b95e8 eae02ee 80b95e8 eae02ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
"""
This script handles document embedding using EmbeddingGemma.
This is the entry point for indexing documents.
"""
import os
import pickle
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
def embed_documents(path: str, config: dict):
"""
Embed documents from a directory and save to FAISS index.
Args:
path (str): Path to the directory containing the documents to embed.
config (dict): Configuration dictionary.
"""
try:
model = SentenceTransformer(config["embedding"]["model_path"])
print(f"Initalized embedding model: {config['embedding']['model_path']}")
except ValueError as e:
print(f"Error initializing embedding model: {e}")
return []
embeddings = []
texts = []
filenames = []
# Read all documents
for fname in os.listdir(path):
fpath = os.path.join(path, fname)
if os.path.isfile(fpath):
try:
with open(fpath, "r", encoding="utf-8") as f:
text = f.read()
if text.strip(): # Only process non-empty files
emb = model.encode(text)
embeddings.append(emb)
texts.append(text)
filenames.append(fname)
except Exception as e:
print(f"Error reading file {fpath}: {e}")
if not embeddings:
print("No documents were successfully embedded.")
return []
# Create FAISS index
dimension = embeddings[0].shape[0]
index = faiss.IndexFlatIP(dimension)
# Normalize embeddings for cosine similarity
embeddings_matrix = np.array(embeddings).astype("float32")
faiss.normalize_L2(embeddings_matrix)
# Add embeddings to index
index.add(embeddings_matrix)
# Save FAISS index and metadata
os.makedirs("vector_cache", exist_ok=True)
faiss.write_index(index, "vector_cache/faiss_index.bin")
with open("vector_cache/metadata.pkl", "wb") as f:
pickle.dump({"texts": texts, "filenames": filenames}, f)
print(f"Saved FAISS index to vector_cache/ with {len(embeddings)} documents.")
print(f"Total embeddings created: {len(embeddings)}")
return list(zip(filenames, embeddings))
|