{ "cells": [ { "cell_type": "markdown", "id": "01d11866", "metadata": {}, "source": [ "# Open nba and tennis datasets" ] }, { "cell_type": "code", "execution_count": 7, "id": "155a7ecb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total NBA dataset examples: 600\n", " natural_query \\\n", "205 How many points did the home team score in the... \n", "\n", " sql_query result \n", "205 SELECT pts_home FROM game WHERE game_id = (SEL... 122.0 \n", "\n", "\n", "Total Tennis dataset examples: 204\n", " natural_query \\\n", "0 Get the full names of all players taller than ... \n", "\n", " sql_query result \n", "0 SELECT name FROM players WHERE height > 210; Reilly|Opelka \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_21248\\149351044.py:11: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", " nba_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n", "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_21248\\149351044.py:12: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", " tennis_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n" ] } ], "source": [ "import pandas as pd\n", "import re\n", "\n", "SAMPLE_SIZE = 600\n", "\n", "# Open two datasets\n", "nba_df = pd.read_csv(\"../../training-data/nba_train_set.tsv\", sep='\\t')\n", "tennis_df = pd.read_csv(\"../../training-data/tennis_train_set.tsv\", sep='\\t')\n", "\n", "# Fix any spacing issues\n", "nba_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n", "tennis_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n", "\n", "# Downsample NBA\n", "nba_df = nba_df.sample(n=SAMPLE_SIZE)\n", "\n", "# Display dataset info\n", "print(f\"Total NBA dataset examples: {len(nba_df)}\")\n", "print(nba_df.head(1))\n", "print()\n", "print()\n", "print(f\"Total Tennis dataset examples: {len(tennis_df)}\")\n", "print(tennis_df.head(1))" ] }, { "cell_type": "markdown", "id": "eb357705", "metadata": {}, "source": [ "# Combine into one tsv with extra column indicating which set each example belongs to" ] }, { "cell_type": "code", "execution_count": 11, "id": "b3acd217", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved combined dataset with 804 rows\n" ] } ], "source": [ "# Add \"is_nba\" indicator column\n", "nba_df[\"is_nba\"] = True\n", "tennis_df[\"is_nba\"] = False\n", "\n", "# Combine into single dataframe, then shuffle\n", "combined_df = pd.concat([nba_df, tennis_df], ignore_index=True)\n", "combined_df = combined_df.sample(frac=1).reset_index(drop=True)\n", "\n", "\n", "# Save to combined TSV\n", "combined_df.to_csv(\"../../training-data/combined_dataset.tsv\", sep=\"\\t\", index=False)\n", "print(\"Saved combined dataset with\", len(combined_df), \"rows\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 5 }