"""Extract structured cascade chains from news articles using LLM.""" from __future__ import annotations import json import logging import re from datetime import date from pathlib import Path from src.llm.client import LLMClient, load_config from src.models.schemas import CascadeChain, CascadeNode, FloodEvent logger = logging.getLogger(__name__) # Single-shot character budget. Scaled up for v0.2 issue #10 (Qwen3.5-9B, # 64K context) so that all real flood events fit in one LLM call; # `_extract_batched` only triggers for genuinely huge corpora and now invokes # `_consolidate_with_llm` to merge across batches before code-side dedupe. MAX_ARTICLES_TEXT_LEN = 80000 # Per v0.2 issue #10: budget raised 30 → 40 to give the cascade DAG more room # for cross-domain depth without sacrificing the major-impacts-only bar. MAX_NODES = 40 MAX_TIME_OFFSET_HOURS = 336 # 2 weeks ALLOWED_DOMAINS = { "infrastructure/power", "infrastructure/water", "infrastructure/transport", "infrastructure/communication", "health/casualties", "health/hospital_service", "health/disease_outbreak", "social/evacuation", "social/supply_shortage", "economy/business_damage", "economy/agriculture", "environment/contamination", } # Legacy / near-miss labels seen from v1 runs → canonical DOMAIN_REMAP = { "health/disease": "health/disease_outbreak", "health/hospital": "health/hospital_service", "health/emergency_services": "health/hospital_service", "emergency_services": "health/hospital_service", "government/emergency_services": "health/hospital_service", "social/displacement": "social/evacuation", "social/housing": "social/evacuation", "social/mental_health": None, # out of scope "economy/business": "economy/business_damage", "environment/pollution": "environment/contamination", "environment/ecology": "environment/contamination", "environment/weather": None, # weather is the trigger, not a cascade "education": None, "education/business": None, } _STOPWORDS = { "the", "a", "an", "of", "to", "in", "and", "for", "on", "due", "with", "at", "by", "from", "as", "is", "are", "was", "were", "be", "been", "have", "has", "had", "caused", "resulted", "led", "flood", "flooding", "floodwaters", "floodwater", "waters", "water", "event", "events", "impact", "impacts", "disruption", "caused", "affecting", "affected", "after", "following", "during", "this", "that", "these", "those", } SEVERITY_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3} def extract_cascade_chain( event: FloodEvent, articles: list[dict], llm_client: LLMClient, config: dict | None = None, ) -> CascadeChain | None: """Extract a cascade chain from scraped articles for a single event.""" if config is None: config = load_config() if not articles: logger.warning(f"No articles for {event.event_id}, skipping extraction") return None articles_text = _format_articles(articles) if len(articles_text) > MAX_ARTICLES_TEXT_LEN: chain = _extract_batched(event, articles, llm_client, config) else: chain = _extract_single(event, articles_text, articles, llm_client, config) if chain is None: return None n_raw = len(chain.cascade_events) nodes = _normalize_domains(chain.cascade_events) nodes = _filter_scope(nodes) nodes = _semantic_dedupe(nodes) n_prefilter = len(nodes) if len(nodes) > MAX_NODES: logger.info( f" {event.event_id}: {n_prefilter} nodes → " f"diversified truncation to {MAX_NODES}" ) nodes = _truncate_by_priority(nodes, MAX_NODES) # Final renumber for clean sequential ids after all merges nodes = _renumber(nodes) logger.info( f" {event.event_id}: {n_raw} raw → {n_prefilter} after scope/dedupe → " f"{len(nodes)} final" ) chain.cascade_events = nodes return chain def _extract_single( event: FloodEvent, articles_text: str, articles: list[dict], llm_client: LLMClient, config: dict, ) -> CascadeChain | None: variables = { "event_id": event.event_id, "country": event.country, "location": event.location or event.country, "start_date": str(event.start_date), "origin": event.origin or "Unknown", "articles_text": articles_text, } response = llm_client.call_with_config( prompt_key="extract_cascade", knowledge_key="extraction", variables=variables, config=config, ) try: json_str = _extract_json(response) data = json.loads(json_str) return _parse_chain(data, event, articles) except (json.JSONDecodeError, KeyError, ValueError) as e: logger.error(f"Failed to parse cascade chain for {event.event_id}: {e}") return None def _extract_batched( event: FloodEvent, articles: list[dict], llm_client: LLMClient, config: dict, ) -> CascadeChain | None: """Run extraction batch by batch, then LLM-consolidate across batches. Code-side renumbering keeps the id space disjoint per batch but cannot decide whether two same-domain nodes from different batches actually refer to the same real-world cascade — that requires LLM judgement. Issue #10 wires `_consolidate_with_llm` into this path so the batched fallback gets the same cross-batch semantic merging that single-shot extraction does implicitly. With `MAX_ARTICLES_TEXT_LEN = 80000` only truly huge corpora hit this branch, so the extra LLM call is rare. """ batches = _batch_articles(articles) all_nodes: list[CascadeNode] = [] node_id_counter = 0 for batch in batches: batch_text = _format_articles(batch) chain = _extract_single(event, batch_text, batch, llm_client, config) if not chain: continue id_map: dict[str, str] = {} for node in chain.cascade_events: node_id_counter += 1 new_id = f"E{node_id_counter}" id_map[node.id] = new_id node.id = new_id for node in chain.cascade_events: node.parent_ids = [id_map.get(pid, pid) for pid in node.parent_ids] all_nodes.extend(chain.cascade_events) if not all_nodes: return None if len(batches) > 1: logger.info( f" {event.event_id}: consolidating {len(all_nodes)} nodes from " f"{len(batches)} batches via LLM" ) consolidated = _consolidate_with_llm(event, all_nodes, llm_client, config) if consolidated is not None: all_nodes = consolidated else: logger.warning( f" {event.event_id}: LLM consolidation failed; falling back to " "unconsolidated batch union (code-side dedupe will still run)" ) return CascadeChain( event_id=event.event_id, trigger_summary=f"Flood in {event.location or event.country} on {event.start_date}", trigger_country=event.country, trigger_iso=event.iso, trigger_date=event.start_date, trigger_severity=_infer_severity(event), cascade_events=all_nodes, source_articles=[a["url"] for a in articles], extraction_date=date.today(), ) def save_cascade_chain(chain: CascadeChain, config: dict | None = None) -> Path: if config is None: config = load_config() out_dir = Path(config["paths"]["cascade_chains_dir"]) out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"{chain.event_id}.json" path.write_text( json.dumps(chain.model_dump(mode="json"), indent=2, ensure_ascii=False) ) return path def _format_articles(articles: list[dict]) -> str: parts = [] for i, a in enumerate(articles, 1): parts.append( f"--- Article {i} ---\n" f"Title: {a.get('title', 'N/A')}\n" f"Date: {a.get('date', 'N/A')}\n" f"Source: {a.get('source', 'N/A')}\n" f"Domain: {a.get('domain', 'N/A')}\n\n" f"{a.get('text', '')}\n" ) return "\n".join(parts) def _batch_articles(articles: list[dict]) -> list[list[dict]]: batches = [] current_batch: list[dict] = [] current_len = 0 for article in articles: text_len = len(_format_articles([article])) if current_len + text_len > MAX_ARTICLES_TEXT_LEN and current_batch: batches.append(current_batch) current_batch = [] current_len = 0 current_batch.append(article) current_len += text_len if current_batch: batches.append(current_batch) return batches def _extract_json(response: str) -> str: if "```json" in response: return response.split("```json")[1].split("```")[0].strip() if "```" in response: return response.split("```")[1].split("```")[0].strip() return response.strip() def _parse_chain(data: dict, event: FloodEvent, articles: list[dict]) -> CascadeChain: nodes = [] for node_data in data.get("cascade_events", []): nodes.append( CascadeNode( id=node_data["id"], description=node_data["description"], domain=node_data.get("domain", "unknown"), severity=node_data.get("severity", "medium"), time_offset_hours=node_data.get("time_offset_hours"), mechanism=node_data.get("mechanism", ""), parent_ids=node_data.get("parent_ids", []), ) ) return CascadeChain( event_id=event.event_id, trigger_summary=data.get( "trigger_summary", f"Flood in {event.location or event.country} on {event.start_date}", ), trigger_country=event.country, trigger_iso=event.iso, trigger_date=event.start_date, trigger_severity=_infer_severity(event), cascade_events=nodes, source_articles=[a["url"] for a in articles], extraction_date=date.today(), ) def _infer_severity(event: FloodEvent) -> str: deaths = event.total_deaths or 0 affected = event.total_affected or 0 if deaths > 50 or affected > 100000: return "critical" if deaths > 10 or affected > 10000: return "high" if deaths > 0 or affected > 1000: return "medium" return "low" def _normalize_domains(nodes: list[CascadeNode]) -> list[CascadeNode]: """Map legacy domain labels to the closed taxonomy; drop nodes with no mapping.""" kept: list[CascadeNode] = [] for node in nodes: d = node.domain.strip() if d in ALLOWED_DOMAINS: kept.append(node) continue if d in DOMAIN_REMAP: target = DOMAIN_REMAP[d] if target is None: continue node.domain = target kept.append(node) continue # Unknown label: try top-level prefix match prefix = d.split("/", 1)[0] fallback = next( (x for x in ALLOWED_DOMAINS if x.startswith(prefix + "/")), None, ) if fallback: node.domain = fallback kept.append(node) # else: drop return kept def _filter_scope(nodes: list[CascadeNode]) -> list[CascadeNode]: """Drop out-of-window and trivially-low nodes.""" kept = [] for node in nodes: if ( node.time_offset_hours is not None and node.time_offset_hours > MAX_TIME_OFFSET_HOURS ): continue kept.append(node) return kept def _semantic_dedupe( nodes: list[CascadeNode], jaccard_threshold: float = 0.4 ) -> list[CascadeNode]: """Merge nodes in the same domain with high word-overlap descriptions. Cluster-merge keeps the longer description, picks higher severity, and takes the union of parent_ids as a starting point. The merged graph then has each node's parent_ids re-pruned by `_prune_ancestors` so the v0.2 issue #9 / B' no-grandparent rule still holds after structural merging (the union of two correctly-pruned sets can introduce new ancestor violations). """ by_domain: dict[str, list[CascadeNode]] = {} for n in nodes: by_domain.setdefault(n.domain, []).append(n) merged: list[CascadeNode] = [] id_rewrite: dict[str, str] = {} for domain, group in by_domain.items(): clusters: list[list[CascadeNode]] = [] for node in group: wset = _wordset(node.description) placed = False for cluster in clusters: rep_words = _wordset(cluster[0].description) # Merge if high jaccard OR one wordset is mostly subset of the other j = _jaccard(wset, rep_words) subset_ratio = _subset_ratio(wset, rep_words) if j >= jaccard_threshold or subset_ratio >= 0.75: cluster.append(node) placed = True break if not placed: clusters.append([node]) for cluster in clusters: rep = _merge_cluster(cluster) merged.append(rep) for n in cluster: id_rewrite[n.id] = rep.id # Rewrite parent refs and drop self-loops introduced by merging for n in merged: n.parent_ids = list({ id_rewrite.get(p, p) for p in n.parent_ids if id_rewrite.get(p, p) != n.id }) # Drop parent refs that no longer exist valid_ids = {n.id for n in merged} for n in merged: n.parent_ids = [p for p in n.parent_ids if p in valid_ids] # Re-apply no-grandparent rule across the merged graph. Cluster merging can # combine two correctly-pruned parent sets into one that now has an # ancestor violation (e.g. cluster A's parent X is an ancestor of cluster # B's parent Y once both are part of the merged node). for n in merged: n.parent_ids = _prune_ancestors(n.parent_ids, merged) return merged def _merge_cluster(cluster: list[CascadeNode]) -> CascadeNode: """Pick the representative from a cluster of near-duplicate nodes.""" # longest description = most specific rep = max(cluster, key=lambda n: len(n.description)) # highest severity rep.severity = max( (n.severity for n in cluster), key=lambda s: SEVERITY_RANK.get(s, 1), ) # earliest non-null time times = [n.time_offset_hours for n in cluster if n.time_offset_hours is not None] if times: rep.time_offset_hours = min(times) # union parents — _semantic_dedupe runs _prune_ancestors over the merged # graph afterward to enforce the no-grandparent rule. parents: set[str] = set() for n in cluster: parents.update(n.parent_ids) rep.parent_ids = sorted(parents) return rep def _wordset(text: str) -> set[str]: words = re.findall(r"[a-zA-Z]+", text.lower()) return {w for w in words if len(w) > 2 and w not in _STOPWORDS} def _jaccard(a: set[str], b: set[str]) -> float: if not a or not b: return 0.0 return len(a & b) / len(a | b) def _subset_ratio(a: set[str], b: set[str]) -> float: """Fraction of the smaller wordset contained in the larger.""" if not a or not b: return 0.0 small, large = (a, b) if len(a) <= len(b) else (b, a) return len(small & large) / len(small) def _prune_ancestors(parents: list[str], all_nodes: list[CascadeNode]) -> list[str]: """Drop any parent that is an ancestor of another listed parent. Implements the v0.2 issue #9 / B' no-grandparent rule: when a node declares multiple parents, only the most-direct ones are kept; transitive ancestors are pruned because they are already implicit in the graph. `all_nodes` is the full node set used to build the parent-of relation by walking each node's ``parent_ids`` upward. """ if len(parents) <= 1: return list(parents) parent_map = {n.id: list(n.parent_ids) for n in all_nodes} def ancestors_of(nid: str) -> set[str]: seen: set[str] = set() stack = list(parent_map.get(nid, [])) while stack: cur = stack.pop() if cur in seen: continue seen.add(cur) stack.extend(parent_map.get(cur, [])) return seen candidate_set = set(parents) drop: set[str] = set() for p in parents: anc = ancestors_of(p) for q in candidate_set: if q != p and q in anc: drop.add(q) return [p for p in parents if p not in drop] def _truncate_by_priority(nodes: list[CascadeNode], budget: int) -> list[CascadeNode]: """Keep a diverse, high-priority subset within budget. Strategy: 1. Ensure at least 1 node per domain present in the input 2. Fill remaining slots ranked by (severity, parent_count, early time) — multi-parent nodes rank higher because they carry more structural signal under the v0.2 issue #9 / B' AND-conjunction semantics. """ def sev_rank(n: CascadeNode) -> int: return SEVERITY_RANK.get(n.severity, 1) def fill_rank(n: CascadeNode) -> tuple: return ( -sev_rank(n), -len(n.parent_ids), n.time_offset_hours or 9999, ) kept: list[CascadeNode] = [] kept_set: set[str] = set() # Pass 1: best node per domain by_domain: dict[str, list[CascadeNode]] = {} for n in nodes: by_domain.setdefault(n.domain, []).append(n) for dom, group in by_domain.items(): if len(kept) >= budget: break best = min(group, key=fill_rank) kept.append(best) kept_set.add(best.id) # Pass 2: fill remaining by global priority for n in sorted(nodes, key=fill_rank): if len(kept) >= budget: break if n.id not in kept_set: kept.append(n) kept_set.add(n.id) # Rewrite dangling parent refs for n in kept: n.parent_ids = [p for p in n.parent_ids if p in kept_set] return kept def _consolidate_with_llm( event: FloodEvent, nodes: list[CascadeNode], llm_client: LLMClient, config: dict, ) -> list[CascadeNode] | None: """Call the consolidate prompt to merge a noisy chain down to ≤30 nodes.""" nodes_payload = json.dumps( [ { "id": n.id, "description": n.description, "domain": n.domain, "severity": n.severity, "time_offset_hours": n.time_offset_hours, "mechanism": n.mechanism, "parent_ids": n.parent_ids, } for n in nodes ], ensure_ascii=False, indent=2, ) variables = { "event_id": event.event_id, "country": event.country, "location": event.location or event.country, "start_date": str(event.start_date), "n_input": len(nodes), "nodes_json": nodes_payload, } try: response = llm_client.call_with_config( prompt_key="consolidate_cascade", knowledge_key="extraction", variables=variables, config=config, ) data = json.loads(_extract_json(response)) except (json.JSONDecodeError, KeyError, ValueError) as e: logger.error(f"Consolidation parse failed for {event.event_id}: {e}") return None consolidated: list[CascadeNode] = [] for nd in data.get("cascade_events", []): try: consolidated.append( CascadeNode( id=nd["id"], description=nd["description"], domain=nd.get("domain", "unknown"), severity=nd.get("severity", "medium"), time_offset_hours=nd.get("time_offset_hours"), mechanism=nd.get("mechanism", ""), parent_ids=nd.get("parent_ids", []), ) ) except (KeyError, ValueError) as e: logger.warning(f" skipping bad consolidated node: {e}") if not consolidated: return None # Re-validate domains in case LLM invented new labels during merge consolidated = _normalize_domains(consolidated) consolidated = _filter_scope(consolidated) if len(consolidated) > MAX_NODES: consolidated = _truncate_by_priority(consolidated, MAX_NODES) return consolidated def _renumber(nodes: list[CascadeNode]) -> list[CascadeNode]: """Renumber ids sequentially as E1..En and rewrite parent_ids.""" id_map: dict[str, str] = {} for i, node in enumerate(nodes, 1): new_id = f"E{i}" id_map[node.id] = new_id for node in nodes: node.id = id_map[node.id] node.parent_ids = [id_map.get(p, p) for p in node.parent_ids if p in id_map] return nodes