Seth0330 commited on
Commit
3ade64c
·
verified ·
1 Parent(s): 2d9244e

Create train.py

Browse files
Files changed (1) hide show
  1. pdrt/train.py +173 -0
pdrt/train.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from tqdm import tqdm
4
+ from torch.utils.data import DataLoader
5
+ from torch.optim import AdamW
6
+ from transformers import get_scheduler
7
+
8
+ import utils_ctc
9
+ from models import Swin_CTC, VED
10
+ from mydatasets import myDatasetCTC, myDatasetTransformerDecoder
11
+
12
+ torch.set_float32_matmul_precision('medium')
13
+
14
+ #################################################################
15
+ # Experiment Settings
16
+ #################################################################
17
+
18
+ NUM_EPOCHS = int(sys.argv[0])
19
+ LR = float(sys.argv[1])
20
+ STRATEGY = str(sys.argv[2])
21
+ BATCH_SIZE = int(sys.argv[3])
22
+ MODEL_NAME = str(sys.argv[4])
23
+ NUM_ACCUMULATION_STEPS = int(sys.argv[5])
24
+
25
+ print(30*'*')
26
+ print("EXPERIMENT PARAMS: ")
27
+ print("\tNUM_EPOCHS: ", NUM_EPOCHS)
28
+ print("\tLR: ", LR)
29
+ print("\tSTRATEGY: ", STRATEGY)
30
+ print("\tBATCH_SIZE: ", BATCH_SIZE)
31
+ print("\tMODEL_NAME: ", MODEL_NAME)
32
+ print("\tNUM_ACCUMULATION_BATCHES: ", NUM_ACCUMULATION_STEPS)
33
+ print(30*'*')
34
+
35
+
36
+ #################################################################
37
+ # Load Torch Dataset and Create Vocab
38
+ #################################################################
39
+
40
+ l_of_transcrips = []
41
+ if MODEL_NAME == "Swin_CTC":
42
+ train_dataset = myDatasetCTC(partition="train")
43
+ else:
44
+ train_dataset = myDatasetTransformerDecoder(partition="train")
45
+
46
+ l_of_transcrips = train_dataset.label_list
47
+ text_to_seq, seq_to_text = utils_ctc.create_char_dicts(l_of_transcrips)
48
+
49
+ # update dics in datasets
50
+ train_dataset.text_to_seq = text_to_seq
51
+ train_dataset.seq_to_text = seq_to_text
52
+ print("Len dict text_to_seq: ", len(text_to_seq))
53
+ print("Len dict seq_to_text: ", len(seq_to_text))
54
+ print("Dict text_to_seq: ", (text_to_seq))
55
+ print("Dict seq_to_text: ", (seq_to_text))
56
+
57
+ #################################################################
58
+ # Load Model
59
+ #################################################################
60
+
61
+ # Create model
62
+ if MODEL_NAME == "Swin_CTC":
63
+ model = Swin_CTC(len(text_to_seq))
64
+ else:
65
+ model = VED()
66
+
67
+ #################################################################
68
+ # Training Settings
69
+ #################################################################
70
+
71
+ device = "cuda:0"
72
+
73
+ if MODEL_NAME == "Swin_CTC":
74
+ mycollate_fn = utils_ctc.custom_collate
75
+ else:
76
+ mycollate_fn = None
77
+
78
+ train_dataloader = DataLoader(
79
+ train_dataset,
80
+ BATCH_SIZE,
81
+ shuffle=True,
82
+ num_workers=23,
83
+ collate_fn=mycollate_fn)
84
+
85
+ optimizer = AdamW(model.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0)
86
+
87
+ num_training_steps = NUM_EPOCHS # * len(train_dataloader)
88
+ lr_scheduler = get_scheduler(
89
+ "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
90
+ )
91
+
92
+ #################################################################
93
+ # Frozen Strategies
94
+ #################################################################
95
+
96
+ model.to(device)
97
+ model.train()
98
+
99
+ if MODEL_NAME == "Swin_CTC":
100
+ if STRATEGY == "CTC-fclayer":
101
+ for name_p,p in model.named_parameters():
102
+ p.requires_grad = False
103
+ if "projection_V" in name_p:
104
+ p.requires_grad = True
105
+ print("Train only: ", name_p)
106
+ elif STRATEGY == "CTC-Swin":
107
+ for name_p,p in model.named_parameters():
108
+ p.requires_grad = True
109
+ if "projection_V" in name_p:
110
+ p.requires_grad = False
111
+ print("No train: ", name_p)
112
+ else:
113
+ for name_p,p in model.named_parameters():
114
+ p.requires_grad = True
115
+ print("Train all layers")
116
+ else:
117
+ if STRATEGY == "VED-encoder":
118
+ for name_p,p in model.named_parameters():
119
+ p.requires_grad = False
120
+ if "model.encoder." in name_p:
121
+ p.requires_grad = True
122
+ print("Train only: ", name_p)
123
+ elif STRATEGY == "VED-decoder":
124
+ for name_p,p in model.named_parameters():
125
+ p.requires_grad = False
126
+ if "model.decoder." in name_p:
127
+ p.requires_grad = True
128
+ print("Train only: ", name_p)
129
+ else:
130
+ for name_p,p in model.named_parameters():
131
+ p.requires_grad = True
132
+ print("Train all layers")
133
+
134
+ def count_parameters(model):
135
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
136
+ print("Params: ", count_parameters(model))
137
+
138
+ #################################################################
139
+ # Training
140
+ #################################################################
141
+
142
+ for epoch in range(NUM_EPOCHS):
143
+
144
+ epoch_loss = 0
145
+ print("Epoch ", epoch)
146
+ idx = 0
147
+ optimizer.zero_grad(set_to_none=True)
148
+ model.train()
149
+
150
+ with tqdm(iter(train_dataloader), desc="Training set", unit="batch") as tepoch:
151
+ for batch in tepoch:
152
+
153
+ inputs: torch.Tensor = batch["img"].to(device)
154
+ labels: torch.Tensor = batch["label"].to(device)
155
+
156
+ if MODEL_NAME == "Swin_CTC":
157
+ target_lengths: torch.Tensor = batch["target_lengths"].to(device)
158
+ outputs, loss = model(inputs, labels, target_lengths)
159
+ else:
160
+ outputs, loss = model(inputs, labels)
161
+
162
+ loss.backward()
163
+
164
+ if ((idx + 1) % NUM_ACCUMULATION_STEPS == 0):
165
+ optimizer.step()
166
+ optimizer.zero_grad(set_to_none=True)
167
+
168
+ tepoch.set_postfix(loss=loss.data.item())
169
+ epoch_loss += loss.data.item()
170
+ idx += 1
171
+
172
+ # Save Final model
173
+ torch.save(model.state_dict(), './FINAL_MODEL')