{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Phase 3: Extract Python Test Cases from HuggingFace Dataset\n", "\n", "This notebook loads the `newfacade/LeetCodeDataset` from HuggingFace, extracts Python test cases,\n", "fixes indentation (adds 4 spaces before each `assert` line), and saves to `python_tests.jsonl`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ Dependencies installed\n" ] } ], "source": [ "# Install required packages\n", "import subprocess\n", "import sys\n", "\n", "# Install datasets if not already installed\n", "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"datasets\", \"-q\"])\n", "print(\"✅ Dependencies installed\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loading HuggingFace dataset: newfacade/LeetCodeDataset...\n", "✅ Dataset loaded: DatasetDict({\n", " train: Dataset({\n", " features: ['task_id', 'question_id', 'difficulty', 'tags', 'problem_description', 'starter_code', 'estimated_date', 'prompt', 'completion', 'entry_point', 'test', 'input_output', 'query', 'response'],\n", " num_rows: 2641\n", " })\n", " test: Dataset({\n", " features: ['task_id', 'question_id', 'difficulty', 'tags', 'problem_description', 'starter_code', 'estimated_date', 'prompt', 'completion', 'entry_point', 'test', 'input_output', 'query', 'response'],\n", " num_rows: 228\n", " })\n", "})\n", "\n", "Dataset keys: dict_keys(['train', 'test'])\n" ] } ], "source": [ "import json\n", "from datasets import load_dataset\n", "from pathlib import Path\n", "\n", "print(\"Loading HuggingFace dataset: newfacade/LeetCodeDataset...\")\n", "dataset = load_dataset(\"newfacade/LeetCodeDataset\")\n", "print(f\"✅ Dataset loaded: {dataset}\")\n", "print(f\"\\nDataset keys: {dataset.keys()}\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First sample keys: dict_keys(['task_id', 'question_id', 'difficulty', 'tags', 'problem_description', 'starter_code', 'estimated_date', 'prompt', 'completion', 'entry_point', 'test', 'input_output', 'query', 'response'])\n", "\n", "First sample:\n", " task_id: two-sum\n", " question_id: 1\n", " difficulty: Easy\n", " tags: ['Array', 'Hash Table']\n", " problem_description: Given an array of integers nums and an integer target, return indices of the two numbers such that t...\n", " starter_code: class Solution:\n", " def twoSum(self, nums: List[int], target: int) -> List[int]:\n", " \n", " estimated_date: 2015-08-07 00:00:00\n", " prompt: import random\n", "import functools\n", "import collections\n", "import string\n", "import math\n", "import datetime\n", "\n", "from ty...\n", " completion: class Solution:\n", " def twoSum(self, nums: List[int], target: int) -> List[int]:\n", " d = {}\n", " ...\n", " entry_point: Solution().twoSum\n", " test: def check(candidate):\n", " assert candidate(nums = [3, 3],target = 6) == [0, 1]\n", " assert candidate(...\n", " input_output: [{'input': 'nums = [3,3], target = 6', 'output': '[0, 1]'}, {'input': 'nums = [-1,-2,-3,-4], target = -8', 'output': 'None'}, {'input': 'nums = [1000000000, 1000000000], target = 2000000000', 'output': '[0, 1]'}, {'input': 'nums = [1,5,7,9], target = 10', 'output': '[0, 3]'}, {'input': 'nums = [1,2,3,4,5,6,7,8,9,10], target = 3', 'output': '[0, 1]'}, {'input': 'nums = [0,4,3,0], target = 0', 'output': '[0, 3]'}, {'input': 'nums = [1000000000, -1000000000, 500000000, -500000000], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [1,2,3,4,5,6,7,8,9,10], target = 17', 'output': '[7, 8]'}, {'input': 'nums = [1,5,7,8], target = 15', 'output': '[2, 3]'}, {'input': 'nums = [1000000000, -1000000000], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [2,7,11,15], target = 9', 'output': '[0, 1]'}, {'input': 'nums = [1,2,3,4,5,6,7,8,9,10], target = 19', 'output': '[8, 9]'}, {'input': 'nums = [1,5,7,11], target = 16', 'output': '[1, 3]'}, {'input': 'nums = [5,5,5,5,5,5,5,5,5,5], target = 10', 'output': '[0, 1]'}, {'input': 'nums = [3,2,4], target = 6', 'output': '[1, 2]'}, {'input': 'nums = [15,11,7,2], target = 9', 'output': '[2, 3]'}, {'input': 'nums = [1000000000,-1000000000,2000000000,-2000000000], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2100, 2200, 2300, 2400, 2500, 2600, 2700, 2800, 2900, 3000], target = 4000', 'output': '[18, 20]'}, {'input': 'nums = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99], target = 100', 'output': '[24, 25]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], target = 199', 'output': '[98, 99]'}, {'input': 'nums = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], target = 0', 'output': '[4, 6]'}, {'input': 'nums = [-10, -20, -30, -40, -50, -60, -70, -80, -90, -100], target = -150', 'output': '[6, 7]'}, {'input': 'nums = [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20], target = -39', 'output': '[18, 19]'}, {'input': 'nums = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119], target = 110', 'output': '[26, 28]'}, {'input': 'nums = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], target = 1100', 'output': '[4, 5]'}, {'input': 'nums = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39], target = 70', 'output': '[16, 18]'}, {'input': 'nums = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2], target = 3', 'output': '[18, 19]'}, {'input': 'nums = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2100, 2200, 2300, 2400, 2500], target = 3000', 'output': '[13, 15]'}, {'input': 'nums = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000], target = 30000', 'output': '[13, 15]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], target = 39', 'output': '[18, 19]'}, {'input': 'nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], target = 99', 'output': '[49, 50]'}, {'input': 'nums = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500], target = 900', 'output': '[43, 45]'}, {'input': 'nums = [-1000000000, -2000000000, -3000000000, -4000000000, -5000000000, -6000000000, -7000000000, -8000000000, -9000000000, -10000000000], target = -15000000000', 'output': '[6, 7]'}, {'input': 'nums = [23, 8, 15, 37, 48, 5, 21, 7, 40, 6], target = 33', 'output': 'None'}, {'input': 'nums = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200], target = 390', 'output': '[18, 19]'}, {'input': 'nums = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000], target = 1500', 'output': '[6, 7]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], target = 21', 'output': '[9, 10]'}, {'input': 'nums = [2, 5, 1, 9, 3, 8, 7, 6, 4, 0], target = 17', 'output': '[3, 5]'}, {'input': 'nums = [1000000000, -1000000000, 500000000, 500000000], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], target = 29', 'output': '[13, 14]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], target = 38', 'output': '[17, 19]'}, {'input': 'nums = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39], target = 78', 'output': 'None'}, {'input': 'nums = [-1000000000, 1000000000, 500000000, -500000000], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [1000000000, 999999999, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], target = 1999999999', 'output': '[0, 1]'}, {'input': 'nums = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200], target = 300', 'output': '[13, 15]'}, {'input': 'nums = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59], target = 100', 'output': '[24, 25]'}, {'input': 'nums = [5, 12, 7, 3, 9, 14, 10, 23, 1, 11], target = 22', 'output': '[1, 6]'}, {'input': 'nums = [-3, 4, 3, 90, -11, 23, -5, 67, 100, -45, 89], target = 53', 'output': 'None'}, {'input': 'nums = [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31, -32, -33, -34, -35, -36, -37, -38, -39, -40, -41, -42, -43, -44, -45, -46, -47, -48, -49, -50, -51, -52, -53, -54, -55, -56, -57, -58, -59, -60, -61, -62, -63, -64, -65, -66, -67, -68, -69, -70, -71, -72, -73, -74, -75, -76, -77, -78, -79, -80, -81, -82, -83, -84, -85, -86, -87, -88, -89, -90, -91, -92, -93, -94, -95, -96, -97, -98, -99, -100], target = -199', 'output': '[98, 99]'}, {'input': 'nums = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], target = 1', 'output': '[0, 1]'}, {'input': 'nums = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2], target = 3', 'output': '[23, 24]'}, {'input': 'nums = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710, 720, 730, 740, 750, 760, 770, 780, 790, 800, 810, 820, 830, 840, 850, 860, 870, 880, 890, 900, 910, 920, 930, 940, 950, 960, 970, 980, 990, 1000], target = 1990', 'output': '[98, 99]'}, {'input': 'nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], target = 59', 'output': '[29, 30]'}, {'input': 'nums = [123456789, 987654321, 456789123, 321987654, 654321987, 789123456], target = 1111111110', 'output': '[0, 1]'}, {'input': 'nums = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10], target = -11', 'output': '[4, 5]'}, {'input': 'nums = [999999999, 999999998, 999999997, 999999996, 999999995, 999999994, 999999993, 999999992], target = 1999999997', 'output': '[0, 1]'}, {'input': 'nums = [-1,-2,-3,-4,-5,-6,-7,-8,-9,-10], target = -15', 'output': '[6, 7]'}, {'input': 'nums = [1000000000, -1000000000, 500000000, 500000000, -500000000, -500000000, 1, 2, 3, 4], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [1000000000, 1000000000, 1000000000, 1000000000, 1000000000, 1000000000, 1000000000, 1000000000, 1000000000, 1000000000], target = 2000000000', 'output': '[0, 1]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150], target = 299', 'output': '[148, 149]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], target = 59', 'output': '[28, 29]'}, {'input': 'nums = [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20], target = -31', 'output': '[14, 15]'}, {'input': 'nums = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000], target = 60000', 'output': 'None'}, {'input': 'nums = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], target = 39', 'output': '[18, 19]'}, {'input': 'nums = [1000000000, -500000000, 2000000000, -1000000000, 0, 500000000], target = 1000000000', 'output': '[2, 3]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], target = 99', 'output': '[48, 49]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], target = 49', 'output': '[23, 24]'}, {'input': 'nums = [999999999, 999999998, 999999997, 999999996, 999999995, 999999994, 999999993, 999999992, 999999991, 999999990], target = 1999999989', 'output': '[4, 5]'}, {'input': 'nums = [1,2,4,8,16,32,64,128,256,512], target = 513', 'output': '[0, 9]'}, {'input': 'nums = [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10], target = -18', 'output': '[7, 9]'}, {'input': 'nums = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500], target = 1300', 'output': '[5, 6]'}, {'input': 'nums = [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10], target = -17', 'output': '[7, 8]'}, {'input': 'nums = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000], target = 1900', 'output': '[8, 9]'}, {'input': 'nums = [0,0,0,0,0,0,0,0,0,0], target = 0', 'output': '[0, 1]'}, {'input': 'nums = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39], target = 79', 'output': 'None'}, {'input': 'nums = [1000000000, 2000000000, 3000000000, 4000000000, 5000000000, 6000000000, 7000000000, 8000000000, 9000000000, 10000000000], target = 30000000000', 'output': 'None'}, {'input': 'nums = [-3, -1, 0, 2, 5, 7, 8, 10], target = 4', 'output': '[1, 4]'}, {'input': 'nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], target = 100', 'output': 'None'}, {'input': 'nums = [29, 37, 10, 55, 44, 3, 67, 90, 11, 38, 2, 9, 100, 34, 65, 23, 89, 12, 33, 22], target = 62', 'output': '[0, 18]'}]\n", " query: You are an expert Python programmer. You will be given a question (problem specification) and will g...\n", " response: To solve this problem efficiently, we can use a hash map (dictionary in Python) to store the numbers...\n" ] } ], "source": [ "# Inspect first few samples to understand structure\n", "if \"train\" in dataset:\n", " sample_split = dataset[\"train\"]\n", "else:\n", " sample_split = list(dataset.values())[0]\n", "\n", "print(f\"First sample keys: {sample_split[0].keys()}\")\n", "print(f\"\\nFirst sample:\")\n", "for key, value in sample_split[0].items():\n", " if isinstance(value, str) and len(value) > 100:\n", " print(f\" {key}: {value[:100]}...\")\n", " else:\n", " print(f\" {key}: {value}\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before:\n", "'def check(candidate):\\nassert candidate(nums = [3, 3],target = 6) == [0, 1]\\nassert candidate(nums = [-1, -2, -3, -4],target = -8) == None'\n", "\n", "After:\n", "'def check(candidate):\\n assert candidate(nums = [3, 3],target = 6) == [0, 1]\\n assert candidate(nums = [-1, -2, -3, -4],target = -8) == None'\n", "\n", "Formatted:\n", "def check(candidate):\n", " assert candidate(nums = [3, 3],target = 6) == [0, 1]\n", " assert candidate(nums = [-1, -2, -3, -4],target = -8) == None\n" ] } ], "source": [ "def fix_test_indentation(test_code: str) -> str:\n", " \"\"\"\n", " Fix indentation in test code.\n", " Expected format:\n", " def check(candidate):\n", " assert candidate(...) == ...\n", " assert candidate(...) == ...\n", " \n", " Output format:\n", " def check(candidate):\n", " assert candidate(...) == ...\n", " assert candidate(...) == ...\n", " \"\"\"\n", " lines = test_code.strip().split(\"\\n\")\n", " fixed_lines = []\n", " \n", " for line in lines:\n", " # First line should be 'def check(candidate):' - no indent\n", " if line.strip().startswith(\"def check\"):\n", " fixed_lines.append(line.strip())\n", " # All other lines should have 4 spaces of indentation\n", " elif line.strip():\n", " # Remove existing indentation\n", " stripped = line.lstrip()\n", " # Add 4 spaces\n", " fixed_lines.append(\" \" + stripped)\n", " \n", " return \"\\n\".join(fixed_lines)\n", "\n", "# Test the function\n", "test_input = \"\"\"def check(candidate):\n", "assert candidate(nums = [3, 3],target = 6) == [0, 1]\n", "assert candidate(nums = [-1, -2, -3, -4],target = -8) == None\"\"\"\n", "\n", "print(\"Before:\")\n", "print(repr(test_input))\n", "print(\"\\nAfter:\")\n", "print(repr(fix_test_indentation(test_input)))\n", "print(\"\\nFormatted:\")\n", "print(fix_test_indentation(test_input))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Processing 2641 samples...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Processed 500/2641 samples... (extracted: 500, skipped: 0)\n", " Processed 1000/2641 samples... (extracted: 1000, skipped: 0)\n", " Processed 1500/2641 samples... (extracted: 1500, skipped: 0)\n", " Processed 2000/2641 samples... (extracted: 2000, skipped: 0)\n", " Processed 2500/2641 samples... (extracted: 2500, skipped: 0)\n", "\n", "✅ Extraction complete!\n", " Extracted: 2641 test cases\n", " Skipped: 0\n", " Errors: 0\n", " Output: /teamspace/studios/this_studio/dataset/complexity_reasoning_data/python_tests.jsonl\n" ] } ], "source": [ "# Extract test cases and save to JSONL\n", "output_path = Path(\"/teamspace/studios/this_studio/dataset/complexity_reasoning_data/python_tests.jsonl\")\n", "output_path.parent.mkdir(parents=True, exist_ok=True)\n", "\n", "# Determine which split to use\n", "if \"train\" in dataset:\n", " data_split = dataset[\"train\"]\n", "else:\n", " data_split = list(dataset.values())[0]\n", "\n", "print(f\"Processing {len(data_split)} samples...\")\n", "\n", "extracted_count = 0\n", "skipped_count = 0\n", "error_count = 0\n", "\n", "MAX_ASSERT_LINES = 30 # Keep only first 30 assert lines\n", "\n", "with open(output_path, \"w\") as f:\n", " for i, sample in enumerate(data_split):\n", " try:\n", " # Try to find problem_id and test field\n", " problem_id = sample.get(\"question_id\") or sample.get(\"problem_id\") or sample.get(\"id\")\n", " test_code = sample.get(\"test\") or sample.get(\"test_code\") or sample.get(\"tests\")\n", " \n", " if problem_id and test_code:\n", " # Check if test contains 'def check' format\n", " if \"def check\" in str(test_code):\n", " # Fix indentation\n", " fixed_test = fix_test_indentation(test_code)\n", " \n", " # Limit to first 30 assert lines\n", " lines = fixed_test.split(\"\\n\")\n", " def_check_line = lines[0] # \"def check(candidate):\"\n", " assert_lines = [l for l in lines[1:] if l.strip().startswith(\"assert\")]\n", " limited_assert_lines = assert_lines[:MAX_ASSERT_LINES]\n", " final_test = \"\\n\".join([def_check_line] + limited_assert_lines)\n", " \n", " # Write to JSONL\n", " json.dump({\"problem_id\": problem_id, \"test\": final_test}, f)\n", " f.write(\"\\n\")\n", " extracted_count += 1\n", " else:\n", " skipped_count += 1\n", " else:\n", " skipped_count += 1\n", " except Exception as e:\n", " error_count += 1\n", " if error_count <= 5: # Print first 5 errors\n", " print(f\"Error at index {i}: {e}\")\n", " \n", " if (i + 1) % 500 == 0:\n", " print(f\" Processed {i + 1}/{len(data_split)} samples... (extracted: {extracted_count}, skipped: {skipped_count})\")\n", "\n", "print(f\"\\n✅ Extraction complete!\")\n", "print(f\" Extracted: {extracted_count} test cases\")\n", "print(f\" Skipped: {skipped_count}\")\n", "print(f\" Errors: {error_count}\")\n", "print(f\" Output: {output_path}\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output file: /teamspace/studios/this_studio/dataset/complexity_reasoning_data/python_tests.jsonl\n", " File size: 10,779,921 bytes\n", " Line count: 2641\n", "\n", "First 3 samples:\n", "\n", " Sample 1:\n", " problem_id: 1\n", " test: (first 3 lines)\n", " 'def check(candidate):'\n", " ' assert candidate(nums = [3, 3],target = 6) == [0, 1]'\n", " ' assert candidate(nums = [-1, -2, -3, -4],target = -8) == None'\n", " ... (28 more lines)\n", "\n", " Sample 2:\n", " problem_id: 2\n", " test: (first 3 lines)\n", " 'def check(candidate):'\n", " ' assert is_same_list(candidate(l1 = list_node([9, 8, 7]),l2 = list_node([1, 2, 3])), list_node([0, 1, 1, 1]))'\n", " ' assert is_same_list(candidate(l1 = list_node([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),l2 = list_node([5, 6, 4])), list_node([6, 6, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))'\n", " ... (28 more lines)\n", "\n", " Sample 3:\n", " problem_id: 3\n", " test: (first 3 lines)\n", " 'def check(candidate):'\n", " ' assert candidate(s = \"abcabcbb\") == 3'\n", " ' assert candidate(s = \"bbbbb\") == 1'\n", " ... (28 more lines)\n" ] } ], "source": [ "# Verify output and show samples\n", "import os\n", "\n", "file_size = os.path.getsize(output_path)\n", "line_count = sum(1 for _ in open(output_path))\n", "\n", "print(f\"Output file: {output_path}\")\n", "print(f\" File size: {file_size:,} bytes\")\n", "print(f\" Line count: {line_count}\")\n", "print(f\"\\nFirst 3 samples:\")\n", "\n", "with open(output_path) as f:\n", " for i, line in enumerate(f):\n", " if i >= 3:\n", " break\n", " sample = json.loads(line)\n", " print(f\"\\n Sample {i+1}:\")\n", " print(f\" problem_id: {sample['problem_id']}\")\n", " test_lines = sample['test'].split(\"\\n\")\n", " print(f\" test: (first 3 lines)\")\n", " for j, tline in enumerate(test_lines[:3]):\n", " print(f\" {repr(tline)}\")\n", " if len(test_lines) > 3:\n", " print(f\" ... ({len(test_lines) - 3} more lines)\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }