from __future__ import annotations import json import time from pathlib import Path from typing import List, Tuple import numpy as np import pandas as pd from tqdm import tqdm from data.catalog_loader import make_assessment_id from models.embedding_model import EmbeddingModel def generate_embeddings(catalog_path: str, model_name: str, batch_size: int = 32, output_dir: str = "data/embeddings") -> Tuple[np.ndarray, List[str]]: df = pd.read_json(catalog_path, lines=True) if catalog_path.endswith(".jsonl") else pd.read_parquet(catalog_path) if "assessment_id" not in df.columns: if "url" in df.columns: df["assessment_id"] = df["url"].apply(make_assessment_id) else: raise KeyError("assessment_id not found and url missing to derive it.") df = df.sort_values("assessment_id") texts = df["doc_text"].tolist() ids = df["assessment_id"].tolist() model = EmbeddingModel(model_name) embeddings: List[np.ndarray] = [] start = time.time() for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"): batch = texts[i : i + batch_size] embeds = model.encode(batch, normalize=True, batch_size=batch_size, is_query=False) embeddings.append(embeds) embeddings_arr = np.vstack(embeddings).astype(np.float32) Path(output_dir).mkdir(parents=True, exist_ok=True) np.save(Path(output_dir) / "embeddings.npy", embeddings_arr) with open(Path(output_dir) / "assessment_ids.json", "w") as f: json.dump(ids, f, indent=2) total_time = time.time() - start log = { "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "model_name": model_name, "num_documents": len(texts), "embedding_dim": embeddings_arr.shape[1], "batch_size": batch_size, "total_time_seconds": total_time, "avg_time_per_doc_ms": (total_time / len(texts) * 1000) if len(texts) else None, "normalized": True, "catalog_path": catalog_path, } with open(Path(output_dir) / "generation_log.json", "w") as f: json.dump(log, f, indent=2) return embeddings_arr, ids if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--catalog", required=True, help="Enriched catalog with doc_text") parser.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2") parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--output-dir", default="data/embeddings") args = parser.parse_args() generate_embeddings(args.catalog, args.model, batch_size=args.batch_size, output_dir=args.output_dir)