Spaces:
Sleeping
Sleeping
File size: 1,115 Bytes
ada63c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from typing import List,Tuple
import torch
from src.config import cfg
from src.vocab import encode_text
def ctc_collate(batch: List[Tuple[torch.Tensor, str, str]]):
"""
batch: list of (image_tensor [C,H,W_max], label_str, rel_path)
returns:
images: [B,C,H,W_max]
targets_flat: [sum(len(label_i))]
target_lengths: [B]
input_lengths: [B] (all equal if same W_max/stride)
rel_paths: list[str]
"""
images = torch.stack([item[0] for item in batch],dim =0)
labels = [item[1] for item in batch]
encoded = [torch.tensor(encode_text(t),dtype = torch.long) for t in labels]
target_lengths = torch.tensor([len(t) for t in encoded],dtype = torch.long)
if len(encoded) > 0:
targets_flat = torch.cat(encoded,dim = 0)
else:
targets_flat = torch.empty(0,dtype = torch.long)
B, C, H, W = images.shape
input_len = W // cfg.total_stride
input_lengths = torch.full((B,), input_len, dtype=torch.long)
rel_paths = [item[2] for item in batch]
return images, targets_flat, target_lengths, input_lengths, rel_paths |