{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c2ed359a", "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "id": "2d441603", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textIDtextselected_textsentimentTime of TweetAge of UserCountryPopulation -2020Land Area (Km²)Density (P/Km²)
0cb774db0d1I`d have responded, if I were goingI`d have responded, if I were goingneutralmorning0-20Afghanistan38928346652860.060
1549e992a42Sooo SAD I will miss you here in San Diego!!!Sooo SADnegativenoon21-30Albania287779727400.0105
2088c60f138my boss is bullying me...bullying menegativenight31-45Algeria438510442381740.018
39642c003efwhat interview! leave me aloneleave me alonenegativemorning46-60Andorra77265470.0164
4358bd9e861Sons of ****, why couldn`t they put them on t...Sons of ****,negativenoon60-70Angola328662721246700.026
.................................
274764eac33d1c0wish we could come see u on Denver husband l...d lostnegativenight31-45Ghana31072940227540.0137
274774f4c4fc327I`ve wondered about rake to. The client has ..., don`t forcenegativemorning46-60Greece10423054128900.081
27478f67aae2310Yay good for both of you. Enjoy the break - y...Yay good for both of you.positivenoon60-70Grenada112523340.0331
27479ed167662a5But it was worth it ****.But it was worth it ****.positivenight70-100Guatemala17915568107160.0167
274806f7127d9d7All this flirting going on - The ATG smiles...All this flirting going on - The ATG smiles. Y...neutralmorning0-20Guinea13132795246000.053
\n", "

27481 rows × 10 columns

\n", "
" ], "text/plain": [ " textID text \\\n", "0 cb774db0d1 I`d have responded, if I were going \n", "1 549e992a42 Sooo SAD I will miss you here in San Diego!!! \n", "2 088c60f138 my boss is bullying me... \n", "3 9642c003ef what interview! leave me alone \n", "4 358bd9e861 Sons of ****, why couldn`t they put them on t... \n", "... ... ... \n", "27476 4eac33d1c0 wish we could come see u on Denver husband l... \n", "27477 4f4c4fc327 I`ve wondered about rake to. The client has ... \n", "27478 f67aae2310 Yay good for both of you. Enjoy the break - y... \n", "27479 ed167662a5 But it was worth it ****. \n", "27480 6f7127d9d7 All this flirting going on - The ATG smiles... \n", "\n", " selected_text sentiment \\\n", "0 I`d have responded, if I were going neutral \n", "1 Sooo SAD negative \n", "2 bullying me negative \n", "3 leave me alone negative \n", "4 Sons of ****, negative \n", "... ... ... \n", "27476 d lost negative \n", "27477 , don`t force negative \n", "27478 Yay good for both of you. positive \n", "27479 But it was worth it ****. positive \n", "27480 All this flirting going on - The ATG smiles. Y... neutral \n", "\n", " Time of Tweet Age of User Country Population -2020 \\\n", "0 morning 0-20 Afghanistan 38928346 \n", "1 noon 21-30 Albania 2877797 \n", "2 night 31-45 Algeria 43851044 \n", "3 morning 46-60 Andorra 77265 \n", "4 noon 60-70 Angola 32866272 \n", "... ... ... ... ... \n", "27476 night 31-45 Ghana 31072940 \n", "27477 morning 46-60 Greece 10423054 \n", "27478 noon 60-70 Grenada 112523 \n", "27479 night 70-100 Guatemala 17915568 \n", "27480 morning 0-20 Guinea 13132795 \n", "\n", " Land Area (Km²) Density (P/Km²) \n", "0 652860.0 60 \n", "1 27400.0 105 \n", "2 2381740.0 18 \n", "3 470.0 164 \n", "4 1246700.0 26 \n", "... ... ... \n", "27476 227540.0 137 \n", "27477 128900.0 81 \n", "27478 340.0 331 \n", "27479 107160.0 167 \n", "27480 246000.0 53 \n", "\n", "[27481 rows x 10 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df=pd.read_csv('train.csv',encoding='unicode_escape')\n", "df" ] }, { "cell_type": "code", "execution_count": 3, "id": "60b7c4de", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textsentiment
0I`d have responded, if I were goingneutral
1Sooo SAD I will miss you here in San Diego!!!negative
2my boss is bullying me...negative
3what interview! leave me alonenegative
4Sons of ****, why couldn`t they put them on t...negative
.........
27476wish we could come see u on Denver husband l...negative
27477I`ve wondered about rake to. The client has ...negative
27478Yay good for both of you. Enjoy the break - y...positive
27479But it was worth it ****.positive
27480All this flirting going on - The ATG smiles...neutral
\n", "

27481 rows × 2 columns

\n", "
" ], "text/plain": [ " text sentiment\n", "0 I`d have responded, if I were going neutral\n", "1 Sooo SAD I will miss you here in San Diego!!! negative\n", "2 my boss is bullying me... negative\n", "3 what interview! leave me alone negative\n", "4 Sons of ****, why couldn`t they put them on t... negative\n", "... ... ...\n", "27476 wish we could come see u on Denver husband l... negative\n", "27477 I`ve wondered about rake to. The client has ... negative\n", "27478 Yay good for both of you. Enjoy the break - y... positive\n", "27479 But it was worth it ****. positive\n", "27480 All this flirting going on - The ATG smiles... neutral\n", "\n", "[27481 rows x 2 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.drop(df.columns[[0,2,4,5,6,7,8,9]], axis=1, inplace=True)\n", "df" ] }, { "cell_type": "code", "execution_count": 4, "id": "296d66e2", "metadata": {}, "outputs": [], "source": [ "labels=df.sentiment.unique()" ] }, { "cell_type": "code", "execution_count": 5, "id": "99085d3e", "metadata": {}, "outputs": [], "source": [ "label_dict={}\n", "for id,label in enumerate(labels):\n", " label_dict[label]=id" ] }, { "cell_type": "code", "execution_count": 6, "id": "fa1e4160", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'neutral': 0, 'negative': 1, 'positive': 2}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label_dict" ] }, { "cell_type": "code", "execution_count": 7, "id": "eaa2c872", "metadata": {}, "outputs": [], "source": [ "df['label']=df.sentiment.replace(label_dict)" ] }, { "cell_type": "code", "execution_count": 8, "id": "63ed05e3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textsentimentlabel
0I`d have responded, if I were goingneutral0
1Sooo SAD I will miss you here in San Diego!!!negative1
2my boss is bullying me...negative1
3what interview! leave me alonenegative1
4Sons of ****, why couldn`t they put them on t...negative1
............
27476wish we could come see u on Denver husband l...negative1
27477I`ve wondered about rake to. The client has ...negative1
27478Yay good for both of you. Enjoy the break - y...positive2
27479But it was worth it ****.positive2
27480All this flirting going on - The ATG smiles...neutral0
\n", "

27481 rows × 3 columns

\n", "
" ], "text/plain": [ " text sentiment label\n", "0 I`d have responded, if I were going neutral 0\n", "1 Sooo SAD I will miss you here in San Diego!!! negative 1\n", "2 my boss is bullying me... negative 1\n", "3 what interview! leave me alone negative 1\n", "4 Sons of ****, why couldn`t they put them on t... negative 1\n", "... ... ... ...\n", "27476 wish we could come see u on Denver husband l... negative 1\n", "27477 I`ve wondered about rake to. The client has ... negative 1\n", "27478 Yay good for both of you. Enjoy the break - y... positive 2\n", "27479 But it was worth it ****. positive 2\n", "27480 All this flirting going on - The ATG smiles... neutral 0\n", "\n", "[27481 rows x 3 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 9, "id": "b1dc846d", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "X_train,X_val,y_train,y_val=train_test_split(df.index.values,df['label'].values,train_size=0.8,random_state=0)" ] }, { "cell_type": "code", "execution_count": 10, "id": "418fb78e", "metadata": {}, "outputs": [], "source": [ "df['data_type'] = 'not_set'\n", "df.loc[X_train, 'data_type'] = 'train'\n", "df.loc[X_val, 'data_type'] = 'val'" ] }, { "cell_type": "code", "execution_count": 11, "id": "ceed0315", "metadata": {}, "outputs": [], "source": [ "df['text']=df['text'].astype(str)" ] }, { "cell_type": "code", "execution_count": 12, "id": "d3c74651", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textsentimentlabeldata_type
0I`d have responded, if I were goingneutral0train
1Sooo SAD I will miss you here in San Diego!!!negative1train
2my boss is bullying me...negative1train
3what interview! leave me alonenegative1train
4Sons of ****, why couldn`t they put them on t...negative1val
...............
27476wish we could come see u on Denver husband l...negative1train
27477I`ve wondered about rake to. The client has ...negative1train
27478Yay good for both of you. Enjoy the break - y...positive2val
27479But it was worth it ****.positive2val
27480All this flirting going on - The ATG smiles...neutral0val
\n", "

27481 rows × 4 columns

\n", "
" ], "text/plain": [ " text sentiment label \\\n", "0 I`d have responded, if I were going neutral 0 \n", "1 Sooo SAD I will miss you here in San Diego!!! negative 1 \n", "2 my boss is bullying me... negative 1 \n", "3 what interview! leave me alone negative 1 \n", "4 Sons of ****, why couldn`t they put them on t... negative 1 \n", "... ... ... ... \n", "27476 wish we could come see u on Denver husband l... negative 1 \n", "27477 I`ve wondered about rake to. The client has ... negative 1 \n", "27478 Yay good for both of you. Enjoy the break - y... positive 2 \n", "27479 But it was worth it ****. positive 2 \n", "27480 All this flirting going on - The ATG smiles... neutral 0 \n", "\n", " data_type \n", "0 train \n", "1 train \n", "2 train \n", "3 train \n", "4 val \n", "... ... \n", "27476 train \n", "27477 train \n", "27478 val \n", "27479 val \n", "27480 val \n", "\n", "[27481 rows x 4 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 13, "id": "b018cca8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([' I`d have responded, if I were going',\n", " ' Sooo SAD I will miss you here in San Diego!!!',\n", " 'my boss is bullying me...', ...,\n", " 'So I get up early and I feel good about the day. I walk to work and I`m feeling alright. But guess what... I don`t work today.',\n", " ' wish we could come see u on Denver husband lost his job and can`t afford it',\n", " ' I`ve wondered about rake to. The client has made it clear .NET only, don`t force devs to learn a new lang #agile #ccnet'],\n", " dtype=object)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[df.data_type=='train'].text.values" ] }, { "cell_type": "code", "execution_count": 14, "id": "0d03c58e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\KARAN\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import BertTokenizer\n", "from torch.utils.data import TensorDataset\n", "import torch" ] }, { "cell_type": "code", "execution_count": 15, "id": "1fc7bfd6", "metadata": {}, "outputs": [], "source": [ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', \n", " do_lower_case=True)" ] }, { "cell_type": "code", "execution_count": 16, "id": "ea3521d9", "metadata": {}, "outputs": [], "source": [ "encoded_data_train = tokenizer.batch_encode_plus(\n", " df[df.data_type=='train'].text.values.tolist(), \n", " add_special_tokens=True, \n", " return_attention_mask=True, \n", " max_length=256,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt',\n", ")\n", "\n", "encoded_data_val = tokenizer.batch_encode_plus(\n", " df[df.data_type=='val'].text.values.tolist(), \n", " add_special_tokens=True, \n", " return_attention_mask=True, \n", " max_length=256,\n", " truncation=True,\n", " padding='max_length', \n", " return_tensors='pt'\n", ")\n", "\n", "\n", "input_ids_train = encoded_data_train['input_ids']\n", "attention_masks_train = encoded_data_train['attention_mask']\n", "labels_train = torch.tensor(df[df.data_type=='train'].label.values)\n", "\n", "input_ids_val = encoded_data_val['input_ids']\n", "attention_masks_val = encoded_data_val['attention_mask']\n", "labels_val = torch.tensor(df[df.data_type=='val'].label.values)" ] }, { "cell_type": "code", "execution_count": 17, "id": "d56c3636", "metadata": {}, "outputs": [], "source": [ "train_data=TensorDataset(input_ids_train,attention_masks_train,labels_train)\n", "val_data=TensorDataset(input_ids_val,attention_masks_val,labels_val)" ] }, { "cell_type": "code", "execution_count": 18, "id": "c1e6192b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import BertForSequenceClassification\n", "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n", " num_labels=len(label_dict),\n", " output_attentions=False,\n", " output_hidden_states=False)" ] }, { "cell_type": "code", "execution_count": 19, "id": "18b4fca0", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader,RandomSampler,SequentialSampler\n", "train_loader=DataLoader(\n", " train_data,\n", " sampler=RandomSampler(train_data),\n", " batch_size=4\n", ")\n", "\n", "val_loader=DataLoader(\n", " val_data,\n", " sampler=SequentialSampler(val_data),\n", " batch_size=4\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "id": "b3f37358", "metadata": {}, "outputs": [], "source": [ "from transformers import get_linear_schedule_with_warmup\n", "from torch.optim import AdamW\n", "\n", "optimizer=AdamW(\n", " model.parameters(),\n", " lr=1e-5,\n", " eps=1e-8\n", ")\n", "\n", "epochs=5\n", "\n", "scheduler=get_linear_schedule_with_warmup(\n", " optimizer,\n", " num_warmup_steps=0,\n", " num_training_steps=len(train_data)*epochs\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "id": "a5ccc6d8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device('cuda')\n", "model.to(device)\n", "print(device)" ] }, { "cell_type": "code", "execution_count": 22, "id": "1ad2f635", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "def eval(val_loader,model):\n", " model.eval()\n", " loss_val_total=0\n", " preds,true=[],[]\n", " \n", " for batch in val_loader:\n", " batch=tuple(b.to(device) for b in batch)\n", " inputs = {'input_ids': batch[0],\n", " 'attention_mask': batch[1],\n", " 'labels': batch[2],\n", " }\n", " \n", " with torch.no_grad():\n", " outputs=model(**inputs)\n", " \n", " loss=outputs[0]\n", " logits=outputs[1]\n", " loss_val_total+=loss.item()\n", " logits=logits.detach().cpu().numpy()\n", " labels=inputs['labels'].cpu().numpy()\n", " preds.append(logits)\n", " true.append(labels)\n", " \n", " loss_val_avg=loss_val_total/len(val_loader)\n", " predictions=np.concatenate(preds,axis=0)\n", " true_vals=np.concatenate(true,axis=0)\n", " \n", " return loss_val_avg,predictions,true_vals" ] }, { "cell_type": "code", "execution_count": 138, "id": "05f1146d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1\n", "Training loss: 1.1087568206265244\n", "Validation loss: 1.1073771828738126\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 2\n", "Training loss: 1.1035335373561803\n", "Validation loss: 1.0943231875246222\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 3\n", "Training loss: 1.0946122174852106\n", "Validation loss: 1.0898548677617854\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 4\n", "Training loss: 1.0907055499164993\n", "Validation loss: 1.0901057242480192\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 5\n", "Training loss: 1.0898382831825786\n", "Validation loss: 1.0943078843030063\n" ] } ], "source": [ "from tqdm import tqdm\n", "for epoch in range(1,epochs+1):\n", " model.train()\n", " loss_train_total=0\n", " progress_bar = tqdm(train_loader,desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)\n", " for batch in progress_bar:\n", "\n", " model.zero_grad()\n", " \n", " batch = tuple(b.to(device) for b in batch)\n", " \n", " inputs = {'input_ids': batch[0],\n", " 'attention_mask': batch[1],\n", " 'labels': batch[2],\n", " } \n", "\n", " outputs = model(**inputs)\n", " \n", " loss = outputs[0]\n", " loss_train_total += loss.item()\n", " loss.backward()\n", "\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", "\n", " optimizer.step()\n", " scheduler.step()\n", " \n", " progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})\n", " torch.save(model.state_dict(), f'finetuned_BERT_epoch_{epoch}.model')\n", " \n", " tqdm.write(f'\\nEpoch {epoch}')\n", " \n", " loss_train_avg = loss_train_total/len(train_loader) \n", " tqdm.write(f'Training loss: {loss_train_avg}')\n", " \n", " val_loss, predictions, true_vals = eval(val_loader)\n", " tqdm.write(f'Validation loss: {val_loss}')" ] }, { "cell_type": "code", "execution_count": 23, "id": "e9f7735a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/plain": [ "BertForSequenceClassification(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (classifier): Linear(in_features=768, out_features=3, bias=True)\n", ")" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model=BertForSequenceClassification.from_pretrained(\n", " 'bert-base-uncased',\n", " num_labels=len(label_dict),\n", " output_attentions=False,\n", " output_hidden_states=False\n", ")\n", " \n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 40, "id": "2cb2cb43", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.load_state_dict(torch.load('finetuned_BERT_epoch_2.model',map_location=torch.device('cpu')))" ] }, { "cell_type": "code", "execution_count": 41, "id": "86053301", "metadata": {}, "outputs": [], "source": [ "loss,predictions,true_vals=eval(val_loader,model)" ] }, { "cell_type": "code", "execution_count": null, "id": "26089e26", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-0.94487095, 2.4501007 , -2.4328873 ],\n", " [ 3.0208707 , -1.7925887 , 0.409608 ],\n", " [-1.245395 , 2.8607914 , -2.7080884 ],\n", " ...,\n", " [-0.13207848, -2.0695374 , 3.5249124 ],\n", " [-1.0361273 , -2.475614 , 3.9253955 ],\n", " [-0.3563956 , -2.541143 , 3.703467 ]], dtype=float32)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions" ] }, { "cell_type": "code", "execution_count": null, "id": "bf20a5ca", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 1, ..., 2, 2, 2], dtype=int64)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds_flat = np.argmax(predictions, axis=1).flatten()\n", "preds_flat" ] }, { "cell_type": "code", "execution_count": null, "id": "70d73cf6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 1, ..., 2, 2, 0], dtype=int64)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "true_vals" ] }, { "cell_type": "code", "execution_count": null, "id": "f4d78070", "metadata": {}, "outputs": [], "source": [ "def accuracy_per_class(preds, labels):\n", " label_dict_inverse = {v: k for k, v in label_dict.items()}\n", " \n", " preds_flat = np.argmax(preds, axis=1).flatten()\n", " labels_flat = labels.flatten()\n", "\n", " for label in np.unique(labels_flat):\n", " y_preds = preds_flat[labels_flat==label]\n", " y_true = labels_flat[labels_flat==label]\n", " print(f'Class: {label_dict_inverse[label]}')\n", " print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\\n')" ] }, { "cell_type": "code", "execution_count": null, "id": "46eb06a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Class: neutral\n", "Accuracy: 1571/2195\n", "\n", "Class: negative\n", "Accuracy: 1230/1563\n", "\n", "Class: positive\n", "Accuracy: 1501/1739\n", "\n" ] } ], "source": [ "accuracy_per_class(predictions, true_vals)" ] }, { "cell_type": "code", "execution_count": null, "id": "284950e9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.76 0.72 0.74 2195\n", " 1 0.80 0.79 0.79 1563\n", " 2 0.80 0.86 0.83 1739\n", "\n", " accuracy 0.78 5497\n", " macro avg 0.78 0.79 0.79 5497\n", "weighted avg 0.78 0.78 0.78 5497\n", "\n" ] } ], "source": [ "from sklearn.metrics import classification_report\n", "print(classification_report(true_vals,preds_flat))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }