Spaces:
Running
Running
| """ | |
| RAG Backend — global model loading + instant retrieve / generate functions. | |
| On server (GPU available): | |
| MOCK_MODE is auto-disabled; real models are loaded once at first call. | |
| On local dev (no GPU / no model files): | |
| MOCK_MODE is auto-enabled; placeholder results are returned so the UI works. | |
| Heavy dependencies (torch, sentence_transformers, vllm) are only imported | |
| when MOCK_MODE is False, so the Gradio UI can start locally without them. | |
| """ | |
| import os | |
| import re | |
| import sys | |
| import json | |
| import hashlib | |
| import pickle | |
| import multiprocessing as mp | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if SCRIPT_DIR not in sys.path: | |
| sys.path.insert(0, SCRIPT_DIR) | |
| os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") | |
| try: | |
| mp.set_start_method("spawn", force=False) | |
| except Exception: | |
| pass | |
| # ======================== Config ======================== | |
| EMBED_MODEL_PATHS = { | |
| "Qwen3-Embedding-0.6B": os.path.join(SCRIPT_DIR, "..", "model", "Qwen3-Embedding-0.6B"), | |
| "Qwen3-Embedding-4B": os.path.join(SCRIPT_DIR, "..", "model", "Qwen3-Embedding-4B"), | |
| } | |
| OPENAI_EMBED_MODELS = { | |
| "text-embedding-3-large", | |
| "text-embedding-3-small", | |
| "text-embedding-ada-002", | |
| } | |
| OPENAI_EMBED_BASE_URL = os.getenv("OPENAI_EMBED_BASE_URL", "https://api.openai.com/v1") | |
| OPENAI_CHAT_MODEL = os.getenv("OPENAI_CHAT_MODEL", "gpt-5-mini") | |
| API_GEN_MODEL_ALIASES = { | |
| "GPT-5 (API)": "gpt-5", | |
| "GPT-5-mini (API)": "gpt-5-mini", | |
| "GPT-4o (API)": "gpt-4o", | |
| "Gemini-2.5-Flash (API)": "gemini-2.5-flash", | |
| "Gemini-3-Flash-Preview (API)": "gemini-3-flash-preview", | |
| } | |
| GEN_MODEL_PATH = os.path.join(SCRIPT_DIR, "..", "model", "Qwen3-4B-Instruct-2507-FP8") | |
| REPORTS_DIR = os.path.join(SCRIPT_DIR, "..", "Reports") | |
| MINERU_ABS = os.path.join(SCRIPT_DIR, "..", "MinerU_Reports") | |
| EMBED_CACHE_DIR = os.path.join(SCRIPT_DIR, "Embedding_Cache") | |
| OCR_CHUNKS_ALL_LENGTH_JSON = os.path.join( | |
| SCRIPT_DIR, "..", "OCR_Chunked_Annotated", "ocr_chunks_all.json" | |
| ) | |
| OCR_CHUNKS_ALL_STRUCTURE_JSON = os.path.join( | |
| SCRIPT_DIR, "..", "OCR_Chunked_Annotated_structure", "ocr_chunks_all_structure.json" | |
| ) | |
| CHUNK_SIZE = 350 | |
| CHUNK_OVERLAP = 50 | |
| STRUCTURE_MAX_TOKENS = 550 | |
| # ======================== Auto-detect mock mode ======================== | |
| def _detect_gpu() -> bool: | |
| try: | |
| import torch | |
| return bool(torch.cuda.is_available()) | |
| except Exception: | |
| return False | |
| HAS_GPU = _detect_gpu() | |
| def _check_mock_mode() -> bool: | |
| if not HAS_GPU: | |
| return True | |
| default_embed = list(EMBED_MODEL_PATHS.values())[0] | |
| if not os.path.isdir(default_embed): | |
| return True | |
| return False | |
| MOCK_MODE = _check_mock_mode() | |
| # ======================== Global State ======================== | |
| _embed_models: Dict[str, object] = {} | |
| _llm_model = None | |
| _sampling_params = None | |
| _report_chunks_cache: Dict[str, Dict[str, list]] = {} | |
| _chunk_embeddings_cache: Dict[str, object] = {} | |
| _bm25_cache: Dict[str, object] = {} | |
| def _lazy_import_heavy(): | |
| import importlib | |
| return importlib.import_module("create_embedding_search_results_qwen") | |
| def _get_embed_model(embed_name: str): | |
| if MOCK_MODE: | |
| return None | |
| if embed_name not in _embed_models: | |
| path = EMBED_MODEL_PATHS.get(embed_name) | |
| if path is None or not os.path.isdir(path): | |
| raise FileNotFoundError(f"Embedding model not found: {path}") | |
| safe_files = list(Path(path).rglob("*.safetensors")) | |
| if safe_files and any(f.stat().st_size == 0 for f in safe_files): | |
| bad = [str(f) for f in safe_files if f.stat().st_size == 0][:3] | |
| raise RuntimeError( | |
| "Detected empty .safetensors file(s) in embedding model directory. " | |
| f"Please re-download model files. Examples: {bad}" | |
| ) | |
| mod = _lazy_import_heavy() | |
| try: | |
| _embed_models[embed_name] = mod.Qwen3EmbeddingModel(path, device="cuda") | |
| except Exception as e: | |
| msg = str(e) | |
| if "Error while deserializing header" in msg or "invalid JSON in header" in msg: | |
| raise RuntimeError( | |
| "Failed to load local Qwen embedding model. " | |
| "This usually means corrupted/incomplete .safetensors files or interrupted copy. " | |
| f"Model path: {path}" | |
| ) from e | |
| raise | |
| return _embed_models[embed_name] | |
| def _get_llm(): | |
| global _llm_model, _sampling_params | |
| if MOCK_MODE: | |
| return None, None | |
| if _llm_model is None: | |
| from vllm import LLM, SamplingParams | |
| _llm_model = LLM(model=GEN_MODEL_PATH, max_model_len=16384, | |
| dtype="auto", trust_remote_code=True, gpu_memory_utilization=0.6,) | |
| _sampling_params = SamplingParams(temperature=0.0, max_tokens=2048, top_p=1.0) | |
| return _llm_model, _sampling_params | |
| def _sanitize_name(text: str) -> str: | |
| return "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in str(text)) | |
| def _build_chunk_pool(report_chunks: Dict[str, list]) -> List[tuple]: | |
| pool = [] | |
| for report in sorted(report_chunks.keys()): | |
| chunks = report_chunks.get(report, []) | |
| for idx, item in enumerate(chunks): | |
| chunk_idx = idx | |
| text = item | |
| page = None | |
| if isinstance(item, (tuple, list)) and len(item) >= 2: | |
| try: | |
| chunk_idx = int(item[0]) | |
| except Exception: | |
| chunk_idx = idx | |
| text = item[1] | |
| if len(item) >= 3: | |
| page = item[2] | |
| elif isinstance(item, dict): | |
| try: | |
| chunk_idx = int(item.get("chunk_idx", idx)) | |
| except Exception: | |
| chunk_idx = idx | |
| text = item.get("chunk_text", item.get("text", "")) | |
| page = item.get("page", item.get("page_start", None)) | |
| try: | |
| page = int(page) if page not in (None, "", "NA") else None | |
| except Exception: | |
| page = None | |
| pool.append((report, chunk_idx, str(text or ""), page)) | |
| return pool | |
| def _chunk_pool_signature(pool: List[tuple]) -> str: | |
| h = hashlib.sha256() | |
| for entry in pool: | |
| report = entry[0] | |
| idx = entry[1] | |
| text = entry[2] | |
| h.update(str(report).encode("utf-8", errors="ignore")) | |
| h.update(b"\x1f") | |
| h.update(str(idx).encode("utf-8", errors="ignore")) | |
| h.update(b"\x1f") | |
| h.update(str(text).encode("utf-8", errors="ignore")) | |
| h.update(b"\x1e") | |
| return h.hexdigest()[:16] | |
| def _get_embed_cache_file(embed_name: str, chunk_mode: str, signature: str) -> str: | |
| os.makedirs(EMBED_CACHE_DIR, exist_ok=True) | |
| fname = f"{_sanitize_name(embed_name)}__{_sanitize_name(chunk_mode)}__{signature}.npy" | |
| return os.path.join(EMBED_CACHE_DIR, fname) | |
| def _get_simple_embed_cache_file(embed_name: str, chunk_mode: str) -> str: | |
| """Cache path keyed only by (model, chunk_mode).""" | |
| os.makedirs(EMBED_CACHE_DIR, exist_ok=True) | |
| fname = f"{_sanitize_name(embed_name)}__{_sanitize_name(chunk_mode)}.npy" | |
| return os.path.join(EMBED_CACHE_DIR, fname) | |
| def _get_simple_bm25_cache_file(chunk_mode: str) -> str: | |
| os.makedirs(EMBED_CACHE_DIR, exist_ok=True) | |
| fname = f"{_sanitize_name('BM25')}__{_sanitize_name(chunk_mode)}.pkl" | |
| return os.path.join(EMBED_CACHE_DIR, fname) | |
| def _tokenize_for_bm25(text: str) -> List[str]: | |
| return re.findall(r"[a-z0-9]+", str(text or "").lower()) | |
| def _load_bm25_from_cache(cache_file: str, pool: List[tuple]): | |
| try: | |
| if not os.path.isfile(cache_file): | |
| return None | |
| with open(cache_file, "rb") as f: | |
| payload = pickle.load(f) | |
| bm25 = payload.get("bm25", None) if isinstance(payload, dict) else None | |
| if bm25 is None: | |
| return None | |
| cached_n = int(payload.get("n_docs", -1)) if isinstance(payload, dict) else -1 | |
| if cached_n != len(pool): | |
| return None | |
| cached_reports = payload.get("report", []) if isinstance(payload, dict) else [] | |
| cached_idx = payload.get("chunk_idx", []) if isinstance(payload, dict) else [] | |
| if len(cached_reports) == len(pool) and len(cached_idx) == len(pool): | |
| for i in range(min(50, len(pool))): | |
| if str(cached_reports[i]) != str(pool[i][0]) or int(cached_idx[i]) != int(pool[i][1]): | |
| return None | |
| return bm25 | |
| except Exception: | |
| return None | |
| def _build_bm25(pool: List[tuple]): | |
| try: | |
| from rank_bm25 import BM25Okapi | |
| except Exception as e: | |
| raise RuntimeError("BM25 requires rank_bm25. Install with: pip install rank_bm25") from e | |
| tokenized = [_tokenize_for_bm25(p[2]) for p in pool] | |
| return BM25Okapi(tokenized) | |
| def _save_bm25_cache(cache_file: str, bm25, pool: List[tuple]) -> None: | |
| try: | |
| payload = { | |
| "bm25": bm25, | |
| "report": [p[0] for p in pool], | |
| "chunk_idx": [int(p[1]) for p in pool], | |
| "n_docs": len(pool), | |
| } | |
| with open(cache_file, "wb") as f: | |
| pickle.dump(payload, f) | |
| except Exception: | |
| pass | |
| def _get_bm25(chunk_mode: str, pool: List[tuple]): | |
| key = f"BM25__{chunk_mode}" | |
| if key in _bm25_cache: | |
| return _bm25_cache[key] | |
| cache_file = _get_simple_bm25_cache_file(chunk_mode) | |
| bm25 = _load_bm25_from_cache(cache_file, pool) | |
| if bm25 is None: | |
| bm25 = _build_bm25(pool) | |
| _save_bm25_cache(cache_file, bm25, pool) | |
| _bm25_cache[key] = bm25 | |
| return bm25 | |
| def _load_cached_embeddings(cache_file: str, expected_rows: int): | |
| try: | |
| if not os.path.isfile(cache_file): | |
| return None | |
| import numpy as np | |
| arr = np.load(cache_file) | |
| if arr.ndim != 2 or arr.shape[0] != expected_rows: | |
| return None | |
| return arr | |
| except Exception: | |
| return None | |
| def _save_cached_embeddings(cache_file: str, embeddings) -> None: | |
| try: | |
| import numpy as np | |
| np.save(cache_file, embeddings) | |
| except Exception: | |
| # Best-effort caching; retrieval should still work without disk persistence. | |
| pass | |
| def _resolve_api_key(api_key: str) -> str: | |
| key = str(api_key or "").strip() | |
| if key: | |
| return key | |
| env_key = str(os.getenv("OPENAI_API_KEY", "")).strip() | |
| if env_key: | |
| return env_key | |
| return str(os.getenv("OPENAI_API_KEY_88996", "")).strip() | |
| def _resolve_gemini_api_key(api_key: str) -> str: | |
| env_key = str(os.getenv("GEMINI_API_KEY", "")).strip() | |
| if env_key: | |
| return env_key | |
| return str(api_key or "").strip() | |
| def _is_openai_embedding_model(embed_name: str) -> bool: | |
| return str(embed_name or "").strip() in OPENAI_EMBED_MODELS | |
| def _is_api_generation_model(gen_model: str) -> bool: | |
| text = str(gen_model or "").strip() | |
| up = text.upper() | |
| if text in API_GEN_MODEL_ALIASES: | |
| return True | |
| return ("GPT" in up) or ("GEMINI" in up) or ("(API)" in up) | |
| def _resolve_api_generation_model_name(gen_model: str) -> str: | |
| raw = str(gen_model or "").strip() | |
| if raw in API_GEN_MODEL_ALIASES: | |
| return API_GEN_MODEL_ALIASES[raw] | |
| return raw or OPENAI_CHAT_MODEL | |
| def _is_gemini_generation_model(gen_model: str) -> bool: | |
| model = str(_resolve_api_generation_model_name(gen_model) or "").strip().lower() | |
| return "gemini" in model | |
| def _extract_gemini_text(response: Any) -> str: | |
| text = getattr(response, "text", None) | |
| if isinstance(text, str) and text.strip(): | |
| return text.strip() | |
| candidates = getattr(response, "candidates", None) | |
| if isinstance(candidates, list): | |
| out = [] | |
| for c in candidates: | |
| content = getattr(c, "content", None) | |
| parts = getattr(content, "parts", None) if content is not None else None | |
| if not isinstance(parts, list): | |
| continue | |
| for p in parts: | |
| t = getattr(p, "text", None) | |
| if isinstance(t, str) and t.strip(): | |
| out.append(t.strip()) | |
| if out: | |
| return "\n".join(out).strip() | |
| return "" | |
| def _call_openai_chat_model( | |
| model_name: str, | |
| system_msg: str, | |
| user_msg: str, | |
| api_key: str, | |
| max_tokens: int, | |
| ) -> str: | |
| resolved_key = _resolve_api_key(api_key) | |
| if not resolved_key: | |
| return "\u26a0\ufe0f Please provide an OpenAI API key to use OpenAI API generation models." | |
| try: | |
| import openai | |
| except ImportError: | |
| return "\u26a0\ufe0f The 'openai' package is not installed. Run: pip install openai" | |
| try: | |
| openai_base_url = str(os.getenv("OPENAI_BASE_URL", "")).strip() | |
| if openai_base_url: | |
| client = openai.OpenAI(api_key=resolved_key, base_url=openai_base_url.rstrip("/")) | |
| else: | |
| client = openai.OpenAI(api_key=resolved_key) | |
| response = client.chat.completions.create( | |
| model=str(model_name).strip(), | |
| messages=[ | |
| {"role": "system", "content": str(system_msg or "")}, | |
| {"role": "user", "content": str(user_msg or "")}, | |
| ], | |
| temperature=0, | |
| max_tokens=int(max_tokens), | |
| timeout=120.0, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"\u26a0\ufe0f API Error: {e}" | |
| def _call_gemini_chat_model( | |
| model_name: str, | |
| system_msg: str, | |
| user_msg: str, | |
| api_key: str, | |
| max_tokens: int, | |
| ) -> str: | |
| resolved_key = _resolve_gemini_api_key(api_key) | |
| if not resolved_key: | |
| return "\u26a0\ufe0f Please provide a Gemini API key (GEMINI_API_KEY) to use Gemini API generation models." | |
| try: | |
| from google import genai | |
| except ImportError: | |
| return "\u26a0\ufe0f The 'google-genai' package is not installed. Run: pip install google-genai" | |
| try: | |
| client = genai.Client(api_key=resolved_key) | |
| full_prompt = ( | |
| f"System instruction:\n{str(system_msg or '').strip()}\n\nUser task:\n{str(user_msg or '').strip()}" | |
| ).strip() | |
| config = {"max_output_tokens": int(max_tokens)} if int(max_tokens) > 0 else None | |
| if config: | |
| response = client.models.generate_content( | |
| model=str(model_name).strip(), | |
| contents=full_prompt, | |
| config=config, | |
| ) | |
| else: | |
| response = client.models.generate_content( | |
| model=str(model_name).strip(), | |
| contents=full_prompt, | |
| ) | |
| text = _extract_gemini_text(response) | |
| if text: | |
| return text | |
| return "\u26a0\ufe0f API Error: Gemini returned empty content." | |
| except Exception as e: | |
| return f"\u26a0\ufe0f API Error: {e}" | |
| def get_openai_client(api_key: str, base_url: str = OPENAI_EMBED_BASE_URL): | |
| resolved = _resolve_api_key(api_key) | |
| if not resolved: | |
| raise RuntimeError("Missing API key.") | |
| try: | |
| from openai import OpenAI | |
| except Exception as e: | |
| raise RuntimeError(f"openai package is required for API embedding/generation: {e}") from e | |
| kwargs = {"api_key": resolved} | |
| if base_url: | |
| kwargs["base_url"] = str(base_url).rstrip("/") | |
| return OpenAI(**kwargs) | |
| def get_embedding(text: str, model_name: str, client): | |
| response = client.embeddings.create( | |
| model=model_name, | |
| input=text, | |
| ) | |
| vectors = _extract_embedding_vectors(response) | |
| if not vectors: | |
| raise RuntimeError("Embedding API returned no vectors.") | |
| return vectors[0] | |
| def _extract_embedding_vectors(response: Any) -> List[List[float]]: | |
| """Handle SDK/object/dict/string responses from OpenAI-compatible gateways.""" | |
| payload: Any = response | |
| if isinstance(payload, str): | |
| try: | |
| payload = json.loads(payload) | |
| except Exception: | |
| raise RuntimeError( | |
| f"Unexpected embedding response string (not JSON): {payload[:240]}" | |
| ) | |
| if hasattr(payload, "model_dump"): | |
| try: | |
| payload = payload.model_dump() | |
| except Exception: | |
| pass | |
| data = None | |
| if hasattr(payload, "data"): | |
| data = payload.data | |
| elif isinstance(payload, dict): | |
| data = payload.get("data", None) | |
| if data is None: | |
| if isinstance(payload, list): | |
| data = payload | |
| else: | |
| raise RuntimeError( | |
| "Unexpected embedding response format: cannot find `data` field." | |
| ) | |
| vectors: List[List[float]] = [] | |
| for item in data: | |
| vec = None | |
| if hasattr(item, "embedding"): | |
| vec = item.embedding | |
| elif isinstance(item, dict): | |
| vec = item.get("embedding", None) | |
| elif isinstance(item, list): | |
| vec = item | |
| if vec is None: | |
| continue | |
| if hasattr(vec, "tolist"): | |
| vec = vec.tolist() | |
| vectors.append([float(x) for x in vec]) | |
| return vectors | |
| def _encode_openai_embeddings( | |
| texts: List[str], | |
| model_name: str, | |
| api_key: str, | |
| base_url: str = OPENAI_EMBED_BASE_URL, | |
| batch_size: int = 32, | |
| show_progress: bool = False, | |
| progress_desc: str = "Embedding", | |
| ): | |
| if not texts: | |
| import numpy as np | |
| return np.zeros((0, 1), dtype="float32") | |
| client = get_openai_client(api_key=api_key, base_url=base_url) | |
| vectors = [] | |
| step = max(1, int(batch_size)) | |
| starts = range(0, len(texts), step) | |
| if show_progress: | |
| try: | |
| from tqdm.auto import tqdm | |
| starts = tqdm(starts, total=(len(texts) + step - 1) // step, desc=progress_desc, unit="batch") | |
| except Exception: | |
| pass | |
| for i in starts: | |
| batch = [str(t or "") for t in texts[i : i + max(1, int(batch_size))]] | |
| resp = client.embeddings.create(model=model_name, input=batch) | |
| vectors.extend(_extract_embedding_vectors(resp)) | |
| if not vectors: | |
| raise RuntimeError("Embedding API returned empty vector list.") | |
| import numpy as np | |
| return np.asarray(vectors, dtype="float32") | |
| def _safe_float(value: Any, default: float = 0.0) -> float: | |
| try: | |
| return float(value) | |
| except Exception: | |
| return default | |
| def _clamp01(value: Any) -> float: | |
| v = _safe_float(value, default=0.0) | |
| return max(0.0, min(1.0, v)) | |
| def _mean(values: List[float], default: float = 0.0) -> float: | |
| if not values: | |
| return default | |
| return float(sum(values) / len(values)) | |
| def _normalize_scores_unit_interval(values: List[float]) -> List[float]: | |
| if not values: | |
| return [] | |
| try: | |
| import numpy as np | |
| arr = np.asarray(values, dtype="float64") | |
| finite = np.isfinite(arr) | |
| if not finite.any(): | |
| return [0.0 for _ in values] | |
| safe = arr.copy() | |
| safe[~finite] = np.nan | |
| lo = np.nanmin(safe) | |
| hi = np.nanmax(safe) | |
| if not np.isfinite(lo) or not np.isfinite(hi): | |
| return [0.0 for _ in values] | |
| if abs(hi - lo) <= 1e-12: | |
| base = 1.0 if hi > 0 else 0.0 | |
| return [base for _ in values] | |
| out = (safe - lo) / (hi - lo) | |
| out = np.clip(out, 0.0, 1.0) | |
| out[~finite] = 0.0 | |
| return [float(x) for x in out.tolist()] | |
| except Exception: | |
| return [0.0 for _ in values] | |
| def _average_similarity(contexts: List[Dict]) -> float: | |
| vals = [_clamp01(c.get("score", 0.0)) for c in contexts if isinstance(c, dict)] | |
| return _mean(vals, default=0.0) | |
| def _extract_json_payload(text: str): | |
| raw = str(text or "").strip() | |
| if not raw: | |
| return None | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| pass | |
| for block in re.findall(r"```(?:json)?\s*([\s\S]*?)```", raw, flags=re.IGNORECASE): | |
| block = block.strip() | |
| if not block: | |
| continue | |
| try: | |
| return json.loads(block) | |
| except Exception: | |
| continue | |
| marker_match = re.search(r"final\s+answer\s+in\s+json\s*:\s*", raw, flags=re.IGNORECASE) | |
| if marker_match: | |
| tail = raw[marker_match.end():].strip() | |
| decoder = json.JSONDecoder() | |
| try: | |
| obj, _ = decoder.raw_decode(tail) | |
| return obj | |
| except Exception: | |
| pass | |
| decoder = json.JSONDecoder() | |
| candidates = [] | |
| for i, ch in enumerate(raw): | |
| if ch not in "{[": | |
| continue | |
| try: | |
| obj, end = decoder.raw_decode(raw[i:]) | |
| candidates.append((obj, end)) | |
| except Exception: | |
| continue | |
| if not candidates: | |
| return None | |
| candidates.sort(key=lambda x: x[1], reverse=True) | |
| return candidates[0][0] | |
| def _first_sentence(text: str) -> str: | |
| raw = str(text or "").strip().replace("\n", " ") | |
| if not raw: | |
| return "" | |
| m = re.split(r"(?<=[.!?])\s+", raw) | |
| if m: | |
| return m[0].strip() | |
| return raw[:220].strip() | |
| def extract_yes_no_label(text: str) -> str: | |
| payload = _extract_json_payload(text) | |
| if isinstance(payload, dict): | |
| conc = payload.get("conclusion") | |
| if isinstance(conc, str): | |
| text = conc | |
| elif isinstance(payload.get("answer"), str): | |
| text = payload.get("answer") | |
| s = str(text or "").strip() | |
| m = re.search(r"\[(YES|NO)\]", s, flags=re.IGNORECASE) | |
| if m: | |
| return m.group(1).upper() | |
| m = re.match(r"\s*(YES|NO)\b", s, flags=re.IGNORECASE) | |
| if m: | |
| return m.group(1).upper() | |
| upper = s.upper() | |
| pos_yes = upper.find("YES") | |
| pos_no = upper.find("NO") | |
| if pos_yes >= 0 and (pos_no < 0 or pos_yes < pos_no): | |
| return "YES" | |
| if pos_no >= 0 and (pos_yes < 0 or pos_no < pos_yes): | |
| return "NO" | |
| return "UNKNOWN" | |
| def _run_model_prompt( | |
| prompt: str, | |
| gen_model: str, | |
| api_key: str = "", | |
| max_tokens: int = 1024, | |
| ) -> str: | |
| if _is_api_generation_model(gen_model): | |
| chat_model = _resolve_api_generation_model_name(gen_model) | |
| system_msg = ( | |
| "You are a strict JSON-oriented climate disclosure analysis assistant. " | |
| "Follow the user task exactly." | |
| ) | |
| if _is_gemini_generation_model(chat_model): | |
| return _call_gemini_chat_model( | |
| model_name=chat_model, | |
| system_msg=system_msg, | |
| user_msg=prompt, | |
| api_key=api_key, | |
| max_tokens=int(max_tokens), | |
| ) | |
| return _call_openai_chat_model( | |
| model_name=chat_model, | |
| system_msg=system_msg, | |
| user_msg=prompt, | |
| api_key=api_key, | |
| max_tokens=int(max_tokens), | |
| ) | |
| if MOCK_MODE: | |
| return _mock_model_prompt(prompt) | |
| llm, _ = _get_llm() | |
| from vllm import SamplingParams | |
| sampling = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=int(max_tokens)) | |
| outputs = llm.generate([prompt], sampling) | |
| return outputs[0].outputs[0].text.strip() | |
| def _resolve_evidence_index(ref: Any, n_contexts: int): | |
| if n_contexts <= 0: | |
| return None | |
| idx = None | |
| if isinstance(ref, int): | |
| idx = ref | |
| elif isinstance(ref, str): | |
| m = re.search(r"(\d+)", ref) | |
| if m: | |
| idx = int(m.group(1)) | |
| elif isinstance(ref, dict): | |
| for key in ("id", "evidence_id", "index", "evidence_index"): | |
| if key in ref: | |
| return _resolve_evidence_index(ref.get(key), n_contexts) | |
| if idx is None: | |
| return None | |
| if 1 <= idx <= n_contexts: | |
| return idx - 1 | |
| if 0 <= idx < n_contexts: | |
| return idx | |
| return None | |
| def _build_cluster_prompt(question: str, contexts: List[Dict]) -> str: | |
| lines = [] | |
| for i, c in enumerate(contexts, start=1): | |
| report = c.get("report", "unknown") | |
| chunk_idx = c.get("chunk_idx", "NA") | |
| score = _clamp01(c.get("score", 0.0)) | |
| text = str(c.get("text", "")).replace("\n", " ").strip() | |
| lines.append( | |
| f"[E{i}] report={report}; paragraph={chunk_idx}; similarity={score:.4f}; text={text[:650]}" | |
| ) | |
| schema = ( | |
| '{"skill":"Evidence Clusterer","clusters":[{"cluster_name":"string","summary":"string",' | |
| '"evidence_ids":["E1"]}]}' | |
| ) | |
| return ( | |
| "__TASK__: EVIDENCE_CLUSTERING\n" | |
| "You are an Evidence Clusterer for trustworthy RAG.\n" | |
| "Task: cluster retrieved evidence by semantic theme and summarize each cluster.\n" | |
| "Rules: use only provided evidence; do not fabricate ids.\n" | |
| "Use clear semantic cluster names such as Governance & Oversight, Risk & Scenario Analysis, " | |
| "Emissions Metrics, Strategy & Transition Plan, Targets & Progress, Value Chain Engagement.\n" | |
| "Do not use generic names like Theme 1 / Cluster 2 / Semantic Cluster.\n" | |
| "Return between 2 and 5 clusters whenever evidence size allows (n>=4).\n" | |
| "If evidence count >= 4, do NOT return only one cluster.\n" | |
| f"Return strict JSON only with schema: {schema}\n\n" | |
| f"Question: {question}\n\n" | |
| "Evidence:\n" | |
| + "\n".join(lines) | |
| ) | |
| def _choose_cluster_k(n_contexts: int) -> int: | |
| if n_contexts <= 3: | |
| return max(1, n_contexts) | |
| if n_contexts <= 6: | |
| return 2 | |
| if n_contexts <= 12: | |
| return 3 | |
| if n_contexts <= 18: | |
| return 4 | |
| return 5 | |
| _SEMANTIC_CLUSTER_LABELS = [ | |
| (("governance", "board", "oversight", "committee", "accountability"), "Governance & Oversight"), | |
| (("risk", "scenario", "ngfs", "ipcc", "stress test", "resilience", "transition risk", "physical risk"), "Risk & Scenario Analysis"), | |
| (("scope 1", "scope 2", "scope 3", "emission", "ghg", "carbon", "intensity", "kpi", "metric"), "Emissions Metrics"), | |
| (("target", "baseline", "attainment", "goal", "net zero", "roadmap"), "Targets & Progress"), | |
| (("strategy", "transition plan", "investment", "capex", "allocation"), "Strategy & Transition Plan"), | |
| (("supplier", "downstream", "upstream", "value chain", "partner", "procurement"), "Value Chain Engagement"), | |
| (("water", "waste", "recycl", "pollution", "withdrawal", "discharge"), "Waste & Water Management"), | |
| (("gri", "sasb", "tcfd", "ifrs", "csrd", "framework", "standard", "assurance", "audit"), "Disclosure Standards & Assurance"), | |
| ] | |
| def _is_generic_cluster_name(name: str) -> bool: | |
| low = str(name or "").strip().lower() | |
| if not low: | |
| return True | |
| if any(x in low for x in ("semantic cluster", "general evidence", "theme", "cluster", "topic", "group")): | |
| cleaned = re.sub(r"\s+", " ", low) | |
| if re.fullmatch(r"(cluster|theme|topic|group)\s*[a-z0-9\- ]*", cleaned): | |
| return True | |
| if cleaned in {"general evidence", "semantic cluster"}: | |
| return True | |
| if "theme " in cleaned or "cluster " in cleaned: | |
| return True | |
| return False | |
| def _semantic_cluster_name(seed_text: str, idx: int) -> str: | |
| low = str(seed_text or "").lower() | |
| for keywords, label in _SEMANTIC_CLUSTER_LABELS: | |
| if any(k in low for k in keywords): | |
| return label | |
| return f"Disclosure Focus {idx + 1}" | |
| def _pick_cluster_name(raw_name: str, summary: str, evidence_text: str, idx: int) -> str: | |
| name = str(raw_name or "").strip() | |
| name = re.sub(r"^\s*cluster\s*[a-z0-9]+\s*[-:]\s*", "", name, flags=re.IGNORECASE).strip() | |
| if name and not _is_generic_cluster_name(name) and len(name) <= 64: | |
| return name | |
| seed = " ".join([str(summary or ""), str(evidence_text or "")]).strip() | |
| return _semantic_cluster_name(seed, idx) | |
| def _dedupe_cluster_names(clusters: List[Dict]) -> List[Dict]: | |
| seen = {} | |
| for c in clusters: | |
| base = str(c.get("cluster_name", "")).strip() or "Disclosure Focus" | |
| count = seen.get(base, 0) + 1 | |
| seen[base] = count | |
| if count > 1: | |
| c["cluster_name"] = f"{base} ({count})" | |
| return clusters | |
| def _local_semantic_clusters(contexts: List[Dict], fallback_conf: float) -> List[Dict]: | |
| n = len(contexts) | |
| if n == 0: | |
| return [] | |
| if n <= 3: | |
| # Keep one cluster for tiny evidence sets. | |
| tiny_summary = _first_sentence(" ".join(str(c.get("text", "")) for c in contexts)) or "Top retrieved evidence grouped into one cluster." | |
| tiny_name = _pick_cluster_name("", tiny_summary, tiny_summary, 0) | |
| return [ | |
| { | |
| "cluster_name": tiny_name, | |
| "summary": tiny_summary, | |
| "cluster_confidence": fallback_conf, | |
| "evidence": [ | |
| { | |
| "evidence_id": f"E{i + 1}", | |
| "report": c.get("report", ""), | |
| "chunk_idx": c.get("chunk_idx", ""), | |
| "page": c.get("page", None), | |
| "text": str(c.get("text", "")), | |
| "score": _clamp01(c.get("score", 0.0)), | |
| } | |
| for i, c in enumerate(contexts) | |
| ], | |
| } | |
| ] | |
| k = _choose_cluster_k(n) | |
| texts = [str(c.get("text", "")) for c in contexts] | |
| labels = None | |
| try: | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.cluster import KMeans | |
| vec = TfidfVectorizer(max_features=1500, stop_words="english") | |
| x = vec.fit_transform(texts) | |
| km = KMeans(n_clusters=k, random_state=42, n_init=10) | |
| labels = km.fit_predict(x).tolist() | |
| except Exception: | |
| # fallback: simple hash bucketing to still produce 2-5 groups | |
| labels = [i % k for i in range(n)] | |
| by_label = {} | |
| for idx, lb in enumerate(labels): | |
| by_label.setdefault(lb, []).append(idx) | |
| clusters = [] | |
| for lb, idxs in sorted(by_label.items(), key=lambda x: -len(x[1])): | |
| evidence = [] | |
| for idx in idxs: | |
| c = contexts[idx] | |
| evidence.append( | |
| { | |
| "evidence_id": f"E{idx + 1}", | |
| "report": c.get("report", ""), | |
| "chunk_idx": c.get("chunk_idx", ""), | |
| "page": c.get("page", None), | |
| "text": str(c.get("text", "")), | |
| "score": _clamp01(c.get("score", 0.0)), | |
| } | |
| ) | |
| avg_score = _mean([_clamp01(e.get("score", 0.0)) for e in evidence], default=fallback_conf) | |
| top_text = str(evidence[0].get("text", "")) if evidence else "" | |
| summary = _first_sentence(top_text) or "Semantically related evidence." | |
| joined_seed = " ".join(str(e.get("text", "")) for e in evidence[:3]) | |
| clusters.append( | |
| { | |
| "cluster_name": _pick_cluster_name("", summary, joined_seed, len(clusters)), | |
| "summary": summary, | |
| "cluster_confidence": avg_score, | |
| "evidence": evidence, | |
| } | |
| ) | |
| # enforce 2-5 when n>=4 | |
| if len(clusters) < 2 and n >= 4: | |
| mid = n // 2 | |
| left = list(range(0, mid)) | |
| right = list(range(mid, n)) | |
| split = [] | |
| for grp in (left, right): | |
| evidence = [] | |
| for idx in grp: | |
| c = contexts[idx] | |
| evidence.append( | |
| { | |
| "evidence_id": f"E{idx + 1}", | |
| "report": c.get("report", ""), | |
| "chunk_idx": c.get("chunk_idx", ""), | |
| "page": c.get("page", None), | |
| "text": str(c.get("text", "")), | |
| "score": _clamp01(c.get("score", 0.0)), | |
| } | |
| ) | |
| avg_score = _mean([_clamp01(e.get("score", 0.0)) for e in evidence], default=fallback_conf) | |
| summary = _first_sentence(str(evidence[0].get("text", ""))) if evidence else "Related evidence." | |
| joined_seed = " ".join(str(e.get("text", "")) for e in evidence[:3]) | |
| split.append( | |
| { | |
| "cluster_name": _pick_cluster_name("", summary or "", joined_seed, len(split)), | |
| "summary": summary or "Related evidence.", | |
| "cluster_confidence": avg_score, | |
| "evidence": evidence, | |
| } | |
| ) | |
| clusters = split | |
| if len(clusters) > 5: | |
| # Keep top-4, merge rest into cluster 5. | |
| kept = clusters[:4] | |
| rest = clusters[4:] | |
| merged_evidence = [] | |
| for c in rest: | |
| merged_evidence.extend(c.get("evidence", [])) | |
| merged_score = _mean([_clamp01(e.get("score", 0.0)) for e in merged_evidence], default=fallback_conf) | |
| merged_seed = " ".join(str(e.get("text", "")) for e in merged_evidence[:3]) | |
| kept.append( | |
| { | |
| "cluster_name": _pick_cluster_name("", "Merged tail semantic groups.", merged_seed, 4), | |
| "summary": "Merged tail semantic groups.", | |
| "cluster_confidence": merged_score, | |
| "evidence": merged_evidence, | |
| } | |
| ) | |
| clusters = kept | |
| return _dedupe_cluster_names(clusters) | |
| def _build_single_fallback_cluster(contexts: List[Dict]) -> List[Dict]: | |
| if not contexts: | |
| return [] | |
| summary = _first_sentence(" ".join(str(c.get("text", "")) for c in contexts[:3])) or "Top retrieved evidence cluster." | |
| return [ | |
| { | |
| "cluster_name": "Retrieved Evidence", | |
| "summary": summary, | |
| "evidence": [ | |
| { | |
| "evidence_id": f"E{i + 1}", | |
| "report": c.get("report", ""), | |
| "chunk_idx": c.get("chunk_idx", ""), | |
| "page": c.get("page", None), | |
| "text": str(c.get("text", "")), | |
| "score": _clamp01(c.get("score", 0.0)), | |
| } | |
| for i, c in enumerate(contexts) | |
| ], | |
| } | |
| ] | |
| def _parse_cluster_result(output_text: str, contexts: List[Dict]) -> Dict[str, Any]: | |
| payload = _extract_json_payload(output_text) | |
| n_contexts = len(contexts) | |
| clusters = [] | |
| if isinstance(payload, dict): | |
| raw_clusters = payload.get("clusters", []) | |
| if isinstance(raw_clusters, list): | |
| for i, rc in enumerate(raw_clusters): | |
| if not isinstance(rc, dict): | |
| continue | |
| summary = str(rc.get("summary", "")).strip() | |
| refs = rc.get("evidence_ids", rc.get("evidence_indices", [])) | |
| if not isinstance(refs, list): | |
| refs = [refs] if refs not in (None, "") else [] | |
| evidence_items = [] | |
| used = set() | |
| for ref in refs: | |
| idx = _resolve_evidence_index(ref, n_contexts) | |
| if idx is None or idx in used: | |
| continue | |
| used.add(idx) | |
| ctx = contexts[idx] | |
| evidence_items.append({ | |
| "evidence_id": f"E{idx + 1}", | |
| "report": ctx.get("report", ""), | |
| "chunk_idx": ctx.get("chunk_idx", ""), | |
| "page": ctx.get("page", None), | |
| "text": str(ctx.get("text", "")), | |
| "score": _clamp01(ctx.get("score", 0.0)), | |
| }) | |
| if not evidence_items and i < n_contexts: | |
| ctx = contexts[i] | |
| evidence_items.append({ | |
| "evidence_id": f"E{i + 1}", | |
| "report": ctx.get("report", ""), | |
| "chunk_idx": ctx.get("chunk_idx", ""), | |
| "page": ctx.get("page", None), | |
| "text": str(ctx.get("text", "")), | |
| "score": _clamp01(ctx.get("score", 0.0)), | |
| }) | |
| evidence_seed = " ".join(str(e.get("text", "")) for e in evidence_items[:3]) | |
| name = _pick_cluster_name(rc.get("cluster_name", ""), summary, evidence_seed, i) | |
| clusters.append({ | |
| "cluster_name": name, | |
| "summary": summary, | |
| "evidence": evidence_items, | |
| }) | |
| fallback_conf = _average_similarity(contexts) | |
| if clusters: | |
| normalized = [] | |
| for i, c in enumerate(clusters): | |
| summary = str(c.get("summary", "")) | |
| evidence_seed = " ".join(str(e.get("text", "")) for e in c.get("evidence", [])[:3]) | |
| c["cluster_name"] = _pick_cluster_name(c.get("cluster_name", ""), summary, evidence_seed, i) | |
| normalized.append(c) | |
| clusters = _dedupe_cluster_names(normalized) | |
| if len(contexts) >= 4 and len(clusters) < 2: | |
| clusters = _local_semantic_clusters(contexts, fallback_conf=fallback_conf) | |
| else: | |
| # If model output is missing/invalid, force semantic multi-cluster fallback when evidence is enough. | |
| if len(contexts) >= 4: | |
| clusters = _local_semantic_clusters(contexts, fallback_conf=fallback_conf) | |
| else: | |
| clusters = _build_single_fallback_cluster(contexts) | |
| return { | |
| "clusters": clusters, | |
| "raw_output": output_text, | |
| "payload": payload, | |
| } | |
| def _build_claim_prompt( | |
| question: str, | |
| answer: str, | |
| doc_mode: str, | |
| n_claims: int, | |
| report_names: List[str], | |
| clusters: List[Dict], | |
| ) -> str: | |
| report_text = ", ".join(report_names) if report_names else "N/A" | |
| cluster_lines = [] | |
| for i, c in enumerate(clusters, start=1): | |
| cluster_lines.append( | |
| f"- Cluster {i} ({c.get('cluster_name', 'Unknown')}): {c.get('summary', '')}" | |
| ) | |
| cluster_block = "\n".join(cluster_lines[:8]) | |
| schema = '{"claims":[{"claim":"string"}]}' | |
| return ( | |
| "__TASK__: CLAIM_GENERATION\n" | |
| "You are the Claim Extractor in a trustworthy RAG pipeline.\n" | |
| "Generate concise claims that represent the answer's key arguments.\n" | |
| "Rules:\n" | |
| "- Return exactly N_CLAIMS claims.\n" | |
| "- If multi-document: provide one claim per company/report, and the final claim must be overall conclusion.\n" | |
| "- If single-document: provide 3 atomic claims.\n" | |
| "- Keep each claim <= 28 words.\n" | |
| f"Return strict JSON only with schema: {schema}\n\n" | |
| f"DOC_MODE: {doc_mode}\n" | |
| f"N_CLAIMS: {n_claims}\n" | |
| f"REPORTS: {report_text}\n" | |
| f"Question: {question}\n\n" | |
| f"Answer:\n{answer}\n\n" | |
| f"Cluster summaries:\n{cluster_block}" | |
| ) | |
| def _contexts_from_clusters(clusters: List[Dict], fallback_contexts: List[Dict]) -> List[Dict]: | |
| picked = {} | |
| order = [] | |
| for c in clusters: | |
| if not isinstance(c, dict): | |
| continue | |
| for ev in c.get("evidence", []): | |
| if not isinstance(ev, dict): | |
| continue | |
| report = str(ev.get("report", "")) | |
| chunk_idx = ev.get("chunk_idx", "") | |
| text = str(ev.get("text", "")) | |
| if not text: | |
| continue | |
| key = (report, str(chunk_idx)) | |
| if key not in picked: | |
| picked[key] = { | |
| "report": report, | |
| "chunk_idx": chunk_idx, | |
| "text": text, | |
| "score": _clamp01(ev.get("score", 0.0)), | |
| } | |
| order.append(key) | |
| else: | |
| picked[key]["score"] = max(picked[key]["score"], _clamp01(ev.get("score", 0.0))) | |
| if not order: | |
| return fallback_contexts | |
| out = [picked[k] for k in order] | |
| out.sort(key=lambda x: x.get("score", 0.0), reverse=True) | |
| return out | |
| def _infer_multi_company_count(question: str, report_names: List[str]) -> int: | |
| fallback = max(1, len(report_names)) | |
| q = str(question or "").lower() | |
| q = re.sub(r"\"[^\"]+\"", " ", q) | |
| connectors = q.count(" or ") + q.count(" and ") | |
| commas = q.count(",") | |
| guessed = 0 | |
| if connectors > 0 or commas > 0: | |
| guessed = max(connectors + 1, commas + 1) | |
| if guessed < 2: | |
| guessed = fallback | |
| return max(1, min(8, guessed)) | |
| def _sanitize_claim_text(text: str) -> str: | |
| t = str(text or "").strip() | |
| if not t: | |
| return "" | |
| if len(t) > 500: | |
| t = _first_sentence(t) | |
| if (t.startswith("{") and t.endswith("}")) or (t.startswith("[") and t.endswith("]")): | |
| payload = _extract_json_payload(t) | |
| if isinstance(payload, dict): | |
| if isinstance(payload.get("conclusion"), str) and payload.get("conclusion", "").strip(): | |
| return payload["conclusion"].strip() | |
| if isinstance(payload.get("answer"), str) and payload.get("answer", "").strip(): | |
| return payload["answer"].strip() | |
| return "Structured result supports an evidence-based claim." | |
| if isinstance(payload, list): | |
| return "Structured result supports an evidence-based claim." | |
| return t | |
| def _extract_conclusion_text_from_answer(answer: str) -> str: | |
| payload = _extract_json_payload(answer) | |
| if isinstance(payload, dict): | |
| conc = payload.get("conclusion") | |
| if isinstance(conc, str) and conc.strip(): | |
| return conc.strip() | |
| if isinstance(conc, dict): | |
| lbl = str(conc.get("label", "")).strip() | |
| reason = str(conc.get("reason", "")).strip() | |
| if lbl or reason: | |
| return f"{lbl} {reason}".strip() | |
| ans = payload.get("answer") | |
| if isinstance(ans, str) and ans.strip(): | |
| return ans.strip() | |
| raw = str(answer or "").strip() | |
| m = re.search(r"(?:^|\n)\s*conclusion\s*:\s*(.+)$", raw, flags=re.IGNORECASE) | |
| if m: | |
| return m.group(1).strip() | |
| return _first_sentence(raw) | |
| def _parse_claims(output_text: str, answer: str, n_claims: int, doc_mode: str, report_names: List[str]) -> List[str]: | |
| payload = _extract_json_payload(output_text) | |
| claims = [] | |
| raw_claims = None | |
| if isinstance(payload, dict): | |
| raw_claims = payload.get("claims") | |
| elif isinstance(payload, list): | |
| raw_claims = payload | |
| if isinstance(raw_claims, list): | |
| for item in raw_claims: | |
| if isinstance(item, str): | |
| claim = _sanitize_claim_text(item) | |
| elif isinstance(item, dict): | |
| claim = _sanitize_claim_text(item.get("claim", "")) | |
| else: | |
| claim = "" | |
| if claim and claim not in claims: | |
| claims.append(claim) | |
| if not claims: | |
| sentence = _first_sentence(answer) or "Evidence supports part of the question." | |
| if doc_mode == "multi": | |
| for rn in report_names: | |
| claims.append(f"{rn.replace('.pdf', '')}: disclosure evidence is mixed and must be interpreted conservatively.") | |
| claims.append(f"Conclusion: {_extract_conclusion_text_from_answer(answer) or sentence}") | |
| else: | |
| claims = [ | |
| sentence, | |
| "Retrieved evidence provides partial quantitative support.", | |
| "Some sub-claims remain weakly supported or uncertain.", | |
| ] | |
| if doc_mode == "multi": | |
| target_n = max(2, n_claims) | |
| else: | |
| target_n = max(3, n_claims) | |
| if doc_mode == "multi": | |
| conclusion_text = _extract_conclusion_text_from_answer(answer) | |
| if not conclusion_text: | |
| conclusion_text = _first_sentence(answer) or "No clear conclusion." | |
| conclusion_claim = f"Conclusion: {conclusion_text}".strip() | |
| body_target = max(1, target_n - 1) | |
| body_claims = [c for c in claims if c and not str(c).strip().lower().startswith("conclusion:")] | |
| body_claims = body_claims[:body_target] | |
| while len(body_claims) < body_target: | |
| body_claims.append(f"Additional claim {len(body_claims) + 1}: evidence is limited.") | |
| return body_claims + [conclusion_claim] | |
| claims = claims[:target_n] | |
| while len(claims) < target_n: | |
| claims.append(f"Additional claim {len(claims) + 1}: evidence is limited.") | |
| return claims | |
| def _build_faithfulness_prompt(claims: List[str], clusters: List[Dict]) -> str: | |
| claim_lines = [f"C{i+1}: {c}" for i, c in enumerate(claims)] | |
| cluster_lines = [] | |
| for i, c in enumerate(clusters, start=1): | |
| ev_lines = [] | |
| for ev in c.get("evidence", [])[:4]: | |
| ev_lines.append( | |
| f" - {ev.get('evidence_id', '')} report={ev.get('report', '')} " | |
| f"paragraph={ev.get('chunk_idx', '')}: {str(ev.get('text', ''))[:200]}" | |
| ) | |
| cluster_lines.append( | |
| f"Cluster {i} ({c.get('cluster_name', 'Unknown')}): {c.get('summary', '')}\n" + "\n".join(ev_lines) | |
| ) | |
| schema = ( | |
| '{"claim_scores":[{"claim":"string","score":"number(0-1)","label":"supported|partial|unsupported","reason":"string"}],' | |
| '"faithfulness":"number(0-1)"}' | |
| ) | |
| return ( | |
| "__TASK__: FAITHFULNESS_VERIFIER\n" | |
| "You are a Faithfulness Verifier.\n" | |
| "Score each claim by whether it is supported by clustered evidence.\n" | |
| "Rules: score in [0,1]; unsupported claims should receive low scores.\n" | |
| f"Return strict JSON only with schema: {schema}\n\n" | |
| "Claims:\n" + "\n".join(claim_lines) + "\n\n" | |
| "Clustered evidence:\n" + "\n\n".join(cluster_lines) | |
| ) | |
| def _token_set(text: str) -> set: | |
| return set(re.findall(r"[a-z0-9]{3,}", str(text or "").lower())) | |
| def _parse_faithfulness(output_text: str, claims: List[str], clusters: List[Dict]) -> Dict[str, Any]: | |
| payload = _extract_json_payload(output_text) | |
| scores = [] | |
| if isinstance(payload, dict): | |
| rows = payload.get("claim_scores", []) | |
| if isinstance(rows, list): | |
| for i, row in enumerate(rows): | |
| if not isinstance(row, dict): | |
| continue | |
| claim = str(row.get("claim", "")).strip() | |
| if not claim and i < len(claims): | |
| claim = claims[i] | |
| score = _clamp01(row.get("score", 0.0)) | |
| label = str(row.get("label", "")).strip().lower() | |
| if not label: | |
| label = "supported" if score >= 0.75 else ("partial" if score >= 0.45 else "unsupported") | |
| reason = str(row.get("reason", "")).strip() | |
| scores.append({"claim": claim, "score": score, "label": label, "reason": reason}) | |
| if not scores: | |
| evidence_text = " ".join( | |
| str(ev.get("text", "")) | |
| for c in clusters | |
| for ev in c.get("evidence", []) | |
| if isinstance(ev, dict) | |
| ) | |
| ev_tokens = _token_set(evidence_text) | |
| for claim in claims: | |
| c_tokens = _token_set(claim) | |
| if not c_tokens: | |
| overlap_ratio = 0.0 | |
| else: | |
| overlap_ratio = len(c_tokens & ev_tokens) / max(1, len(c_tokens)) | |
| score = _clamp01(overlap_ratio * 1.4) | |
| label = "supported" if score >= 0.75 else ("partial" if score >= 0.45 else "unsupported") | |
| scores.append({"claim": claim, "score": score, "label": label, "reason": ""}) | |
| while len(scores) < len(claims): | |
| claim = claims[len(scores)] | |
| scores.append({"claim": claim, "score": 0.35, "label": "partial", "reason": ""}) | |
| scores = scores[: len(claims)] | |
| model_faith = None | |
| if isinstance(payload, dict) and payload.get("faithfulness", None) is not None: | |
| model_faith = _clamp01(payload.get("faithfulness")) | |
| avg_faith = _mean([s["score"] for s in scores], default=0.0) | |
| faithfulness = model_faith if model_faith is not None else avg_faith | |
| return { | |
| "claim_scores": scores, | |
| "faithfulness": faithfulness, | |
| "raw_output": output_text, | |
| "payload": payload, | |
| } | |
| def compute_calibrated_confidence(average_similarity: float, cluster_confidence: float, faithfulness: float) -> float: | |
| return _mean( | |
| [_clamp01(average_similarity), _clamp01(cluster_confidence), _clamp01(faithfulness)], | |
| default=0.0, | |
| ) | |
| def _chunk_json_source_file(chunk_mode: str) -> str: | |
| if str(chunk_mode or "").strip().lower() == "structure": | |
| return OCR_CHUNKS_ALL_STRUCTURE_JSON | |
| return OCR_CHUNKS_ALL_LENGTH_JSON | |
| def _load_report_chunks_from_json(chunk_mode: str): | |
| src = _chunk_json_source_file(chunk_mode) | |
| if not os.path.isfile(src): | |
| return None | |
| try: | |
| with open(src, "r", encoding="utf-8") as f: | |
| rows = json.load(f) | |
| except Exception: | |
| return None | |
| if not isinstance(rows, list): | |
| return None | |
| grouped: Dict[str, List[tuple]] = {} | |
| for row in rows: | |
| if not isinstance(row, dict): | |
| continue | |
| report = str(row.get("report", "")).strip() | |
| if not report: | |
| continue | |
| text = str(row.get("chunk_text", row.get("text", "")) or "") | |
| try: | |
| chunk_idx = int(row.get("chunk_idx")) | |
| except Exception: | |
| chunk_idx = len(grouped.get(report, [])) | |
| page = row.get("page", row.get("page_start", None)) | |
| try: | |
| page = int(page) if page not in (None, "", "NA") else None | |
| except Exception: | |
| page = None | |
| grouped.setdefault(report, []).append((chunk_idx, text, page)) | |
| if not grouped: | |
| return None | |
| normalized: Dict[str, List[tuple]] = {} | |
| for report, items in grouped.items(): | |
| dedup = {} | |
| for idx, text, page in items: | |
| if idx in dedup: | |
| # Keep non-empty replacement if duplicate index appears. | |
| if (not str(dedup[idx][0]).strip()) and str(text).strip(): | |
| dedup[idx] = (text, page) | |
| elif dedup[idx][1] is None and page is not None: | |
| dedup[idx] = (dedup[idx][0], page) | |
| else: | |
| dedup[idx] = (text, page) | |
| normalized[report] = sorted( | |
| [(int(idx), str(txt or ""), pg) for idx, (txt, pg) in dedup.items()], | |
| key=lambda x: x[0], | |
| ) | |
| return normalized | |
| def get_report_chunks(chunk_mode: str) -> Dict[str, list]: | |
| if chunk_mode in _report_chunks_cache: | |
| return _report_chunks_cache[chunk_mode] | |
| # Prefer pre-chunked JSON cache for both single-doc and multi-doc queries. | |
| # Retrieval builds one global pool, so this source is shared by both modes. | |
| report_chunks = _load_report_chunks_from_json(chunk_mode) | |
| if isinstance(report_chunks, dict) and report_chunks: | |
| _report_chunks_cache[chunk_mode] = report_chunks | |
| return report_chunks | |
| src = _chunk_json_source_file(chunk_mode) | |
| raise FileNotFoundError( | |
| f"Required chunk JSON is missing or unreadable: {src}. " | |
| "This build does not fallback to MinerU_Reports." | |
| ) | |
| # ======================== Core Functions ======================== | |
| def retrieve( | |
| question: str, | |
| chunk_mode: str = "structure", | |
| doc_mode: str = "single", | |
| top_k: int = 5, | |
| embed_name: str = "Qwen3-Embedding-0.6B", | |
| api_key: str = "", | |
| openai_base_url: str = OPENAI_EMBED_BASE_URL, | |
| ) -> List[Dict]: | |
| openai_base_url = str(openai_base_url or OPENAI_EMBED_BASE_URL).rstrip("/") | |
| use_openai_embed = _is_openai_embedding_model(embed_name) | |
| use_bm25 = str(embed_name or "").strip().upper() == "BM25" | |
| if MOCK_MODE and not use_openai_embed and not use_bm25: | |
| return _mock_retrieve(question, chunk_mode, doc_mode, top_k) | |
| if use_openai_embed and (not _resolve_api_key(api_key)): | |
| raise RuntimeError( | |
| "OpenAI embedding model selected but API key is missing. " | |
| "Please provide OPENAI_API_KEY or pass api_key." | |
| ) | |
| import numpy as np | |
| report_chunks = get_report_chunks(chunk_mode) | |
| pool = _build_chunk_pool(report_chunks) | |
| if not pool: | |
| return [] | |
| if use_bm25: | |
| bm25 = _get_bm25(chunk_mode=chunk_mode, pool=pool) | |
| scores = bm25.get_scores(_tokenize_for_bm25(question)) | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| results = [] | |
| for i in top_indices: | |
| results.append( | |
| { | |
| "report": pool[i][0], | |
| "chunk_idx": pool[i][1], | |
| "text": pool[i][2], | |
| "page": (pool[i][3] if len(pool[i]) > 3 else None), | |
| "score": float(scores[i]), | |
| "raw_score": float(scores[i]), | |
| } | |
| ) | |
| return results | |
| if use_openai_embed: | |
| signature = _chunk_pool_signature(pool) | |
| cache_key = f"{embed_name}__{chunk_mode}" | |
| if cache_key not in _chunk_embeddings_cache: | |
| simple_cache_file = _get_simple_embed_cache_file(embed_name, chunk_mode) | |
| legacy_cache_file = _get_embed_cache_file( | |
| f"{embed_name}__api__{openai_base_url}", | |
| chunk_mode, | |
| signature, | |
| ) | |
| cached = _load_cached_embeddings(simple_cache_file, expected_rows=len(pool)) | |
| if cached is None: | |
| cached = _load_cached_embeddings(legacy_cache_file, expected_rows=len(pool)) | |
| if cached is not None: | |
| _chunk_embeddings_cache[cache_key] = cached | |
| else: | |
| texts = [p[2] for p in pool] | |
| encoded = _encode_openai_embeddings( | |
| texts=texts, | |
| model_name=embed_name, | |
| api_key=api_key, | |
| base_url=openai_base_url, | |
| batch_size=32, | |
| ) | |
| _chunk_embeddings_cache[cache_key] = encoded | |
| _save_cached_embeddings(simple_cache_file, encoded) | |
| doc_emb = _chunk_embeddings_cache[cache_key] | |
| query_emb = _encode_openai_embeddings( | |
| texts=[question], | |
| model_name=embed_name, | |
| api_key=api_key, | |
| base_url=openai_base_url, | |
| batch_size=1, | |
| ) | |
| else: | |
| embed_model = _get_embed_model(embed_name) | |
| signature = _chunk_pool_signature(pool) | |
| cache_key = f"{embed_name}__{chunk_mode}" | |
| if cache_key not in _chunk_embeddings_cache: | |
| simple_cache_file = _get_simple_embed_cache_file(embed_name, chunk_mode) | |
| legacy_cache_file = _get_embed_cache_file(embed_name, chunk_mode, signature) | |
| cached = _load_cached_embeddings(simple_cache_file, expected_rows=len(pool)) | |
| if cached is None: | |
| cached = _load_cached_embeddings(legacy_cache_file, expected_rows=len(pool)) | |
| if cached is not None: | |
| _chunk_embeddings_cache[cache_key] = cached | |
| else: | |
| texts = [p[2] for p in pool] | |
| encoded = embed_model.encode_documents(texts, batch_size=32) | |
| _chunk_embeddings_cache[cache_key] = encoded | |
| _save_cached_embeddings(simple_cache_file, encoded) | |
| doc_emb = _chunk_embeddings_cache[cache_key] | |
| query_emb = embed_model.encode_queries([question], batch_size=1) | |
| if getattr(doc_emb, "ndim", 0) != 2 or getattr(query_emb, "ndim", 0) != 2: | |
| raise RuntimeError("Embedding encoding failed: expected 2D arrays for document/query embeddings.") | |
| if doc_emb.shape[1] != query_emb.shape[1]: | |
| raise RuntimeError( | |
| f"Embedding dimension mismatch: docs={doc_emb.shape[1]}, query={query_emb.shape[1]} " | |
| f"(model={embed_name})." | |
| ) | |
| # Use true cosine similarity for dense retrieval: | |
| # sim(q, d) = (q·d) / (||q|| * ||d||) | |
| query_vec = query_emb[0] | |
| dot_scores = doc_emb @ query_vec | |
| doc_norms = np.linalg.norm(doc_emb, axis=1) | |
| q_norm = float(np.linalg.norm(query_vec)) | |
| denom = doc_norms * max(q_norm, 1e-12) | |
| scores = np.divide( | |
| dot_scores, | |
| denom, | |
| out=np.zeros_like(dot_scores, dtype="float32"), | |
| where=denom > 1e-12, | |
| ) | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| results = [] | |
| for i in top_indices: | |
| results.append({ | |
| "report": pool[i][0], | |
| "chunk_idx": pool[i][1], | |
| "text": pool[i][2], | |
| "page": (pool[i][3] if len(pool[i]) > 3 else None), | |
| "score": float(scores[i]), | |
| "raw_score": float(dot_scores[i]), | |
| }) | |
| return results | |
| def generate( | |
| question: str, | |
| contexts: List[Dict], | |
| doc_mode: str = "single", | |
| gen_model: str = "Qwen3-4B-Instruct-2507-FP8", | |
| api_key: str = "", | |
| ) -> str: | |
| if _is_api_generation_model(gen_model): | |
| return _generate_with_api(question, contexts, doc_mode, gen_model, api_key) | |
| if not HAS_GPU: | |
| return ( | |
| "\u26a0\ufe0f No GPU detected on this host. Local generation models are disabled. " | |
| "Please use an API generation model." | |
| ) | |
| if MOCK_MODE: | |
| return _mock_generate(question, contexts, doc_mode) | |
| from generation import ( | |
| build_yes_no_prompt, | |
| build_multi_zero_shot_prompt, | |
| infer_multi_skill_name, | |
| get_multi_skill_spec, | |
| ) | |
| llm, sampling = _get_llm() | |
| if doc_mode == "multi": | |
| skill_name = infer_multi_skill_name(question) | |
| spec = get_multi_skill_spec(skill_name) | |
| ctx_texts = [ | |
| f"(report={c['report']}, chunk_idx={c['chunk_idx']}) {c['text']}" | |
| for c in contexts | |
| ] | |
| prompt = build_multi_zero_shot_prompt( | |
| question=question, | |
| contexts=ctx_texts, | |
| skill_name=spec["skill_name"], | |
| skill_prompt=spec["skill_prompt"], | |
| output_json_schema=spec["output_json_schema"], | |
| retrieval_query=question, | |
| ) | |
| else: | |
| ctx_texts = [c["text"] for c in contexts] | |
| prompt = build_yes_no_prompt(question, ctx_texts, doc_mode) | |
| outputs = llm.generate([prompt], sampling) | |
| return outputs[0].outputs[0].text.strip() | |
| def run_rag( | |
| question: str, | |
| chunk_mode: str = "structure", | |
| doc_mode: str = "single", | |
| top_k: int = 5, | |
| embed_name: str = "Qwen3-Embedding-0.6B", | |
| gen_model: str = "Qwen3-4B-Instruct-2507-FP8", | |
| api_key: str = "", | |
| ) -> tuple: | |
| contexts = retrieve( | |
| question=question, | |
| chunk_mode=chunk_mode, | |
| doc_mode=doc_mode, | |
| top_k=top_k, | |
| embed_name=embed_name, | |
| api_key=api_key, | |
| openai_base_url=OPENAI_EMBED_BASE_URL, | |
| ) | |
| answer = generate(question, contexts, doc_mode, gen_model, api_key) | |
| return answer, contexts | |
| def run_trustworthy_step1( | |
| question: str, | |
| doc_mode: str = "single", | |
| top_k: int = 10, | |
| embed_name: str = "Qwen3-Embedding-0.6B", | |
| gen_model: str = "Qwen3-4B-Instruct-2507-FP8", | |
| api_key: str = "", | |
| ) -> Dict[str, Any]: | |
| contexts = retrieve( | |
| question=question, | |
| chunk_mode="structure", | |
| doc_mode=doc_mode, | |
| top_k=int(top_k), | |
| embed_name=embed_name, | |
| api_key=api_key, | |
| openai_base_url=OPENAI_EMBED_BASE_URL, | |
| ) | |
| avg_similarity = _average_similarity(contexts) | |
| cluster_prompt = _build_cluster_prompt(question, contexts) | |
| cluster_raw = _run_model_prompt( | |
| prompt=cluster_prompt, | |
| gen_model=gen_model, | |
| api_key=api_key, | |
| max_tokens=1200, | |
| ) | |
| cluster_result = _parse_cluster_result(cluster_raw, contexts) | |
| return { | |
| "contexts": contexts, | |
| "average_similarity": avg_similarity, | |
| "clusters": cluster_result["clusters"], | |
| "cluster_raw_output": cluster_result["raw_output"], | |
| } | |
| def run_trustworthy_recluster( | |
| question: str, | |
| contexts: List[Dict], | |
| gen_model: str = "Qwen3-4B-Instruct-2507-FP8", | |
| api_key: str = "", | |
| ) -> Dict[str, Any]: | |
| cluster_prompt = _build_cluster_prompt(question, contexts) | |
| cluster_raw = _run_model_prompt( | |
| prompt=cluster_prompt, | |
| gen_model=gen_model, | |
| api_key=api_key, | |
| max_tokens=1200, | |
| ) | |
| cluster_result = _parse_cluster_result(cluster_raw, contexts) | |
| return { | |
| "clusters": cluster_result["clusters"], | |
| "cluster_raw_output": cluster_result["raw_output"], | |
| } | |
| def run_trustworthy_step2( | |
| question: str, | |
| doc_mode: str, | |
| contexts: List[Dict], | |
| clusters: List[Dict], | |
| gen_model: str = "Qwen3-4B-Instruct-2507-FP8", | |
| api_key: str = "", | |
| ) -> Dict[str, Any]: | |
| qa_contexts = _contexts_from_clusters(clusters, contexts) | |
| answer = generate( | |
| question=question, | |
| contexts=qa_contexts, | |
| doc_mode=doc_mode, | |
| gen_model=gen_model, | |
| api_key=api_key, | |
| ) | |
| return { | |
| "answer": answer, | |
| "qa_contexts": qa_contexts, | |
| } | |
| def run_rag_trustworthy( | |
| question: str, | |
| doc_mode: str = "single", | |
| top_k: int = 10, | |
| embed_name: str = "Qwen3-Embedding-0.6B", | |
| gen_model: str = "Qwen3-4B-Instruct-2507-FP8", | |
| api_key: str = "", | |
| ) -> Dict[str, Any]: | |
| step1 = run_trustworthy_step1( | |
| question=question, | |
| doc_mode=doc_mode, | |
| top_k=top_k, | |
| embed_name=embed_name, | |
| gen_model=gen_model, | |
| api_key=api_key, | |
| ) | |
| step2 = run_trustworthy_step2( | |
| question=question, | |
| doc_mode=doc_mode, | |
| contexts=step1["contexts"], | |
| clusters=step1["clusters"], | |
| gen_model=gen_model, | |
| api_key=api_key, | |
| ) | |
| return { | |
| "step1": step1, | |
| "step2": step2, | |
| } | |
| def _build_claim_alignment_prompt( | |
| question: str, | |
| answer: str, | |
| contexts: List[Dict], | |
| doc_mode: str, | |
| n_claims: int, | |
| ) -> str: | |
| lines = [] | |
| for i, c in enumerate(contexts, start=1): | |
| report = c.get("report", "unknown") | |
| chunk_idx = c.get("chunk_idx", "NA") | |
| text = str(c.get("text", "")).replace("\n", " ").strip() | |
| lines.append(f"[E{i}] report={report}; paragraph={chunk_idx}; text={text[:650]}") | |
| schema = ( | |
| '{"claims":[{"claim":"string","evidence_ids":["E1"]}]}' | |
| ) | |
| return ( | |
| "__TASK__: CLAIM_EXTRACTION\n" | |
| "You are the Claim Extractor in a trustworthy RAG pipeline.\n" | |
| "Given answer and evidence list, extract concise claims and link each claim to supporting evidence ids.\n" | |
| "Rules:\n" | |
| "- Return exactly N_CLAIMS claims.\n" | |
| "- Use only evidence ids from provided list.\n" | |
| "- Keep each claim <= 28 words.\n" | |
| f"Return strict JSON only with schema: {schema}\n\n" | |
| f"DOC_MODE: {doc_mode}\n" | |
| f"N_CLAIMS: {n_claims}\n" | |
| f"Question: {question}\n\n" | |
| f"Answer:\n{answer}\n\n" | |
| "Evidence:\n" | |
| + "\n".join(lines) | |
| ) | |
| def _parse_claim_alignment( | |
| output_text: str, | |
| contexts: List[Dict], | |
| n_claims: int, | |
| answer: str, | |
| doc_mode: str, | |
| ) -> Dict[str, Any]: | |
| payload = _extract_json_payload(output_text) | |
| n_contexts = len(contexts) | |
| claim_rows = [] | |
| if isinstance(payload, dict): | |
| rows = payload.get("claims", []) | |
| if isinstance(rows, list): | |
| for i, row in enumerate(rows): | |
| if isinstance(row, str): | |
| claim = _sanitize_claim_text(row) | |
| refs = [] | |
| elif isinstance(row, dict): | |
| claim = _sanitize_claim_text(row.get("claim", "")) | |
| refs = row.get("evidence_ids", row.get("evidence", row.get("sources", []))) | |
| else: | |
| continue | |
| if not claim: | |
| continue | |
| if not isinstance(refs, list): | |
| refs = [refs] if refs not in (None, "") else [] | |
| evidence = [] | |
| used = set() | |
| for ref in refs: | |
| idx = _resolve_evidence_index(ref, n_contexts) | |
| if idx is None or idx in used: | |
| continue | |
| used.add(idx) | |
| ctx = contexts[idx] | |
| evidence.append( | |
| { | |
| "evidence_id": f"E{idx + 1}", | |
| "report": str(ctx.get("report", "")), | |
| "chunk_idx": ctx.get("chunk_idx", ""), | |
| "page": ctx.get("page", None), | |
| "text": str(ctx.get("text", "")), | |
| "score": _clamp01(ctx.get("score", 0.0)), | |
| } | |
| ) | |
| # fallback to one evidence by order if model omitted refs | |
| if not evidence and i < n_contexts: | |
| ctx = contexts[i] | |
| evidence.append( | |
| { | |
| "evidence_id": f"E{i + 1}", | |
| "report": str(ctx.get("report", "")), | |
| "chunk_idx": ctx.get("chunk_idx", ""), | |
| "page": ctx.get("page", None), | |
| "text": str(ctx.get("text", "")), | |
| "score": _clamp01(ctx.get("score", 0.0)), | |
| } | |
| ) | |
| claim_rows.append({"claim": claim, "evidence": evidence}) | |
| if not claim_rows: | |
| report_names = sorted({str(c.get("report", "")) for c in contexts if str(c.get("report", "")).strip()}) | |
| seed_claims = _parse_claims( | |
| output_text="", | |
| answer=answer, | |
| n_claims=n_claims, | |
| doc_mode=doc_mode, | |
| report_names=report_names, | |
| ) | |
| for i, claim in enumerate(seed_claims[:n_claims]): | |
| evidence = [] | |
| if i < n_contexts: | |
| ctx = contexts[i] | |
| evidence.append( | |
| { | |
| "evidence_id": f"E{i + 1}", | |
| "report": str(ctx.get("report", "")), | |
| "chunk_idx": ctx.get("chunk_idx", ""), | |
| "page": ctx.get("page", None), | |
| "text": str(ctx.get("text", "")), | |
| "score": _clamp01(ctx.get("score", 0.0)), | |
| } | |
| ) | |
| claim_rows.append({"claim": claim, "evidence": evidence}) | |
| claim_rows = claim_rows[:n_claims] | |
| while len(claim_rows) < n_claims: | |
| claim_rows.append({"claim": f"Additional claim {len(claim_rows) + 1}: evidence is limited.", "evidence": []}) | |
| return { | |
| "claims": [c["claim"] for c in claim_rows], | |
| "claim_links": claim_rows, | |
| "raw_output": output_text, | |
| "payload": payload, | |
| } | |
| def run_trustworthy_step3_claims( | |
| question: str, | |
| answer: str, | |
| contexts: List[Dict], | |
| doc_mode: str = "single", | |
| gen_model: str = "Qwen3-4B-Instruct-2507-FP8", | |
| api_key: str = "", | |
| ) -> Dict[str, Any]: | |
| if doc_mode == "single": | |
| n_claims = 3 | |
| else: | |
| report_names = sorted({str(c.get("report", "")) for c in contexts if str(c.get("report", "")).strip()}) | |
| n_claims = max(2, _infer_multi_company_count(question, report_names) + 1) | |
| prompt = _build_claim_alignment_prompt( | |
| question=question, | |
| answer=answer, | |
| contexts=contexts, | |
| doc_mode=doc_mode, | |
| n_claims=n_claims, | |
| ) | |
| raw = _run_model_prompt( | |
| prompt=prompt, | |
| gen_model=gen_model, | |
| api_key=api_key, | |
| max_tokens=900, | |
| ) | |
| parsed = _parse_claim_alignment( | |
| output_text=raw, | |
| contexts=contexts, | |
| n_claims=n_claims, | |
| answer=answer, | |
| doc_mode=doc_mode, | |
| ) | |
| return { | |
| "claims": parsed["claims"], | |
| "claim_links": parsed["claim_links"], | |
| "claims_raw_output": parsed["raw_output"], | |
| } | |
| # ======================== Report Library ======================== | |
| def list_reports() -> List[Dict]: | |
| reports = [] | |
| if os.path.isdir(REPORTS_DIR): | |
| for f in sorted(os.listdir(REPORTS_DIR)): | |
| if f.lower().endswith(".pdf"): | |
| reports.append({"name": f, "path": os.path.join(REPORTS_DIR, f)}) | |
| return reports | |
| # ======================== API Generation ======================== | |
| def _generate_with_api( | |
| question: str, | |
| contexts: List[Dict], | |
| doc_mode: str, | |
| gen_model: str, | |
| api_key: str, | |
| ) -> str: | |
| from generation import ( | |
| build_yes_no_prompt, | |
| build_multi_zero_shot_prompt, | |
| infer_multi_skill_name, | |
| get_multi_skill_spec, | |
| ) | |
| if doc_mode == "multi": | |
| skill = infer_multi_skill_name(question) | |
| spec = get_multi_skill_spec(skill) | |
| ctx_texts = [ | |
| f"(report={c['report']}, chunk_idx={c['chunk_idx']}) {c['text']}" | |
| for c in contexts | |
| ] | |
| user_msg = build_multi_zero_shot_prompt( | |
| question=question, | |
| contexts=ctx_texts, | |
| skill_name=spec["skill_name"], | |
| skill_prompt=spec["skill_prompt"], | |
| output_json_schema=spec["output_json_schema"], | |
| retrieval_query=question, | |
| ) | |
| else: | |
| ctx_texts = [str(c.get("text", "")) for c in contexts] | |
| mode = doc_mode if doc_mode in {"single", "cross"} else "single" | |
| user_msg = build_yes_no_prompt(question, ctx_texts, mode) | |
| system_msg = ( | |
| "You are a strict climate disclosure analysis assistant. " | |
| "Follow the user task exactly and use only provided evidence." | |
| ) | |
| chat_model = _resolve_api_generation_model_name(gen_model) | |
| if _is_gemini_generation_model(chat_model): | |
| return _call_gemini_chat_model( | |
| model_name=chat_model, | |
| system_msg=system_msg, | |
| user_msg=user_msg, | |
| api_key=api_key, | |
| max_tokens=2048, | |
| ) | |
| return _call_openai_chat_model( | |
| model_name=chat_model, | |
| system_msg=system_msg, | |
| user_msg=user_msg, | |
| api_key=api_key, | |
| max_tokens=2048, | |
| ) | |
| # ======================== Mock Functions ======================== | |
| _MOCK_REPORTS = [ | |
| "2022 Microsoft Environmental Sustainability Report.pdf", | |
| "Apple_2023_EnvironmentalProgress_Report.pdf", | |
| "Shell sustainability report 2022.pdf", | |
| "Boeing 2023 Sustainability Report.pdf", | |
| "Deloitte 2023 Global Impact Report.pdf", | |
| ] | |
| _MOCK_PASSAGES = [ | |
| "The company has established comprehensive environmental targets aligned with " | |
| "the Paris Agreement and Science Based Targets initiative (SBTi). Key commitments " | |
| "include achieving carbon neutrality by 2030 and net-zero emissions across the " | |
| "entire value chain by 2050. The sustainability roadmap outlines specific milestones " | |
| "for renewable energy adoption, waste reduction, and water stewardship programs.", | |
| "In alignment with TCFD recommendations, we have integrated climate risk assessment " | |
| "into our enterprise risk management framework. Physical risks including extreme " | |
| "weather events and transition risks such as regulatory changes are evaluated " | |
| "annually. Scenario analysis covering 1.5\u00b0C, 2\u00b0C, and 4\u00b0C pathways informs " | |
| "our long-term strategic planning.", | |
| "Our Scope 1 and 2 greenhouse gas emissions decreased by 23% compared to the " | |
| "2019 baseline year, driven by accelerated procurement of renewable electricity " | |
| "and fleet electrification. Scope 3 emissions from the supply chain remain our " | |
| "largest challenge, and we have engaged 85% of Tier-1 suppliers in setting " | |
| "their own science-based targets.", | |
| "The company invested $2.1 billion in clean technology R&D during the reporting " | |
| "period, representing a 35% increase year-over-year. Priority areas include " | |
| "carbon capture and storage, sustainable aviation fuels, and next-generation " | |
| "battery technology for energy storage solutions.", | |
| "Water stewardship initiatives achieved a 40% reduction in freshwater withdrawal " | |
| "intensity across manufacturing operations. Zero-liquid-discharge systems were " | |
| "deployed at three additional facilities, bringing the total to twelve sites " | |
| "operating under closed-loop water management protocols.", | |
| ] | |
| def _mock_model_prompt(prompt: str) -> str: | |
| text = str(prompt or "") | |
| if "__TASK__: EVIDENCE_CLUSTERING" in text: | |
| evidence_ids = re.findall(r"\[E(\d+)\]", text) | |
| ids = [f"E{x}" for x in evidence_ids] | |
| split = max(1, len(ids) // 2) | |
| cluster_a = ids[:split] if ids else ["E1"] | |
| cluster_b = ids[split:] if len(ids) > split else [] | |
| clusters = [ | |
| { | |
| "cluster_name": "Governance", | |
| "summary": "Oversight, risk process, and governance-related evidence.", | |
| "evidence_ids": cluster_a, | |
| } | |
| ] | |
| if cluster_b: | |
| clusters.append( | |
| { | |
| "cluster_name": "Metrics & Targets", | |
| "summary": "Quantitative metrics, targets, and progress evidence.", | |
| "evidence_ids": cluster_b, | |
| } | |
| ) | |
| return json.dumps( | |
| { | |
| "skill": "Evidence Clusterer", | |
| "clusters": clusters, | |
| }, | |
| ensure_ascii=False, | |
| ) | |
| if "__TASK__: CLAIM_GENERATION" in text: | |
| n = 3 | |
| m = re.search(r"N_CLAIMS:\s*(\d+)", text) | |
| if m: | |
| n = max(1, int(m.group(1))) | |
| claims = [{"claim": f"Claim {i+1}: evidence-backed summary statement."} for i in range(n)] | |
| return json.dumps({"claims": claims}, ensure_ascii=False) | |
| if "__TASK__: CLAIM_EXTRACTION" in text: | |
| n = 3 | |
| m = re.search(r"N_CLAIMS:\s*(\d+)", text) | |
| if m: | |
| n = max(1, int(m.group(1))) | |
| evidence_ids = re.findall(r"\[E(\d+)\]", text) | |
| evidences = [f"E{x}" for x in evidence_ids] | |
| rows = [] | |
| for i in range(n): | |
| refs = [] | |
| if evidences: | |
| refs = [evidences[i % len(evidences)]] | |
| rows.append( | |
| { | |
| "claim": f"Claim {i+1}: evidence-backed summary statement.", | |
| "evidence_ids": refs, | |
| } | |
| ) | |
| return json.dumps({"claims": rows}, ensure_ascii=False) | |
| if "__TASK__: FAITHFULNESS_VERIFIER" in text: | |
| claim_lines = re.findall(r"^C\d+:\s*(.+)$", text, flags=re.MULTILINE) | |
| if not claim_lines: | |
| claim_lines = ["Claim 1"] | |
| rows = [] | |
| for i, c in enumerate(claim_lines): | |
| score = max(0.55, 0.9 - i * 0.08) | |
| label = "supported" if score >= 0.75 else "partial" | |
| rows.append( | |
| { | |
| "claim": c.strip(), | |
| "score": round(score, 3), | |
| "label": label, | |
| "reason": "Mock verifier support signal.", | |
| } | |
| ) | |
| faith = _mean([r["score"] for r in rows], default=0.0) | |
| return json.dumps({"claim_scores": rows, "faithfulness": round(faith, 3)}, ensure_ascii=False) | |
| return "{}" | |
| def _mock_retrieve(question: str, chunk_mode: str, doc_mode: str, top_k: int) -> List[Dict]: | |
| return [ | |
| { | |
| "report": _MOCK_REPORTS[i % len(_MOCK_REPORTS)], | |
| "chunk_idx": i * 3 + 7, | |
| "text": _MOCK_PASSAGES[i % len(_MOCK_PASSAGES)], | |
| "score": round(0.9231 - i * 0.0547, 4), | |
| } | |
| for i in range(top_k) | |
| ] | |
| def _mock_generate(question: str, contexts: List[Dict], doc_mode: str) -> str: | |
| n = len(contexts) | |
| if doc_mode == "multi": | |
| return ( | |
| '{' | |
| '"dimension":"Mock dimension",' | |
| '"rows":[{"report":"Mock Report A","year":"2022","disclosure_status":"explicit",' | |
| '"key_points":["Mock evidence A"],"evidence_chunks":["E1"]}],' | |
| '"ranking":[{"rank":1,"report":"Mock Report A","rationale":"Mock rationale."}],' | |
| '"conclusion":"[YES] Mock conclusion based on provided evidence."' | |
| '}' | |
| ) | |
| return ( | |
| f"[YES]. Based on the {n} retrieved passages, the company demonstrates " | |
| f"alignment with external climate targets through specific commitments " | |
| f"to net-zero emissions and science-based targets. The report references " | |
| f"the Paris Agreement, SBTi, and TCFD frameworks as guiding principles " | |
| f"for their environmental strategy." | |
| ) | |