File size: 643 Bytes
2411029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from src.data import make_loaders
from src.model import CRNN

def main():
    train_loader, _, stoi = make_loaders(batch_size=2, img_height=64)
    batch = next(iter(train_loader))

    num_classes = len(stoi)  # includes <BLANK> at 0
    model = CRNN(num_classes=num_classes)

    log_probs, input_lengths = model(batch.images)
    print("log_probs shape:", log_probs.shape)      # [T, B, C]
    print("input_lengths:", input_lengths)          # [B]
    print("targets shape:", batch.targets.shape)    # [sum(T)]
    print("target_lengths:", batch.target_lengths)  # [B]

if __name__ == "__main__":
    main()