import torch from src.data import make_loaders from src.model import CRNN def main(): train_loader, _, stoi = make_loaders(batch_size=2, img_height=64) batch = next(iter(train_loader)) num_classes = len(stoi) # includes at 0 model = CRNN(num_classes=num_classes) log_probs, input_lengths = model(batch.images) print("log_probs shape:", log_probs.shape) # [T, B, C] print("input_lengths:", input_lengths) # [B] print("targets shape:", batch.targets.shape) # [sum(T)] print("target_lengths:", batch.target_lengths) # [B] if __name__ == "__main__": main()