File size: 1,115 Bytes
ada63c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import List,Tuple
import torch
from src.config import cfg
from src.vocab import encode_text

def ctc_collate(batch: List[Tuple[torch.Tensor, str, str]]):
    """
    batch: list of (image_tensor [C,H,W_max], label_str, rel_path)
    returns:
      images: [B,C,H,W_max]
      targets_flat: [sum(len(label_i))]
      target_lengths: [B]
      input_lengths: [B]  (all equal if same W_max/stride)
      rel_paths: list[str]
    """

    images = torch.stack([item[0] for item in batch],dim =0)

    labels = [item[1] for item in batch]
    encoded = [torch.tensor(encode_text(t),dtype = torch.long) for t in labels]
    target_lengths = torch.tensor([len(t) for t in encoded],dtype = torch.long)
    if len(encoded) > 0:
        targets_flat = torch.cat(encoded,dim = 0)
    else:
        targets_flat = torch.empty(0,dtype = torch.long)

    
    B, C, H, W = images.shape
    
    input_len = W // cfg.total_stride
    input_lengths = torch.full((B,), input_len, dtype=torch.long)

    rel_paths = [item[2] for item in batch]
    return images, targets_flat, target_lengths, input_lengths, rel_paths