File size: 25,267 Bytes
f04e061 | 1 | {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, Trainer, TrainingArguments\nfrom datasets import load_dataset, concatenate_datasets, Features, ClassLabel, Value, Sequence\n\nds_imdb = load_dataset(\"imdb\")\nds_sst2 = load_dataset(\"glue\", \"sst2\")\n\ntokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n\ncommon_features = Features({\n 'input_ids': Sequence(Value('int32')),\n 'attention_mask': Sequence(Value('int8')),\n 'label': ClassLabel(names=['negative', 'positive'])\n})\n\ndef tokenize_fn(examples):\n return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True, max_length=128)\n\ndef tokenize_sst2_fn(examples):\n return tokenizer(examples[\"sentence\"], padding=\"max_length\", truncation=True, max_length=128)\n\ntokenized_imdb = ds_imdb.map(tokenize_fn, batched=True)\ntokenized_sst2 = ds_sst2.map(tokenize_sst2_fn, batched=True)\n\ncolumns_to_keep = ['input_ids', 'attention_mask', 'label']\n\ntrain_imdb = tokenized_imdb[\"train\"].remove_columns([c for c in tokenized_imdb[\"train\"].column_names if c not in columns_to_keep])\ntrain_sst2 = tokenized_sst2[\"train\"].remove_columns([c for c in tokenized_sst2[\"train\"].column_names if c not in columns_to_keep])\n\ntest_imdb = tokenized_imdb[\"test\"].remove_columns([c for c in tokenized_imdb[\"test\"].column_names if c not in columns_to_keep])\nval_sst2 = tokenized_sst2[\"validation\"].remove_columns([c for c in tokenized_sst2[\"validation\"].column_names if c not in columns_to_keep])\n\ntrain_imdb = train_imdb.cast(common_features)\ntrain_sst2 = train_sst2.cast(common_features)\ntest_imdb = test_imdb.cast(common_features)\nval_sst2 = val_sst2.cast(common_features)\n\ntrain_mixed = concatenate_datasets([train_imdb, train_sst2]).shuffle(seed=42)\ntest_mixed = concatenate_datasets([test_imdb, val_sst2]).shuffle(seed=42)\n\nprint(f\"β
Datasets loaded successfully!\")\nprint(f\"Number of training samples: {len(train_mixed)}\")","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:28:22.715564Z","iopub.execute_input":"2026-04-05T18:28:22.715957Z","iopub.status.idle":"2026-04-05T18:28:25.728039Z","shell.execute_reply.started":"2026-04-05T18:28:22.715925Z","shell.execute_reply":"2026-04-05T18:28:25.727226Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/25000 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"19ce5b5196ef4468863f10e8e7d4e61b"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/67349 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e62bb1ea5c144ab1b0473a77ae299370"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/25000 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"1ad819f4820945379fab0f49a27b01a3"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/872 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"bd95a6f98fcf41ab9afe9ec85186bedc"}},"metadata":{}},{"name":"stdout","text":"β
Datasets loaded successfully!\nNumber of training samples: 92349\n","output_type":"stream"}],"execution_count":2},{"cell_type":"code","source":"config = AutoConfig.from_pretrained(\n \"distilbert-base-uncased\",\n num_labels=2,\n n_layers=4,\n n_heads=8,\n dim=256,\n hidden_dim=1024\n)\n\nmodel = AutoModelForSequenceClassification.from_config(config)\nprint(f\"VibeCheck v1 size: {model.num_parameters():,} parameters\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:28:29.886911Z","iopub.execute_input":"2026-04-05T18:28:29.887229Z","iopub.status.idle":"2026-04-05T18:28:30.274162Z","shell.execute_reply.started":"2026-04-05T18:28:29.887200Z","shell.execute_reply":"2026-04-05T18:28:30.273528Z"}},"outputs":[{"name":"stdout","text":"VibeCheck v1 size: 11,170,562 parameters\n","output_type":"stream"}],"execution_count":3},{"cell_type":"code","source":"!pip install evaluate","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:32:43.535586Z","iopub.execute_input":"2026-04-05T18:32:43.535912Z","iopub.status.idle":"2026-04-05T18:32:48.212496Z","shell.execute_reply.started":"2026-04-05T18:32:43.535884Z","shell.execute_reply":"2026-04-05T18:32:48.211839Z"}},"outputs":[{"name":"stdout","text":"Collecting evaluate\n Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)\nRequirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.8.3)\nRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.0.2)\nRequirement already satisfied: dill in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.4.1)\nRequirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.3.3)\nRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.32.4)\nRequirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.67.3)\nRequirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from evaluate) (3.6.0)\nRequirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.70.16)\nRequirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2026.2.0)\nRequirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (1.4.1)\nRequirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from evaluate) (26.0)\nRequirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (3.24.3)\nRequirement already satisfied: pyarrow>=21.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (23.0.1)\nRequirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (0.28.1)\nRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (6.0.3)\nRequirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (3.13.3)\nRequirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.3.0)\nRequirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.5.4)\nRequirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (0.24.0)\nRequirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.15.0)\nRequirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.4.4)\nRequirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.11)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2.5.0)\nRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2026.1.4)\nRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2.9.0.post0)\nRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2025.2)\nRequirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2025.3)\nRequirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (2.6.1)\nRequirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.4.0)\nRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (25.4.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.8.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (6.7.1)\nRequirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (0.4.1)\nRequirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.22.0)\nRequirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (4.12.1)\nRequirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (1.0.9)\nRequirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->datasets>=2.0.0->evaluate) (0.16.0)\nRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.17.0)\nRequirement already satisfied: typer>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from typer-slim->huggingface-hub>=0.7.0->evaluate) (0.24.1)\nRequirement already satisfied: click>=8.2.1 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (8.3.1)\nRequirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (13.9.4)\nRequirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (0.0.4)\nRequirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (4.0.0)\nRequirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (2.19.2)\nRequirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer>=0.24.0->typer-slim->huggingface-hub>=0.7.0->evaluate) (0.1.2)\nDownloading evaluate-0.4.6-py3-none-any.whl (84 kB)\n\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hInstalling collected packages: evaluate\nSuccessfully installed evaluate-0.4.6\n","output_type":"stream"}],"execution_count":6},{"cell_type":"code","source":"import numpy as np\nfrom evaluate import load\n\nmetric = load(\"accuracy\")\n\ndef compute_metrics(eval_pred):\n logits, labels = eval_pred\n predictions = np.argmax(logits, axis=-1)\n return metric.compute(predictions=predictions, references=labels)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:32:50.719412Z","iopub.execute_input":"2026-04-05T18:32:50.720245Z","iopub.status.idle":"2026-04-05T18:32:51.634078Z","shell.execute_reply.started":"2026-04-05T18:32:50.720208Z","shell.execute_reply":"2026-04-05T18:32:51.633544Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Downloading builder script: 0.00B [00:00, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"aa3cd8ad5032426682e4297c0ddfbc69"}},"metadata":{}}],"execution_count":7},{"cell_type":"code","source":"training_args = TrainingArguments(\n output_dir=\"./VibeCheck_v1\",\n eval_strategy=\"epoch\",\n save_strategy=\"epoch\",\n learning_rate=3e-5,\n per_device_train_batch_size=32,\n num_train_epochs=10,\n weight_decay=0.02,\n lr_scheduler_type=\"cosine\",\n warmup_steps=500,\n load_best_model_at_end=True,\n metric_for_best_model=\"accuracy\",\n report_to=\"none\"\n)\n\ntrainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=train_mixed,\n eval_dataset=test_mixed,\n compute_metrics=compute_metrics,\n)\n\ntrainer.train()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:33:20.417350Z","iopub.execute_input":"2026-04-05T18:33:20.417666Z","iopub.status.idle":"2026-04-05T18:50:52.500464Z","shell.execute_reply.started":"2026-04-05T18:33:20.417637Z","shell.execute_reply":"2026-04-05T18:50:52.499917Z"}},"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n <div>\n \n <progress value='14430' max='14430' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [14430/14430 17:31, Epoch 10/10]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n <th>Accuracy</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>1</td>\n <td>0.591538</td>\n <td>0.957027</td>\n <td>0.785019</td>\n </tr>\n <tr>\n <td>2</td>\n <td>0.551588</td>\n <td>0.878036</td>\n <td>0.802914</td>\n </tr>\n <tr>\n <td>3</td>\n <td>0.496823</td>\n <td>1.004667</td>\n <td>0.782352</td>\n </tr>\n <tr>\n <td>4</td>\n <td>0.460600</td>\n <td>0.911087</td>\n <td>0.793677</td>\n </tr>\n <tr>\n <td>5</td>\n <td>0.411595</td>\n <td>1.006384</td>\n <td>0.791473</td>\n </tr>\n <tr>\n <td>6</td>\n <td>0.388303</td>\n <td>1.062831</td>\n <td>0.785135</td>\n </tr>\n <tr>\n <td>7</td>\n <td>0.351578</td>\n <td>1.188706</td>\n <td>0.783472</td>\n </tr>\n <tr>\n <td>8</td>\n <td>0.331163</td>\n <td>1.257352</td>\n <td>0.780612</td>\n </tr>\n <tr>\n <td>9</td>\n <td>0.310947</td>\n <td>1.282940</td>\n <td>0.781385</td>\n </tr>\n <tr>\n <td>10</td>\n <td>0.316977</td>\n <td>1.294634</td>\n <td>0.780999</td>\n </tr>\n </tbody>\n</table><p>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e7209249ac6d43349cb93118c1ab75eb"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e6d3eaa6809242f0901db471a0bedf25"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"43c71330e86b4c89959b2435389951ba"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f55f932ffc494862886840655a0ee431"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"78e811ac42fc4176b580b010bf134987"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"83f31ec6906b45d8900fa7cda16ff6b3"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"0543f65a68484120be30b5107fbb9852"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"ef50172beeae4d138ff1f00bf64b3d30"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"d287f2de5dfa4007b6a61dac98b85681"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py:583: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n return super().apply(*args, **kwargs) # type: ignore[misc]\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5d3b39c7006b435ebaa76908a6fe236e"}},"metadata":{}},{"execution_count":8,"output_type":"execute_result","data":{"text/plain":"TrainOutput(global_step=14430, training_loss=0.41753544268803056, metrics={'train_runtime': 1051.4727, 'train_samples_per_second': 878.282, 'train_steps_per_second': 13.724, 'total_flos': 2287908560194560.0, 'train_loss': 0.41753544268803056, 'epoch': 10.0})"},"metadata":{}}],"execution_count":8},{"cell_type":"code","source":"import os\nimport shutil\n\n# 1. Save the model and tokenizer\nexport_path = \"./VibeCheck_v1_HF\"\nos.makedirs(export_path, exist_ok=True)\n\ntrainer.save_model(export_path)\ntokenizer.save_pretrained(export_path)\n\n# 2. Create a ZIP for easy download\nshutil.make_archive(\"VibeCheck_v1_Model\", 'zip', export_path)\n\nprint(f\"β
Model saved to {export_path} and zipped as VibeCheck_v1_Model.zip\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:51:32.850699Z","iopub.execute_input":"2026-04-05T18:51:32.851085Z","iopub.status.idle":"2026-04-05T18:51:35.185954Z","shell.execute_reply.started":"2026-04-05T18:51:32.851053Z","shell.execute_reply":"2026-04-05T18:51:35.185139Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"48fb8c1700c7400e95cbc38cabd40cf0"}},"metadata":{}},{"name":"stdout","text":"β
Model saved to ./VibeCheck_v1_HF and zipped as VibeCheck_v1_Model.zip\n","output_type":"stream"}],"execution_count":9},{"cell_type":"code","source":"import torch\nfrom transformers import AutoTokenizer, AutoModelForSequenceClassification\n\nclass VibeCheckInference:\n def __init__(self, model_path):\n self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n self.tokenizer = AutoTokenizer.from_pretrained(model_path)\n self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)\n self.model.eval()\n\n def analyze(self, text):\n inputs = self.tokenizer(\n text, \n return_tensors=\"pt\", \n padding=True, \n truncation=True, \n max_length=128\n ).to(self.device)\n\n with torch.no_grad():\n outputs = self.model(**inputs)\n probs = torch.nn.functional.softmax(outputs.logits, dim=-1)\n conf, pred = torch.max(probs, dim=-1)\n\n result = \"POSITIVE\" if pred.item() == 1 else \"NEGATIVE\"\n return {\n \"text\": text,\n \"label\": result,\n \"confidence\": f\"{conf.item() * 100:.2f}%\"\n }\n\n# Usage\nif __name__ == \"__main__\":\n # Point this to your unzipped folder\n engine = VibeCheckInference(\"./VibeCheck_v1_HF\")\n \n sample = \"Did you see the new movie?' B: 'Yeah, it was okay, but the ending felt a bit rushed.' A: 'I totally agree, it could have been better.'\"\n prediction = engine.analyze(sample)\n print(f\"Result: {prediction['label']} | Confidence: {prediction['confidence']}\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-05T18:54:15.722632Z","iopub.execute_input":"2026-04-05T18:54:15.722968Z","iopub.status.idle":"2026-04-05T18:54:15.960733Z","shell.execute_reply.started":"2026-04-05T18:54:15.722936Z","shell.execute_reply":"2026-04-05T18:54:15.959939Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Loading weights: 0%| | 0/72 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"6ae760c792b043288f8e9b38a190a3dc"}},"metadata":{}},{"name":"stdout","text":"Result: NEGATIVE | Confidence: 80.98%\n","output_type":"stream"}],"execution_count":17}]} |