| """ |
| Training script for CLPRNet with PARSeq Tiny backbone. |
| |
| Changes from original train.py: |
| - Uses CLPRNetPARSeq model instead of CLPRNet |
| - Recognition loss uses PARSeq's cross-entropy on sequence output |
| - No character attention maps (at_ch removed) |
| - Ground-truth boxes are passed to the model for plate cropping during training |
| - PARSeq uses teacher forcing during training |
| """ |
|
|
| from common import BaseExperiment |
| from model_parseq import CLPRNetPARSeq as Model, Tokenizer |
| from dataset import MyDataset, CHARACTER |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import lr_scheduler |
| import os |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import random |
| from utils import provinces1, provinces2 |
|
|
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def IoU_multi(pred_boxes, target_boxes, eps=1e-6): |
| pred_x1, pred_y1, pred_x2, pred_y2 = torch.split(pred_boxes, 1, dim=-1) |
| target_x1, target_y1, target_x2, target_y2 = torch.split(target_boxes, 1, dim=-1) |
| inter_x1 = torch.max(pred_x1, target_x1) |
| inter_y1 = torch.max(pred_y1, target_y1) |
| inter_x2 = torch.min(pred_x2, target_x2) |
| inter_y2 = torch.min(pred_y2, target_y2) |
| inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0) |
| pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1) |
| target_area = (target_x2 - target_x1) * (target_y2 - target_y1) |
| union_area = pred_area + target_area - inter_area |
| iou = inter_area / (union_area + eps) |
| return iou |
|
|
|
|
| def IOU(box, other_boxes): |
| box_area = (box[2] - box[0]) * (box[3] - box[1]) |
| other_boxes_area = (other_boxes[:, 2] - other_boxes[:, 0]) * (other_boxes[:, 3] - other_boxes[:, 1]) |
| x1 = torch.max(box[0], other_boxes[:, 0]) |
| y1 = torch.max(box[1], other_boxes[:, 1]) |
| x2 = torch.min(box[2], other_boxes[:, 2]) |
| y2 = torch.min(box[3], other_boxes[:, 3]) |
| Min = torch.zeros(1, device=box.device) |
| w, h = torch.max(Min, x2 - x1), torch.max(Min, y2 - y1) |
| overlap_area = w * h |
| iou = overlap_area / (box_area + other_boxes_area - overlap_area + 1e-6) |
| return iou |
|
|
|
|
| def NMS(boxes, C=0.5): |
| if len(boxes) == 0: |
| return [] |
| sort_boxes = boxes[boxes[:, 0].argsort(descending=True)] |
| keep = [] |
| while len(sort_boxes) > 0: |
| ref_box = sort_boxes[0] |
| keep.append(ref_box) |
| if len(sort_boxes) > 1: |
| other_boxes = sort_boxes[1:] |
| sort_boxes = other_boxes[torch.where(IOU(ref_box[1:5], other_boxes[:, 1:5]) < C)] |
| else: |
| break |
| return torch.stack(keep) |
|
|
|
|
| class BCEFocalLoss(torch.nn.Module): |
| def __init__(self, gamma=2, alpha=0.25, reduction='mean'): |
| super(BCEFocalLoss, self).__init__() |
| self.gamma = gamma |
| self.alpha = alpha |
| self.reduction = reduction |
|
|
| def forward(self, predict, target): |
| pt = predict |
| loss = -((1 - self.alpha) * ((1 - pt + 1e-5) ** self.gamma) * (target * torch.log(pt + 1e-5)) + |
| self.alpha * ((pt + 1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt + 1e-5))) |
| if self.reduction == 'mean': |
| loss = torch.mean(loss) |
| elif self.reduction == 'sum': |
| loss = torch.sum(loss) |
| return loss |
|
|
|
|
| class BCEWithWeightLoss(torch.nn.Module): |
| def __init__(self, weight=[1, 1], reduction='mean'): |
| super(BCEWithWeightLoss, self).__init__() |
| self.weight = weight |
| self.reduction = reduction |
|
|
| def forward(self, inputs, target): |
| loss = -(self.weight[1] * target * torch.log(inputs + 1e-7) + |
| self.weight[0] * (1 - target) * torch.log(1 - inputs + 1e-7)) |
| if self.reduction == 'mean': |
| loss = torch.mean(loss) |
| elif self.reduction == 'sum': |
| loss = torch.sum(loss) |
| return loss |
|
|
|
|
| class Experiment(BaseExperiment): |
|
|
| def __init__(self, **parameter) -> None: |
| super().__init__(**parameter) |
| self.LR_STEP_SIZE = parameter['LR_STEP_SIZE'] |
| self.LR_STEP_GAMMA = parameter['LR_STEP_GAMMA'] |
| self.MOMENTUM = parameter['MOMENTUM'] |
| self.WEIGHT_DECAY = parameter['WEIGHT_DECAY'] |
| self.BETAS = parameter['BETAS'] |
| self.file_name = '' |
| self.x_mask = mask[:, :, 0].to(self.DEVICE).unsqueeze_(dim=2) |
| self.y_mask = mask[:, :, 1].to(self.DEVICE).unsqueeze_(dim=2) |
| self.tokenizer = Tokenizer() |
|
|
| def load_model(self, pretrain: str = None): |
| self.model = Model(max_label_length=8) |
| self.model = self.model.to(self.DEVICE) |
| if pretrain: |
| self.print(f"pretrain: {pretrain}") |
| self.model.load_state_dict(torch.load(os.path.join(self.WORKSPACE, pretrain)), strict=False) |
| if self.CHECKPOINT: |
| self.model.load_state_dict(torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['parameter']) |
| self.START_EPOCH = torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['epoch'] + 1 |
|
|
| def load_optimizer(self): |
| |
| parseq_params = list(self.model.parseq.parameters()) |
| det_params = [p for n, p in self.model.named_parameters() if not n.startswith('parseq')] |
| |
| self.optimizer = torch.optim.AdamW([ |
| {'params': det_params, 'lr': self.LR}, |
| {'params': parseq_params, 'lr': self.LR * 0.1}, |
| ], weight_decay=self.WEIGHT_DECAY) |
| |
| if self.CHECKPOINT: |
| self.optimizer.load_state_dict(torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['optimizer']) |
|
|
| def load_scheduler(self): |
| self.scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=[80, 85, 90], gamma=self.LR_STEP_GAMMA) |
| if self.CHECKPOINT: |
| state_dict = torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['scheduler'] |
| self.scheduler.load_state_dict(state_dict) |
|
|
| def forward(self, data): |
| x, _, _, _, _, _, _, _ = data |
| x = x.to(self.DEVICE) |
| |
| |
| img, lp_at, ch_at, bboxs, lp, lurds, lp_at_rec_facal, lp_lurd = data |
| |
| |
| batch_boxes = [] |
| batch_plate_labels = [] |
| |
| for b_idx in range(x.shape[0]): |
| lp_lurd_str = lp_lurd[b_idx] |
| plates_info = lp_lurd_str.split(';') |
| boxes_b = [] |
| labels_b = [] |
| |
| for info in plates_info: |
| if '-' not in info: |
| continue |
| lurd_str, lp_str = info.split('-') |
| lurd_vals = [int(v) for v in lurd_str.split(',')] |
| lp_indices = [int(v) for v in lp_str.split(',')] |
| |
| if len(lurd_vals) == 4 and lurd_vals[2] > lurd_vals[0] and lurd_vals[3] > lurd_vals[1]: |
| boxes_b.append(lurd_vals) |
| |
| plate_str = ''.join([CHARACTER[idx] for idx in lp_indices if idx < len(CHARACTER) - 1]) |
| labels_b.append(plate_str) |
| |
| if len(boxes_b) > 0: |
| batch_boxes.append(torch.tensor(boxes_b, dtype=torch.float32, device=self.DEVICE)) |
| else: |
| batch_boxes.append(torch.zeros((1, 4), dtype=torch.float32, device=self.DEVICE)) |
| batch_plate_labels.extend(labels_b) |
| |
| |
| pred = self.model(x, boxes_lurd=batch_boxes, plate_labels=batch_plate_labels if batch_plate_labels else None) |
| return pred |
|
|
| def loss(self, data, pred): |
| img, lp_at, ch_at, bboxs, lp, lurds, lp_at_rec_facal, lp_lurd = data |
| y_detection, y_recognition, pred_at_lp, plate_counts = pred |
|
|
| |
| lp_at = lp_at.to(self.DEVICE).unsqueeze(dim=1) |
| loss_at = BCEWithWeightLoss(weight=[0.1, 0.9])(pred_at_lp, lp_at) |
|
|
| |
| bboxs = bboxs.to(self.DEVICE) |
| lurds = lurds.to(self.DEVICE) |
| obj = torch.sum(bboxs, dim=3) > 0 |
| noobj = torch.sum(bboxs, dim=3) == 0 |
|
|
| l, t, r, b = torch.split(y_detection[:, :, :, :4], 1, dim=-1) |
| l = self.x_mask - l * img.shape[3] |
| t = self.y_mask - t * img.shape[2] |
| r = self.x_mask + r * img.shape[3] |
| b = self.y_mask + b * img.shape[2] |
| iou = IoU_multi( |
| torch.flatten(lurds, start_dim=0, end_dim=2), |
| torch.flatten(torch.concat([l, t, r, b], dim=3), start_dim=0, end_dim=2) |
| ) |
| iou = iou.view(bboxs.shape[:3]) |
|
|
| loss_location = -torch.log(iou + 1e-6) * obj |
| loss_location = torch.sum(loss_location) / (torch.sum(obj) + 1e-6) |
|
|
| |
| confidence_location = iou.detach().float() |
| loss_confidence_location = nn.MSELoss(reduction='none')(y_detection[:, :, :, 4], confidence_location) * obj + \ |
| 0.1 * nn.MSELoss(reduction='none')(y_detection[:, :, :, 4], |
| torch.zeros_like(confidence_location, device=self.DEVICE)) * noobj |
| loss_confidence_location = torch.mean(loss_confidence_location) |
|
|
| |
| loss_classify = torch.tensor(0.0, device=self.DEVICE) |
| |
| if y_recognition is not None and sum(plate_counts) > 0: |
| |
| batch_plate_labels = [] |
| for b_idx in range(img.shape[0]): |
| lp_lurd_str = lp_lurd[b_idx] |
| plates_info = lp_lurd_str.split(';') |
| for info in plates_info: |
| if '-' not in info: |
| continue |
| lurd_str, lp_str = info.split('-') |
| lurd_vals = [int(v) for v in lurd_str.split(',')] |
| lp_indices = [int(v) for v in lp_str.split(',')] |
| if len(lurd_vals) == 4 and lurd_vals[2] > lurd_vals[0] and lurd_vals[3] > lurd_vals[1]: |
| plate_str = ''.join([CHARACTER[idx] for idx in lp_indices if idx < len(CHARACTER) - 1]) |
| batch_plate_labels.append(plate_str) |
| |
| if len(batch_plate_labels) > 0 and y_recognition.shape[0] == len(batch_plate_labels): |
| |
| targets = self.tokenizer.encode(batch_plate_labels, max_length=8, device=self.DEVICE) |
| |
| targets_shifted = targets[:, 1:] |
| |
| |
| |
| logits_flat = y_recognition.reshape(-1, y_recognition.shape[-1]) |
| targets_flat = targets_shifted.reshape(-1) |
| |
| |
| pad_mask = targets_flat != self.tokenizer.pad_id |
| if pad_mask.any(): |
| loss_classify = F.cross_entropy( |
| logits_flat[pad_mask], |
| targets_flat[pad_mask], |
| ignore_index=self.tokenizer.pad_id |
| ) |
|
|
| |
| loss = 0.2 * loss_location + loss_confidence_location + 0.5 * loss_classify + 10 * loss_at |
| return loss |
|
|
| def before_val(self): |
| self.count_iou = 0 |
| self.iou_list = [torch.zeros(1, device=self.DEVICE)] |
| self.sample_num = 0 |
| self.pred_num = 1e-5 |
| self.count_lp = 0 |
| self.count_iou_lp = 0 |
|
|
| def evaluate(self, data, pred): |
| img, _, _, _, _, _, _, lp_lurd = data |
| y_detection, _, pred_at_lp, _ = pred |
|
|
| for index in range(y_detection.shape[0]): |
| lp_lurd_list = lp_lurd[index].split(';') |
| lp_list = [] |
| lurd_list = [] |
| for i in lp_lurd_list: |
| lurd, lp_str = i.split('-') |
| lp_list.append(np.array(lp_str.split(',')).astype('int32')) |
| lurd_list.append(np.array(lurd.split(',')).astype('int32')) |
|
|
| l, t, r, b, c = torch.split(y_detection[index, :, :, :5], 1, dim=-1) |
| l = self.x_mask - l * img.shape[3] |
| t = self.y_mask - t * img.shape[2] |
| r = self.x_mask + r * img.shape[3] |
| b = self.y_mask + b * img.shape[2] |
| |
| |
| out = torch.flatten(torch.concat([c, l, t, r, b], dim=2), start_dim=0, end_dim=1) |
| out = out[torch.where(out[:, 0] > 0.3)] |
| |
| if len(out) == 0: |
| self.sample_num += len(lp_list) |
| continue |
| |
| out = NMS(out.unsqueeze(0).squeeze(0) if out.dim() == 2 else out, 0.3) |
| |
| |
| boxes_for_rec = [] |
| for det in out: |
| boxes_for_rec.append(det[1:5]) |
| |
| if len(boxes_for_rec) > 0: |
| boxes_tensor = torch.stack(boxes_for_rec).unsqueeze(0) |
| input_img = img[index:index+1].to(self.DEVICE) |
| plate_texts, _ = self.model.recognize_plates(input_img, [boxes_tensor.squeeze(0)]) |
| else: |
| plate_texts = [] |
|
|
| self.sample_num += len(lp_list) |
| self.pred_num += len(out) |
|
|
| for i in range(len(lp_list)): |
| for j in range(len(out)): |
| iou = self.iou( |
| torch.from_numpy(lurd_list[i]).to(self.DEVICE), |
| torch.stack([out[j][1], out[j][2], out[j][3], out[j][4]]) |
| ) |
| self.iou_list.append(iou) |
| if iou > 0.7: |
| self.count_iou += 1 |
|
|
| if j < len(plate_texts): |
| gt_str = ''.join([CHARACTER[idx] for idx in lp_list[i] if idx < len(CHARACTER) - 1]) |
| pred_str = plate_texts[j] |
| if pred_str == gt_str: |
| self.count_lp += 1 |
| if pred_str == gt_str and iou > 0.6: |
| self.count_iou_lp += 1 |
|
|
| def after_val(self): |
| print(f"Val IoU Detection Accuracy:{self.count_iou / len(self.val_dataset):>7f}") |
| self.iou_list = torch.concat(self.iou_list, dim=0) |
| print(f"Ave IoU:{torch.mean(self.iou_list):>7f}") |
| print(f"Val Recognition Accuracy:{self.count_lp / len(self.val_dataset):>7f}") |
| print(f"Val Recognition and Detection Accuracy:{self.count_iou_lp / len(self.val_dataset):>7f}") |
| print(f"Val sample_num:{self.sample_num:>7f}") |
| print(f"Val pred_num:{self.pred_num:>7f}") |
| print(f"Val recall:{self.count_iou_lp / self.sample_num:>7f}") |
| print(f"Val precision:{self.count_iou_lp / self.pred_num:>7f}") |
|
|
| def iou(self, box, other_boxe): |
| box_area = (box[2] - box[0]) * (box[3] - box[1]) |
| other_boxes_area = (other_boxe[2] - other_boxe[0]) * (other_boxe[3] - other_boxe[1]) |
| x1 = torch.max(box[0], other_boxe[0]) |
| y1 = torch.max(box[1], other_boxe[1]) |
| x2 = torch.min(box[2], other_boxe[2]) |
| y2 = torch.min(box[3], other_boxe[3]) |
| Min = torch.zeros(1, device=box.device) |
| w, h = torch.max(Min, x2 - x1), torch.max(Min, y2 - y1) |
| overlap_area = w * h |
| iou = overlap_area / (box_area + other_boxes_area - overlap_area) |
| return iou |
|
|