TungDuong's picture
source code
06142a4 verified
import json
import sys
import os
import argparse
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
sys.path.append(os.getcwd())
from src.Text_Recognization.text_recognization import *
from src.Text_Recognization.prepare_dataset import *
from src.Text_Recognization.dataloader import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_json_config(config_path):
with open(config_path, "r") as f:
config = json.load(f)
return config
def evaluate(model, dataloader, criterion, device):
model.eval()
losses = []
with torch.no_grad():
for images, labels, labels_len in dataloader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
logits_lens = torch.full(
size=(outputs.size(1), ),
fill_value=outputs.size(0),
dtype=torch.long
).to(device)
loss = criterion(outputs, labels, logits_lens, labels_len)
losses.append(loss.item())
eval_loss = sum(losses) / len(losses)
return eval_loss
def training_loop(model, train_loader, val_loader, learning_rate, epochs, optimizer, criterion, scheduler, device):
model.to(device)
train_losses = []
val_losses = []
for epoch in range(epochs):
model.train()
batch_losses = []
for images, labels, labels_len in tqdm(train_loader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
logits_lens = torch.full(
size=(outputs.size(1), ),
fill_value=outputs.size(0),
dtype=torch.long
).to(device)
loss = criterion(outputs, labels, logits_lens, labels_len)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
batch_losses.append(loss.item())
train_loss = sum(batch_losses) / len(batch_losses)
train_losses.append(train_loss)
val_loss = evaluate(model, val_loader, criterion, device)
val_losses.append(val_loss)
print(f"epoch: {epoch+1}/{epochs}\ttrain_loss:{train_loss}\tval_loss:{val_loss}")
scheduler.step()
return train_losses, val_losses
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str, default=os.getcwd(), help='Path to the root directory')
parser.add_argument('--checkpoints_path', type=str, default=os.path.join(os.getcwd(), 'checkpoints'), help='Path to the checkpoint directory')
args = parser.parse_args()
config_path = 'src/config.json'
dataset_path = os.path.join(args.root_path, 'Dataset')
config = load_json_config(config_path)
# dictionary char and idx
char_to_idx, idx_to_char = build_vocab(dataset_path)
# model
model = CRNN(vocab_size=config['vocab_size'], hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
# dataloader
train_loader, val_loader, test_loader = get_dataloader()
# define hyper parammeters
criterion = nn.CTCLoss(
blank=char_to_idx[config['blank_char']],
zero_infinity=True,
reduction='mean'
)
optimizer = torch.optim.Adam(
model.parameters(),
lr=config['CRNN']['learning_rate'],
weight_decay=config['CRNN']['weight_decay']
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer,
step_size=config['CRNN']['scheduler_step_size'],
gamma=0.1
)
# training loop
train_losses, val_losses = training_loop(
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=config['CRNN']['learning_rate'],
epochs=config['CRNN']['epochs'],
optimizer=optimizer,
criterion=criterion,
scheduler=scheduler,
device=device
)
# save model
if not os.path.exists(args.checkpoints_path):
os.makedirs(args.checkpoints_path)
os.makedirs(os.path.join(args.checkpoints_path, 'losses'))
torch.save(model.state_dict(), os.path.join(args.checkpoints_path, 'crnn.pt'))
# draw losses
fig, axis = plt.subplots(1, 2, figsize=(8, 8))
axis[0].plot(train_losses, label='train_loss')
axis[0].set_xlabel('Epochs')
axis[0].set_ylabel('Loss')
axis[0].axis('off')
axis[0].legend()
axis[1].plot(val_losses, label='val_loss')
axis[1].set_xlabel('Epochs')
axis[1].set_ylabel('Loss')
axis[1].axis('off')
axis[1].legend()
plt.savefig(os.path.join(args.checkpoints_path, 'losses', 'losses.png'))
if __name__ == '__main__':
main()