{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, Trainer, TrainingArguments\nfrom datasets import load_dataset, concatenate_datasets, Features, ClassLabel, Value, Sequence\n\nds_imdb = load_dataset(\"imdb\")\nds_sst2 = load_dataset(\"glue\", \"sst2\")\n\ntokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n\ncommon_features = Features({\n 'input_ids': Sequence(Value('int32')),\n 'attention_mask': Sequence(Value('int8')),\n 'label': ClassLabel(names=['negative', 'positive'])\n})\n\ndef tokenize_fn(examples):\n return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True, max_length=128)\n\ndef tokenize_sst2_fn(examples):\n return tokenizer(examples[\"sentence\"], padding=\"max_length\", truncation=True, max_length=128)\n\ntokenized_imdb = ds_imdb.map(tokenize_fn, batched=True)\ntokenized_sst2 = ds_sst2.map(tokenize_sst2_fn, batched=True)\n\ncolumns_to_keep = ['input_ids', 'attention_mask', 'label']\n\ntrain_imdb = tokenized_imdb[\"train\"].remove_columns([c for c in tokenized_imdb[\"train\"].column_names if c not in columns_to_keep])\ntrain_sst2 = tokenized_sst2[\"train\"].remove_columns([c for c in tokenized_sst2[\"train\"].column_names if c not in columns_to_keep])\n\ntest_imdb = tokenized_imdb[\"test\"].remove_columns([c for c in tokenized_imdb[\"test\"].column_names if c not in columns_to_keep])\nval_sst2 = tokenized_sst2[\"validation\"].remove_columns([c for c in tokenized_sst2[\"validation\"].column_names if c not in columns_to_keep])\n\ntrain_imdb = train_imdb.cast(common_features)\ntrain_sst2 = train_sst2.cast(common_features)\ntest_imdb = test_imdb.cast(common_features)\nval_sst2 = val_sst2.cast(common_features)\n\ntrain_mixed = concatenate_datasets([train_imdb, train_sst2]).shuffle(seed=42)\ntest_mixed = concatenate_datasets([test_imdb, val_sst2]).shuffle(seed=42)\n\nprint(f\"✅ Datasets loaded successfully!\")\nprint(f\"Number of training samples: {len(train_mixed)}\")","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:28:22.715564Z","iopub.execute_input":"2026-04-05T18:28:22.715957Z","iopub.status.idle":"2026-04-05T18:28:25.728039Z","shell.execute_reply.started":"2026-04-05T18:28:22.715925Z","shell.execute_reply":"2026-04-05T18:28:25.727226Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/25000 [00:00=2.0.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.8.3)\nRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.0.2)\nRequirement already satisfied: dill in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.4.1)\nRequirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.3.3)\nRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.32.4)\nRequirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.67.3)\nRequirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from evaluate) (3.6.0)\nRequirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.70.16)\nRequirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2026.2.0)\nRequirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (1.4.1)\nRequirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from evaluate) (26.0)\nRequirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (3.24.3)\nRequirement already satisfied: pyarrow>=21.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (23.0.1)\nRequirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (0.28.1)\nRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (6.0.3)\nRequirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (3.13.3)\nRequirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.3.0)\nRequirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.5.4)\nRequirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (0.24.0)\nRequirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.15.0)\nRequirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.4.4)\nRequirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.11)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2.5.0)\nRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2026.1.4)\nRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2.9.0.post0)\nRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2025.2)\nRequirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2025.3)\nRequirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (2.6.1)\nRequirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.4.0)\nRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (25.4.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.8.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (6.7.1)\nRequirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (0.4.1)\nRequirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.22.0)\nRequirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (4.12.1)\nRequirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (1.0.9)\nRequirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->datasets>=2.0.0->evaluate) (0.16.0)\nRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.17.0)\nRequirement already satisfied: typer>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from typer-slim->huggingface-hub>=0.7.0->evaluate) (0.24.1)\nRequirement already satisfied: click>=8.2.1 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (8.3.1)\nRequirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (13.9.4)\nRequirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (0.0.4)\nRequirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (4.0.0)\nRequirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (2.19.2)\nRequirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (0.1.2)\nDownloading evaluate-0.4.6-py3-none-any.whl (84 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hInstalling collected packages: evaluate\nSuccessfully installed evaluate-0.4.6\n","output_type":"stream"}],"execution_count":6},{"cell_type":"code","source":"import numpy as np\nfrom evaluate import load\n\nmetric = load(\"accuracy\")\n\ndef compute_metrics(eval_pred):\n logits, labels = eval_pred\n predictions = np.argmax(logits, axis=-1)\n return metric.compute(predictions=predictions, references=labels)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:32:50.719412Z","iopub.execute_input":"2026-04-05T18:32:50.720245Z","iopub.status.idle":"2026-04-05T18:32:51.634078Z","shell.execute_reply.started":"2026-04-05T18:32:50.720208Z","shell.execute_reply":"2026-04-05T18:32:51.633544Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Downloading builder script: 0.00B [00:00, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"aa3cd8ad5032426682e4297c0ddfbc69"}},"metadata":{}}],"execution_count":7},{"cell_type":"code","source":"training_args = TrainingArguments(\n output_dir=\"./VibeCheck_v1\",\n eval_strategy=\"epoch\",\n save_strategy=\"epoch\",\n learning_rate=3e-5,\n per_device_train_batch_size=32,\n num_train_epochs=10,\n weight_decay=0.02,\n lr_scheduler_type=\"cosine\",\n warmup_steps=500,\n load_best_model_at_end=True,\n metric_for_best_model=\"accuracy\",\n report_to=\"none\"\n)\n\ntrainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=train_mixed,\n eval_dataset=test_mixed,\n compute_metrics=compute_metrics,\n)\n\ntrainer.train()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:33:20.417350Z","iopub.execute_input":"2026-04-05T18:33:20.417666Z","iopub.status.idle":"2026-04-05T18:50:52.500464Z","shell.execute_reply.started":"2026-04-05T18:33:20.417637Z","shell.execute_reply":"2026-04-05T18:50:52.499917Z"}},"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"","text/html":"\n
\n \n \n [14430/14430 17:31, Epoch 10/10]\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
EpochTraining LossValidation LossAccuracy
10.5915380.9570270.785019
20.5515880.8780360.802914
30.4968231.0046670.782352
40.4606000.9110870.793677
50.4115951.0063840.791473
60.3883031.0628310.785135
70.3515781.1887060.783472
80.3311631.2573520.780612
90.3109471.2829400.781385
100.3169771.2946340.780999

"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00