| | import sys |
| | import torch |
| | from tqdm import tqdm |
| | from torch.utils.data import DataLoader |
| | from torch.optim import AdamW |
| | from transformers import get_scheduler |
| |
|
| | import utils_ctc |
| | from models import Swin_CTC, VED |
| | from mydatasets import myDatasetCTC, myDatasetTransformerDecoder |
| |
|
| | torch.set_float32_matmul_precision('medium') |
| |
|
| | |
| | |
| | |
| |
|
| | NUM_EPOCHS = int(sys.argv[0]) |
| | LR = float(sys.argv[1]) |
| | STRATEGY = str(sys.argv[2]) |
| | BATCH_SIZE = int(sys.argv[3]) |
| | MODEL_NAME = str(sys.argv[4]) |
| | NUM_ACCUMULATION_STEPS = int(sys.argv[5]) |
| |
|
| | print(30*'*') |
| | print("EXPERIMENT PARAMS: ") |
| | print("\tNUM_EPOCHS: ", NUM_EPOCHS) |
| | print("\tLR: ", LR) |
| | print("\tSTRATEGY: ", STRATEGY) |
| | print("\tBATCH_SIZE: ", BATCH_SIZE) |
| | print("\tMODEL_NAME: ", MODEL_NAME) |
| | print("\tNUM_ACCUMULATION_BATCHES: ", NUM_ACCUMULATION_STEPS) |
| | print(30*'*') |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | l_of_transcrips = [] |
| | if MODEL_NAME == "Swin_CTC": |
| | train_dataset = myDatasetCTC(partition="train") |
| | else: |
| | train_dataset = myDatasetTransformerDecoder(partition="train") |
| |
|
| | l_of_transcrips = train_dataset.label_list |
| | text_to_seq, seq_to_text = utils_ctc.create_char_dicts(l_of_transcrips) |
| |
|
| | |
| | train_dataset.text_to_seq = text_to_seq |
| | train_dataset.seq_to_text = seq_to_text |
| | print("Len dict text_to_seq: ", len(text_to_seq)) |
| | print("Len dict seq_to_text: ", len(seq_to_text)) |
| | print("Dict text_to_seq: ", (text_to_seq)) |
| | print("Dict seq_to_text: ", (seq_to_text)) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | if MODEL_NAME == "Swin_CTC": |
| | model = Swin_CTC(len(text_to_seq)) |
| | else: |
| | model = VED() |
| |
|
| | |
| | |
| | |
| |
|
| | device = "cuda:0" |
| |
|
| | if MODEL_NAME == "Swin_CTC": |
| | mycollate_fn = utils_ctc.custom_collate |
| | else: |
| | mycollate_fn = None |
| |
|
| | train_dataloader = DataLoader( |
| | train_dataset, |
| | BATCH_SIZE, |
| | shuffle=True, |
| | num_workers=23, |
| | collate_fn=mycollate_fn) |
| |
|
| | optimizer = AdamW(model.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0) |
| |
|
| | num_training_steps = NUM_EPOCHS |
| | lr_scheduler = get_scheduler( |
| | "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | model.to(device) |
| | model.train() |
| |
|
| | if MODEL_NAME == "Swin_CTC": |
| | if STRATEGY == "CTC-fclayer": |
| | for name_p,p in model.named_parameters(): |
| | p.requires_grad = False |
| | if "projection_V" in name_p: |
| | p.requires_grad = True |
| | print("Train only: ", name_p) |
| | elif STRATEGY == "CTC-Swin": |
| | for name_p,p in model.named_parameters(): |
| | p.requires_grad = True |
| | if "projection_V" in name_p: |
| | p.requires_grad = False |
| | print("No train: ", name_p) |
| | else: |
| | for name_p,p in model.named_parameters(): |
| | p.requires_grad = True |
| | print("Train all layers") |
| | else: |
| | if STRATEGY == "VED-encoder": |
| | for name_p,p in model.named_parameters(): |
| | p.requires_grad = False |
| | if "model.encoder." in name_p: |
| | p.requires_grad = True |
| | print("Train only: ", name_p) |
| | elif STRATEGY == "VED-decoder": |
| | for name_p,p in model.named_parameters(): |
| | p.requires_grad = False |
| | if "model.decoder." in name_p: |
| | p.requires_grad = True |
| | print("Train only: ", name_p) |
| | else: |
| | for name_p,p in model.named_parameters(): |
| | p.requires_grad = True |
| | print("Train all layers") |
| |
|
| | def count_parameters(model): |
| | return sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | print("Params: ", count_parameters(model)) |
| |
|
| | |
| | |
| | |
| |
|
| | for epoch in range(NUM_EPOCHS): |
| |
|
| | epoch_loss = 0 |
| | print("Epoch ", epoch) |
| | idx = 0 |
| | optimizer.zero_grad(set_to_none=True) |
| | model.train() |
| |
|
| | with tqdm(iter(train_dataloader), desc="Training set", unit="batch") as tepoch: |
| | for batch in tepoch: |
| | |
| | inputs: torch.Tensor = batch["img"].to(device) |
| | labels: torch.Tensor = batch["label"].to(device) |
| | |
| | if MODEL_NAME == "Swin_CTC": |
| | target_lengths: torch.Tensor = batch["target_lengths"].to(device) |
| | outputs, loss = model(inputs, labels, target_lengths) |
| | else: |
| | outputs, loss = model(inputs, labels) |
| |
|
| | loss.backward() |
| |
|
| | if ((idx + 1) % NUM_ACCUMULATION_STEPS == 0): |
| | optimizer.step() |
| | optimizer.zero_grad(set_to_none=True) |
| |
|
| | tepoch.set_postfix(loss=loss.data.item()) |
| | epoch_loss += loss.data.item() |
| | idx += 1 |
| |
|
| | |
| | torch.save(model.state_dict(), './FINAL_MODEL') |