ClimateRAG_QA / Experiments /generation.py
tengfeiCheng's picture
add cleaned experiments code
12323e1
"""RAG Generation for single / cross / multi document settings."""
from __future__ import annotations
import os
import re
import json
import argparse
from typing import Dict, List, Tuple, Any
import pandas as pd
# ======================== Config ========================
GEN_MODEL_PATH = "../model/Qwen3-4B-Instruct-2507-FP8"
BASE_XLSX_PATH = "../Expert-Annotated Relevant Sources Dataset/ClimRetrieve_base.xlsx"
CROSS_XLSX_PATH = "../Expert-Annotated Relevant Sources Dataset/ClimRetrieve_cross.xlsx"
RESULT_DIR_BASE = "./Embedding_Search_Results_Qwen"
RETRIEVAL_TOP_K = 5
MAX_NEW_TOKENS = 768
TEMPERATURE = 0.0
# ======================== Prompt Templates ========================
FEW_SHOT_EXAMPLES_SINGLE = """Example 1:
Question: Does the company have a strategy on waste management?
Context:
- "To meet our commitment to being zero waste by 2030, we are reducing our waste footprint through reuse, recycling and recovery."
- "We increased our reuse and recycle rates of all cloud hardware to 82 percent."
Answer: [YES]. The company has a clear strategy on waste management focused on reduction, recycling, and recovery to meet its 2030 zero waste commitment.
Example 2:
Question: Does the company report the climate change scenarios used to test the resilience of its business strategy?
Context:
- "Our ESG ratings improved from 2021 to 2022."
- "We continue to focus on sustainable sourcing across our supply chain."
Answer: [NO]. The provided context does not contain information about climate change scenarios used to test business strategy resilience.
"""
FEW_SHOT_EXAMPLES_CROSS = """Example 1:
Question: Among [Company A Report, Company B Report], did at least one company address "Does the company have a strategy on waste management?"
Context:
- (Company A) "We are committed to reducing waste by 50% by 2030 through our circular economy approach."
- (Company B) "Our revenue grew by 15% year-over-year driven by strong demand."
Answer: [YES]. Company A provides clear evidence of a waste management strategy with specific targets, while Company B's context does not address waste management.
Example 2:
Question: Among [Company A Report, Company B Report], did Company A provide stronger disclosure than Company B on "Does the company report its Scope 1 emissions?"
Context:
- (Company A) "We focus on employee diversity and inclusion programs."
- (Company B) "Our sustainability team oversees environmental compliance."
Answer: [NO]. Neither company's context provides evidence about Scope 1 emissions reporting, so the comparison claim is not supported.
"""
MULTI_ZERO_SHOT_BASE = """You are a regulator/auditor analyzing multi-document climate disclosure questions.
Use ONLY the provided context chunks.
Do not fabricate facts, numeric values, years, units, page numbers, or source ids.
Prioritize quantitative comparison: extract comparable metrics (value + unit + period) only if explicitly stated.
If evidence is insufficient, keep fields conservative: null/[]/"insufficient".
Set `conclusion` as a direct decision in this exact style: "[YES] <reason>" or "[NO] <reason>".
Return strict JSON only (no markdown, no prose outside JSON).
"""
MULTI_OUTPUT_SCHEMA = (
'{'
'"dimension":"string",'
'"rows":[{"report":"string","year":"string","disclosure_status":"explicit | partial | missing",'
'"key_points":["string"],"evidence_chunks":["E1","E4"]}],'
'"ranking":[{"rank":1,"report":"string","rationale":"string"}],'
'"conclusion":"string ([YES]/[NO] + reason)"'
'}'
)
MULTI_SKILL_SPECS = {
"Comparative Table Builder": {
"prompt": (
"Build a regulator-facing comparative table. "
"For each report, output disclosure maturity and quantified metrics extracted from context. "
"Use `quant_metrics` with objects: {metric, value, unit, period, note}. "
"If a metric is unavailable for a report, do not invent it; use null values or omit the metric."
),
"schema": (
'{"skill":"Comparative Table Builder","dimension":"string","reports":[{"report":"string","year":"number|null",'
'"maturity_level":"high|moderate|low|insufficient","key_evidence":["string"],'
'"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null","note":"string|null"}]}],'
'"comparison_metrics":["string"],"conclusion":"string ([YES]/[NO] + reason)"}'
),
},
"Trend & Quant Comparator": {
"prompt": (
"Build a quantitative trend comparison across reports. "
"For each report, return measurable indicators with fields: value, intensity, attainment_rate, "
"change_magnitude, trend_direction. Keep values numeric when present in context, otherwise null. "
"Include 1-3 concise key_evidence bullets per report."
),
"schema": (
'{"skill":"Trend & Quant Comparator","dimension":"string","reports":[{"report":"string","year":"number|null",'
'"key_evidence":["string"],"strength_score":"number|null",'
'"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null",'
'"intensity":"number|null","attainment_rate":"number|null","change_magnitude":"number|null",'
'"trend_direction":"up|down|flat|unknown|insufficient","note":"string|null"}]}],'
'"metric_highlights":["string"],"conclusion":"string ([YES]/[NO] + reason)"}'
),
},
"Target Attainment & Delta Benchmark": {
"prompt": (
"Create a target-attainment benchmark. For each report, extract baseline/current/target values where available, "
"compute attainment_rate and delta fields (absolute and percent), and infer trend_direction conservatively. "
"If any field is unavailable in context, return null/insufficient rather than guessing."
),
"schema": (
'{"skill":"Target Attainment & Delta Benchmark","dimension":"string","reports":[{"report":"string","year":"number|null",'
'"overall_strength":"high|moderate|low|insufficient","key_evidence":["string"],'
'"benchmarks":[{"metric":"string","baseline_value":"number|null","baseline_period":"string|null",'
'"current_value":"number|null","current_period":"string|null","target_value":"number|null","target_period":"string|null",'
'"attainment_rate":"number|null","delta_abs":"number|null","delta_percent":"number|null","intensity":"number|null",'
'"unit":"string|null","trend_direction":"up|down|flat|unknown|insufficient","note":"string|null"}]}],'
'"leaderboard":[{"report":"string","score":"number|null","reason":"string"}],"conclusion":"string ([YES]/[NO] + reason)"}'
),
},
"Compliance Checklist": {
"prompt": (
"Create a compliance checklist comparison. "
"Return item-level pass/partial/fail for each report and include quantified indicators where available. "
"Also compute summary counts per report and include 1-3 concise key_evidence bullets per report."
),
"schema": (
'{"skill":"Compliance Checklist","dimension":"string","required_checks":["string"],"reports":[{"report":"string",'
'"key_evidence":["string"],'
'"summary":{"pass":"number","partial":"number","fail":"number","completion_rate":"number|null"},'
'"checks":[{"item":"string","status":"pass|partial|fail|insufficient","quant_value":"number|null","quant_unit":"string|null","note":"string"}]}],'
'"conclusion":"string ([YES]/[NO] + reason)"}'
),
},
"Dimension Extractor": {
"prompt": (
"Extract disclosure dimensions and quantify coverage by bucket. "
"Return bucket-level counts and any explicit quantitative metric snippets, plus 1-3 key_evidence bullets per report."
),
"schema": (
'{"skill":"Dimension Extractor","dimension":"string","reports":[{"report":"string","bucket_counts":{"Process":"number","Input":"number","Output":"number","Outcome":"number","Governance":"number","Risk":"number"},'
'"key_evidence":["string"],'
'"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null","note":"string|null"}],'
'"coverage_level":"high|moderate|low|insufficient"}],"conclusion":"string ([YES]/[NO] + reason)"}'
),
},
"Contradiction/Consistency Check": {
"prompt": (
"Evaluate consistency across reports. "
"Return rule checks and a quantified consistency score per report pair/group. "
"Include a concise key_evidence list that supports the consistency conclusion."
),
"schema": (
'{"skill":"Contradiction/Consistency Check","dimension":"string","key_evidence":["string"],'
'"checks":[{"rule":"string","result":"consistent|inconsistent|insufficient","note":"string"}],'
'"scores":{"consistent":"number","inconsistent":"number","insufficient":"number","consistency_rate":"number|null"},'
'"conclusion":"string ([YES]/[NO] + reason)"}'
),
},
"Consensus/Count (Portfolio Statistics)": {
"prompt": (
"Produce portfolio-level statistics from retrieved context only. "
"Count explicit/partial/missing disclosures and include percentage split. "
"For each report, include 1-3 key_evidence bullets."
),
"schema": (
'{"skill":"Consensus/Count (Portfolio Statistics)","dimension":"string","counts":{"explicit":"number","partial":"number","missing":"number","total":"number"},'
'"percentages":{"explicit":"number|null","partial":"number|null","missing":"number|null"},'
'"per_report":[{"report":"string","label":"explicit|partial|missing|insufficient","quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null"}]}],'
'"key_evidence_by_report":[{"report":"string","key_evidence":["string"]}],'
'"consensus_items":["string"],"outliers":["string"],"conclusion":"string ([YES]/[NO] + reason)"}'
),
},
}
METRIC_HINT_RULES = [
("trend", ["value", "intensity", "attainment_rate", "change_magnitude", "delta_percent"]),
("attainment", ["target_value", "current_value", "attainment_rate", "delta_percent"]),
("delta", ["delta_abs", "delta_percent", "baseline_value", "current_value"]),
("baseline", ["baseline_value", "current_value", "target_value", "delta_percent"]),
("kpi", ["value", "unit", "period", "change_magnitude"]),
("risk", ["number_of_risk_categories", "assessment_frequency_per_year", "scenario_count", "horizon_year"]),
("methodolog", ["framework_count", "method_steps_count", "indicator_count", "coverage_percent"]),
("scenario", ["scenario_count", "temperature_pathway_count", "horizon_year"]),
("resilience", ["scenario_count", "stress_test_count", "horizon_year"]),
("impact", ["scope1_emissions", "scope2_emissions", "scope3_emissions", "reduction_percent"]),
("waste", ["waste_reduction_percent", "recycling_rate_percent", "waste_diverted_tons", "target_year"]),
("industry peers", ["initiative_count", "partnership_count", "engagement_frequency_per_year"]),
("downstream", ["supplier_coverage_percent", "partner_count", "assessment_completion_percent"]),
("water", ["water_withdrawal_change_percent", "water_intensity_change_percent", "sites_count", "target_year"]),
]
def infer_metric_hints(question: str) -> List[str]:
q = str(question or "").lower()
hints = []
for key, metrics in METRIC_HINT_RULES:
if key in q:
for m in metrics:
if m not in hints:
hints.append(m)
if not hints:
hints = ["reported_metric_count", "target_year", "percent_change", "absolute_value"]
return hints
def normalize_multi_skill_name(skill_name: str) -> str:
s = str(skill_name or "").strip().lower()
if "trend" in s and "quant" in s:
return "Trend & Quant Comparator"
if "attainment" in s or ("delta" in s and "benchmark" in s):
return "Target Attainment & Delta Benchmark"
if "compliance" in s and "check" in s:
return "Compliance Checklist"
if "dimension" in s and "extract" in s:
return "Dimension Extractor"
if "contradiction" in s or "consistency" in s:
return "Contradiction/Consistency Check"
if "consensus" in s or "portfolio" in s or "count" in s:
return "Consensus/Count (Portfolio Statistics)"
return "Comparative Table Builder"
def infer_multi_skill_name(question: str) -> str:
q = str(question or "").lower()
if ("trend" in q and "quant" in q) or "change magnitude" in q or "measurable progress" in q:
return "Trend & Quant Comparator"
if "attainment rate" in q or "target gap" in q or "baseline" in q or "delta percent" in q:
return "Target Attainment & Delta Benchmark"
if "checklist" in q or "compliance" in q:
return "Compliance Checklist"
if "extract" in q and "dimension" in q:
return "Dimension Extractor"
if "contradiction" in q or "consistent" in q or "inconsistent" in q:
return "Contradiction/Consistency Check"
if "consensus" in q or "outlier" in q or "portfolio" in q or "count" in q:
return "Consensus/Count (Portfolio Statistics)"
return "Comparative Table Builder"
def get_multi_skill_spec(skill_name: str) -> Dict[str, str]:
name = normalize_multi_skill_name(skill_name)
return {
"skill_name": name,
"skill_prompt": MULTI_SKILL_SPECS[name]["prompt"],
"output_json_schema": MULTI_OUTPUT_SCHEMA,
}
def build_yes_no_prompt(question: str, contexts: List[str], doc_mode: str) -> str:
few_shot = FEW_SHOT_EXAMPLES_CROSS if doc_mode == "cross" else FEW_SHOT_EXAMPLES_SINGLE
ctx_text = "\n".join(f'- "{c}"' for c in contexts)
return f"""You are an expert analyst evaluating corporate sustainability reports.
Based ONLY on the provided context passages, answer the question with [YES] or [NO], followed by a brief reasoning.
Format your answer as:
[YES]. <reasoning> OR [NO]. <reasoning>
{few_shot.strip()}
Now answer the following:
Question: {question}
Context:
{ctx_text}
Answer:"""
def build_multi_zero_shot_prompt(
question: str,
contexts: List[str],
skill_name: str = "",
skill_prompt: str = "",
output_json_schema: str = "",
retrieval_query: str = "",
metric_hints: List[str] | None = None,
) -> str:
ctx_text = "\n".join(f'[E{i}] "{c}"' for i, c in enumerate(contexts, start=1))
skill_name = str(skill_name or "").strip()
skill_prompt = str(skill_prompt or "").strip()
output_json_schema = str(output_json_schema or "").strip()
retrieval_query = str(retrieval_query or "").strip()
metric_hints = metric_hints or infer_metric_hints(question)
metric_hint_text = ", ".join(metric_hints)
return f"""{MULTI_ZERO_SHOT_BASE}
Matched skill: {skill_name}
Skill guidance: {skill_prompt}
Retrieval query dimension: {retrieval_query}
Suggested comparable metrics: {metric_hint_text}
Target JSON schema: {output_json_schema}
Question:
{question}
Context:
{ctx_text}
JSON Answer:"""
# ======================== Ground Truth ========================
def extract_yes_no(text: str) -> str | None:
text_upper = str(text).strip().upper()
m = re.match(r"\[*\s*(YES|NO)\s*\]*", text_upper)
if m:
return m.group(1)
if text_upper.startswith("YES"):
return "YES"
if text_upper.startswith("NO"):
return "NO"
return None
def load_ground_truth(doc_mode: str) -> Dict[Any, str]:
if doc_mode == "single":
df = pd.read_excel(BASE_XLSX_PATH, index_col=0)
gt = {}
for _, row in df.iterrows():
report = row["Document"]
question = row["Question"]
label = extract_yes_no(row["Answer"])
if label is not None:
gt[(report, question)] = label
return gt
if doc_mode == "cross":
df = pd.read_excel(CROSS_XLSX_PATH)
if "Unnamed: 0" in df.columns:
df = df.drop(columns=["Unnamed: 0"])
gt = {}
for _, row in df.iterrows():
question = row["Question"]
label = extract_yes_no(row["Answer"])
if label is not None and question not in gt:
gt[question] = label
return gt
# multi has no closed-form YES/NO ground truth
return {}
# ======================== Retrieval Result Loading ========================
def get_result_dir(chunk_mode: str, doc_mode: str) -> str:
parts = [RESULT_DIR_BASE]
if chunk_mode != "length":
parts.append(chunk_mode)
if doc_mode != "single":
parts.append(doc_mode)
return "_".join(parts)
def load_retrieval_contexts(chunk_mode: str, doc_mode: str, top_k: int) -> Dict[Any, Any]:
result_dir = get_result_dir(chunk_mode, doc_mode)
result_file = os.path.join(result_dir, f"Qwen3-Embedding-0.6B__{chunk_mode}__{doc_mode}__{top_k}.csv")
if not os.path.exists(result_file):
raise FileNotFoundError(f"Retrieval result file not found: {result_file}")
df = pd.read_csv(result_file)
retrieved = df[df["question_retrieved"] == 1].copy()
retrieved = retrieved.sort_values("similarity_score", ascending=False)
if doc_mode == "single":
contexts = {}
for (report, question), group in retrieved.groupby(["report", "question"], sort=False):
contexts[(report, question)] = group["chunk_text"].tolist()
return contexts
contexts = {}
for question, group in retrieved.groupby("question", sort=False):
chunks = []
for _, row in group.iterrows():
chunks.append(f"({row['report'].replace('.pdf', '')}, chunk={row['chunk_idx']}) {row['chunk_text']}")
payload = {"chunks": chunks}
for col in ["skill_name", "skill_prompt", "output_json_schema", "retrieval_query"]:
if col in group.columns:
non_null = group[col].dropna()
payload[col] = non_null.iloc[0] if len(non_null) > 0 else ""
else:
payload[col] = ""
contexts[question] = payload
return contexts
# ======================== Main ========================
def main():
parser = argparse.ArgumentParser(description="RAG generation for ClimRetrieve")
parser.add_argument("--chunk", type=str, default="length", choices=["length", "structure"])
parser.add_argument("--doc", type=str, default="single", choices=["single", "cross", "multi"])
parser.add_argument("--top_k", type=int, default=RETRIEVAL_TOP_K)
parser.add_argument("--model", type=str, default=GEN_MODEL_PATH)
parser.add_argument("--max_tokens", type=int, default=MAX_NEW_TOKENS)
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--limit", type=int, default=None, help="Only generate for first N questions")
args = parser.parse_args()
chunk_mode = args.chunk
doc_mode = args.doc
top_k = args.top_k
limit = args.limit
print("=" * 70)
print("RAG Generation")
print(f" CHUNK={chunk_mode}, DOC={doc_mode}, TOP_K={top_k}, LIMIT={limit}")
print(f" Model: {args.model}")
print("=" * 70)
print("\n[Step 1] Load retrieval contexts")
contexts = load_retrieval_contexts(chunk_mode, doc_mode, top_k)
print(f" Context groups: {len(contexts)}")
if limit is not None and limit > 0:
items = list(contexts.items())[:limit]
contexts = dict(items)
print(f" Applied limit -> {len(contexts)} groups")
gt = {}
if doc_mode in ("single", "cross"):
print("\n[Step 2] Load ground truth")
gt = load_ground_truth(doc_mode)
print(f" Ground truth count: {len(gt)}")
else:
print("\n[Step 2] Multi mode: skip YES/NO ground truth")
print("\n[Step 3] Build prompts")
prompts = []
prompt_keys = []
if doc_mode == "single":
for (report, question), ctx_list in contexts.items():
prompts.append(build_yes_no_prompt(question, ctx_list[:top_k], doc_mode))
prompt_keys.append({"report": report, "question": question})
elif doc_mode == "cross":
for question, payload in contexts.items():
ctx_list = payload["chunks"] if isinstance(payload, dict) else payload
prompts.append(build_yes_no_prompt(question, ctx_list[:top_k], doc_mode))
prompt_keys.append({"question": question})
else:
for question, payload in contexts.items():
ctx_list = payload.get("chunks", [])
inferred_skill = payload.get("skill_name", "") or infer_multi_skill_name(question)
spec = get_multi_skill_spec(inferred_skill)
skill_name = spec["skill_name"]
skill_prompt = payload.get("skill_prompt", "").strip() or spec["skill_prompt"]
output_json_schema = spec["output_json_schema"]
retrieval_query = payload.get("retrieval_query", "").strip() or question
prompts.append(
build_multi_zero_shot_prompt(
question=question,
contexts=ctx_list[:top_k],
skill_name=skill_name,
skill_prompt=skill_prompt,
output_json_schema=output_json_schema,
retrieval_query=retrieval_query,
metric_hints=infer_metric_hints(question),
)
)
prompt_keys.append(
{
"question": question,
"skill_name": skill_name,
"retrieval_query": retrieval_query,
}
)
print(f" Prompt count: {len(prompts)}")
if prompts:
print(f" Prompt preview: {prompts[0][:400]}...")
print("\n[Step 4] Run generation")
from vllm import LLM, SamplingParams
llm = LLM(model=args.model, max_model_len=8192, dtype="auto", trust_remote_code=True)
sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
top_p=1.0,
)
outputs = llm.generate(prompts, sampling_params)
print("\n[Step 5] Collect results")
results = []
correct = 0
total_eval = 0
for i, output in enumerate(outputs):
generated_text = output.outputs[0].text.strip()
key = prompt_keys[i]
if doc_mode == "multi":
results.append(
{
**key,
"generated_answer": generated_text,
}
)
continue
predicted_label = extract_yes_no(generated_text)
if doc_mode == "single":
gt_label = gt.get((key["report"], key["question"]))
else:
gt_label = gt.get(key["question"])
is_correct = None
if predicted_label is not None and gt_label is not None:
is_correct = 1 if predicted_label == gt_label else 0
correct += is_correct
total_eval += 1
results.append(
{
**key,
"generated_answer": generated_text[:800],
"predicted_label": predicted_label,
"gt_label": gt_label,
"is_correct": is_correct,
}
)
print("\n[Step 6] Save outputs")
output_dir = get_result_dir(chunk_mode, doc_mode)
os.makedirs(output_dir, exist_ok=True)
limit_tag = f"__limit{limit}" if (limit is not None and limit > 0) else ""
output_file = os.path.join(output_dir, f"generation__{chunk_mode}__{doc_mode}__top{top_k}{limit_tag}.csv")
pd.DataFrame(results).to_csv(output_file, index=False, encoding="utf-8-sig")
print(f" Result file: {output_file}")
if doc_mode == "multi":
metrics = {
"chunk_mode": chunk_mode,
"doc_mode": doc_mode,
"top_k": top_k,
"limit": limit,
"generated_count": len(results),
}
else:
accuracy = correct / total_eval if total_eval > 0 else 0.0
metrics = {
"chunk_mode": chunk_mode,
"doc_mode": doc_mode,
"top_k": top_k,
"limit": limit,
"total_eval": total_eval,
"correct": correct,
"accuracy": accuracy,
"unparsed": sum(1 for r in results if r.get("predicted_label") is None),
"no_gt": sum(1 for r in results if r.get("gt_label") is None),
}
print(f" Accuracy: {accuracy:.4f} ({correct}/{total_eval})")
metrics_file = os.path.join(output_dir, f"generation_metrics__{chunk_mode}__{doc_mode}__top{top_k}{limit_tag}.json")
with open(metrics_file, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
print(f" Metrics file: {metrics_file}")
print("\n" + "=" * 70)
print("Generation completed.")
print("=" * 70)
if __name__ == "__main__":
main()