AlexandrKovalenko1981 commited on
Commit
dd7e34c
·
verified ·
1 Parent(s): 2bcea48

Upload fine_tune_loop.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. fine_tune_loop.py +96 -0
fine_tune_loop.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, DataCollatorWithPadding
3
+ from torch.utils.data import DataLoader
4
+ from transformers import AutoModelForSequenceClassification
5
+ from transformers import AdamW
6
+ from transformers import get_scheduler
7
+ import torch
8
+ from tqdm.auto import tqdm
9
+ import evaluate
10
+
11
+ raw_datasets = load_dataset("glue","mrpc")
12
+ checkpoint = 'bert-base-cased'
13
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
14
+
15
+ def tokenize_function(example):
16
+ return tokenizer(example['sentence1'], example['sentence2'], truncation=True)
17
+
18
+ tokenized_dataset = raw_datasets.map(tokenize_function, batched=True)
19
+ tokenized_dataset = tokenized_dataset.remove_columns(['sentence1', 'sentence2','idx'])
20
+ tokenized_dataset = tokenized_dataset.rename_column('label','labels')
21
+ #print(tokenized_dataset.column_names["train"])
22
+
23
+ tokenized_dataset.set_format('torch')
24
+ #print(tokenized_dataset)
25
+
26
+ data_collator = DataCollatorWithPadding(tokenizer)
27
+
28
+ train_dataloader = DataLoader(
29
+ tokenized_dataset['validation'], batch_size=8, collate_fn=data_collator
30
+ )
31
+
32
+ eval_dataloader = DataLoader(
33
+ tokenized_dataset['validation'], batch_size=8, collate_fn=data_collator
34
+ )
35
+
36
+ #for batch in train_dataloader:
37
+ # break
38
+ #print({k: v.shape for k, v in batch.items()})
39
+ #print()
40
+ #print(batch)
41
+ #print()
42
+
43
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
44
+
45
+ #outputs = model(**batch)
46
+ #print(outputs.loss, outputs.logits.shape)
47
+
48
+ optimizer = AdamW(model.parameters(), lr=5e-5)
49
+
50
+ #loss = outputs.loss
51
+ #loss.backward()
52
+ #optimizer.step()
53
+
54
+ #optimizer.zero_grad()
55
+
56
+ num_epochs = 3
57
+ num_training_steps = num_epochs * len(train_dataloader)
58
+ lr_scheduler = get_scheduler(
59
+ 'linear',
60
+ optimizer=optimizer,
61
+ num_warmup_steps=0,
62
+ num_training_steps=num_training_steps
63
+ )
64
+
65
+ device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
66
+ model.to(device)
67
+ print(f'Using device: {device}')
68
+
69
+ progress_bar = tqdm(range(num_training_steps))
70
+
71
+ model.train()
72
+ for epoch in range(num_epochs):
73
+ for batch in train_dataloader:
74
+ batch = {k: v.to(device) for k, v in batch.items()}
75
+ outputs = model(**batch)
76
+ loss = outputs.loss
77
+ loss.backward()
78
+
79
+ optimizer.step()
80
+ lr_scheduler.step()
81
+ optimizer.zero_grad()
82
+ progress_bar.update(1)
83
+
84
+ metric= evaluate.load('glue','mrpc')
85
+ model.eval()
86
+ for batch in eval_dataloader:
87
+ batch = {k: v.to(device) for k, v in batch.items()}
88
+ with torch.no_grad():
89
+ outputs = model(**batch)
90
+
91
+ logits = outputs.logits
92
+ predictions = torch.argmax(logits, dim=-1)
93
+ metric.add_batch(predictions=predictions, references=batch['labels'])
94
+
95
+ result = metric.compute()
96
+ print(result)