""" CLPRNet with PARSeq Tiny as the OCR recognition backbone. Architecture: - Detection: Original CLPRNet shared backbone (FPN) + detection head (unchanged) - Recognition: PARSeq Tiny (ViT encoder + autoregressive decoder) Integration approach: During training, ground-truth bounding boxes are used to crop plate regions from the input image via differentiable grid_sample. These crops are resized to (32, 128) and fed to PARSeq Tiny, which outputs character logits. During inference, the detection head produces bounding boxes (after NMS), then crops are extracted and fed to PARSeq Tiny for recognition. The 8-channel per-character spatial attention maps (at_ch) are removed since PARSeq handles character-level attention internally via its Transformer decoder. Only at_lp (license plate region attention) is kept to assist detection. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from functools import partial from itertools import permutations # ============================================================================ # Original CLPRNet building blocks (unchanged) # ============================================================================ class SE(nn.Module): def __init__(self, in_channel, reduction=16): super(SE, self).__init__() self.avepool = nn.AdaptiveAvgPool2d(1) self.maxpool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Linear(in_channel * 2, in_channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(in_channel // reduction, in_channel, bias=False), nn.Sigmoid(), ) def forward(self, x): ax = self.avepool(x).view(x.size(0), -1) mx = self.maxpool(x).view(x.size(0), -1) se = torch.concat([ax, mx], dim=1) out = self.fc(se) out = out.view(out.size(0), out.size(1), 1, 1) return out * x class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, down_simple=1): super(BasicBlock, self).__init__() self.feature = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=down_simple, padding=1), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(num_features=out_channels), ) self.resize = nn.Sequential() if down_simple > 1 or in_channels != out_channels: self.resize = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=down_simple, stride=down_simple), nn.BatchNorm2d(num_features=out_channels), ) def forward(self, x): f = self.feature(x) x = self.resize(x) y = F.leaky_relu(x + f) return y class SEBasicBlock(nn.Module): def __init__(self, in_channels, out_channels, down_simple=1, reduction=16): super(SEBasicBlock, self).__init__() self.feature = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=down_simple, padding=1), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(num_features=out_channels), ) self.se = SE(in_channel=out_channels, reduction=reduction) self.resize = nn.Sequential() if down_simple > 1 or in_channels != out_channels: self.resize = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=down_simple, stride=down_simple), nn.BatchNorm2d(num_features=out_channels), ) def forward(self, x): f = self.feature(x) se = self.se(f) x = self.resize(x) y = F.leaky_relu(x + se) return y # ============================================================================ # PARSeq Tiny - Minimal self-contained implementation # Based on baudm/parseq (ECCV 2022) # ============================================================================ class PatchEmbed(nn.Module): """Image to Patch Embedding for PARSeq.""" def __init__(self, img_size=(32, 128), patch_size=(4, 8), in_chans=3, embed_dim=192): super().__init__() self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): x = self.proj(x) # (B, embed_dim, H/ph, W/pw) x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim) x = self.norm(x) return x class Attention(nn.Module): """Multi-head self-attention.""" def __init__(self, dim, num_heads=3, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MLP(nn.Module): """Feed-forward network.""" def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class EncoderBlock(nn.Module): """Transformer encoder block.""" def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.norm2 = nn.LayerNorm(dim) self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class DecoderBlock(nn.Module): """Transformer decoder block with cross-attention.""" def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.self_attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, batch_first=True) self.norm2 = nn.LayerNorm(dim) self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, batch_first=True) self.norm_mem = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop) def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None): tgt2 = self.norm1(tgt) tgt2, _ = self.self_attn(tgt2, tgt2, tgt2, attn_mask=tgt_mask) tgt = tgt + tgt2 tgt2 = self.norm2(tgt) mem = self.norm_mem(memory) tgt2, _ = self.cross_attn(tgt2, mem, mem) tgt = tgt + tgt2 tgt = tgt + self.mlp(self.norm3(tgt)) return tgt class Tokenizer: """Minimal tokenizer for Chinese license plates. Vocab: 73 characters (31 provinces + 24 letters + 10 digits + 7 specials + 1 empty) Special tokens: [EOS]=73, [BOS]=74, [PAD]=75 Head output: num_classes = 74 (73 charset + [EOS]) """ # Chinese LP character set CHARSET = ["京", "津", "冀", "晋", "蒙", "辽", "吉", "黑", "沪", "苏", "浙", "皖", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "渝", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "A", "B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "港", "澳", "使", "领", "学", "警", "挂", ""] def __init__(self): self.num_chars = len(self.CHARSET) # 73 self.eos_id = self.num_chars # 73 self.bos_id = self.num_chars + 1 # 74 self.pad_id = self.num_chars + 2 # 75 self.num_tokens = self.num_chars + 1 # 74 (charset + EOS), head output size def encode(self, labels, max_length=8, device='cpu'): """Encode string labels to token indices. Args: labels: list of plate strings max_length: max number of characters (8 for CN plates) device: target device Returns: targets: (B, max_length + 2) tensor [BOS, char1, ..., charN, EOS, PAD...] """ batch_size = len(labels) targets = torch.full((batch_size, max_length + 2), self.pad_id, dtype=torch.long, device=device) targets[:, 0] = self.bos_id for i, label in enumerate(labels): for j, ch in enumerate(label): if j >= max_length: break if ch in self.CHARSET: targets[i, j + 1] = self.CHARSET.index(ch) # EOS after last character end_pos = min(len(label), max_length) + 1 targets[i, end_pos] = self.eos_id return targets def decode(self, logits): """Decode logits to plate strings. Args: logits: (B, L, num_tokens) where num_tokens = 74 Returns: list of decoded plate strings """ preds = logits.argmax(dim=-1) # (B, L) results = [] for i in range(preds.shape[0]): chars = [] for j in range(preds.shape[1]): idx = preds[i, j].item() if idx == self.eos_id: break if idx < self.num_chars: chars.append(self.CHARSET[idx]) results.append(''.join(chars)) return results class PARSeqTiny(nn.Module): """PARSeq Tiny: Scene Text Recognition with Permuted Autoregressive Sequence Models. Architecture (from DeiT-Ti configuration): - Encoder: ViT with embed_dim=192, 3 heads, 12 layers - Decoder: 1-layer Transformer decoder with 6 heads - Input: (B, 3, 32, 128) plate crops - Output: (B, max_label_length, num_tokens) logits For Chinese LP recognition: - max_label_length = 8 (Chinese plates have 7-8 characters) - num_tokens = 74 (73 charset chars + EOS) """ def __init__(self, max_label_length=8, num_tokens=74, img_size=(32, 128), patch_size=(4, 8), embed_dim=192, enc_num_heads=3, enc_depth=12, dec_num_heads=6, dec_depth=1, mlp_ratio=4., dropout=0.1, decode_ar=True, refine_iters=1): super().__init__() self.max_label_length = max_label_length self.num_tokens = num_tokens self.embed_dim = embed_dim self.decode_ar = decode_ar self.refine_iters = refine_iters self.bos_id = num_tokens # BOS is after charset+EOS self.eos_id = num_tokens - 1 # Last charset token is EOS in the head output self.pad_id = num_tokens + 1 # Encoder (ViT) self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) self.pos_drop = nn.Dropout(p=dropout) self.encoder_blocks = nn.ModuleList([ EncoderBlock(dim=embed_dim, num_heads=enc_num_heads, mlp_ratio=mlp_ratio, drop=dropout, attn_drop=dropout) for _ in range(enc_depth) ]) self.encoder_norm = nn.LayerNorm(embed_dim) # Decoder self.token_embed = nn.Embedding(num_tokens + 2, embed_dim) # +2 for BOS and PAD self.pos_queries = nn.Parameter(torch.zeros(1, max_label_length + 1, embed_dim)) # +1 for EOS position self.decoder_blocks = nn.ModuleList([ DecoderBlock(dim=embed_dim, num_heads=dec_num_heads, mlp_ratio=mlp_ratio, drop=dropout, attn_drop=dropout) for _ in range(dec_depth) ]) self.decoder_norm = nn.LayerNorm(embed_dim) # Classification head self.head = nn.Linear(embed_dim, num_tokens) # Initialize weights nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.pos_queries, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def encode(self, img): """Encode image patches. Args: img: (B, 3, 32, 128) normalized plate image Returns: memory: (B, num_patches, embed_dim) """ x = self.patch_embed(img) x = x + self.pos_embed x = self.pos_drop(x) for blk in self.encoder_blocks: x = blk(x) x = self.encoder_norm(x) return x def decode(self, memory, tgt=None, tgt_mask=None, tgt_key_padding_mask=None): """Decode from encoder memory. Args: memory: (B, num_patches, embed_dim) encoder output tgt: (B, L) target token indices (for teacher forcing) or None (for inference) tgt_mask: causal mask for autoregressive decoding Returns: logits: (B, L, num_tokens) """ B = memory.shape[0] L = self.max_label_length + 1 # +1 for EOS slot if tgt is not None: # Teacher forcing: embed the target tokens tgt_emb = self.token_embed(tgt) # (B, L, embed_dim) # Add positional queries tgt_emb = tgt_emb + self.pos_queries[:, :tgt_emb.shape[1], :] else: # Inference: use learned positional queries as input tgt_emb = self.pos_queries.expand(B, -1, -1) # (B, L, embed_dim) x = tgt_emb for blk in self.decoder_blocks: x = blk(x, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask) x = self.decoder_norm(x) logits = self.head(x) # (B, L, num_tokens) return logits def forward(self, img, tgt=None): """Full forward pass. Args: img: (B, 3, 32, 128) plate crops, normalized to [-1, 1] or ImageNet stats tgt: (B, L) target indices for teacher-forced training, or None for inference Returns: logits: (B, max_label_length+1, num_tokens=74) """ memory = self.encode(img) if tgt is not None: # Training with teacher forcing (use BOS + target chars as input) # tgt should be [BOS, c1, c2, ..., cN] (exclude EOS from input, predict it) logits = self.decode(memory, tgt) else: # Inference: non-autoregressive (single pass with positional queries) logits = self.decode(memory) return logits def generate_permutation_masks(self, max_length, num_perms=6): """Generate permutation-based training masks (PARSeq's key innovation). Returns a list of attention masks for different character orderings. """ # Standard left-to-right + right-to-left + random permutations perms = [torch.arange(max_length)] # L2R perms.append(torch.arange(max_length - 1, -1, -1)) # R2L # Random permutations for _ in range(num_perms - 2): perm = torch.randperm(max_length) perms.append(perm) return perms # ============================================================================ # CLPRNet with PARSeq Tiny backbone # ============================================================================ class CLPRNetPARSeq(nn.Module): """CLPRNet with PARSeq Tiny replacing the CNN-based recognition branch. Changes from original CLPRNet: - REMOVED: self.recognition (4x SEBasicBlock CNN) - REMOVED: self.recognition_head (Conv2d 256->73) - REMOVED: 8-channel character attention maps (at_ch) - REMOVED: 8-branch attention-masked batching logic - ADDED: PARSeq Tiny for plate text recognition - ADDED: Differentiable plate cropping via grid_sample - MODIFIED: at_head now outputs 1 channel only (at_lp) Forward pass: 1. Shared backbone extracts features -> detection head -> boxes 2. Plate regions are cropped from input image using GT/predicted boxes 3. Crops are resized to (32, 128) and normalized for PARSeq 4. PARSeq Tiny produces character logits """ def __init__(self, max_label_length=8, parseq_pretrained_path=None): super(CLPRNetPARSeq, self).__init__() self.max_label_length = max_label_length self.tokenizer = Tokenizer() # --- Shared Backbone (FPN) - UNCHANGED --- self.feature = nn.Sequential( BasicBlock(in_channels=3, out_channels=4), BasicBlock(in_channels=4, out_channels=16, down_simple=2), BasicBlock(in_channels=16, out_channels=16), BasicBlock(in_channels=16, out_channels=16), BasicBlock(in_channels=16, out_channels=64, down_simple=2), BasicBlock(in_channels=64, out_channels=64), BasicBlock(in_channels=64, out_channels=64), ) self.feature_128 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=2), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(), ) self.feature_64 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2), nn.BatchNorm2d(num_features=128), nn.LeakyReLU(), ) self.feature_32 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=2), nn.BatchNorm2d(num_features=128), nn.LeakyReLU(), ) self.feature_up_64 = nn.Sequential( nn.Conv2d(in_channels=(128 + 128), out_channels=64, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(), ) self.feature_up_128 = nn.Sequential( nn.Conv2d(in_channels=(64 + 64), out_channels=64, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(), ) self.feature_up_256 = nn.Sequential( nn.Conv2d(in_channels=(64 + 64), out_channels=32, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(), ) # --- Attention Head: only LP attention (1 channel instead of 9) --- self.at_head = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1), nn.LeakyReLU(), nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1), nn.Sigmoid(), ) # --- Detection Branch - UNCHANGED --- self.detection = nn.Sequential( SEBasicBlock(in_channels=64, out_channels=64, down_simple=2, reduction=2), SEBasicBlock(in_channels=64, out_channels=64, reduction=2), SEBasicBlock(in_channels=64, out_channels=128, down_simple=2, reduction=4), SEBasicBlock(in_channels=128, out_channels=128, reduction=2), SEBasicBlock(in_channels=128, out_channels=128, down_simple=1, reduction=4), SEBasicBlock(in_channels=128, out_channels=128, reduction=2), ) self.detection_head = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=5, kernel_size=1), nn.BatchNorm2d(num_features=5), nn.Sigmoid(), ) # --- PARSeq Tiny Recognition Backbone --- self.parseq = PARSeqTiny( max_label_length=max_label_length, num_tokens=self.tokenizer.num_tokens, # 74 (73 chars + EOS) img_size=(32, 128), patch_size=(4, 8), embed_dim=192, enc_num_heads=3, enc_depth=12, dec_num_heads=6, dec_depth=1, mlp_ratio=4., dropout=0.1, ) # Load pretrained PARSeq weights if available if parseq_pretrained_path is not None: self._load_parseq_pretrained(parseq_pretrained_path) def _load_parseq_pretrained(self, path): """Load pretrained PARSeq Tiny weights (partial, as charset differs).""" state_dict = torch.load(path, map_location='cpu') # Filter out head weights since our charset is different filtered = {k: v for k, v in state_dict.items() if not k.startswith('head.') and not k.startswith('token_embed.')} missing, unexpected = self.parseq.load_state_dict(filtered, strict=False) print(f"PARSeq pretrained loaded. Missing: {len(missing)}, Unexpected: {len(unexpected)}") def crop_plates(self, images, boxes, img_size): """Differentiable plate cropping using grid_sample. Args: images: (B, 3, H, W) input images boxes: (B, N, 4) normalized boxes [l, t, r, b] in pixel coords or list of (N_i, 4) tensors with variable number of plates per image img_size: (H, W) image size Returns: crops: (total_plates, 3, 32, 128) plate crops ready for PARSeq plate_counts: number of plates per image in batch """ H, W = img_size crops = [] plate_counts = [] for b in range(images.shape[0]): if isinstance(boxes, (list, tuple)): box_set = boxes[b] # (N_i, 4) else: box_set = boxes[b] # (N, 4) if box_set.dim() == 1: box_set = box_set.unsqueeze(0) count = 0 for i in range(box_set.shape[0]): l, t, r, b_coord = box_set[i] # Skip invalid boxes if r <= l or b_coord <= t: continue # Normalize to [-1, 1] for grid_sample # grid_sample expects (x, y) in [-1, 1] x1 = (l / W) * 2 - 1 x2 = (r / W) * 2 - 1 y1 = (t / H) * 2 - 1 y2 = (b_coord / H) * 2 - 1 # Create sampling grid for (32, 128) output grid_x = torch.linspace(x1.item(), x2.item(), 128, device=images.device) grid_y = torch.linspace(y1.item(), y2.item(), 32, device=images.device) grid_yy, grid_xx = torch.meshgrid(grid_y, grid_x, indexing='ij') grid = torch.stack([grid_xx, grid_yy], dim=-1).unsqueeze(0) # (1, 32, 128, 2) crop = F.grid_sample(images[b:b+1], grid, mode='bilinear', padding_mode='zeros', align_corners=True) crops.append(crop.squeeze(0)) # (3, 32, 128) count += 1 plate_counts.append(count) if len(crops) == 0: # Return dummy crop if no valid boxes found dummy = torch.zeros(1, 3, 32, 128, device=images.device) return dummy, [0] * images.shape[0] crops = torch.stack(crops, dim=0) # (total_plates, 3, 32, 128) return crops, plate_counts def normalize_crops_for_parseq(self, crops): """Normalize plate crops for PARSeq input. PARSeq expects: mean=0.5, std=0.5 (maps [0,1] to [-1,1]) CLPRNet input is ImageNet normalized: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] We first de-normalize from ImageNet stats, then re-normalize for PARSeq. """ # De-normalize from ImageNet imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=crops.device).view(1, 3, 1, 1) imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=crops.device).view(1, 3, 1, 1) crops = crops * imagenet_std + imagenet_mean # back to [0, 1] # Normalize for PARSeq (mean=0.5, std=0.5) crops = (crops - 0.5) / 0.5 # to [-1, 1] return crops def forward(self, x, boxes_lurd=None, plate_labels=None): """Forward pass. Args: x: (B, 3, H, W) input images (ImageNet normalized) boxes_lurd: Ground truth boxes for training, format (B, N, 4) [l, t, r, b] in pixels If None, uses detection head output (inference mode) plate_labels: list of plate strings for teacher forcing, or None Returns: y_detection: (B, 64, 64, 5) detection output [l, t, r, b, conf] y_recognition: (total_plates, max_label_length+1, 74) PARSeq logits at_lp: (B, 1, H/4, W/4) license plate attention map plate_counts: list of int, number of plates per image """ B = x.shape[0] H, W = x.shape[2], x.shape[3] # --- Shared backbone (FPN) --- x_256 = self.feature(x) x_128 = self.feature_128(x_256) x_64 = self.feature_64(x_128) x_32 = self.feature_32(x_64) x_up_64 = self.feature_up_64(torch.concat([x_64, F.interpolate(x_32, size=x_64.shape[2:], mode='nearest')], dim=1)) x_up_128 = self.feature_up_128(torch.concat([x_128, F.interpolate(x_up_64, size=x_128.shape[2:], mode='nearest')], dim=1)) x_up_256 = self.feature_up_256(torch.concat([x_256, F.interpolate(x_up_128, size=x_256.shape[2:], mode='nearest')], dim=1)) # --- Attention map (LP only, no per-char attention) --- at_lp = self.at_head(x_up_256) # (B, 1, H/4, W/4) # --- Detection --- y_detection = self.detection(x_256) y_detection = self.detection_head(y_detection) y_detection = y_detection.transpose(1, 3).transpose(1, 2) # (B, 64, 64, 5) # --- Recognition via PARSeq Tiny --- if boxes_lurd is not None: # Training: use GT boxes to crop plates crops, plate_counts = self.crop_plates(x, boxes_lurd, (H, W)) else: # Inference: extract boxes from detection head # (handled externally in inference script for proper NMS) # For forward pass without boxes, return detection only return y_detection, None, at_lp, [0] * B # Normalize crops for PARSeq crops = self.normalize_crops_for_parseq(crops) # PARSeq forward if plate_labels is not None and self.training: # Teacher forcing with target tokens tgt = self.tokenizer.encode(plate_labels, max_length=self.max_label_length, device=x.device) # Input to decoder: BOS + target chars (exclude last token which is EOS/PAD) tgt_input = tgt[:, :-1] # (total_plates, max_label_length + 1) y_recognition = self.parseq(crops, tgt_input) else: # No teacher forcing y_recognition = self.parseq(crops) return y_detection, y_recognition, at_lp, plate_counts def recognize_plates(self, images, boxes): """Convenience method for inference: crop and recognize plates. Args: images: (B, 3, H, W) ImageNet-normalized images boxes: list of (N_i, 4) tensors, each row is [l, t, r, b] in pixels Returns: plate_texts: list of strings confidences: list of float confidence scores """ H, W = images.shape[2], images.shape[3] crops, plate_counts = self.crop_plates(images, boxes, (H, W)) if sum(plate_counts) == 0: return [], [] crops = self.normalize_crops_for_parseq(crops) with torch.no_grad(): logits = self.parseq(crops) # (N, max_len+1, 74) # Decode plate_texts = self.tokenizer.decode(logits) # Confidence: product of max softmax probs before EOS probs = logits.softmax(dim=-1) confidences = [] for i in range(probs.shape[0]): max_probs = probs[i].max(dim=-1).values # Find EOS position preds = logits[i].argmax(dim=-1) eos_pos = (preds == self.tokenizer.eos_id).nonzero(as_tuple=True)[0] if len(eos_pos) > 0: end = eos_pos[0].item() else: end = probs.shape[1] conf = max_probs[:end].prod().item() if end > 0 else 0.0 confidences.append(conf) return plate_texts, confidences # ============================================================================ # Convenience function to create model # ============================================================================ def create_clprnet_parseq(max_label_length=8, parseq_pretrained_path=None): """Create CLPRNet with PARSeq Tiny backbone. Args: max_label_length: Maximum plate string length (8 for Chinese plates) parseq_pretrained_path: Path to pretrained PARSeq Tiny weights (optional) Returns: CLPRNetPARSeq model """ model = CLPRNetPARSeq( max_label_length=max_label_length, parseq_pretrained_path=parseq_pretrained_path, ) return model if __name__ == '__main__': # Quick test device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = create_clprnet_parseq().to(device) # Simulate input B = 2 x = torch.rand((B, 3, 1024, 1024)).to(device) # Simulate GT boxes: 1 plate per image, [l, t, r, b] in pixels boxes = [ torch.tensor([[200, 400, 500, 480]], dtype=torch.float32, device=device), torch.tensor([[300, 500, 600, 580]], dtype=torch.float32, device=device), ] plate_labels = ["京A12345", "沪B67890"] # Training forward pass model.train() y_det, y_rec, at_lp, plate_counts = model(x, boxes_lurd=boxes, plate_labels=plate_labels) print(f"Detection output: {y_det.shape}") # (2, 64, 64, 5) print(f"Recognition output: {y_rec.shape}") # (2, 9, 74) print(f"Attention map: {at_lp.shape}") # (2, 1, 256, 256) print(f"Plate counts: {plate_counts}") # [1, 1] # Inference model.eval() with torch.no_grad(): y_det_inf, _, at_lp_inf, _ = model(x) print(f"\nInference detection: {y_det_inf.shape}") # Test recognize_plates plates, confs = model.recognize_plates(x, boxes) print(f"Recognized plates: {plates}") print(f"Confidences: {confs}") # Parameter count comparison total_params = sum(p.numel() for p in model.parameters()) parseq_params = sum(p.numel() for p in model.parseq.parameters()) det_params = total_params - parseq_params print(f"\nTotal parameters: {total_params:,}") print(f"PARSeq Tiny parameters: {parseq_params:,}") print(f"Detection backbone parameters: {det_params:,}")