import json import sys import os import argparse import torch import torch.nn as nn import torchvision from tqdm import tqdm sys.path.append(os.getcwd()) from src.Text_Recognization.text_recognization import * from src.Text_Recognization.prepare_dataset import * from src.Text_Recognization.dataloader import * device = 'cuda' if torch.cuda.is_available() else 'cpu' def load_json_config(config_path): with open(config_path, "r") as f: config = json.load(f) return config def evaluate(model, dataloader, criterion, device): model.eval() losses = [] with torch.no_grad(): for images, labels, labels_len in dataloader: images = images.to(device) labels = labels.to(device) outputs = model(images) logits_lens = torch.full( size=(outputs.size(1), ), fill_value=outputs.size(0), dtype=torch.long ).to(device) loss = criterion(outputs, labels, logits_lens, labels_len) losses.append(loss.item()) eval_loss = sum(losses) / len(losses) return eval_loss def training_loop(model, train_loader, val_loader, learning_rate, epochs, optimizer, criterion, scheduler, device): model.to(device) train_losses = [] val_losses = [] for epoch in range(epochs): model.train() batch_losses = [] for images, labels, labels_len in tqdm(train_loader): images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) logits_lens = torch.full( size=(outputs.size(1), ), fill_value=outputs.size(0), dtype=torch.long ).to(device) loss = criterion(outputs, labels, logits_lens, labels_len) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5) optimizer.step() batch_losses.append(loss.item()) train_loss = sum(batch_losses) / len(batch_losses) train_losses.append(train_loss) val_loss = evaluate(model, val_loader, criterion, device) val_losses.append(val_loss) print(f"epoch: {epoch+1}/{epochs}\ttrain_loss:{train_loss}\tval_loss:{val_loss}") scheduler.step() return train_losses, val_losses def main(): parser = argparse.ArgumentParser() parser.add_argument('--root_path', type=str, default=os.getcwd(), help='Path to the root directory') parser.add_argument('--checkpoints_path', type=str, default=os.path.join(os.getcwd(), 'checkpoints'), help='Path to the checkpoint directory') args = parser.parse_args() config_path = 'src/config.json' dataset_path = os.path.join(args.root_path, 'Dataset') config = load_json_config(config_path) # dictionary char and idx char_to_idx, idx_to_char = build_vocab(dataset_path) # model model = CRNN(vocab_size=config['vocab_size'], hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers']) # dataloader train_loader, val_loader, test_loader = get_dataloader() # define hyper parammeters criterion = nn.CTCLoss( blank=char_to_idx[config['blank_char']], zero_infinity=True, reduction='mean' ) optimizer = torch.optim.Adam( model.parameters(), lr=config['CRNN']['learning_rate'], weight_decay=config['CRNN']['weight_decay'] ) scheduler = torch.optim.lr_scheduler.StepLR( optimizer=optimizer, step_size=config['CRNN']['scheduler_step_size'], gamma=0.1 ) # training loop train_losses, val_losses = training_loop( model=model, train_loader=train_loader, val_loader=val_loader, learning_rate=config['CRNN']['learning_rate'], epochs=config['CRNN']['epochs'], optimizer=optimizer, criterion=criterion, scheduler=scheduler, device=device ) # save model if not os.path.exists(args.checkpoints_path): os.makedirs(args.checkpoints_path) os.makedirs(os.path.join(args.checkpoints_path, 'losses')) torch.save(model.state_dict(), os.path.join(args.checkpoints_path, 'crnn.pt')) # draw losses fig, axis = plt.subplots(1, 2, figsize=(8, 8)) axis[0].plot(train_losses, label='train_loss') axis[0].set_xlabel('Epochs') axis[0].set_ylabel('Loss') axis[0].axis('off') axis[0].legend() axis[1].plot(val_losses, label='val_loss') axis[1].set_xlabel('Epochs') axis[1].set_ylabel('Loss') axis[1].axis('off') axis[1].legend() plt.savefig(os.path.join(args.checkpoints_path, 'losses', 'losses.png')) if __name__ == '__main__': main()