handwritten-notes-ocr / src /test_model.py
lakshmi-charan's picture
Upload 15 files
2411029 verified
raw
history blame contribute delete
643 Bytes
import torch
from src.data import make_loaders
from src.model import CRNN
def main():
train_loader, _, stoi = make_loaders(batch_size=2, img_height=64)
batch = next(iter(train_loader))
num_classes = len(stoi) # includes <BLANK> at 0
model = CRNN(num_classes=num_classes)
log_probs, input_lengths = model(batch.images)
print("log_probs shape:", log_probs.shape) # [T, B, C]
print("input_lengths:", input_lengths) # [B]
print("targets shape:", batch.targets.shape) # [sum(T)]
print("target_lengths:", batch.target_lengths) # [B]
if __name__ == "__main__":
main()