{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Finetune DeepSeek Coder 1.3B for NBA + Tennis Kaggle Databases SQLite Generation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tf_keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n", "\n" ] } ], "source": [ "import pandas as pd\n", "import torch\n", "from datasets import Dataset\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig, EarlyStoppingCallback, PreTrainedTokenizer\n", "from torch.utils.data import DataLoader\n", "import sys\n", "from peft import LoraConfig, get_peft_model, TaskType\n", "from huggingface_hub import snapshot_download\n", "import os\n", "import re\n", "import contextlib #helps make pip silent\n", "import sys\n", "import os\n", "import numpy as np\n", "with contextlib.redirect_stdout(sys.__stdout__), contextlib.redirect_stderr(sys.__stderr__):\n", " %pip install datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Define constants for using google colab or local runs" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "is_google_colab = False\n", "use_bnb = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Establish read and write paths" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "current_read_path = \"./\"\n", "current_write_path = \"./\"\n", "\n", "def read_path(rel_path):\n", " return os.path.join(current_read_path, rel_path)\n", "\n", "def write_path(rel_path):\n", " return os.path.join(current_write_path, rel_path)\n", "\n", "if is_google_colab:\n", " from google.colab import drive\n", " drive.mount('/content/drive')\n", " current_write_path = \"/content/drive/MyDrive/sql_gen\"\n", "\n", " hugging_face_path = snapshot_download(\n", " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n", " repo_type=\"model\",\n", " allow_patterns=[\"train-data/*\", \"deepseek-coder-1.3b-instruct/*\"], \n", " )\n", " sys.path.append(hugging_face_path)\n", " current_read_path = hugging_face_path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First define prompt" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9035\n", "7990\n" ] } ], "source": [ "from utils.prompts.nba_prompt import input_text as nba_prompt\n", "from utils.prompts.tennis_prompt import input_text as tennis_prompt\n", "\n", "print(len(nba_prompt))\n", "print(len(tennis_prompt))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data and convert to Dataset object tokenized by the DeepSeek model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total train dataset examples: 1014\n", " natural_query \\\n", "0 How many matches were played at Wimbledon in 2... \n", "1 How many US Open matches has Novak Djokovic pa... \n", "2 List all matches won by Cameron Norrie. \n", "3 What is the highest number of personal fouls c... \n", "4 What is the average points scored by the San A... \n", "\n", " sql_query result is_nba \n", "0 SELECT COUNT(*) FROM matches WHERE tourney_n... 239 False \n", "1 SELECT COUNT(*) FROM matches WHERE tourney_nam... 84 False \n", "2 SELECT tourney_name FROM matches WHERE winner_... NaN False \n", "3 SELECT MAX(pf_away) as max_pf FROM game WHERE ... 41 True \n", "4 SELECT AVG(pts_away) as avg_points FROM game ... 102.35 True \n", "Total test dataset examples: 250\n", " natural_query \\\n", "0 How many spanish (ESP) players are there? \n", "1 How many distinct players appear in the rankin... \n", "2 How many times did the Los Angeles Clippers lo... \n", "3 How many times have the Boston Celtics won an ... \n", "4 Show the most successful player by win count \n", "\n", " sql_query result \\\n", "0 SELECT COUNT(*) AS spanish_players FROM player... 3026 \n", "1 SELECT COUNT(DISTINCT player) AS distinct_play... 16174 \n", "2 SELECT COUNT(*) FROM game g WHERE g.team_abbre... 4 \n", "3 SELECT COUNT(*) FROM game WHERE team_abbre... 179 \n", "4 SELECT winner_name, COUNT(*) as total_wins FRO... Roger Federer|1305 \n", "\n", " is_nba \n", "0 False \n", "1 False \n", "2 True \n", "3 True \n", "4 False \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_13948\\3424569434.py:10: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", " df_train.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n", "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_13948\\3424569434.py:11: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", " df_test.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "adding special token\n", "32022\n", "32023\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 1014/1014 [00:19<00:00, 51.05 examples/s]\n", "Map: 100%|██████████| 250/250 [00:04<00:00, 51.36 examples/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'is_nba': False, 'input_ids': [32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32013, 2042, 417, 274, 20926, 20391, 344, 6145, 1267, 3881, 4694, 12780, 878, 4629, 5975, 547, 12780, 13, 185, 14108, 11593, 732, 285, 2066, 11767, 715, 185, 185, 3101, 3238, 6922, 185, 50, 577, 379, 1748, 782, 519, 17382, 262, 12050, 13, 185, 13403, 11866, 15787, 5787, 7449, 30862, 440, 3101, 3238, 1, 334, 185, 207, 440, 83, 7466, 62, 304, 1, 323, 13532, 11, 436, 1585, 1198, 2710, 21868, 21411, 334, 2600, 1017, 8, 185, 207, 440, 83, 7466, 62, 1523, 1, 323, 13532, 11, 730, 1585, 9715, 280, 254, 21868, 334, 68, 13, 70, 1787, 440, 54, 308, 11259, 249, 2456, 185, 207, 440, 25030, 1, 323, 13532, 11, 3137, 1585, 8592, 5426, 8507, 25110, 950, 440, 1982, 333, 950, 440, 8508, 468, 2456, 185, 207, 440, 4833, 62, 3017, 1, 323, 13532, 11, 294, 1585, 11988, 280, 6594, 279, 254, 21868, 4293, 185, 207, 440, 83, 7466, 62, 6217, 1, 323, 13532, 11, 251, 1585, 323, 17460, 8771, 334, 68, 13, 70, 1787, 440, 38, 950, 440, 44, 950, 440, 32, 2456, 185, 207, 440, 83, 7466, 62, 1984, 1, 5878, 1743, 11, 730, 1585, 8830, 4278, 280, 254, 21868, 334, 19393, 19393, 8213, 7127, 372, 1582, 8, 185, 207, 440, 10108, 62, 6487, 1, 323, 13532, 11, 294, 1585, 25395, 4168, 3750, 2372, 245, 21868, 185, 207, 440, 86, 4799, 62, 304, 1, 5878, 1743, 11, 294, 1585, 1198, 2710, 4982, 280, 254, 24556, 4168, 16813, 185, 207, 440, 86, 4799, 62, 25888, 1, 323, 13532, 11, 1032, 1585, 1972, 271, 1594, 280, 254, 4168, 16813, 334, 351, 683, 8, 185, 207, 440, 86, 4799, 62, 8470, 1, 323, 13532, 11, 730, 1585, 7481, 280, 6048, 8507, 54, 34, 950, 440, 48, 950, 440, 2360, 950, 3371, 3087, 185, 207, 440, 86, 4799, 62, 1523, 1, 323, 13532, 11, 1032, 1585, 9715, 280, 254, 4168, 16813, 185, 207, 440, 86, 4799, 62, 4560, 1, 323, 13532, 11, 1032, 1585, 422, 4799, 6, 82, 6066, 1295, 8507, 49, 1, 409, 440, 43, 2456, 185, 207, 440, 86, 4799, 62, 383, 1, 5878, 1743, 11, 294, 1585, 422, 4799, 6, 82, 5471, 279, 1783, 23990, 407, 185, 207, 440, 86, 4799, 62, 72, 404, 1, 323, 13532, 11, 436, 1585, 422, 4799, 6, 82, 3073, 2974, 334, 18680, 1616, 30619, 12409, 4797, 8, 185, 207, 440, 86, 4799, 62, 490, 1, 5878, 1743, 11, 436, 1585, 422, 4799, 6, 82, 4489, 429, 761, 280, 4168, 185, 207, 440, 9222, 250, 62, 304, 1, 5878, 1743, 11, 1574, 1585, 1198, 2710, 4982, 280, 254, 24556, 4168, 3298, 250, 185, 207, 440, 9222, 250, 62, 25888, 1, 323, 13532, 11, 436, 1585, 1972, 271, 1594, 280, 254, 4168, 3298, 250, 185, 207, 440, 9222, 250, 62, 8470, 1, 323, 13532, 11, 1032, 1585, 7481, 280, 6048, 327, 254, 3298, 250, 185, 207, 440, 9222, 250, 62, 1523, 1, 323, 13532, 11, 436, 1585, 10851, 250, 6, 82, 2192, 1208, 185, 207, 440, 9222, 250, 62, 4560, 1, 323, 13532, 11, 436, 1585, 10851, 250, 6, 82, 6066, 1295, 185, 207, 440, 9222, 250, 62, 383, 1, 5878, 1743, 11, 1574, 1585, 10851, 250, 6, 82, 5471, 185, 207, 440, 9222, 250, 62, 72, 404, 1, 323, 13532, 11, 294, 1585, 10851, 250, 6, 82, 3073, 2974, 185, 207, 440, 9222, 250, 62, 490, 1, 5878, 1743, 11, 294, 1585, 10851, 250, 6, 82, 4489, 185, 207, 440, 20709, 1, 323, 13532, 11, 2481, 1585, 16131, 4168, 8129, 279, 24450, 16464, 185, 207, 440, 15041, 62, 990, 1, 323, 13532, 11, 3137, 1585, 11988, 280, 6229, 254, 4168, 317, 7226, 276, 334, 27804, 207, 18, 409, 207, 20, 8, 185, 207, 440, 1033, 1, 323, 13532, 11, 2481, 1585, 323, 17460, 4069, 8507, 49, 21, 19, 950, 440, 48, 37, 950, 440, 31708, 950, 440, 37, 2456, 185, 207, 440, 1513, 3263, 1, 5878, 1743, 11, 3137, 1585, 23772, 13672, 279, 4054, 185, 207, 440, 86, 62, 582, 1, 5878, 1743, 11, 2481, 1585, 338, 1516, 5901, 457, 254, 16813, 185, 207, 440, 86, 62, 3934, 1, 5878, 1743, 11, 3462, 1585, 18010, 10456, 82, 1396, 457, 254, 16813, 185, 207, 440, 86, 62, 10477, 462, 1, 5878, 1743, 11, 655, 1585, 19090, 5029, 3472, 7226, 457, 254, 16813, 185, 207, 440, 86, 62, 16, 292, 769, 1, 5878, 1743, 11, 3137, 1585, 5899, 14716, 1396, 457, 254, 16813, 185, 207, 440, 86, 62, 16, 292, 54, 249, 1, 5878, 1743, 11, 1574, 1585, 5899, 12, 8994, 3472, 2103, 457, 254, 16813, 185, 207, 440, 86, 62, 17, 425, 54, 249, 1, 5878, 1743, 11, 1574, 1585, 11419, 12, 8994, 3472, 2103, 457, 254, 16813, 185, 207, 440, 86, 62, 50, 85, 38, 1400, 1, 5878, 1743, 11, 3137, 1585, 30905, 4951, 7226, 457, 254, 16813, 185, 207, 440, 86, 62, 17606, 50, 12614, 1, 5878, 1743, 11, 294, 1585, 23067, 3472, 9637, 457, 254, 16813, 185, 207, 440, 86, 62, 17606, 37, 3250, 1, 5878, 1743, 11, 294, 1585, 23067, 3472, 17879, 457, 254, 16813, 185, 207, 440, 75, 62, 582, 1, 5878, 1743, 11, 2481, 1585, 338, 1516, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 3934, 1, 5878, 1743, 11, 3462, 1585, 18010, 10456, 82, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 10477, 462, 1, 5878, 1743, 11, 655, 1585, 30905, 3472, 7226, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 16, 292, 769, 1, 5878, 1743, 11, 3137, 1585, 5899, 14716, 1396, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 16, 292, 54, 249, 1, 5878, 1743, 11, 1574, 1585, 5899, 12, 8994, 3472, 2103, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 17, 425, 54, 249, 1, 5878, 1743, 11, 1574, 1585, 11419, 12, 8994, 3472, 2103, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 50, 85, 38, 1400, 1, 5878, 1743, 11, 3137, 1585, 30905, 4951, 7226, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 17606, 50, 12614, 1, 5878, 1743, 11, 294, 1585, 23067, 3472, 9637, 457, 254, 3298, 250, 185, 207, 440, 75, 62, 17606, 37, 3250, 1, 5878, 1743, 11, 294, 1585, 23067, 3472, 17879, 457, 254, 3298, 250, 185, 207, 440, 86, 4799, 62, 17712, 1, 5878, 1743, 11, 1032, 1585, 422, 4799, 6, 82, 338, 7118, 24958, 429, 254, 761, 280, 4168, 185, 207, 440, 86, 4799, 62, 17712, 62, 12168, 1, 5878, 1743, 11, 315, 1585, 422, 4799, 6, 82, 338, 7118, 24958, 3472, 185, 207, 440, 9222, 250, 62, 17712, 1, 5878, 1743, 11, 436, 1585, 10851, 250, 6, 82, 338, 7118, 24958, 185, 207, 440, 9222, 250, 62, 17712, 62, 12168, 1, 5878, 1743, 11, 1585, 10851, 250, 6, 82, 338, 7118, 3472, 185, 207, 440, 86, 4799, 16, 62, 304, 1, 5878, 1743, 11, 436, 1585, 22536, 4982, 82, 327, 254, 12697, 4678, 7666, 2547, 185, 207, 440, 86, 4799, 17, 62, 304, 1, 323, 13532, 11, 436, 1585, 22536, 4982, 82, 327, 254, 12697, 4678, 7666, 2547, 185, 207, 440, 9222, 250, 16, 62, 304, 1, 323, 13532, 11, 294, 1585, 22536, 4982, 82, 327, 254, 13957, 4678, 7666, 2547, 185, 207, 440, 9222, 250, 17, 62, 304, 1, 5878, 1743, 11, 294, 1585, 22536, 4982, 82, 327, 254, 13957, 4678, 7666, 2547, 185, 207, 440, 86, 4799, 16, 62, 1523, 1, 323, 13532, 11, 730, 1585, 9715, 280, 12697, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 86, 4799, 16, 62, 4560, 1, 5878, 1743, 11, 730, 1585, 7836, 272, 1295, 327, 12697, 4678, 7666, 6594, 1494, 16, 185, 207, 440, 86, 4799, 16, 62, 383, 1, 323, 13532, 11, 436, 1585, 1061, 447, 280, 12697, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 86, 4799, 16, 62, 72, 404, 1, 5878, 1743, 11, 1032, 1585, 5549, 465, 280, 12697, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 86, 4799, 16, 62, 490, 1, 323, 13532, 11, 1032, 1585, 20252, 280, 12697, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 86, 4799, 17, 62, 1523, 1, 323, 13532, 11, 730, 1585, 9715, 280, 12697, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 86, 4799, 17, 62, 4560, 1, 5878, 1743, 11, 730, 1585, 7836, 272, 1295, 327, 12697, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 86, 4799, 17, 62, 383, 1, 323, 13532, 11, 436, 1585, 1061, 447, 280, 12697, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 86, 4799, 17, 62, 72, 404, 1, 5878, 1743, 11, 1032, 1585, 5549, 465, 280, 12697, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 86, 4799, 17, 62, 490, 1, 323, 13532, 11, 1032, 1585, 20252, 280, 12697, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 9222, 250, 16, 62, 1523, 1, 323, 13532, 11, 1032, 1585, 9715, 280, 13957, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 9222, 250, 16, 62, 4560, 1, 5878, 1743, 11, 1032, 1585, 7836, 272, 1295, 327, 13957, 4678, 7666, 6594, 1494, 16, 185, 207, 440, 9222, 250, 16, 62, 383, 1, 323, 13532, 11, 294, 1585, 1061, 447, 280, 13957, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 9222, 250, 16, 62, 72, 404, 1, 5878, 1743, 11, 436, 1585, 5549, 465, 280, 13957, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 9222, 250, 16, 62, 490, 1, 323, 13532, 11, 436, 1585, 20252, 280, 207, 13957, 4678, 7666, 6706, 1494, 16, 185, 207, 440, 9222, 250, 17, 62, 1523, 1, 323, 13532, 11, 1032, 1585, 9715, 280, 13957, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 9222, 250, 17, 62, 4560, 1, 5878, 1743, 11, 1032, 1585, 7836, 272, 1295, 327, 13957, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 9222, 250, 17, 62, 383, 1, 323, 13532, 11, 294, 1585, 1061, 447, 280, 13957, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 9222, 250, 17, 62, 72, 404, 1, 5878, 1743, 11, 436, 1585, 5549, 465, 280, 13957, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 9222, 250, 17, 62, 490, 1, 5878, 1743, 11, 436, 1585, 20252, 280, 13957, 4678, 7666, 6706, 1494, 17, 185, 207, 440, 86, 4799, 16, 62, 17712, 1, 5878, 1743, 11, 1574, 1585, 14790, 7666, 1494, 16, 16813, 24958, 185, 207, 440, 86, 4799, 16, 62, 17712, 62, 12168, 1, 5878, 1743, 11, 243, 1585, 14790, 7666, 1494, 16, 16813, 7053, 3472, 185, 207, 440, 86, 4799, 17, 62, 17712, 1, 5878, 1743, 11, 1574, 1585, 14790, 7666, 1494, 17, 16813, 24958, 185, 207, 440, 86, 4799, 17, 62, 17712, 62, 12168, 1, 5878, 1743, 11, 243, 1585, 14790, 7666, 1494, 17, 16813, 7053, 3472, 185, 207, 440, 9222, 250, 16, 62, 17712, 1, 5878, 1743, 11, 3137, 1585, 14790, 7666, 1494, 16, 3298, 250, 24958, 185, 207, 440, 9222, 250, 16, 62, 17712, 62, 12168, 1, 5878, 1743, 11, 315, 1585, 14790, 7666, 1494, 16, 3298, 250, 7053, 3472, 185, 207, 440, 9222, 250, 17, 62, 17712, 1, 5878, 1743, 11, 3137, 1585, 14790, 7666, 1494, 17, 3298, 250, 24958, 185, 207, 440, 9222, 250, 17, 62, 17712, 62, 12168, 1, 5878, 1743, 251, 1585, 14790, 7666, 1494, 17, 3298, 250, 7053, 3472, 185, 477, 185, 185, 2035, 407, 6922, 185, 50, 577, 379, 1748, 782, 519, 24450, 6594, 13, 185, 13403, 11866, 15787, 5787, 7449, 30862, 440, 2035, 407, 1, 334, 185, 207, 440, 15276, 62, 304, 1, 3379, 4463, 18924, 11, 207, 1585, 1198, 2710, 6706, 21411, 334, 14775, 2119, 8, 185, 207, 440, 4560, 1, 323, 13532, 11, 1574, 1585, 7836, 272, 1295, 8507, 49, 950, 440, 43, 2456, 185, 207, 440, 67, 656, 1, 5878, 1743, 11, 3137, 1585, 9312, 280, 7394, 334, 19393, 19393, 8213, 7127, 8, 185, 207, 440, 72, 404, 1, 323, 13532, 11, 3137, 1585, 17430, 2974, 334, 609, 13, 11156, 327, 4783, 5098, 280, 6092, 8, 185, 207, 440, 6107, 1, 5878, 1743, 11, 436, 1585, 1061, 447, 279, 1783, 23990, 407, 185, 207, 1208, 323, 13532, 2481, 1585, 22536, 6, 82, 2192, 1208, 185, 477, 185, 185, 17712, 787, 6922, 185, 29133, 1748, 782, 6706, 7053, 787, 851, 1442, 13567, 185, 13403, 11866, 15787, 5787, 7449, 30862, 440, 17712, 787, 1, 334, 185, 207, 440, 17712, 272, 62, 1984, 1, 3379, 4463, 18924, 11, 243, 1585, 9312, 280, 254, 24958, 23561, 334, 19393, 19393, 8213, 7127, 8, 185, 207, 440, 17712, 1, 3379, 4463, 18924, 11, 3137, 1585, 22536, 6, 82, 1835, 24958, 331, 344, 4278, 185, 207, 440, 15276, 1, 3379, 4463, 18924, 11, 294, 1585, 4982, 280, 254, 6706, 334, 1251, 617, 2119, 276, 6594, 13, 15276, 62, 304, 8, 185, 207, 440, 12168, 1, 5878, 1743, 2481, 1585, 31175, 272, 3472, 331, 344, 4278, 207, 185, 477, 185, 185, 13956, 21539, 30975, 185, 16, 13, 7310, 885, 254, 7214, 285, 10115, 4212, 2321, 13, 185, 17, 13, 7310, 3493, 3812, 4761, 13, 185, 18, 13, 9320, 6419, 7214, 750, 4362, 334, 68, 13, 70, 1787, 6594, 207, 8797, 229, 7053, 787, 207, 8797, 229, 12050, 3752, 6706, 62, 304, 409, 16813, 62, 304, 14, 9222, 250, 62, 304, 628, 185, 19, 13, 1271, 245, 2503, 317, 21707, 621, 11, 4340, 254, 1093, 11050, 15864, 13, 185, 20, 13, 3119, 441, 11510, 1238, 17575, 10115, 409, 2365, 4761, 13, 207, 185, 185, 9138, 16813, 62, 1523, 409, 3298, 250, 62, 1523, 276, 2893, 2461, 327, 245, 2017, 6706, 473, 254, 12050, 2365, 13, 7310, 6975, 3281, 62, 1523, 276, 2893, 1748, 473, 3041, 244, 4307, 7270, 13, 207, 185, 185, 1889, 6226, 457, 13567, 11, 931, 254, 4278, 4797, 765, 19393, 56, 8213, 7127, 185, 185, 15013, 10481, 10413, 6074, 285, 5975, 547, 3130, 7486, 185, 4397, 25, 185, 2808, 1311, 6594, 417, 2104, 12, 29125, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 6594, 11294, 1295, 405, 651, 43, 4057, 185, 185, 4397, 25, 185, 2808, 1311, 12050, 429, 254, 2604, 6304, 30789, 686, 849, 207, 17, 19, 15, 4054, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 12050, 11294, 6975, 3281, 62, 1523, 405, 651, 3648, 6304, 6, 5584, 4054, 6213, 17, 19, 15, 26, 185, 185, 4397, 25, 185, 2628, 317, 254, 2567, 370, 4168, 7226, 457, 11784, 556, 413, 73, 541, 872, 278, 30, 185, 6231, 547, 25, 185, 7507, 31329, 7, 1513, 3263, 8, 7432, 12050, 11294, 16813, 62, 1523, 405, 651, 24682, 556, 413, 73, 541, 872, 278, 6, 6982, 3298, 250, 62, 1523, 405, 651, 24682, 556, 413, 73, 541, 872, 278, 4057, 185, 185, 4397, 25, 185, 2628, 317, 254, 6054, 1594, 280, 24958, 3472, 4578, 457, 683, 6706, 331, 683, 4278, 30, 185, 6231, 547, 25, 185, 7507, 21234, 7, 12168, 8, 4958, 3034, 62, 12168, 7432, 7053, 787, 7840, 185, 185, 4397, 25, 185, 2808, 1311, 12050, 638, 21623, 23115, 12184, 4726, 276, 24857, 26529, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 12050, 11294, 3298, 250, 62, 1523, 405, 651, 49, 493, 250, 23115, 12184, 6, 5584, 16813, 62, 1523, 405, 651, 2270, 88, 26529, 4057, 185, 185, 4397, 25, 185, 2808, 1311, 6594, 773, 7730, 1321, 207, 16, 24, 23, 15, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 6594, 11294, 533, 65, 8086, 16, 24, 23, 15, 15, 16, 15, 16, 26, 185, 185, 7605, 387, 885, 254, 5975, 547, 5151, 3651, 3250, 457, 5975, 547, 25, 285, 637, 746, 2422, 11, 533, 441, 2816, 274, 11543, 280, 254, 5151, 13, 4195, 8297, 274, 5975, 547, 5151, 327, 254, 1884, 2664, 3092, 13, 17858, 25, 185, 2808, 1311, 731, 11506, 334, 2718, 47, 8, 6594, 417, 741, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 4958, 731, 11506, 62, 2035, 407, 7432, 6594, 11294, 460, 404, 405, 651, 2718, 47, 4057, 32022], 'attention_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 17712, 1, 3379, 4463, 18924, 11, 3137, 1585, 22536, 6, 82, 1835, 24958, 331, 344, 4278, 185, 207, 440, 15276, 1, 3379, 4463, 18924, 11, 294, 1585, 4982, 280, 254, 6706, 334, 1251, 617, 2119, 276, 6594, 13, 15276, 62, 304, 8, 185, 207, 440, 12168, 1, 5878, 1743, 2481, 1585, 31175, 272, 3472, 331, 344, 4278, 207, 185, 477, 185, 185, 13956, 21539, 30975, 185, 16, 13, 7310, 885, 254, 7214, 285, 10115, 4212, 2321, 13, 185, 17, 13, 7310, 3493, 3812, 4761, 13, 185, 18, 13, 9320, 6419, 7214, 750, 4362, 334, 68, 13, 70, 1787, 6594, 207, 8797, 229, 7053, 787, 207, 8797, 229, 12050, 3752, 6706, 62, 304, 409, 16813, 62, 304, 14, 9222, 250, 62, 304, 628, 185, 19, 13, 1271, 245, 2503, 317, 21707, 621, 11, 4340, 254, 1093, 11050, 15864, 13, 185, 20, 13, 3119, 441, 11510, 1238, 17575, 10115, 409, 2365, 4761, 13, 207, 185, 185, 9138, 16813, 62, 1523, 409, 3298, 250, 62, 1523, 276, 2893, 2461, 327, 245, 2017, 6706, 473, 254, 12050, 2365, 13, 7310, 6975, 3281, 62, 1523, 276, 2893, 1748, 473, 3041, 244, 4307, 7270, 13, 207, 185, 185, 1889, 6226, 457, 13567, 11, 931, 254, 4278, 4797, 765, 19393, 56, 8213, 7127, 185, 185, 15013, 10481, 10413, 6074, 285, 5975, 547, 3130, 7486, 185, 4397, 25, 185, 2808, 1311, 6594, 417, 2104, 12, 29125, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 6594, 11294, 1295, 405, 651, 43, 4057, 185, 185, 4397, 25, 185, 2808, 1311, 12050, 429, 254, 2604, 6304, 30789, 686, 849, 207, 17, 19, 15, 4054, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 12050, 11294, 6975, 3281, 62, 1523, 405, 651, 3648, 6304, 6, 5584, 4054, 6213, 17, 19, 15, 26, 185, 185, 4397, 25, 185, 2628, 317, 254, 2567, 370, 4168, 7226, 457, 11784, 556, 413, 73, 541, 872, 278, 30, 185, 6231, 547, 25, 185, 7507, 31329, 7, 1513, 3263, 8, 7432, 12050, 11294, 16813, 62, 1523, 405, 651, 24682, 556, 413, 73, 541, 872, 278, 6, 6982, 3298, 250, 62, 1523, 405, 651, 24682, 556, 413, 73, 541, 872, 278, 4057, 185, 185, 4397, 25, 185, 2628, 317, 254, 6054, 1594, 280, 24958, 3472, 4578, 457, 683, 6706, 331, 683, 4278, 30, 185, 6231, 547, 25, 185, 7507, 21234, 7, 12168, 8, 4958, 3034, 62, 12168, 7432, 7053, 787, 7840, 185, 185, 4397, 25, 185, 2808, 1311, 12050, 638, 21623, 23115, 12184, 4726, 276, 24857, 26529, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 12050, 11294, 3298, 250, 62, 1523, 405, 651, 49, 493, 250, 23115, 12184, 6, 5584, 16813, 62, 1523, 405, 651, 2270, 88, 26529, 4057, 185, 185, 4397, 25, 185, 2808, 1311, 6594, 773, 7730, 1321, 207, 16, 24, 23, 15, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 7432, 6594, 11294, 533, 65, 8086, 16, 24, 23, 15, 15, 16, 15, 16, 26, 185, 185, 7605, 387, 885, 254, 5975, 547, 5151, 3651, 3250, 457, 5975, 547, 25, 285, 637, 746, 2422, 11, 533, 441, 2816, 274, 11543, 280, 254, 5151, 13, 4195, 8297, 274, 5975, 547, 5151, 327, 254, 1884, 2664, 3092, 13, 17858, 25, 185, 2808, 1311, 731, 11506, 334, 2718, 47, 8, 6594, 417, 741, 30, 185, 6231, 547, 25, 185, 7507, 31970, 7, 10230, 4958, 731, 11506, 62, 2035, 407, 7432, 6594, 11294, 460, 404, 405, 651, 2718, 47, 4057, 32022]}\n", "3156\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Model output directories\n", "MODEL_DIR = write_path(\"finetuned-model-16-full\")\n", "VAL_OUTPUT = write_path(\"val-16-full.hf\")\n", "\n", "# Load dataset\n", "df_train = pd.read_csv(read_path(\"training-data/combined_full_dataset.tsv\"), sep='\\t')\n", "df_test = pd.read_csv(read_path(\"training-data/test_set.tsv\"), sep='\\t')\n", "\n", "# Fix any spacing issues\n", "df_train.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n", "df_test.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n", "\n", "# Display dataset info\n", "print(f\"Total train dataset examples: {len(df_train)}\")\n", "print(df_train.head())\n", "print(f\"Total test dataset examples: {len(df_test)}\")\n", "print(df_test.head())\n", "# Load tokenizer\n", "model_name = read_path(\"deepseek-coder-1.3b-instruct\")\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", "# Enable 8-bit quantization for lower memory usage\n", "bnb_config = None\n", "if use_bnb:\n", " bnb_config = BitsAndBytesConfig(\n", " load_in_8bit=True, \n", " bnb_8bit_compute_dtype=torch.float16\n", " )\n", "\n", "# Load model with quantization\n", "device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", "device = torch.device(device_name)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name, \n", " quantization_config=bnb_config,\n", " device_map=device\n", ")\n", "\n", "# Add a custom stop token (can be anything that won’t show up in your data)\n", "special_token = \"<|endofsql|>\"\n", "\n", "# Only add if it doesn’t already exist\n", "print(\"adding special token\")\n", "print(len(tokenizer))\n", "tokenizer.add_special_tokens({\"additional_special_tokens\": [special_token]})\n", "tokenizer.eos_token = special_token\n", "model.resize_token_embeddings(len(tokenizer))\n", "print(len(tokenizer))\n", "\n", "tokenizer.truncation_side = \"left\"\n", "\n", "def format_deepseek_chat(example, tokenizer, special_token=\"<|endofsql|>\"):\n", " # Manually build the prompt as one flat string\n", " if example['is_nba']:\n", " prompt = f\"{nba_prompt}{example['natural_query']}\\n\"\n", " else:\n", " prompt = f\"{tennis_prompt}{example['natural_query']}\\n\"\n", "\n", " completion = f\"SQLite:\\n{example['sql_query']}{special_token}\"\n", "\n", " full_text = prompt + completion\n", " tokenized = tokenizer(\n", " full_text,\n", " truncation=True,\n", " padding=\"max_length\",\n", " max_length=3156, # or whatever your model can handle\n", " )\n", "\n", " # Mask out prompt tokens in the labels\n", " prompt_len = len(tokenizer(prompt, truncation=True)[\"input_ids\"])\n", " labels = tokenized[\"input_ids\"][:]\n", " labels[:prompt_len] = [-100] * prompt_len\n", " tokenized[\"labels\"] = labels\n", "\n", " return tokenized\n", "\n", "# Build dataset dict\n", "train_dataset_dict = {\n", " \"natural_query\": df_train[\"natural_query\"].tolist(),\n", " \"sql_query\": df_train[\"sql_query\"].tolist(),\n", " \"is_nba\": df_train[\"is_nba\"].tolist(),\n", "}\n", "\n", "\n", "val_dataset_dict = {\n", " \"natural_query\": df_test[\"natural_query\"].tolist(),\n", " \"sql_query\": df_test[\"sql_query\"].tolist(),\n", " \"is_nba\": df_test[\"is_nba\"].tolist(),\n", "}\n", "\n", "# Create HuggingFace Dataset\n", "train_dataset = Dataset.from_dict(train_dataset_dict)\n", "val_dataset = Dataset.from_dict(val_dataset_dict)\n", "\n", "# Apply formatting\n", "train_dataset = train_dataset.map(\n", " lambda x: format_deepseek_chat(x, tokenizer),\n", " remove_columns=[\"natural_query\", \"sql_query\"]\n", ")\n", "\n", "val_dataset = val_dataset.map(\n", " lambda x: format_deepseek_chat(x, tokenizer),\n", " remove_columns=[\"natural_query\", \"sql_query\"]\n", ")\n", "\n", "del df_train, df_test, train_dataset_dict, val_dataset_dict\n", "\n", "\n", "for v in val_dataset:\n", " print(v)\n", " print(len(v['input_ids']))\n", " break\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load model and define training arguments" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 14,991,360 || all params: 1,360,508,928 || trainable%: 1.1019\n" ] } ], "source": [ "# Define LoRA configuration\n", "lora_config = LoraConfig(\n", " r=16, # Rank of LoRA matrices (adjust for memory vs. accuracy)\n", " lora_alpha=32, # Scaling factor\n", " lora_dropout=0.0, # Dropout for regularization\n", " bias=\"none\",\n", " task_type=TaskType.CAUSAL_LM,\n", " target_modules=[\n", " \"q_proj\",\n", " \"k_proj\",\n", " \"v_proj\",\n", " \"o_proj\",\n", " \"gate_proj\",\n", " \"up_proj\",\n", " \"down_proj\"\n", " ]\n", ")\n", "\n", "# Wrap model with LoRA adapters\n", "model = get_peft_model(model, lora_config)\n", "model = model.to(device)\n", "model.print_trainable_parameters() # Show trainable parameters count" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup model trainer" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_13948\\2528805260.py:21: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " trainer = Trainer(\n", "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n" ] } ], "source": [ "training_args = TrainingArguments(\n", " output_dir=MODEL_DIR,\n", " evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n", " save_strategy=\"epoch\", # Save model every epoch\n", " per_device_train_batch_size=1, \n", " per_device_eval_batch_size=1,\n", " gradient_accumulation_steps=16,\n", " num_train_epochs=10, # Increase if needed\n", " learning_rate=5e-5, # Higher LR since we're only training LoRA layers\n", " weight_decay=0.001,\n", " logging_steps=50, # Print loss every 50 steps\n", " save_total_limit=2, # Keep last 2 checkpoints\n", " bf16=True if torch.cuda.is_available() else False,\n", " push_to_hub=False,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"eval_loss\",\n", " greater_is_better=False\n", ")\n", "\n", "# Trainer setup\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " tokenizer=tokenizer,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run fine-tuning and save model weights when complete" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\integrations\\sdpa_attention.py:54: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n", " attn_output = torch.nn.functional.scaled_dot_product_attention(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [576/630 45:29:22 < 4:16:46, 0.00 it/s, Epoch 9/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
10.5977000.167915
20.1177000.137110
30.0936000.125831
40.0773000.116754
50.0715000.113243
60.0678000.110944
70.0628000.109875
80.0563000.111329
90.0518000.110340

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\utils\\save_and_load.py:250: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n", " warnings.warn(\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\peft\\tuners\\lora\\bnb.py:85: UserWarning: Merge lora module to 8-bit linear may get different generations due to rounding errors.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "('./finetuned-model-16-full\\\\tokenizer_config.json',\n", " './finetuned-model-16-full\\\\special_tokens_map.json',\n", " './finetuned-model-16-full\\\\tokenizer.json')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Run training\n", "trainer.train()\n", "\n", "# Merge LoRA adapters with the base model before saving\n", "model = model.merge_and_unload()\n", "model.save_pretrained(MODEL_DIR)\n", "tokenizer.save_pretrained(MODEL_DIR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Try inference using fine-tuned model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n", "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generated SQL: SQLite:\n", "SELECT AVG(pts_home) FROM game WHERE team_name_home = 'Los Angeles Lakers';anyes\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n", "!\n" ] } ], "source": [ "model = AutoModelForCausalLM.from_pretrained(MODEL_DIR, torch_dtype=torch.bfloat16, device_map=device)\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)\n", "\n", "# Prepare query with the same prompt\n", "input_text = \"How many points do the Los Angeles Lakers average at home?\"\n", "message = [{'role': 'user', 'content': nba_prompt + input_text}]\n", "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", "\n", "# Generate SQL query\n", "outputs = model.generate(\n", " inputs,\n", " max_new_tokens=256,\n", " eos_token_id=tokenizer.convert_tokens_to_ids(\"<|endofsql|>\")\n", ")\n", "model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", "\n", "print(\"Generated SQL:\", model_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save validation set to disk" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Saving the dataset (1/1 shards): 100%|██████████| 250/250 [00:00<00:00, 27776.11 examples/s]\n" ] } ], "source": [ "val_dataset.save_to_disk(VAL_OUTPUT)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test logic for obtaining original prompt and SQLite" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "How many spanish (ESP) players are there?\n", "\n", "SELECT COUNT(*) AS spanish_players FROM players WHERE ioc = 'ESP';\n" ] }, { "ename": "OperationalError", "evalue": "no such table: players", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mOperationalError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[11], line 17\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(question)\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28mprint\u001b[39m(sql_query)\n\u001b[1;32m---> 17\u001b[0m \u001b[43mcursor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43msql_query\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 18\u001b[0m rows \u001b[38;5;241m=\u001b[39m cursor\u001b[38;5;241m.\u001b[39mfetchall()\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m rows:\n", "\u001b[1;31mOperationalError\u001b[0m: no such table: players" ] } ], "source": [ "import sqlite3 as sql\n", "\n", "# Create connection to sqlite3 database\n", "connection = sql.connect(read_path('nba-data/nba.sqlite'))\n", "cursor = connection.cursor()\n", "\n", "for v in val_dataset:\n", " if v[\"is_nba\"]:\n", " prompt_length = len(nba_prompt)\n", " else:\n", " prompt_length = len(tennis_prompt)\n", " full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n", " user_prompt = full_example[:prompt_length]\n", " question, sql_query = full_example[prompt_length:].split(\"SQLite:\\n\")\n", " print(question)\n", " print(sql_query)\n", " cursor.execute(sql_query)\n", " rows = cursor.fetchall()\n", " for row in rows:\n", " print(row)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run evaluation over entire validation set" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import random\n", "\n", "def compare_result(sample_query, model_output):\n", " # Clean model output to only have the query output\n", " if model_output[0:8] == \"SQLite:\\n\":\n", " query = model_output[8:]\n", " elif model_output[0:8] == \"SQLite: \":\n", " query = model_output[8:]\n", " elif model_output[0:7] == \"SQLite:\":\n", " query = model_output[7:]\n", " elif model_output[0:5] == \"SQL:\\n\":\n", " query = model_output[5:]\n", " elif model_output[0:5] == \"SQL: \":\n", " query = model_output[5:]\n", " elif model_output[0:4] == \"SQL:\":\n", " query = model_output[4:]\n", " else:\n", " query = model_output\n", "\n", " # Clean any excess text after the query semicolon\n", " for i in range(len(query)):\n", " if query[i] == \";\":\n", " query = query[:i+1]\n", " break\n", "\n", " # Get sample and model result\n", " cursor.execute(sample_query)\n", " sample_result = [str(item) for tup in cursor.fetchall() for item in tup]\n", "\n", " try:\n", " cursor.execute(query)\n", " except:\n", " return False, False, False\n", " model_result = [str(item) for tup in cursor.fetchall() for item in tup]\n", "\n", " print(sample_result)\n", " print(model_result)\n", "\n", " # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.\n", " query = query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n", " sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n", " query_match = (query == sample_query)\n", "\n", " # If the queries match, the results clearly also match\n", " if query_match:\n", " return True, True, True\n", "\n", " # Try to execute query, if it fails, then this is a failure of the model\n", " try:\n", " for r in sample_result:\n", " for res in model_result:\n", " try:\n", " if math.isclose(float(r), float(res), abs_tol=0.5):\n", " return True, False, True\n", " except:\n", " if r in res or res in r:\n", " return True, False, True\n", " return True, False, False\n", " except:\n", " return True, False, False\n", " \n", "num_valid = 0\n", "num_sql_matched = 0\n", "num_result_matched = 0\n", "\n", "for v in val_dataset:\n", " # Obtain sample natural language question and sql_query\n", " #v = val_dataset[random.randint(0, len(val_dataset) - 1)]\n", " full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n", " user_prompt = full_example[:prompt_length]\n", " question, sql_query = full_example[prompt_length:].split(\"SQLite:\\n\")\n", " #print(question)\n", " #print(sql_query)\n", "\n", " # Obtain model output\n", " input_text = \"How many points to the Los Angeles Lakers average at home?\"\n", " message = [{'role': 'user', 'content': nba_prompt + question}]\n", " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", "\n", " # Generate SQL query\n", " outputs = model.generate(\n", " inputs,\n", " max_new_tokens=256,\n", " eos_token_id=tokenizer.convert_tokens_to_ids(\"<|endofsql|>\")\n", " )\n", " model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", "\n", " print(sql_query)\n", " print(model_output.split(\";\")[0])\n", " #print()\n", " #print(model_output)\n", " result = compare_result(sql_query, model_output)\n", " print(\"Statement valid? \" + str(result[0]))\n", " print(\"SQLite matched? \" + str(result[1]))\n", " print(\"Result matched? \" + str(result[2]))\n", " print()\n", " print()\n", "\n", " if result[0]:\n", " num_valid += 1\n", " if result[1]:\n", " num_sql_matched += 1\n", " if result[2]:\n", " num_result_matched += 1\n", "\n", "print(\"Percent valid: \" + str(num_valid / len(val_dataset)))\n", "print(\"Percent SQLite matched: \" + str(num_sql_matched / len(val_dataset)))\n", "print(\"Percent result matched: \" + str(num_result_matched / len(val_dataset)))\n", "\n", "# break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Tennis Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_text = \"Which hand does Pete Sampras use?\"\n", "message = [{'role': 'user', 'content': tennis_prompt + input_text}]\n", "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", "\n", "# Generate SQL query\n", "outputs = model.generate(\n", " inputs,\n", " max_new_tokens=256,\n", " eos_token_id=tokenizer.convert_tokens_to_ids(\"<|endofsql|>\")\n", ")\n", "model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", "\n", "print(\"Generated SQL:\", model_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test validation set only on short queries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_valid = 0\n", "num_sql_matched = 0\n", "num_result_matched = 0\n", "counter = 0\n", "for v in val_dataset:\n", " # Obtain sample natural language question and sql_query\n", " #v = val_dataset[random.randint(0, len(val_dataset) - 1)]\n", " full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n", " user_prompt = full_example[:prompt_length]\n", " question, sql_query = full_example[prompt_length:].split(\"SQLite:\\n\")\n", " #print(question)\n", " #print(sql_query)\n", "\n", " if len(sql_query) <= 90:\n", " # Obtain model output\n", " message = [{'role': 'user', 'content': nba_prompt + question}]\n", " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", "\n", " # Generate SQL query\n", " outputs = model.generate(\n", " inputs,\n", " max_new_tokens=256,\n", " eos_token_id=tokenizer.convert_tokens_to_ids(\"<|endofsql|>\")\n", " )\n", " model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", "\n", " print(sql_query)\n", " print(model_output.split(\";\")[0])\n", " #print()\n", " #print(model_output)\n", " result = compare_result(sql_query, model_output)\n", " print(\"Statement valid? \" + str(result[0]))\n", " print(\"SQLite matched? \" + str(result[1]))\n", " print(\"Result matched? \" + str(result[2]))\n", " print()\n", " print()\n", " counter += 1\n", "\n", " if result[0]:\n", " num_valid += 1\n", " if result[1]:\n", " num_sql_matched += 1\n", " if result[2]:\n", " num_result_matched += 1\n", "\n", "print(\"Percent valid: \" + str(num_valid / counter))\n", "print(\"Percent SQLite matched: \" + str(num_sql_matched / counter))\n", "print(\"Percent result matched: \" + str(num_result_matched / counter))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test privacy breaking inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# See if we can generate NBA SQL with the tennis prompt\n", "input_text = \"What is the abbreviation of the team nicknamed 'Heat'?\"\n", "message = [{'role': 'user', 'content': tennis_prompt + input_text}]\n", "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", "\n", "# Generate SQL query\n", "outputs = model.generate(\n", " inputs,\n", " max_new_tokens=256,\n", " eos_token_id=tokenizer.convert_tokens_to_ids(\"<|endofsql|>\")\n", ")\n", "model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", "\n", "print(\"Generated SQL:\", model_output)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 2 }