CLPRNet-PARSeq / train_parseq.py
theakhilshukla's picture
Upload train_parseq.py
deda9ff verified
"""
Training script for CLPRNet with PARSeq Tiny backbone.
Changes from original train.py:
- Uses CLPRNetPARSeq model instead of CLPRNet
- Recognition loss uses PARSeq's cross-entropy on sequence output
- No character attention maps (at_ch removed)
- Ground-truth boxes are passed to the model for plate cropping during training
- PARSeq uses teacher forcing during training
"""
from common import BaseExperiment
from model_parseq import CLPRNetPARSeq as Model, Tokenizer
from dataset import MyDataset, CHARACTER
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import numpy as np
import random
from utils import provinces1, provinces2
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def IoU_multi(pred_boxes, target_boxes, eps=1e-6):
pred_x1, pred_y1, pred_x2, pred_y2 = torch.split(pred_boxes, 1, dim=-1)
target_x1, target_y1, target_x2, target_y2 = torch.split(target_boxes, 1, dim=-1)
inter_x1 = torch.max(pred_x1, target_x1)
inter_y1 = torch.max(pred_y1, target_y1)
inter_x2 = torch.min(pred_x2, target_x2)
inter_y2 = torch.min(pred_y2, target_y2)
inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)
pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
target_area = (target_x2 - target_x1) * (target_y2 - target_y1)
union_area = pred_area + target_area - inter_area
iou = inter_area / (union_area + eps)
return iou
def IOU(box, other_boxes):
box_area = (box[2] - box[0]) * (box[3] - box[1])
other_boxes_area = (other_boxes[:, 2] - other_boxes[:, 0]) * (other_boxes[:, 3] - other_boxes[:, 1])
x1 = torch.max(box[0], other_boxes[:, 0])
y1 = torch.max(box[1], other_boxes[:, 1])
x2 = torch.min(box[2], other_boxes[:, 2])
y2 = torch.min(box[3], other_boxes[:, 3])
Min = torch.zeros(1, device=box.device)
w, h = torch.max(Min, x2 - x1), torch.max(Min, y2 - y1)
overlap_area = w * h
iou = overlap_area / (box_area + other_boxes_area - overlap_area + 1e-6)
return iou
def NMS(boxes, C=0.5):
if len(boxes) == 0:
return []
sort_boxes = boxes[boxes[:, 0].argsort(descending=True)]
keep = []
while len(sort_boxes) > 0:
ref_box = sort_boxes[0]
keep.append(ref_box)
if len(sort_boxes) > 1:
other_boxes = sort_boxes[1:]
sort_boxes = other_boxes[torch.where(IOU(ref_box[1:5], other_boxes[:, 1:5]) < C)]
else:
break
return torch.stack(keep)
class BCEFocalLoss(torch.nn.Module):
def __init__(self, gamma=2, alpha=0.25, reduction='mean'):
super(BCEFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, predict, target):
pt = predict
loss = -((1 - self.alpha) * ((1 - pt + 1e-5) ** self.gamma) * (target * torch.log(pt + 1e-5)) +
self.alpha * ((pt + 1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt + 1e-5)))
if self.reduction == 'mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss
class BCEWithWeightLoss(torch.nn.Module):
def __init__(self, weight=[1, 1], reduction='mean'):
super(BCEWithWeightLoss, self).__init__()
self.weight = weight
self.reduction = reduction
def forward(self, inputs, target):
loss = -(self.weight[1] * target * torch.log(inputs + 1e-7) +
self.weight[0] * (1 - target) * torch.log(1 - inputs + 1e-7))
if self.reduction == 'mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss
class Experiment(BaseExperiment):
def __init__(self, **parameter) -> None:
super().__init__(**parameter)
self.LR_STEP_SIZE = parameter['LR_STEP_SIZE']
self.LR_STEP_GAMMA = parameter['LR_STEP_GAMMA']
self.MOMENTUM = parameter['MOMENTUM']
self.WEIGHT_DECAY = parameter['WEIGHT_DECAY']
self.BETAS = parameter['BETAS']
self.file_name = ''
self.x_mask = mask[:, :, 0].to(self.DEVICE).unsqueeze_(dim=2)
self.y_mask = mask[:, :, 1].to(self.DEVICE).unsqueeze_(dim=2)
self.tokenizer = Tokenizer()
def load_model(self, pretrain: str = None):
self.model = Model(max_label_length=8)
self.model = self.model.to(self.DEVICE)
if pretrain:
self.print(f"pretrain: {pretrain}")
self.model.load_state_dict(torch.load(os.path.join(self.WORKSPACE, pretrain)), strict=False)
if self.CHECKPOINT:
self.model.load_state_dict(torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['parameter'])
self.START_EPOCH = torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['epoch'] + 1
def load_optimizer(self):
# Use AdamW for PARSeq (better for transformers) with different LR groups
parseq_params = list(self.model.parseq.parameters())
det_params = [p for n, p in self.model.named_parameters() if not n.startswith('parseq')]
self.optimizer = torch.optim.AdamW([
{'params': det_params, 'lr': self.LR},
{'params': parseq_params, 'lr': self.LR * 0.1}, # Lower LR for PARSeq
], weight_decay=self.WEIGHT_DECAY)
if self.CHECKPOINT:
self.optimizer.load_state_dict(torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['optimizer'])
def load_scheduler(self):
self.scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=[80, 85, 90], gamma=self.LR_STEP_GAMMA)
if self.CHECKPOINT:
state_dict = torch.load(os.path.join(self.WORKSPACE, self.CHECKPOINT))['scheduler']
self.scheduler.load_state_dict(state_dict)
def forward(self, data):
x, _, _, _, _, _, _, _ = data
x = x.to(self.DEVICE)
# Extract GT boxes from the data for plate cropping
img, lp_at, ch_at, bboxs, lp, lurds, lp_at_rec_facal, lp_lurd = data
# Parse plate labels and boxes from lp_lurd string
batch_boxes = []
batch_plate_labels = []
for b_idx in range(x.shape[0]):
lp_lurd_str = lp_lurd[b_idx]
plates_info = lp_lurd_str.split(';')
boxes_b = []
labels_b = []
for info in plates_info:
if '-' not in info:
continue
lurd_str, lp_str = info.split('-')
lurd_vals = [int(v) for v in lurd_str.split(',')]
lp_indices = [int(v) for v in lp_str.split(',')]
if len(lurd_vals) == 4 and lurd_vals[2] > lurd_vals[0] and lurd_vals[3] > lurd_vals[1]:
boxes_b.append(lurd_vals) # [l, t, r, b]
# Convert indices back to plate string
plate_str = ''.join([CHARACTER[idx] for idx in lp_indices if idx < len(CHARACTER) - 1])
labels_b.append(plate_str)
if len(boxes_b) > 0:
batch_boxes.append(torch.tensor(boxes_b, dtype=torch.float32, device=self.DEVICE))
else:
batch_boxes.append(torch.zeros((1, 4), dtype=torch.float32, device=self.DEVICE))
batch_plate_labels.extend(labels_b)
# Forward pass with GT boxes and labels
pred = self.model(x, boxes_lurd=batch_boxes, plate_labels=batch_plate_labels if batch_plate_labels else None)
return pred
def loss(self, data, pred):
img, lp_at, ch_at, bboxs, lp, lurds, lp_at_rec_facal, lp_lurd = data
y_detection, y_recognition, pred_at_lp, plate_counts = pred
# --- Attention loss (LP attention only, no char attention) ---
lp_at = lp_at.to(self.DEVICE).unsqueeze(dim=1)
loss_at = BCEWithWeightLoss(weight=[0.1, 0.9])(pred_at_lp, lp_at)
# --- Location loss (unchanged) ---
bboxs = bboxs.to(self.DEVICE)
lurds = lurds.to(self.DEVICE)
obj = torch.sum(bboxs, dim=3) > 0
noobj = torch.sum(bboxs, dim=3) == 0
l, t, r, b = torch.split(y_detection[:, :, :, :4], 1, dim=-1)
l = self.x_mask - l * img.shape[3]
t = self.y_mask - t * img.shape[2]
r = self.x_mask + r * img.shape[3]
b = self.y_mask + b * img.shape[2]
iou = IoU_multi(
torch.flatten(lurds, start_dim=0, end_dim=2),
torch.flatten(torch.concat([l, t, r, b], dim=3), start_dim=0, end_dim=2)
)
iou = iou.view(bboxs.shape[:3])
loss_location = -torch.log(iou + 1e-6) * obj
loss_location = torch.sum(loss_location) / (torch.sum(obj) + 1e-6)
# --- Confidence loss (unchanged) ---
confidence_location = iou.detach().float()
loss_confidence_location = nn.MSELoss(reduction='none')(y_detection[:, :, :, 4], confidence_location) * obj + \
0.1 * nn.MSELoss(reduction='none')(y_detection[:, :, :, 4],
torch.zeros_like(confidence_location, device=self.DEVICE)) * noobj
loss_confidence_location = torch.mean(loss_confidence_location)
# --- Recognition loss (PARSeq cross-entropy) ---
loss_classify = torch.tensor(0.0, device=self.DEVICE)
if y_recognition is not None and sum(plate_counts) > 0:
# Get target labels
batch_plate_labels = []
for b_idx in range(img.shape[0]):
lp_lurd_str = lp_lurd[b_idx]
plates_info = lp_lurd_str.split(';')
for info in plates_info:
if '-' not in info:
continue
lurd_str, lp_str = info.split('-')
lurd_vals = [int(v) for v in lurd_str.split(',')]
lp_indices = [int(v) for v in lp_str.split(',')]
if len(lurd_vals) == 4 and lurd_vals[2] > lurd_vals[0] and lurd_vals[3] > lurd_vals[1]:
plate_str = ''.join([CHARACTER[idx] for idx in lp_indices if idx < len(CHARACTER) - 1])
batch_plate_labels.append(plate_str)
if len(batch_plate_labels) > 0 and y_recognition.shape[0] == len(batch_plate_labels):
# Encode targets: [BOS, c1, ..., cN, EOS, PAD...]
targets = self.tokenizer.encode(batch_plate_labels, max_length=8, device=self.DEVICE)
# Target for loss: [c1, c2, ..., cN, EOS, PAD...] (shift by 1)
targets_shifted = targets[:, 1:] # Remove BOS, keep chars + EOS + PAD
# PARSeq output: (N, max_len+1, num_tokens=74)
# Flatten for cross-entropy
logits_flat = y_recognition.reshape(-1, y_recognition.shape[-1]) # (N*(max_len+1), 74)
targets_flat = targets_shifted.reshape(-1) # (N*(max_len+1),)
# Mask out PAD positions from loss
pad_mask = targets_flat != self.tokenizer.pad_id
if pad_mask.any():
loss_classify = F.cross_entropy(
logits_flat[pad_mask],
targets_flat[pad_mask],
ignore_index=self.tokenizer.pad_id
)
# --- Total loss ---
loss = 0.2 * loss_location + loss_confidence_location + 0.5 * loss_classify + 10 * loss_at
return loss
def before_val(self):
self.count_iou = 0
self.iou_list = [torch.zeros(1, device=self.DEVICE)]
self.sample_num = 0
self.pred_num = 1e-5
self.count_lp = 0
self.count_iou_lp = 0
def evaluate(self, data, pred):
img, _, _, _, _, _, _, lp_lurd = data
y_detection, _, pred_at_lp, _ = pred
for index in range(y_detection.shape[0]):
lp_lurd_list = lp_lurd[index].split(';')
lp_list = []
lurd_list = []
for i in lp_lurd_list:
lurd, lp_str = i.split('-')
lp_list.append(np.array(lp_str.split(',')).astype('int32'))
lurd_list.append(np.array(lurd.split(',')).astype('int32'))
l, t, r, b, c = torch.split(y_detection[index, :, :, :5], 1, dim=-1)
l = self.x_mask - l * img.shape[3]
t = self.y_mask - t * img.shape[2]
r = self.x_mask + r * img.shape[3]
b = self.y_mask + b * img.shape[2]
# For evaluation, get detected boxes then run PARSeq on crops
out = torch.flatten(torch.concat([c, l, t, r, b], dim=2), start_dim=0, end_dim=1)
out = out[torch.where(out[:, 0] > 0.3)]
if len(out) == 0:
self.sample_num += len(lp_list)
continue
out = NMS(out.unsqueeze(0).squeeze(0) if out.dim() == 2 else out, 0.3)
# Crop and recognize each detected plate
boxes_for_rec = []
for det in out:
boxes_for_rec.append(det[1:5])
if len(boxes_for_rec) > 0:
boxes_tensor = torch.stack(boxes_for_rec).unsqueeze(0)
input_img = img[index:index+1].to(self.DEVICE)
plate_texts, _ = self.model.recognize_plates(input_img, [boxes_tensor.squeeze(0)])
else:
plate_texts = []
self.sample_num += len(lp_list)
self.pred_num += len(out)
for i in range(len(lp_list)):
for j in range(len(out)):
iou = self.iou(
torch.from_numpy(lurd_list[i]).to(self.DEVICE),
torch.stack([out[j][1], out[j][2], out[j][3], out[j][4]])
)
self.iou_list.append(iou)
if iou > 0.7:
self.count_iou += 1
if j < len(plate_texts):
gt_str = ''.join([CHARACTER[idx] for idx in lp_list[i] if idx < len(CHARACTER) - 1])
pred_str = plate_texts[j]
if pred_str == gt_str:
self.count_lp += 1
if pred_str == gt_str and iou > 0.6:
self.count_iou_lp += 1
def after_val(self):
print(f"Val IoU Detection Accuracy:{self.count_iou / len(self.val_dataset):>7f}")
self.iou_list = torch.concat(self.iou_list, dim=0)
print(f"Ave IoU:{torch.mean(self.iou_list):>7f}")
print(f"Val Recognition Accuracy:{self.count_lp / len(self.val_dataset):>7f}")
print(f"Val Recognition and Detection Accuracy:{self.count_iou_lp / len(self.val_dataset):>7f}")
print(f"Val sample_num:{self.sample_num:>7f}")
print(f"Val pred_num:{self.pred_num:>7f}")
print(f"Val recall:{self.count_iou_lp / self.sample_num:>7f}")
print(f"Val precision:{self.count_iou_lp / self.pred_num:>7f}")
def iou(self, box, other_boxe):
box_area = (box[2] - box[0]) * (box[3] - box[1])
other_boxes_area = (other_boxe[2] - other_boxe[0]) * (other_boxe[3] - other_boxe[1])
x1 = torch.max(box[0], other_boxe[0])
y1 = torch.max(box[1], other_boxe[1])
x2 = torch.min(box[2], other_boxe[2])
y2 = torch.min(box[3], other_boxe[3])
Min = torch.zeros(1, device=box.device)
w, h = torch.max(Min, x2 - x1), torch.max(Min, y2 - y1)
overlap_area = w * h
iou = overlap_area / (box_area + other_boxes_area - overlap_area)
return iou