cascade_risk / src /data /cascade_extractor.py
Lucasoppem's picture
Sync from GitHub main (part 2)
36f9d47 verified
Raw
History Blame Contribute Delete
21.2 kB
"""Extract structured cascade chains from news articles using LLM."""
from __future__ import annotations
import json
import logging
import re
from datetime import date
from pathlib import Path
from src.llm.client import LLMClient, load_config
from src.models.schemas import CascadeChain, CascadeNode, FloodEvent
logger = logging.getLogger(__name__)
# Single-shot character budget. Scaled up for v0.2 issue #10 (Qwen3.5-9B,
# 64K context) so that all real flood events fit in one LLM call;
# `_extract_batched` only triggers for genuinely huge corpora and now invokes
# `_consolidate_with_llm` to merge across batches before code-side dedupe.
MAX_ARTICLES_TEXT_LEN = 80000
# Per v0.2 issue #10: budget raised 30 → 40 to give the cascade DAG more room
# for cross-domain depth without sacrificing the major-impacts-only bar.
MAX_NODES = 40
MAX_TIME_OFFSET_HOURS = 336 # 2 weeks
ALLOWED_DOMAINS = {
"infrastructure/power",
"infrastructure/water",
"infrastructure/transport",
"infrastructure/communication",
"health/casualties",
"health/hospital_service",
"health/disease_outbreak",
"social/evacuation",
"social/supply_shortage",
"economy/business_damage",
"economy/agriculture",
"environment/contamination",
}
# Legacy / near-miss labels seen from v1 runs → canonical
DOMAIN_REMAP = {
"health/disease": "health/disease_outbreak",
"health/hospital": "health/hospital_service",
"health/emergency_services": "health/hospital_service",
"emergency_services": "health/hospital_service",
"government/emergency_services": "health/hospital_service",
"social/displacement": "social/evacuation",
"social/housing": "social/evacuation",
"social/mental_health": None, # out of scope
"economy/business": "economy/business_damage",
"environment/pollution": "environment/contamination",
"environment/ecology": "environment/contamination",
"environment/weather": None, # weather is the trigger, not a cascade
"education": None,
"education/business": None,
}
_STOPWORDS = {
"the", "a", "an", "of", "to", "in", "and", "for", "on", "due", "with",
"at", "by", "from", "as", "is", "are", "was", "were", "be", "been",
"have", "has", "had", "caused", "resulted", "led", "flood", "flooding",
"floodwaters", "floodwater", "waters", "water", "event", "events",
"impact", "impacts", "disruption", "caused", "affecting", "affected",
"after", "following", "during", "this", "that", "these", "those",
}
SEVERITY_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
def extract_cascade_chain(
event: FloodEvent,
articles: list[dict],
llm_client: LLMClient,
config: dict | None = None,
) -> CascadeChain | None:
"""Extract a cascade chain from scraped articles for a single event."""
if config is None:
config = load_config()
if not articles:
logger.warning(f"No articles for {event.event_id}, skipping extraction")
return None
articles_text = _format_articles(articles)
if len(articles_text) > MAX_ARTICLES_TEXT_LEN:
chain = _extract_batched(event, articles, llm_client, config)
else:
chain = _extract_single(event, articles_text, articles, llm_client, config)
if chain is None:
return None
n_raw = len(chain.cascade_events)
nodes = _normalize_domains(chain.cascade_events)
nodes = _filter_scope(nodes)
nodes = _semantic_dedupe(nodes)
n_prefilter = len(nodes)
if len(nodes) > MAX_NODES:
logger.info(
f" {event.event_id}: {n_prefilter} nodes → "
f"diversified truncation to {MAX_NODES}"
)
nodes = _truncate_by_priority(nodes, MAX_NODES)
# Final renumber for clean sequential ids after all merges
nodes = _renumber(nodes)
logger.info(
f" {event.event_id}: {n_raw} raw → {n_prefilter} after scope/dedupe → "
f"{len(nodes)} final"
)
chain.cascade_events = nodes
return chain
def _extract_single(
event: FloodEvent,
articles_text: str,
articles: list[dict],
llm_client: LLMClient,
config: dict,
) -> CascadeChain | None:
variables = {
"event_id": event.event_id,
"country": event.country,
"location": event.location or event.country,
"start_date": str(event.start_date),
"origin": event.origin or "Unknown",
"articles_text": articles_text,
}
response = llm_client.call_with_config(
prompt_key="extract_cascade",
knowledge_key="extraction",
variables=variables,
config=config,
)
try:
json_str = _extract_json(response)
data = json.loads(json_str)
return _parse_chain(data, event, articles)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.error(f"Failed to parse cascade chain for {event.event_id}: {e}")
return None
def _extract_batched(
event: FloodEvent,
articles: list[dict],
llm_client: LLMClient,
config: dict,
) -> CascadeChain | None:
"""Run extraction batch by batch, then LLM-consolidate across batches.
Code-side renumbering keeps the id space disjoint per batch but cannot
decide whether two same-domain nodes from different batches actually
refer to the same real-world cascade — that requires LLM judgement.
Issue #10 wires `_consolidate_with_llm` into this path so the batched
fallback gets the same cross-batch semantic merging that single-shot
extraction does implicitly. With `MAX_ARTICLES_TEXT_LEN = 80000` only
truly huge corpora hit this branch, so the extra LLM call is rare.
"""
batches = _batch_articles(articles)
all_nodes: list[CascadeNode] = []
node_id_counter = 0
for batch in batches:
batch_text = _format_articles(batch)
chain = _extract_single(event, batch_text, batch, llm_client, config)
if not chain:
continue
id_map: dict[str, str] = {}
for node in chain.cascade_events:
node_id_counter += 1
new_id = f"E{node_id_counter}"
id_map[node.id] = new_id
node.id = new_id
for node in chain.cascade_events:
node.parent_ids = [id_map.get(pid, pid) for pid in node.parent_ids]
all_nodes.extend(chain.cascade_events)
if not all_nodes:
return None
if len(batches) > 1:
logger.info(
f" {event.event_id}: consolidating {len(all_nodes)} nodes from "
f"{len(batches)} batches via LLM"
)
consolidated = _consolidate_with_llm(event, all_nodes, llm_client, config)
if consolidated is not None:
all_nodes = consolidated
else:
logger.warning(
f" {event.event_id}: LLM consolidation failed; falling back to "
"unconsolidated batch union (code-side dedupe will still run)"
)
return CascadeChain(
event_id=event.event_id,
trigger_summary=f"Flood in {event.location or event.country} on {event.start_date}",
trigger_country=event.country,
trigger_iso=event.iso,
trigger_date=event.start_date,
trigger_severity=_infer_severity(event),
cascade_events=all_nodes,
source_articles=[a["url"] for a in articles],
extraction_date=date.today(),
)
def save_cascade_chain(chain: CascadeChain, config: dict | None = None) -> Path:
if config is None:
config = load_config()
out_dir = Path(config["paths"]["cascade_chains_dir"])
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"{chain.event_id}.json"
path.write_text(
json.dumps(chain.model_dump(mode="json"), indent=2, ensure_ascii=False)
)
return path
def _format_articles(articles: list[dict]) -> str:
parts = []
for i, a in enumerate(articles, 1):
parts.append(
f"--- Article {i} ---\n"
f"Title: {a.get('title', 'N/A')}\n"
f"Date: {a.get('date', 'N/A')}\n"
f"Source: {a.get('source', 'N/A')}\n"
f"Domain: {a.get('domain', 'N/A')}\n\n"
f"{a.get('text', '')}\n"
)
return "\n".join(parts)
def _batch_articles(articles: list[dict]) -> list[list[dict]]:
batches = []
current_batch: list[dict] = []
current_len = 0
for article in articles:
text_len = len(_format_articles([article]))
if current_len + text_len > MAX_ARTICLES_TEXT_LEN and current_batch:
batches.append(current_batch)
current_batch = []
current_len = 0
current_batch.append(article)
current_len += text_len
if current_batch:
batches.append(current_batch)
return batches
def _extract_json(response: str) -> str:
if "```json" in response:
return response.split("```json")[1].split("```")[0].strip()
if "```" in response:
return response.split("```")[1].split("```")[0].strip()
return response.strip()
def _parse_chain(data: dict, event: FloodEvent, articles: list[dict]) -> CascadeChain:
nodes = []
for node_data in data.get("cascade_events", []):
nodes.append(
CascadeNode(
id=node_data["id"],
description=node_data["description"],
domain=node_data.get("domain", "unknown"),
severity=node_data.get("severity", "medium"),
time_offset_hours=node_data.get("time_offset_hours"),
mechanism=node_data.get("mechanism", ""),
parent_ids=node_data.get("parent_ids", []),
)
)
return CascadeChain(
event_id=event.event_id,
trigger_summary=data.get(
"trigger_summary",
f"Flood in {event.location or event.country} on {event.start_date}",
),
trigger_country=event.country,
trigger_iso=event.iso,
trigger_date=event.start_date,
trigger_severity=_infer_severity(event),
cascade_events=nodes,
source_articles=[a["url"] for a in articles],
extraction_date=date.today(),
)
def _infer_severity(event: FloodEvent) -> str:
deaths = event.total_deaths or 0
affected = event.total_affected or 0
if deaths > 50 or affected > 100000:
return "critical"
if deaths > 10 or affected > 10000:
return "high"
if deaths > 0 or affected > 1000:
return "medium"
return "low"
def _normalize_domains(nodes: list[CascadeNode]) -> list[CascadeNode]:
"""Map legacy domain labels to the closed taxonomy; drop nodes with no mapping."""
kept: list[CascadeNode] = []
for node in nodes:
d = node.domain.strip()
if d in ALLOWED_DOMAINS:
kept.append(node)
continue
if d in DOMAIN_REMAP:
target = DOMAIN_REMAP[d]
if target is None:
continue
node.domain = target
kept.append(node)
continue
# Unknown label: try top-level prefix match
prefix = d.split("/", 1)[0]
fallback = next(
(x for x in ALLOWED_DOMAINS if x.startswith(prefix + "/")),
None,
)
if fallback:
node.domain = fallback
kept.append(node)
# else: drop
return kept
def _filter_scope(nodes: list[CascadeNode]) -> list[CascadeNode]:
"""Drop out-of-window and trivially-low nodes."""
kept = []
for node in nodes:
if (
node.time_offset_hours is not None
and node.time_offset_hours > MAX_TIME_OFFSET_HOURS
):
continue
kept.append(node)
return kept
def _semantic_dedupe(
nodes: list[CascadeNode], jaccard_threshold: float = 0.4
) -> list[CascadeNode]:
"""Merge nodes in the same domain with high word-overlap descriptions.
Cluster-merge keeps the longer description, picks higher severity, and
takes the union of parent_ids as a starting point. The merged graph then
has each node's parent_ids re-pruned by `_prune_ancestors` so the v0.2
issue #9 / B' no-grandparent rule still holds after structural merging
(the union of two correctly-pruned sets can introduce new ancestor
violations).
"""
by_domain: dict[str, list[CascadeNode]] = {}
for n in nodes:
by_domain.setdefault(n.domain, []).append(n)
merged: list[CascadeNode] = []
id_rewrite: dict[str, str] = {}
for domain, group in by_domain.items():
clusters: list[list[CascadeNode]] = []
for node in group:
wset = _wordset(node.description)
placed = False
for cluster in clusters:
rep_words = _wordset(cluster[0].description)
# Merge if high jaccard OR one wordset is mostly subset of the other
j = _jaccard(wset, rep_words)
subset_ratio = _subset_ratio(wset, rep_words)
if j >= jaccard_threshold or subset_ratio >= 0.75:
cluster.append(node)
placed = True
break
if not placed:
clusters.append([node])
for cluster in clusters:
rep = _merge_cluster(cluster)
merged.append(rep)
for n in cluster:
id_rewrite[n.id] = rep.id
# Rewrite parent refs and drop self-loops introduced by merging
for n in merged:
n.parent_ids = list({
id_rewrite.get(p, p)
for p in n.parent_ids
if id_rewrite.get(p, p) != n.id
})
# Drop parent refs that no longer exist
valid_ids = {n.id for n in merged}
for n in merged:
n.parent_ids = [p for p in n.parent_ids if p in valid_ids]
# Re-apply no-grandparent rule across the merged graph. Cluster merging can
# combine two correctly-pruned parent sets into one that now has an
# ancestor violation (e.g. cluster A's parent X is an ancestor of cluster
# B's parent Y once both are part of the merged node).
for n in merged:
n.parent_ids = _prune_ancestors(n.parent_ids, merged)
return merged
def _merge_cluster(cluster: list[CascadeNode]) -> CascadeNode:
"""Pick the representative from a cluster of near-duplicate nodes."""
# longest description = most specific
rep = max(cluster, key=lambda n: len(n.description))
# highest severity
rep.severity = max(
(n.severity for n in cluster),
key=lambda s: SEVERITY_RANK.get(s, 1),
)
# earliest non-null time
times = [n.time_offset_hours for n in cluster if n.time_offset_hours is not None]
if times:
rep.time_offset_hours = min(times)
# union parents — _semantic_dedupe runs _prune_ancestors over the merged
# graph afterward to enforce the no-grandparent rule.
parents: set[str] = set()
for n in cluster:
parents.update(n.parent_ids)
rep.parent_ids = sorted(parents)
return rep
def _wordset(text: str) -> set[str]:
words = re.findall(r"[a-zA-Z]+", text.lower())
return {w for w in words if len(w) > 2 and w not in _STOPWORDS}
def _jaccard(a: set[str], b: set[str]) -> float:
if not a or not b:
return 0.0
return len(a & b) / len(a | b)
def _subset_ratio(a: set[str], b: set[str]) -> float:
"""Fraction of the smaller wordset contained in the larger."""
if not a or not b:
return 0.0
small, large = (a, b) if len(a) <= len(b) else (b, a)
return len(small & large) / len(small)
def _prune_ancestors(parents: list[str], all_nodes: list[CascadeNode]) -> list[str]:
"""Drop any parent that is an ancestor of another listed parent.
Implements the v0.2 issue #9 / B' no-grandparent rule: when a node
declares multiple parents, only the most-direct ones are kept; transitive
ancestors are pruned because they are already implicit in the graph.
`all_nodes` is the full node set used to build the parent-of relation by
walking each node's ``parent_ids`` upward.
"""
if len(parents) <= 1:
return list(parents)
parent_map = {n.id: list(n.parent_ids) for n in all_nodes}
def ancestors_of(nid: str) -> set[str]:
seen: set[str] = set()
stack = list(parent_map.get(nid, []))
while stack:
cur = stack.pop()
if cur in seen:
continue
seen.add(cur)
stack.extend(parent_map.get(cur, []))
return seen
candidate_set = set(parents)
drop: set[str] = set()
for p in parents:
anc = ancestors_of(p)
for q in candidate_set:
if q != p and q in anc:
drop.add(q)
return [p for p in parents if p not in drop]
def _truncate_by_priority(nodes: list[CascadeNode], budget: int) -> list[CascadeNode]:
"""Keep a diverse, high-priority subset within budget.
Strategy:
1. Ensure at least 1 node per domain present in the input
2. Fill remaining slots ranked by (severity, parent_count, early time) —
multi-parent nodes rank higher because they carry more structural
signal under the v0.2 issue #9 / B' AND-conjunction semantics.
"""
def sev_rank(n: CascadeNode) -> int:
return SEVERITY_RANK.get(n.severity, 1)
def fill_rank(n: CascadeNode) -> tuple:
return (
-sev_rank(n),
-len(n.parent_ids),
n.time_offset_hours or 9999,
)
kept: list[CascadeNode] = []
kept_set: set[str] = set()
# Pass 1: best node per domain
by_domain: dict[str, list[CascadeNode]] = {}
for n in nodes:
by_domain.setdefault(n.domain, []).append(n)
for dom, group in by_domain.items():
if len(kept) >= budget:
break
best = min(group, key=fill_rank)
kept.append(best)
kept_set.add(best.id)
# Pass 2: fill remaining by global priority
for n in sorted(nodes, key=fill_rank):
if len(kept) >= budget:
break
if n.id not in kept_set:
kept.append(n)
kept_set.add(n.id)
# Rewrite dangling parent refs
for n in kept:
n.parent_ids = [p for p in n.parent_ids if p in kept_set]
return kept
def _consolidate_with_llm(
event: FloodEvent,
nodes: list[CascadeNode],
llm_client: LLMClient,
config: dict,
) -> list[CascadeNode] | None:
"""Call the consolidate prompt to merge a noisy chain down to ≤30 nodes."""
nodes_payload = json.dumps(
[
{
"id": n.id,
"description": n.description,
"domain": n.domain,
"severity": n.severity,
"time_offset_hours": n.time_offset_hours,
"mechanism": n.mechanism,
"parent_ids": n.parent_ids,
}
for n in nodes
],
ensure_ascii=False,
indent=2,
)
variables = {
"event_id": event.event_id,
"country": event.country,
"location": event.location or event.country,
"start_date": str(event.start_date),
"n_input": len(nodes),
"nodes_json": nodes_payload,
}
try:
response = llm_client.call_with_config(
prompt_key="consolidate_cascade",
knowledge_key="extraction",
variables=variables,
config=config,
)
data = json.loads(_extract_json(response))
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.error(f"Consolidation parse failed for {event.event_id}: {e}")
return None
consolidated: list[CascadeNode] = []
for nd in data.get("cascade_events", []):
try:
consolidated.append(
CascadeNode(
id=nd["id"],
description=nd["description"],
domain=nd.get("domain", "unknown"),
severity=nd.get("severity", "medium"),
time_offset_hours=nd.get("time_offset_hours"),
mechanism=nd.get("mechanism", ""),
parent_ids=nd.get("parent_ids", []),
)
)
except (KeyError, ValueError) as e:
logger.warning(f" skipping bad consolidated node: {e}")
if not consolidated:
return None
# Re-validate domains in case LLM invented new labels during merge
consolidated = _normalize_domains(consolidated)
consolidated = _filter_scope(consolidated)
if len(consolidated) > MAX_NODES:
consolidated = _truncate_by_priority(consolidated, MAX_NODES)
return consolidated
def _renumber(nodes: list[CascadeNode]) -> list[CascadeNode]:
"""Renumber ids sequentially as E1..En and rewrite parent_ids."""
id_map: dict[str, str] = {}
for i, node in enumerate(nodes, 1):
new_id = f"E{i}"
id_map[node.id] = new_id
for node in nodes:
node.id = id_map[node.id]
node.parent_ids = [id_map.get(p, p) for p in node.parent_ids if p in id_map]
return nodes