"""RAG Generation for single / cross / multi document settings.""" from __future__ import annotations import os import re import json import argparse from typing import Dict, List, Tuple, Any import pandas as pd # ======================== Config ======================== GEN_MODEL_PATH = "../model/Qwen3-4B-Instruct-2507-FP8" BASE_XLSX_PATH = "../Expert-Annotated Relevant Sources Dataset/ClimRetrieve_base.xlsx" CROSS_XLSX_PATH = "../Expert-Annotated Relevant Sources Dataset/ClimRetrieve_cross.xlsx" RESULT_DIR_BASE = "./Embedding_Search_Results_Qwen" RETRIEVAL_TOP_K = 5 MAX_NEW_TOKENS = 768 TEMPERATURE = 0.0 # ======================== Prompt Templates ======================== FEW_SHOT_EXAMPLES_SINGLE = """Example 1: Question: Does the company have a strategy on waste management? Context: - "To meet our commitment to being zero waste by 2030, we are reducing our waste footprint through reuse, recycling and recovery." - "We increased our reuse and recycle rates of all cloud hardware to 82 percent." Answer: [YES]. The company has a clear strategy on waste management focused on reduction, recycling, and recovery to meet its 2030 zero waste commitment. Example 2: Question: Does the company report the climate change scenarios used to test the resilience of its business strategy? Context: - "Our ESG ratings improved from 2021 to 2022." - "We continue to focus on sustainable sourcing across our supply chain." Answer: [NO]. The provided context does not contain information about climate change scenarios used to test business strategy resilience. """ FEW_SHOT_EXAMPLES_CROSS = """Example 1: Question: Among [Company A Report, Company B Report], did at least one company address "Does the company have a strategy on waste management?" Context: - (Company A) "We are committed to reducing waste by 50% by 2030 through our circular economy approach." - (Company B) "Our revenue grew by 15% year-over-year driven by strong demand." Answer: [YES]. Company A provides clear evidence of a waste management strategy with specific targets, while Company B's context does not address waste management. Example 2: Question: Among [Company A Report, Company B Report], did Company A provide stronger disclosure than Company B on "Does the company report its Scope 1 emissions?" Context: - (Company A) "We focus on employee diversity and inclusion programs." - (Company B) "Our sustainability team oversees environmental compliance." Answer: [NO]. Neither company's context provides evidence about Scope 1 emissions reporting, so the comparison claim is not supported. """ MULTI_ZERO_SHOT_BASE = """You are a regulator/auditor analyzing multi-document climate disclosure questions. Use ONLY the provided context chunks. Do not fabricate facts, numeric values, years, units, page numbers, or source ids. Prioritize quantitative comparison: extract comparable metrics (value + unit + period) only if explicitly stated. If evidence is insufficient, keep fields conservative: null/[]/"insufficient". Set `conclusion` as a direct decision in this exact style: "[YES] " or "[NO] ". Return strict JSON only (no markdown, no prose outside JSON). """ MULTI_OUTPUT_SCHEMA = ( '{' '"dimension":"string",' '"rows":[{"report":"string","year":"string","disclosure_status":"explicit | partial | missing",' '"key_points":["string"],"evidence_chunks":["E1","E4"]}],' '"ranking":[{"rank":1,"report":"string","rationale":"string"}],' '"conclusion":"string ([YES]/[NO] + reason)"' '}' ) MULTI_SKILL_SPECS = { "Comparative Table Builder": { "prompt": ( "Build a regulator-facing comparative table. " "For each report, output disclosure maturity and quantified metrics extracted from context. " "Use `quant_metrics` with objects: {metric, value, unit, period, note}. " "If a metric is unavailable for a report, do not invent it; use null values or omit the metric." ), "schema": ( '{"skill":"Comparative Table Builder","dimension":"string","reports":[{"report":"string","year":"number|null",' '"maturity_level":"high|moderate|low|insufficient","key_evidence":["string"],' '"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null","note":"string|null"}]}],' '"comparison_metrics":["string"],"conclusion":"string ([YES]/[NO] + reason)"}' ), }, "Trend & Quant Comparator": { "prompt": ( "Build a quantitative trend comparison across reports. " "For each report, return measurable indicators with fields: value, intensity, attainment_rate, " "change_magnitude, trend_direction. Keep values numeric when present in context, otherwise null. " "Include 1-3 concise key_evidence bullets per report." ), "schema": ( '{"skill":"Trend & Quant Comparator","dimension":"string","reports":[{"report":"string","year":"number|null",' '"key_evidence":["string"],"strength_score":"number|null",' '"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null",' '"intensity":"number|null","attainment_rate":"number|null","change_magnitude":"number|null",' '"trend_direction":"up|down|flat|unknown|insufficient","note":"string|null"}]}],' '"metric_highlights":["string"],"conclusion":"string ([YES]/[NO] + reason)"}' ), }, "Target Attainment & Delta Benchmark": { "prompt": ( "Create a target-attainment benchmark. For each report, extract baseline/current/target values where available, " "compute attainment_rate and delta fields (absolute and percent), and infer trend_direction conservatively. " "If any field is unavailable in context, return null/insufficient rather than guessing." ), "schema": ( '{"skill":"Target Attainment & Delta Benchmark","dimension":"string","reports":[{"report":"string","year":"number|null",' '"overall_strength":"high|moderate|low|insufficient","key_evidence":["string"],' '"benchmarks":[{"metric":"string","baseline_value":"number|null","baseline_period":"string|null",' '"current_value":"number|null","current_period":"string|null","target_value":"number|null","target_period":"string|null",' '"attainment_rate":"number|null","delta_abs":"number|null","delta_percent":"number|null","intensity":"number|null",' '"unit":"string|null","trend_direction":"up|down|flat|unknown|insufficient","note":"string|null"}]}],' '"leaderboard":[{"report":"string","score":"number|null","reason":"string"}],"conclusion":"string ([YES]/[NO] + reason)"}' ), }, "Compliance Checklist": { "prompt": ( "Create a compliance checklist comparison. " "Return item-level pass/partial/fail for each report and include quantified indicators where available. " "Also compute summary counts per report and include 1-3 concise key_evidence bullets per report." ), "schema": ( '{"skill":"Compliance Checklist","dimension":"string","required_checks":["string"],"reports":[{"report":"string",' '"key_evidence":["string"],' '"summary":{"pass":"number","partial":"number","fail":"number","completion_rate":"number|null"},' '"checks":[{"item":"string","status":"pass|partial|fail|insufficient","quant_value":"number|null","quant_unit":"string|null","note":"string"}]}],' '"conclusion":"string ([YES]/[NO] + reason)"}' ), }, "Dimension Extractor": { "prompt": ( "Extract disclosure dimensions and quantify coverage by bucket. " "Return bucket-level counts and any explicit quantitative metric snippets, plus 1-3 key_evidence bullets per report." ), "schema": ( '{"skill":"Dimension Extractor","dimension":"string","reports":[{"report":"string","bucket_counts":{"Process":"number","Input":"number","Output":"number","Outcome":"number","Governance":"number","Risk":"number"},' '"key_evidence":["string"],' '"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null","note":"string|null"}],' '"coverage_level":"high|moderate|low|insufficient"}],"conclusion":"string ([YES]/[NO] + reason)"}' ), }, "Contradiction/Consistency Check": { "prompt": ( "Evaluate consistency across reports. " "Return rule checks and a quantified consistency score per report pair/group. " "Include a concise key_evidence list that supports the consistency conclusion." ), "schema": ( '{"skill":"Contradiction/Consistency Check","dimension":"string","key_evidence":["string"],' '"checks":[{"rule":"string","result":"consistent|inconsistent|insufficient","note":"string"}],' '"scores":{"consistent":"number","inconsistent":"number","insufficient":"number","consistency_rate":"number|null"},' '"conclusion":"string ([YES]/[NO] + reason)"}' ), }, "Consensus/Count (Portfolio Statistics)": { "prompt": ( "Produce portfolio-level statistics from retrieved context only. " "Count explicit/partial/missing disclosures and include percentage split. " "For each report, include 1-3 key_evidence bullets." ), "schema": ( '{"skill":"Consensus/Count (Portfolio Statistics)","dimension":"string","counts":{"explicit":"number","partial":"number","missing":"number","total":"number"},' '"percentages":{"explicit":"number|null","partial":"number|null","missing":"number|null"},' '"per_report":[{"report":"string","label":"explicit|partial|missing|insufficient","quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null"}]}],' '"key_evidence_by_report":[{"report":"string","key_evidence":["string"]}],' '"consensus_items":["string"],"outliers":["string"],"conclusion":"string ([YES]/[NO] + reason)"}' ), }, } METRIC_HINT_RULES = [ ("trend", ["value", "intensity", "attainment_rate", "change_magnitude", "delta_percent"]), ("attainment", ["target_value", "current_value", "attainment_rate", "delta_percent"]), ("delta", ["delta_abs", "delta_percent", "baseline_value", "current_value"]), ("baseline", ["baseline_value", "current_value", "target_value", "delta_percent"]), ("kpi", ["value", "unit", "period", "change_magnitude"]), ("risk", ["number_of_risk_categories", "assessment_frequency_per_year", "scenario_count", "horizon_year"]), ("methodolog", ["framework_count", "method_steps_count", "indicator_count", "coverage_percent"]), ("scenario", ["scenario_count", "temperature_pathway_count", "horizon_year"]), ("resilience", ["scenario_count", "stress_test_count", "horizon_year"]), ("impact", ["scope1_emissions", "scope2_emissions", "scope3_emissions", "reduction_percent"]), ("waste", ["waste_reduction_percent", "recycling_rate_percent", "waste_diverted_tons", "target_year"]), ("industry peers", ["initiative_count", "partnership_count", "engagement_frequency_per_year"]), ("downstream", ["supplier_coverage_percent", "partner_count", "assessment_completion_percent"]), ("water", ["water_withdrawal_change_percent", "water_intensity_change_percent", "sites_count", "target_year"]), ] def infer_metric_hints(question: str) -> List[str]: q = str(question or "").lower() hints = [] for key, metrics in METRIC_HINT_RULES: if key in q: for m in metrics: if m not in hints: hints.append(m) if not hints: hints = ["reported_metric_count", "target_year", "percent_change", "absolute_value"] return hints def normalize_multi_skill_name(skill_name: str) -> str: s = str(skill_name or "").strip().lower() if "trend" in s and "quant" in s: return "Trend & Quant Comparator" if "attainment" in s or ("delta" in s and "benchmark" in s): return "Target Attainment & Delta Benchmark" if "compliance" in s and "check" in s: return "Compliance Checklist" if "dimension" in s and "extract" in s: return "Dimension Extractor" if "contradiction" in s or "consistency" in s: return "Contradiction/Consistency Check" if "consensus" in s or "portfolio" in s or "count" in s: return "Consensus/Count (Portfolio Statistics)" return "Comparative Table Builder" def infer_multi_skill_name(question: str) -> str: q = str(question or "").lower() if ("trend" in q and "quant" in q) or "change magnitude" in q or "measurable progress" in q: return "Trend & Quant Comparator" if "attainment rate" in q or "target gap" in q or "baseline" in q or "delta percent" in q: return "Target Attainment & Delta Benchmark" if "checklist" in q or "compliance" in q: return "Compliance Checklist" if "extract" in q and "dimension" in q: return "Dimension Extractor" if "contradiction" in q or "consistent" in q or "inconsistent" in q: return "Contradiction/Consistency Check" if "consensus" in q or "outlier" in q or "portfolio" in q or "count" in q: return "Consensus/Count (Portfolio Statistics)" return "Comparative Table Builder" def get_multi_skill_spec(skill_name: str) -> Dict[str, str]: name = normalize_multi_skill_name(skill_name) return { "skill_name": name, "skill_prompt": MULTI_SKILL_SPECS[name]["prompt"], "output_json_schema": MULTI_OUTPUT_SCHEMA, } def build_yes_no_prompt(question: str, contexts: List[str], doc_mode: str) -> str: few_shot = FEW_SHOT_EXAMPLES_CROSS if doc_mode == "cross" else FEW_SHOT_EXAMPLES_SINGLE ctx_text = "\n".join(f'- "{c}"' for c in contexts) return f"""You are an expert analyst evaluating corporate sustainability reports. Based ONLY on the provided context passages, answer the question with [YES] or [NO], followed by a brief reasoning. Format your answer as: [YES]. OR [NO]. {few_shot.strip()} Now answer the following: Question: {question} Context: {ctx_text} Answer:""" def build_multi_zero_shot_prompt( question: str, contexts: List[str], skill_name: str = "", skill_prompt: str = "", output_json_schema: str = "", retrieval_query: str = "", metric_hints: List[str] | None = None, ) -> str: ctx_text = "\n".join(f'[E{i}] "{c}"' for i, c in enumerate(contexts, start=1)) skill_name = str(skill_name or "").strip() skill_prompt = str(skill_prompt or "").strip() output_json_schema = str(output_json_schema or "").strip() retrieval_query = str(retrieval_query or "").strip() metric_hints = metric_hints or infer_metric_hints(question) metric_hint_text = ", ".join(metric_hints) return f"""{MULTI_ZERO_SHOT_BASE} Matched skill: {skill_name} Skill guidance: {skill_prompt} Retrieval query dimension: {retrieval_query} Suggested comparable metrics: {metric_hint_text} Target JSON schema: {output_json_schema} Question: {question} Context: {ctx_text} JSON Answer:""" # ======================== Ground Truth ======================== def extract_yes_no(text: str) -> str | None: text_upper = str(text).strip().upper() m = re.match(r"\[*\s*(YES|NO)\s*\]*", text_upper) if m: return m.group(1) if text_upper.startswith("YES"): return "YES" if text_upper.startswith("NO"): return "NO" return None def load_ground_truth(doc_mode: str) -> Dict[Any, str]: if doc_mode == "single": df = pd.read_excel(BASE_XLSX_PATH, index_col=0) gt = {} for _, row in df.iterrows(): report = row["Document"] question = row["Question"] label = extract_yes_no(row["Answer"]) if label is not None: gt[(report, question)] = label return gt if doc_mode == "cross": df = pd.read_excel(CROSS_XLSX_PATH) if "Unnamed: 0" in df.columns: df = df.drop(columns=["Unnamed: 0"]) gt = {} for _, row in df.iterrows(): question = row["Question"] label = extract_yes_no(row["Answer"]) if label is not None and question not in gt: gt[question] = label return gt # multi has no closed-form YES/NO ground truth return {} # ======================== Retrieval Result Loading ======================== def get_result_dir(chunk_mode: str, doc_mode: str) -> str: parts = [RESULT_DIR_BASE] if chunk_mode != "length": parts.append(chunk_mode) if doc_mode != "single": parts.append(doc_mode) return "_".join(parts) def load_retrieval_contexts(chunk_mode: str, doc_mode: str, top_k: int) -> Dict[Any, Any]: result_dir = get_result_dir(chunk_mode, doc_mode) result_file = os.path.join(result_dir, f"Qwen3-Embedding-0.6B__{chunk_mode}__{doc_mode}__{top_k}.csv") if not os.path.exists(result_file): raise FileNotFoundError(f"Retrieval result file not found: {result_file}") df = pd.read_csv(result_file) retrieved = df[df["question_retrieved"] == 1].copy() retrieved = retrieved.sort_values("similarity_score", ascending=False) if doc_mode == "single": contexts = {} for (report, question), group in retrieved.groupby(["report", "question"], sort=False): contexts[(report, question)] = group["chunk_text"].tolist() return contexts contexts = {} for question, group in retrieved.groupby("question", sort=False): chunks = [] for _, row in group.iterrows(): chunks.append(f"({row['report'].replace('.pdf', '')}, chunk={row['chunk_idx']}) {row['chunk_text']}") payload = {"chunks": chunks} for col in ["skill_name", "skill_prompt", "output_json_schema", "retrieval_query"]: if col in group.columns: non_null = group[col].dropna() payload[col] = non_null.iloc[0] if len(non_null) > 0 else "" else: payload[col] = "" contexts[question] = payload return contexts # ======================== Main ======================== def main(): parser = argparse.ArgumentParser(description="RAG generation for ClimRetrieve") parser.add_argument("--chunk", type=str, default="length", choices=["length", "structure"]) parser.add_argument("--doc", type=str, default="single", choices=["single", "cross", "multi"]) parser.add_argument("--top_k", type=int, default=RETRIEVAL_TOP_K) parser.add_argument("--model", type=str, default=GEN_MODEL_PATH) parser.add_argument("--max_tokens", type=int, default=MAX_NEW_TOKENS) parser.add_argument("--temperature", type=float, default=TEMPERATURE) parser.add_argument("--limit", type=int, default=None, help="Only generate for first N questions") args = parser.parse_args() chunk_mode = args.chunk doc_mode = args.doc top_k = args.top_k limit = args.limit print("=" * 70) print("RAG Generation") print(f" CHUNK={chunk_mode}, DOC={doc_mode}, TOP_K={top_k}, LIMIT={limit}") print(f" Model: {args.model}") print("=" * 70) print("\n[Step 1] Load retrieval contexts") contexts = load_retrieval_contexts(chunk_mode, doc_mode, top_k) print(f" Context groups: {len(contexts)}") if limit is not None and limit > 0: items = list(contexts.items())[:limit] contexts = dict(items) print(f" Applied limit -> {len(contexts)} groups") gt = {} if doc_mode in ("single", "cross"): print("\n[Step 2] Load ground truth") gt = load_ground_truth(doc_mode) print(f" Ground truth count: {len(gt)}") else: print("\n[Step 2] Multi mode: skip YES/NO ground truth") print("\n[Step 3] Build prompts") prompts = [] prompt_keys = [] if doc_mode == "single": for (report, question), ctx_list in contexts.items(): prompts.append(build_yes_no_prompt(question, ctx_list[:top_k], doc_mode)) prompt_keys.append({"report": report, "question": question}) elif doc_mode == "cross": for question, payload in contexts.items(): ctx_list = payload["chunks"] if isinstance(payload, dict) else payload prompts.append(build_yes_no_prompt(question, ctx_list[:top_k], doc_mode)) prompt_keys.append({"question": question}) else: for question, payload in contexts.items(): ctx_list = payload.get("chunks", []) inferred_skill = payload.get("skill_name", "") or infer_multi_skill_name(question) spec = get_multi_skill_spec(inferred_skill) skill_name = spec["skill_name"] skill_prompt = payload.get("skill_prompt", "").strip() or spec["skill_prompt"] output_json_schema = spec["output_json_schema"] retrieval_query = payload.get("retrieval_query", "").strip() or question prompts.append( build_multi_zero_shot_prompt( question=question, contexts=ctx_list[:top_k], skill_name=skill_name, skill_prompt=skill_prompt, output_json_schema=output_json_schema, retrieval_query=retrieval_query, metric_hints=infer_metric_hints(question), ) ) prompt_keys.append( { "question": question, "skill_name": skill_name, "retrieval_query": retrieval_query, } ) print(f" Prompt count: {len(prompts)}") if prompts: print(f" Prompt preview: {prompts[0][:400]}...") print("\n[Step 4] Run generation") from vllm import LLM, SamplingParams llm = LLM(model=args.model, max_model_len=8192, dtype="auto", trust_remote_code=True) sampling_params = SamplingParams( temperature=args.temperature, max_tokens=args.max_tokens, top_p=1.0, ) outputs = llm.generate(prompts, sampling_params) print("\n[Step 5] Collect results") results = [] correct = 0 total_eval = 0 for i, output in enumerate(outputs): generated_text = output.outputs[0].text.strip() key = prompt_keys[i] if doc_mode == "multi": results.append( { **key, "generated_answer": generated_text, } ) continue predicted_label = extract_yes_no(generated_text) if doc_mode == "single": gt_label = gt.get((key["report"], key["question"])) else: gt_label = gt.get(key["question"]) is_correct = None if predicted_label is not None and gt_label is not None: is_correct = 1 if predicted_label == gt_label else 0 correct += is_correct total_eval += 1 results.append( { **key, "generated_answer": generated_text[:800], "predicted_label": predicted_label, "gt_label": gt_label, "is_correct": is_correct, } ) print("\n[Step 6] Save outputs") output_dir = get_result_dir(chunk_mode, doc_mode) os.makedirs(output_dir, exist_ok=True) limit_tag = f"__limit{limit}" if (limit is not None and limit > 0) else "" output_file = os.path.join(output_dir, f"generation__{chunk_mode}__{doc_mode}__top{top_k}{limit_tag}.csv") pd.DataFrame(results).to_csv(output_file, index=False, encoding="utf-8-sig") print(f" Result file: {output_file}") if doc_mode == "multi": metrics = { "chunk_mode": chunk_mode, "doc_mode": doc_mode, "top_k": top_k, "limit": limit, "generated_count": len(results), } else: accuracy = correct / total_eval if total_eval > 0 else 0.0 metrics = { "chunk_mode": chunk_mode, "doc_mode": doc_mode, "top_k": top_k, "limit": limit, "total_eval": total_eval, "correct": correct, "accuracy": accuracy, "unparsed": sum(1 for r in results if r.get("predicted_label") is None), "no_gt": sum(1 for r in results if r.get("gt_label") is None), } print(f" Accuracy: {accuracy:.4f} ({correct}/{total_eval})") metrics_file = os.path.join(output_dir, f"generation_metrics__{chunk_mode}__{doc_mode}__top{top_k}{limit_tag}.json") with open(metrics_file, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) print(f" Metrics file: {metrics_file}") print("\n" + "=" * 70) print("Generation completed.") print("=" * 70) if __name__ == "__main__": main()