rishikasrinivas commited on
Commit
03012a7
·
verified ·
1 Parent(s): 2566a7e

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +171 -0
train_model.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, BertForSequenceClassification, BertTokenizer
2
+ import torch
3
+ from process_data import getDF
4
+ from torch.utils.data import TensorDataset, random_split
5
+ from torch.utils.data import DataLoader, SequentialSampler
6
+
7
+ from transformers import DataCollatorForTokenClassification
8
+ from transformers import get_linear_schedule_with_warmup
9
+ from sampler import BalanceSampler
10
+ NUM_CLASSES = 13
11
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
12
+
13
+ # Load pre-trained model and tokenizer
14
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
15
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
16
+
17
+ # Modify the output layer to match the number of classes
18
+ model.classifier = torch.nn.Linear(in_features = 768, out_features= NUM_CLASSES)
19
+ print(model)
20
+ data_collator = DataCollatorForTokenClassification(tokenizer)
21
+
22
+ def get_input_id_and_attention_masks():
23
+ df = getDF() #from process.py
24
+
25
+ input_ids = []
26
+ attention_masks = []
27
+ for summ in df['summary']:
28
+ encoded_dict = tokenizer.encode_plus(
29
+ summ, # Sentence to encode.
30
+ add_special_tokens = True, # Add '[CLS]' and '[SEP]'
31
+ max_length = 512, # Pad & truncate all sentences.
32
+ truncation=True,
33
+ pad_to_max_length = True,
34
+ padding='max_length',
35
+ return_attention_mask = True, # Construct attn. masks.
36
+ return_tensors = 'pt',# Return pytorch tensors.
37
+ )
38
+ input_ids.append(encoded_dict['input_ids'])
39
+
40
+ # And its attention mask (simply differentiates padding from non-padding).
41
+ attention_masks.append(encoded_dict['attention_mask'])
42
+ input_ids = torch.cat(input_ids, dim=0)
43
+ attention_masks = torch.cat(attention_masks, dim=0)
44
+
45
+
46
+ labels = torch.from_numpy(np.array(df['genre_id'].tolist()))
47
+ return input_ids, attention_masks, labels
48
+
49
+ input_ids,attention_masks, labels= get_input_id_and_attention_masks()
50
+
51
+ def createTensorDS(input_ids,attention_masks, labels):
52
+ return TensorDataset(input_ids, attention_masks, labels)
53
+
54
+ def split(tensorDataset):
55
+ train_size = int(0.85 * len(tensorDataset))
56
+ val_size = len(tensorDataset) - train_size
57
+ train_dataset, val_dataset = random_split(tensorDataset, [train_size, val_size])
58
+ return train_dataset, val_dataset
59
+
60
+ def createDataloaders(train_dataset, val_dataset):
61
+
62
+ batch_size = 16
63
+
64
+ train_dataloader = DataLoader(
65
+ train_dataset,
66
+ sampler = BalanceSampler(train_dataset),
67
+ batch_size = batch_size
68
+ )
69
+
70
+ valid_dataloader = DataLoader(
71
+ val_dataset,
72
+ sampler = SequentialSampler(val_dataset),
73
+ batch_size = batch_size
74
+ )
75
+ return train_dataloader, valid_dataloader
76
+
77
+
78
+ def calc_accuracy(logits,labels):
79
+ label=[]
80
+ num_ones = 0
81
+ acc = 0
82
+ for label_set in labels:
83
+ labs = []
84
+ for ind, res in enumerate(label_set):
85
+ if res.item() == 1:
86
+ labs.append(ind)
87
+ label.append(labs)
88
+ num_ones += len(labs)
89
+
90
+ for i,log in enumerate(logits):
91
+ top_out = (-log).argsort()[:5]
92
+
93
+ for ind in top_out:
94
+ if ind in label[i]:
95
+ acc = acc+1
96
+ return acc/num_ones
97
+
98
+ def train(model, train, val, epochs):
99
+ total_steps = len(train)*epochs
100
+ optimizer = torch.optim.Adam(model.parameters(),
101
+ lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
102
+ eps = 1e-8 # args.adam_epsilon - default is 1e-8.
103
+ )
104
+ scheduler = get_linear_schedule_with_warmup(optimizer,
105
+ num_warmup_steps = 0,
106
+ num_training_steps = total_steps)
107
+ loss_fn=torch.nn.BCEWithLogitsLoss()
108
+ for epoch in range(3):
109
+ total_train_loss = 0
110
+ batch_loss = 0
111
+ print("")
112
+ print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
113
+ print('Training...')
114
+ model.train()
115
+ for step, batch in enumerate(train):
116
+ input_ids= batch[0].to(device)
117
+ input_mask = batch[1].to(device)
118
+ labels = batch[2].to(device)
119
+
120
+ optimizer.zero_grad()
121
+ out = model(input_ids, attention_mask=input_mask)
122
+
123
+ logits =out['logits']
124
+ loss = loss_fn(logits, labels)
125
+
126
+ acc += calc_accuracy(logits, labels)
127
+ total_train_loss += loss.item()
128
+ batch_loss += loss.item()
129
+
130
+ loss.backward()
131
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
132
+ optimizer.step()
133
+ scheduler.step()
134
+
135
+ avg_train_loss = total_train_loss/len(train)
136
+ print('train_loss: ', avg_train_loss,)
137
+ print('train_acc: ', acc)
138
+ print("Running Validation...")
139
+ model.eval()
140
+ total_eval_accuracy=0
141
+ total_eval_loss= 0
142
+ num_Eval_steps= 0
143
+
144
+ for batch in val:
145
+ input_ids= batch[0].to(device)
146
+ input_mask=batch[1].to(device)
147
+ labels = batch[2].to(device)
148
+ with torch.no_grad():
149
+ out = model(input_ids,attention_mask=input_mask)
150
+
151
+
152
+
153
+ logits = out['logits']
154
+ loss = loss_fn(logits, labels)
155
+ total_eval_loss += loss.item()
156
+
157
+ logits = logits.detach().cpu().numpy()
158
+ label_ids = labels.cpu().numpy()
159
+
160
+
161
+ avg_loss_Eval = total_eval_loss/len(val)
162
+ print(
163
+ 'epoch: ', epoch,
164
+ 'train_loss: ', avg_train_loss,
165
+ 'valid loss ', avg_loss_Eval,
166
+ )
167
+ input_ids, attention_masks, labels=get_input_id_and_attention_masks()
168
+ ds=createTensorDS(input_ids, attention_masks, labels)
169
+ train_dataset, val_dataset=split(ds)
170
+ train_dataloader, valid_dataloader=createDataloaders(train_dataset, val_dataset)
171
+ train(model, train_dataloader, valid_dataloader, 3)