File size: 54,361 Bytes
d7173d6
1
2
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# <b><span style='color:#F1A424'>|</span> Import Libraries</b><a class='anchor' id='import_libraries'></a> [↑](#top) \n","\n","***\n","\n","Import all the required libraries for this notebook."]},{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:23.520178Z","iopub.status.busy":"2024-09-19T17:05:23.519758Z","iopub.status.idle":"2024-09-19T17:05:27.482622Z","shell.execute_reply":"2024-09-19T17:05:27.481379Z","shell.execute_reply.started":"2024-09-19T17:05:23.520134Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["/home/ea301b/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n","  from .autonotebook import tqdm as notebook_tqdm\n"]},{"name":"stdout","output_type":"stream","text":["Current device is: cuda\n","mkdir: cannot create directory ‘output’: File exists\n"]}],"source":["import matplotlib.pyplot as plt\n","import pandas as pd\n","\n","import warnings\n","# import wandb\n","\n","\n","from sklearn.metrics import roc_auc_score\n","from sklearn.utils import shuffle\n","from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold\n","import torch\n","import torch.nn as nn\n","from torch.nn import Parameter\n","import torch.nn.functional as F\n","from torch.optim import Adam, SGD, AdamW\n","from torch.optim.lr_scheduler import OneCycleLR\n","from torch.utils.data import DataLoader, Dataset\n","from tqdm.auto import tqdm\n","\n","# ======= OPTIONS =========\n","pd.set_option('display.max_rows', 500)\n","pd.set_option('display.max_columns', 500)\n","pd.set_option('display.width', 1000)\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","print(f\"Current device is: {device}\")\n","warnings.filterwarnings(\"ignore\")\n","!mkdir output"]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:27.484828Z","iopub.status.busy":"2024-09-19T17:05:27.484129Z","iopub.status.idle":"2024-09-19T17:05:43.959428Z","shell.execute_reply":"2024-09-19T17:05:43.958637Z","shell.execute_reply.started":"2024-09-19T17:05:27.484792Z"},"trusted":true},"outputs":[],"source":["import random\n","import torch.nn as nn\n","from torch.nn import BCEWithLogitsLoss\n","from collections import namedtuple\n","from dataclasses import dataclass, field, asdict\n","from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel\n","from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf\n","# from huggingface_hub import HfApi\n","\n","# import evaluate\n","import numpy as np\n","# from datasets import load_dataset\n","# from transformers import Trainer\n","from transformers import DataCollatorWithPadding\n","from transformers import AutoTokenizer, TrainingArguments\n","import re"]},{"cell_type":"markdown","metadata":{"execution":{"iopub.execute_input":"2024-09-18T04:07:22.734063Z","iopub.status.busy":"2024-09-18T04:07:22.733462Z","iopub.status.idle":"2024-09-18T04:07:22.850537Z","shell.execute_reply":"2024-09-18T04:07:22.849644Z","shell.execute_reply.started":"2024-09-18T04:07:22.734029Z"},"trusted":true},"source":["import wandb\n","from huggingface_hub import login\n","\n","login(token=\"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\")\n","notes = \"Train Mamba With 400k row dataset\""]},{"cell_type":"markdown","metadata":{"papermill":{"duration":0.012589,"end_time":"2022-08-31T07:03:04.13341","exception":false,"start_time":"2022-08-31T07:03:04.120821","status":"completed"},"tags":[]},"source":["# <b><span style='color:#F1A424'>|</span> Load Data</b><a class='anchor' id='load_data'></a> [↑](#top) \n","\n","***\n","\n","Load data."]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:43.962860Z","iopub.status.busy":"2024-09-19T17:05:43.962010Z","iopub.status.idle":"2024-09-19T17:05:53.355460Z","shell.execute_reply":"2024-09-19T17:05:53.354541Z","shell.execute_reply.started":"2024-09-19T17:05:43.962813Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["Processing Text: 100%|██████████| 165767/165767 [00:04<00:00, 37690.26it/s]\n","Processing Text: 100%|██████████| 1679/1679 [00:00<00:00, 40374.47it/s]"]},{"name":"stdout","output_type":"stream","text":["Trainging DF Processing\n","<class 'pandas.core.frame.DataFrame'>\n","RangeIndex: 165767 entries, 0 to 165766\n","Data columns (total 4 columns):\n"," #   Column     Non-Null Count   Dtype \n","---  ------     --------------   ----- \n"," 0   id         165767 non-null  object\n"," 1   prompt_id  165767 non-null  int64 \n"," 2   text       165767 non-null  object\n"," 3   generated  165767 non-null  int64 \n","dtypes: int64(2), object(2)\n","memory usage: 5.1+ MB\n","None\n","Testing DF Processing\n","<class 'pandas.core.frame.DataFrame'>\n","RangeIndex: 1679 entries, 0 to 1678\n","Data columns (total 4 columns):\n"," #   Column     Non-Null Count  Dtype \n","---  ------     --------------  ----- \n"," 0   id         1679 non-null   object\n"," 1   prompt_id  1679 non-null   int64 \n"," 2   text       1679 non-null   object\n"," 3   generated  1679 non-null   int64 \n","dtypes: int64(2), object(2)\n","memory usage: 52.6+ KB\n","None\n"]},{"name":"stderr","output_type":"stream","text":["\n"]}],"source":["import pandas as pd\n","import re\n","import unicodedata\n","from tqdm import tqdm\n","\n","# Load DataFrame\n","train_df = pd.read_parquet('/home/HardDisk/binh230_intern/Mamba-AI-generated-text-detection/data/Mix-AI-Dataset/train_essays.parquet')\n","valid_df = pd.read_parquet('/home/HardDisk/binh230_intern/Mamba-AI-generated-text-detection/data/Mix-AI-Dataset/valid_essays.parquet')\n","\n","# Define characters to remove\n","char_to_remove = ['{', '£', '\\x97', '¹', 'å', '\\\\', '\\x85', '<', '\\x99', \n","                  'é', ']', '+', 'Ö', '\\xa0', '>', '|', '\\x80', '~', '©', \n","                  '/', '\\x93', '$', 'Ó', '²', '^', ';', '`', 'á', '*', '(', \n","                  '¶', '®', '[', '\\x94', '\\x91', '#', '-', 'ó', ')', '}', '=']\n","\n","# Define preprocessing function\n","def preprocess_text(text, strategy='light'):    \n","    if strategy == \"none\":\n","        text = text\n","    elif strategy == \"light\":\n","        text = text.encode(\"ascii\", \"ignore\").decode('ascii')        \n","        text = text.strip()\n","        text = text.strip(\"\\\"\")\n","        for c in char_to_remove:\n","            text = text.replace(c, \"\")\n","        if text and text[-1] != \".\":\n","            text = text.split(\".\")\n","            text = \".\".join(text[:-1])\n","            text += \".\"\n","    else:\n","        text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('ascii')\n","        text = text.lower()\n","        text = re.sub(r'[^a-z0-9\\s.,;?!:()\\'\\\"%-]', '', text)\n","        text = re.sub(r'\\s+', ' ', text).strip()\n","    \n","    return text\n","\n","# Apply preprocessing with progress bar\n","tqdm.pandas(desc=\"Processing Text\")\n","train_df['text'] = train_df['text'].progress_apply(lambda x: preprocess_text(x, strategy='light'))\n","valid_df['text'] = valid_df['text'].progress_apply(lambda x: preprocess_text(x, strategy='light'))\n","\n","# Display the first few rows to verify\n","print(\"Trainging DF Processing\")\n","print(train_df.info())\n","print(\"Testing DF Processing\")\n","print(valid_df.info())\n","\n"]},{"cell_type":"markdown","metadata":{"papermill":{"duration":0.008127,"end_time":"2022-08-31T07:03:11.985369","exception":false,"start_time":"2022-08-31T07:03:11.977242","status":"completed"},"tags":[]},"source":["# <b><span style='color:#F1A424'>|</span> Dataset</b><a class='anchor' id='dataset'></a> [↑](#top) \n","\n","***\n","\n","    \n","We need to get the `max_len` from our `tokenizer`. We create a `tqdm` iterator and for each text we extract the tokenized length. Then we get the maximum value and we add 3 for the special tokens `CLS`, `SEP`, `SEP`.\n","\n","- [Hugging Face Padding and Truncation](https://huggingface.co/docs/transformers/pad_truncation): check truncation to `max_length` or `True` (batch max length)."]},{"cell_type":"markdown","metadata":{},"source":["One sample from the dataset should look as following:\n","```python\n","{\n","\t'inputs': {\n","\t\t'input_ids': tensor([1, 279, 883, ..., 0, 0]),\n","\t\t'token_type_ids': tensor([0, 0, 0, ..., 0, 0]),\n","\t\t'attention_mask': tensor([1, 1, 1, ..., 0, 0])\n","\t},\n","\t'label': tensor([0.0]),\n","\t'ids': '000e8c3c7ddb'\n","}\n","```\n","You can check it by running the cell below."]},{"cell_type":"markdown","metadata":{},"source":["import wandb\n","# Định nghĩa tên project để log thông tin quá trình huấn luyện trên wandb\n","os.environ[\"WANDB_PROJECT\"] = \"mamba_LLM_detect_binary_classification\"\n","os.environ[\"WANDB_API_KEY \"] = \"e7432690ce6d9bfdee410567f89d7e38844ed584\"\n","\n","\n","wandb.login()\n","# start a new wandb run to track this script\n","wandb.init(\n","    # set the wandb project where this run will be logged\n","    project=\"mamba_LLM_detect_binary_classification\",\n","\n","    # track hyperparameters and run metadata\n","    config={\n","    \"learning_rate\": 6e-5,\n","    \"architecture\": \"Mamba-130m-with-Linear-Head\",\n","    \"dataset\": \"Test\",\n","    \"epochs\": 1,\n","    \"lr_scheduler_type\": \"cosine\"\n","    }\n",")"]},{"cell_type":"markdown","metadata":{"papermill":{"duration":0.008073,"end_time":"2022-08-31T07:03:17.933189","exception":false,"start_time":"2022-08-31T07:03:17.925116","status":"completed"},"tags":[]},"source":["# <b><span style='color:#F1A424'>|</span> Model</b><a class='anchor' id='model'></a> [↑](#top) \n","\n","***"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:53.357180Z","iopub.status.busy":"2024-09-19T17:05:53.356739Z","iopub.status.idle":"2024-09-19T17:05:53.373217Z","shell.execute_reply":"2024-09-19T17:05:53.372238Z","shell.execute_reply.started":"2024-09-19T17:05:53.357130Z"},"trusted":true},"outputs":[{"data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>id</th>\n","      <th>prompt_id</th>\n","      <th>text</th>\n","      <th>generated</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>e_ddxvqx5i</td>\n","      <td>0</td>\n","      <td>In recent years, there has been a growing move...</td>\n","      <td>1</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>e_hi0yzrcv</td>\n","      <td>0</td>\n","      <td>\\nWhy not cars in our life\\n\\nI have ever met ...</td>\n","      <td>1</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>e_uesv4xha</td>\n","      <td>0</td>\n","      <td>A car is considered by many a nessecity for ev...</td>\n","      <td>1</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>e_2tl5ylwy</td>\n","      <td>0</td>\n","      <td>H\\n\\nello fellow citezens , we are here to inf...</td>\n","      <td>0</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>e_s6ci4vj0</td>\n","      <td>0</td>\n","      <td>Have you ever known how if feels not being abl...</td>\n","      <td>1</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>"],"text/plain":["           id  prompt_id                                               text  generated\n","0  e_ddxvqx5i          0  In recent years, there has been a growing move...          1\n","1  e_hi0yzrcv          0  \\nWhy not cars in our life\\n\\nI have ever met ...          1\n","2  e_uesv4xha          0  A car is considered by many a nessecity for ev...          1\n","3  e_2tl5ylwy          0  H\\n\\nello fellow citezens , we are here to inf...          0\n","4  e_s6ci4vj0          0  Have you ever known how if feels not being abl...          1"]},"execution_count":4,"metadata":{},"output_type":"execute_result"}],"source":["train_df.head()"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:53.374547Z","iopub.status.busy":"2024-09-19T17:05:53.374259Z","iopub.status.idle":"2024-09-19T17:05:56.661369Z","shell.execute_reply":"2024-09-19T17:05:56.660369Z","shell.execute_reply.started":"2024-09-19T17:05:53.374516Z"},"trusted":true},"outputs":[{"data":{"text/plain":["DatasetDict({\n","    train: Dataset({\n","        features: ['id', 'prompt_id', 'text', 'labels'],\n","        num_rows: 165767\n","    })\n","    test: Dataset({\n","        features: ['id', 'prompt_id', 'text', 'labels'],\n","        num_rows: 1679\n","    })\n","})"]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["import pandas as pd\n","from datasets import Dataset, DatasetDict\n","from sklearn.model_selection import train_test_split\n","\n","# Assuming train_df is your DataFrame with a 'text' column\n","# Convert the 'id' column to a string to avoid ArrowTypeError\n","# df['id'] = df['id'].astype(str)\n","\n","# Rename the 'generated' column to 'labels'\n","train_df.rename(columns={'generated': 'labels'}, inplace=True)\n","valid_df.rename(columns={'generated': 'labels'}, inplace=True)\n","\n","# # Access the train and test datasets\n","# train_dataset, test_dataset = train_test_split(df, test_size=0.05)\n","\n","# Combine the splits into a DatasetDict\n","dataset_dict = DatasetDict({\n","    'train': Dataset.from_pandas(train_df),\n","    'test': Dataset.from_pandas(valid_df),\n","})\n","\n","# Display the first example from each dataset\n","dataset_dict"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:56.662755Z","iopub.status.busy":"2024-09-19T17:05:56.662437Z","iopub.status.idle":"2024-09-19T17:08:50.995345Z","shell.execute_reply":"2024-09-19T17:08:50.994344Z","shell.execute_reply.started":"2024-09-19T17:05:56.662708Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n","Map: 100%|██████████| 165767/165767 [02:04<00:00, 1335.75 examples/s]\n","Map: 100%|██████████| 1679/1679 [00:01<00:00, 1349.49 examples/s]\n"]}],"source":["tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')\n","# Add eos tokens\n","# tokenizer.eos_token = \"<|endoftext|>\"\n","tokenizer.pad_token = tokenizer.eos_token\n","def preprocess_function(examples):\n","    # Tokenize the text with truncation\n","    samples = tokenizer(examples['text'], \n","                        truncation=True, \n","                        padding='max_length', \n","                        max_length=1024,         \n","                        return_tensors=\"pt\")\n","    \n","    return samples\n","\n","# Apply preprocessing to the dataset\n","tokenized_dataset = dataset_dict.map(preprocess_function, batched=True)\n"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"data":{"text/plain":["DatasetDict({\n","    train: Dataset({\n","        features: ['id', 'prompt_id', 'text', 'labels', 'input_ids', 'attention_mask'],\n","        num_rows: 165767\n","    })\n","    test: Dataset({\n","        features: ['id', 'prompt_id', 'text', 'labels', 'input_ids', 'attention_mask'],\n","        num_rows: 1679\n","    })\n","})"]},"execution_count":7,"metadata":{},"output_type":"execute_result"}],"source":["tokenized_dataset"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:08:50.996892Z","iopub.status.busy":"2024-09-19T17:08:50.996548Z","iopub.status.idle":"2024-09-19T17:08:51.011804Z","shell.execute_reply":"2024-09-19T17:08:51.010782Z","shell.execute_reply.started":"2024-09-19T17:08:50.996858Z"},"trusted":true},"outputs":[],"source":["# Set seed cho hàm random\n","random.seed(42)\n","\n","# Tạo tập train và test\n","train_dataset = tokenized_dataset[\"train\"]\n","test_dataset = tokenized_dataset[\"test\"]\n","#  Drop the 'prompt_id' feature from both datasets\n","train_dataset = train_dataset.remove_columns([\"text\"]).remove_columns([\"id\"])\n","test_dataset = test_dataset.remove_columns([\"text\"]).remove_columns([\"id\"])\n","\n","# Tạo tập evaluation để đánh giá trong lúc train\n","# Do số lượng tập test lớn nên chỉ lấy mẫu 1% tập dữ liệu test để đánh giá\n","# total_samples = len(test_dataset)\n","# eval_samples = int(0.5 * total_samples)\n","# eval_indices = random.sample(range(total_samples), eval_samples)\n","# eval_dataset = test_dataset.select(eval_indices)"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:08:51.013411Z","iopub.status.busy":"2024-09-19T17:08:51.013088Z","iopub.status.idle":"2024-09-19T17:10:28.712704Z","shell.execute_reply":"2024-09-19T17:10:28.711832Z","shell.execute_reply.started":"2024-09-19T17:08:51.013380Z"},"trusted":true},"outputs":[],"source":["import torch\n","import numpy as np\n","from transformers import Trainer, TrainingArguments, DataCollatorWithPadding\n","# Load the model\n","import torch\n","from xlstm import (\n","    xLSTMBlockStack,\n","    xLSTMBlockStackConfig,\n","    mLSTMBlockConfig,\n","    mLSTMLayerConfig,\n","    sLSTMBlockConfig,\n","    sLSTMLayerConfig,\n","    FeedForwardConfig,\n",")\n","\n","# Dataset and Tokenizer Setup\n","data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["{'verbose': True, 'with_cuda': True, 'extra_ldflags': ['-L/home/ea301b/anaconda3/envs/binh_mamba/lib', '-lcublas'], 'extra_cflags': ['-DSLSTM_HIDDEN_SIZE=1024', '-DSLSTM_BATCH_SIZE=8', '-DSLSTM_NUM_HEADS=4', '-DSLSTM_NUM_STATES=4', '-DSLSTM_DTYPE_B=float', '-DSLSTM_DTYPE_R=__nv_bfloat16', '-DSLSTM_DTYPE_W=__nv_bfloat16', '-DSLSTM_DTYPE_G=__nv_bfloat16', '-DSLSTM_DTYPE_S=__nv_bfloat16', '-DSLSTM_DTYPE_A=float', '-DSLSTM_NUM_GATES=4', '-DSLSTM_SIMPLE_AGG=true', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL_VALID=false', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL=0.0', '-DSLSTM_FORWARD_CLIPVAL_VALID=false', '-DSLSTM_FORWARD_CLIPVAL=0.0', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_OPERATORS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '-U__CUDA_NO_BFLOAT162_OPERATORS__', '-U__CUDA_NO_BFLOAT162_CONVERSIONS__'], 'extra_cuda_cflags': ['-Xptxas=\"-v\"', '-gencode', 'arch=compute_80,code=compute_80', '-res-usage', '--use_fast_math', '-O3', '-Xptxas -O3', '--extra-device-vectorization', '-DSLSTM_HIDDEN_SIZE=1024', '-DSLSTM_BATCH_SIZE=8', '-DSLSTM_NUM_HEADS=4', '-DSLSTM_NUM_STATES=4', '-DSLSTM_DTYPE_B=float', '-DSLSTM_DTYPE_R=__nv_bfloat16', '-DSLSTM_DTYPE_W=__nv_bfloat16', '-DSLSTM_DTYPE_G=__nv_bfloat16', '-DSLSTM_DTYPE_S=__nv_bfloat16', '-DSLSTM_DTYPE_A=float', '-DSLSTM_NUM_GATES=4', '-DSLSTM_SIMPLE_AGG=true', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL_VALID=false', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL=0.0', '-DSLSTM_FORWARD_CLIPVAL_VALID=false', '-DSLSTM_FORWARD_CLIPVAL=0.0', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_OPERATORS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '-U__CUDA_NO_BFLOAT162_OPERATORS__', '-U__CUDA_NO_BFLOAT162_CONVERSIONS__']}\n"]},{"name":"stderr","output_type":"stream","text":["Using /home/ea301b/.cache/torch_extensions/py312_cu124 as PyTorch extensions root...\n","Detected CUDA files, patching ldflags\n","Emitting ninja build file /home/ea301b/.cache/torch_extensions/py312_cu124/slstm_HS1024BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0/build.ninja...\n","Building extension module slstm_HS1024BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...\n","Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n","Loading extension module slstm_HS1024BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...\n"]},{"name":"stdout","output_type":"stream","text":["ninja: no work to do.\n"]},{"data":{"text/plain":["True"]},"execution_count":10,"metadata":{},"output_type":"execute_result"}],"source":["from omegaconf import OmegaConf\n","from dacite import from_dict\n","from dacite import Config as DaciteConfig\n","from xlstm import xLSTMLMModel, xLSTMLMModelConfig\n","\n","xlstm_cfg = \"\"\" \n","vocab_size: 50304\n","mlstm_block:\n","  mlstm:\n","    conv1d_kernel_size: 4\n","    qkv_proj_blocksize: 4\n","    num_heads: 4\n","slstm_block:\n","  slstm:\n","    backend: cuda\n","    num_heads: 4\n","    conv1d_kernel_size: 4\n","    bias_init: powerlaw_blockdependent\n","  feedforward:\n","    proj_factor: 1.3\n","    act_fn: gelu\n","context_length: 1024\n","num_blocks: 16\n","embedding_dim: 1024\n","slstm_at: [1]\n","\"\"\"\n","cfg = OmegaConf.create(xlstm_cfg)\n","cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))\n","xlstm_stack = xLSTMLMModel(cfg)\n","\n","x = torch.randint(0, 50304, size=(4, 256)).to(\"cuda\")\n","xlstm_stack = xlstm_stack.to(\"cuda\")\n","y = xlstm_stack(x)\n","y.shape[1:] == (256, 50304)\n","# model = xlstm_stack.lm_head"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"data":{"text/plain":["xLSTMLMModel(\n","  (xlstm_block_stack): xLSTMBlockStack(\n","    (blocks): ModuleList(\n","      (0): mLSTMBlock(\n","        (xlstm_norm): LayerNorm()\n","        (xlstm): mLSTMLayer(\n","          (proj_up): Linear(in_features=1024, out_features=4096, bias=False)\n","          (q_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (k_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (v_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (conv1d): CausalConv1d(\n","            (conv): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)\n","          )\n","          (conv_act_fn): SiLU()\n","          (mlstm_cell): mLSTMCell(\n","            (igate): Linear(in_features=6144, out_features=4, bias=True)\n","            (fgate): Linear(in_features=6144, out_features=4, bias=True)\n","            (outnorm): MultiHeadLayerNorm()\n","          )\n","          (ogate_act_fn): SiLU()\n","          (proj_down): Linear(in_features=2048, out_features=1024, bias=False)\n","          (dropout): Dropout(p=0.0, inplace=False)\n","        )\n","      )\n","      (1): sLSTMBlock(\n","        (xlstm_norm): LayerNorm()\n","        (xlstm): sLSTMLayer(\n","          (conv1d): CausalConv1d(\n","            (conv): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)\n","          )\n","          (conv_act_fn): SiLU()\n","          (fgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (igate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (zgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (ogate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (slstm_cell): sLSTMCell_cuda(function=slstm, hidden_size=1024, num_heads=4)\n","          (group_norm): MultiHeadLayerNorm()\n","          (dropout): Dropout(p=0.0, inplace=False)\n","        )\n","        (ffn_norm): LayerNorm()\n","        (ffn): GatedFeedForward(\n","          (proj_up): Linear(in_features=1024, out_features=2688, bias=False)\n","          (proj_down): Linear(in_features=1344, out_features=1024, bias=False)\n","          (dropout): Dropout(p=0.0, inplace=False)\n","        )\n","      )\n","      (2-15): 14 x mLSTMBlock(\n","        (xlstm_norm): LayerNorm()\n","        (xlstm): mLSTMLayer(\n","          (proj_up): Linear(in_features=1024, out_features=4096, bias=False)\n","          (q_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (k_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (v_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n","          (conv1d): CausalConv1d(\n","            (conv): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)\n","          )\n","          (conv_act_fn): SiLU()\n","          (mlstm_cell): mLSTMCell(\n","            (igate): Linear(in_features=6144, out_features=4, bias=True)\n","            (fgate): Linear(in_features=6144, out_features=4, bias=True)\n","            (outnorm): MultiHeadLayerNorm()\n","          )\n","          (ogate_act_fn): SiLU()\n","          (proj_down): Linear(in_features=2048, out_features=1024, bias=False)\n","          (dropout): Dropout(p=0.0, inplace=False)\n","        )\n","      )\n","    )\n","    (post_blocks_norm): LayerNorm()\n","  )\n","  (token_embedding): Embedding(50304, 1024)\n","  (emb_dropout): Identity()\n","  (lm_head): Linear(in_features=1024, out_features=2, bias=True)\n",")"]},"execution_count":11,"metadata":{},"output_type":"execute_result"}],"source":["xlstm_stack.lm_head = nn.Linear(xlstm_stack.lm_head.in_features, 2)\n","model = xlstm_stack\n","model.cuda()"]},{"cell_type":"code","execution_count":12,"metadata":{},"outputs":[],"source":["from transformers import DataCollatorWithPadding, AutoConfig\n","\n","# Dataset and Tokenizer Setup\n","data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[],"source":["# import torch\n","# import numpy as np\n","# import wandb  # Weights & Biases integration\n","# from torch import nn\n","# import torch\n","# from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score\n","# from typing import Dict, Union\n","# import torch\n","# from transformers import (\n","#     DataCollatorWithPadding, \n","#     AdamW, \n","#     Trainer, \n","#     TrainingArguments,\n","#     get_cosine_schedule_with_warmup,\n","#     TrainerCallback\n","# )\n","# from torch.utils.data import DataLoader\n","# from huggingface_hub import login  # For pushing to the Hugging Face Hub\n","\n","# # Authenticate Hugging Face API token\n","# # Make sure you've logged in before running the script\n","# login(token=\"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\")\n","# # Initialize wandb run\n","# wandb.init(project=\"Detect AI Generated Text\", \n","#            name=\"xLSTM-base\",\n","#            config={\n","#                \"learning_rate\": 1e-4,\n","#                \"label_smoothing\": 0.03,\n","#                \"batch_size\": 8,\n","#                \"num_epochs\": 1,\n","#                \"optimizer\": \"AdamW\",\n","#                \"model\": 'xLSTM',\n","#                \"model_params\": sum(p.numel() for p in xlstm_stack.parameters() if p.requires_grad)\n","#            })\n","\n","    \n","# # Access the configuration\n","# config = wandb.config\n","\n","# # Now you can call config values like this\n","# learning_rate = config.learning_rate\n","# label_smoothing = config.label_smoothing\n","# batch_size = config.batch_size\n","\n","# # Data Collator Setup\n","# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n","\n","# # Dataloader Setup\n","# train_data_loader = DataLoader(\n","#     train_dataset, \n","#     batch_size=batch_size,  # Increased batch size since it will be split across GPUs\n","#     num_workers=4, \n","#     shuffle=True, \n","#     pin_memory=True, \n","#     collate_fn=data_collator\n","# )\n","\n","# test_data_loader = DataLoader(\n","#     test_dataset, \n","#     batch_size=2,  # Increased batch size\n","#     num_workers=4, \n","#     shuffle=False, \n","#     pin_memory=True, \n","#     collate_fn=data_collator\n","# )\n","\n","# # Optimizer Setup\n","# optimizer = AdamW(\n","#     xlstm_stack.parameters(),\n","#     lr=learning_rate,  # Define your learning_rate\n","#     weight_decay=0.1\n","# )\n","\n","# # Scheduler Setup (Cosine Annealing)\n","# total_train_steps = len(train_data_loader) * 1  # Adjust based on your epochs\n","# lr_scheduler = get_cosine_schedule_with_warmup(\n","#     optimizer,\n","#     num_warmup_steps=500,  # Can adjust based on needs\n","#     num_training_steps=total_train_steps\n","# )\n","\n","\n","# def compute_metrics(eval_pred):\n","#     \"\"\"\n","#     Compute metrics for Hugging Face Trainer, including AUROC.\n","\n","#     Args:\n","#         eval_pred: tuple of (predictions, labels) where predictions are logits.\n","\n","#     Returns:\n","#         dictionary containing the computed metrics, including AUROC.\n","#     \"\"\"\n","#     # Unpack predictions and labels\n","#     logits, labels = eval_pred\n","#     preds = logits.argmax(-1)  # Get the predicted class\n","\n","#     # Calculate accuracy\n","#     accuracy = accuracy_score(labels, preds)\n","\n","#     # Calculate precision, recall, and F1-score\n","#     precision = precision_score(labels, preds, average='weighted')\n","#     recall = recall_score(labels, preds, average='weighted')\n","#     f1 = f1_score(labels, preds, average='weighted')\n","\n","#     # Calculate probabilities using softmax on logits (not on preds)\n","#     probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()\n","#     # For binary classification, take the probability of the positive class (class 1)\n","#     auroc = roc_auc_score(labels, probs[:, 1])\n","\n","\n","#     return {\n","#         'accuracy': accuracy,\n","#         'precision': precision,\n","#         'recall': recall,\n","#         'f1': f1,\n","#         'auroc': auroc\n","#     }\n","\n","\n","\n","# # Training Arguments Setup\n","# training_args = TrainingArguments(\n","#     output_dir=\"./results\",  # Directory to save model checkpoints\n","#     evaluation_strategy=\"steps\",  # Evaluate every few steps\n","#     # eval_steps=,  # Evaluate every 1000 steps\n","#     per_device_train_batch_size=batch_size,  # Adjust batch size per device (GPU)\n","#     per_device_eval_batch_size=4,  # Same for evaluation\n","#     num_train_epochs=1,  # Define total number of epochs\n","#     weight_decay=0.1,  # L2 regularization\n","#     logging_dir=\"./logs\",  # Log directory\n","#     fp16=False,  # Use mixed precision training\n","#     save_steps=2000,  # Save model every 20000 steps\n","#     label_smoothing_factor=0.03,\n","#     hub_model_id=\"xLSTM-4-1\",  # Set model name for HF Hub\n","#     push_to_hub=True,  # Push to Hugging Face Hub\n","#     save_total_limit=2,  # Only keep the last 2 checkpoints,\n","#     metric_for_best_model=\"eval_auroc\",  # Use AUROC to determine best model\n","#     greater_is_better=True,         # Higher AUROC is better\n","#     max_grad_norm=1,\n","#     report_to=\"wandb\",               # Report metrics to Weights & Biases\n","# )\n","\n","# # Initialize the Trainer\n","# trainer = Trainer(\n","#     model=model,\n","#     args=training_args,\n","#     train_dataset=train_dataset,  # Replace with your actual training dataset\n","#     eval_dataset=test_dataset,    # Replace with your actual evaluation dataset\n","#     tokenizer=tokenizer,\n","#     data_collator=data_collator,\n","#     optimizers=(optimizer, lr_scheduler),  # Pass the optimizer and scheduler\n","#     compute_metrics=compute_metrics  # Optional custom metric computation\n","# )\n","\n","# # Training and evaluation\n","# trainer.train()\n","# trainer.evaluate()\n","\n","# # Push to Hub\n","# trainer.push_to_hub()\n","\n","# # Finish wandb logging\n","# wandb.finish()\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]},{"cell_type":"code","execution_count":14,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:10:28.716075Z","iopub.status.busy":"2024-09-19T17:10:28.715683Z","iopub.status.idle":"2024-09-19T17:10:28.726787Z","shell.execute_reply":"2024-09-19T17:10:28.726000Z","shell.execute_reply.started":"2024-09-19T17:10:28.716039Z"},"trusted":true},"outputs":[],"source":["import torch\n","import numpy as np\n","from tqdm import tqdm\n","from sklearn.metrics import roc_auc_score, accuracy_score\n","\n","# Accuracy Calculation\n","def compute_accuracy(labels, predictions):\n","    preds = torch.argmax(predictions, dim=1)\n","    correct = torch.sum(preds == labels)\n","    return correct.item() / len(labels)\n","\n","def TestModel(test_data_loader, model, criterion):\n","    test_losses = []\n","    all_predictions = []\n","    all_actual_values = []\n","    \n","    with torch.no_grad():\n","        for batch in tqdm(test_data_loader):\n","            if len(batch.input_ids) == 0:\n","                # Safeguard against empty sequences.\n","                continue\n","\n","            # Have shape (batch size, token count)\n","            token_sequences = batch.input_ids.cuda()\n","            attention_masks = batch.attention_mask.cuda()\n","            # Has shape (batch size)\n","            labels = batch.labels.cuda()\n","\n","            with torch.cuda.amp.autocast():\n","                output = model(token_sequences)\n","                raw_predictions = output[:, -1, :]\n","\n","            loss = criterion(raw_predictions.view(-1, 2), labels.view(-1))\n","            test_losses.append(loss.detach().cpu())\n","\n","            scaled_predictions = raw_predictions.softmax(dim=1)[:, 1]\n","            all_predictions.extend(scaled_predictions.cpu().numpy())\n","            all_actual_values.extend(labels.cpu().numpy())\n","\n","    all_predictions, all_actual_values = np.array(all_predictions), np.array(all_actual_values)\n","\n","    auroc = roc_auc_score(all_actual_values, all_predictions)\n","    \n","    # Binarize predictions and compute accuracy\n","    binary_predictions = (all_predictions > 0.7).astype(int)\n","    accuracy = accuracy_score(all_actual_values, binary_predictions)\n","    \n","    return accuracy, auroc, np.mean(test_losses)\n"]},{"cell_type":"code","execution_count":15,"metadata":{"collapsed":true,"execution":{"iopub.execute_input":"2024-09-19T17:10:28.728787Z","iopub.status.busy":"2024-09-19T17:10:28.728385Z","iopub.status.idle":"2024-09-20T03:37:05.179864Z","shell.execute_reply":"2024-09-20T03:37:05.178362Z","shell.execute_reply.started":"2024-09-19T17:10:28.728728Z"},"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n","\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtruonggiabjnh2003\u001b[0m (\u001b[33mtruonggiabjnh2003-fpt-university\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"]},{"data":{"text/html":["Tracking run with wandb version 0.18.5"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["Run data is saved locally in <code>/home/HardDisk/binh230_intern/Mamba-AI-generated-text-detection/train/finetune_model/wandb/run-20241026_104913-w6ldiog1</code>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["Syncing run <strong><a href='https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name/runs/w6ldiog1' target=\"_blank\">experiment_3090_1</a></strong> to <a href='https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":[" View project at <a href='https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name' target=\"_blank\">https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name</a>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":[" View run at <a href='https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name/runs/w6ldiog1' target=\"_blank\">https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name/runs/w6ldiog1</a>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["lr = 0.0003, label_smoothing = 0.03, output_subdir = 3090_1\n"]},{"name":"stderr","output_type":"stream","text":["  0%|          | 0/20721 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n","To disable this warning, you can either:\n","\t- Avoid using `tokenizers` before the fork if possible\n","\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n","huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n","To disable this warning, you can either:\n","\t- Avoid using `tokenizers` before the fork if possible\n","\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n","huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n","To disable this warning, you can either:\n","\t- Avoid using `tokenizers` before the fork if possible\n","\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n","huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n","To disable this warning, you can either:\n","\t- Avoid using `tokenizers` before the fork if possible\n","\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n","  0%|          | 21/20721 [00:17<4:42:09,  1.22it/s]\n"]},{"ename":"KeyboardInterrupt","evalue":"","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)","Cell \u001b[0;32mIn[15], line 98\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[38;5;66;03m# Training accuracy\u001b[39;00m\n\u001b[1;32m     96\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m compute_accuracy(labels\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m), raw_predictions\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m))\n\u001b[0;32m---> 98\u001b[0m \u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     99\u001b[0m scaler\u001b[38;5;241m.\u001b[39munscale_(optimizer)\n\u001b[1;32m    100\u001b[0m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mclip_grad_norm_(model\u001b[38;5;241m.\u001b[39mparameters(), \u001b[38;5;241m1.0\u001b[39m)\n","File \u001b[0;32m~/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/torch/_tensor.py:521\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    511\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    512\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    513\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    514\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    519\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    520\u001b[0m     )\n\u001b[0;32m--> 521\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    522\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    523\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n","File \u001b[0;32m~/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/torch/autograd/__init__.py:289\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    284\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    286\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    287\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    288\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 289\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    290\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    291\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    292\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    293\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    294\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    295\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    296\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    297\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n","File \u001b[0;32m~/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/torch/autograd/graph.py:768\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m    766\u001b[0m     unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m    767\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 768\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    769\u001b[0m \u001b[43m        \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m    770\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m    771\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    772\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n","\u001b[0;31mKeyboardInterrupt\u001b[0m: "]}],"source":["import wandb\n","import torch\n","import torch.nn as nn\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","from transformers import Adafactor, AdamW\n","\n","\n","\n","# Variables for the experiment\n","label_smoothing = 0.03\n","output_subdir = '3090_1'\n","max_learning_rates = [3e-4]\n","\n","\n","# Initialize Weights & Biases\n","wandb.init(project=\"your_project_name\", name=\"experiment_3090_1\", config={\n","    \"label_smoothing\": label_smoothing,\n","    \"output_subdir\": output_subdir,\n","    \"max_learning_rate\": max_learning_rates[0],\n","    \"batch_size\": 8\n","})\n","\n","\n","\n","# Run experiment\n","for max_learning_rate in max_learning_rates:\n","    print(f'lr = {max_learning_rate}, label_smoothing = {label_smoothing}, output_subdir = {output_subdir}')\n","    \n","    # Dataloader Setup\n","    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n","    train_data_loader = DataLoader(\n","        train_dataset, \n","        batch_size=8,\n","        num_workers=4, \n","        shuffle=True, \n","        pin_memory=True, \n","        collate_fn=data_collator\n","    )\n","    test_data_loader = DataLoader(\n","        test_dataset, \n","        batch_size=8,\n","        num_workers=4, \n","        shuffle=False, \n","        pin_memory=True, \n","        collate_fn=data_collator\n","    )\n","\n","    # Optimizer, Criterion, and Scaler Setup\n","    optimizer = AdamW(\n","        model.parameters(),\n","        lr=max_learning_rate,\n","        weight_decay=0.1\n","    )\n","    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)\n","    scaler = torch.cuda.amp.GradScaler(enabled=True)\n","\n","    total_step_count = len(train_data_loader)\n","    lr_schedule = torch.optim.lr_scheduler.OneCycleLR(\n","        optimizer=optimizer,\n","        max_lr=max_learning_rate,\n","        total_steps=total_step_count,\n","        pct_start=0.1,\n","        anneal_strategy='linear',\n","        cycle_momentum=False\n","    )\n","\n","    best_auroc = -99999999\n","    train_losses = []\n","    model.train()\n","\n","    # Tracking the number of rows processed\n","    total_rows_processed = 0\n","    row_threshold = 50000\n","\n","    print_steps = 500\n","\n","    for batch_index, train_batch in enumerate(tqdm(train_data_loader)):\n","        if len(train_batch.input_ids) == 0:\n","            continue\n","\n","        # Send data to GPU(s)\n","        token_sequences = train_batch.input_ids.to(\"cuda\")\n","        attention_masks = train_batch.attention_mask.to(\"cuda\")\n","        labels = train_batch.labels.to(\"cuda\")\n","\n","        optimizer.zero_grad()\n","\n","        with torch.cuda.amp.autocast():\n","            output = model(token_sequences)\n","            raw_predictions = output[:, -1, :]\n","\n","            loss = criterion(raw_predictions.view(-1, 2), labels.view(-1)) \n","\n","        # Training accuracy\n","        accuracy = compute_accuracy(labels.view(-1), raw_predictions.view(-1, 2))\n","\n","        scaler.scale(loss).backward()\n","        scaler.unscale_(optimizer)\n","        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n","        scaler.step(optimizer)\n","        scaler.update()\n","        lr_schedule.step()\n","\n","        train_losses.append(loss.detach().cpu())\n","\n","        # Log training accuracy and loss every 500 steps\n","        if (batch_index + 1) % print_steps == 0:\n","            avg_train_loss = sum(train_losses) / len(train_losses)\n","            print(f\"Step {batch_index+1}/{total_step_count}: Avg Train Loss = {avg_train_loss:.4f}, Train Accuracy = {accuracy*100:.2f}%\")\n","\n","            # Log to Weights & Biases\n","            wandb.log({\"train_loss\": avg_train_loss, \"train_accuracy\": accuracy * 100})\n","\n","            train_losses = []  # Reset train loss tracking for the next 500 steps\n","\n","        # Increment the number of rows processed\n","        total_rows_processed += len(train_batch.input_ids)\n","\n","        # Evaluate the model every 50,000 rows\n","        if total_rows_processed >= row_threshold:\n","            model.eval()\n","            val_accuracy, auroc, test_loss = TestModel(test_data_loader, model, criterion)\n","            model.train()\n","            \n","            print(f'Validation Loss: {test_loss:.4f}, Validation Accuracy: {val_accuracy*100:.2f}%, AuROC Score: {auroc*100:.2f}%')\n","            \n","            # Log validation metrics to Weights & Biases\n","            wandb.log({\"val_loss\": test_loss, \"val_accuracy\": val_accuracy * 100, \"auroc\": auroc * 100})\n","\n","            total_rows_processed = 0  # Reset after each evaluation\n","\n","            # Save model if improved\n","            if auroc > best_auroc:\n","                best_auroc = auroc\n","                torch.save(model.state_dict(), f'./weight_models/xLSTM-Base-Val_Accuracy-{val_accuracy*100}%-AuROC_Score-{auroc*100}-Loss-{int(test_loss*1000)}.pth')\n","\n","    print(f'Training Finish !!!')\n","\n","# Finish the W&B run\n","# wandb.finish()\n"]},{"cell_type":"code","execution_count":14,"metadata":{"execution":{"iopub.execute_input":"2024-09-20T04:13:25.274416Z","iopub.status.busy":"2024-09-20T04:13:25.274045Z","iopub.status.idle":"2024-09-20T04:13:30.882150Z","shell.execute_reply":"2024-09-20T04:13:30.880422Z","shell.execute_reply.started":"2024-09-20T04:13:25.274380Z"},"trusted":true},"outputs":[],"source":["torch.save(model.state_dict(), f'./Models/Mamba-780m-Step-{batch_index+1}-Loss-{int(test_loss*1000)}.pth')\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-20T03:37:05.182217Z","iopub.status.idle":"2024-09-20T03:37:05.182548Z","shell.execute_reply":"2024-09-20T03:37:05.182399Z","shell.execute_reply.started":"2024-09-20T03:37:05.182382Z"},"trusted":true},"outputs":[],"source":["auroc_scores_by_dataset, test_loss"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-20T03:37:05.183718Z","iopub.status.idle":"2024-09-20T03:37:05.184069Z","shell.execute_reply":"2024-09-20T03:37:05.183920Z","shell.execute_reply.started":"2024-09-20T03:37:05.183903Z"},"trusted":true},"outputs":[],"source":["import torch\n","import numpy as np\n","from tqdm import tqdm\n","from sklearn.metrics import roc_auc_score\n","\n","model.eval()\n","auroc_scores_by_dataset, test_loss = TestModel(test_data_loader, model, criterion)\n","model.train()\n","\n","# average_auroc = np.average(auroc_scores_by_dataset, weights=[1, 1])\n","# if (average_auroc > best_auroc) or (max(auroc_scores_by_dataset) > 0.993):\n","#     best_auroc = average_auroc\n","#     if output_subdir is not None:\n","#         torch.save(model.state_dict(), f'Models/Mamba/{output_subdir}/S{step_number}_CTX1024.pth')\n","\n","# train_losses = []"]},{"cell_type":"markdown","metadata":{},"source":["### <b><span style='color:#F1A424'>Confusion Matrix</span></b>\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-20T03:37:05.185888Z","iopub.status.idle":"2024-09-20T03:37:05.186360Z","shell.execute_reply":"2024-09-20T03:37:05.186134Z","shell.execute_reply.started":"2024-09-20T03:37:05.186109Z"},"trusted":true},"outputs":[],"source":["import matplotlib.pyplot as plt\n","import pandas as pd\n","import seaborn as sns\n","\n","from sklearn.metrics import confusion_matrix\n","\n","def binarize(x, threshold):\n","    if x > threshold:\n","        x = 1\n","    else:\n","        x = 0\n","    return x\n","\n","# Assuming df is your pandas DataFrame\n","oof_df[\"binary\"] = oof_df[\"preds\"].apply(lambda x: binarize(x, 0.5))\n","true_labels = oof_df[\"generated\"].values\n","predicted_labels = oof_df[\"binary\"].values\n","\n","# Get the unique classes from both true and predicted labels\n","classes = np.unique(np.concatenate((true_labels, predicted_labels)))\n","\n","# Compute the confusion matrix\n","cm = confusion_matrix(true_labels, predicted_labels, labels=classes)\n","plt.figure(figsize=(8, 6))\n","sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=classes, yticklabels=classes)\n","plt.xlabel(\"Predicted Labels\")\n","plt.ylabel(\"True Labels\")\n","plt.title(\"Confusion Matrix\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"databundleVersionId":7516023,"sourceId":61542,"sourceType":"competition"},{"datasetId":3936750,"sourceId":6847931,"sourceType":"datasetVersion"},{"datasetId":4325258,"sourceId":7432540,"sourceType":"datasetVersion"},{"datasetId":4336615,"sourceId":7452416,"sourceType":"datasetVersion"}],"dockerImageVersionId":30762,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"binh_mamba","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.12.5"}},"nbformat":4,"nbformat_minor":4}