File size: 643 Bytes
2411029 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from src.data import make_loaders
from src.model import CRNN
def main():
train_loader, _, stoi = make_loaders(batch_size=2, img_height=64)
batch = next(iter(train_loader))
num_classes = len(stoi) # includes <BLANK> at 0
model = CRNN(num_classes=num_classes)
log_probs, input_lengths = model(batch.images)
print("log_probs shape:", log_probs.shape) # [T, B, C]
print("input_lengths:", input_lengths) # [B]
print("targets shape:", batch.targets.shape) # [sum(T)]
print("target_lengths:", batch.target_lengths) # [B]
if __name__ == "__main__":
main()
|