Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Extract structured cascade chains from news articles using LLM.""" | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import re | |
| from datetime import date | |
| from pathlib import Path | |
| from src.llm.client import LLMClient, load_config | |
| from src.models.schemas import CascadeChain, CascadeNode, FloodEvent | |
| logger = logging.getLogger(__name__) | |
| # Single-shot character budget. Scaled up for v0.2 issue #10 (Qwen3.5-9B, | |
| # 64K context) so that all real flood events fit in one LLM call; | |
| # `_extract_batched` only triggers for genuinely huge corpora and now invokes | |
| # `_consolidate_with_llm` to merge across batches before code-side dedupe. | |
| MAX_ARTICLES_TEXT_LEN = 80000 | |
| # Per v0.2 issue #10: budget raised 30 → 40 to give the cascade DAG more room | |
| # for cross-domain depth without sacrificing the major-impacts-only bar. | |
| MAX_NODES = 40 | |
| MAX_TIME_OFFSET_HOURS = 336 # 2 weeks | |
| ALLOWED_DOMAINS = { | |
| "infrastructure/power", | |
| "infrastructure/water", | |
| "infrastructure/transport", | |
| "infrastructure/communication", | |
| "health/casualties", | |
| "health/hospital_service", | |
| "health/disease_outbreak", | |
| "social/evacuation", | |
| "social/supply_shortage", | |
| "economy/business_damage", | |
| "economy/agriculture", | |
| "environment/contamination", | |
| } | |
| # Legacy / near-miss labels seen from v1 runs → canonical | |
| DOMAIN_REMAP = { | |
| "health/disease": "health/disease_outbreak", | |
| "health/hospital": "health/hospital_service", | |
| "health/emergency_services": "health/hospital_service", | |
| "emergency_services": "health/hospital_service", | |
| "government/emergency_services": "health/hospital_service", | |
| "social/displacement": "social/evacuation", | |
| "social/housing": "social/evacuation", | |
| "social/mental_health": None, # out of scope | |
| "economy/business": "economy/business_damage", | |
| "environment/pollution": "environment/contamination", | |
| "environment/ecology": "environment/contamination", | |
| "environment/weather": None, # weather is the trigger, not a cascade | |
| "education": None, | |
| "education/business": None, | |
| } | |
| _STOPWORDS = { | |
| "the", "a", "an", "of", "to", "in", "and", "for", "on", "due", "with", | |
| "at", "by", "from", "as", "is", "are", "was", "were", "be", "been", | |
| "have", "has", "had", "caused", "resulted", "led", "flood", "flooding", | |
| "floodwaters", "floodwater", "waters", "water", "event", "events", | |
| "impact", "impacts", "disruption", "caused", "affecting", "affected", | |
| "after", "following", "during", "this", "that", "these", "those", | |
| } | |
| SEVERITY_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3} | |
| def extract_cascade_chain( | |
| event: FloodEvent, | |
| articles: list[dict], | |
| llm_client: LLMClient, | |
| config: dict | None = None, | |
| ) -> CascadeChain | None: | |
| """Extract a cascade chain from scraped articles for a single event.""" | |
| if config is None: | |
| config = load_config() | |
| if not articles: | |
| logger.warning(f"No articles for {event.event_id}, skipping extraction") | |
| return None | |
| articles_text = _format_articles(articles) | |
| if len(articles_text) > MAX_ARTICLES_TEXT_LEN: | |
| chain = _extract_batched(event, articles, llm_client, config) | |
| else: | |
| chain = _extract_single(event, articles_text, articles, llm_client, config) | |
| if chain is None: | |
| return None | |
| n_raw = len(chain.cascade_events) | |
| nodes = _normalize_domains(chain.cascade_events) | |
| nodes = _filter_scope(nodes) | |
| nodes = _semantic_dedupe(nodes) | |
| n_prefilter = len(nodes) | |
| if len(nodes) > MAX_NODES: | |
| logger.info( | |
| f" {event.event_id}: {n_prefilter} nodes → " | |
| f"diversified truncation to {MAX_NODES}" | |
| ) | |
| nodes = _truncate_by_priority(nodes, MAX_NODES) | |
| # Final renumber for clean sequential ids after all merges | |
| nodes = _renumber(nodes) | |
| logger.info( | |
| f" {event.event_id}: {n_raw} raw → {n_prefilter} after scope/dedupe → " | |
| f"{len(nodes)} final" | |
| ) | |
| chain.cascade_events = nodes | |
| return chain | |
| def _extract_single( | |
| event: FloodEvent, | |
| articles_text: str, | |
| articles: list[dict], | |
| llm_client: LLMClient, | |
| config: dict, | |
| ) -> CascadeChain | None: | |
| variables = { | |
| "event_id": event.event_id, | |
| "country": event.country, | |
| "location": event.location or event.country, | |
| "start_date": str(event.start_date), | |
| "origin": event.origin or "Unknown", | |
| "articles_text": articles_text, | |
| } | |
| response = llm_client.call_with_config( | |
| prompt_key="extract_cascade", | |
| knowledge_key="extraction", | |
| variables=variables, | |
| config=config, | |
| ) | |
| try: | |
| json_str = _extract_json(response) | |
| data = json.loads(json_str) | |
| return _parse_chain(data, event, articles) | |
| except (json.JSONDecodeError, KeyError, ValueError) as e: | |
| logger.error(f"Failed to parse cascade chain for {event.event_id}: {e}") | |
| return None | |
| def _extract_batched( | |
| event: FloodEvent, | |
| articles: list[dict], | |
| llm_client: LLMClient, | |
| config: dict, | |
| ) -> CascadeChain | None: | |
| """Run extraction batch by batch, then LLM-consolidate across batches. | |
| Code-side renumbering keeps the id space disjoint per batch but cannot | |
| decide whether two same-domain nodes from different batches actually | |
| refer to the same real-world cascade — that requires LLM judgement. | |
| Issue #10 wires `_consolidate_with_llm` into this path so the batched | |
| fallback gets the same cross-batch semantic merging that single-shot | |
| extraction does implicitly. With `MAX_ARTICLES_TEXT_LEN = 80000` only | |
| truly huge corpora hit this branch, so the extra LLM call is rare. | |
| """ | |
| batches = _batch_articles(articles) | |
| all_nodes: list[CascadeNode] = [] | |
| node_id_counter = 0 | |
| for batch in batches: | |
| batch_text = _format_articles(batch) | |
| chain = _extract_single(event, batch_text, batch, llm_client, config) | |
| if not chain: | |
| continue | |
| id_map: dict[str, str] = {} | |
| for node in chain.cascade_events: | |
| node_id_counter += 1 | |
| new_id = f"E{node_id_counter}" | |
| id_map[node.id] = new_id | |
| node.id = new_id | |
| for node in chain.cascade_events: | |
| node.parent_ids = [id_map.get(pid, pid) for pid in node.parent_ids] | |
| all_nodes.extend(chain.cascade_events) | |
| if not all_nodes: | |
| return None | |
| if len(batches) > 1: | |
| logger.info( | |
| f" {event.event_id}: consolidating {len(all_nodes)} nodes from " | |
| f"{len(batches)} batches via LLM" | |
| ) | |
| consolidated = _consolidate_with_llm(event, all_nodes, llm_client, config) | |
| if consolidated is not None: | |
| all_nodes = consolidated | |
| else: | |
| logger.warning( | |
| f" {event.event_id}: LLM consolidation failed; falling back to " | |
| "unconsolidated batch union (code-side dedupe will still run)" | |
| ) | |
| return CascadeChain( | |
| event_id=event.event_id, | |
| trigger_summary=f"Flood in {event.location or event.country} on {event.start_date}", | |
| trigger_country=event.country, | |
| trigger_iso=event.iso, | |
| trigger_date=event.start_date, | |
| trigger_severity=_infer_severity(event), | |
| cascade_events=all_nodes, | |
| source_articles=[a["url"] for a in articles], | |
| extraction_date=date.today(), | |
| ) | |
| def save_cascade_chain(chain: CascadeChain, config: dict | None = None) -> Path: | |
| if config is None: | |
| config = load_config() | |
| out_dir = Path(config["paths"]["cascade_chains_dir"]) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| path = out_dir / f"{chain.event_id}.json" | |
| path.write_text( | |
| json.dumps(chain.model_dump(mode="json"), indent=2, ensure_ascii=False) | |
| ) | |
| return path | |
| def _format_articles(articles: list[dict]) -> str: | |
| parts = [] | |
| for i, a in enumerate(articles, 1): | |
| parts.append( | |
| f"--- Article {i} ---\n" | |
| f"Title: {a.get('title', 'N/A')}\n" | |
| f"Date: {a.get('date', 'N/A')}\n" | |
| f"Source: {a.get('source', 'N/A')}\n" | |
| f"Domain: {a.get('domain', 'N/A')}\n\n" | |
| f"{a.get('text', '')}\n" | |
| ) | |
| return "\n".join(parts) | |
| def _batch_articles(articles: list[dict]) -> list[list[dict]]: | |
| batches = [] | |
| current_batch: list[dict] = [] | |
| current_len = 0 | |
| for article in articles: | |
| text_len = len(_format_articles([article])) | |
| if current_len + text_len > MAX_ARTICLES_TEXT_LEN and current_batch: | |
| batches.append(current_batch) | |
| current_batch = [] | |
| current_len = 0 | |
| current_batch.append(article) | |
| current_len += text_len | |
| if current_batch: | |
| batches.append(current_batch) | |
| return batches | |
| def _extract_json(response: str) -> str: | |
| if "```json" in response: | |
| return response.split("```json")[1].split("```")[0].strip() | |
| if "```" in response: | |
| return response.split("```")[1].split("```")[0].strip() | |
| return response.strip() | |
| def _parse_chain(data: dict, event: FloodEvent, articles: list[dict]) -> CascadeChain: | |
| nodes = [] | |
| for node_data in data.get("cascade_events", []): | |
| nodes.append( | |
| CascadeNode( | |
| id=node_data["id"], | |
| description=node_data["description"], | |
| domain=node_data.get("domain", "unknown"), | |
| severity=node_data.get("severity", "medium"), | |
| time_offset_hours=node_data.get("time_offset_hours"), | |
| mechanism=node_data.get("mechanism", ""), | |
| parent_ids=node_data.get("parent_ids", []), | |
| ) | |
| ) | |
| return CascadeChain( | |
| event_id=event.event_id, | |
| trigger_summary=data.get( | |
| "trigger_summary", | |
| f"Flood in {event.location or event.country} on {event.start_date}", | |
| ), | |
| trigger_country=event.country, | |
| trigger_iso=event.iso, | |
| trigger_date=event.start_date, | |
| trigger_severity=_infer_severity(event), | |
| cascade_events=nodes, | |
| source_articles=[a["url"] for a in articles], | |
| extraction_date=date.today(), | |
| ) | |
| def _infer_severity(event: FloodEvent) -> str: | |
| deaths = event.total_deaths or 0 | |
| affected = event.total_affected or 0 | |
| if deaths > 50 or affected > 100000: | |
| return "critical" | |
| if deaths > 10 or affected > 10000: | |
| return "high" | |
| if deaths > 0 or affected > 1000: | |
| return "medium" | |
| return "low" | |
| def _normalize_domains(nodes: list[CascadeNode]) -> list[CascadeNode]: | |
| """Map legacy domain labels to the closed taxonomy; drop nodes with no mapping.""" | |
| kept: list[CascadeNode] = [] | |
| for node in nodes: | |
| d = node.domain.strip() | |
| if d in ALLOWED_DOMAINS: | |
| kept.append(node) | |
| continue | |
| if d in DOMAIN_REMAP: | |
| target = DOMAIN_REMAP[d] | |
| if target is None: | |
| continue | |
| node.domain = target | |
| kept.append(node) | |
| continue | |
| # Unknown label: try top-level prefix match | |
| prefix = d.split("/", 1)[0] | |
| fallback = next( | |
| (x for x in ALLOWED_DOMAINS if x.startswith(prefix + "/")), | |
| None, | |
| ) | |
| if fallback: | |
| node.domain = fallback | |
| kept.append(node) | |
| # else: drop | |
| return kept | |
| def _filter_scope(nodes: list[CascadeNode]) -> list[CascadeNode]: | |
| """Drop out-of-window and trivially-low nodes.""" | |
| kept = [] | |
| for node in nodes: | |
| if ( | |
| node.time_offset_hours is not None | |
| and node.time_offset_hours > MAX_TIME_OFFSET_HOURS | |
| ): | |
| continue | |
| kept.append(node) | |
| return kept | |
| def _semantic_dedupe( | |
| nodes: list[CascadeNode], jaccard_threshold: float = 0.4 | |
| ) -> list[CascadeNode]: | |
| """Merge nodes in the same domain with high word-overlap descriptions. | |
| Cluster-merge keeps the longer description, picks higher severity, and | |
| takes the union of parent_ids as a starting point. The merged graph then | |
| has each node's parent_ids re-pruned by `_prune_ancestors` so the v0.2 | |
| issue #9 / B' no-grandparent rule still holds after structural merging | |
| (the union of two correctly-pruned sets can introduce new ancestor | |
| violations). | |
| """ | |
| by_domain: dict[str, list[CascadeNode]] = {} | |
| for n in nodes: | |
| by_domain.setdefault(n.domain, []).append(n) | |
| merged: list[CascadeNode] = [] | |
| id_rewrite: dict[str, str] = {} | |
| for domain, group in by_domain.items(): | |
| clusters: list[list[CascadeNode]] = [] | |
| for node in group: | |
| wset = _wordset(node.description) | |
| placed = False | |
| for cluster in clusters: | |
| rep_words = _wordset(cluster[0].description) | |
| # Merge if high jaccard OR one wordset is mostly subset of the other | |
| j = _jaccard(wset, rep_words) | |
| subset_ratio = _subset_ratio(wset, rep_words) | |
| if j >= jaccard_threshold or subset_ratio >= 0.75: | |
| cluster.append(node) | |
| placed = True | |
| break | |
| if not placed: | |
| clusters.append([node]) | |
| for cluster in clusters: | |
| rep = _merge_cluster(cluster) | |
| merged.append(rep) | |
| for n in cluster: | |
| id_rewrite[n.id] = rep.id | |
| # Rewrite parent refs and drop self-loops introduced by merging | |
| for n in merged: | |
| n.parent_ids = list({ | |
| id_rewrite.get(p, p) | |
| for p in n.parent_ids | |
| if id_rewrite.get(p, p) != n.id | |
| }) | |
| # Drop parent refs that no longer exist | |
| valid_ids = {n.id for n in merged} | |
| for n in merged: | |
| n.parent_ids = [p for p in n.parent_ids if p in valid_ids] | |
| # Re-apply no-grandparent rule across the merged graph. Cluster merging can | |
| # combine two correctly-pruned parent sets into one that now has an | |
| # ancestor violation (e.g. cluster A's parent X is an ancestor of cluster | |
| # B's parent Y once both are part of the merged node). | |
| for n in merged: | |
| n.parent_ids = _prune_ancestors(n.parent_ids, merged) | |
| return merged | |
| def _merge_cluster(cluster: list[CascadeNode]) -> CascadeNode: | |
| """Pick the representative from a cluster of near-duplicate nodes.""" | |
| # longest description = most specific | |
| rep = max(cluster, key=lambda n: len(n.description)) | |
| # highest severity | |
| rep.severity = max( | |
| (n.severity for n in cluster), | |
| key=lambda s: SEVERITY_RANK.get(s, 1), | |
| ) | |
| # earliest non-null time | |
| times = [n.time_offset_hours for n in cluster if n.time_offset_hours is not None] | |
| if times: | |
| rep.time_offset_hours = min(times) | |
| # union parents — _semantic_dedupe runs _prune_ancestors over the merged | |
| # graph afterward to enforce the no-grandparent rule. | |
| parents: set[str] = set() | |
| for n in cluster: | |
| parents.update(n.parent_ids) | |
| rep.parent_ids = sorted(parents) | |
| return rep | |
| def _wordset(text: str) -> set[str]: | |
| words = re.findall(r"[a-zA-Z]+", text.lower()) | |
| return {w for w in words if len(w) > 2 and w not in _STOPWORDS} | |
| def _jaccard(a: set[str], b: set[str]) -> float: | |
| if not a or not b: | |
| return 0.0 | |
| return len(a & b) / len(a | b) | |
| def _subset_ratio(a: set[str], b: set[str]) -> float: | |
| """Fraction of the smaller wordset contained in the larger.""" | |
| if not a or not b: | |
| return 0.0 | |
| small, large = (a, b) if len(a) <= len(b) else (b, a) | |
| return len(small & large) / len(small) | |
| def _prune_ancestors(parents: list[str], all_nodes: list[CascadeNode]) -> list[str]: | |
| """Drop any parent that is an ancestor of another listed parent. | |
| Implements the v0.2 issue #9 / B' no-grandparent rule: when a node | |
| declares multiple parents, only the most-direct ones are kept; transitive | |
| ancestors are pruned because they are already implicit in the graph. | |
| `all_nodes` is the full node set used to build the parent-of relation by | |
| walking each node's ``parent_ids`` upward. | |
| """ | |
| if len(parents) <= 1: | |
| return list(parents) | |
| parent_map = {n.id: list(n.parent_ids) for n in all_nodes} | |
| def ancestors_of(nid: str) -> set[str]: | |
| seen: set[str] = set() | |
| stack = list(parent_map.get(nid, [])) | |
| while stack: | |
| cur = stack.pop() | |
| if cur in seen: | |
| continue | |
| seen.add(cur) | |
| stack.extend(parent_map.get(cur, [])) | |
| return seen | |
| candidate_set = set(parents) | |
| drop: set[str] = set() | |
| for p in parents: | |
| anc = ancestors_of(p) | |
| for q in candidate_set: | |
| if q != p and q in anc: | |
| drop.add(q) | |
| return [p for p in parents if p not in drop] | |
| def _truncate_by_priority(nodes: list[CascadeNode], budget: int) -> list[CascadeNode]: | |
| """Keep a diverse, high-priority subset within budget. | |
| Strategy: | |
| 1. Ensure at least 1 node per domain present in the input | |
| 2. Fill remaining slots ranked by (severity, parent_count, early time) — | |
| multi-parent nodes rank higher because they carry more structural | |
| signal under the v0.2 issue #9 / B' AND-conjunction semantics. | |
| """ | |
| def sev_rank(n: CascadeNode) -> int: | |
| return SEVERITY_RANK.get(n.severity, 1) | |
| def fill_rank(n: CascadeNode) -> tuple: | |
| return ( | |
| -sev_rank(n), | |
| -len(n.parent_ids), | |
| n.time_offset_hours or 9999, | |
| ) | |
| kept: list[CascadeNode] = [] | |
| kept_set: set[str] = set() | |
| # Pass 1: best node per domain | |
| by_domain: dict[str, list[CascadeNode]] = {} | |
| for n in nodes: | |
| by_domain.setdefault(n.domain, []).append(n) | |
| for dom, group in by_domain.items(): | |
| if len(kept) >= budget: | |
| break | |
| best = min(group, key=fill_rank) | |
| kept.append(best) | |
| kept_set.add(best.id) | |
| # Pass 2: fill remaining by global priority | |
| for n in sorted(nodes, key=fill_rank): | |
| if len(kept) >= budget: | |
| break | |
| if n.id not in kept_set: | |
| kept.append(n) | |
| kept_set.add(n.id) | |
| # Rewrite dangling parent refs | |
| for n in kept: | |
| n.parent_ids = [p for p in n.parent_ids if p in kept_set] | |
| return kept | |
| def _consolidate_with_llm( | |
| event: FloodEvent, | |
| nodes: list[CascadeNode], | |
| llm_client: LLMClient, | |
| config: dict, | |
| ) -> list[CascadeNode] | None: | |
| """Call the consolidate prompt to merge a noisy chain down to ≤30 nodes.""" | |
| nodes_payload = json.dumps( | |
| [ | |
| { | |
| "id": n.id, | |
| "description": n.description, | |
| "domain": n.domain, | |
| "severity": n.severity, | |
| "time_offset_hours": n.time_offset_hours, | |
| "mechanism": n.mechanism, | |
| "parent_ids": n.parent_ids, | |
| } | |
| for n in nodes | |
| ], | |
| ensure_ascii=False, | |
| indent=2, | |
| ) | |
| variables = { | |
| "event_id": event.event_id, | |
| "country": event.country, | |
| "location": event.location or event.country, | |
| "start_date": str(event.start_date), | |
| "n_input": len(nodes), | |
| "nodes_json": nodes_payload, | |
| } | |
| try: | |
| response = llm_client.call_with_config( | |
| prompt_key="consolidate_cascade", | |
| knowledge_key="extraction", | |
| variables=variables, | |
| config=config, | |
| ) | |
| data = json.loads(_extract_json(response)) | |
| except (json.JSONDecodeError, KeyError, ValueError) as e: | |
| logger.error(f"Consolidation parse failed for {event.event_id}: {e}") | |
| return None | |
| consolidated: list[CascadeNode] = [] | |
| for nd in data.get("cascade_events", []): | |
| try: | |
| consolidated.append( | |
| CascadeNode( | |
| id=nd["id"], | |
| description=nd["description"], | |
| domain=nd.get("domain", "unknown"), | |
| severity=nd.get("severity", "medium"), | |
| time_offset_hours=nd.get("time_offset_hours"), | |
| mechanism=nd.get("mechanism", ""), | |
| parent_ids=nd.get("parent_ids", []), | |
| ) | |
| ) | |
| except (KeyError, ValueError) as e: | |
| logger.warning(f" skipping bad consolidated node: {e}") | |
| if not consolidated: | |
| return None | |
| # Re-validate domains in case LLM invented new labels during merge | |
| consolidated = _normalize_domains(consolidated) | |
| consolidated = _filter_scope(consolidated) | |
| if len(consolidated) > MAX_NODES: | |
| consolidated = _truncate_by_priority(consolidated, MAX_NODES) | |
| return consolidated | |
| def _renumber(nodes: list[CascadeNode]) -> list[CascadeNode]: | |
| """Renumber ids sequentially as E1..En and rewrite parent_ids.""" | |
| id_map: dict[str, str] = {} | |
| for i, node in enumerate(nodes, 1): | |
| new_id = f"E{i}" | |
| id_map[node.id] = new_id | |
| for node in nodes: | |
| node.id = id_map[node.id] | |
| node.parent_ids = [id_map.get(p, p) for p in node.parent_ids if p in id_map] | |
| return nodes | |