datamatters24 commited on
Commit
a329614
·
verified ·
1 Parent(s): d531db8

Upload notebooks/02_entity_network/20_entity_cooccurrence.ipynb with huggingface_hub

Browse files
notebooks/02_entity_network/20_entity_cooccurrence.ipynb ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 20 - Entity Co-occurrence Analysis\n",
8
+ "\n",
9
+ "Pipeline notebook for computing entity co-occurrence from page-level entity extractions.\n",
10
+ "\n",
11
+ "For each page, computes all entity pairs (entity_a < entity_b) and aggregates co-occurrence\n",
12
+ "counts across the corpus. Results are stored in the `entity_relationships` table.\n",
13
+ "\n",
14
+ "**Incremental**: Only processes documents not yet in entity_relationships."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "tags": [
22
+ "parameters"
23
+ ]
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "# Parameters\n",
28
+ "source_section = None\n",
29
+ "min_count = 3\n",
30
+ "batch_size = 10000"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "import sys\n",
40
+ "sys.path.insert(0, '/opt/epstein_env/research')\n",
41
+ "\n",
42
+ "import pandas as pd\n",
43
+ "import numpy as np\n",
44
+ "from collections import Counter\n",
45
+ "from itertools import combinations\n",
46
+ "from tqdm.auto import tqdm\n",
47
+ "\n",
48
+ "from research_lib.db import fetch_df, fetch_all, bulk_insert, get_conn\n",
49
+ "from research_lib.incremental import start_run, finish_run, get_processed_doc_ids"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# Start incremental run\n",
59
+ "run_id = start_run(\n",
60
+ " 'entity_cooccurrence',\n",
61
+ " source_section=source_section,\n",
62
+ " parameters={'min_count': min_count, 'batch_size': batch_size},\n",
63
+ ")\n",
64
+ "print(f'Started run {run_id}')"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "# Get already-processed document IDs from entity_relationships\n",
74
+ "processed_ids = get_processed_doc_ids(\n",
75
+ " 'entity_cooccurrence',\n",
76
+ " feature_table='entity_relationships',\n",
77
+ ")\n",
78
+ "print(f'Already processed: {len(processed_ids)} documents')"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "# Build query for unprocessed entities\n",
88
+ "where_clauses = []\n",
89
+ "params = []\n",
90
+ "\n",
91
+ "if source_section:\n",
92
+ " where_clauses.append('d.source_section = %s')\n",
93
+ " params.append(source_section)\n",
94
+ "\n",
95
+ "if processed_ids:\n",
96
+ " where_clauses.append(f'e.document_id NOT IN ({\",\".join(str(i) for i in processed_ids)})')\n",
97
+ "\n",
98
+ "where_sql = ('WHERE ' + ' AND '.join(where_clauses)) if where_clauses else ''\n",
99
+ "\n",
100
+ "count_sql = f\"\"\"\n",
101
+ " SELECT COUNT(DISTINCT e.document_id)\n",
102
+ " FROM entities e\n",
103
+ " JOIN documents d ON d.id = e.document_id\n",
104
+ " {where_sql}\n",
105
+ "\"\"\"\n",
106
+ "total_docs = fetch_all(count_sql, params or None)[0]['count']\n",
107
+ "print(f'Documents to process: {total_docs}')"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "# Process entities in batches, computing co-occurrence pairs\n",
117
+ "pair_counter = Counter()\n",
118
+ "pair_types = {} # (entity_a, entity_b) -> (type_a, type_b)\n",
119
+ "doc_ids_processed = set()\n",
120
+ "offset = 0\n",
121
+ "\n",
122
+ "while True:\n",
123
+ " sql = f\"\"\"\n",
124
+ " SELECT e.document_id, e.page_number, e.entity_text, e.entity_type\n",
125
+ " FROM entities e\n",
126
+ " JOIN documents d ON d.id = e.document_id\n",
127
+ " {where_sql}\n",
128
+ " ORDER BY e.document_id, e.page_number\n",
129
+ " LIMIT %s OFFSET %s\n",
130
+ " \"\"\"\n",
131
+ " batch_params = (params or []) + [batch_size, offset]\n",
132
+ " batch_df = fetch_df(sql, batch_params)\n",
133
+ "\n",
134
+ " if batch_df.empty:\n",
135
+ " break\n",
136
+ "\n",
137
+ " # Group by (document_id, page_number) and compute pairs\n",
138
+ " for (doc_id, page_num), group in batch_df.groupby(['document_id', 'page_number']):\n",
139
+ " entities = list(zip(group['entity_text'], group['entity_type']))\n",
140
+ " doc_ids_processed.add(doc_id)\n",
141
+ "\n",
142
+ " for (text_a, type_a), (text_b, type_b) in combinations(entities, 2):\n",
143
+ " # Canonical ordering: alphabetical by entity text\n",
144
+ " if text_a > text_b:\n",
145
+ " text_a, text_b = text_b, text_a\n",
146
+ " type_a, type_b = type_b, type_a\n",
147
+ "\n",
148
+ " pair = (text_a, text_b)\n",
149
+ " pair_counter[pair] += 1\n",
150
+ " if pair not in pair_types:\n",
151
+ " pair_types[pair] = (type_a, type_b)\n",
152
+ "\n",
153
+ " offset += batch_size\n",
154
+ " print(f' Processed batch at offset {offset}, running pairs: {len(pair_counter)}')\n",
155
+ "\n",
156
+ "print(f'\\nTotal unique pairs found: {len(pair_counter)}')\n",
157
+ "print(f'Documents processed: {len(doc_ids_processed)}')"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": null,
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "# Filter by min_count and prepare rows for insertion\n",
167
+ "rows = []\n",
168
+ "for (entity_a, entity_b), count in pair_counter.items():\n",
169
+ " if count >= min_count:\n",
170
+ " type_a, type_b = pair_types[(entity_a, entity_b)]\n",
171
+ " rows.append((\n",
172
+ " entity_a,\n",
173
+ " type_a,\n",
174
+ " entity_b,\n",
175
+ " type_b,\n",
176
+ " count,\n",
177
+ " source_section,\n",
178
+ " ))\n",
179
+ "\n",
180
+ "print(f'Pairs after filtering (min_count={min_count}): {len(rows)}')"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "# Insert into entity_relationships\n",
190
+ "if rows:\n",
191
+ " inserted = bulk_insert(\n",
192
+ " 'entity_relationships',\n",
193
+ " ['entity_a', 'entity_a_type', 'entity_b', 'entity_b_type', 'co_occurrence_count', 'source_section'],\n",
194
+ " rows,\n",
195
+ " on_conflict='DO NOTHING',\n",
196
+ " )\n",
197
+ " print(f'Inserted {inserted} rows into entity_relationships')\n",
198
+ "else:\n",
199
+ " print('No rows to insert.')"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "# Finish run\n",
209
+ "finish_run(run_id, documents_processed=len(doc_ids_processed))\n",
210
+ "print(f'Run {run_id} completed.')"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "# Summary statistics\n",
220
+ "print('=== Co-occurrence Summary ===')\n",
221
+ "print(f'Total documents processed this run: {len(doc_ids_processed)}')\n",
222
+ "print(f'Total unique pairs (before filtering): {len(pair_counter)}')\n",
223
+ "print(f'Pairs stored (min_count >= {min_count}): {len(rows)}')\n",
224
+ "\n",
225
+ "if rows:\n",
226
+ " counts = [r[4] for r in rows]\n",
227
+ " print(f'Co-occurrence count range: {min(counts)} - {max(counts)}')\n",
228
+ " print(f'Mean co-occurrence count: {np.mean(counts):.1f}')\n",
229
+ " print(f'Median co-occurrence count: {np.median(counts):.1f}')\n",
230
+ "\n",
231
+ " # Top 20 pairs\n",
232
+ " print('\\nTop 20 co-occurring entity pairs:')\n",
233
+ " top_pairs = sorted(rows, key=lambda x: x[4], reverse=True)[:20]\n",
234
+ " for entity_a, type_a, entity_b, type_b, count, _ in top_pairs:\n",
235
+ " print(f' {entity_a} ({type_a}) <-> {entity_b} ({type_b}): {count}')"
236
+ ]
237
+ }
238
+ ],
239
+ "metadata": {
240
+ "kernelspec": {
241
+ "display_name": "Python 3",
242
+ "language": "python",
243
+ "name": "python3"
244
+ },
245
+ "language_info": {
246
+ "name": "python",
247
+ "version": "3.10.0"
248
+ }
249
+ },
250
+ "nbformat": 4,
251
+ "nbformat_minor": 5
252
+ }