""" Training script for CLPRNet with PARSeq Tiny backbone. Changes from original train.py: - Uses CLPRNetPARSeq model instead of CLPRNet - Recognition loss uses PARSeq's cross-entropy on sequence output - No character attention maps (at_ch removed) - Ground-truth boxes are passed to the model for plate cropping during training - PARSeq uses teacher forcing during training """ from common import BaseExperiment from model_parseq import CLPRNetPARSeq as Model, Tokenizer from dataset import MyDataset, CHARACTER import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import lr_scheduler import os import matplotlib.pyplot as plt import numpy as np import random from utils import provinces1, provinces2 def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def IoU_multi(pred_boxes, target_boxes, eps=1e-6): pred_x1, pred_y1, pred_x2, pred_y2 = torch.split(pred_boxes, 1, dim=-1) target_x1, target_y1, target_x2, target_y2 = torch.split(target_boxes, 1, dim=-1) inter_x1 = torch.max(pred_x1, target_x1) inter_y1 = torch.max(pred_y1, target_y1) inter_x2 = torch.min(pred_x2, target_x2) inter_y2 = torch.min(pred_y2, target_y2) inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0) pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1) target_area = (target_x2 - target_x1) * (target_y2 - target_y1) union_area = pred_area + target_area - inter_area iou = inter_area / (union_area + eps) return iou def IOU(box, other_boxes): box_area = (box[2] - box[0]) * (box[3] - box[1]) other_boxes_area = (other_boxes[:, 2] - other_boxes[:, 0]) * (other_boxes[:, 3] - other_boxes[:, 1]) x1 = torch.max(box[0], other_boxes[:, 0]) y1 = torch.max(box[1], other_boxes[:, 1]) x2 = torch.min(box[2], other_boxes[:, 2]) y2 = torch.min(box[3], other_boxes[:, 3]) Min = torch.zeros(1, device=box.device) w, h = torch.max(Min, x2 - x1), torch.max(Min, y2 - y1) overlap_area = w * h iou = overlap_area / (box_area + other_boxes_area - overlap_area + 1e-6) return iou def NMS(boxes, C=0.5): if len(boxes) == 0: return [] sort_boxes = boxes[boxes[:, 0].argsort(descending=True)] keep = [] while len(sort_boxes) > 0: ref_box = sort_boxes[0] keep.append(ref_box) if len(sort_boxes) > 1: other_boxes = sort_boxes[1:] sort_boxes = other_boxes[torch.where(IOU(ref_box[1:5], other_boxes[:, 1:5]) < C)] else: break return torch.stack(keep) class BCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='mean'): super(BCEFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = -((1 - self.alpha) * ((1 - pt + 1e-5) ** self.gamma) * (target * torch.log(pt + 1e-5)) + self.alpha * ((pt + 1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt + 1e-5))) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss class BCEWithWeightLoss(torch.nn.Module): def __init__(self, weight=[1, 1], reduction='mean'): super(BCEWithWeightLoss, self).__init__() self.weight = weight self.reduction = reduction def forward(self, inputs, target): loss = -(self.weight[1] * target * torch.log(inputs + 1e-7) + self.weight[0] * (1 - target) * torch.log(1 - inputs + 1e-7)) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss class Experiment(BaseExperiment): def __init__(self, **parameter) -> None: super().__init__(**parameter) self.LR_STEP_SIZE = parameter['LR_STEP_SIZE'] self.LR_STEP_GAMMA = parameter['LR_STEP_GAMMA'] self.MOMENTUM = parameter['MOMENTUM'] self.WEIGHT_DECAY = parameter['WEIGHT_DECAY'] self.BETAS = parameter['BETAS'] self.file_name = '' self.x_mask = mask[:, :, 0].to(self.DEVICE).unsqueeze_(dim=2) self.y_mask = mask[:, :, 1].to(self.DEVICE).unsqueeze_(dim=2) self.tokenizer = Tokenizer() def load_model(self, pretrain: str = None): self.model = Model(max_label_length=8) self.model = self.model.to(self.DEVICE) if pretrain: self.print(f"pretrain: {pretrain}") self.model.load_state_dict(torch.load(os.path.join(self.WORKSPACE, pretrain)), strict=False) if self.CHECKPOINT: self.model.load_state_dict(torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['parameter']) self.START_EPOCH = torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['epoch'] + 1 def load_optimizer(self): # Use AdamW for PARSeq (better for transformers) with different LR groups parseq_params = list(self.model.parseq.parameters()) det_params = [p for n, p in self.model.named_parameters() if not n.startswith('parseq')] self.optimizer = torch.optim.AdamW([ {'params': det_params, 'lr': self.LR}, {'params': parseq_params, 'lr': self.LR * 0.1}, # Lower LR for PARSeq ], weight_decay=self.WEIGHT_DECAY) if self.CHECKPOINT: self.optimizer.load_state_dict(torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['optimizer']) def load_scheduler(self): self.scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=[80, 85, 90], gamma=self.LR_STEP_GAMMA) if self.CHECKPOINT: state_dict = torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['scheduler'] self.scheduler.load_state_dict(state_dict) def forward(self, data): x, _, _, _, _, _, _, _ = data x = x.to(self.DEVICE) # Extract GT boxes from the data for plate cropping img, lp_at, ch_at, bboxs, lp, lurds, lp_at_rec_facal, lp_lurd = data # Parse plate labels and boxes from lp_lurd string batch_boxes = [] batch_plate_labels = [] for b_idx in range(x.shape[0]): lp_lurd_str = lp_lurd[b_idx] plates_info = lp_lurd_str.split(';') boxes_b = [] labels_b = [] for info in plates_info: if '-' not in info: continue lurd_str, lp_str = info.split('-') lurd_vals = [int(v) for v in lurd_str.split(',')] lp_indices = [int(v) for v in lp_str.split(',')] if len(lurd_vals) == 4 and lurd_vals[2] > lurd_vals[0] and lurd_vals[3] > lurd_vals[1]: boxes_b.append(lurd_vals) # [l, t, r, b] # Convert indices back to plate string plate_str = ''.join([CHARACTER[idx] for idx in lp_indices if idx < len(CHARACTER) - 1]) labels_b.append(plate_str) if len(boxes_b) > 0: batch_boxes.append(torch.tensor(boxes_b, dtype=torch.float32, device=self.DEVICE)) else: batch_boxes.append(torch.zeros((1, 4), dtype=torch.float32, device=self.DEVICE)) batch_plate_labels.extend(labels_b) # Forward pass with GT boxes and labels pred = self.model(x, boxes_lurd=batch_boxes, plate_labels=batch_plate_labels if batch_plate_labels else None) return pred def loss(self, data, pred): img, lp_at, ch_at, bboxs, lp, lurds, lp_at_rec_facal, lp_lurd = data y_detection, y_recognition, pred_at_lp, plate_counts = pred # --- Attention loss (LP attention only, no char attention) --- lp_at = lp_at.to(self.DEVICE).unsqueeze(dim=1) loss_at = BCEWithWeightLoss(weight=[0.1, 0.9])(pred_at_lp, lp_at) # --- Location loss (unchanged) --- bboxs = bboxs.to(self.DEVICE) lurds = lurds.to(self.DEVICE) obj = torch.sum(bboxs, dim=3) > 0 noobj = torch.sum(bboxs, dim=3) == 0 l, t, r, b = torch.split(y_detection[:, :, :, :4], 1, dim=-1) l = self.x_mask - l * img.shape[3] t = self.y_mask - t * img.shape[2] r = self.x_mask + r * img.shape[3] b = self.y_mask + b * img.shape[2] iou = IoU_multi( torch.flatten(lurds, start_dim=0, end_dim=2), torch.flatten(torch.concat([l, t, r, b], dim=3), start_dim=0, end_dim=2) ) iou = iou.view(bboxs.shape[:3]) loss_location = -torch.log(iou + 1e-6) * obj loss_location = torch.sum(loss_location) / (torch.sum(obj) + 1e-6) # --- Confidence loss (unchanged) --- confidence_location = iou.detach().float() loss_confidence_location = nn.MSELoss(reduction='none')(y_detection[:, :, :, 4], confidence_location) * obj + \ 0.1 * nn.MSELoss(reduction='none')(y_detection[:, :, :, 4], torch.zeros_like(confidence_location, device=self.DEVICE)) * noobj loss_confidence_location = torch.mean(loss_confidence_location) # --- Recognition loss (PARSeq cross-entropy) --- loss_classify = torch.tensor(0.0, device=self.DEVICE) if y_recognition is not None and sum(plate_counts) > 0: # Get target labels batch_plate_labels = [] for b_idx in range(img.shape[0]): lp_lurd_str = lp_lurd[b_idx] plates_info = lp_lurd_str.split(';') for info in plates_info: if '-' not in info: continue lurd_str, lp_str = info.split('-') lurd_vals = [int(v) for v in lurd_str.split(',')] lp_indices = [int(v) for v in lp_str.split(',')] if len(lurd_vals) == 4 and lurd_vals[2] > lurd_vals[0] and lurd_vals[3] > lurd_vals[1]: plate_str = ''.join([CHARACTER[idx] for idx in lp_indices if idx < len(CHARACTER) - 1]) batch_plate_labels.append(plate_str) if len(batch_plate_labels) > 0 and y_recognition.shape[0] == len(batch_plate_labels): # Encode targets: [BOS, c1, ..., cN, EOS, PAD...] targets = self.tokenizer.encode(batch_plate_labels, max_length=8, device=self.DEVICE) # Target for loss: [c1, c2, ..., cN, EOS, PAD...] (shift by 1) targets_shifted = targets[:, 1:] # Remove BOS, keep chars + EOS + PAD # PARSeq output: (N, max_len+1, num_tokens=74) # Flatten for cross-entropy logits_flat = y_recognition.reshape(-1, y_recognition.shape[-1]) # (N*(max_len+1), 74) targets_flat = targets_shifted.reshape(-1) # (N*(max_len+1),) # Mask out PAD positions from loss pad_mask = targets_flat != self.tokenizer.pad_id if pad_mask.any(): loss_classify = F.cross_entropy( logits_flat[pad_mask], targets_flat[pad_mask], ignore_index=self.tokenizer.pad_id ) # --- Total loss --- loss = 0.2 * loss_location + loss_confidence_location + 0.5 * loss_classify + 10 * loss_at return loss def before_val(self): self.count_iou = 0 self.iou_list = [torch.zeros(1, device=self.DEVICE)] self.sample_num = 0 self.pred_num = 1e-5 self.count_lp = 0 self.count_iou_lp = 0 def evaluate(self, data, pred): img, _, _, _, _, _, _, lp_lurd = data y_detection, _, pred_at_lp, _ = pred for index in range(y_detection.shape[0]): lp_lurd_list = lp_lurd[index].split(';') lp_list = [] lurd_list = [] for i in lp_lurd_list: lurd, lp_str = i.split('-') lp_list.append(np.array(lp_str.split(',')).astype('int32')) lurd_list.append(np.array(lurd.split(',')).astype('int32')) l, t, r, b, c = torch.split(y_detection[index, :, :, :5], 1, dim=-1) l = self.x_mask - l * img.shape[3] t = self.y_mask - t * img.shape[2] r = self.x_mask + r * img.shape[3] b = self.y_mask + b * img.shape[2] # For evaluation, get detected boxes then run PARSeq on crops out = torch.flatten(torch.concat([c, l, t, r, b], dim=2), start_dim=0, end_dim=1) out = out[torch.where(out[:, 0] > 0.3)] if len(out) == 0: self.sample_num += len(lp_list) continue out = NMS(out.unsqueeze(0).squeeze(0) if out.dim() == 2 else out, 0.3) # Crop and recognize each detected plate boxes_for_rec = [] for det in out: boxes_for_rec.append(det[1:5]) if len(boxes_for_rec) > 0: boxes_tensor = torch.stack(boxes_for_rec).unsqueeze(0) input_img = img[index:index+1].to(self.DEVICE) plate_texts, _ = self.model.recognize_plates(input_img, [boxes_tensor.squeeze(0)]) else: plate_texts = [] self.sample_num += len(lp_list) self.pred_num += len(out) for i in range(len(lp_list)): for j in range(len(out)): iou = self.iou( torch.from_numpy(lurd_list[i]).to(self.DEVICE), torch.stack([out[j][1], out[j][2], out[j][3], out[j][4]]) ) self.iou_list.append(iou) if iou > 0.7: self.count_iou += 1 if j < len(plate_texts): gt_str = ''.join([CHARACTER[idx] for idx in lp_list[i] if idx < len(CHARACTER) - 1]) pred_str = plate_texts[j] if pred_str == gt_str: self.count_lp += 1 if pred_str == gt_str and iou > 0.6: self.count_iou_lp += 1 def after_val(self): print(f"Val IoU Detection Accuracy:{self.count_iou / len(self.val_dataset):>7f}") self.iou_list = torch.concat(self.iou_list, dim=0) print(f"Ave IoU:{torch.mean(self.iou_list):>7f}") print(f"Val Recognition Accuracy:{self.count_lp / len(self.val_dataset):>7f}") print(f"Val Recognition and Detection Accuracy:{self.count_iou_lp / len(self.val_dataset):>7f}") print(f"Val sample_num:{self.sample_num:>7f}") print(f"Val pred_num:{self.pred_num:>7f}") print(f"Val recall:{self.count_iou_lp / self.sample_num:>7f}") print(f"Val precision:{self.count_iou_lp / self.pred_num:>7f}") def iou(self, box, other_boxe): box_area = (box[2] - box[0]) * (box[3] - box[1]) other_boxes_area = (other_boxe[2] - other_boxe[0]) * (other_boxe[3] - other_boxe[1]) x1 = torch.max(box[0], other_boxe[0]) y1 = torch.max(box[1], other_boxe[1]) x2 = torch.min(box[2], other_boxe[2]) y2 = torch.min(box[3], other_boxe[3]) Min = torch.zeros(1, device=box.device) w, h = torch.max(Min, x2 - x1), torch.max(Min, y2 - y1) overlap_area = w * h iou = overlap_area / (box_area + other_boxes_area - overlap_area) return iou