File size: 3,470 Bytes
6e89f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.config import cfg
from src.collate import ctc_collate
from src.captcha_dataset import CaptchaDataset
from src.vocab import vocab_size, ctc_greedy_decode
from src.model_crnn import CRNN


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    in_ch = 1 if cfg.grayscale else 3

    print("Creating datasets...")
    train_ds = CaptchaDataset("train")
    val_ds = CaptchaDataset("val")
    
    print(f"Training dataset size: {len(train_ds)}")
    print(f"Validation dataset size: {len(val_ds)}")

    train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, 
                          num_workers=cfg.num_workers, pin_memory=True, 
                          drop_last=True, collate_fn=ctc_collate)
    val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, 
                        num_workers=cfg.num_workers, pin_memory=True, 
                        drop_last=True, collate_fn=ctc_collate)

    # # Test training data
    # print("\nTesting training data...")
    # for batch in train_dl:
    #     images, targets_flat, target_lengths, input_lengths, paths = batch
    #     print(f"Training batch shape: {images.shape}")
    #     print(f"Sample labels: {targets_flat[:10]}")
    #     break

    # # Test validation data
    # print("\nTesting validation data...")
    # try:
    #     for batch in val_dl:
    #         images, targets_flat, target_lengths, input_lengths, paths = batch
    #         print(f"Validation batch shape: {images.shape}")
    #         print(f"Sample labels: {targets_flat[:10]}")
    #         break
    # except Exception as e:
    #     print(f"Error in validation data: {e}")
    #     print("This suggests there are issues with some validation images")

    model = CRNN(vocab_size=vocab_size()).to(device)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scaler = torch.amp.GradScaler('cuda', enabled=cfg.amp and device.type == "cuda")

    model.train()
    steps = 200
    it = iter(train_dl)
    for step in range(1,steps+1):
        try:
            images, targets_flat, target_lengths, input_lengths, paths = next(it)
        except StopIteration:
            it = iter(train_dl)
            images, targets_flat, target_lengths, input_lengths, paths = next(it)
        
        images = images.to(device)
        targets_flat = targets_flat.to(device)
        target_lengths = target_lengths.to(device)
        input_lengths = input_lengths.to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=scaler.is_enabled()):
            logits = model(images)
            log_probs = logits.log_softmax(dim=-1)
            loss = criterion(log_probs,targets_flat,input_lengths,target_lengths)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()


        if step % 20 == 0:
            print(f"step {step}/{steps} - loss {loss.item():.4f}")

    model.eval()
    with torch.no_grad():
        images, targets_flat, target_lengths, input_lengths, paths = next(iter(val_dl))
        images = images.to(device)
        logits = model(images)
        preds = ctc_greedy_decode(logits)

    print("Sanity check complete")


if __name__ == "__main__":
    os.makedirs("checkpoints", exist_ok=True)
    main()