ClimateRAG_QA / Experiments /openai_embedding.py
tengfeiCheng's picture
add cleaned experiments code
12323e1
"""
Unified embedding cache builder for ReguRAG (single implementation).
Behavior:
- Only one cache policy: if cache exists, reuse it; otherwise build and save.
- Cache key only uses (model, chunk_mode).
- No target/doc-mode/overwrite mode split.
Benchmark query caches:
- To support offline benchmark evaluation, this script also builds benchmark caches
for single-doc and multi-doc question sets.
Supported models:
- BM25
- Qwen3-Embedding-0.6B
- Qwen3-Embedding-4B
- text-embedding-3-large
- text-embedding-3-small
- text-embedding-ada-002
Examples:
python openai_embedding.py --models all --chunk-mode all
python openai_embedding.py --models text-embedding-3-small --chunk-mode structure --base-url https://88996.cloud/v1
python openai_embedding.py --models Qwen3-Embedding-4B --chunk-mode structure --device cuda
"""
from __future__ import annotations
import argparse
import os
import pickle
import re
import sys
import time
from typing import Dict, List, Tuple
import numpy as np
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
if SCRIPT_DIR not in sys.path:
sys.path.insert(0, SCRIPT_DIR)
from rag_app_backend import (
EMBED_CACHE_DIR,
EMBED_MODEL_PATHS,
OPENAI_EMBED_BASE_URL,
OPENAI_EMBED_MODELS,
_build_chunk_pool,
_extract_embedding_vectors,
_resolve_api_key,
get_openai_client,
get_report_chunks,
)
try:
from tqdm.auto import tqdm
except Exception:
def tqdm(x, **kwargs): # type: ignore
return x
BM25_MODEL = "BM25"
QWEN_MODELS = {"Qwen3-Embedding-0.6B", "Qwen3-Embedding-4B"}
OPENAI_MODELS = set(OPENAI_EMBED_MODELS)
ALL_MODELS = [
BM25_MODEL,
"Qwen3-Embedding-0.6B",
"Qwen3-Embedding-4B",
"text-embedding-3-large",
"text-embedding-3-small",
"text-embedding-ada-002",
]
BENCH_DATASET_PATHS = {
("length", "single"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated", "ocr_chunks_annotated.csv"),
("structure", "single"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated_structure", "ocr_chunks_annotated_structure.csv"),
("length", "multi"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated_cross", "ocr_chunks_annotated_length_multi.csv"),
("structure", "multi"): os.path.join(SCRIPT_DIR, "..", "OCR_Chunked_Annotated_structure_cross", "ocr_chunks_annotated_structure_multi.csv"),
}
def _sanitize_name(text: str) -> str:
return "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in str(text))
def _simple_npy_cache_file(model_name: str, chunk_mode: str) -> str:
os.makedirs(EMBED_CACHE_DIR, exist_ok=True)
fname = f"{_sanitize_name(model_name)}__{_sanitize_name(chunk_mode)}.npy"
return os.path.join(EMBED_CACHE_DIR, fname)
def _simple_bm25_cache_file(model_name: str, chunk_mode: str) -> str:
os.makedirs(EMBED_CACHE_DIR, exist_ok=True)
fname = f"{_sanitize_name(model_name)}__{_sanitize_name(chunk_mode)}.pkl"
return os.path.join(EMBED_CACHE_DIR, fname)
def _simple_query_npy_cache_file(model_name: str, chunk_mode: str) -> str:
os.makedirs(EMBED_CACHE_DIR, exist_ok=True)
fname = f"{_sanitize_name(model_name)}__{_sanitize_name(chunk_mode)}__query.npy"
return os.path.join(EMBED_CACHE_DIR, fname)
def _normalize_rows(x: np.ndarray) -> np.ndarray:
if x.ndim != 2:
return x
norms = np.linalg.norm(x, axis=1, keepdims=True)
norms[norms == 0] = 1.0
return x / norms
def _tokenize(text: str) -> List[str]:
return re.findall(r"[a-z0-9]+", str(text or "").lower())
def _parse_models(raw: str) -> List[str]:
spec = str(raw or "all").strip()
if spec.lower() == "all":
return list(ALL_MODELS)
out = []
for item in spec.split(","):
m = item.strip()
if not m:
continue
if m not in ALL_MODELS:
raise ValueError(f"Unsupported model: {m}")
out.append(m)
if not out:
raise ValueError("No valid models provided.")
return out
def _parse_chunk_modes(raw: str) -> List[str]:
s = str(raw or "structure").strip().lower()
if s == "all":
return ["length", "structure"]
if s in {"length", "structure"}:
return [s]
raise ValueError("chunk-mode must be one of: length, structure, all")
def _load_npy(cache_file: str, expected_rows: int):
try:
if not os.path.isfile(cache_file):
return None
arr = np.load(cache_file)
if arr.ndim != 2 or arr.shape[0] != expected_rows:
return None
return arr
except Exception:
return None
def _save_npy(cache_file: str, arr: np.ndarray) -> None:
np.save(cache_file, arr)
def _encode_openai_texts(
texts: List[str],
model_name: str,
api_key: str,
base_url: str,
batch_size: int,
desc: str,
) -> np.ndarray:
if not texts:
return np.zeros((0, 1), dtype="float32")
client = get_openai_client(api_key=api_key, base_url=base_url)
vectors = []
step = max(1, int(batch_size))
starts = range(0, len(texts), step)
for i in tqdm(starts, total=(len(texts) + step - 1) // step, desc=desc, unit="batch"):
batch = [str(t or "") for t in texts[i : i + step]]
resp = client.embeddings.create(model=model_name, input=batch)
vectors.extend(_extract_embedding_vectors(resp))
arr = np.asarray(vectors, dtype="float32")
if arr.ndim != 2:
raise RuntimeError(f"OpenAI embedding output shape invalid: {arr.shape}")
return _normalize_rows(arr)
def _get_qwen_model(model_name: str, device: str):
if model_name not in EMBED_MODEL_PATHS:
raise ValueError(f"Unknown local Qwen model: {model_name}")
model_path = EMBED_MODEL_PATHS[model_name]
if not os.path.isdir(model_path):
raise FileNotFoundError(f"Model path not found: {model_path}")
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
from create_embedding_search_results_qwen import Qwen3EmbeddingModel
return Qwen3EmbeddingModel(model_path, device=device)
def _encode_qwen_texts(
texts: List[str],
model,
batch_size: int,
desc: str,
) -> np.ndarray:
if not texts:
return np.zeros((0, 1), dtype="float32")
step = max(1, int(batch_size))
batches = [texts[i : i + step] for i in range(0, len(texts), step)]
parts = []
for b in tqdm(batches, total=len(batches), desc=desc, unit="batch"):
emb = model.encode_documents(b, batch_size=step)
parts.append(np.asarray(emb, dtype="float32"))
arr = np.vstack(parts) if parts else np.zeros((0, 1), dtype="float32")
return _normalize_rows(arr)
def _build_bm25_payload(pool: List[Tuple[str, int, str]]) -> Dict:
try:
from rank_bm25 import BM25Okapi
except Exception as e:
raise RuntimeError("BM25 requires rank_bm25. Install with: pip install rank_bm25") from e
texts = [p[2] for p in pool]
tokenized = [_tokenize(t) for t in texts]
bm25 = BM25Okapi(tokenized)
return {
"bm25": bm25,
"report": [p[0] for p in pool],
"chunk_idx": [int(p[1]) for p in pool],
"n_docs": len(pool),
}
def _load_benchmark_corpus(chunk_mode: str, doc_mode: str) -> Tuple[List[Tuple[str, int, str]], List[str]]:
import pandas as pd
path = BENCH_DATASET_PATHS[(chunk_mode, doc_mode)]
if not os.path.isfile(path):
raise FileNotFoundError(f"Benchmark corpus not found: {path}")
df = pd.read_csv(path)
need_cols = {"report", "chunk_idx", "chunk_text", "question"}
miss = [c for c in need_cols if c not in df.columns]
if miss:
raise ValueError(f"Missing columns in benchmark corpus {path}: {miss}")
sub = df[["report", "chunk_idx", "chunk_text"]].drop_duplicates().copy()
sub["report"] = sub["report"].astype(str)
sub["chunk_idx"] = sub["chunk_idx"].astype(int)
sub["chunk_text"] = sub["chunk_text"].astype(str)
sub = sub.sort_values(["report", "chunk_idx"], ascending=[True, True])
pool = [(r.report, int(r.chunk_idx), str(r.chunk_text)) for r in sub.itertuples(index=False)]
questions = sorted(df["question"].astype(str).drop_duplicates().tolist())
return pool, questions
def _collect_benchmark_questions(chunk_mode: str) -> List[str]:
all_q = set()
for doc_mode in ("single", "multi"):
_, questions = _load_benchmark_corpus(chunk_mode, doc_mode)
for q in questions:
all_q.add(str(q))
return sorted(all_q)
def _run_probe(models: List[str], api_key: str, base_url: str, text: str) -> None:
print("=== Probe API embedding responses ===")
client = get_openai_client(api_key=api_key, base_url=base_url)
for model in models:
if model not in OPENAI_MODELS:
continue
resp = client.embeddings.create(model=model, input=text)
vectors = _extract_embedding_vectors(resp)
if vectors:
vec = vectors[0]
print(f"[probe] model={model}, dim={len(vec)}, first5={vec[:5]}")
def _build_for_chunk_mode(
models: List[str],
chunk_mode: str,
api_key: str,
base_url: str,
batch_size: int,
device: str,
) -> None:
report_chunks = get_report_chunks(chunk_mode)
pool = _build_chunk_pool(report_chunks)
if not pool:
print(f"[skip] chunk_mode={chunk_mode}: empty chunk pool")
return
texts = [p[2] for p in pool]
qwen_model_cache = {}
for model in models:
if model == BM25_MODEL:
cache_file = _simple_bm25_cache_file(model, chunk_mode)
if os.path.isfile(cache_file):
print(f"[hit] {model} | {chunk_mode} | file={cache_file}")
continue
t0 = time.perf_counter()
payload = _build_bm25_payload(pool)
with open(cache_file, "wb") as f:
pickle.dump(payload, f)
print(
f"[saved] {model} | {chunk_mode} | n_docs={payload['n_docs']} | "
f"time={time.perf_counter() - t0:.1f}s | file={cache_file}"
)
continue
cache_file = _simple_npy_cache_file(model, chunk_mode)
cached = _load_npy(cache_file, expected_rows=len(texts))
if cached is not None:
print(f"[hit] {model} | {chunk_mode} | docs={cached.shape} | file={cache_file}")
continue
t0 = time.perf_counter()
if model in OPENAI_MODELS:
emb = _encode_openai_texts(
texts=texts,
model_name=model,
api_key=api_key,
base_url=base_url,
batch_size=batch_size,
desc=f"{model} [{chunk_mode}]",
)
elif model in QWEN_MODELS:
if model not in qwen_model_cache:
qwen_model_cache[model] = _get_qwen_model(model_name=model, device=device)
emb = _encode_qwen_texts(
texts=texts,
model=qwen_model_cache[model],
batch_size=batch_size,
desc=f"{model} [{chunk_mode}]",
)
else:
raise ValueError(f"Unsupported model: {model}")
if emb.shape[0] != len(texts):
raise RuntimeError(f"Embedding row mismatch: expected {len(texts)}, got {emb.shape[0]}")
_save_npy(cache_file, emb)
print(
f"[saved] {model} | {chunk_mode} | docs={emb.shape} | "
f"time={time.perf_counter() - t0:.1f}s | file={cache_file}"
)
def _build_benchmark_for_chunk_mode(
models: List[str],
chunk_mode: str,
api_key: str,
base_url: str,
batch_size: int,
device: str,
) -> None:
questions = _collect_benchmark_questions(chunk_mode)
if not questions:
print(f"[skip] benchmark queries chunk_mode={chunk_mode}: no questions")
return
qwen_model_cache = {}
print(f"[info] benchmark queries chunk_mode={chunk_mode}, count={len(questions)}")
for model in models:
if model == BM25_MODEL:
continue
q_cache = _simple_query_npy_cache_file(model, chunk_mode)
qry_cached = _load_npy(q_cache, expected_rows=len(questions))
if qry_cached is not None:
print(f"[hit] {model} | {chunk_mode} | benchmark queries={qry_cached.shape} | file={q_cache}")
continue
if model in OPENAI_MODELS:
t0 = time.perf_counter()
qry_emb = _encode_openai_texts(
texts=questions,
model_name=model,
api_key=api_key,
base_url=base_url,
batch_size=batch_size,
desc=f"{model} queries [{chunk_mode}]",
)
_save_npy(q_cache, qry_emb)
print(
f"[saved] {model} | {chunk_mode} | benchmark queries={qry_emb.shape} | "
f"time={time.perf_counter() - t0:.1f}s | file={q_cache}"
)
elif model in QWEN_MODELS:
if model not in qwen_model_cache:
qwen_model_cache[model] = _get_qwen_model(model_name=model, device=device)
qwen_model = qwen_model_cache[model]
t0 = time.perf_counter()
qry_emb = _encode_qwen_texts(
texts=questions,
model=qwen_model,
batch_size=batch_size,
desc=f"{model} queries [{chunk_mode}]",
)
_save_npy(q_cache, qry_emb)
print(
f"[saved] {model} | {chunk_mode} | benchmark queries={qry_emb.shape} | "
f"time={time.perf_counter() - t0:.1f}s | file={q_cache}"
)
else:
raise ValueError(f"Unsupported model: {model}")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--models", default="all", help="Comma list or 'all'")
parser.add_argument("--chunk-mode", default="structure", help="length|structure|all")
parser.add_argument("--api-key", default="", help="Optional. Falls back to OPENAI_API_KEY")
parser.add_argument("--base-url", default=OPENAI_EMBED_BASE_URL, help="OpenAI-compatible base URL")
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--device", default="cuda", help="cuda|cpu for local Qwen models")
parser.add_argument("--probe", action="store_true", help="Run one-text probe for API models")
parser.add_argument("--probe-text", default="Climate-related financial disclosure under IFRS S2.")
parser.add_argument("--skip-benchmark", action="store_true", help="Skip building benchmark single/multi query caches")
args = parser.parse_args()
models = _parse_models(args.models)
chunk_modes = _parse_chunk_modes(args.chunk_mode)
base_url = (str(args.base_url or "").strip() or OPENAI_EMBED_BASE_URL).rstrip("/")
needs_api = any(m in OPENAI_MODELS for m in models)
api_key = _resolve_api_key(args.api_key)
if needs_api and not api_key:
raise RuntimeError("Missing API key for OpenAI embedding models. Use --api-key or set OPENAI_API_KEY.")
print(f"Models: {models}")
print(f"Chunk modes: {chunk_modes}")
print(f"Base URL: {base_url}")
print(f"Batch size: {args.batch_size}")
print(f"Device: {args.device}")
print(f"Build benchmark caches: {not bool(args.skip_benchmark)}")
if args.probe and needs_api:
_run_probe(models=models, api_key=api_key, base_url=base_url, text=args.probe_text)
for chunk_mode in chunk_modes:
_build_for_chunk_mode(
models=models,
chunk_mode=chunk_mode,
api_key=api_key,
base_url=base_url,
batch_size=max(1, int(args.batch_size)),
device=str(args.device).strip(),
)
if not bool(args.skip_benchmark):
_build_benchmark_for_chunk_mode(
models=models,
chunk_mode=chunk_mode,
api_key=api_key,
base_url=base_url,
batch_size=max(1, int(args.batch_size)),
device=str(args.device).strip(),
)
print("Done.")
if __name__ == "__main__":
main()