colab support part two
Browse files- test_pretrained.ipynb +17 -13
test_pretrained.ipynb
CHANGED
|
@@ -9,7 +9,7 @@
|
|
| 9 |
},
|
| 10 |
{
|
| 11 |
"cell_type": "code",
|
| 12 |
-
"execution_count":
|
| 13 |
"metadata": {},
|
| 14 |
"outputs": [],
|
| 15 |
"source": [
|
|
@@ -19,13 +19,14 @@
|
|
| 19 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 20 |
"import torch\n",
|
| 21 |
"import sys\n",
|
|
|
|
| 22 |
"import sqlite3 as sql\n",
|
| 23 |
"from huggingface_hub import snapshot_download"
|
| 24 |
]
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
-
"execution_count":
|
| 29 |
"metadata": {},
|
| 30 |
"outputs": [],
|
| 31 |
"source": [
|
|
@@ -34,22 +35,25 @@
|
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"cell_type": "code",
|
| 37 |
-
"execution_count":
|
| 38 |
"metadata": {},
|
| 39 |
"outputs": [],
|
| 40 |
"source": [
|
|
|
|
|
|
|
| 41 |
"if is_google_colab:\n",
|
| 42 |
" hugging_face_path = snapshot_download(\n",
|
| 43 |
" repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
|
| 44 |
" repo_type=\"model\", \n",
|
| 45 |
" allow_patterns=[\"src/*\"], \n",
|
| 46 |
" )\n",
|
| 47 |
-
" sys.path.append(hugging_face_path)"
|
|
|
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "code",
|
| 52 |
-
"execution_count":
|
| 53 |
"metadata": {},
|
| 54 |
"outputs": [],
|
| 55 |
"source": [
|
|
@@ -66,7 +70,7 @@
|
|
| 66 |
},
|
| 67 |
{
|
| 68 |
"cell_type": "code",
|
| 69 |
-
"execution_count":
|
| 70 |
"metadata": {},
|
| 71 |
"outputs": [
|
| 72 |
{
|
|
@@ -76,15 +80,15 @@
|
|
| 76 |
"Total dataset examples: 1044\n",
|
| 77 |
"\n",
|
| 78 |
"\n",
|
| 79 |
-
"
|
| 80 |
-
"SELECT
|
| 81 |
-
"
|
| 82 |
]
|
| 83 |
}
|
| 84 |
],
|
| 85 |
"source": [
|
| 86 |
"# Load dataset and check length\n",
|
| 87 |
-
"df = pd.read_csv(\"
|
| 88 |
"print(\"Total dataset examples: \" + str(len(df)))\n",
|
| 89 |
"print(\"\\n\")\n",
|
| 90 |
"\n",
|
|
@@ -126,7 +130,7 @@
|
|
| 126 |
},
|
| 127 |
{
|
| 128 |
"cell_type": "code",
|
| 129 |
-
"execution_count":
|
| 130 |
"metadata": {},
|
| 131 |
"outputs": [
|
| 132 |
{
|
|
@@ -159,7 +163,7 @@
|
|
| 159 |
},
|
| 160 |
{
|
| 161 |
"cell_type": "code",
|
| 162 |
-
"execution_count":
|
| 163 |
"metadata": {},
|
| 164 |
"outputs": [
|
| 165 |
{
|
|
@@ -202,7 +206,7 @@
|
|
| 202 |
},
|
| 203 |
{
|
| 204 |
"cell_type": "code",
|
| 205 |
-
"execution_count":
|
| 206 |
"metadata": {},
|
| 207 |
"outputs": [
|
| 208 |
{
|
|
|
|
| 9 |
},
|
| 10 |
{
|
| 11 |
"cell_type": "code",
|
| 12 |
+
"execution_count": 31,
|
| 13 |
"metadata": {},
|
| 14 |
"outputs": [],
|
| 15 |
"source": [
|
|
|
|
| 19 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 20 |
"import torch\n",
|
| 21 |
"import sys\n",
|
| 22 |
+
"import os\n",
|
| 23 |
"import sqlite3 as sql\n",
|
| 24 |
"from huggingface_hub import snapshot_download"
|
| 25 |
]
|
| 26 |
},
|
| 27 |
{
|
| 28 |
"cell_type": "code",
|
| 29 |
+
"execution_count": 32,
|
| 30 |
"metadata": {},
|
| 31 |
"outputs": [],
|
| 32 |
"source": [
|
|
|
|
| 35 |
},
|
| 36 |
{
|
| 37 |
"cell_type": "code",
|
| 38 |
+
"execution_count": 33,
|
| 39 |
"metadata": {},
|
| 40 |
"outputs": [],
|
| 41 |
"source": [
|
| 42 |
+
"current_path = \"./\"\n",
|
| 43 |
+
"\n",
|
| 44 |
"if is_google_colab:\n",
|
| 45 |
" hugging_face_path = snapshot_download(\n",
|
| 46 |
" repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
|
| 47 |
" repo_type=\"model\", \n",
|
| 48 |
" allow_patterns=[\"src/*\"], \n",
|
| 49 |
" )\n",
|
| 50 |
+
" sys.path.append(hugging_face_path)\n",
|
| 51 |
+
" current_path = hugging_face_path"
|
| 52 |
]
|
| 53 |
},
|
| 54 |
{
|
| 55 |
"cell_type": "code",
|
| 56 |
+
"execution_count": 34,
|
| 57 |
"metadata": {},
|
| 58 |
"outputs": [],
|
| 59 |
"source": [
|
|
|
|
| 70 |
},
|
| 71 |
{
|
| 72 |
"cell_type": "code",
|
| 73 |
+
"execution_count": 36,
|
| 74 |
"metadata": {},
|
| 75 |
"outputs": [
|
| 76 |
{
|
|
|
|
| 80 |
"Total dataset examples: 1044\n",
|
| 81 |
"\n",
|
| 82 |
"\n",
|
| 83 |
+
"How many points did the Phoenix Suns score in the highest scoring away game they played?\n",
|
| 84 |
+
"SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'PHX';\n",
|
| 85 |
+
"161.0\n"
|
| 86 |
]
|
| 87 |
}
|
| 88 |
],
|
| 89 |
"source": [
|
| 90 |
"# Load dataset and check length\n",
|
| 91 |
+
"df = pd.read_csv(os.path.join(current_path, \"train-data/sql_train.tsv\"), sep=\"\\t\")\n",
|
| 92 |
"print(\"Total dataset examples: \" + str(len(df)))\n",
|
| 93 |
"print(\"\\n\")\n",
|
| 94 |
"\n",
|
|
|
|
| 130 |
},
|
| 131 |
{
|
| 132 |
"cell_type": "code",
|
| 133 |
+
"execution_count": 28,
|
| 134 |
"metadata": {},
|
| 135 |
"outputs": [
|
| 136 |
{
|
|
|
|
| 163 |
},
|
| 164 |
{
|
| 165 |
"cell_type": "code",
|
| 166 |
+
"execution_count": 29,
|
| 167 |
"metadata": {},
|
| 168 |
"outputs": [
|
| 169 |
{
|
|
|
|
| 206 |
},
|
| 207 |
{
|
| 208 |
"cell_type": "code",
|
| 209 |
+
"execution_count": 12,
|
| 210 |
"metadata": {},
|
| 211 |
"outputs": [
|
| 212 |
{
|