ClimateRAG_QA / Experiments /rag_app_backend.py
tengfeiCheng's picture
update
03cb531
"""
RAG Backend — global model loading + instant retrieve / generate functions.
On server (GPU available):
MOCK_MODE is auto-disabled; real models are loaded once at first call.
On local dev (no GPU / no model files):
MOCK_MODE is auto-enabled; placeholder results are returned so the UI works.
Heavy dependencies (torch, sentence_transformers, vllm) are only imported
when MOCK_MODE is False, so the Gradio UI can start locally without them.
"""
import os
import re
import sys
import json
import hashlib
import pickle
import multiprocessing as mp
from pathlib import Path
from typing import List, Dict, Any
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
if SCRIPT_DIR not in sys.path:
sys.path.insert(0, SCRIPT_DIR)
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
try:
mp.set_start_method("spawn", force=False)
except Exception:
pass
# ======================== Config ========================
EMBED_MODEL_PATHS = {
"Qwen3-Embedding-0.6B": os.path.join(SCRIPT_DIR, "..", "model", "Qwen3-Embedding-0.6B"),
"Qwen3-Embedding-4B": os.path.join(SCRIPT_DIR, "..", "model", "Qwen3-Embedding-4B"),
}
OPENAI_EMBED_MODELS = {
"text-embedding-3-large",
"text-embedding-3-small",
"text-embedding-ada-002",
}
OPENAI_EMBED_BASE_URL = os.getenv("OPENAI_EMBED_BASE_URL", "https://api.openai.com/v1")
OPENAI_CHAT_MODEL = os.getenv("OPENAI_CHAT_MODEL", "gpt-5-mini")
API_GEN_MODEL_ALIASES = {
"GPT-5 (API)": "gpt-5",
"GPT-5-mini (API)": "gpt-5-mini",
"GPT-4o (API)": "gpt-4o",
"Gemini-2.5-Flash (API)": "gemini-2.5-flash",
"Gemini-3-Flash-Preview (API)": "gemini-3-flash-preview",
}
GEN_MODEL_PATH = os.path.join(SCRIPT_DIR, "..", "model", "Qwen3-4B-Instruct-2507-FP8")
REPORTS_DIR = os.path.join(SCRIPT_DIR, "..", "Reports")
MINERU_ABS = os.path.join(SCRIPT_DIR, "..", "MinerU_Reports")
EMBED_CACHE_DIR = os.path.join(SCRIPT_DIR, "Embedding_Cache")
OCR_CHUNKS_ALL_LENGTH_JSON = os.path.join(
SCRIPT_DIR, "..", "OCR_Chunked_Annotated", "ocr_chunks_all.json"
)
OCR_CHUNKS_ALL_STRUCTURE_JSON = os.path.join(
SCRIPT_DIR, "..", "OCR_Chunked_Annotated_structure", "ocr_chunks_all_structure.json"
)
CHUNK_SIZE = 350
CHUNK_OVERLAP = 50
STRUCTURE_MAX_TOKENS = 550
# ======================== Auto-detect mock mode ========================
def _detect_gpu() -> bool:
try:
import torch
return bool(torch.cuda.is_available())
except Exception:
return False
HAS_GPU = _detect_gpu()
def _check_mock_mode() -> bool:
if not HAS_GPU:
return True
default_embed = list(EMBED_MODEL_PATHS.values())[0]
if not os.path.isdir(default_embed):
return True
return False
MOCK_MODE = _check_mock_mode()
# ======================== Global State ========================
_embed_models: Dict[str, object] = {}
_llm_model = None
_sampling_params = None
_report_chunks_cache: Dict[str, Dict[str, list]] = {}
_chunk_embeddings_cache: Dict[str, object] = {}
_bm25_cache: Dict[str, object] = {}
def _lazy_import_heavy():
import importlib
return importlib.import_module("create_embedding_search_results_qwen")
def _get_embed_model(embed_name: str):
if MOCK_MODE:
return None
if embed_name not in _embed_models:
path = EMBED_MODEL_PATHS.get(embed_name)
if path is None or not os.path.isdir(path):
raise FileNotFoundError(f"Embedding model not found: {path}")
safe_files = list(Path(path).rglob("*.safetensors"))
if safe_files and any(f.stat().st_size == 0 for f in safe_files):
bad = [str(f) for f in safe_files if f.stat().st_size == 0][:3]
raise RuntimeError(
"Detected empty .safetensors file(s) in embedding model directory. "
f"Please re-download model files. Examples: {bad}"
)
mod = _lazy_import_heavy()
try:
_embed_models[embed_name] = mod.Qwen3EmbeddingModel(path, device="cuda")
except Exception as e:
msg = str(e)
if "Error while deserializing header" in msg or "invalid JSON in header" in msg:
raise RuntimeError(
"Failed to load local Qwen embedding model. "
"This usually means corrupted/incomplete .safetensors files or interrupted copy. "
f"Model path: {path}"
) from e
raise
return _embed_models[embed_name]
def _get_llm():
global _llm_model, _sampling_params
if MOCK_MODE:
return None, None
if _llm_model is None:
from vllm import LLM, SamplingParams
_llm_model = LLM(model=GEN_MODEL_PATH, max_model_len=16384,
dtype="auto", trust_remote_code=True, gpu_memory_utilization=0.6,)
_sampling_params = SamplingParams(temperature=0.0, max_tokens=2048, top_p=1.0)
return _llm_model, _sampling_params
def _sanitize_name(text: str) -> str:
return "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in str(text))
def _build_chunk_pool(report_chunks: Dict[str, list]) -> List[tuple]:
pool = []
for report in sorted(report_chunks.keys()):
chunks = report_chunks.get(report, [])
for idx, item in enumerate(chunks):
chunk_idx = idx
text = item
page = None
if isinstance(item, (tuple, list)) and len(item) >= 2:
try:
chunk_idx = int(item[0])
except Exception:
chunk_idx = idx
text = item[1]
if len(item) >= 3:
page = item[2]
elif isinstance(item, dict):
try:
chunk_idx = int(item.get("chunk_idx", idx))
except Exception:
chunk_idx = idx
text = item.get("chunk_text", item.get("text", ""))
page = item.get("page", item.get("page_start", None))
try:
page = int(page) if page not in (None, "", "NA") else None
except Exception:
page = None
pool.append((report, chunk_idx, str(text or ""), page))
return pool
def _chunk_pool_signature(pool: List[tuple]) -> str:
h = hashlib.sha256()
for entry in pool:
report = entry[0]
idx = entry[1]
text = entry[2]
h.update(str(report).encode("utf-8", errors="ignore"))
h.update(b"\x1f")
h.update(str(idx).encode("utf-8", errors="ignore"))
h.update(b"\x1f")
h.update(str(text).encode("utf-8", errors="ignore"))
h.update(b"\x1e")
return h.hexdigest()[:16]
def _get_embed_cache_file(embed_name: str, chunk_mode: str, signature: str) -> str:
os.makedirs(EMBED_CACHE_DIR, exist_ok=True)
fname = f"{_sanitize_name(embed_name)}__{_sanitize_name(chunk_mode)}__{signature}.npy"
return os.path.join(EMBED_CACHE_DIR, fname)
def _get_simple_embed_cache_file(embed_name: str, chunk_mode: str) -> str:
"""Cache path keyed only by (model, chunk_mode)."""
os.makedirs(EMBED_CACHE_DIR, exist_ok=True)
fname = f"{_sanitize_name(embed_name)}__{_sanitize_name(chunk_mode)}.npy"
return os.path.join(EMBED_CACHE_DIR, fname)
def _get_simple_bm25_cache_file(chunk_mode: str) -> str:
os.makedirs(EMBED_CACHE_DIR, exist_ok=True)
fname = f"{_sanitize_name('BM25')}__{_sanitize_name(chunk_mode)}.pkl"
return os.path.join(EMBED_CACHE_DIR, fname)
def _tokenize_for_bm25(text: str) -> List[str]:
return re.findall(r"[a-z0-9]+", str(text or "").lower())
def _load_bm25_from_cache(cache_file: str, pool: List[tuple]):
try:
if not os.path.isfile(cache_file):
return None
with open(cache_file, "rb") as f:
payload = pickle.load(f)
bm25 = payload.get("bm25", None) if isinstance(payload, dict) else None
if bm25 is None:
return None
cached_n = int(payload.get("n_docs", -1)) if isinstance(payload, dict) else -1
if cached_n != len(pool):
return None
cached_reports = payload.get("report", []) if isinstance(payload, dict) else []
cached_idx = payload.get("chunk_idx", []) if isinstance(payload, dict) else []
if len(cached_reports) == len(pool) and len(cached_idx) == len(pool):
for i in range(min(50, len(pool))):
if str(cached_reports[i]) != str(pool[i][0]) or int(cached_idx[i]) != int(pool[i][1]):
return None
return bm25
except Exception:
return None
def _build_bm25(pool: List[tuple]):
try:
from rank_bm25 import BM25Okapi
except Exception as e:
raise RuntimeError("BM25 requires rank_bm25. Install with: pip install rank_bm25") from e
tokenized = [_tokenize_for_bm25(p[2]) for p in pool]
return BM25Okapi(tokenized)
def _save_bm25_cache(cache_file: str, bm25, pool: List[tuple]) -> None:
try:
payload = {
"bm25": bm25,
"report": [p[0] for p in pool],
"chunk_idx": [int(p[1]) for p in pool],
"n_docs": len(pool),
}
with open(cache_file, "wb") as f:
pickle.dump(payload, f)
except Exception:
pass
def _get_bm25(chunk_mode: str, pool: List[tuple]):
key = f"BM25__{chunk_mode}"
if key in _bm25_cache:
return _bm25_cache[key]
cache_file = _get_simple_bm25_cache_file(chunk_mode)
bm25 = _load_bm25_from_cache(cache_file, pool)
if bm25 is None:
bm25 = _build_bm25(pool)
_save_bm25_cache(cache_file, bm25, pool)
_bm25_cache[key] = bm25
return bm25
def _load_cached_embeddings(cache_file: str, expected_rows: int):
try:
if not os.path.isfile(cache_file):
return None
import numpy as np
arr = np.load(cache_file)
if arr.ndim != 2 or arr.shape[0] != expected_rows:
return None
return arr
except Exception:
return None
def _save_cached_embeddings(cache_file: str, embeddings) -> None:
try:
import numpy as np
np.save(cache_file, embeddings)
except Exception:
# Best-effort caching; retrieval should still work without disk persistence.
pass
def _resolve_api_key(api_key: str) -> str:
key = str(api_key or "").strip()
if key:
return key
env_key = str(os.getenv("OPENAI_API_KEY", "")).strip()
if env_key:
return env_key
return str(os.getenv("OPENAI_API_KEY_88996", "")).strip()
def _resolve_gemini_api_key(api_key: str) -> str:
env_key = str(os.getenv("GEMINI_API_KEY", "")).strip()
if env_key:
return env_key
return str(api_key or "").strip()
def _is_openai_embedding_model(embed_name: str) -> bool:
return str(embed_name or "").strip() in OPENAI_EMBED_MODELS
def _is_api_generation_model(gen_model: str) -> bool:
text = str(gen_model or "").strip()
up = text.upper()
if text in API_GEN_MODEL_ALIASES:
return True
return ("GPT" in up) or ("GEMINI" in up) or ("(API)" in up)
def _resolve_api_generation_model_name(gen_model: str) -> str:
raw = str(gen_model or "").strip()
if raw in API_GEN_MODEL_ALIASES:
return API_GEN_MODEL_ALIASES[raw]
return raw or OPENAI_CHAT_MODEL
def _is_gemini_generation_model(gen_model: str) -> bool:
model = str(_resolve_api_generation_model_name(gen_model) or "").strip().lower()
return "gemini" in model
def _extract_gemini_text(response: Any) -> str:
text = getattr(response, "text", None)
if isinstance(text, str) and text.strip():
return text.strip()
candidates = getattr(response, "candidates", None)
if isinstance(candidates, list):
out = []
for c in candidates:
content = getattr(c, "content", None)
parts = getattr(content, "parts", None) if content is not None else None
if not isinstance(parts, list):
continue
for p in parts:
t = getattr(p, "text", None)
if isinstance(t, str) and t.strip():
out.append(t.strip())
if out:
return "\n".join(out).strip()
return ""
def _call_openai_chat_model(
model_name: str,
system_msg: str,
user_msg: str,
api_key: str,
max_tokens: int,
) -> str:
resolved_key = _resolve_api_key(api_key)
if not resolved_key:
return "\u26a0\ufe0f Please provide an OpenAI API key to use OpenAI API generation models."
try:
import openai
except ImportError:
return "\u26a0\ufe0f The 'openai' package is not installed. Run: pip install openai"
try:
openai_base_url = str(os.getenv("OPENAI_BASE_URL", "")).strip()
if openai_base_url:
client = openai.OpenAI(api_key=resolved_key, base_url=openai_base_url.rstrip("/"))
else:
client = openai.OpenAI(api_key=resolved_key)
response = client.chat.completions.create(
model=str(model_name).strip(),
messages=[
{"role": "system", "content": str(system_msg or "")},
{"role": "user", "content": str(user_msg or "")},
],
temperature=0,
max_tokens=int(max_tokens),
timeout=120.0,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"\u26a0\ufe0f API Error: {e}"
def _call_gemini_chat_model(
model_name: str,
system_msg: str,
user_msg: str,
api_key: str,
max_tokens: int,
) -> str:
resolved_key = _resolve_gemini_api_key(api_key)
if not resolved_key:
return "\u26a0\ufe0f Please provide a Gemini API key (GEMINI_API_KEY) to use Gemini API generation models."
try:
from google import genai
except ImportError:
return "\u26a0\ufe0f The 'google-genai' package is not installed. Run: pip install google-genai"
try:
client = genai.Client(api_key=resolved_key)
full_prompt = (
f"System instruction:\n{str(system_msg or '').strip()}\n\nUser task:\n{str(user_msg or '').strip()}"
).strip()
config = {"max_output_tokens": int(max_tokens)} if int(max_tokens) > 0 else None
if config:
response = client.models.generate_content(
model=str(model_name).strip(),
contents=full_prompt,
config=config,
)
else:
response = client.models.generate_content(
model=str(model_name).strip(),
contents=full_prompt,
)
text = _extract_gemini_text(response)
if text:
return text
return "\u26a0\ufe0f API Error: Gemini returned empty content."
except Exception as e:
return f"\u26a0\ufe0f API Error: {e}"
def get_openai_client(api_key: str, base_url: str = OPENAI_EMBED_BASE_URL):
resolved = _resolve_api_key(api_key)
if not resolved:
raise RuntimeError("Missing API key.")
try:
from openai import OpenAI
except Exception as e:
raise RuntimeError(f"openai package is required for API embedding/generation: {e}") from e
kwargs = {"api_key": resolved}
if base_url:
kwargs["base_url"] = str(base_url).rstrip("/")
return OpenAI(**kwargs)
def get_embedding(text: str, model_name: str, client):
response = client.embeddings.create(
model=model_name,
input=text,
)
vectors = _extract_embedding_vectors(response)
if not vectors:
raise RuntimeError("Embedding API returned no vectors.")
return vectors[0]
def _extract_embedding_vectors(response: Any) -> List[List[float]]:
"""Handle SDK/object/dict/string responses from OpenAI-compatible gateways."""
payload: Any = response
if isinstance(payload, str):
try:
payload = json.loads(payload)
except Exception:
raise RuntimeError(
f"Unexpected embedding response string (not JSON): {payload[:240]}"
)
if hasattr(payload, "model_dump"):
try:
payload = payload.model_dump()
except Exception:
pass
data = None
if hasattr(payload, "data"):
data = payload.data
elif isinstance(payload, dict):
data = payload.get("data", None)
if data is None:
if isinstance(payload, list):
data = payload
else:
raise RuntimeError(
"Unexpected embedding response format: cannot find `data` field."
)
vectors: List[List[float]] = []
for item in data:
vec = None
if hasattr(item, "embedding"):
vec = item.embedding
elif isinstance(item, dict):
vec = item.get("embedding", None)
elif isinstance(item, list):
vec = item
if vec is None:
continue
if hasattr(vec, "tolist"):
vec = vec.tolist()
vectors.append([float(x) for x in vec])
return vectors
def _encode_openai_embeddings(
texts: List[str],
model_name: str,
api_key: str,
base_url: str = OPENAI_EMBED_BASE_URL,
batch_size: int = 32,
show_progress: bool = False,
progress_desc: str = "Embedding",
):
if not texts:
import numpy as np
return np.zeros((0, 1), dtype="float32")
client = get_openai_client(api_key=api_key, base_url=base_url)
vectors = []
step = max(1, int(batch_size))
starts = range(0, len(texts), step)
if show_progress:
try:
from tqdm.auto import tqdm
starts = tqdm(starts, total=(len(texts) + step - 1) // step, desc=progress_desc, unit="batch")
except Exception:
pass
for i in starts:
batch = [str(t or "") for t in texts[i : i + max(1, int(batch_size))]]
resp = client.embeddings.create(model=model_name, input=batch)
vectors.extend(_extract_embedding_vectors(resp))
if not vectors:
raise RuntimeError("Embedding API returned empty vector list.")
import numpy as np
return np.asarray(vectors, dtype="float32")
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except Exception:
return default
def _clamp01(value: Any) -> float:
v = _safe_float(value, default=0.0)
return max(0.0, min(1.0, v))
def _mean(values: List[float], default: float = 0.0) -> float:
if not values:
return default
return float(sum(values) / len(values))
def _normalize_scores_unit_interval(values: List[float]) -> List[float]:
if not values:
return []
try:
import numpy as np
arr = np.asarray(values, dtype="float64")
finite = np.isfinite(arr)
if not finite.any():
return [0.0 for _ in values]
safe = arr.copy()
safe[~finite] = np.nan
lo = np.nanmin(safe)
hi = np.nanmax(safe)
if not np.isfinite(lo) or not np.isfinite(hi):
return [0.0 for _ in values]
if abs(hi - lo) <= 1e-12:
base = 1.0 if hi > 0 else 0.0
return [base for _ in values]
out = (safe - lo) / (hi - lo)
out = np.clip(out, 0.0, 1.0)
out[~finite] = 0.0
return [float(x) for x in out.tolist()]
except Exception:
return [0.0 for _ in values]
def _average_similarity(contexts: List[Dict]) -> float:
vals = [_clamp01(c.get("score", 0.0)) for c in contexts if isinstance(c, dict)]
return _mean(vals, default=0.0)
def _extract_json_payload(text: str):
raw = str(text or "").strip()
if not raw:
return None
try:
return json.loads(raw)
except Exception:
pass
for block in re.findall(r"```(?:json)?\s*([\s\S]*?)```", raw, flags=re.IGNORECASE):
block = block.strip()
if not block:
continue
try:
return json.loads(block)
except Exception:
continue
marker_match = re.search(r"final\s+answer\s+in\s+json\s*:\s*", raw, flags=re.IGNORECASE)
if marker_match:
tail = raw[marker_match.end():].strip()
decoder = json.JSONDecoder()
try:
obj, _ = decoder.raw_decode(tail)
return obj
except Exception:
pass
decoder = json.JSONDecoder()
candidates = []
for i, ch in enumerate(raw):
if ch not in "{[":
continue
try:
obj, end = decoder.raw_decode(raw[i:])
candidates.append((obj, end))
except Exception:
continue
if not candidates:
return None
candidates.sort(key=lambda x: x[1], reverse=True)
return candidates[0][0]
def _first_sentence(text: str) -> str:
raw = str(text or "").strip().replace("\n", " ")
if not raw:
return ""
m = re.split(r"(?<=[.!?])\s+", raw)
if m:
return m[0].strip()
return raw[:220].strip()
def extract_yes_no_label(text: str) -> str:
payload = _extract_json_payload(text)
if isinstance(payload, dict):
conc = payload.get("conclusion")
if isinstance(conc, str):
text = conc
elif isinstance(payload.get("answer"), str):
text = payload.get("answer")
s = str(text or "").strip()
m = re.search(r"\[(YES|NO)\]", s, flags=re.IGNORECASE)
if m:
return m.group(1).upper()
m = re.match(r"\s*(YES|NO)\b", s, flags=re.IGNORECASE)
if m:
return m.group(1).upper()
upper = s.upper()
pos_yes = upper.find("YES")
pos_no = upper.find("NO")
if pos_yes >= 0 and (pos_no < 0 or pos_yes < pos_no):
return "YES"
if pos_no >= 0 and (pos_yes < 0 or pos_no < pos_yes):
return "NO"
return "UNKNOWN"
def _run_model_prompt(
prompt: str,
gen_model: str,
api_key: str = "",
max_tokens: int = 1024,
) -> str:
if _is_api_generation_model(gen_model):
chat_model = _resolve_api_generation_model_name(gen_model)
system_msg = (
"You are a strict JSON-oriented climate disclosure analysis assistant. "
"Follow the user task exactly."
)
if _is_gemini_generation_model(chat_model):
return _call_gemini_chat_model(
model_name=chat_model,
system_msg=system_msg,
user_msg=prompt,
api_key=api_key,
max_tokens=int(max_tokens),
)
return _call_openai_chat_model(
model_name=chat_model,
system_msg=system_msg,
user_msg=prompt,
api_key=api_key,
max_tokens=int(max_tokens),
)
if MOCK_MODE:
return _mock_model_prompt(prompt)
llm, _ = _get_llm()
from vllm import SamplingParams
sampling = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=int(max_tokens))
outputs = llm.generate([prompt], sampling)
return outputs[0].outputs[0].text.strip()
def _resolve_evidence_index(ref: Any, n_contexts: int):
if n_contexts <= 0:
return None
idx = None
if isinstance(ref, int):
idx = ref
elif isinstance(ref, str):
m = re.search(r"(\d+)", ref)
if m:
idx = int(m.group(1))
elif isinstance(ref, dict):
for key in ("id", "evidence_id", "index", "evidence_index"):
if key in ref:
return _resolve_evidence_index(ref.get(key), n_contexts)
if idx is None:
return None
if 1 <= idx <= n_contexts:
return idx - 1
if 0 <= idx < n_contexts:
return idx
return None
def _build_cluster_prompt(question: str, contexts: List[Dict]) -> str:
lines = []
for i, c in enumerate(contexts, start=1):
report = c.get("report", "unknown")
chunk_idx = c.get("chunk_idx", "NA")
score = _clamp01(c.get("score", 0.0))
text = str(c.get("text", "")).replace("\n", " ").strip()
lines.append(
f"[E{i}] report={report}; paragraph={chunk_idx}; similarity={score:.4f}; text={text[:650]}"
)
schema = (
'{"skill":"Evidence Clusterer","clusters":[{"cluster_name":"string","summary":"string",'
'"evidence_ids":["E1"]}]}'
)
return (
"__TASK__: EVIDENCE_CLUSTERING\n"
"You are an Evidence Clusterer for trustworthy RAG.\n"
"Task: cluster retrieved evidence by semantic theme and summarize each cluster.\n"
"Rules: use only provided evidence; do not fabricate ids.\n"
"Use clear semantic cluster names such as Governance & Oversight, Risk & Scenario Analysis, "
"Emissions Metrics, Strategy & Transition Plan, Targets & Progress, Value Chain Engagement.\n"
"Do not use generic names like Theme 1 / Cluster 2 / Semantic Cluster.\n"
"Return between 2 and 5 clusters whenever evidence size allows (n>=4).\n"
"If evidence count >= 4, do NOT return only one cluster.\n"
f"Return strict JSON only with schema: {schema}\n\n"
f"Question: {question}\n\n"
"Evidence:\n"
+ "\n".join(lines)
)
def _choose_cluster_k(n_contexts: int) -> int:
if n_contexts <= 3:
return max(1, n_contexts)
if n_contexts <= 6:
return 2
if n_contexts <= 12:
return 3
if n_contexts <= 18:
return 4
return 5
_SEMANTIC_CLUSTER_LABELS = [
(("governance", "board", "oversight", "committee", "accountability"), "Governance & Oversight"),
(("risk", "scenario", "ngfs", "ipcc", "stress test", "resilience", "transition risk", "physical risk"), "Risk & Scenario Analysis"),
(("scope 1", "scope 2", "scope 3", "emission", "ghg", "carbon", "intensity", "kpi", "metric"), "Emissions Metrics"),
(("target", "baseline", "attainment", "goal", "net zero", "roadmap"), "Targets & Progress"),
(("strategy", "transition plan", "investment", "capex", "allocation"), "Strategy & Transition Plan"),
(("supplier", "downstream", "upstream", "value chain", "partner", "procurement"), "Value Chain Engagement"),
(("water", "waste", "recycl", "pollution", "withdrawal", "discharge"), "Waste & Water Management"),
(("gri", "sasb", "tcfd", "ifrs", "csrd", "framework", "standard", "assurance", "audit"), "Disclosure Standards & Assurance"),
]
def _is_generic_cluster_name(name: str) -> bool:
low = str(name or "").strip().lower()
if not low:
return True
if any(x in low for x in ("semantic cluster", "general evidence", "theme", "cluster", "topic", "group")):
cleaned = re.sub(r"\s+", " ", low)
if re.fullmatch(r"(cluster|theme|topic|group)\s*[a-z0-9\- ]*", cleaned):
return True
if cleaned in {"general evidence", "semantic cluster"}:
return True
if "theme " in cleaned or "cluster " in cleaned:
return True
return False
def _semantic_cluster_name(seed_text: str, idx: int) -> str:
low = str(seed_text or "").lower()
for keywords, label in _SEMANTIC_CLUSTER_LABELS:
if any(k in low for k in keywords):
return label
return f"Disclosure Focus {idx + 1}"
def _pick_cluster_name(raw_name: str, summary: str, evidence_text: str, idx: int) -> str:
name = str(raw_name or "").strip()
name = re.sub(r"^\s*cluster\s*[a-z0-9]+\s*[-:]\s*", "", name, flags=re.IGNORECASE).strip()
if name and not _is_generic_cluster_name(name) and len(name) <= 64:
return name
seed = " ".join([str(summary or ""), str(evidence_text or "")]).strip()
return _semantic_cluster_name(seed, idx)
def _dedupe_cluster_names(clusters: List[Dict]) -> List[Dict]:
seen = {}
for c in clusters:
base = str(c.get("cluster_name", "")).strip() or "Disclosure Focus"
count = seen.get(base, 0) + 1
seen[base] = count
if count > 1:
c["cluster_name"] = f"{base} ({count})"
return clusters
def _local_semantic_clusters(contexts: List[Dict], fallback_conf: float) -> List[Dict]:
n = len(contexts)
if n == 0:
return []
if n <= 3:
# Keep one cluster for tiny evidence sets.
tiny_summary = _first_sentence(" ".join(str(c.get("text", "")) for c in contexts)) or "Top retrieved evidence grouped into one cluster."
tiny_name = _pick_cluster_name("", tiny_summary, tiny_summary, 0)
return [
{
"cluster_name": tiny_name,
"summary": tiny_summary,
"cluster_confidence": fallback_conf,
"evidence": [
{
"evidence_id": f"E{i + 1}",
"report": c.get("report", ""),
"chunk_idx": c.get("chunk_idx", ""),
"page": c.get("page", None),
"text": str(c.get("text", "")),
"score": _clamp01(c.get("score", 0.0)),
}
for i, c in enumerate(contexts)
],
}
]
k = _choose_cluster_k(n)
texts = [str(c.get("text", "")) for c in contexts]
labels = None
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
vec = TfidfVectorizer(max_features=1500, stop_words="english")
x = vec.fit_transform(texts)
km = KMeans(n_clusters=k, random_state=42, n_init=10)
labels = km.fit_predict(x).tolist()
except Exception:
# fallback: simple hash bucketing to still produce 2-5 groups
labels = [i % k for i in range(n)]
by_label = {}
for idx, lb in enumerate(labels):
by_label.setdefault(lb, []).append(idx)
clusters = []
for lb, idxs in sorted(by_label.items(), key=lambda x: -len(x[1])):
evidence = []
for idx in idxs:
c = contexts[idx]
evidence.append(
{
"evidence_id": f"E{idx + 1}",
"report": c.get("report", ""),
"chunk_idx": c.get("chunk_idx", ""),
"page": c.get("page", None),
"text": str(c.get("text", "")),
"score": _clamp01(c.get("score", 0.0)),
}
)
avg_score = _mean([_clamp01(e.get("score", 0.0)) for e in evidence], default=fallback_conf)
top_text = str(evidence[0].get("text", "")) if evidence else ""
summary = _first_sentence(top_text) or "Semantically related evidence."
joined_seed = " ".join(str(e.get("text", "")) for e in evidence[:3])
clusters.append(
{
"cluster_name": _pick_cluster_name("", summary, joined_seed, len(clusters)),
"summary": summary,
"cluster_confidence": avg_score,
"evidence": evidence,
}
)
# enforce 2-5 when n>=4
if len(clusters) < 2 and n >= 4:
mid = n // 2
left = list(range(0, mid))
right = list(range(mid, n))
split = []
for grp in (left, right):
evidence = []
for idx in grp:
c = contexts[idx]
evidence.append(
{
"evidence_id": f"E{idx + 1}",
"report": c.get("report", ""),
"chunk_idx": c.get("chunk_idx", ""),
"page": c.get("page", None),
"text": str(c.get("text", "")),
"score": _clamp01(c.get("score", 0.0)),
}
)
avg_score = _mean([_clamp01(e.get("score", 0.0)) for e in evidence], default=fallback_conf)
summary = _first_sentence(str(evidence[0].get("text", ""))) if evidence else "Related evidence."
joined_seed = " ".join(str(e.get("text", "")) for e in evidence[:3])
split.append(
{
"cluster_name": _pick_cluster_name("", summary or "", joined_seed, len(split)),
"summary": summary or "Related evidence.",
"cluster_confidence": avg_score,
"evidence": evidence,
}
)
clusters = split
if len(clusters) > 5:
# Keep top-4, merge rest into cluster 5.
kept = clusters[:4]
rest = clusters[4:]
merged_evidence = []
for c in rest:
merged_evidence.extend(c.get("evidence", []))
merged_score = _mean([_clamp01(e.get("score", 0.0)) for e in merged_evidence], default=fallback_conf)
merged_seed = " ".join(str(e.get("text", "")) for e in merged_evidence[:3])
kept.append(
{
"cluster_name": _pick_cluster_name("", "Merged tail semantic groups.", merged_seed, 4),
"summary": "Merged tail semantic groups.",
"cluster_confidence": merged_score,
"evidence": merged_evidence,
}
)
clusters = kept
return _dedupe_cluster_names(clusters)
def _build_single_fallback_cluster(contexts: List[Dict]) -> List[Dict]:
if not contexts:
return []
summary = _first_sentence(" ".join(str(c.get("text", "")) for c in contexts[:3])) or "Top retrieved evidence cluster."
return [
{
"cluster_name": "Retrieved Evidence",
"summary": summary,
"evidence": [
{
"evidence_id": f"E{i + 1}",
"report": c.get("report", ""),
"chunk_idx": c.get("chunk_idx", ""),
"page": c.get("page", None),
"text": str(c.get("text", "")),
"score": _clamp01(c.get("score", 0.0)),
}
for i, c in enumerate(contexts)
],
}
]
def _parse_cluster_result(output_text: str, contexts: List[Dict]) -> Dict[str, Any]:
payload = _extract_json_payload(output_text)
n_contexts = len(contexts)
clusters = []
if isinstance(payload, dict):
raw_clusters = payload.get("clusters", [])
if isinstance(raw_clusters, list):
for i, rc in enumerate(raw_clusters):
if not isinstance(rc, dict):
continue
summary = str(rc.get("summary", "")).strip()
refs = rc.get("evidence_ids", rc.get("evidence_indices", []))
if not isinstance(refs, list):
refs = [refs] if refs not in (None, "") else []
evidence_items = []
used = set()
for ref in refs:
idx = _resolve_evidence_index(ref, n_contexts)
if idx is None or idx in used:
continue
used.add(idx)
ctx = contexts[idx]
evidence_items.append({
"evidence_id": f"E{idx + 1}",
"report": ctx.get("report", ""),
"chunk_idx": ctx.get("chunk_idx", ""),
"page": ctx.get("page", None),
"text": str(ctx.get("text", "")),
"score": _clamp01(ctx.get("score", 0.0)),
})
if not evidence_items and i < n_contexts:
ctx = contexts[i]
evidence_items.append({
"evidence_id": f"E{i + 1}",
"report": ctx.get("report", ""),
"chunk_idx": ctx.get("chunk_idx", ""),
"page": ctx.get("page", None),
"text": str(ctx.get("text", "")),
"score": _clamp01(ctx.get("score", 0.0)),
})
evidence_seed = " ".join(str(e.get("text", "")) for e in evidence_items[:3])
name = _pick_cluster_name(rc.get("cluster_name", ""), summary, evidence_seed, i)
clusters.append({
"cluster_name": name,
"summary": summary,
"evidence": evidence_items,
})
fallback_conf = _average_similarity(contexts)
if clusters:
normalized = []
for i, c in enumerate(clusters):
summary = str(c.get("summary", ""))
evidence_seed = " ".join(str(e.get("text", "")) for e in c.get("evidence", [])[:3])
c["cluster_name"] = _pick_cluster_name(c.get("cluster_name", ""), summary, evidence_seed, i)
normalized.append(c)
clusters = _dedupe_cluster_names(normalized)
if len(contexts) >= 4 and len(clusters) < 2:
clusters = _local_semantic_clusters(contexts, fallback_conf=fallback_conf)
else:
# If model output is missing/invalid, force semantic multi-cluster fallback when evidence is enough.
if len(contexts) >= 4:
clusters = _local_semantic_clusters(contexts, fallback_conf=fallback_conf)
else:
clusters = _build_single_fallback_cluster(contexts)
return {
"clusters": clusters,
"raw_output": output_text,
"payload": payload,
}
def _build_claim_prompt(
question: str,
answer: str,
doc_mode: str,
n_claims: int,
report_names: List[str],
clusters: List[Dict],
) -> str:
report_text = ", ".join(report_names) if report_names else "N/A"
cluster_lines = []
for i, c in enumerate(clusters, start=1):
cluster_lines.append(
f"- Cluster {i} ({c.get('cluster_name', 'Unknown')}): {c.get('summary', '')}"
)
cluster_block = "\n".join(cluster_lines[:8])
schema = '{"claims":[{"claim":"string"}]}'
return (
"__TASK__: CLAIM_GENERATION\n"
"You are the Claim Extractor in a trustworthy RAG pipeline.\n"
"Generate concise claims that represent the answer's key arguments.\n"
"Rules:\n"
"- Return exactly N_CLAIMS claims.\n"
"- If multi-document: provide one claim per company/report, and the final claim must be overall conclusion.\n"
"- If single-document: provide 3 atomic claims.\n"
"- Keep each claim <= 28 words.\n"
f"Return strict JSON only with schema: {schema}\n\n"
f"DOC_MODE: {doc_mode}\n"
f"N_CLAIMS: {n_claims}\n"
f"REPORTS: {report_text}\n"
f"Question: {question}\n\n"
f"Answer:\n{answer}\n\n"
f"Cluster summaries:\n{cluster_block}"
)
def _contexts_from_clusters(clusters: List[Dict], fallback_contexts: List[Dict]) -> List[Dict]:
picked = {}
order = []
for c in clusters:
if not isinstance(c, dict):
continue
for ev in c.get("evidence", []):
if not isinstance(ev, dict):
continue
report = str(ev.get("report", ""))
chunk_idx = ev.get("chunk_idx", "")
text = str(ev.get("text", ""))
if not text:
continue
key = (report, str(chunk_idx))
if key not in picked:
picked[key] = {
"report": report,
"chunk_idx": chunk_idx,
"text": text,
"score": _clamp01(ev.get("score", 0.0)),
}
order.append(key)
else:
picked[key]["score"] = max(picked[key]["score"], _clamp01(ev.get("score", 0.0)))
if not order:
return fallback_contexts
out = [picked[k] for k in order]
out.sort(key=lambda x: x.get("score", 0.0), reverse=True)
return out
def _infer_multi_company_count(question: str, report_names: List[str]) -> int:
fallback = max(1, len(report_names))
q = str(question or "").lower()
q = re.sub(r"\"[^\"]+\"", " ", q)
connectors = q.count(" or ") + q.count(" and ")
commas = q.count(",")
guessed = 0
if connectors > 0 or commas > 0:
guessed = max(connectors + 1, commas + 1)
if guessed < 2:
guessed = fallback
return max(1, min(8, guessed))
def _sanitize_claim_text(text: str) -> str:
t = str(text or "").strip()
if not t:
return ""
if len(t) > 500:
t = _first_sentence(t)
if (t.startswith("{") and t.endswith("}")) or (t.startswith("[") and t.endswith("]")):
payload = _extract_json_payload(t)
if isinstance(payload, dict):
if isinstance(payload.get("conclusion"), str) and payload.get("conclusion", "").strip():
return payload["conclusion"].strip()
if isinstance(payload.get("answer"), str) and payload.get("answer", "").strip():
return payload["answer"].strip()
return "Structured result supports an evidence-based claim."
if isinstance(payload, list):
return "Structured result supports an evidence-based claim."
return t
def _extract_conclusion_text_from_answer(answer: str) -> str:
payload = _extract_json_payload(answer)
if isinstance(payload, dict):
conc = payload.get("conclusion")
if isinstance(conc, str) and conc.strip():
return conc.strip()
if isinstance(conc, dict):
lbl = str(conc.get("label", "")).strip()
reason = str(conc.get("reason", "")).strip()
if lbl or reason:
return f"{lbl} {reason}".strip()
ans = payload.get("answer")
if isinstance(ans, str) and ans.strip():
return ans.strip()
raw = str(answer or "").strip()
m = re.search(r"(?:^|\n)\s*conclusion\s*:\s*(.+)$", raw, flags=re.IGNORECASE)
if m:
return m.group(1).strip()
return _first_sentence(raw)
def _parse_claims(output_text: str, answer: str, n_claims: int, doc_mode: str, report_names: List[str]) -> List[str]:
payload = _extract_json_payload(output_text)
claims = []
raw_claims = None
if isinstance(payload, dict):
raw_claims = payload.get("claims")
elif isinstance(payload, list):
raw_claims = payload
if isinstance(raw_claims, list):
for item in raw_claims:
if isinstance(item, str):
claim = _sanitize_claim_text(item)
elif isinstance(item, dict):
claim = _sanitize_claim_text(item.get("claim", ""))
else:
claim = ""
if claim and claim not in claims:
claims.append(claim)
if not claims:
sentence = _first_sentence(answer) or "Evidence supports part of the question."
if doc_mode == "multi":
for rn in report_names:
claims.append(f"{rn.replace('.pdf', '')}: disclosure evidence is mixed and must be interpreted conservatively.")
claims.append(f"Conclusion: {_extract_conclusion_text_from_answer(answer) or sentence}")
else:
claims = [
sentence,
"Retrieved evidence provides partial quantitative support.",
"Some sub-claims remain weakly supported or uncertain.",
]
if doc_mode == "multi":
target_n = max(2, n_claims)
else:
target_n = max(3, n_claims)
if doc_mode == "multi":
conclusion_text = _extract_conclusion_text_from_answer(answer)
if not conclusion_text:
conclusion_text = _first_sentence(answer) or "No clear conclusion."
conclusion_claim = f"Conclusion: {conclusion_text}".strip()
body_target = max(1, target_n - 1)
body_claims = [c for c in claims if c and not str(c).strip().lower().startswith("conclusion:")]
body_claims = body_claims[:body_target]
while len(body_claims) < body_target:
body_claims.append(f"Additional claim {len(body_claims) + 1}: evidence is limited.")
return body_claims + [conclusion_claim]
claims = claims[:target_n]
while len(claims) < target_n:
claims.append(f"Additional claim {len(claims) + 1}: evidence is limited.")
return claims
def _build_faithfulness_prompt(claims: List[str], clusters: List[Dict]) -> str:
claim_lines = [f"C{i+1}: {c}" for i, c in enumerate(claims)]
cluster_lines = []
for i, c in enumerate(clusters, start=1):
ev_lines = []
for ev in c.get("evidence", [])[:4]:
ev_lines.append(
f" - {ev.get('evidence_id', '')} report={ev.get('report', '')} "
f"paragraph={ev.get('chunk_idx', '')}: {str(ev.get('text', ''))[:200]}"
)
cluster_lines.append(
f"Cluster {i} ({c.get('cluster_name', 'Unknown')}): {c.get('summary', '')}\n" + "\n".join(ev_lines)
)
schema = (
'{"claim_scores":[{"claim":"string","score":"number(0-1)","label":"supported|partial|unsupported","reason":"string"}],'
'"faithfulness":"number(0-1)"}'
)
return (
"__TASK__: FAITHFULNESS_VERIFIER\n"
"You are a Faithfulness Verifier.\n"
"Score each claim by whether it is supported by clustered evidence.\n"
"Rules: score in [0,1]; unsupported claims should receive low scores.\n"
f"Return strict JSON only with schema: {schema}\n\n"
"Claims:\n" + "\n".join(claim_lines) + "\n\n"
"Clustered evidence:\n" + "\n\n".join(cluster_lines)
)
def _token_set(text: str) -> set:
return set(re.findall(r"[a-z0-9]{3,}", str(text or "").lower()))
def _parse_faithfulness(output_text: str, claims: List[str], clusters: List[Dict]) -> Dict[str, Any]:
payload = _extract_json_payload(output_text)
scores = []
if isinstance(payload, dict):
rows = payload.get("claim_scores", [])
if isinstance(rows, list):
for i, row in enumerate(rows):
if not isinstance(row, dict):
continue
claim = str(row.get("claim", "")).strip()
if not claim and i < len(claims):
claim = claims[i]
score = _clamp01(row.get("score", 0.0))
label = str(row.get("label", "")).strip().lower()
if not label:
label = "supported" if score >= 0.75 else ("partial" if score >= 0.45 else "unsupported")
reason = str(row.get("reason", "")).strip()
scores.append({"claim": claim, "score": score, "label": label, "reason": reason})
if not scores:
evidence_text = " ".join(
str(ev.get("text", ""))
for c in clusters
for ev in c.get("evidence", [])
if isinstance(ev, dict)
)
ev_tokens = _token_set(evidence_text)
for claim in claims:
c_tokens = _token_set(claim)
if not c_tokens:
overlap_ratio = 0.0
else:
overlap_ratio = len(c_tokens & ev_tokens) / max(1, len(c_tokens))
score = _clamp01(overlap_ratio * 1.4)
label = "supported" if score >= 0.75 else ("partial" if score >= 0.45 else "unsupported")
scores.append({"claim": claim, "score": score, "label": label, "reason": ""})
while len(scores) < len(claims):
claim = claims[len(scores)]
scores.append({"claim": claim, "score": 0.35, "label": "partial", "reason": ""})
scores = scores[: len(claims)]
model_faith = None
if isinstance(payload, dict) and payload.get("faithfulness", None) is not None:
model_faith = _clamp01(payload.get("faithfulness"))
avg_faith = _mean([s["score"] for s in scores], default=0.0)
faithfulness = model_faith if model_faith is not None else avg_faith
return {
"claim_scores": scores,
"faithfulness": faithfulness,
"raw_output": output_text,
"payload": payload,
}
def compute_calibrated_confidence(average_similarity: float, cluster_confidence: float, faithfulness: float) -> float:
return _mean(
[_clamp01(average_similarity), _clamp01(cluster_confidence), _clamp01(faithfulness)],
default=0.0,
)
def _chunk_json_source_file(chunk_mode: str) -> str:
if str(chunk_mode or "").strip().lower() == "structure":
return OCR_CHUNKS_ALL_STRUCTURE_JSON
return OCR_CHUNKS_ALL_LENGTH_JSON
def _load_report_chunks_from_json(chunk_mode: str):
src = _chunk_json_source_file(chunk_mode)
if not os.path.isfile(src):
return None
try:
with open(src, "r", encoding="utf-8") as f:
rows = json.load(f)
except Exception:
return None
if not isinstance(rows, list):
return None
grouped: Dict[str, List[tuple]] = {}
for row in rows:
if not isinstance(row, dict):
continue
report = str(row.get("report", "")).strip()
if not report:
continue
text = str(row.get("chunk_text", row.get("text", "")) or "")
try:
chunk_idx = int(row.get("chunk_idx"))
except Exception:
chunk_idx = len(grouped.get(report, []))
page = row.get("page", row.get("page_start", None))
try:
page = int(page) if page not in (None, "", "NA") else None
except Exception:
page = None
grouped.setdefault(report, []).append((chunk_idx, text, page))
if not grouped:
return None
normalized: Dict[str, List[tuple]] = {}
for report, items in grouped.items():
dedup = {}
for idx, text, page in items:
if idx in dedup:
# Keep non-empty replacement if duplicate index appears.
if (not str(dedup[idx][0]).strip()) and str(text).strip():
dedup[idx] = (text, page)
elif dedup[idx][1] is None and page is not None:
dedup[idx] = (dedup[idx][0], page)
else:
dedup[idx] = (text, page)
normalized[report] = sorted(
[(int(idx), str(txt or ""), pg) for idx, (txt, pg) in dedup.items()],
key=lambda x: x[0],
)
return normalized
def get_report_chunks(chunk_mode: str) -> Dict[str, list]:
if chunk_mode in _report_chunks_cache:
return _report_chunks_cache[chunk_mode]
# Prefer pre-chunked JSON cache for both single-doc and multi-doc queries.
# Retrieval builds one global pool, so this source is shared by both modes.
report_chunks = _load_report_chunks_from_json(chunk_mode)
if isinstance(report_chunks, dict) and report_chunks:
_report_chunks_cache[chunk_mode] = report_chunks
return report_chunks
src = _chunk_json_source_file(chunk_mode)
raise FileNotFoundError(
f"Required chunk JSON is missing or unreadable: {src}. "
"This build does not fallback to MinerU_Reports."
)
# ======================== Core Functions ========================
def retrieve(
question: str,
chunk_mode: str = "structure",
doc_mode: str = "single",
top_k: int = 5,
embed_name: str = "Qwen3-Embedding-0.6B",
api_key: str = "",
openai_base_url: str = OPENAI_EMBED_BASE_URL,
) -> List[Dict]:
openai_base_url = str(openai_base_url or OPENAI_EMBED_BASE_URL).rstrip("/")
use_openai_embed = _is_openai_embedding_model(embed_name)
use_bm25 = str(embed_name or "").strip().upper() == "BM25"
if MOCK_MODE and not use_openai_embed and not use_bm25:
return _mock_retrieve(question, chunk_mode, doc_mode, top_k)
if use_openai_embed and (not _resolve_api_key(api_key)):
raise RuntimeError(
"OpenAI embedding model selected but API key is missing. "
"Please provide OPENAI_API_KEY or pass api_key."
)
import numpy as np
report_chunks = get_report_chunks(chunk_mode)
pool = _build_chunk_pool(report_chunks)
if not pool:
return []
if use_bm25:
bm25 = _get_bm25(chunk_mode=chunk_mode, pool=pool)
scores = bm25.get_scores(_tokenize_for_bm25(question))
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
for i in top_indices:
results.append(
{
"report": pool[i][0],
"chunk_idx": pool[i][1],
"text": pool[i][2],
"page": (pool[i][3] if len(pool[i]) > 3 else None),
"score": float(scores[i]),
"raw_score": float(scores[i]),
}
)
return results
if use_openai_embed:
signature = _chunk_pool_signature(pool)
cache_key = f"{embed_name}__{chunk_mode}"
if cache_key not in _chunk_embeddings_cache:
simple_cache_file = _get_simple_embed_cache_file(embed_name, chunk_mode)
legacy_cache_file = _get_embed_cache_file(
f"{embed_name}__api__{openai_base_url}",
chunk_mode,
signature,
)
cached = _load_cached_embeddings(simple_cache_file, expected_rows=len(pool))
if cached is None:
cached = _load_cached_embeddings(legacy_cache_file, expected_rows=len(pool))
if cached is not None:
_chunk_embeddings_cache[cache_key] = cached
else:
texts = [p[2] for p in pool]
encoded = _encode_openai_embeddings(
texts=texts,
model_name=embed_name,
api_key=api_key,
base_url=openai_base_url,
batch_size=32,
)
_chunk_embeddings_cache[cache_key] = encoded
_save_cached_embeddings(simple_cache_file, encoded)
doc_emb = _chunk_embeddings_cache[cache_key]
query_emb = _encode_openai_embeddings(
texts=[question],
model_name=embed_name,
api_key=api_key,
base_url=openai_base_url,
batch_size=1,
)
else:
embed_model = _get_embed_model(embed_name)
signature = _chunk_pool_signature(pool)
cache_key = f"{embed_name}__{chunk_mode}"
if cache_key not in _chunk_embeddings_cache:
simple_cache_file = _get_simple_embed_cache_file(embed_name, chunk_mode)
legacy_cache_file = _get_embed_cache_file(embed_name, chunk_mode, signature)
cached = _load_cached_embeddings(simple_cache_file, expected_rows=len(pool))
if cached is None:
cached = _load_cached_embeddings(legacy_cache_file, expected_rows=len(pool))
if cached is not None:
_chunk_embeddings_cache[cache_key] = cached
else:
texts = [p[2] for p in pool]
encoded = embed_model.encode_documents(texts, batch_size=32)
_chunk_embeddings_cache[cache_key] = encoded
_save_cached_embeddings(simple_cache_file, encoded)
doc_emb = _chunk_embeddings_cache[cache_key]
query_emb = embed_model.encode_queries([question], batch_size=1)
if getattr(doc_emb, "ndim", 0) != 2 or getattr(query_emb, "ndim", 0) != 2:
raise RuntimeError("Embedding encoding failed: expected 2D arrays for document/query embeddings.")
if doc_emb.shape[1] != query_emb.shape[1]:
raise RuntimeError(
f"Embedding dimension mismatch: docs={doc_emb.shape[1]}, query={query_emb.shape[1]} "
f"(model={embed_name})."
)
# Use true cosine similarity for dense retrieval:
# sim(q, d) = (q·d) / (||q|| * ||d||)
query_vec = query_emb[0]
dot_scores = doc_emb @ query_vec
doc_norms = np.linalg.norm(doc_emb, axis=1)
q_norm = float(np.linalg.norm(query_vec))
denom = doc_norms * max(q_norm, 1e-12)
scores = np.divide(
dot_scores,
denom,
out=np.zeros_like(dot_scores, dtype="float32"),
where=denom > 1e-12,
)
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
for i in top_indices:
results.append({
"report": pool[i][0],
"chunk_idx": pool[i][1],
"text": pool[i][2],
"page": (pool[i][3] if len(pool[i]) > 3 else None),
"score": float(scores[i]),
"raw_score": float(dot_scores[i]),
})
return results
def generate(
question: str,
contexts: List[Dict],
doc_mode: str = "single",
gen_model: str = "Qwen3-4B-Instruct-2507-FP8",
api_key: str = "",
) -> str:
if _is_api_generation_model(gen_model):
return _generate_with_api(question, contexts, doc_mode, gen_model, api_key)
if not HAS_GPU:
return (
"\u26a0\ufe0f No GPU detected on this host. Local generation models are disabled. "
"Please use an API generation model."
)
if MOCK_MODE:
return _mock_generate(question, contexts, doc_mode)
from generation import (
build_yes_no_prompt,
build_multi_zero_shot_prompt,
infer_multi_skill_name,
get_multi_skill_spec,
)
llm, sampling = _get_llm()
if doc_mode == "multi":
skill_name = infer_multi_skill_name(question)
spec = get_multi_skill_spec(skill_name)
ctx_texts = [
f"(report={c['report']}, chunk_idx={c['chunk_idx']}) {c['text']}"
for c in contexts
]
prompt = build_multi_zero_shot_prompt(
question=question,
contexts=ctx_texts,
skill_name=spec["skill_name"],
skill_prompt=spec["skill_prompt"],
output_json_schema=spec["output_json_schema"],
retrieval_query=question,
)
else:
ctx_texts = [c["text"] for c in contexts]
prompt = build_yes_no_prompt(question, ctx_texts, doc_mode)
outputs = llm.generate([prompt], sampling)
return outputs[0].outputs[0].text.strip()
def run_rag(
question: str,
chunk_mode: str = "structure",
doc_mode: str = "single",
top_k: int = 5,
embed_name: str = "Qwen3-Embedding-0.6B",
gen_model: str = "Qwen3-4B-Instruct-2507-FP8",
api_key: str = "",
) -> tuple:
contexts = retrieve(
question=question,
chunk_mode=chunk_mode,
doc_mode=doc_mode,
top_k=top_k,
embed_name=embed_name,
api_key=api_key,
openai_base_url=OPENAI_EMBED_BASE_URL,
)
answer = generate(question, contexts, doc_mode, gen_model, api_key)
return answer, contexts
def run_trustworthy_step1(
question: str,
doc_mode: str = "single",
top_k: int = 10,
embed_name: str = "Qwen3-Embedding-0.6B",
gen_model: str = "Qwen3-4B-Instruct-2507-FP8",
api_key: str = "",
) -> Dict[str, Any]:
contexts = retrieve(
question=question,
chunk_mode="structure",
doc_mode=doc_mode,
top_k=int(top_k),
embed_name=embed_name,
api_key=api_key,
openai_base_url=OPENAI_EMBED_BASE_URL,
)
avg_similarity = _average_similarity(contexts)
cluster_prompt = _build_cluster_prompt(question, contexts)
cluster_raw = _run_model_prompt(
prompt=cluster_prompt,
gen_model=gen_model,
api_key=api_key,
max_tokens=1200,
)
cluster_result = _parse_cluster_result(cluster_raw, contexts)
return {
"contexts": contexts,
"average_similarity": avg_similarity,
"clusters": cluster_result["clusters"],
"cluster_raw_output": cluster_result["raw_output"],
}
def run_trustworthy_recluster(
question: str,
contexts: List[Dict],
gen_model: str = "Qwen3-4B-Instruct-2507-FP8",
api_key: str = "",
) -> Dict[str, Any]:
cluster_prompt = _build_cluster_prompt(question, contexts)
cluster_raw = _run_model_prompt(
prompt=cluster_prompt,
gen_model=gen_model,
api_key=api_key,
max_tokens=1200,
)
cluster_result = _parse_cluster_result(cluster_raw, contexts)
return {
"clusters": cluster_result["clusters"],
"cluster_raw_output": cluster_result["raw_output"],
}
def run_trustworthy_step2(
question: str,
doc_mode: str,
contexts: List[Dict],
clusters: List[Dict],
gen_model: str = "Qwen3-4B-Instruct-2507-FP8",
api_key: str = "",
) -> Dict[str, Any]:
qa_contexts = _contexts_from_clusters(clusters, contexts)
answer = generate(
question=question,
contexts=qa_contexts,
doc_mode=doc_mode,
gen_model=gen_model,
api_key=api_key,
)
return {
"answer": answer,
"qa_contexts": qa_contexts,
}
def run_rag_trustworthy(
question: str,
doc_mode: str = "single",
top_k: int = 10,
embed_name: str = "Qwen3-Embedding-0.6B",
gen_model: str = "Qwen3-4B-Instruct-2507-FP8",
api_key: str = "",
) -> Dict[str, Any]:
step1 = run_trustworthy_step1(
question=question,
doc_mode=doc_mode,
top_k=top_k,
embed_name=embed_name,
gen_model=gen_model,
api_key=api_key,
)
step2 = run_trustworthy_step2(
question=question,
doc_mode=doc_mode,
contexts=step1["contexts"],
clusters=step1["clusters"],
gen_model=gen_model,
api_key=api_key,
)
return {
"step1": step1,
"step2": step2,
}
def _build_claim_alignment_prompt(
question: str,
answer: str,
contexts: List[Dict],
doc_mode: str,
n_claims: int,
) -> str:
lines = []
for i, c in enumerate(contexts, start=1):
report = c.get("report", "unknown")
chunk_idx = c.get("chunk_idx", "NA")
text = str(c.get("text", "")).replace("\n", " ").strip()
lines.append(f"[E{i}] report={report}; paragraph={chunk_idx}; text={text[:650]}")
schema = (
'{"claims":[{"claim":"string","evidence_ids":["E1"]}]}'
)
return (
"__TASK__: CLAIM_EXTRACTION\n"
"You are the Claim Extractor in a trustworthy RAG pipeline.\n"
"Given answer and evidence list, extract concise claims and link each claim to supporting evidence ids.\n"
"Rules:\n"
"- Return exactly N_CLAIMS claims.\n"
"- Use only evidence ids from provided list.\n"
"- Keep each claim <= 28 words.\n"
f"Return strict JSON only with schema: {schema}\n\n"
f"DOC_MODE: {doc_mode}\n"
f"N_CLAIMS: {n_claims}\n"
f"Question: {question}\n\n"
f"Answer:\n{answer}\n\n"
"Evidence:\n"
+ "\n".join(lines)
)
def _parse_claim_alignment(
output_text: str,
contexts: List[Dict],
n_claims: int,
answer: str,
doc_mode: str,
) -> Dict[str, Any]:
payload = _extract_json_payload(output_text)
n_contexts = len(contexts)
claim_rows = []
if isinstance(payload, dict):
rows = payload.get("claims", [])
if isinstance(rows, list):
for i, row in enumerate(rows):
if isinstance(row, str):
claim = _sanitize_claim_text(row)
refs = []
elif isinstance(row, dict):
claim = _sanitize_claim_text(row.get("claim", ""))
refs = row.get("evidence_ids", row.get("evidence", row.get("sources", [])))
else:
continue
if not claim:
continue
if not isinstance(refs, list):
refs = [refs] if refs not in (None, "") else []
evidence = []
used = set()
for ref in refs:
idx = _resolve_evidence_index(ref, n_contexts)
if idx is None or idx in used:
continue
used.add(idx)
ctx = contexts[idx]
evidence.append(
{
"evidence_id": f"E{idx + 1}",
"report": str(ctx.get("report", "")),
"chunk_idx": ctx.get("chunk_idx", ""),
"page": ctx.get("page", None),
"text": str(ctx.get("text", "")),
"score": _clamp01(ctx.get("score", 0.0)),
}
)
# fallback to one evidence by order if model omitted refs
if not evidence and i < n_contexts:
ctx = contexts[i]
evidence.append(
{
"evidence_id": f"E{i + 1}",
"report": str(ctx.get("report", "")),
"chunk_idx": ctx.get("chunk_idx", ""),
"page": ctx.get("page", None),
"text": str(ctx.get("text", "")),
"score": _clamp01(ctx.get("score", 0.0)),
}
)
claim_rows.append({"claim": claim, "evidence": evidence})
if not claim_rows:
report_names = sorted({str(c.get("report", "")) for c in contexts if str(c.get("report", "")).strip()})
seed_claims = _parse_claims(
output_text="",
answer=answer,
n_claims=n_claims,
doc_mode=doc_mode,
report_names=report_names,
)
for i, claim in enumerate(seed_claims[:n_claims]):
evidence = []
if i < n_contexts:
ctx = contexts[i]
evidence.append(
{
"evidence_id": f"E{i + 1}",
"report": str(ctx.get("report", "")),
"chunk_idx": ctx.get("chunk_idx", ""),
"page": ctx.get("page", None),
"text": str(ctx.get("text", "")),
"score": _clamp01(ctx.get("score", 0.0)),
}
)
claim_rows.append({"claim": claim, "evidence": evidence})
claim_rows = claim_rows[:n_claims]
while len(claim_rows) < n_claims:
claim_rows.append({"claim": f"Additional claim {len(claim_rows) + 1}: evidence is limited.", "evidence": []})
return {
"claims": [c["claim"] for c in claim_rows],
"claim_links": claim_rows,
"raw_output": output_text,
"payload": payload,
}
def run_trustworthy_step3_claims(
question: str,
answer: str,
contexts: List[Dict],
doc_mode: str = "single",
gen_model: str = "Qwen3-4B-Instruct-2507-FP8",
api_key: str = "",
) -> Dict[str, Any]:
if doc_mode == "single":
n_claims = 3
else:
report_names = sorted({str(c.get("report", "")) for c in contexts if str(c.get("report", "")).strip()})
n_claims = max(2, _infer_multi_company_count(question, report_names) + 1)
prompt = _build_claim_alignment_prompt(
question=question,
answer=answer,
contexts=contexts,
doc_mode=doc_mode,
n_claims=n_claims,
)
raw = _run_model_prompt(
prompt=prompt,
gen_model=gen_model,
api_key=api_key,
max_tokens=900,
)
parsed = _parse_claim_alignment(
output_text=raw,
contexts=contexts,
n_claims=n_claims,
answer=answer,
doc_mode=doc_mode,
)
return {
"claims": parsed["claims"],
"claim_links": parsed["claim_links"],
"claims_raw_output": parsed["raw_output"],
}
# ======================== Report Library ========================
def list_reports() -> List[Dict]:
reports = []
if os.path.isdir(REPORTS_DIR):
for f in sorted(os.listdir(REPORTS_DIR)):
if f.lower().endswith(".pdf"):
reports.append({"name": f, "path": os.path.join(REPORTS_DIR, f)})
return reports
# ======================== API Generation ========================
def _generate_with_api(
question: str,
contexts: List[Dict],
doc_mode: str,
gen_model: str,
api_key: str,
) -> str:
from generation import (
build_yes_no_prompt,
build_multi_zero_shot_prompt,
infer_multi_skill_name,
get_multi_skill_spec,
)
if doc_mode == "multi":
skill = infer_multi_skill_name(question)
spec = get_multi_skill_spec(skill)
ctx_texts = [
f"(report={c['report']}, chunk_idx={c['chunk_idx']}) {c['text']}"
for c in contexts
]
user_msg = build_multi_zero_shot_prompt(
question=question,
contexts=ctx_texts,
skill_name=spec["skill_name"],
skill_prompt=spec["skill_prompt"],
output_json_schema=spec["output_json_schema"],
retrieval_query=question,
)
else:
ctx_texts = [str(c.get("text", "")) for c in contexts]
mode = doc_mode if doc_mode in {"single", "cross"} else "single"
user_msg = build_yes_no_prompt(question, ctx_texts, mode)
system_msg = (
"You are a strict climate disclosure analysis assistant. "
"Follow the user task exactly and use only provided evidence."
)
chat_model = _resolve_api_generation_model_name(gen_model)
if _is_gemini_generation_model(chat_model):
return _call_gemini_chat_model(
model_name=chat_model,
system_msg=system_msg,
user_msg=user_msg,
api_key=api_key,
max_tokens=2048,
)
return _call_openai_chat_model(
model_name=chat_model,
system_msg=system_msg,
user_msg=user_msg,
api_key=api_key,
max_tokens=2048,
)
# ======================== Mock Functions ========================
_MOCK_REPORTS = [
"2022 Microsoft Environmental Sustainability Report.pdf",
"Apple_2023_EnvironmentalProgress_Report.pdf",
"Shell sustainability report 2022.pdf",
"Boeing 2023 Sustainability Report.pdf",
"Deloitte 2023 Global Impact Report.pdf",
]
_MOCK_PASSAGES = [
"The company has established comprehensive environmental targets aligned with "
"the Paris Agreement and Science Based Targets initiative (SBTi). Key commitments "
"include achieving carbon neutrality by 2030 and net-zero emissions across the "
"entire value chain by 2050. The sustainability roadmap outlines specific milestones "
"for renewable energy adoption, waste reduction, and water stewardship programs.",
"In alignment with TCFD recommendations, we have integrated climate risk assessment "
"into our enterprise risk management framework. Physical risks including extreme "
"weather events and transition risks such as regulatory changes are evaluated "
"annually. Scenario analysis covering 1.5\u00b0C, 2\u00b0C, and 4\u00b0C pathways informs "
"our long-term strategic planning.",
"Our Scope 1 and 2 greenhouse gas emissions decreased by 23% compared to the "
"2019 baseline year, driven by accelerated procurement of renewable electricity "
"and fleet electrification. Scope 3 emissions from the supply chain remain our "
"largest challenge, and we have engaged 85% of Tier-1 suppliers in setting "
"their own science-based targets.",
"The company invested $2.1 billion in clean technology R&D during the reporting "
"period, representing a 35% increase year-over-year. Priority areas include "
"carbon capture and storage, sustainable aviation fuels, and next-generation "
"battery technology for energy storage solutions.",
"Water stewardship initiatives achieved a 40% reduction in freshwater withdrawal "
"intensity across manufacturing operations. Zero-liquid-discharge systems were "
"deployed at three additional facilities, bringing the total to twelve sites "
"operating under closed-loop water management protocols.",
]
def _mock_model_prompt(prompt: str) -> str:
text = str(prompt or "")
if "__TASK__: EVIDENCE_CLUSTERING" in text:
evidence_ids = re.findall(r"\[E(\d+)\]", text)
ids = [f"E{x}" for x in evidence_ids]
split = max(1, len(ids) // 2)
cluster_a = ids[:split] if ids else ["E1"]
cluster_b = ids[split:] if len(ids) > split else []
clusters = [
{
"cluster_name": "Governance",
"summary": "Oversight, risk process, and governance-related evidence.",
"evidence_ids": cluster_a,
}
]
if cluster_b:
clusters.append(
{
"cluster_name": "Metrics & Targets",
"summary": "Quantitative metrics, targets, and progress evidence.",
"evidence_ids": cluster_b,
}
)
return json.dumps(
{
"skill": "Evidence Clusterer",
"clusters": clusters,
},
ensure_ascii=False,
)
if "__TASK__: CLAIM_GENERATION" in text:
n = 3
m = re.search(r"N_CLAIMS:\s*(\d+)", text)
if m:
n = max(1, int(m.group(1)))
claims = [{"claim": f"Claim {i+1}: evidence-backed summary statement."} for i in range(n)]
return json.dumps({"claims": claims}, ensure_ascii=False)
if "__TASK__: CLAIM_EXTRACTION" in text:
n = 3
m = re.search(r"N_CLAIMS:\s*(\d+)", text)
if m:
n = max(1, int(m.group(1)))
evidence_ids = re.findall(r"\[E(\d+)\]", text)
evidences = [f"E{x}" for x in evidence_ids]
rows = []
for i in range(n):
refs = []
if evidences:
refs = [evidences[i % len(evidences)]]
rows.append(
{
"claim": f"Claim {i+1}: evidence-backed summary statement.",
"evidence_ids": refs,
}
)
return json.dumps({"claims": rows}, ensure_ascii=False)
if "__TASK__: FAITHFULNESS_VERIFIER" in text:
claim_lines = re.findall(r"^C\d+:\s*(.+)$", text, flags=re.MULTILINE)
if not claim_lines:
claim_lines = ["Claim 1"]
rows = []
for i, c in enumerate(claim_lines):
score = max(0.55, 0.9 - i * 0.08)
label = "supported" if score >= 0.75 else "partial"
rows.append(
{
"claim": c.strip(),
"score": round(score, 3),
"label": label,
"reason": "Mock verifier support signal.",
}
)
faith = _mean([r["score"] for r in rows], default=0.0)
return json.dumps({"claim_scores": rows, "faithfulness": round(faith, 3)}, ensure_ascii=False)
return "{}"
def _mock_retrieve(question: str, chunk_mode: str, doc_mode: str, top_k: int) -> List[Dict]:
return [
{
"report": _MOCK_REPORTS[i % len(_MOCK_REPORTS)],
"chunk_idx": i * 3 + 7,
"text": _MOCK_PASSAGES[i % len(_MOCK_PASSAGES)],
"score": round(0.9231 - i * 0.0547, 4),
}
for i in range(top_k)
]
def _mock_generate(question: str, contexts: List[Dict], doc_mode: str) -> str:
n = len(contexts)
if doc_mode == "multi":
return (
'{'
'"dimension":"Mock dimension",'
'"rows":[{"report":"Mock Report A","year":"2022","disclosure_status":"explicit",'
'"key_points":["Mock evidence A"],"evidence_chunks":["E1"]}],'
'"ranking":[{"rank":1,"report":"Mock Report A","rationale":"Mock rationale."}],'
'"conclusion":"[YES] Mock conclusion based on provided evidence."'
'}'
)
return (
f"[YES]. Based on the {n} retrieved passages, the company demonstrates "
f"alignment with external climate targets through specific commitments "
f"to net-zero emissions and science-based targets. The report references "
f"the Paris Agreement, SBTi, and TCFD frameworks as guiding principles "
f"for their environmental strategy."
)