"""Build Python code generation + GraphQL generation training datasets. Each pair: question → gold_output with multi-dimensional reward signals. Usage: python3 scripts/build_python_graphql_datasets.py """ from __future__ import annotations import json from pathlib import Path from typing import Dict, List RESULTS_DIR = Path("training/kan_bench_results") SOTA_PATH = RESULTS_DIR / "sota_training_data.json" def _reward(syntax=1.0, pythonic=1.0, type_correct=1.0, tests=1.0, complexity=1.0) -> Dict[str, float]: return {"syntax_valid": syntax, "pythonic": pythonic, "type_correct": type_correct, "passes_tests": tests, "complexity_appropriate": complexity, "composite": round((syntax + pythonic + type_correct + tests + complexity) / 5, 3)} def _gql_reward(schema=1.0, type_correct=1.0, efficiency=1.0, syntax=1.0) -> Dict[str, float]: return {"schema_valid": schema, "type_correct": type_correct, "efficiency": efficiency, "syntax_valid": syntax, "composite": round((schema + type_correct + efficiency + syntax) / 4, 3)} # ── Python Code Generation ────────────────────────────────────────────────── def _build_python_data_processing() -> List[Dict]: """Data processing / pandas patterns.""" pairs = [] templates = [ ("Write a function to calculate moving average of a list", 'def moving_average(data: list[float], window: int) -> list[float]:\n """Calculate moving average with given window size."""\n if window <= 0 or window > len(data):\n return []\n result = []\n for i in range(len(data) - window + 1):\n avg = sum(data[i:i + window]) / window\n result.append(round(avg, 4))\n return result', "data_processing"), ("Write a function to merge two sorted lists", 'def merge_sorted(a: list[int], b: list[int]) -> list[int]:\n """Merge two sorted lists into one sorted list."""\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n result.extend(a[i:])\n result.extend(b[j:])\n return result', "data_processing"), ("Write a function to group items by a key function", 'from collections import defaultdict\nfrom typing import Callable, TypeVar, Hashable\n\nT = TypeVar("T")\nK = TypeVar("K", bound=Hashable)\n\ndef group_by(items: list[T], key_fn: Callable[[T], K]) -> dict[K, list[T]]:\n """Group items by the result of key_fn."""\n groups: dict[K, list[T]] = defaultdict(list)\n for item in items:\n groups[key_fn(item)].append(item)\n return dict(groups)', "data_processing"), ("Write a function to flatten nested dictionaries", 'def flatten_dict(d: dict, prefix: str = "", sep: str = ".") -> dict[str, any]:\n """Flatten nested dict with dot-separated keys."""\n result = {}\n for key, value in d.items():\n new_key = f"{prefix}{sep}{key}" if prefix else key\n if isinstance(value, dict):\n result.update(flatten_dict(value, new_key, sep))\n else:\n result[new_key] = value\n return result', "data_processing"), ("Write a function to deduplicate preserving order", 'def deduplicate(items: list) -> list:\n """Remove duplicates while preserving insertion order."""\n seen = set()\n result = []\n for item in items:\n if item not in seen:\n seen.add(item)\n result.append(item)\n return result', "data_processing"), ("Write a CSV parser that handles quoted fields", 'def parse_csv_line(line: str, delimiter: str = ",") -> list[str]:\n """Parse a single CSV line handling quoted fields with commas."""\n fields = []\n current = []\n in_quotes = False\n for char in line:\n if char == \'"\' and not in_quotes:\n in_quotes = True\n elif char == \'"\' and in_quotes:\n in_quotes = False\n elif char == delimiter and not in_quotes:\n fields.append("".join(current).strip())\n current = []\n else:\n current.append(char)\n fields.append("".join(current).strip())\n return fields', "data_processing"), ] for q, code, cat in templates: pairs.append({"source": q, "target": code, "dialect": "python", "category": cat, "reward_signals": _reward()}) return pairs def _build_python_async() -> List[Dict]: """Async/concurrency patterns.""" pairs = [] templates = [ ("Write an async function to fetch multiple URLs concurrently", 'import asyncio\nimport aiohttp\n\nasync def fetch_all(urls: list[str], timeout: int = 30) -> list[dict]:\n """Fetch multiple URLs concurrently and return results."""\n async def fetch_one(session: aiohttp.ClientSession, url: str) -> dict:\n try:\n async with session.get(url, timeout=aiohttp.ClientTimeout(total=timeout)) as resp:\n return {"url": url, "status": resp.status, "body": await resp.text()}\n except Exception as e:\n return {"url": url, "status": -1, "error": str(e)}\n\n async with aiohttp.ClientSession() as session:\n tasks = [fetch_one(session, url) for url in urls]\n return await asyncio.gather(*tasks)', "async"), ("Write a rate limiter using asyncio semaphore", 'import asyncio\nfrom typing import Callable, Awaitable, TypeVar\n\nT = TypeVar("T")\n\nclass RateLimiter:\n """Limit concurrent async operations."""\n\n def __init__(self, max_concurrent: int = 10):\n self._semaphore = asyncio.Semaphore(max_concurrent)\n\n async def execute(self, fn: Callable[..., Awaitable[T]], *args, **kwargs) -> T:\n async with self._semaphore:\n return await fn(*args, **kwargs)', "async"), ("Write a producer-consumer pattern with asyncio queue", 'import asyncio\nfrom typing import Any, Callable, Awaitable\n\nasync def producer_consumer(\n items: list[Any],\n process_fn: Callable[[Any], Awaitable[Any]],\n n_consumers: int = 5,\n) -> list[Any]:\n """Process items with N concurrent consumers."""\n queue: asyncio.Queue = asyncio.Queue()\n results: list[Any] = []\n\n for item in items:\n await queue.put(item)\n\n async def consumer():\n while not queue.empty():\n try:\n item = queue.get_nowait()\n except asyncio.QueueEmpty:\n break\n result = await process_fn(item)\n results.append(result)\n queue.task_done()\n\n consumers = [asyncio.create_task(consumer()) for _ in range(n_consumers)]\n await asyncio.gather(*consumers)\n return results', "async"), ] for q, code, cat in templates: pairs.append({"source": q, "target": code, "dialect": "python", "category": cat, "reward_signals": _reward()}) return pairs def _build_python_design_patterns() -> List[Dict]: """Design patterns in Python.""" pairs = [] templates = [ ("Implement the Observer pattern in Python", 'from abc import ABC, abstractmethod\nfrom typing import Any\n\nclass Observer(ABC):\n @abstractmethod\n def update(self, event: str, data: Any) -> None: ...\n\nclass Subject:\n def __init__(self):\n self._observers: list[Observer] = []\n\n def attach(self, observer: Observer) -> None:\n self._observers.append(observer)\n\n def detach(self, observer: Observer) -> None:\n self._observers.remove(observer)\n\n def notify(self, event: str, data: Any = None) -> None:\n for observer in self._observers:\n observer.update(event, data)', "design_pattern"), ("Implement the Strategy pattern in Python", 'from abc import ABC, abstractmethod\nfrom typing import TypeVar\n\nT = TypeVar("T")\n\nclass Strategy(ABC):\n @abstractmethod\n def execute(self, data: list[float]) -> float: ...\n\nclass MeanStrategy(Strategy):\n def execute(self, data: list[float]) -> float:\n return sum(data) / len(data) if data else 0.0\n\nclass MedianStrategy(Strategy):\n def execute(self, data: list[float]) -> float:\n if not data:\n return 0.0\n s = sorted(data)\n n = len(s)\n return (s[n // 2] + s[(n - 1) // 2]) / 2\n\nclass Aggregator:\n def __init__(self, strategy: Strategy):\n self._strategy = strategy\n\n def aggregate(self, data: list[float]) -> float:\n return self._strategy.execute(data)', "design_pattern"), ("Implement a builder pattern for configuration objects", 'from dataclasses import dataclass, field\nfrom typing import Optional\n\n@dataclass(frozen=True)\nclass Config:\n host: str\n port: int\n database: str\n user: str\n password: str\n pool_size: int = 5\n timeout: int = 30\n ssl: bool = True\n\nclass ConfigBuilder:\n def __init__(self):\n self._host = "localhost"\n self._port = 5432\n self._database = "default"\n self._user = "admin"\n self._password = ""\n self._pool_size = 5\n self._timeout = 30\n self._ssl = True\n\n def host(self, h: str) -> "ConfigBuilder":\n self._host = h\n return self\n\n def port(self, p: int) -> "ConfigBuilder":\n self._port = p\n return self\n\n def database(self, d: str) -> "ConfigBuilder":\n self._database = d\n return self\n\n def credentials(self, user: str, password: str) -> "ConfigBuilder":\n self._user = user\n self._password = password\n return self\n\n def pool_size(self, n: int) -> "ConfigBuilder":\n self._pool_size = n\n return self\n\n def build(self) -> Config:\n return Config(\n host=self._host, port=self._port, database=self._database,\n user=self._user, password=self._password,\n pool_size=self._pool_size, timeout=self._timeout, ssl=self._ssl,\n )', "design_pattern"), ("Implement a factory pattern for creating database connections", 'from abc import ABC, abstractmethod\n\nclass Connection(ABC):\n @abstractmethod\n def execute(self, query: str) -> list[dict]: ...\n\n @abstractmethod\n def close(self) -> None: ...\n\nclass PostgresConnection(Connection):\n def __init__(self, dsn: str):\n self._dsn = dsn\n\n def execute(self, query: str) -> list[dict]:\n return [] # placeholder\n\n def close(self) -> None:\n pass\n\nclass SnowflakeConnection(Connection):\n def __init__(self, account: str, user: str, password: str):\n self._account = account\n\n def execute(self, query: str) -> list[dict]:\n return []\n\n def close(self) -> None:\n pass\n\nclass ConnectionFactory:\n _registry: dict[str, type[Connection]] = {\n "postgres": PostgresConnection,\n "snowflake": SnowflakeConnection,\n }\n\n @classmethod\n def create(cls, db_type: str, **kwargs) -> Connection:\n conn_class = cls._registry.get(db_type)\n if not conn_class:\n raise ValueError(f"Unknown db type: {db_type}")\n return conn_class(**kwargs)', "design_pattern"), ] for q, code, cat in templates: pairs.append({"source": q, "target": code, "dialect": "python", "category": cat, "reward_signals": _reward()}) return pairs def _build_python_algorithms() -> List[Dict]: """Core algorithms.""" pairs = [] templates = [ ("Implement binary search", 'def binary_search(arr: list[int], target: int) -> int:\n """Return index of target in sorted array, or -1 if not found."""\n lo, hi = 0, len(arr) - 1\n while lo <= hi:\n mid = (lo + hi) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n lo = mid + 1\n else:\n hi = mid - 1\n return -1', "algorithm"), ("Implement topological sort using DFS", 'def topological_sort(graph: dict[str, list[str]]) -> list[str]:\n """Topological sort of a DAG represented as adjacency list."""\n visited: set[str] = set()\n result: list[str] = []\n\n def dfs(node: str) -> None:\n if node in visited:\n return\n visited.add(node)\n for neighbor in graph.get(node, []):\n dfs(neighbor)\n result.append(node)\n\n for node in graph:\n dfs(node)\n result.reverse()\n return result', "algorithm"), ("Implement LRU cache from scratch", 'from collections import OrderedDict\nfrom typing import TypeVar, Hashable\n\nK = TypeVar("K", bound=Hashable)\nV = TypeVar("V")\n\nclass LRUCache:\n """Least Recently Used cache with O(1) get/put."""\n\n def __init__(self, capacity: int):\n self._capacity = capacity\n self._cache: OrderedDict = OrderedDict()\n\n def get(self, key: K) -> V | None:\n if key not in self._cache:\n return None\n self._cache.move_to_end(key)\n return self._cache[key]\n\n def put(self, key: K, value: V) -> None:\n if key in self._cache:\n self._cache.move_to_end(key)\n self._cache[key] = value\n if len(self._cache) > self._capacity:\n self._cache.popitem(last=False)', "algorithm"), ("Implement Dijkstra's shortest path", 'import heapq\n\ndef dijkstra(graph: dict[str, list[tuple[str, float]]], start: str) -> dict[str, float]:\n """Shortest paths from start using Dijkstra. graph: {node: [(neighbor, weight)]}."""\n dist: dict[str, float] = {start: 0.0}\n pq: list[tuple[float, str]] = [(0.0, start)]\n\n while pq:\n d, u = heapq.heappop(pq)\n if d > dist.get(u, float("inf")):\n continue\n for v, w in graph.get(u, []):\n new_dist = d + w\n if new_dist < dist.get(v, float("inf")):\n dist[v] = new_dist\n heapq.heappush(pq, (new_dist, v))\n\n return dist', "algorithm"), ("Implement trie data structure", 'class TrieNode:\n def __init__(self):\n self.children: dict[str, "TrieNode"] = {}\n self.is_end: bool = False\n\nclass Trie:\n def __init__(self):\n self.root = TrieNode()\n\n def insert(self, word: str) -> None:\n node = self.root\n for ch in word:\n if ch not in node.children:\n node.children[ch] = TrieNode()\n node = node.children[ch]\n node.is_end = True\n\n def search(self, word: str) -> bool:\n node = self._find(word)\n return node is not None and node.is_end\n\n def starts_with(self, prefix: str) -> bool:\n return self._find(prefix) is not None\n\n def _find(self, prefix: str) -> TrieNode | None:\n node = self.root\n for ch in prefix:\n if ch not in node.children:\n return None\n node = node.children[ch]\n return node', "algorithm"), ] for q, code, cat in templates: pairs.append({"source": q, "target": code, "dialect": "python", "category": cat, "reward_signals": _reward()}) return pairs def _build_python_testing() -> List[Dict]: """Testing patterns.""" pairs = [] templates = [ ("Write pytest tests for a calculator class", 'import pytest\n\nclass Calculator:\n def add(self, a: float, b: float) -> float:\n return a + b\n\n def divide(self, a: float, b: float) -> float:\n if b == 0:\n raise ValueError("Cannot divide by zero")\n return a / b\n\nclass TestCalculator:\n @pytest.fixture\n def calc(self) -> Calculator:\n return Calculator()\n\n def test_add(self, calc: Calculator) -> None:\n assert calc.add(2, 3) == 5\n assert calc.add(-1, 1) == 0\n assert calc.add(0, 0) == 0\n\n def test_divide(self, calc: Calculator) -> None:\n assert calc.divide(10, 2) == 5.0\n assert calc.divide(7, 2) == 3.5\n\n def test_divide_by_zero(self, calc: Calculator) -> None:\n with pytest.raises(ValueError, match="Cannot divide by zero"):\n calc.divide(1, 0)\n\n @pytest.mark.parametrize("a,b,expected", [(1, 1, 2), (0, 0, 0), (-1, -1, -2)])\n def test_add_parametrized(self, calc: Calculator, a, b, expected) -> None:\n assert calc.add(a, b) == expected', "testing"), ("Write a mock-based test for an API client", 'from unittest.mock import AsyncMock, patch\nimport pytest\n\nclass APIClient:\n def __init__(self, base_url: str):\n self.base_url = base_url\n\n async def get_user(self, user_id: str) -> dict:\n import aiohttp\n async with aiohttp.ClientSession() as session:\n async with session.get(f"{self.base_url}/users/{user_id}") as resp:\n return await resp.json()\n\n@pytest.mark.asyncio\nasync def test_get_user():\n client = APIClient("https://api.example.com")\n mock_response = {"id": "123", "name": "Alice"}\n\n with patch("aiohttp.ClientSession") as mock_session:\n mock_resp = AsyncMock()\n mock_resp.json = AsyncMock(return_value=mock_response)\n mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)\n mock_resp.__aexit__ = AsyncMock(return_value=False)\n\n mock_get = AsyncMock(return_value=mock_resp)\n mock_session_inst = AsyncMock()\n mock_session_inst.get = mock_get\n mock_session_inst.__aenter__ = AsyncMock(return_value=mock_session_inst)\n mock_session_inst.__aexit__ = AsyncMock(return_value=False)\n mock_session.return_value = mock_session_inst\n\n result = await client.get_user("123")\n assert result == mock_response', "testing"), ] for q, code, cat in templates: pairs.append({"source": q, "target": code, "dialect": "python", "category": cat, "reward_signals": _reward()}) return pairs def _build_python_mistakes() -> List[Dict]: """Common Python mistakes.""" pairs = [] mistakes = [ ("Write function with default list parameter", 'def append_to(item, target=[]):\n target.append(item)\n return target', 'def append_to(item, target: list | None = None) -> list:\n if target is None:\n target = []\n target.append(item)\n return target', "Mutable default arguments are shared across calls"), ("Write bare except handler", 'try:\n result = process(data)\nexcept:\n pass', 'try:\n result = process(data)\nexcept ValueError as e:\n logger.warning("Invalid data: %s", e)\n result = default_value', "Never use bare except — catch specific exceptions"), ("String concatenation in a loop", 'def build_report(items):\n result = ""\n for item in items:\n result += str(item) + "\\n"\n return result', 'def build_report(items: list) -> str:\n return "\\n".join(str(item) for item in items)', "Use str.join() instead of += in loops for O(n) vs O(n²)"), ("Not using context manager for file", 'def read_file(path):\n f = open(path)\n data = f.read()\n f.close()\n return data', 'def read_file(path: str) -> str:\n with open(path) as f:\n return f.read()', "Always use context managers (with statement) for file I/O"), ] for q, bad, good, explanation in mistakes: pairs.append({"source": q, "target": good, "dialect": "python", "category": "mistake_correction", "reward_signals": _reward()}) pairs.append({"source": q, "target": bad, "dialect": "python", "category": "common_mistake", "mistake_explanation": explanation, "reward_signals": _reward(syntax=0.8, pythonic=0.0, complexity=0.3)}) return pairs # ── GraphQL Dataset ────────────────────────────────────────────────────────── def _build_graphql_queries() -> List[Dict]: """GraphQL query patterns.""" pairs = [] templates = [ ("Get user by ID with their posts", 'query GetUser($userId: ID!) {\n user(id: $userId) {\n id\n name\n email\n posts(first: 10, orderBy: CREATED_AT_DESC) {\n edges {\n node {\n id\n title\n content\n createdAt\n }\n }\n pageInfo {\n hasNextPage\n endCursor\n }\n }\n }\n}', "query"), ("Search products with filtering and pagination", 'query SearchProducts($query: String!, $category: Category, $first: Int = 20, $after: String) {\n searchProducts(query: $query, filter: { category: $category }, first: $first, after: $after) {\n edges {\n node {\n id\n name\n price\n category\n rating\n reviewCount\n }\n }\n totalCount\n pageInfo {\n hasNextPage\n endCursor\n }\n }\n}', "query"), ("Get dashboard analytics data", 'query DashboardAnalytics($dateRange: DateRangeInput!) {\n analytics(dateRange: $dateRange) {\n totalRevenue\n orderCount\n averageOrderValue\n conversionRate\n topProducts(limit: 5) {\n product {\n id\n name\n }\n revenue\n unitsSold\n }\n revenueByDay {\n date\n amount\n }\n }\n}', "query"), ("Get Neo4j graph data with Cypher resolver", 'query GetMovieNetwork($movieTitle: String!) {\n movies(where: { title: $movieTitle }) {\n title\n released\n actors {\n name\n born\n }\n directors {\n name\n }\n similarMovies @cypher(statement: """\n MATCH (this)<-[:ACTED_IN]-(:Person)-[:ACTED_IN]->(other:Movie)\n WHERE other <> this\n RETURN DISTINCT other\n LIMIT 5\n """) {\n title\n released\n }\n }\n}', "cypher_resolver"), ] for q, gql, cat in templates: pairs.append({"source": q, "target": gql, "dialect": "graphql", "category": cat, "reward_signals": _gql_reward()}) return pairs def _build_graphql_mutations() -> List[Dict]: """GraphQL mutation patterns.""" pairs = [] templates = [ ("Create a new user account", 'mutation CreateUser($input: CreateUserInput!) {\n createUser(input: $input) {\n user {\n id\n name\n email\n createdAt\n }\n errors {\n field\n message\n }\n }\n}', "mutation"), ("Place an order with multiple items", 'mutation PlaceOrder($input: PlaceOrderInput!) {\n placeOrder(input: $input) {\n order {\n id\n status\n totalAmount\n items {\n product {\n id\n name\n }\n quantity\n unitPrice\n }\n shippingAddress {\n street\n city\n state\n zipCode\n }\n }\n errors {\n field\n message\n }\n }\n}', "mutation"), ("Update user profile with optimistic locking", 'mutation UpdateProfile($id: ID!, $input: UpdateProfileInput!, $version: Int!) {\n updateProfile(id: $id, input: $input, expectedVersion: $version) {\n profile {\n id\n displayName\n bio\n avatarUrl\n version\n }\n errors {\n field\n message\n code\n }\n }\n}', "mutation"), ("Create Neo4j relationship via GraphQL", 'mutation ConnectActorToMovie($actorName: String!, $movieTitle: String!, $role: String!) {\n createActedInRelationship(\n input: {\n actor: { where: { name: $actorName } }\n movie: { where: { title: $movieTitle } }\n edge: { role: $role }\n }\n ) {\n actors {\n name\n }\n movies {\n title\n }\n }\n}', "mutation"), ] for q, gql, cat in templates: pairs.append({"source": q, "target": gql, "dialect": "graphql", "category": cat, "reward_signals": _gql_reward()}) return pairs def _build_graphql_subscriptions() -> List[Dict]: """GraphQL subscription patterns.""" pairs = [] templates = [ ("Subscribe to order status updates", 'subscription OrderUpdates($orderId: ID!) {\n orderStatusChanged(orderId: $orderId) {\n order {\n id\n status\n updatedAt\n estimatedDelivery\n }\n previousStatus\n newStatus\n }\n}', "subscription"), ("Subscribe to real-time sensor alerts", 'subscription SensorAlerts($deviceIds: [ID!]!, $minSeverity: AlertSeverity = WARNING) {\n sensorAlert(deviceIds: $deviceIds, minSeverity: $minSeverity) {\n alert {\n id\n deviceId\n severity\n message\n reading {\n sensorType\n value\n unit\n timestamp\n }\n }\n }\n}', "subscription"), ] for q, gql, cat in templates: pairs.append({"source": q, "target": gql, "dialect": "graphql", "category": cat, "reward_signals": _gql_reward()}) return pairs def _build_graphql_fragments() -> List[Dict]: """Fragment and directive patterns.""" pairs = [] templates = [ ("Use fragments for reusable user fields", 'fragment UserFields on User {\n id\n name\n email\n avatarUrl\n}\n\nfragment UserWithPosts on User {\n ...UserFields\n posts(first: 5) {\n edges {\n node {\n id\n title\n createdAt\n }\n }\n }\n}\n\nquery GetUsers {\n users(first: 20) {\n edges {\n node {\n ...UserWithPosts\n }\n }\n }\n}', "fragment"), ("Conditional fields with directives", 'query GetProduct($id: ID!, $includeReviews: Boolean!, $includeInventory: Boolean!) {\n product(id: $id) {\n id\n name\n price\n description\n reviews @include(if: $includeReviews) {\n edges {\n node {\n rating\n comment\n author {\n name\n }\n }\n }\n }\n inventory @include(if: $includeInventory) {\n warehouse\n quantity\n lastUpdated\n }\n legacyField @deprecated(reason: "Use newField instead")\n }\n}', "directive"), ] for q, gql, cat in templates: pairs.append({"source": q, "target": gql, "dialect": "graphql", "category": cat, "reward_signals": _gql_reward()}) return pairs def _build_graphql_federation() -> List[Dict]: """Apollo Federation patterns.""" pairs = [] templates = [ ("Define federated product type with key", 'type Product @key(fields: "id") {\n id: ID!\n name: String!\n price: Float!\n category: Category!\n}\n\nextend type Query {\n product(id: ID!): Product\n products(first: Int, after: String, filter: ProductFilter): ProductConnection!\n}', "federation_schema"), ("Extend product type from another service", 'type Product @key(fields: "id") @extends {\n id: ID! @external\n reviews: [Review!]!\n averageRating: Float!\n reviewCount: Int!\n}\n\ntype Review {\n id: ID!\n rating: Int!\n comment: String\n author: User!\n createdAt: DateTime!\n}', "federation_extend"), ("Query across federated services", 'query GetProductWithReviews($productId: ID!) {\n product(id: $productId) {\n id\n name\n price\n category\n reviews {\n rating\n comment\n author {\n name\n avatarUrl\n }\n }\n averageRating\n inventory {\n warehouse\n quantity\n }\n }\n}', "federation_query"), ] for q, gql, cat in templates: pairs.append({"source": q, "target": gql, "dialect": "graphql", "category": cat, "reward_signals": _gql_reward()}) return pairs def _build_graphql_mistakes() -> List[Dict]: """Common GraphQL mistakes.""" pairs = [] mistakes = [ ("Query user without required argument", '{ user { name email } }', 'query GetUser($userId: ID!) {\n user(id: $userId) {\n name\n email\n }\n}', "Missing required arguments — user needs id parameter"), ("N+1 query pattern", 'query { users { name posts { comments { author { name } } } } }', 'query GetUsersWithPosts {\n users(first: 20) {\n edges {\n node {\n name\n posts(first: 10) {\n edges {\n node {\n title\n commentCount\n }\n }\n }\n }\n }\n }\n}', "Deeply nested queries cause N+1 — limit depth, use pagination"), ("Mutation without error handling", 'mutation { createUser(name: "Alice") { id } }', 'mutation CreateUser($input: CreateUserInput!) {\n createUser(input: $input) {\n user {\n id\n name\n }\n errors {\n field\n message\n }\n }\n}', "Mutations should return union of result + errors, use input types"), ] for q, bad, good, explanation in mistakes: pairs.append({"source": q, "target": good, "dialect": "graphql", "category": "mistake_correction", "reward_signals": _gql_reward()}) pairs.append({"source": q, "target": bad, "dialect": "graphql", "category": "common_mistake", "mistake_explanation": explanation, "reward_signals": _gql_reward(schema=0.3, efficiency=0.2)}) return pairs # ── Main ───────────────────────────────────────────────────────────────────── def build_all() -> List[Dict]: builders = [ _build_python_data_processing, _build_python_async, _build_python_design_patterns, _build_python_algorithms, _build_python_testing, _build_python_mistakes, _build_graphql_queries, _build_graphql_mutations, _build_graphql_subscriptions, _build_graphql_fragments, _build_graphql_federation, _build_graphql_mistakes, ] all_pairs = [] for builder in builders: pairs = builder() cat = pairs[0]["category"] if pairs else "unknown" print(f" {builder.__name__}: {len(pairs)} pairs ({cat})") all_pairs.extend(pairs) return all_pairs def main(): print("=== Building Python + GraphQL Generation Datasets ===\n") pairs = build_all() py_count = sum(1 for p in pairs if p["dialect"] == "python") gql_count = sum(1 for p in pairs if p["dialect"] == "graphql") print(f"\nTotal: {len(pairs)} pairs (Python: {py_count}, GraphQL: {gql_count})") RESULTS_DIR.mkdir(parents=True, exist_ok=True) out_path = RESULTS_DIR / "python_graphql_dataset.json" with open(out_path, "w") as f: json.dump(pairs, f, indent=2) print(f"Saved → {out_path}") print("Run scripts/combine_and_push_datasets.py to merge into SOTA data") return pairs if __name__ == "__main__": main()