File size: 4,682 Bytes
39af4d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import os
import sys
import glob
from typing import List, Generator
from bs4 import BeautifulSoup
import google.generativeai as genai
from qdrant_client import QdrantClient, models
# Add the project root to the Python path to allow importing from `src`
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from src.core.config import settings
# --- Configuration ---
EMBEDDING_MODEL = 'models/gemini-embedding-001'
COLLECTION_NAME = "textbook_content"
DOCS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'frontend', 'docs'))
EMBEDDING_DIMENSION = 3072 # For Gemini gemini-embedding-001
def get_documents(path: str) -> List[str]:
"""Finds all markdown files in the specified path."""
print(f"Searching for markdown files in: {path}")
files = glob.glob(f"{path}/**/*.md", recursive=True)
files.extend(glob.glob(f"{path}/**/*.mdx", recursive=True))
print(f"Found {len(files)} documents.")
return files
def get_text_chunks(file_path: str, chunk_size: int = 2000, overlap: int = 200) -> Generator[str, None, None]:
"""Reads a file, cleans HTML/Markdown, and yields text chunks."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Use BeautifulSoup to strip markdown/html tags for cleaner text
soup = BeautifulSoup(content, 'html.parser')
text = soup.get_text()
if not text:
return
start = 0
while start < len(text):
end = start + chunk_size
yield text[start:end]
start += chunk_size - overlap
except Exception as e:
print(f"Error processing file {file_path}: {e}")
return
def main():
"""
Main function to run the data ingestion process.
"""
print("--- Starting Data Ingestion ---")
# --- Initialize Clients ---
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
qdrant_client = QdrantClient(url=settings.QDRANT_URL, api_key=settings.QDRANT_API_KEY)
print("Successfully initialized Gemini and Qdrant clients.")
except Exception as e:
print(f"Error initializing clients: {e}")
return
# --- Setup Qdrant Collection ---
print(f"Setting up Qdrant collection: '{COLLECTION_NAME}'")
try:
qdrant_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(size=EMBEDDING_DIMENSION, distance=models.Distance.COSINE),
)
print(f"Collection '{COLLECTION_NAME}' created/recreated successfully.")
except Exception as e:
print(f"Error creating Qdrant collection: {e}")
return
# --- Process and Upload Documents ---
documents = get_documents(DOCS_PATH)
all_chunks = []
chunk_metadata = []
point_id_counter = 0
for doc_path in documents:
print(f"\nProcessing document: {doc_path}")
for chunk in get_text_chunks(doc_path):
all_chunks.append(chunk)
chunk_metadata.append({"text": chunk, "source": os.path.basename(doc_path)})
point_id_counter += 1
# Batch embeddings for Gemini
batch_size = 100
for i in range(0, len(all_chunks), batch_size):
batch_chunks = all_chunks[i:i + batch_size]
batch_metadata = chunk_metadata[i:i + batch_size]
batch_ids = list(range(i, i + len(batch_chunks)))
try:
# Generate embeddings using Gemini
response = genai.embed_content(
model=EMBEDDING_MODEL,
content=batch_chunks,
task_type="retrieval_document"
)
embeddings = response['embedding']
# Prepare points for Qdrant
points_to_upsert = []
for j, embedding in enumerate(embeddings):
points_to_upsert.append(
models.PointStruct(
id=batch_ids[j],
vector=embedding,
payload=batch_metadata[j],
)
)
# Upsert in batches to Qdrant
qdrant_client.upsert(collection_name=COLLECTION_NAME, points=points_to_upsert, wait=True)
print(f"Upserted a batch of {len(points_to_upsert)} points (IDs {batch_ids[0]} - {batch_ids[-1]}).")
except Exception as e:
print(f"Error processing batch {i} to {i + len(batch_chunks)}: {e}")
print(f"\n--- Data Ingestion Complete ---")
print(f"Total points added to collection '{COLLECTION_NAME}': {point_id_counter}")
if __name__ == "__main__":
main()
|