datamatters24 commited on
Commit
fe76792
·
verified ·
1 Parent(s): 6c33ee2

Upload ml/03_correlate_crises.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ml/03_correlate_crises.py +322 -0
ml/03_correlate_crises.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 3: Correlate documents with historical crisis events.
4
+
5
+ Scoring methods:
6
+ 1. Date overlap — document date falls within event date range (or within buffer)
7
+ 2. Keyword match — OCR text contains event-specific keywords
8
+ 3. Entity overlap — extracted entities match event keywords
9
+ 4. Collection affinity — source_section naturally maps to certain events
10
+
11
+ Each method contributes a partial score; combined score determines relevance.
12
+
13
+ Populates: document_events table
14
+ """
15
+
16
+ import json
17
+ import logging
18
+ import re
19
+ import sys
20
+ from datetime import date, timedelta
21
+
22
+ import psycopg2
23
+ import psycopg2.extras
24
+ from config import BATCH_SIZE
25
+ from db import get_conn, fetch_all
26
+
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format="%(asctime)s %(levelname)-8s %(message)s",
30
+ handlers=[logging.StreamHandler(sys.stdout)],
31
+ )
32
+ log = logging.getLogger(__name__)
33
+
34
+ # Buffer days: documents created within this window of an event still correlate
35
+ DATE_BUFFER_DAYS = 365
36
+
37
+ # Minimum total score to record a correlation
38
+ MIN_SCORE = 0.15
39
+
40
+ # Collection -> event affinity (source_section maps naturally to events)
41
+ COLLECTION_AFFINITY = {
42
+ "jfk_assassination": ["JFK Assassination", "RFK Assassination", "MLK Assassination",
43
+ "Church Committee Investigations", "Warren Commission"],
44
+ "cia_mkultra": ["MKUltra Program", "Church Committee Investigations"],
45
+ "cia_stargate": ["CIA Stargate / Remote Viewing Program"],
46
+ "cia_declassified": ["Bay of Pigs Invasion", "Cuban Missile Crisis",
47
+ "Area 51 / U-2 Program", "Iran-Contra Affair"],
48
+ "lincoln_archives": ["Lincoln Assassination", "Civil War End / Reconstruction"],
49
+ }
50
+
51
+
52
+ def load_events():
53
+ """Load all historical events from DB."""
54
+ rows = fetch_all("SELECT * FROM historical_events ORDER BY id")
55
+ events = []
56
+ for r in rows:
57
+ kw = r["keywords"]
58
+ if isinstance(kw, str):
59
+ kw = json.loads(kw)
60
+ events.append({
61
+ "id": r["id"],
62
+ "name": r["event_name"],
63
+ "start": r["start_date"],
64
+ "end": r["end_date"] or r["start_date"],
65
+ "category": r["category"],
66
+ "keywords": [k.lower() for k in kw],
67
+ })
68
+ return events
69
+
70
+
71
+ def score_date_overlap(doc_date: date | None, doc_range_start: date | None,
72
+ doc_range_end: date | None, event: dict) -> float:
73
+ """Score based on temporal overlap between document date and event."""
74
+ if not doc_date:
75
+ return 0.0
76
+
77
+ ev_start = event["start"] - timedelta(days=DATE_BUFFER_DAYS)
78
+ ev_end = event["end"] + timedelta(days=DATE_BUFFER_DAYS)
79
+
80
+ # Direct overlap: doc date within event range (no buffer)
81
+ if event["start"] <= doc_date <= event["end"]:
82
+ return 0.5
83
+
84
+ # Within buffer range
85
+ if ev_start <= doc_date <= ev_end:
86
+ # Score decays with distance
87
+ if doc_date < event["start"]:
88
+ days_away = (event["start"] - doc_date).days
89
+ else:
90
+ days_away = (doc_date - event["end"]).days
91
+ decay = max(0, 1.0 - days_away / DATE_BUFFER_DAYS)
92
+ return 0.3 * decay
93
+
94
+ # Check doc range overlap with event range
95
+ if doc_range_start and doc_range_end:
96
+ if doc_range_start <= event["end"] and doc_range_end >= event["start"]:
97
+ # Partial overlap
98
+ overlap_start = max(doc_range_start, event["start"])
99
+ overlap_end = min(doc_range_end, event["end"])
100
+ overlap_days = (overlap_end - overlap_start).days + 1
101
+ doc_span = (doc_range_end - doc_range_start).days + 1
102
+ if doc_span > 0:
103
+ return 0.3 * min(overlap_days / doc_span, 1.0)
104
+
105
+ return 0.0
106
+
107
+
108
+ def score_keyword_match(doc_id: int, event: dict, conn) -> tuple[float, list[str]]:
109
+ """
110
+ Score based on keyword matches in the first few pages of OCR text.
111
+ Returns (score, matched_keywords).
112
+ """
113
+ if not event["keywords"]:
114
+ return 0.0, []
115
+
116
+ with conn.cursor() as cur:
117
+ cur.execute(
118
+ """SELECT string_agg(ocr_text, ' ') as combined_text
119
+ FROM (
120
+ SELECT ocr_text FROM pages
121
+ WHERE document_id = %s AND ocr_text IS NOT NULL
122
+ ORDER BY page_number
123
+ LIMIT 5
124
+ ) sub""",
125
+ (doc_id,),
126
+ )
127
+ row = cur.fetchone()
128
+
129
+ if not row or not row[0]:
130
+ return 0.0, []
131
+
132
+ text_lower = row[0].lower()
133
+ matched = []
134
+
135
+ for kw in event["keywords"]:
136
+ # Use word boundary matching for short keywords to avoid false positives
137
+ if len(kw) < 5:
138
+ pattern = r'\b' + re.escape(kw) + r'\b'
139
+ if re.search(pattern, text_lower):
140
+ matched.append(kw)
141
+ else:
142
+ if kw in text_lower:
143
+ matched.append(kw)
144
+
145
+ if not matched:
146
+ return 0.0, matched
147
+
148
+ # Score: more keyword matches = higher score, max 0.4
149
+ ratio = len(matched) / len(event["keywords"])
150
+ return min(0.4, 0.15 + 0.25 * ratio), matched
151
+
152
+
153
+ def score_entity_match(doc_id: int, event: dict, conn) -> tuple[float, list[str]]:
154
+ """
155
+ Score based on entity overlap (PERSON, ORG, GPE, EVENT entities
156
+ matching event keywords).
157
+ """
158
+ if not event["keywords"]:
159
+ return 0.0, []
160
+
161
+ with conn.cursor() as cur:
162
+ cur.execute(
163
+ """SELECT DISTINCT lower(entity_text) as ent
164
+ FROM entities
165
+ WHERE document_id = %s
166
+ AND entity_type IN ('PERSON', 'ORG', 'GPE', 'EVENT', 'NORP')""",
167
+ (doc_id,),
168
+ )
169
+ entities = {row[0] for row in cur.fetchall()}
170
+
171
+ if not entities:
172
+ return 0.0, []
173
+
174
+ matched = []
175
+ for kw in event["keywords"]:
176
+ for ent in entities:
177
+ if kw in ent or ent in kw:
178
+ matched.append(kw)
179
+ break
180
+
181
+ if not matched:
182
+ return 0.0, matched
183
+
184
+ return min(0.3, 0.1 + 0.2 * len(matched) / len(event["keywords"])), matched
185
+
186
+
187
+ def score_collection_affinity(source_section: str, event: dict) -> float:
188
+ """Score based on natural mapping between collection and event."""
189
+ affinity_events = COLLECTION_AFFINITY.get(source_section, [])
190
+ if event["name"] in affinity_events:
191
+ return 0.2
192
+ return 0.0
193
+
194
+
195
+ def process_correlations():
196
+ """Main correlation loop."""
197
+ events = load_events()
198
+ log.info(f"Loaded {len(events)} historical events")
199
+
200
+ conn = get_conn()
201
+ conn.autocommit = False
202
+
203
+ # Get documents with dates, that haven't been correlated yet
204
+ with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
205
+ cur.execute("""
206
+ SELECT d.id, d.source_section, d.file_path,
207
+ dd.estimated_date, dd.date_range_start, dd.date_range_end
208
+ FROM documents d
209
+ LEFT JOIN document_dates dd ON dd.document_id = d.id
210
+ WHERE NOT EXISTS (
211
+ SELECT 1 FROM document_events de WHERE de.document_id = d.id
212
+ )
213
+ ORDER BY d.id
214
+ """)
215
+ docs = cur.fetchall()
216
+
217
+ total = len(docs)
218
+ log.info(f"Processing {total} documents for crisis correlation")
219
+
220
+ batch = []
221
+ processed = 0
222
+ correlations_found = 0
223
+
224
+ for doc in docs:
225
+ doc_id = doc["id"]
226
+
227
+ for event in events:
228
+ methods = []
229
+ details = {}
230
+ total_score = 0.0
231
+
232
+ # 1. Date overlap
233
+ date_score = score_date_overlap(
234
+ doc["estimated_date"], doc["date_range_start"],
235
+ doc["date_range_end"], event
236
+ )
237
+ if date_score > 0:
238
+ total_score += date_score
239
+ methods.append("date")
240
+ details["date_score"] = round(date_score, 3)
241
+
242
+ # 2. Collection affinity (cheap — no DB query)
243
+ affinity = score_collection_affinity(doc["source_section"] or "", event)
244
+ if affinity > 0:
245
+ total_score += affinity
246
+ methods.append("collection")
247
+
248
+ # Only do expensive keyword/entity lookups if we already have some signal
249
+ # or if the collection has natural affinity
250
+ if total_score > 0.05 or affinity > 0:
251
+ # 3. Keyword match
252
+ kw_score, kw_matched = score_keyword_match(doc_id, event, conn)
253
+ if kw_score > 0:
254
+ total_score += kw_score
255
+ methods.append("keyword")
256
+ details["matched_keywords"] = kw_matched
257
+
258
+ # 4. Entity match
259
+ ent_score, ent_matched = score_entity_match(doc_id, event, conn)
260
+ if ent_score > 0:
261
+ total_score += ent_score
262
+ methods.append("entity")
263
+ details["matched_entities"] = ent_matched
264
+
265
+ if total_score >= MIN_SCORE:
266
+ batch.append((
267
+ doc_id, event["id"], round(total_score, 4),
268
+ json.dumps(methods), json.dumps(details),
269
+ ))
270
+ correlations_found += 1
271
+
272
+ processed += 1
273
+ if processed % 500 == 0:
274
+ if batch:
275
+ _flush_batch(conn, batch)
276
+ batch = []
277
+ log.info(
278
+ f"Progress: {processed}/{total} ({processed*100//total}%) "
279
+ f"— {correlations_found} correlations found"
280
+ )
281
+
282
+ if batch:
283
+ _flush_batch(conn, batch)
284
+
285
+ conn.close()
286
+ log.info(f"Done. {processed} docs processed, {correlations_found} correlations found.")
287
+
288
+ # Print summary
289
+ stats = fetch_all("""
290
+ SELECT he.event_name, COUNT(*) as doc_count,
291
+ ROUND(AVG(de.relevance_score)::numeric, 3) as avg_score
292
+ FROM document_events de
293
+ JOIN historical_events he ON he.id = de.event_id
294
+ GROUP BY he.event_name
295
+ ORDER BY doc_count DESC
296
+ """)
297
+ log.info("Crisis correlation summary:")
298
+ for row in stats:
299
+ log.info(f" {row['event_name']}: {row['doc_count']} docs (avg score: {row['avg_score']})")
300
+
301
+
302
+ def _flush_batch(conn, batch):
303
+ with conn.cursor() as cur:
304
+ psycopg2.extras.execute_batch(
305
+ cur,
306
+ """INSERT INTO document_events
307
+ (document_id, event_id, relevance_score, match_methods, details)
308
+ VALUES (%s, %s, %s, %s, %s)
309
+ ON CONFLICT (document_id, event_id) DO UPDATE SET
310
+ relevance_score = EXCLUDED.relevance_score,
311
+ match_methods = EXCLUDED.match_methods,
312
+ details = EXCLUDED.details,
313
+ created_at = NOW()
314
+ """,
315
+ batch,
316
+ page_size=500,
317
+ )
318
+ conn.commit()
319
+
320
+
321
+ if __name__ == "__main__":
322
+ process_correlations()