Spaces:
Sleeping
Sleeping
| """ | |
| graph/write_graph.py | |
| Phase 3: Write entities and relationships to Neo4j. | |
| Reads from data/graph/embeddings_<year>.npy | |
| and data/graph/embedding_index_<year>.jsonl | |
| No GPU needed β pure network writes to Neo4j. | |
| IDEMPOTENT: Safe to re-run β relationships only update when ark_id not already present. | |
| Run: | |
| python -m graph.write_graph --year 1900 | |
| python -m graph.write_graph --all | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import functools | |
| import json | |
| import os | |
| import time | |
| from itertools import combinations | |
| from pathlib import Path | |
| import numpy as np | |
| # Flush all print output immediately so logs update in real time | |
| print = functools.partial(print, flush=True) | |
| # Prevent numpy from spawning extra threads (avoids memory bloat on cluster) | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| from graph.neo4j_client import get_session, create_schema | |
| OUTPUT_DIR = Path("data/graph") | |
| BATCH_SIZE = 100 # documents per Neo4j transaction | |
| def write_all(year: int = None, suffix: str = None): | |
| print("write_all() started") | |
| if suffix: | |
| file_suffix = suffix | |
| else: | |
| file_suffix = str(year) if year else "all" | |
| emb_file = OUTPUT_DIR / f"embeddings_{file_suffix}.npy" | |
| index_file = OUTPUT_DIR / f"embedding_index_{file_suffix}.jsonl" | |
| print(f"\n{'='*60}") | |
| print(f"BPL Graph β Phase 3: Write to Neo4j") | |
| print(f" Embeddings : {emb_file}") | |
| print(f" Index : {index_file}") | |
| print(f" Batch size : {BATCH_SIZE}") | |
| print(f"{'='*60}\n") | |
| if not emb_file.exists() or not index_file.exists(): | |
| raise FileNotFoundError( | |
| f"Missing files. Run Phase 1 and Phase 2 first." | |
| ) | |
| create_schema() | |
| # Load embeddings | |
| print("Loading embeddings...") | |
| embeddings = np.load(emb_file) | |
| print(f" Shape: {embeddings.shape}") | |
| # Load index | |
| print("Loading index...") | |
| records = [] | |
| with open(index_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| records.append(json.loads(line)) | |
| print(f" Found {len(records)} documents\n") | |
| start_time = time.monotonic() | |
| total_written = 0 | |
| for batch_start in range(0, len(records), BATCH_SIZE): | |
| batch = records[batch_start:batch_start + BATCH_SIZE] | |
| batch_end = min(batch_start + BATCH_SIZE, len(records)) | |
| # Build all data for the batch upfront | |
| docs_data = [] | |
| for record in batch: | |
| entities = record["entities"] | |
| emb_indices = record["emb_indices"] | |
| if not entities: | |
| continue | |
| doc_embs = embeddings[emb_indices] | |
| top_entities = entities[:10] | |
| docs_data.append({ | |
| "ark_id": record["ark_id"], | |
| "title": record["title"], | |
| "year": record["year"][0] if record["year"] else None, | |
| "institution": record["institution"], | |
| "source_url": record["source_url"], | |
| "issue_date": record["issue_date"], | |
| "entities": [ | |
| { | |
| "name": ent["text"], | |
| "type": ent["type"], | |
| "count": ent["count"], | |
| "embedding": doc_embs[i].tolist(), | |
| } | |
| for i, ent in enumerate(entities) | |
| ], | |
| "pairs": [ | |
| { | |
| "name1": e1["text"], "type1": e1["type"], | |
| "name2": e2["text"], "type2": e2["type"], | |
| } | |
| for e1, e2 in combinations(top_entities, 2) | |
| ], | |
| }) | |
| with get_session() as session: | |
| # Single round trip for all document + entity upserts in the batch | |
| session.run( | |
| """ | |
| UNWIND $docs AS doc | |
| MERGE (d:Document {ark_id: doc.ark_id}) | |
| SET d.title = doc.title, | |
| d.year = doc.year, | |
| d.institution = doc.institution, | |
| d.source_url = doc.source_url, | |
| d.issue_date = doc.issue_date | |
| WITH d, doc | |
| UNWIND doc.entities AS ent | |
| MERGE (e:Entity {name: ent.name, type: ent.type}) | |
| ON CREATE SET e.embedding = ent.embedding | |
| MERGE (d)-[r:MENTIONS]->(e) | |
| ON CREATE SET r.count = ent.count, | |
| r.documents = [doc.ark_id] | |
| ON MATCH SET r.count = CASE | |
| WHEN NOT doc.ark_id IN coalesce(r.documents, []) | |
| THEN coalesce(r.count, 0) + ent.count | |
| ELSE r.count | |
| END, | |
| r.documents = CASE | |
| WHEN NOT doc.ark_id IN coalesce(r.documents, []) | |
| THEN coalesce(r.documents, []) + [doc.ark_id] | |
| ELSE r.documents | |
| END | |
| """, | |
| docs=docs_data, | |
| ) | |
| # Co-occurrence β canonicalize pair ordering and make idempotent | |
| all_pairs_with_ark = [] | |
| for doc in docs_data: | |
| for p in doc["pairs"]: | |
| all_pairs_with_ark.append({ | |
| "ark_id": doc["ark_id"], | |
| "name1": p["name1"], | |
| "type1": p["type1"], | |
| "name2": p["name2"], | |
| "type2": p["type2"], | |
| }) | |
| if all_pairs_with_ark: | |
| session.run( | |
| """ | |
| UNWIND $pairs AS pair | |
| MATCH (e1:Entity {name: pair.name1, type: pair.type1}) | |
| MATCH (e2:Entity {name: pair.name2, type: pair.type2}) | |
| // Canonicalize: ensure consistent ordering (smaller name/type first) | |
| WITH e1, e2, pair, | |
| CASE | |
| WHEN pair.name1 < pair.name2 THEN e1 | |
| WHEN pair.name1 > pair.name2 THEN e2 | |
| WHEN pair.type1 <= pair.type2 THEN e1 | |
| ELSE e2 | |
| END AS a, | |
| CASE | |
| WHEN pair.name1 < pair.name2 THEN e2 | |
| WHEN pair.name1 > pair.name2 THEN e1 | |
| WHEN pair.type1 <= pair.type2 THEN e2 | |
| ELSE e1 | |
| END AS b | |
| MERGE (a)-[r:CO_OCCURS_WITH]-(b) | |
| ON CREATE SET r.weight = 1, | |
| r.documents = [pair.ark_id] | |
| ON MATCH SET r.weight = CASE | |
| WHEN NOT pair.ark_id IN coalesce(r.documents, []) | |
| THEN coalesce(r.weight, 0) + 1 | |
| ELSE r.weight | |
| END, | |
| r.documents = CASE | |
| WHEN NOT pair.ark_id IN coalesce(r.documents, []) | |
| THEN coalesce(r.documents, []) + [pair.ark_id] | |
| ELSE r.documents | |
| END | |
| """, | |
| pairs=all_pairs_with_ark, | |
| ) | |
| total_written += len(docs_data) | |
| elapsed = time.monotonic() - start_time | |
| remaining = (elapsed / total_written) * (len(records) - total_written) if total_written else 0 | |
| print( | |
| f" [{batch_end}/{len(records)}] " | |
| f"Written {total_written} docs | " | |
| f"ETA: {remaining/60:.1f}min" | |
| ) | |
| print(f"\nβ Graph write complete.") | |
| print(f" Documents written : {total_written}") | |
| print(f" Total time : {(time.monotonic()-start_time)/60:.1f} min") | |
| # ββ CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Phase 3: Write graph to Neo4j") | |
| parser.add_argument("--year", type=int, default=None) | |
| parser.add_argument("--all", action="store_true") | |
| parser.add_argument("--suffix", type=str, default=None, help="Explicit file suffix e.g. 'all_gpt' or 'metadata'") | |
| args = parser.parse_args() | |
| write_all( | |
| year = None if (args.all or args.suffix) else (args.year or 1900), | |
| suffix = args.suffix, | |
| ) |