Lim commited on
Commit
3af5ba5
·
1 Parent(s): 35b5385

add fine tunning

Browse files
Files changed (1) hide show
  1. fine_tunning.py +192 -0
fine_tunning.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, DatasetDict, Dataset
2
+
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoConfig,
6
+ AutoModelForSequenceClassification,
7
+ DataCollatorWithPadding,
8
+ TrainingArguments,
9
+ Trainer,
10
+ )
11
+
12
+ from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
13
+
14
+ import torch
15
+ import evaluate
16
+ import torch
17
+ import numpy as np
18
+
19
+ model_checkpoint = "distilbert/distilbert-base-uncased"
20
+
21
+ # define label maps
22
+ id2label = {0: "Negative", 1: "Positive"}
23
+ label2id = {"Negative": 0, "Positive": 1}
24
+
25
+ # generative classification model from model_checkpoint
26
+ model = AutoModelForSequenceClassification.from_pretrained(
27
+ model_checkpoint,
28
+ num_labels=2,
29
+ id2label=id2label,
30
+ label2id=label2id,
31
+ )
32
+
33
+ # load dataset
34
+ dataset = load_dataset("shawhin/imdb-truncated")
35
+ # dataset = DatasetDict({
36
+ # train: Dataset({
37
+ # features: ['label', 'text'],
38
+ # num_rows: 1000
39
+ # }),
40
+ # validation: Dataset({
41
+ # features: ['label', 'text'],
42
+ # num_rows: 1000
43
+ # })
44
+ # })
45
+
46
+ # create tokenizer
47
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
48
+
49
+ # create tokenizer function
50
+
51
+
52
+ def tokenize_function(examples):
53
+ # extract text
54
+ text = examples["text"]
55
+
56
+ # tokenize and truncate text
57
+ tokenizer.truncation_side = "left"
58
+ tokenized_inputs = tokenizer(
59
+ text, return_tensors="np", truncation=True, max_length=512
60
+ )
61
+
62
+ return tokenized_inputs
63
+
64
+
65
+ # add pad token if none exists
66
+ if tokenizer.pad_token is None:
67
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
68
+ model.resize_token_embeddings(len(tokenizer))
69
+
70
+ # tokenize training and validation dataset
71
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
72
+
73
+ # create data collator
74
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
75
+
76
+ # import accuracy evaluation metric
77
+ accuracy = evaluate.load("accuracy")
78
+
79
+ # define an evaluation function to pass into trainer later
80
+
81
+
82
+ def compute_metrics(p):
83
+ predictions, labels = p
84
+ predictions = np.argmax(predictions, axis=1)
85
+
86
+ return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}
87
+
88
+
89
+ # define list of examples
90
+ text_list = [
91
+ "It was good.",
92
+ "Not a fan, don't recommend.",
93
+ "Better than the first one.",
94
+ "This is not worth watching even once.",
95
+ "This one is a pass.",
96
+ ]
97
+
98
+ print("Untrained model predictions: ")
99
+ print("-----------------------------")
100
+
101
+ for text in text_list:
102
+ # tokenize text
103
+ inputs = tokenizer.encode(text, return_tensors="pt")
104
+ # compute logits
105
+ logits = model(inputs).logits
106
+ # convert logits to label
107
+ predictions = torch.argmax(logits)
108
+
109
+ print(text + " - " + id2label[predictions.tolist()])
110
+
111
+ # Output:
112
+ # Untrained model predictions:
113
+ # ----------------------------
114
+ # It was good. - Negative
115
+ # Not a fan, don't recommend. - Negative
116
+ # Better than the first one. - Negative
117
+ # This is not worth watching even once. - Negative
118
+ # This one is a pass. - Negative
119
+
120
+ peft_config = LoraConfig(
121
+ task_type="SEQ_CLS", # sequence classification
122
+ r=4, # intrinsic rank of trainable weight matrix
123
+ lora_alpha=32, # this is like a learning rate
124
+ lora_dropout=0.01, # probablity of dropout
125
+ target_modules=["q_lin"], # we apply lora to query layer
126
+ )
127
+
128
+ model = get_peft_model(model, peft_config)
129
+ model.print_trainable_parameters()
130
+
131
+ # trainable params: 1,221,124 || all params: 67,584,004 || trainable: 1.8068239934
132
+
133
+ # hyperparameters
134
+ lr = 1e-3 # size of optimization step
135
+ batch_size = 4 # number of examples proceed per optimization step
136
+ num_epochs = 10 # number of times model runs through training data
137
+
138
+ # define training arguments
139
+ training_args = TrainingArguments(
140
+ output_dir=model_checkpoint + "-lora-text-classification",
141
+ learning_rate=lr,
142
+ per_device_train_batch_size=batch_size,
143
+ per_device_eval_batch_size=batch_size,
144
+ num_train_epochs=num_epochs,
145
+ weight_decay=0.01,
146
+ evaluation_strategy="epoch",
147
+ save_strategy="epoch",
148
+ load_best_model_at_end=True,
149
+ )
150
+
151
+ # create trainer object
152
+ trainer = Trainer(
153
+ model=model,
154
+ args=training_args, # our peft model
155
+ train_dataset=tokenized_dataset["train"], # training data
156
+ eval_dataset=tokenized_dataset["validation"], # validation data
157
+ tokenizer=tokenizer, # define tokenizer
158
+ data_collator=data_collator, # this will dynamically pad examples
159
+ compute_metrics=compute_metrics, # evaluates model using compute_metrics
160
+ )
161
+
162
+ # train model
163
+ trainer.train()
164
+
165
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
166
+ # model.to(device)
167
+
168
+ path = "./model"
169
+ trainer.save_model(path)
170
+
171
+ print("Trained model predictions:")
172
+ print("--------------------------")
173
+ for text in text_list:
174
+ model = AutoModelForSequenceClassification.from_pretrained(
175
+ path,
176
+ )
177
+ inputs = tokenizer.encode(text, return_tensors="pt")
178
+ # inputs = tokenizer.encode(text, return_tensors="pt").to("mps") # moving to mps
179
+
180
+ logitis = model(inputs).logits
181
+ predictions = torch.max(logits, 1).indices
182
+
183
+ print(text + " - " + id2label[predictions.tolist()[0]])
184
+
185
+ # Output:
186
+ # Trained model predictions:
187
+ # --------------------------
188
+ # It was good. - Positive
189
+ # Not a fan, don't recommend. - Negative
190
+ # Better than the first one. - Positive
191
+ # This is not worth watching even once. - Negative
192
+ # This one is a pass. - Positive # this one is tricky