CaptchaOCR / src /collate.py
mohakapoor's picture
Initial project setup on Dev branch
ada63c0
raw
history blame
1.12 kB
from typing import List,Tuple
import torch
from src.config import cfg
from src.vocab import encode_text
def ctc_collate(batch: List[Tuple[torch.Tensor, str, str]]):
"""
batch: list of (image_tensor [C,H,W_max], label_str, rel_path)
returns:
images: [B,C,H,W_max]
targets_flat: [sum(len(label_i))]
target_lengths: [B]
input_lengths: [B] (all equal if same W_max/stride)
rel_paths: list[str]
"""
images = torch.stack([item[0] for item in batch],dim =0)
labels = [item[1] for item in batch]
encoded = [torch.tensor(encode_text(t),dtype = torch.long) for t in labels]
target_lengths = torch.tensor([len(t) for t in encoded],dtype = torch.long)
if len(encoded) > 0:
targets_flat = torch.cat(encoded,dim = 0)
else:
targets_flat = torch.empty(0,dtype = torch.long)
B, C, H, W = images.shape
input_len = W // cfg.total_stride
input_lengths = torch.full((B,), input_len, dtype=torch.long)
rel_paths = [item[2] for item in batch]
return images, targets_flat, target_lengths, input_lengths, rel_paths