|
|
import torch
|
|
|
from src.data import make_loaders
|
|
|
from src.model import CRNN
|
|
|
|
|
|
def main():
|
|
|
train_loader, _, stoi = make_loaders(batch_size=2, img_height=64)
|
|
|
batch = next(iter(train_loader))
|
|
|
|
|
|
num_classes = len(stoi)
|
|
|
model = CRNN(num_classes=num_classes)
|
|
|
|
|
|
log_probs, input_lengths = model(batch.images)
|
|
|
print("log_probs shape:", log_probs.shape)
|
|
|
print("input_lengths:", input_lengths)
|
|
|
print("targets shape:", batch.targets.shape)
|
|
|
print("target_lengths:", batch.target_lengths)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|