"""Generate Cypher 25 training pairs from official Neo4j cheat sheet patterns. Covers all Cypher 25 features: FILTER, LET, WHEN/THEN/ELSE, NEXT, shortest paths, quantified path patterns, dynamic labels/types, CALL subqueries, LOAD CSV, type predicates, and more. Each template is expanded with parametric variation (10x) to generate diverse training examples covering the full Cypher 25 syntax surface. """ from __future__ import annotations import json import random import re from pathlib import Path from typing import List, Tuple # ── Entity pools for parametric expansion ───────────────────────────────────── LABELS = [ "Person", "Movie", "User", "Customer", "Product", "Order", "Employee", "Department", "Company", "Actor", "Director", "Station", "Artist", "Student", "Teacher", "Doctor", "Patient", "Team", "Player", "Supplier", ] PROPS = [ "name", "title", "age", "email", "salary", "price", "score", "rating", "year", "status", "role", "level", "count", "date", "amount", ] RELS = [ "KNOWS", "ACTED_IN", "DIRECTED", "WORKS_FOR", "BOUGHT", "FOLLOWS", "MANAGES", "PLAYS_FOR", "TEACHES", "TREATS", "SUPPLIES", "REVIEWED", "LIKES", "REPORTS_TO", "MEMBER_OF", "LOCATED_IN", "CALLS_AT", "NEXT", ] NAMES = [ "Alice", "Bob", "Charlie", "Diana", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack", "Kate", "Leo", "Maya", "Noah", "Olivia", "Peter", ] VALUES = ["100", "200", "500", "1000", "50", "30", "10", "42", "99"] def _sub(template: str) -> str: """Substitute placeholders with random values.""" s = template s = s.replace("{L}", random.choice(LABELS)) s = s.replace("{L2}", random.choice(LABELS)) s = s.replace("{P}", random.choice(PROPS)) s = s.replace("{P2}", random.choice(PROPS)) s = s.replace("{R}", random.choice(RELS)) s = s.replace("{R2}", random.choice(RELS)) s = s.replace("{N}", random.choice(NAMES)) s = s.replace("{N2}", random.choice(NAMES)) s = s.replace("{V}", random.choice(VALUES)) return s # ── Cypher 25 training pair templates ───────────────────────────────────────── _TEMPLATES: List[Tuple[str, str]] = [ # --- FILTER clause (Cypher 25 new) --- ("Find {L} nodes where {P} is less than {V}", "MATCH (n:{L}) FILTER n.{P} < {V} RETURN n.{P2} AS {P2}"), ("Filter {L} by dynamic property {P}", "MATCH (n:{L}) FILTER n[$propname] > {V} RETURN n.{P} AS {P}, n.{P2} AS {P2}"), ("Load CSV and filter out null IDs", "LOAD CSV WITH HEADERS FROM 'file:///data.csv' AS row " "FILTER row.Id IS NOT NULL MERGE (c:{L} {{id: row.Id}})"), # --- LET clause (Cypher 25 new) --- ("Bind supplier and product names using LET", "MATCH (s:{L})-[:SUPPLIES]->(p:{L2}) " "LET supplier = s.{P}, product = p.{P2} RETURN supplier, product"), ("Chain LET expressions for classification", "MATCH (p:{L}) " "LET isExpensive = p.{P} >= {V} " "LET isAffordable = NOT isExpensive " "LET category = CASE WHEN isExpensive THEN 'High-end' ELSE 'Budget' END " "RETURN p.{P2} AS item, p.{P} AS price, isAffordable, category ORDER BY price"), # --- WHEN/THEN/ELSE conditional queries (Cypher 25 new) --- ("Conditional query: return different results based on condition", "WHEN true THEN RETURN 1 AS x " "WHEN false THEN RETURN 2 AS x " "ELSE RETURN 3 AS x"), ("Conditional match based on label", 'WHEN true THEN {{ MATCH (n:{L}) WHERE n.{P} STARTS WITH "{N}" ' "RETURN n.{P} AS result }} " "ELSE {{ MATCH (n:{L}) RETURN n.{P} AS result }}"), ("WHEN inside CALL subquery", "MATCH (n:{L}) OPTIONAL MATCH (n)-[:{R}]->(mgr:{L2}) " "CALL (*) {{ WHEN mgr IS NULL THEN {{ " "MERGE (newMgr:{L2} {{name: '{N}'}}) " "MERGE (n)-[:{R}]->(newMgr) " "RETURN newMgr, n.{P} AS employee }} }} " "RETURN newMgr.{P} AS manager, collect(employee) AS employees"), # --- NEXT sequential queries (Cypher 25 new) --- ("Sequential query with NEXT", "MATCH (c:{L})-[:{R}]->(p:{L2}) RETURN c AS customer, p AS product " "NEXT RETURN product.{P} AS product, COUNT(customer) AS total"), ("NEXT with aggregation", "MATCH (c:{L})-[:{R}]->(p:{L2}) RETURN c, p " "NEXT RETURN c.{P} AS name, COLLECT(p.{P2}) AS items"), ("NEXT with WHEN branches", "MATCH (c:{L})-[:{R}]->({L2})<-[:SUPPLIES]-(s:{L}) " "RETURN c.{P} AS customer, s.{P} AS supplier " "NEXT WHEN supplier = '{N}' THEN RETURN customer, 'VIP' AS tier " "ELSE RETURN customer, 'Standard' AS tier"), # --- Quantified path patterns --- ("Find paths of length 1 to 3 between stations", "MATCH (:{L} {{name: '{N}'}})<-[:{R}]-(d:{L2}) " "(({L2})-[:{R2}]->({L2})){{1,3}} " "(a:{L2})-[:{R}]->(:{L} {{name: '{N2}'}}) " "RETURN d.{P} AS start, a.{P} AS end"), ("Quantified relationship 1 to 10 hops", "MATCH (d:{L} {{name: '{N}'}})<-[:{R}]- " "(n:{L2})-[:{R2}]->{{1,10}}(m:{L2})-[:{R}]-> " "(a:{L} {{name: '{N2}'}}) " "WHERE m.{P} < {V} RETURN n.{P} AS departure"), # --- Shortest paths --- ("Find shortest path between two nodes", "MATCH p = SHORTEST 1 (a:{L})-[:{R}]-+(b:{L}) " "WHERE a.{P} = '{N}' AND b.{P} = '{N2}' " "RETURN length(p) AS pathLength"), ("Find all shortest paths", "MATCH p = ALL SHORTEST (a:{L})-[:{R}]-+(b:{L}) " "WHERE a.{P} = '{N}' AND b.{P} = '{N2}' " "RETURN [n in nodes(p) | n.{P}] AS stops"), ("Check reachability with ANY", "MATCH path = ANY (:{L} {{name: '{N}'}})-[:{R} WHERE r.{P2} < {V}]-+ " "(b:{L} {{name: '{N2}'}}) " "RETURN [r IN relationships(path) | r.{P2}] AS values"), # --- Dynamic labels and types --- ("Create node with dynamic label", "CREATE (n:$($nodeLabels) {{name: '{N}'}}) RETURN n"), ("Match with dynamic label variable", "MATCH (n) WHERE n:$($label) RETURN labels(n) AS labels"), ("Dynamic property access", "MATCH (n:{L}) WHERE n[$propname] > {V} " "RETURN n.{P} AS {P}, n.{P2} AS {P2}"), ("Set dynamic label", "MATCH (n) SET n:$($label)"), ("Remove dynamic label", "MATCH (n {{name: '{N}'}}) REMOVE n:$($label) RETURN n.name"), # --- CALL subqueries --- ("CALL subquery to collect players per team", "MATCH (t:{L}) CALL (t) {{ " "MATCH (p:{L2})-[:{R}]->(t) RETURN collect(p) as members }} " "RETURN t AS team, members"), ("OPTIONAL CALL subquery", "MATCH (p:{L}) OPTIONAL CALL (p) {{ " "MATCH (p)-[:{R}]->(t:{L2}) RETURN t }} " "RETURN p.{P} AS name, t.{P} AS target"), ("CALL subquery in transactions", "LOAD CSV FROM 'file:///data.csv' AS line " "CALL (line) {{ CREATE (:{L} {{name: line[1]}}) }} IN TRANSACTIONS OF 200 ROWS"), # --- Type predicates (Cypher 25) --- ("Filter by type predicate", "MATCH (n:{L}) WHERE n.{P} IS :: INTEGER AND n.{P} > {V} " "RETURN n.{P2} AS name, n.{P} AS value"), ("Check value is not a string", "UNWIND [42, true, 'abc', null] AS val " "RETURN val, val IS NOT :: STRING AS notString"), ("Dynamic union type check", "UNWIND [42, 42.0, '42'] AS val " "RETURN val, val IS :: INTEGER | FLOAT AS isNumber"), # --- EXISTS / COUNT / COLLECT subqueries --- ("Count subquery for filtering", "MATCH (p:{L}) WHERE COUNT {{ (p)-[:{R}]->({L2}) }} > 1 " "RETURN p.{P} AS name"), ("EXISTS subquery", "MATCH (p:{L}) WHERE EXISTS {{ " "MATCH (p)-[:{R}]->(d:{L2}) WHERE p.{P} = d.{P} }} " "RETURN p.{P} AS name"), ("COLLECT subquery to set property", "MATCH (p:{L}) WHERE p.{P} = '{N}' " "SET p.items = COLLECT {{ MATCH (p)-[:{R}]->(d:{L2}) RETURN d.{P} }} " "RETURN p.items"), # --- MATCH modes --- ("Default DIFFERENT RELATIONSHIPS match", "MATCH p = (:{L} {{name: '{N}'}})--{{7}}() RETURN count(p) AS pathCount"), ("REPEATABLE ELEMENTS match", "MATCH REPEATABLE ELEMENTS p = (:{L} {{name: '{N}'}})-[:{R}]-{{7}}() " "WITH collect(p)[0] AS samplePath " "RETURN [n IN nodes(samplePath) | n.{P}] AS nodes"), # --- Advanced patterns --- ("Pattern comprehension", "MATCH (a:{L} {{name: '{N}'}}) " "RETURN [(a)-->(b WHERE b:{L2}) | b.{P}] AS connected"), ("WHERE in variable-length pattern", "MATCH p = (a:{L} {{name: '{N}'}})-[r:{R} WHERE r.{P2} < {V}]->{{1,4}}(:{L}) " "RETURN [n IN nodes(p) | n.{P}] AS paths"), ("Equijoin pattern", "MATCH (n:{L} {{name: '{N}'}})<-[:{R}]-(s1:{L2}) " "-[:{R2}]->(s2:{L2})-[:{R}]->(n) " "RETURN s1.{P} AS outbound, s2.{P} AS inbound"), ("List comprehension with filter", "WITH [1, 2, 3, 4, 5] AS list " "RETURN [n IN list WHERE n > 2 | n] AS filtered"), # --- CRUD + advanced write --- ("Create with dynamic labels and relationship", "CREATE (a:$($nodeLabels) {{name: '{N}'}})-[r:$($relType)]->" "(m:{L} {{title: '{N2}'}}) RETURN a, r, m"), ("MERGE with ON CREATE / ON MATCH", "MERGE (n:{L} {{name: '{N}'}}) " "ON CREATE SET n.created = timestamp() " "ON MATCH SET n.counter = coalesce(n.counter, 0) + 1 " "RETURN n"), ("DETACH DELETE node", "MATCH (n:{L}) WHERE n.{P} = '{N}' DETACH DELETE n"), ("SET with map replacement", "MATCH (n:{L} {{name: '{N}'}}) SET n += $map"), ("Remove property", "MATCH (n:{L}) WHERE n.{P} = '{N}' REMOVE n.{P2}"), # --- Aggregation + WITH chaining --- ("WITH for aggregation and filtering", "MATCH (s:{L})-[r]->(p:{L2})<-[:{R}]-(c:{L}) " "WITH s, sum(p.{P}) AS totalSales, count(DISTINCT c) AS uniqueCustomers " "WHERE totalSales > {V} " "RETURN s.{P2} AS supplier, totalSales, uniqueCustomers"), ("WITH DISTINCT", "MATCH (c:{L}) WITH DISTINCT c.{P} AS values RETURN values ORDER BY values"), ("WITH ORDER BY LIMIT", "MATCH (c:{L})-[:{R}]->(p:{L2}) " "WITH c, sum(p.{P}) AS total ORDER BY total DESC LIMIT 3 " "SET c.topSpender = true RETURN c.{P2}, total"), # --- LOAD CSV --- ("Load CSV with headers and merge", "LOAD CSV WITH HEADERS FROM 'file:///data.csv' AS row " "MERGE (n:{L} {{name: row.Name}}) RETURN n"), ("Load CSV with dynamic labels", "LOAD CSV WITH HEADERS FROM 'file:///data.csv' AS line " "MERGE (n:$(line.Label) {{name: line.Name}}) RETURN n"), # --- UNION + WHEN composition --- ("UNION of different labels", "MATCH (n:{L}) RETURN n.{P} AS name UNION " "MATCH (n:{L2}) RETURN n.{P2} AS name"), ("UNION ALL", "MATCH (n:{L}) RETURN n.{P} AS name UNION ALL " "MATCH (n:{L2}) RETURN n.{P2} AS name"), # --- FINISH clause (Cypher 25 new) --- ("Set property and finish without returning", "MATCH (p:{L}) FINISH"), # --- FOREACH --- ("FOREACH to set properties along path", "MATCH p=(start)-[*]->(finish) WHERE start.{P} = '{N}' AND finish.{P} = '{N2}' " "FOREACH (n IN nodes(p) | SET n.marked = true)"), # --- UNWIND --- ("UNWIND list and create nodes", "WITH ['{N}', '{N2}'] AS names " "FOREACH (value IN names | CREATE (:{L} {{name: value}}))"), ("UNWIND for batch processing", "UNWIND $events AS event " "MERGE (y:{L} {{year: event.year}}) " "MERGE (y)<-[:IN]-(e:{L2} {{id: event.id}}) RETURN e.id"), # --- Schema operations --- ("Create index for search performance", "CREATE INDEX FOR (n:{L}) ON (n.{P})"), ("Create composite index", "CREATE INDEX FOR (n:{L}) ON (n.{P}, n.{P2})"), ("Create uniqueness constraint", "CREATE CONSTRAINT FOR (n:{L}) REQUIRE n.{P} IS UNIQUE"), ("Create node key constraint", "CREATE CONSTRAINT FOR (n:{L}) REQUIRE (n.{P}, n.{P2}) IS NODE KEY"), # --- Temporal + spatial --- ("Filter by date", "MATCH (o:{L}) WHERE o.date > datetime() - duration('P30D') " "RETURN o.{P}, o.date ORDER BY o.date DESC"), ("Duration arithmetic", "MATCH (u:{L}) WHERE u.registered.year = date().year RETURN u.{P}, u.registered"), ("Point distance", "MATCH (a:{L}), (b:{L}) " "WHERE point.distance(a.location, b.location) < {V} " "RETURN a.{P}, b.{P}"), # --- String functions --- ("STARTS WITH filter", "MATCH (n:{L}) WHERE n.{P} STARTS WITH '{N}' RETURN n.{P}"), ("CONTAINS filter", "MATCH (n:{L}) WHERE n.{P} CONTAINS 'eng' RETURN n.{P}"), ("Regex match", "MATCH (n:{L}) WHERE n.{P} =~ '(?i){N}.*' RETURN n.{P}"), # --- NULL handling --- ("IS NULL check", "MATCH (n:{L}) WHERE n.{P} IS NULL RETURN n.{P2}"), ("IS NOT NULL check", "MATCH (n:{L}) WHERE n.{P} IS NOT NULL RETURN n.{P2}, n.{P}"), ("COALESCE for defaults", "MATCH (n:{L}) RETURN coalesce(n.{P}, 'Unknown') AS value"), # --- CASE expressions --- ("Simple CASE expression", "MATCH (n:{L}) RETURN CASE n.{P} WHEN 'A' THEN 1 WHEN 'B' THEN 2 ELSE 3 END AS code"), ("Generic CASE with conditions", "MATCH (n:{L}) RETURN CASE WHEN n.{P} = '{N}' THEN 1 " "WHEN n.{P2} < {V} THEN 2 ELSE 3 END AS result"), ("Extended CASE with IS NULL (Cypher 25)", "MATCH (n:{L}) RETURN n.{P2}, CASE n.{P} " "WHEN IS NULL, IS NOT TYPED INTEGER | FLOAT THEN 'Unknown' " "WHEN = 0 THEN 'Zero' WHEN <= {V} THEN 'Low' ELSE 'High' END AS category"), # --- Label expressions --- ("OR label expression", "MATCH (n:{L}|{L2}) RETURN n.{P} AS name"), ("Negation label expression", "MATCH (n:!{L}) RETURN labels(n) AS label, count(n) AS cnt"), ("OR relationship type expression", "MATCH (:{L})<-[:{R}|{R2}]-(p:{L2}) RETURN p.{P}"), ] def generate_pairs(n_per_template: int = 10) -> List[Tuple[str, str]]: """Generate training pairs from all templates with parametric expansion.""" pairs: List[Tuple[str, str]] = [] for q_tmpl, c_tmpl in _TEMPLATES: # Add the raw template (with one substitution) for _ in range(n_per_template): q = _sub(q_tmpl) c = _sub(c_tmpl) pairs.append((q, c)) return pairs def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--n-per-template", type=int, default=10, help="Variants per template") parser.add_argument("--output", default="training/kan_bench_results/cypher25_training_data.json") args = parser.parse_args() pairs = generate_pairs(args.n_per_template) random.shuffle(pairs) data = [{"question": q, "gold": c, "domain": "cypher25"} for q, c in pairs] Path(args.output).parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as f: json.dump(data, f, indent=1) print(f"Generated {len(data)} Cypher 25 training pairs from {len(_TEMPLATES)} templates") print(f"Saved to {args.output}") # Also append to SOTA training data sota_path = Path("training/kan_bench_results/sota_training_data.json") if sota_path.exists(): existing = json.loads(sota_path.read_text()) existing.extend(data) sota_path.write_text(json.dumps(existing, indent=1)) print(f"Appended to SOTA data: {len(existing)} total pairs") if __name__ == "__main__": main()