Spaces:
Running
Running
| """RAG Generation for single / cross / multi document settings.""" | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import json | |
| import argparse | |
| from typing import Dict, List, Tuple, Any | |
| import pandas as pd | |
| # ======================== Config ======================== | |
| GEN_MODEL_PATH = "../model/Qwen3-4B-Instruct-2507-FP8" | |
| BASE_XLSX_PATH = "../Expert-Annotated Relevant Sources Dataset/ClimRetrieve_base.xlsx" | |
| CROSS_XLSX_PATH = "../Expert-Annotated Relevant Sources Dataset/ClimRetrieve_cross.xlsx" | |
| RESULT_DIR_BASE = "./Embedding_Search_Results_Qwen" | |
| RETRIEVAL_TOP_K = 5 | |
| MAX_NEW_TOKENS = 768 | |
| TEMPERATURE = 0.0 | |
| # ======================== Prompt Templates ======================== | |
| FEW_SHOT_EXAMPLES_SINGLE = """Example 1: | |
| Question: Does the company have a strategy on waste management? | |
| Context: | |
| - "To meet our commitment to being zero waste by 2030, we are reducing our waste footprint through reuse, recycling and recovery." | |
| - "We increased our reuse and recycle rates of all cloud hardware to 82 percent." | |
| Answer: [YES]. The company has a clear strategy on waste management focused on reduction, recycling, and recovery to meet its 2030 zero waste commitment. | |
| Example 2: | |
| Question: Does the company report the climate change scenarios used to test the resilience of its business strategy? | |
| Context: | |
| - "Our ESG ratings improved from 2021 to 2022." | |
| - "We continue to focus on sustainable sourcing across our supply chain." | |
| Answer: [NO]. The provided context does not contain information about climate change scenarios used to test business strategy resilience. | |
| """ | |
| FEW_SHOT_EXAMPLES_CROSS = """Example 1: | |
| Question: Among [Company A Report, Company B Report], did at least one company address "Does the company have a strategy on waste management?" | |
| Context: | |
| - (Company A) "We are committed to reducing waste by 50% by 2030 through our circular economy approach." | |
| - (Company B) "Our revenue grew by 15% year-over-year driven by strong demand." | |
| Answer: [YES]. Company A provides clear evidence of a waste management strategy with specific targets, while Company B's context does not address waste management. | |
| Example 2: | |
| Question: Among [Company A Report, Company B Report], did Company A provide stronger disclosure than Company B on "Does the company report its Scope 1 emissions?" | |
| Context: | |
| - (Company A) "We focus on employee diversity and inclusion programs." | |
| - (Company B) "Our sustainability team oversees environmental compliance." | |
| Answer: [NO]. Neither company's context provides evidence about Scope 1 emissions reporting, so the comparison claim is not supported. | |
| """ | |
| MULTI_ZERO_SHOT_BASE = """You are a regulator/auditor analyzing multi-document climate disclosure questions. | |
| Use ONLY the provided context chunks. | |
| Do not fabricate facts, numeric values, years, units, page numbers, or source ids. | |
| Prioritize quantitative comparison: extract comparable metrics (value + unit + period) only if explicitly stated. | |
| If evidence is insufficient, keep fields conservative: null/[]/"insufficient". | |
| Set `conclusion` as a direct decision in this exact style: "[YES] <reason>" or "[NO] <reason>". | |
| Return strict JSON only (no markdown, no prose outside JSON). | |
| """ | |
| MULTI_OUTPUT_SCHEMA = ( | |
| '{' | |
| '"dimension":"string",' | |
| '"rows":[{"report":"string","year":"string","disclosure_status":"explicit | partial | missing",' | |
| '"key_points":["string"],"evidence_chunks":["E1","E4"]}],' | |
| '"ranking":[{"rank":1,"report":"string","rationale":"string"}],' | |
| '"conclusion":"string ([YES]/[NO] + reason)"' | |
| '}' | |
| ) | |
| MULTI_SKILL_SPECS = { | |
| "Comparative Table Builder": { | |
| "prompt": ( | |
| "Build a regulator-facing comparative table. " | |
| "For each report, output disclosure maturity and quantified metrics extracted from context. " | |
| "Use `quant_metrics` with objects: {metric, value, unit, period, note}. " | |
| "If a metric is unavailable for a report, do not invent it; use null values or omit the metric." | |
| ), | |
| "schema": ( | |
| '{"skill":"Comparative Table Builder","dimension":"string","reports":[{"report":"string","year":"number|null",' | |
| '"maturity_level":"high|moderate|low|insufficient","key_evidence":["string"],' | |
| '"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null","note":"string|null"}]}],' | |
| '"comparison_metrics":["string"],"conclusion":"string ([YES]/[NO] + reason)"}' | |
| ), | |
| }, | |
| "Trend & Quant Comparator": { | |
| "prompt": ( | |
| "Build a quantitative trend comparison across reports. " | |
| "For each report, return measurable indicators with fields: value, intensity, attainment_rate, " | |
| "change_magnitude, trend_direction. Keep values numeric when present in context, otherwise null. " | |
| "Include 1-3 concise key_evidence bullets per report." | |
| ), | |
| "schema": ( | |
| '{"skill":"Trend & Quant Comparator","dimension":"string","reports":[{"report":"string","year":"number|null",' | |
| '"key_evidence":["string"],"strength_score":"number|null",' | |
| '"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null",' | |
| '"intensity":"number|null","attainment_rate":"number|null","change_magnitude":"number|null",' | |
| '"trend_direction":"up|down|flat|unknown|insufficient","note":"string|null"}]}],' | |
| '"metric_highlights":["string"],"conclusion":"string ([YES]/[NO] + reason)"}' | |
| ), | |
| }, | |
| "Target Attainment & Delta Benchmark": { | |
| "prompt": ( | |
| "Create a target-attainment benchmark. For each report, extract baseline/current/target values where available, " | |
| "compute attainment_rate and delta fields (absolute and percent), and infer trend_direction conservatively. " | |
| "If any field is unavailable in context, return null/insufficient rather than guessing." | |
| ), | |
| "schema": ( | |
| '{"skill":"Target Attainment & Delta Benchmark","dimension":"string","reports":[{"report":"string","year":"number|null",' | |
| '"overall_strength":"high|moderate|low|insufficient","key_evidence":["string"],' | |
| '"benchmarks":[{"metric":"string","baseline_value":"number|null","baseline_period":"string|null",' | |
| '"current_value":"number|null","current_period":"string|null","target_value":"number|null","target_period":"string|null",' | |
| '"attainment_rate":"number|null","delta_abs":"number|null","delta_percent":"number|null","intensity":"number|null",' | |
| '"unit":"string|null","trend_direction":"up|down|flat|unknown|insufficient","note":"string|null"}]}],' | |
| '"leaderboard":[{"report":"string","score":"number|null","reason":"string"}],"conclusion":"string ([YES]/[NO] + reason)"}' | |
| ), | |
| }, | |
| "Compliance Checklist": { | |
| "prompt": ( | |
| "Create a compliance checklist comparison. " | |
| "Return item-level pass/partial/fail for each report and include quantified indicators where available. " | |
| "Also compute summary counts per report and include 1-3 concise key_evidence bullets per report." | |
| ), | |
| "schema": ( | |
| '{"skill":"Compliance Checklist","dimension":"string","required_checks":["string"],"reports":[{"report":"string",' | |
| '"key_evidence":["string"],' | |
| '"summary":{"pass":"number","partial":"number","fail":"number","completion_rate":"number|null"},' | |
| '"checks":[{"item":"string","status":"pass|partial|fail|insufficient","quant_value":"number|null","quant_unit":"string|null","note":"string"}]}],' | |
| '"conclusion":"string ([YES]/[NO] + reason)"}' | |
| ), | |
| }, | |
| "Dimension Extractor": { | |
| "prompt": ( | |
| "Extract disclosure dimensions and quantify coverage by bucket. " | |
| "Return bucket-level counts and any explicit quantitative metric snippets, plus 1-3 key_evidence bullets per report." | |
| ), | |
| "schema": ( | |
| '{"skill":"Dimension Extractor","dimension":"string","reports":[{"report":"string","bucket_counts":{"Process":"number","Input":"number","Output":"number","Outcome":"number","Governance":"number","Risk":"number"},' | |
| '"key_evidence":["string"],' | |
| '"quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null","note":"string|null"}],' | |
| '"coverage_level":"high|moderate|low|insufficient"}],"conclusion":"string ([YES]/[NO] + reason)"}' | |
| ), | |
| }, | |
| "Contradiction/Consistency Check": { | |
| "prompt": ( | |
| "Evaluate consistency across reports. " | |
| "Return rule checks and a quantified consistency score per report pair/group. " | |
| "Include a concise key_evidence list that supports the consistency conclusion." | |
| ), | |
| "schema": ( | |
| '{"skill":"Contradiction/Consistency Check","dimension":"string","key_evidence":["string"],' | |
| '"checks":[{"rule":"string","result":"consistent|inconsistent|insufficient","note":"string"}],' | |
| '"scores":{"consistent":"number","inconsistent":"number","insufficient":"number","consistency_rate":"number|null"},' | |
| '"conclusion":"string ([YES]/[NO] + reason)"}' | |
| ), | |
| }, | |
| "Consensus/Count (Portfolio Statistics)": { | |
| "prompt": ( | |
| "Produce portfolio-level statistics from retrieved context only. " | |
| "Count explicit/partial/missing disclosures and include percentage split. " | |
| "For each report, include 1-3 key_evidence bullets." | |
| ), | |
| "schema": ( | |
| '{"skill":"Consensus/Count (Portfolio Statistics)","dimension":"string","counts":{"explicit":"number","partial":"number","missing":"number","total":"number"},' | |
| '"percentages":{"explicit":"number|null","partial":"number|null","missing":"number|null"},' | |
| '"per_report":[{"report":"string","label":"explicit|partial|missing|insufficient","quant_metrics":[{"metric":"string","value":"number|null","unit":"string|null","period":"string|null"}]}],' | |
| '"key_evidence_by_report":[{"report":"string","key_evidence":["string"]}],' | |
| '"consensus_items":["string"],"outliers":["string"],"conclusion":"string ([YES]/[NO] + reason)"}' | |
| ), | |
| }, | |
| } | |
| METRIC_HINT_RULES = [ | |
| ("trend", ["value", "intensity", "attainment_rate", "change_magnitude", "delta_percent"]), | |
| ("attainment", ["target_value", "current_value", "attainment_rate", "delta_percent"]), | |
| ("delta", ["delta_abs", "delta_percent", "baseline_value", "current_value"]), | |
| ("baseline", ["baseline_value", "current_value", "target_value", "delta_percent"]), | |
| ("kpi", ["value", "unit", "period", "change_magnitude"]), | |
| ("risk", ["number_of_risk_categories", "assessment_frequency_per_year", "scenario_count", "horizon_year"]), | |
| ("methodolog", ["framework_count", "method_steps_count", "indicator_count", "coverage_percent"]), | |
| ("scenario", ["scenario_count", "temperature_pathway_count", "horizon_year"]), | |
| ("resilience", ["scenario_count", "stress_test_count", "horizon_year"]), | |
| ("impact", ["scope1_emissions", "scope2_emissions", "scope3_emissions", "reduction_percent"]), | |
| ("waste", ["waste_reduction_percent", "recycling_rate_percent", "waste_diverted_tons", "target_year"]), | |
| ("industry peers", ["initiative_count", "partnership_count", "engagement_frequency_per_year"]), | |
| ("downstream", ["supplier_coverage_percent", "partner_count", "assessment_completion_percent"]), | |
| ("water", ["water_withdrawal_change_percent", "water_intensity_change_percent", "sites_count", "target_year"]), | |
| ] | |
| def infer_metric_hints(question: str) -> List[str]: | |
| q = str(question or "").lower() | |
| hints = [] | |
| for key, metrics in METRIC_HINT_RULES: | |
| if key in q: | |
| for m in metrics: | |
| if m not in hints: | |
| hints.append(m) | |
| if not hints: | |
| hints = ["reported_metric_count", "target_year", "percent_change", "absolute_value"] | |
| return hints | |
| def normalize_multi_skill_name(skill_name: str) -> str: | |
| s = str(skill_name or "").strip().lower() | |
| if "trend" in s and "quant" in s: | |
| return "Trend & Quant Comparator" | |
| if "attainment" in s or ("delta" in s and "benchmark" in s): | |
| return "Target Attainment & Delta Benchmark" | |
| if "compliance" in s and "check" in s: | |
| return "Compliance Checklist" | |
| if "dimension" in s and "extract" in s: | |
| return "Dimension Extractor" | |
| if "contradiction" in s or "consistency" in s: | |
| return "Contradiction/Consistency Check" | |
| if "consensus" in s or "portfolio" in s or "count" in s: | |
| return "Consensus/Count (Portfolio Statistics)" | |
| return "Comparative Table Builder" | |
| def infer_multi_skill_name(question: str) -> str: | |
| q = str(question or "").lower() | |
| if ("trend" in q and "quant" in q) or "change magnitude" in q or "measurable progress" in q: | |
| return "Trend & Quant Comparator" | |
| if "attainment rate" in q or "target gap" in q or "baseline" in q or "delta percent" in q: | |
| return "Target Attainment & Delta Benchmark" | |
| if "checklist" in q or "compliance" in q: | |
| return "Compliance Checklist" | |
| if "extract" in q and "dimension" in q: | |
| return "Dimension Extractor" | |
| if "contradiction" in q or "consistent" in q or "inconsistent" in q: | |
| return "Contradiction/Consistency Check" | |
| if "consensus" in q or "outlier" in q or "portfolio" in q or "count" in q: | |
| return "Consensus/Count (Portfolio Statistics)" | |
| return "Comparative Table Builder" | |
| def get_multi_skill_spec(skill_name: str) -> Dict[str, str]: | |
| name = normalize_multi_skill_name(skill_name) | |
| return { | |
| "skill_name": name, | |
| "skill_prompt": MULTI_SKILL_SPECS[name]["prompt"], | |
| "output_json_schema": MULTI_OUTPUT_SCHEMA, | |
| } | |
| def build_yes_no_prompt(question: str, contexts: List[str], doc_mode: str) -> str: | |
| few_shot = FEW_SHOT_EXAMPLES_CROSS if doc_mode == "cross" else FEW_SHOT_EXAMPLES_SINGLE | |
| ctx_text = "\n".join(f'- "{c}"' for c in contexts) | |
| return f"""You are an expert analyst evaluating corporate sustainability reports. | |
| Based ONLY on the provided context passages, answer the question with [YES] or [NO], followed by a brief reasoning. | |
| Format your answer as: | |
| [YES]. <reasoning> OR [NO]. <reasoning> | |
| {few_shot.strip()} | |
| Now answer the following: | |
| Question: {question} | |
| Context: | |
| {ctx_text} | |
| Answer:""" | |
| def build_multi_zero_shot_prompt( | |
| question: str, | |
| contexts: List[str], | |
| skill_name: str = "", | |
| skill_prompt: str = "", | |
| output_json_schema: str = "", | |
| retrieval_query: str = "", | |
| metric_hints: List[str] | None = None, | |
| ) -> str: | |
| ctx_text = "\n".join(f'[E{i}] "{c}"' for i, c in enumerate(contexts, start=1)) | |
| skill_name = str(skill_name or "").strip() | |
| skill_prompt = str(skill_prompt or "").strip() | |
| output_json_schema = str(output_json_schema or "").strip() | |
| retrieval_query = str(retrieval_query or "").strip() | |
| metric_hints = metric_hints or infer_metric_hints(question) | |
| metric_hint_text = ", ".join(metric_hints) | |
| return f"""{MULTI_ZERO_SHOT_BASE} | |
| Matched skill: {skill_name} | |
| Skill guidance: {skill_prompt} | |
| Retrieval query dimension: {retrieval_query} | |
| Suggested comparable metrics: {metric_hint_text} | |
| Target JSON schema: {output_json_schema} | |
| Question: | |
| {question} | |
| Context: | |
| {ctx_text} | |
| JSON Answer:""" | |
| # ======================== Ground Truth ======================== | |
| def extract_yes_no(text: str) -> str | None: | |
| text_upper = str(text).strip().upper() | |
| m = re.match(r"\[*\s*(YES|NO)\s*\]*", text_upper) | |
| if m: | |
| return m.group(1) | |
| if text_upper.startswith("YES"): | |
| return "YES" | |
| if text_upper.startswith("NO"): | |
| return "NO" | |
| return None | |
| def load_ground_truth(doc_mode: str) -> Dict[Any, str]: | |
| if doc_mode == "single": | |
| df = pd.read_excel(BASE_XLSX_PATH, index_col=0) | |
| gt = {} | |
| for _, row in df.iterrows(): | |
| report = row["Document"] | |
| question = row["Question"] | |
| label = extract_yes_no(row["Answer"]) | |
| if label is not None: | |
| gt[(report, question)] = label | |
| return gt | |
| if doc_mode == "cross": | |
| df = pd.read_excel(CROSS_XLSX_PATH) | |
| if "Unnamed: 0" in df.columns: | |
| df = df.drop(columns=["Unnamed: 0"]) | |
| gt = {} | |
| for _, row in df.iterrows(): | |
| question = row["Question"] | |
| label = extract_yes_no(row["Answer"]) | |
| if label is not None and question not in gt: | |
| gt[question] = label | |
| return gt | |
| # multi has no closed-form YES/NO ground truth | |
| return {} | |
| # ======================== Retrieval Result Loading ======================== | |
| def get_result_dir(chunk_mode: str, doc_mode: str) -> str: | |
| parts = [RESULT_DIR_BASE] | |
| if chunk_mode != "length": | |
| parts.append(chunk_mode) | |
| if doc_mode != "single": | |
| parts.append(doc_mode) | |
| return "_".join(parts) | |
| def load_retrieval_contexts(chunk_mode: str, doc_mode: str, top_k: int) -> Dict[Any, Any]: | |
| result_dir = get_result_dir(chunk_mode, doc_mode) | |
| result_file = os.path.join(result_dir, f"Qwen3-Embedding-0.6B__{chunk_mode}__{doc_mode}__{top_k}.csv") | |
| if not os.path.exists(result_file): | |
| raise FileNotFoundError(f"Retrieval result file not found: {result_file}") | |
| df = pd.read_csv(result_file) | |
| retrieved = df[df["question_retrieved"] == 1].copy() | |
| retrieved = retrieved.sort_values("similarity_score", ascending=False) | |
| if doc_mode == "single": | |
| contexts = {} | |
| for (report, question), group in retrieved.groupby(["report", "question"], sort=False): | |
| contexts[(report, question)] = group["chunk_text"].tolist() | |
| return contexts | |
| contexts = {} | |
| for question, group in retrieved.groupby("question", sort=False): | |
| chunks = [] | |
| for _, row in group.iterrows(): | |
| chunks.append(f"({row['report'].replace('.pdf', '')}, chunk={row['chunk_idx']}) {row['chunk_text']}") | |
| payload = {"chunks": chunks} | |
| for col in ["skill_name", "skill_prompt", "output_json_schema", "retrieval_query"]: | |
| if col in group.columns: | |
| non_null = group[col].dropna() | |
| payload[col] = non_null.iloc[0] if len(non_null) > 0 else "" | |
| else: | |
| payload[col] = "" | |
| contexts[question] = payload | |
| return contexts | |
| # ======================== Main ======================== | |
| def main(): | |
| parser = argparse.ArgumentParser(description="RAG generation for ClimRetrieve") | |
| parser.add_argument("--chunk", type=str, default="length", choices=["length", "structure"]) | |
| parser.add_argument("--doc", type=str, default="single", choices=["single", "cross", "multi"]) | |
| parser.add_argument("--top_k", type=int, default=RETRIEVAL_TOP_K) | |
| parser.add_argument("--model", type=str, default=GEN_MODEL_PATH) | |
| parser.add_argument("--max_tokens", type=int, default=MAX_NEW_TOKENS) | |
| parser.add_argument("--temperature", type=float, default=TEMPERATURE) | |
| parser.add_argument("--limit", type=int, default=None, help="Only generate for first N questions") | |
| args = parser.parse_args() | |
| chunk_mode = args.chunk | |
| doc_mode = args.doc | |
| top_k = args.top_k | |
| limit = args.limit | |
| print("=" * 70) | |
| print("RAG Generation") | |
| print(f" CHUNK={chunk_mode}, DOC={doc_mode}, TOP_K={top_k}, LIMIT={limit}") | |
| print(f" Model: {args.model}") | |
| print("=" * 70) | |
| print("\n[Step 1] Load retrieval contexts") | |
| contexts = load_retrieval_contexts(chunk_mode, doc_mode, top_k) | |
| print(f" Context groups: {len(contexts)}") | |
| if limit is not None and limit > 0: | |
| items = list(contexts.items())[:limit] | |
| contexts = dict(items) | |
| print(f" Applied limit -> {len(contexts)} groups") | |
| gt = {} | |
| if doc_mode in ("single", "cross"): | |
| print("\n[Step 2] Load ground truth") | |
| gt = load_ground_truth(doc_mode) | |
| print(f" Ground truth count: {len(gt)}") | |
| else: | |
| print("\n[Step 2] Multi mode: skip YES/NO ground truth") | |
| print("\n[Step 3] Build prompts") | |
| prompts = [] | |
| prompt_keys = [] | |
| if doc_mode == "single": | |
| for (report, question), ctx_list in contexts.items(): | |
| prompts.append(build_yes_no_prompt(question, ctx_list[:top_k], doc_mode)) | |
| prompt_keys.append({"report": report, "question": question}) | |
| elif doc_mode == "cross": | |
| for question, payload in contexts.items(): | |
| ctx_list = payload["chunks"] if isinstance(payload, dict) else payload | |
| prompts.append(build_yes_no_prompt(question, ctx_list[:top_k], doc_mode)) | |
| prompt_keys.append({"question": question}) | |
| else: | |
| for question, payload in contexts.items(): | |
| ctx_list = payload.get("chunks", []) | |
| inferred_skill = payload.get("skill_name", "") or infer_multi_skill_name(question) | |
| spec = get_multi_skill_spec(inferred_skill) | |
| skill_name = spec["skill_name"] | |
| skill_prompt = payload.get("skill_prompt", "").strip() or spec["skill_prompt"] | |
| output_json_schema = spec["output_json_schema"] | |
| retrieval_query = payload.get("retrieval_query", "").strip() or question | |
| prompts.append( | |
| build_multi_zero_shot_prompt( | |
| question=question, | |
| contexts=ctx_list[:top_k], | |
| skill_name=skill_name, | |
| skill_prompt=skill_prompt, | |
| output_json_schema=output_json_schema, | |
| retrieval_query=retrieval_query, | |
| metric_hints=infer_metric_hints(question), | |
| ) | |
| ) | |
| prompt_keys.append( | |
| { | |
| "question": question, | |
| "skill_name": skill_name, | |
| "retrieval_query": retrieval_query, | |
| } | |
| ) | |
| print(f" Prompt count: {len(prompts)}") | |
| if prompts: | |
| print(f" Prompt preview: {prompts[0][:400]}...") | |
| print("\n[Step 4] Run generation") | |
| from vllm import LLM, SamplingParams | |
| llm = LLM(model=args.model, max_model_len=8192, dtype="auto", trust_remote_code=True) | |
| sampling_params = SamplingParams( | |
| temperature=args.temperature, | |
| max_tokens=args.max_tokens, | |
| top_p=1.0, | |
| ) | |
| outputs = llm.generate(prompts, sampling_params) | |
| print("\n[Step 5] Collect results") | |
| results = [] | |
| correct = 0 | |
| total_eval = 0 | |
| for i, output in enumerate(outputs): | |
| generated_text = output.outputs[0].text.strip() | |
| key = prompt_keys[i] | |
| if doc_mode == "multi": | |
| results.append( | |
| { | |
| **key, | |
| "generated_answer": generated_text, | |
| } | |
| ) | |
| continue | |
| predicted_label = extract_yes_no(generated_text) | |
| if doc_mode == "single": | |
| gt_label = gt.get((key["report"], key["question"])) | |
| else: | |
| gt_label = gt.get(key["question"]) | |
| is_correct = None | |
| if predicted_label is not None and gt_label is not None: | |
| is_correct = 1 if predicted_label == gt_label else 0 | |
| correct += is_correct | |
| total_eval += 1 | |
| results.append( | |
| { | |
| **key, | |
| "generated_answer": generated_text[:800], | |
| "predicted_label": predicted_label, | |
| "gt_label": gt_label, | |
| "is_correct": is_correct, | |
| } | |
| ) | |
| print("\n[Step 6] Save outputs") | |
| output_dir = get_result_dir(chunk_mode, doc_mode) | |
| os.makedirs(output_dir, exist_ok=True) | |
| limit_tag = f"__limit{limit}" if (limit is not None and limit > 0) else "" | |
| output_file = os.path.join(output_dir, f"generation__{chunk_mode}__{doc_mode}__top{top_k}{limit_tag}.csv") | |
| pd.DataFrame(results).to_csv(output_file, index=False, encoding="utf-8-sig") | |
| print(f" Result file: {output_file}") | |
| if doc_mode == "multi": | |
| metrics = { | |
| "chunk_mode": chunk_mode, | |
| "doc_mode": doc_mode, | |
| "top_k": top_k, | |
| "limit": limit, | |
| "generated_count": len(results), | |
| } | |
| else: | |
| accuracy = correct / total_eval if total_eval > 0 else 0.0 | |
| metrics = { | |
| "chunk_mode": chunk_mode, | |
| "doc_mode": doc_mode, | |
| "top_k": top_k, | |
| "limit": limit, | |
| "total_eval": total_eval, | |
| "correct": correct, | |
| "accuracy": accuracy, | |
| "unparsed": sum(1 for r in results if r.get("predicted_label") is None), | |
| "no_gt": sum(1 for r in results if r.get("gt_label") is None), | |
| } | |
| print(f" Accuracy: {accuracy:.4f} ({correct}/{total_eval})") | |
| metrics_file = os.path.join(output_dir, f"generation_metrics__{chunk_mode}__{doc_mode}__top{top_k}{limit_tag}.json") | |
| with open(metrics_file, "w", encoding="utf-8") as f: | |
| json.dump(metrics, f, indent=2) | |
| print(f" Metrics file: {metrics_file}") | |
| print("\n" + "=" * 70) | |
| print("Generation completed.") | |
| print("=" * 70) | |
| if __name__ == "__main__": | |
| main() | |