File size: 6,511 Bytes
825751a
 
 
 
 
 
 
 
 
 
 
 
e92176b
825751a
 
 
 
 
 
 
e92176b
 
 
 
 
 
 
825751a
e92176b
 
 
825751a
e92176b
 
825751a
 
e92176b
 
 
 
 
 
825751a
 
 
 
 
 
e92176b
825751a
e92176b
 
 
825751a
 
 
 
 
 
 
 
e92176b
825751a
 
 
e92176b
 
 
 
 
 
 
825751a
 
 
 
 
e92176b
 
 
 
 
825751a
 
 
e92176b
 
 
825751a
 
 
 
 
 
e92176b
 
 
 
 
825751a
 
 
 
 
 
 
 
 
 
 
 
e92176b
825751a
 
 
 
 
 
 
e92176b
825751a
 
 
 
 
 
 
 
 
 
 
 
 
 
e92176b
825751a
 
e92176b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825751a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "01d11866",
   "metadata": {},
   "source": [
    "# Open nba and tennis datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "155a7ecb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total NBA dataset examples: 500\n",
      "                                          natural_query  \\\n",
      "2096  How many times have the Memphis Grizzlies won ...   \n",
      "\n",
      "                                              sql_query result  \n",
      "2096  SELECT COUNT(*) FROM game WHERE (team_abbrevia...     31  \n",
      "\n",
      "\n",
      "Total Tennis dataset examples: 514\n",
      "                       natural_query  \\\n",
      "1  How many players are left-handed?   \n",
      "\n",
      "                                        sql_query result  \n",
      "1  SELECT COUNT(*) FROM players WHERE hand = 'L';   1435  \n",
      "\n",
      "\n",
      "Total Tennis test examples: 100\n",
      "                                         natural_query  \\\n",
      "144  What is the average ranking of players defeate...   \n",
      "\n",
      "                                             sql_query            result  \n",
      "144  SELECT AVG(r.rank) FROM matches m JOIN ranking...  212.317855446654  \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_22452\\2246720866.py:17: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
      "  nba_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n",
      "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_22452\\2246720866.py:18: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
      "  tennis_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n",
      "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_22452\\2246720866.py:29: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
      "  tennis_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import re\n",
    "\n",
    "SAMPLE_SIZE = 500\n",
    "\n",
    "# Open two datasets\n",
    "nba_df = pd.read_csv(\"../../training-data/nba_train_set.tsv\", sep='\\t')\n",
    "dean_df = pd.read_csv(\"../../training-data/tennis_train_set_dean.tsv\", sep='\\t')\n",
    "connor_df = pd.read_csv(\"../../training-data/tennis_train_set_connor.tsv\", sep='\\t')\n",
    "mehul_df = pd.read_csv(\"../../training-data/tennis_train_set_mehul.tsv\", sep='\\t')\n",
    "mehul_df = mehul_df.drop('tennis', axis=1)\n",
    "\n",
    "# Merge all tennis datasets into one\n",
    "tennis_df = pd.concat([dean_df, mehul_df], ignore_index=True)\n",
    "\n",
    "# Fix any spacing issues\n",
    "nba_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n",
    "tennis_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n",
    "\n",
    "# Separate testing data for tennis\n",
    "test_tennis_df = tennis_df.sample(n=100)\n",
    "tennis_df = pd.concat([dean_df, mehul_df, connor_df], ignore_index=True)\n",
    "tennis_df = tennis_df.drop(test_tennis_df.index)\n",
    "\n",
    "# Downsample NBA\n",
    "nba_df = nba_df.sample(n=SAMPLE_SIZE)\n",
    "\n",
    "# Pull in Connor's data\n",
    "tennis_df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n",
    "\n",
    "# Display dataset info\n",
    "print(f\"Total NBA dataset examples: {len(nba_df)}\")\n",
    "print(nba_df.head(1))\n",
    "print()\n",
    "print()\n",
    "print(f\"Total Tennis dataset examples: {len(tennis_df)}\")\n",
    "print(tennis_df.head(1))\n",
    "print()\n",
    "print()\n",
    "print(f\"Total Tennis test examples: {len(test_tennis_df)}\")\n",
    "print(test_tennis_df.head(1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb357705",
   "metadata": {},
   "source": [
    "# Combine into one tsv with extra column indicating which set each example belongs to"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b3acd217",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved combined dataset with 1014 rows\n"
     ]
    }
   ],
   "source": [
    "# Add \"is_nba\" indicator column\n",
    "nba_df[\"is_nba\"] = True\n",
    "tennis_df[\"is_nba\"] = False\n",
    "\n",
    "# Combine into single dataframe, then shuffle\n",
    "combined_df = pd.concat([nba_df, tennis_df], ignore_index=True)\n",
    "combined_df = combined_df.sample(frac=1).reset_index(drop=True)\n",
    "\n",
    "\n",
    "# Save to combined TSV\n",
    "combined_df.to_csv(\"../../training-data/combined_full_dataset.tsv\", sep=\"\\t\", index=False)\n",
    "print(\"Saved combined dataset with\", len(combined_df), \"rows\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ce62029",
   "metadata": {},
   "source": [
    "# Combine tennis test data with NBA test tsv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "72a934e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved combined test dataset with 250 rows\n"
     ]
    }
   ],
   "source": [
    "nba_test_df = pd.read_csv(\"../../training-data/nba_test_set.tsv\", sep='\\t')\n",
    "\n",
    "nba_test_df[\"is_nba\"] = True\n",
    "test_tennis_df[\"is_nba\"] = False\n",
    "\n",
    "combined_test_df = pd.concat([nba_test_df, test_tennis_df], ignore_index=True)\n",
    "combined_test_df = combined_test_df.sample(frac=1).reset_index(drop=True)\n",
    "\n",
    "combined_test_df.to_csv(\"../../training-data/test_set.tsv\", sep='\\t', index=False)\n",
    "print(\"Saved combined test dataset with\", len(combined_test_df), \"rows\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}