CLPRNet-PARSeq / inference_parseq.py
theakhilshukla's picture
Upload inference_parseq.py
bef897d verified
"""
Inference script for CLPRNet with PARSeq Tiny backbone.
Two-stage inference:
1. Detection: CLPRNet backbone + detection head -> boxes (with NMS)
2. Recognition: Crop detected plates -> PARSeq Tiny -> plate strings
"""
from model_parseq import CLPRNetPARSeq, Tokenizer
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import os
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
CHARACTER = Tokenizer.CHARSET
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = (1024, 1024)
font_size = 30
try:
font = ImageFont.truetype('resource/msyh.ttc', font_size, encoding='utf-8')
except:
font = ImageFont.load_default()
# Load model
model = CLPRNetPARSeq(max_label_length=8)
model = model.to(DEVICE)
model.load_state_dict(torch.load('resource/CLPRNet_PARSeq.pth', map_location=DEVICE))
model.eval()
if not os.path.exists('output'):
os.makedirs('output')
tran = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def IOU(box, other_boxes):
box_area = (box[2] - box[0]) * (box[3] - box[1])
other_boxes_area = (other_boxes[:, 2] - other_boxes[:, 0]) * (other_boxes[:, 3] - other_boxes[:, 1])
x1 = torch.max(box[0], other_boxes[:, 0])
y1 = torch.max(box[1], other_boxes[:, 1])
x2 = torch.min(box[2], other_boxes[:, 2])
y2 = torch.min(box[3], other_boxes[:, 3])
Min = torch.zeros(1, device=box.device)
w, h = torch.max(Min, x2 - x1), torch.max(Min, y2 - y1)
overlap_area = w * h
iou = overlap_area / (box_area + other_boxes_area - overlap_area + 1e-6)
return iou
def NMS(boxes, C=0.3):
if len(boxes) == 0:
return []
sort_boxes = boxes[boxes[:, 0].argsort(descending=True)]
keep = []
while len(sort_boxes) > 0:
ref_box = sort_boxes[0]
keep.append(ref_box)
if len(sort_boxes) > 1:
other_boxes = sort_boxes[1:]
sort_boxes = other_boxes[torch.where(IOU(ref_box[1:5], other_boxes[:, 1:5]) < C)]
else:
break
return torch.stack(keep)
def inference(src, image_list):
grid = 64
mask_x = (np.array([[i for i in range(grid)]] * grid) + 0.5) * img_size[0] / grid
mask_y = (np.array([[i] * grid for i in range(grid)]) + 0.5) * img_size[1] / grid
mask = torch.from_numpy(np.stack([mask_x, mask_y], axis=2))
x_mask = mask[:, :, 0].to(DEVICE).unsqueeze_(dim=2)
y_mask = mask[:, :, 1].to(DEVICE).unsqueeze_(dim=2)
for img_name in image_list:
print(img_name)
org_img = cv2.imread(os.path.join(src, img_name))
# Normalize image (pad to square)
height, width, _ = org_img.shape
size = height if height > width else width
img2 = np.zeros((size, size, 3)).astype("uint8")
if height == size:
y = 0
x = (size - width) // 2
else:
x = 0
y = (size - height) // 2
img2[y:y + height, x:x + width, :] = org_img
img = cv2.resize(img2, img_size)
# Inference
inputs = img[:, :, ::-1] # BGR -> RGB
inputs = tran(inputs)
inputs = inputs.unsqueeze(dim=0)
inputs = inputs.to(DEVICE)
with torch.no_grad():
# Stage 1: Detection only (no boxes provided)
y_detection, _, at_lp, _ = model(inputs)
# Stage 2: Extract detected boxes and recognize plates
for index in range(y_detection.shape[0]):
l, t, r, b, c = torch.split(y_detection[index, :, :, :5], 1, dim=-1)
l = x_mask - l * inputs.shape[3]
t = y_mask - t * inputs.shape[2]
r = x_mask + r * inputs.shape[3]
b = y_mask + b * inputs.shape[2]
# Flatten and filter by confidence
out = torch.flatten(torch.concat([c, l, t, r, b], dim=2), start_dim=0, end_dim=1)
out = out[torch.where(out[:, 0] > 0.3)]
if len(out) == 0:
print(" No plates detected")
continue
out = NMS(out, 0.3)
# Crop detected plates and recognize with PARSeq
boxes_for_rec = [det[1:5] for det in out]
boxes_tensor = torch.stack(boxes_for_rec)
with torch.no_grad():
plate_texts, confidences = model.recognize_plates(
inputs[index:index+1], [boxes_tensor]
)
# Draw results
preb_lurd_list = []
preb_pl_list = []
preb_c = []
for i, det in enumerate(out):
lurd = torch.tensor([det[1], det[2], det[3], det[4]]).cpu().numpy()
lurd[0] = lurd[0] * size / img_size[0] - x
lurd[1] = lurd[1] * size / img_size[1] - y
lurd[2] = lurd[2] * size / img_size[0] - x
lurd[3] = lurd[3] * size / img_size[1] - y
preb_lurd_list.append(lurd.astype('int32'))
if i < len(plate_texts):
preb_pl_list.append(plate_texts[i])
det_conf = float(det[0].cpu().numpy())
rec_conf = confidences[i] if i < len(confidences) else 0.0
preb_c.append(round(det_conf * rec_conf, 3))
else:
preb_pl_list.append("???")
preb_c.append(0.0)
# Draw bounding boxes
for i in preb_lurd_list:
cv2.rectangle(org_img, i[:2], i[2:], (0, 0, 255), 2)
org_img_rgb = org_img[:, :, ::-1]
org_img_pil = Image.fromarray(org_img_rgb.astype('uint8')).convert('RGB')
draw = ImageDraw.Draw(org_img_pil)
for i in range(len(preb_pl_list)):
label_text = f"{preb_pl_list[i]}_{preb_c[i]}"
label_size = int(draw.textlength(label_text, font))
draw.rectangle(
[(preb_lurd_list[i][0], preb_lurd_list[i][1] - font_size),
(preb_lurd_list[i][0] + label_size, preb_lurd_list[i][1])],
fill='red'
)
draw.text(
xy=(preb_lurd_list[i][0], preb_lurd_list[i][1] - int(font_size * 1.25)),
text=label_text,
fill=(255, 255, 255),
font=font
)
print(f" Plate: {preb_pl_list[i]}, Conf: {preb_c[i]}")
org_img_pil.save('output/' + img_name)
if __name__ == '__main__':
src = 'image'
inference(src, os.listdir(src))