Spaces:
Running
Running
| """ | |
| Unified embedding cache builder for ReguRAG (single implementation). | |
| Behavior: | |
| - Only one cache policy: if cache exists, reuse it; otherwise build and save. | |
| - Cache key only uses (model, chunk_mode). | |
| - No target/doc-mode/overwrite mode split. | |
| Benchmark query caches: | |
| - To support offline benchmark evaluation, this script also builds benchmark caches | |
| for single-doc and multi-doc question sets. | |
| Supported models: | |
| - BM25 | |
| - Qwen3-Embedding-0.6B | |
| - Qwen3-Embedding-4B | |
| - text-embedding-3-large | |
| - text-embedding-3-small | |
| - text-embedding-ada-002 | |
| Examples: | |
| python openai_embedding.py --models all --chunk-mode all | |
| python openai_embedding.py --models text-embedding-3-small --chunk-mode structure --base-url https://88996.cloud/v1 | |
| python openai_embedding.py --models Qwen3-Embedding-4B --chunk-mode structure --device cuda | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import pickle | |
| import re | |
| import sys | |
| import time | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if SCRIPT_DIR not in sys.path: | |
| sys.path.insert(0, SCRIPT_DIR) | |
| from rag_app_backend import ( | |
| EMBED_CACHE_DIR, | |
| EMBED_MODEL_PATHS, | |
| OPENAI_EMBED_BASE_URL, | |
| OPENAI_EMBED_MODELS, | |
| _build_chunk_pool, | |
| _extract_embedding_vectors, | |
| _resolve_api_key, | |
| get_openai_client, | |
| get_report_chunks, | |
| ) | |
| try: | |
| from tqdm.auto import tqdm | |
| except Exception: | |
| def tqdm(x, **kwargs): # type: ignore | |
| return x | |
| BM25_MODEL = "BM25" | |
| QWEN_MODELS = {"Qwen3-Embedding-0.6B", "Qwen3-Embedding-4B"} | |
| OPENAI_MODELS = set(OPENAI_EMBED_MODELS) | |
| ALL_MODELS = [ | |
| BM25_MODEL, | |
| "Qwen3-Embedding-0.6B", | |
| "Qwen3-Embedding-4B", | |
| "text-embedding-3-large", | |
| "text-embedding-3-small", | |
| "text-embedding-ada-002", | |
| ] | |
| BENCH_DATASET_PATHS = { | |
| ("length", "single"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated", "ocr_chunks_annotated.csv"), | |
| ("structure", "single"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated_structure", "ocr_chunks_annotated_structure.csv"), | |
| ("length", "multi"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated_cross", "ocr_chunks_annotated_length_multi.csv"), | |
| ("structure", "multi"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated_structure_cross", "ocr_chunks_annotated_structure_multi.csv"), | |
| } | |
| def _sanitize_name(text: str) -> str: | |
| return "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in str(text)) | |
| def _simple_npy_cache_file(model_name: str, chunk_mode: str) -> str: | |
| os.makedirs(EMBED_CACHE_DIR, exist_ok=True) | |
| fname = f"{_sanitize_name(model_name)}__{_sanitize_name(chunk_mode)}.npy" | |
| return os.path.join(EMBED_CACHE_DIR, fname) | |
| def _simple_bm25_cache_file(model_name: str, chunk_mode: str) -> str: | |
| os.makedirs(EMBED_CACHE_DIR, exist_ok=True) | |
| fname = f"{_sanitize_name(model_name)}__{_sanitize_name(chunk_mode)}.pkl" | |
| return os.path.join(EMBED_CACHE_DIR, fname) | |
| def _simple_query_npy_cache_file(model_name: str, chunk_mode: str) -> str: | |
| os.makedirs(EMBED_CACHE_DIR, exist_ok=True) | |
| fname = f"{_sanitize_name(model_name)}__{_sanitize_name(chunk_mode)}__query.npy" | |
| return os.path.join(EMBED_CACHE_DIR, fname) | |
| def _normalize_rows(x: np.ndarray) -> np.ndarray: | |
| if x.ndim != 2: | |
| return x | |
| norms = np.linalg.norm(x, axis=1, keepdims=True) | |
| norms[norms == 0] = 1.0 | |
| return x / norms | |
| def _tokenize(text: str) -> List[str]: | |
| return re.findall(r"[a-z0-9]+", str(text or "").lower()) | |
| def _parse_models(raw: str) -> List[str]: | |
| spec = str(raw or "all").strip() | |
| if spec.lower() == "all": | |
| return list(ALL_MODELS) | |
| out = [] | |
| for item in spec.split(","): | |
| m = item.strip() | |
| if not m: | |
| continue | |
| if m not in ALL_MODELS: | |
| raise ValueError(f"Unsupported model: {m}") | |
| out.append(m) | |
| if not out: | |
| raise ValueError("No valid models provided.") | |
| return out | |
| def _parse_chunk_modes(raw: str) -> List[str]: | |
| s = str(raw or "structure").strip().lower() | |
| if s == "all": | |
| return ["length", "structure"] | |
| if s in {"length", "structure"}: | |
| return [s] | |
| raise ValueError("chunk-mode must be one of: length, structure, all") | |
| def _load_npy(cache_file: str, expected_rows: int): | |
| try: | |
| if not os.path.isfile(cache_file): | |
| return None | |
| arr = np.load(cache_file) | |
| if arr.ndim != 2 or arr.shape[0] != expected_rows: | |
| return None | |
| return arr | |
| except Exception: | |
| return None | |
| def _save_npy(cache_file: str, arr: np.ndarray) -> None: | |
| np.save(cache_file, arr) | |
| def _encode_openai_texts( | |
| texts: List[str], | |
| model_name: str, | |
| api_key: str, | |
| base_url: str, | |
| batch_size: int, | |
| desc: str, | |
| ) -> np.ndarray: | |
| if not texts: | |
| return np.zeros((0, 1), dtype="float32") | |
| client = get_openai_client(api_key=api_key, base_url=base_url) | |
| vectors = [] | |
| step = max(1, int(batch_size)) | |
| starts = range(0, len(texts), step) | |
| for i in tqdm(starts, total=(len(texts) + step - 1) // step, desc=desc, unit="batch"): | |
| batch = [str(t or "") for t in texts[i : i + step]] | |
| resp = client.embeddings.create(model=model_name, input=batch) | |
| vectors.extend(_extract_embedding_vectors(resp)) | |
| arr = np.asarray(vectors, dtype="float32") | |
| if arr.ndim != 2: | |
| raise RuntimeError(f"OpenAI embedding output shape invalid: {arr.shape}") | |
| return _normalize_rows(arr) | |
| def _get_qwen_model(model_name: str, device: str): | |
| if model_name not in EMBED_MODEL_PATHS: | |
| raise ValueError(f"Unknown local Qwen model: {model_name}") | |
| model_path = EMBED_MODEL_PATHS[model_name] | |
| if not os.path.isdir(model_path): | |
| raise FileNotFoundError(f"Model path not found: {model_path}") | |
| os.environ.setdefault("TRANSFORMERS_NO_TF", "1") | |
| from create_embedding_search_results_qwen import Qwen3EmbeddingModel | |
| return Qwen3EmbeddingModel(model_path, device=device) | |
| def _encode_qwen_texts( | |
| texts: List[str], | |
| model, | |
| batch_size: int, | |
| desc: str, | |
| ) -> np.ndarray: | |
| if not texts: | |
| return np.zeros((0, 1), dtype="float32") | |
| step = max(1, int(batch_size)) | |
| batches = [texts[i : i + step] for i in range(0, len(texts), step)] | |
| parts = [] | |
| for b in tqdm(batches, total=len(batches), desc=desc, unit="batch"): | |
| emb = model.encode_documents(b, batch_size=step) | |
| parts.append(np.asarray(emb, dtype="float32")) | |
| arr = np.vstack(parts) if parts else np.zeros((0, 1), dtype="float32") | |
| return _normalize_rows(arr) | |
| def _build_bm25_payload(pool: List[Tuple[str, int, str]]) -> Dict: | |
| try: | |
| from rank_bm25 import BM25Okapi | |
| except Exception as e: | |
| raise RuntimeError("BM25 requires rank_bm25. Install with: pip install rank_bm25") from e | |
| texts = [p[2] for p in pool] | |
| tokenized = [_tokenize(t) for t in texts] | |
| bm25 = BM25Okapi(tokenized) | |
| return { | |
| "bm25": bm25, | |
| "report": [p[0] for p in pool], | |
| "chunk_idx": [int(p[1]) for p in pool], | |
| "n_docs": len(pool), | |
| } | |
| def _load_benchmark_corpus(chunk_mode: str, doc_mode: str) -> Tuple[List[Tuple[str, int, str]], List[str]]: | |
| import pandas as pd | |
| path = BENCH_DATASET_PATHS[(chunk_mode, doc_mode)] | |
| if not os.path.isfile(path): | |
| raise FileNotFoundError(f"Benchmark corpus not found: {path}") | |
| df = pd.read_csv(path) | |
| need_cols = {"report", "chunk_idx", "chunk_text", "question"} | |
| miss = [c for c in need_cols if c not in df.columns] | |
| if miss: | |
| raise ValueError(f"Missing columns in benchmark corpus {path}: {miss}") | |
| sub = df[["report", "chunk_idx", "chunk_text"]].drop_duplicates().copy() | |
| sub["report"] = sub["report"].astype(str) | |
| sub["chunk_idx"] = sub["chunk_idx"].astype(int) | |
| sub["chunk_text"] = sub["chunk_text"].astype(str) | |
| sub = sub.sort_values(["report", "chunk_idx"], ascending=[True, True]) | |
| pool = [(r.report, int(r.chunk_idx), str(r.chunk_text)) for r in sub.itertuples(index=False)] | |
| questions = sorted(df["question"].astype(str).drop_duplicates().tolist()) | |
| return pool, questions | |
| def _collect_benchmark_questions(chunk_mode: str) -> List[str]: | |
| all_q = set() | |
| for doc_mode in ("single", "multi"): | |
| _, questions = _load_benchmark_corpus(chunk_mode, doc_mode) | |
| for q in questions: | |
| all_q.add(str(q)) | |
| return sorted(all_q) | |
| def _run_probe(models: List[str], api_key: str, base_url: str, text: str) -> None: | |
| print("=== Probe API embedding responses ===") | |
| client = get_openai_client(api_key=api_key, base_url=base_url) | |
| for model in models: | |
| if model not in OPENAI_MODELS: | |
| continue | |
| resp = client.embeddings.create(model=model, input=text) | |
| vectors = _extract_embedding_vectors(resp) | |
| if vectors: | |
| vec = vectors[0] | |
| print(f"[probe] model={model}, dim={len(vec)}, first5={vec[:5]}") | |
| def _build_for_chunk_mode( | |
| models: List[str], | |
| chunk_mode: str, | |
| api_key: str, | |
| base_url: str, | |
| batch_size: int, | |
| device: str, | |
| ) -> None: | |
| report_chunks = get_report_chunks(chunk_mode) | |
| pool = _build_chunk_pool(report_chunks) | |
| if not pool: | |
| print(f"[skip] chunk_mode={chunk_mode}: empty chunk pool") | |
| return | |
| texts = [p[2] for p in pool] | |
| qwen_model_cache = {} | |
| for model in models: | |
| if model == BM25_MODEL: | |
| cache_file = _simple_bm25_cache_file(model, chunk_mode) | |
| if os.path.isfile(cache_file): | |
| print(f"[hit] {model} | {chunk_mode} | file={cache_file}") | |
| continue | |
| t0 = time.perf_counter() | |
| payload = _build_bm25_payload(pool) | |
| with open(cache_file, "wb") as f: | |
| pickle.dump(payload, f) | |
| print( | |
| f"[saved] {model} | {chunk_mode} | n_docs={payload['n_docs']} | " | |
| f"time={time.perf_counter() - t0:.1f}s | file={cache_file}" | |
| ) | |
| continue | |
| cache_file = _simple_npy_cache_file(model, chunk_mode) | |
| cached = _load_npy(cache_file, expected_rows=len(texts)) | |
| if cached is not None: | |
| print(f"[hit] {model} | {chunk_mode} | docs={cached.shape} | file={cache_file}") | |
| continue | |
| t0 = time.perf_counter() | |
| if model in OPENAI_MODELS: | |
| emb = _encode_openai_texts( | |
| texts=texts, | |
| model_name=model, | |
| api_key=api_key, | |
| base_url=base_url, | |
| batch_size=batch_size, | |
| desc=f"{model} [{chunk_mode}]", | |
| ) | |
| elif model in QWEN_MODELS: | |
| if model not in qwen_model_cache: | |
| qwen_model_cache[model] = _get_qwen_model(model_name=model, device=device) | |
| emb = _encode_qwen_texts( | |
| texts=texts, | |
| model=qwen_model_cache[model], | |
| batch_size=batch_size, | |
| desc=f"{model} [{chunk_mode}]", | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported model: {model}") | |
| if emb.shape[0] != len(texts): | |
| raise RuntimeError(f"Embedding row mismatch: expected {len(texts)}, got {emb.shape[0]}") | |
| _save_npy(cache_file, emb) | |
| print( | |
| f"[saved] {model} | {chunk_mode} | docs={emb.shape} | " | |
| f"time={time.perf_counter() - t0:.1f}s | file={cache_file}" | |
| ) | |
| def _build_benchmark_for_chunk_mode( | |
| models: List[str], | |
| chunk_mode: str, | |
| api_key: str, | |
| base_url: str, | |
| batch_size: int, | |
| device: str, | |
| ) -> None: | |
| questions = _collect_benchmark_questions(chunk_mode) | |
| if not questions: | |
| print(f"[skip] benchmark queries chunk_mode={chunk_mode}: no questions") | |
| return | |
| qwen_model_cache = {} | |
| print(f"[info] benchmark queries chunk_mode={chunk_mode}, count={len(questions)}") | |
| for model in models: | |
| if model == BM25_MODEL: | |
| continue | |
| q_cache = _simple_query_npy_cache_file(model, chunk_mode) | |
| qry_cached = _load_npy(q_cache, expected_rows=len(questions)) | |
| if qry_cached is not None: | |
| print(f"[hit] {model} | {chunk_mode} | benchmark queries={qry_cached.shape} | file={q_cache}") | |
| continue | |
| if model in OPENAI_MODELS: | |
| t0 = time.perf_counter() | |
| qry_emb = _encode_openai_texts( | |
| texts=questions, | |
| model_name=model, | |
| api_key=api_key, | |
| base_url=base_url, | |
| batch_size=batch_size, | |
| desc=f"{model} queries [{chunk_mode}]", | |
| ) | |
| _save_npy(q_cache, qry_emb) | |
| print( | |
| f"[saved] {model} | {chunk_mode} | benchmark queries={qry_emb.shape} | " | |
| f"time={time.perf_counter() - t0:.1f}s | file={q_cache}" | |
| ) | |
| elif model in QWEN_MODELS: | |
| if model not in qwen_model_cache: | |
| qwen_model_cache[model] = _get_qwen_model(model_name=model, device=device) | |
| qwen_model = qwen_model_cache[model] | |
| t0 = time.perf_counter() | |
| qry_emb = _encode_qwen_texts( | |
| texts=questions, | |
| model=qwen_model, | |
| batch_size=batch_size, | |
| desc=f"{model} queries [{chunk_mode}]", | |
| ) | |
| _save_npy(q_cache, qry_emb) | |
| print( | |
| f"[saved] {model} | {chunk_mode} | benchmark queries={qry_emb.shape} | " | |
| f"time={time.perf_counter() - t0:.1f}s | file={q_cache}" | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported model: {model}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--models", default="all", help="Comma list or 'all'") | |
| parser.add_argument("--chunk-mode", default="structure", help="length|structure|all") | |
| parser.add_argument("--api-key", default="", help="Optional. Falls back to OPENAI_API_KEY") | |
| parser.add_argument("--base-url", default=OPENAI_EMBED_BASE_URL, help="OpenAI-compatible base URL") | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--device", default="cuda", help="cuda|cpu for local Qwen models") | |
| parser.add_argument("--probe", action="store_true", help="Run one-text probe for API models") | |
| parser.add_argument("--probe-text", default="Climate-related financial disclosure under IFRS S2.") | |
| parser.add_argument("--skip-benchmark", action="store_true", help="Skip building benchmark single/multi query caches") | |
| args = parser.parse_args() | |
| models = _parse_models(args.models) | |
| chunk_modes = _parse_chunk_modes(args.chunk_mode) | |
| base_url = (str(args.base_url or "").strip() or OPENAI_EMBED_BASE_URL).rstrip("/") | |
| needs_api = any(m in OPENAI_MODELS for m in models) | |
| api_key = _resolve_api_key(args.api_key) | |
| if needs_api and not api_key: | |
| raise RuntimeError("Missing API key for OpenAI embedding models. Use --api-key or set OPENAI_API_KEY.") | |
| print(f"Models: {models}") | |
| print(f"Chunk modes: {chunk_modes}") | |
| print(f"Base URL: {base_url}") | |
| print(f"Batch size: {args.batch_size}") | |
| print(f"Device: {args.device}") | |
| print(f"Build benchmark caches: {not bool(args.skip_benchmark)}") | |
| if args.probe and needs_api: | |
| _run_probe(models=models, api_key=api_key, base_url=base_url, text=args.probe_text) | |
| for chunk_mode in chunk_modes: | |
| _build_for_chunk_mode( | |
| models=models, | |
| chunk_mode=chunk_mode, | |
| api_key=api_key, | |
| base_url=base_url, | |
| batch_size=max(1, int(args.batch_size)), | |
| device=str(args.device).strip(), | |
| ) | |
| if not bool(args.skip_benchmark): | |
| _build_benchmark_for_chunk_mode( | |
| models=models, | |
| chunk_mode=chunk_mode, | |
| api_key=api_key, | |
| base_url=base_url, | |
| batch_size=max(1, int(args.batch_size)), | |
| device=str(args.device).strip(), | |
| ) | |
| print("Done.") | |
| if __name__ == "__main__": | |
| main() | |