Thouph commited on
Commit
1574172
·
1 Parent(s): c3310dd

Upload train_k.py

Browse files
Files changed (1) hide show
  1. train_k.py +235 -0
train_k.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """from IPython.display import clear_output
2
+ #!pip install rouge_score -q
3
+ #!pip install deep-phonemizer -q
4
+ clear_output()"""
5
+
6
+
7
+ import os
8
+
9
+ import datasets
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torchvision
13
+ from PIL import Image
14
+ from pathlib import Path
15
+ from tqdm.auto import tqdm
16
+ import multiprocessing as mp
17
+ import matplotlib.pyplot as plt
18
+ from sklearn.model_selection import train_test_split
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torchvision import io, transforms
24
+ from torch.utils.data import Dataset, DataLoader, random_split
25
+
26
+ from transformers import Seq2SeqTrainer ,Seq2SeqTrainingArguments
27
+ from transformers import VisionEncoderDecoderModel , ViTFeatureExtractor
28
+ from transformers import AutoTokenizer , default_data_collator
29
+ import os
30
+ os.environ["WANDB_DISABLED"] = "true"
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ dev = xm.xla_device()
34
+
35
+
36
+ if torch.cuda.is_available():
37
+
38
+ device = torch.device("cuda")
39
+
40
+ print('There are %d GPU(s) available.' % torch.cuda.device_count())
41
+
42
+ print('We will use the GPU:', torch.cuda.get_device_name(0))
43
+
44
+ else:
45
+ print('No GPU available, using the CPU instead.')
46
+ device = torch.device("cpu")
47
+
48
+
49
+
50
+ #os.environ["WANDB_DISABLED"] = "true"
51
+ class config :
52
+ ENCODER = "google/vit-base-patch16-224"
53
+ DECODER = "gpt2"
54
+ TRAIN_BATCH_SIZE = 4#8
55
+ VAL_BATCH_SIZE = 4#8
56
+ VAL_EPOCHS = 1
57
+ LR = 5e-5
58
+ SEED = 42
59
+ MAX_LEN = 128
60
+ SUMMARY_LEN = 20
61
+ WEIGHT_DECAY = 0.01
62
+ MEAN = (0.485, 0.456, 0.406)
63
+ STD = (0.229, 0.224, 0.225)
64
+ TRAIN_PCT = 0.95
65
+ NUM_WORKERS = mp.cpu_count()
66
+ EPOCHS = 1
67
+ IMG_SIZE = (224,224)
68
+ LABEL_MASK = -100
69
+ TOP_K = 10
70
+ TOP_P = 0.95
71
+
72
+
73
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
74
+ outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
75
+ return outputs
76
+ AutoTokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
77
+
78
+
79
+
80
+ rouge = datasets.load_metric("rouge")
81
+
82
+ def compute_metrics(pred):
83
+ labels_ids = pred.label_ids
84
+ pred_ids = pred.predictions
85
+
86
+ # all unnecessary tokens are removed
87
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
88
+ labels_ids[labels_ids == -100] = tokenizer.pad_token_id
89
+ label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
90
+
91
+ rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
92
+
93
+ return {
94
+ "rouge2_precision": round(rouge_output.precision, 4),
95
+ "rouge2_recall": round(rouge_output.recall, 4),
96
+ "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
97
+ }
98
+
99
+
100
+ feature_extractor = ViTFeatureExtractor.from_pretrained(config.ENCODER)
101
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
102
+ tokenizer.pad_token = tokenizer.unk_token
103
+
104
+ transforms = transforms.Compose(
105
+ [
106
+ #transforms.Resize(config.IMG_SIZE),
107
+ transforms.ToTensor(),
108
+ transforms.Normalize(
109
+ mean=[0.5, 0.5, 0.5],
110
+ std=[0.5, 0.5, 0.5],
111
+ )
112
+ ]
113
+ )
114
+
115
+
116
+
117
+ class ImgDataset(torch.utils.data.Dataset):
118
+ def __init__(self, df, root_dir, tokenizer, feature_extractor, transform):
119
+ self.df = df
120
+ self.transform = transform
121
+ self.root_dir = root_dir
122
+ self.tokenizer = tokenizer
123
+ self.feature_extractor = feature_extractor
124
+ self.max_length = 128
125
+
126
+ def __len__(self, ):
127
+ return len(self.df)
128
+
129
+ def __getitem__(self, idx):
130
+ caption = self.df.tags.iloc[idx]
131
+ image = self.df.image_id.iloc[idx]+".jpg"
132
+ folder_name = str(self.df.folder_name.iloc[idx])
133
+ img_path = os.path.join(os.path.join(self.root_dir, folder_name), image)
134
+ img = Image.open(img_path).convert("RGB")
135
+
136
+
137
+ img = self.transform(img)
138
+
139
+ # Check if normalization is required
140
+ if img.min() < 0.0:
141
+ img = (img + 1.0) / 2.0
142
+
143
+ pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values
144
+ captions = self.tokenizer(caption,
145
+ padding='max_length',
146
+ max_length=self.max_length,
147
+ truncation=True).input_ids
148
+ captions = [caption if caption != self.tokenizer.pad_token_id else -100 for caption in captions]
149
+ encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(captions)}
150
+ return encoding
151
+
152
+ for j in range(1, 179+1):
153
+ df=pd.read_csv(rf"posts/posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder_{j}.csv")#r"Z:\posts-2023-04-17_MD5_caption_sifted_no_symbol_purged.csv")
154
+ train_df , val_df = train_test_split(df , test_size = 0.02)
155
+ print(df.head(3))
156
+
157
+ train_dataset = ImgDataset(
158
+ train_df,
159
+ root_dir = rf"dump_small",
160
+ tokenizer=tokenizer,
161
+ feature_extractor = feature_extractor ,
162
+ transform = transforms,
163
+ )
164
+
165
+ val_dataset = ImgDataset(
166
+ val_df ,
167
+ root_dir = rf"dump_small",
168
+ tokenizer=tokenizer,
169
+ feature_extractor = feature_extractor ,
170
+ transform = transforms
171
+ )
172
+
173
+
174
+ model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER)
175
+
176
+
177
+ model.config.decoder_start_token_id = tokenizer.cls_token_id
178
+ model.config.pad_token_id = tokenizer.pad_token_id
179
+ # make sure vocab size is set correctly
180
+ model.config.vocab_size = model.config.decoder.vocab_size
181
+ # set beam search parameters
182
+ model.config.eos_token_id = tokenizer.sep_token_id
183
+ model.config.decoder_start_token_id = tokenizer.bos_token_id
184
+ model.config.max_length = 128
185
+ model.config.early_stopping = True
186
+ model.config.no_repeat_ngram_size = 2
187
+ model.config.length_penalty = 2.0
188
+ model.config.num_beams = 2
189
+
190
+ training_args = Seq2SeqTrainingArguments(
191
+ output_dir='VIT_large_gpt2',
192
+ per_device_train_batch_size=config.TRAIN_BATCH_SIZE,
193
+ per_device_eval_batch_size=config.VAL_BATCH_SIZE,
194
+ predict_with_generate=True,
195
+ evaluation_strategy="steps",
196
+ do_train=True,
197
+ do_eval=True,
198
+ logging_steps=1000,
199
+ save_steps=1000,
200
+ warmup_steps=200,
201
+ learning_rate = 5e-5-j*2.2e-7,
202
+ #max_steps=400, # delete for full training
203
+ num_train_epochs = config.EPOCHS, #TRAIN_EPOCHS
204
+ overwrite_output_dir=True,
205
+ save_total_limit=3,
206
+ )
207
+
208
+
209
+
210
+
211
+ """import transformers.trainer
212
+ from transformers.trainer import SequentialSampler
213
+
214
+
215
+ def sampler_monkey_patch(dataset, generator):
216
+ return SequentialSampler(dataset)
217
+
218
+
219
+ transformers.trainer.RandomSampler = sampler_monkey_patch"""
220
+
221
+ trainer = Seq2SeqTrainer(
222
+ tokenizer=feature_extractor,
223
+ model=model,
224
+ args=training_args,
225
+ compute_metrics=compute_metrics,
226
+ train_dataset=train_dataset,
227
+ eval_dataset=val_dataset,
228
+ data_collator=default_data_collator,
229
+ )
230
+ try:
231
+ trainer.train(resume_from_checkpoint='VIT_large_gpt2_model')
232
+ except:
233
+ trainer.train()
234
+ trainer.save_model('VIT_large_gpt2_model')
235
+