|
|
|
|
|
"""Embedding generation CLI script for the RAG pipeline. |
|
|
|
|
|
This script generates embeddings from text chunks and builds search indexes |
|
|
for the pythermalcomfort RAG chatbot. It supports: |
|
|
- Reading chunks from JSONL files |
|
|
- Generating embeddings using BGE encoder with GPU acceleration |
|
|
- Building FAISS indexes for dense retrieval |
|
|
- Building BM25 indexes for sparse retrieval |
|
|
- Publishing artifacts to HuggingFace (optional) |
|
|
|
|
|
Usage: |
|
|
# Basic embedding generation |
|
|
poetry run python scripts/embed.py data/chunks/chunks.jsonl data/embeddings/ |
|
|
|
|
|
# With HuggingFace publishing |
|
|
poetry run python scripts/embed.py data/chunks/chunks.jsonl data/embeddings/ \ |
|
|
--publish |
|
|
|
|
|
# Custom batch size and model |
|
|
poetry run python scripts/embed.py data/chunks/chunks.jsonl data/embeddings/ \ |
|
|
--batch-size 64 --model BAAI/bge-base-en-v1.5 |
|
|
|
|
|
Output Files: |
|
|
{output_dir}/ |
|
|
βββ embeddings.parquet # Embeddings with chunk_id mapping |
|
|
βββ metadata.json # Model metadata |
|
|
βββ faiss_index.bin # FAISS index for dense retrieval |
|
|
βββ faiss_index.bin.ids.json # Chunk ID mapping for FAISS |
|
|
βββ bm25_index.pkl # BM25 index for sparse retrieval |
|
|
βββ chunks.parquet # Chunks in parquet format (if --publish) |
|
|
|
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
_PROJECT_ROOT = Path(__file__).parent.parent |
|
|
_ENV_FILE = _PROJECT_ROOT / ".env" |
|
|
|
|
|
if _ENV_FILE.exists(): |
|
|
load_dotenv(_ENV_FILE) |
|
|
|
|
|
|
|
|
from rich.console import Console |
|
|
from rich.progress import ( |
|
|
BarColumn, |
|
|
MofNCompleteColumn, |
|
|
Progress, |
|
|
SpinnerColumn, |
|
|
TaskID, |
|
|
TaskProgressColumn, |
|
|
TextColumn, |
|
|
TimeElapsedColumn, |
|
|
TimeRemainingColumn, |
|
|
) |
|
|
from rich.table import Table |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from rag_chatbot.chunking.models import Chunk |
|
|
from rag_chatbot.embeddings import ( |
|
|
BGEEncoder, |
|
|
EmbeddingRecord, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MODEL: str = "BAAI/bge-small-en-v1.5" |
|
|
|
|
|
|
|
|
DEFAULT_BATCH_SIZE: int = 32 |
|
|
|
|
|
|
|
|
BGE_SMALL_DIM: int = 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
|
handlers=[logging.StreamHandler(sys.stderr)], |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
console = Console() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_device_info() -> tuple[str, str]: |
|
|
"""Get GPU/device information for reporting. |
|
|
|
|
|
This function checks for CUDA availability and returns device information |
|
|
for display in the CLI output. It imports torch lazily to avoid loading |
|
|
heavy dependencies until needed. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
Tuple of (device_type, device_name) where: |
|
|
- device_type: "cuda" or "cpu" |
|
|
- device_name: GPU name (e.g., "NVIDIA RTX 4090") or "CPU" |
|
|
|
|
|
Example: |
|
|
------- |
|
|
>>> device_type, device_name = get_device_info() |
|
|
>>> print(f"Using {device_name}") |
|
|
Using NVIDIA RTX 4090 |
|
|
|
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device_type = "cuda" |
|
|
device_name = torch.cuda.get_device_name(0) |
|
|
else: |
|
|
device_type = "cpu" |
|
|
device_name = "CPU" |
|
|
|
|
|
return device_type, device_name |
|
|
|
|
|
|
|
|
def load_chunks_from_jsonl(input_path: Path) -> list[Chunk]: |
|
|
"""Load chunks from a JSONL file. |
|
|
|
|
|
Reads a JSONL file containing chunk data and parses each line into |
|
|
a Chunk model instance. Text normalization is applied during loading |
|
|
to fix common PDF extraction artifacts. |
|
|
|
|
|
Args: |
|
|
---- |
|
|
input_path: Path to the chunks.jsonl file. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
List of Chunk objects parsed from the file. |
|
|
|
|
|
Raises: |
|
|
------ |
|
|
FileNotFoundError: If the input file doesn't exist. |
|
|
ValueError: If the file is empty or contains invalid JSON. |
|
|
|
|
|
Example: |
|
|
------- |
|
|
>>> chunks = load_chunks_from_jsonl(Path("data/chunks/chunks.jsonl")) |
|
|
>>> len(chunks) |
|
|
1500 |
|
|
|
|
|
""" |
|
|
|
|
|
from rag_chatbot.chunking.models import Chunk, TextNormalizer |
|
|
|
|
|
if not input_path.exists(): |
|
|
msg = f"Input file not found: {input_path}" |
|
|
raise FileNotFoundError(msg) |
|
|
|
|
|
chunks: list[Chunk] = [] |
|
|
normalizer = TextNormalizer() |
|
|
|
|
|
with open(input_path, encoding="utf-8") as f: |
|
|
for line_num, raw_line in enumerate(f, start=1): |
|
|
|
|
|
line = raw_line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
data = json.loads(line) |
|
|
|
|
|
|
|
|
|
|
|
if "text" in data: |
|
|
data["text"] = normalizer.normalize(data["text"], is_heading=False) |
|
|
|
|
|
|
|
|
chunk = Chunk(**data) |
|
|
chunks.append(chunk) |
|
|
|
|
|
except json.JSONDecodeError as exc: |
|
|
logger.warning("Invalid JSON at line %d: %s", line_num, exc) |
|
|
continue |
|
|
except Exception as exc: |
|
|
logger.warning("Error parsing chunk at line %d: %s", line_num, exc) |
|
|
continue |
|
|
|
|
|
if not chunks: |
|
|
msg = f"No valid chunks found in {input_path}" |
|
|
raise ValueError(msg) |
|
|
|
|
|
return chunks |
|
|
|
|
|
|
|
|
def create_embedding_records( |
|
|
chunks: list[Chunk], |
|
|
encoder: BGEEncoder, |
|
|
batch_size: int, |
|
|
progress: Progress, |
|
|
task_id: TaskID, |
|
|
) -> list[EmbeddingRecord]: |
|
|
"""Generate embeddings for chunks and create EmbeddingRecord objects. |
|
|
|
|
|
This function encodes all chunk texts using the BGE encoder and creates |
|
|
EmbeddingRecord instances with chunk_id, chunk_hash, and embedding data. |
|
|
Progress is reported through the Rich progress bar. |
|
|
|
|
|
Args: |
|
|
---- |
|
|
chunks: List of Chunk objects to embed. |
|
|
encoder: BGEEncoder instance for generating embeddings. |
|
|
batch_size: Number of chunks to process per batch. |
|
|
progress: Rich Progress instance for progress tracking. |
|
|
task_id: Task ID for the progress bar. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
List of EmbeddingRecord objects with generated embeddings. |
|
|
|
|
|
Example: |
|
|
------- |
|
|
>>> records = create_embedding_records(chunks, encoder, 32, progress, task_id) |
|
|
>>> len(records) == len(chunks) |
|
|
True |
|
|
|
|
|
""" |
|
|
|
|
|
from rag_chatbot.embeddings import EmbeddingRecord |
|
|
|
|
|
|
|
|
texts = [chunk.text for chunk in chunks] |
|
|
|
|
|
|
|
|
def progress_callback(current_batch: int, _total_batches: int) -> None: |
|
|
"""Update progress bar after each batch.""" |
|
|
progress.update(task_id, completed=current_batch) |
|
|
|
|
|
|
|
|
embeddings = encoder.encode( |
|
|
texts=texts, |
|
|
batch_size=batch_size, |
|
|
show_progress=False, |
|
|
progress_callback=progress_callback, |
|
|
) |
|
|
|
|
|
|
|
|
records: list[EmbeddingRecord] = [] |
|
|
for idx, chunk in enumerate(chunks): |
|
|
record = EmbeddingRecord( |
|
|
chunk_id=chunk.chunk_id, |
|
|
chunk_hash=chunk.chunk_hash, |
|
|
embedding=embeddings[idx].tolist(), |
|
|
) |
|
|
records.append(record) |
|
|
|
|
|
return records |
|
|
|
|
|
|
|
|
def build_indexes( |
|
|
output_dir: Path, |
|
|
chunks: list[Chunk], |
|
|
progress: Progress, |
|
|
) -> tuple[float, float]: |
|
|
"""Build FAISS and BM25 indexes from embeddings and chunks. |
|
|
|
|
|
This function builds both dense (FAISS) and sparse (BM25) indexes |
|
|
for hybrid retrieval. The FAISS index is built from the embeddings |
|
|
parquet file, while BM25 is built from chunk texts. |
|
|
|
|
|
Args: |
|
|
---- |
|
|
output_dir: Directory containing embeddings.parquet and for saving indexes. |
|
|
chunks: List of Chunk objects for BM25 indexing. |
|
|
progress: Rich Progress instance for progress tracking. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
Tuple of (faiss_build_time, bm25_build_time) in seconds. |
|
|
|
|
|
Example: |
|
|
------- |
|
|
>>> faiss_time, bm25_time = build_indexes(output_dir, chunks, progress) |
|
|
>>> print(f"FAISS: {faiss_time:.2f}s, BM25: {bm25_time:.2f}s") |
|
|
|
|
|
""" |
|
|
|
|
|
from rag_chatbot.embeddings import BM25IndexBuilder, FAISSIndexBuilder |
|
|
|
|
|
embeddings_path = output_dir / "embeddings.parquet" |
|
|
|
|
|
|
|
|
faiss_task = progress.add_task("[cyan]Building FAISS index...", total=1) |
|
|
faiss_start = time.perf_counter() |
|
|
|
|
|
faiss_builder = FAISSIndexBuilder() |
|
|
faiss_index = faiss_builder.build_from_parquet(embeddings_path) |
|
|
faiss_builder.save_index(faiss_index, output_dir / "faiss_index.bin") |
|
|
|
|
|
faiss_time = time.perf_counter() - faiss_start |
|
|
progress.update(faiss_task, completed=1) |
|
|
|
|
|
|
|
|
bm25_task = progress.add_task("[cyan]Building BM25 index...", total=1) |
|
|
bm25_start = time.perf_counter() |
|
|
|
|
|
bm25_builder = BM25IndexBuilder() |
|
|
bm25_index, chunk_ids = bm25_builder.build_from_chunks(chunks) |
|
|
bm25_builder.save_index(bm25_index, chunk_ids, output_dir / "bm25_index.pkl") |
|
|
|
|
|
bm25_time = time.perf_counter() - bm25_start |
|
|
progress.update(bm25_task, completed=1) |
|
|
|
|
|
return faiss_time, bm25_time |
|
|
|
|
|
|
|
|
def publish_to_huggingface( |
|
|
output_dir: Path, |
|
|
chunks: list[Chunk], |
|
|
model_name: str, |
|
|
embedding_dim: int, |
|
|
progress: Progress, |
|
|
) -> str: |
|
|
"""Publish all artifacts to HuggingFace dataset repository. |
|
|
|
|
|
This function handles the complete publishing workflow: |
|
|
1. Saves chunks to parquet format |
|
|
2. Generates source manifest |
|
|
3. Authenticates with HuggingFace |
|
|
4. Uploads all artifacts |
|
|
|
|
|
Args: |
|
|
---- |
|
|
output_dir: Directory containing artifacts to publish. |
|
|
chunks: List of Chunk objects for chunks.parquet. |
|
|
model_name: Name of the embedding model used. |
|
|
embedding_dim: Dimension of embeddings. |
|
|
progress: Rich Progress instance for progress tracking. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
URL of the published HuggingFace dataset. |
|
|
|
|
|
Raises: |
|
|
------ |
|
|
ValueError: If HF_TOKEN is not set. |
|
|
RuntimeError: If publishing fails. |
|
|
|
|
|
Example: |
|
|
------- |
|
|
>>> url = publish_to_huggingface( |
|
|
... output_dir, chunks, "BAAI/bge-small-en-v1.5", 384, progress |
|
|
... ) |
|
|
>>> print(url) |
|
|
'https://huggingface.co/datasets/sadickam/pytherm_index' |
|
|
|
|
|
""" |
|
|
|
|
|
from rag_chatbot.embeddings import HuggingFacePublisher, PublisherConfig |
|
|
|
|
|
publish_task = progress.add_task("[cyan]Publishing to HuggingFace...", total=4) |
|
|
|
|
|
|
|
|
config = PublisherConfig() |
|
|
publisher = HuggingFacePublisher(config) |
|
|
|
|
|
publisher.save_chunks_parquet(chunks, output_dir) |
|
|
progress.update(publish_task, advance=1) |
|
|
|
|
|
|
|
|
manifest = publisher.generate_source_manifest( |
|
|
source_files=[], |
|
|
total_chunks=len(chunks), |
|
|
total_embeddings=len(chunks), |
|
|
) |
|
|
progress.update(publish_task, advance=1) |
|
|
|
|
|
|
|
|
publisher.authenticate() |
|
|
progress.update(publish_task, advance=1) |
|
|
|
|
|
|
|
|
dataset_url = publisher.publish( |
|
|
artifacts_dir=output_dir, |
|
|
manifest=manifest, |
|
|
model_name=model_name, |
|
|
embedding_dimension=embedding_dim, |
|
|
) |
|
|
progress.update(publish_task, advance=1) |
|
|
|
|
|
return dataset_url |
|
|
|
|
|
|
|
|
def print_statistics( |
|
|
total_chunks: int, |
|
|
total_time: float, |
|
|
embedding_time: float, |
|
|
faiss_time: float, |
|
|
bm25_time: float, |
|
|
device_name: str, |
|
|
model_name: str, |
|
|
output_dir: Path, |
|
|
dataset_url: str | None = None, |
|
|
) -> None: |
|
|
"""Print final statistics table using Rich. |
|
|
|
|
|
Displays a formatted table with embedding statistics including: |
|
|
- Total chunks processed |
|
|
- Time breakdowns (embedding, indexing) |
|
|
- Throughput metrics |
|
|
- Device information |
|
|
- Output file sizes |
|
|
|
|
|
Args: |
|
|
---- |
|
|
total_chunks: Number of chunks embedded. |
|
|
total_time: Total elapsed time in seconds. |
|
|
embedding_time: Time spent on embedding generation. |
|
|
faiss_time: Time spent building FAISS index. |
|
|
bm25_time: Time spent building BM25 index. |
|
|
device_name: Name of the device used (GPU name or "CPU"). |
|
|
model_name: Name of the embedding model. |
|
|
output_dir: Directory where outputs were saved. |
|
|
dataset_url: Optional URL of published HuggingFace dataset. |
|
|
|
|
|
""" |
|
|
|
|
|
throughput = total_chunks / embedding_time if embedding_time > 0 else 0 |
|
|
|
|
|
|
|
|
table = Table( |
|
|
title="Embedding Statistics", show_header=True, header_style="bold cyan" |
|
|
) |
|
|
table.add_column("Metric", style="dim", width=25) |
|
|
table.add_column("Value", justify="right") |
|
|
|
|
|
|
|
|
table.add_row("Total Chunks", f"{total_chunks:,}") |
|
|
table.add_row("Model", model_name) |
|
|
table.add_row("Device", device_name) |
|
|
table.add_row("", "") |
|
|
table.add_row("Embedding Time", f"{embedding_time:.2f}s") |
|
|
table.add_row("FAISS Build Time", f"{faiss_time:.2f}s") |
|
|
table.add_row("BM25 Build Time", f"{bm25_time:.2f}s") |
|
|
table.add_row("Total Time", f"{total_time:.2f}s") |
|
|
table.add_row("", "") |
|
|
table.add_row("Throughput", f"{throughput:.1f} chunks/sec") |
|
|
|
|
|
|
|
|
embeddings_file = output_dir / "embeddings.parquet" |
|
|
faiss_file = output_dir / "faiss_index.bin" |
|
|
bm25_file = output_dir / "bm25_index.pkl" |
|
|
|
|
|
if embeddings_file.exists(): |
|
|
size_mb = embeddings_file.stat().st_size / (1024 * 1024) |
|
|
table.add_row("Embeddings Size", f"{size_mb:.2f} MB") |
|
|
|
|
|
if faiss_file.exists(): |
|
|
size_mb = faiss_file.stat().st_size / (1024 * 1024) |
|
|
table.add_row("FAISS Index Size", f"{size_mb:.2f} MB") |
|
|
|
|
|
if bm25_file.exists(): |
|
|
size_mb = bm25_file.stat().st_size / (1024 * 1024) |
|
|
table.add_row("BM25 Index Size", f"{size_mb:.2f} MB") |
|
|
|
|
|
|
|
|
if dataset_url: |
|
|
table.add_row("", "") |
|
|
table.add_row("Published URL", dataset_url) |
|
|
|
|
|
console.print() |
|
|
console.print(table) |
|
|
console.print() |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
"""Parse command line arguments. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Parsed argument namespace with input_path, output_dir, and options. |
|
|
|
|
|
""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Generate embeddings and build indexes for RAG retrieval.", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
poetry run python scripts/embed.py chunks.jsonl output/ |
|
|
poetry run python scripts/embed.py chunks.jsonl output/ --publish |
|
|
poetry run python scripts/embed.py chunks.jsonl output/ --batch-size 64 |
|
|
""", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"input_path", |
|
|
type=Path, |
|
|
help="Path to chunks.jsonl file containing text chunks", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"output_dir", |
|
|
type=Path, |
|
|
help="Directory to save embeddings and indexes", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--publish", |
|
|
action="store_true", |
|
|
help="Publish artifacts to HuggingFace after embedding", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--batch-size", |
|
|
type=int, |
|
|
default=DEFAULT_BATCH_SIZE, |
|
|
help=f"Batch size for embedding generation (default: {DEFAULT_BATCH_SIZE})", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
default=DEFAULT_MODEL, |
|
|
help=f"Embedding model name (default: {DEFAULT_MODEL})", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="auto", |
|
|
choices=["cpu", "cuda", "auto"], |
|
|
help="Device to use for embedding (default: auto)", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
"""Run the embedding generation pipeline. |
|
|
|
|
|
This is the main entry point for the embed.py CLI script. It orchestrates: |
|
|
1. Loading chunks from JSONL |
|
|
2. Initializing the encoder with GPU if available |
|
|
3. Generating embeddings with progress tracking |
|
|
4. Saving embeddings to parquet storage |
|
|
5. Building FAISS and BM25 indexes |
|
|
6. Optionally publishing to HuggingFace |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Exit code (0 for success, 1 for error). |
|
|
|
|
|
""" |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
start_time = time.perf_counter() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
console.print("\n[bold cyan]Embedding Generation Pipeline[/bold cyan]\n") |
|
|
|
|
|
device_type, device_name = get_device_info() |
|
|
device = args.device if args.device != "auto" else None |
|
|
|
|
|
console.print(f"[green]Device:[/green] {device_name}") |
|
|
console.print(f"[green]Model:[/green] {args.model}") |
|
|
console.print(f"[green]Batch Size:[/green] {args.batch_size}") |
|
|
console.print() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
console.print(f"[cyan]Loading chunks from {args.input_path}...[/cyan]") |
|
|
|
|
|
chunks = load_chunks_from_jsonl(args.input_path) |
|
|
console.print(f"[green]Loaded {len(chunks):,} chunks[/green]\n") |
|
|
|
|
|
|
|
|
if not chunks: |
|
|
console.print("[yellow]Warning: No chunks to process. Exiting.[/yellow]") |
|
|
return 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from rag_chatbot.embeddings import ( |
|
|
BGEEncoder, |
|
|
EmbeddingBatch, |
|
|
EmbeddingStorage, |
|
|
) |
|
|
|
|
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
encoder = BGEEncoder( |
|
|
model_name=args.model, |
|
|
device=device, |
|
|
normalize_text=False, |
|
|
) |
|
|
|
|
|
|
|
|
storage = EmbeddingStorage(args.output_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
total_batches = math.ceil(len(chunks) / args.batch_size) |
|
|
|
|
|
with Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(), |
|
|
TaskProgressColumn(), |
|
|
MofNCompleteColumn(), |
|
|
TimeElapsedColumn(), |
|
|
TimeRemainingColumn(), |
|
|
console=console, |
|
|
) as progress: |
|
|
|
|
|
embed_task = progress.add_task( |
|
|
"[cyan]Embedding chunks...", |
|
|
total=total_batches, |
|
|
) |
|
|
|
|
|
embedding_start = time.perf_counter() |
|
|
|
|
|
|
|
|
records = create_embedding_records( |
|
|
chunks=chunks, |
|
|
encoder=encoder, |
|
|
batch_size=args.batch_size, |
|
|
progress=progress, |
|
|
task_id=embed_task, |
|
|
) |
|
|
|
|
|
embedding_time = time.perf_counter() - embedding_start |
|
|
|
|
|
|
|
|
progress.update(embed_task, completed=total_batches) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_task = progress.add_task("[cyan]Saving embeddings...", total=1) |
|
|
|
|
|
|
|
|
batch = EmbeddingBatch( |
|
|
model_name=args.model, |
|
|
dimension=encoder.embedding_dim, |
|
|
dtype="float16", |
|
|
records=records, |
|
|
) |
|
|
|
|
|
|
|
|
storage.save(batch) |
|
|
|
|
|
progress.update(save_task, completed=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
faiss_time, bm25_time = build_indexes( |
|
|
output_dir=args.output_dir, |
|
|
chunks=chunks, |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_url: str | None = None |
|
|
|
|
|
if args.publish: |
|
|
dataset_url = publish_to_huggingface( |
|
|
output_dir=args.output_dir, |
|
|
chunks=chunks, |
|
|
model_name=args.model, |
|
|
embedding_dim=encoder.embedding_dim, |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_time = time.perf_counter() - start_time |
|
|
|
|
|
print_statistics( |
|
|
total_chunks=len(chunks), |
|
|
total_time=total_time, |
|
|
embedding_time=embedding_time, |
|
|
faiss_time=faiss_time, |
|
|
bm25_time=bm25_time, |
|
|
device_name=device_name, |
|
|
model_name=args.model, |
|
|
output_dir=args.output_dir, |
|
|
dataset_url=dataset_url, |
|
|
) |
|
|
|
|
|
except FileNotFoundError as exc: |
|
|
console.print(f"[bold red]Error:[/bold red] {exc}") |
|
|
return 1 |
|
|
except ValueError as exc: |
|
|
console.print(f"[bold red]Error:[/bold red] {exc}") |
|
|
return 1 |
|
|
except KeyboardInterrupt: |
|
|
console.print("\n[yellow]Interrupted by user[/yellow]") |
|
|
return 1 |
|
|
except Exception as exc: |
|
|
console.print(f"[bold red]Unexpected error:[/bold red] {exc}") |
|
|
logger.exception("Unexpected error during embedding generation") |
|
|
return 1 |
|
|
else: |
|
|
console.print("[bold green]Embedding generation complete![/bold green]\n") |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|