|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import torch |
|
|
import torchvision.transforms as standard_transforms |
|
|
|
|
|
import util.misc as utils |
|
|
from models import build_model |
|
|
|
|
|
PET_TRANSFORM = standard_transforms.Compose([ |
|
|
standard_transforms.ToTensor(), |
|
|
standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
def get_args_parser() -> argparse.ArgumentParser: |
|
|
parser = argparse.ArgumentParser('PET single-image inference (HF release)', add_help=False) |
|
|
|
|
|
parser.add_argument('--image_path', required=True, type=str, |
|
|
help='Path to a single input image.') |
|
|
parser.add_argument('--resume', default='PET_Finetuned.safetensors', type=str, |
|
|
help='Path to model weights (.safetensors or .pth).') |
|
|
parser.add_argument('--device', default='cuda', type=str, |
|
|
help='Device for inference, e.g. cuda or cpu.') |
|
|
|
|
|
parser.add_argument('--backbone', default='vgg16_bn', type=str) |
|
|
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned', 'fourier')) |
|
|
parser.add_argument('--dec_layers', default=2, type=int) |
|
|
parser.add_argument('--dim_feedforward', default=512, type=int) |
|
|
parser.add_argument('--hidden_dim', default=256, type=int) |
|
|
parser.add_argument('--dropout', default=0.0, type=float) |
|
|
parser.add_argument('--nheads', default=8, type=int) |
|
|
|
|
|
parser.add_argument('--set_cost_class', default=1, type=float) |
|
|
parser.add_argument('--set_cost_point', default=0.05, type=float) |
|
|
parser.add_argument('--ce_loss_coef', default=1.0, type=float) |
|
|
parser.add_argument('--point_loss_coef', default=5.0, type=float) |
|
|
parser.add_argument('--eos_coef', default=0.5, type=float) |
|
|
|
|
|
parser.add_argument('--dataset_file', default='SHA') |
|
|
parser.add_argument('--data_path', default='./data/ShanghaiTech/PartA', type=str) |
|
|
|
|
|
parser.add_argument('--upper_bound', default=-1, type=int, |
|
|
help='Max image side for inference; -1 means only cap at 2560 (same as compare_models).') |
|
|
parser.add_argument('--output_image', default='', type=str, |
|
|
help='Optional path to save annotated image panel.') |
|
|
parser.add_argument('--title_text', default='PET-Finetuned', type=str, |
|
|
help='Title prefix used in top panel text.') |
|
|
parser.add_argument('--radius', default=3, type=int) |
|
|
parser.add_argument('--point_color', default='0,255,0', type=str, |
|
|
help='BGR color for points, e.g., 0,255,0') |
|
|
parser.add_argument('--panel_long_side', default=1600, type=int, |
|
|
help='Resize annotated panel long side to this value.') |
|
|
parser.add_argument('--panel_pad', default=24, type=int, |
|
|
help='Panel padding around the image and title area.') |
|
|
parser.add_argument('--panel_font_size', default=48, type=int, |
|
|
help='Font size for panel title text.') |
|
|
|
|
|
parser.add_argument('--output_json', default='', type=str, |
|
|
help='Optional output JSON path for prediction details.') |
|
|
parser.add_argument('--seed', default=42, type=int) |
|
|
|
|
|
return parser |
|
|
|
|
|
|
|
|
def parse_color(color_str: str): |
|
|
parts = color_str.split(',') |
|
|
if len(parts) != 3: |
|
|
raise ValueError('color must be B,G,R like 0,255,0') |
|
|
return tuple(int(p.strip()) for p in parts) |
|
|
|
|
|
|
|
|
def resolve_device(device_str: str) -> torch.device: |
|
|
if device_str.startswith('cuda') and not torch.cuda.is_available(): |
|
|
print('CUDA not available. Falling back to CPU.') |
|
|
return torch.device('cpu') |
|
|
device = torch.device(device_str) |
|
|
if device.type == 'cuda' and device.index is not None: |
|
|
torch.cuda.set_device(device.index) |
|
|
return device |
|
|
|
|
|
|
|
|
def resize_for_eval(frame_rgb, upper_bound): |
|
|
h, w = frame_rgb.shape[:2] |
|
|
max_size = max(h, w) |
|
|
if upper_bound != -1 and max_size > upper_bound: |
|
|
scale = float(upper_bound) / float(max_size) |
|
|
elif max_size > 2560: |
|
|
scale = 2560.0 / float(max_size) |
|
|
else: |
|
|
scale = 1.0 |
|
|
if scale == 1.0: |
|
|
return frame_rgb, scale |
|
|
new_w = max(1, int(round(w * scale))) |
|
|
new_h = max(1, int(round(h * scale))) |
|
|
resized = cv2.resize(frame_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR) |
|
|
return resized, scale |
|
|
|
|
|
|
|
|
def load_font(font_size=40, bold=False, font_paths=None): |
|
|
if font_paths is None: |
|
|
if bold: |
|
|
font_paths = [ |
|
|
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', |
|
|
'/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf', |
|
|
'/usr/share/fonts/truetype/freefont/FreeSansBold.ttf', |
|
|
] |
|
|
else: |
|
|
font_paths = [ |
|
|
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', |
|
|
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', |
|
|
'/usr/share/fonts/truetype/freefont/FreeSans.ttf', |
|
|
] |
|
|
for font_path in font_paths: |
|
|
if os.path.exists(font_path): |
|
|
try: |
|
|
return ImageFont.truetype(font_path, font_size) |
|
|
except OSError: |
|
|
continue |
|
|
try: |
|
|
fallback = 'DejaVuSans-Bold.ttf' if bold else 'DejaVuSans.ttf' |
|
|
return ImageFont.truetype(fallback, font_size) |
|
|
except OSError: |
|
|
return ImageFont.load_default() |
|
|
|
|
|
|
|
|
def draw_text(draw, xy, text, font, fill, bold=False, stroke_width=0): |
|
|
if bold and stroke_width <= 0: |
|
|
stroke_width = 2 |
|
|
try: |
|
|
if bold: |
|
|
draw.text( |
|
|
xy, |
|
|
text, |
|
|
fill=fill, |
|
|
font=font, |
|
|
stroke_width=stroke_width, |
|
|
stroke_fill=fill, |
|
|
) |
|
|
else: |
|
|
draw.text(xy, text, fill=fill, font=font) |
|
|
except TypeError: |
|
|
if bold: |
|
|
offsets = [(0, 0), (1, 0), (0, 1), (1, 1)] |
|
|
for dx, dy in offsets: |
|
|
draw.text((xy[0] + dx, xy[1] + dy), text, fill=fill, font=font) |
|
|
else: |
|
|
draw.text(xy, text, fill=fill, font=font) |
|
|
|
|
|
|
|
|
def _get_text_size(draw, text, font, bold=False, stroke_width=0): |
|
|
if hasattr(draw, 'textbbox'): |
|
|
try: |
|
|
x0, y0, x1, y1 = draw.textbbox( |
|
|
(0, 0), |
|
|
text, |
|
|
font=font, |
|
|
stroke_width=stroke_width if bold else 0, |
|
|
) |
|
|
except TypeError: |
|
|
x0, y0, x1, y1 = draw.textbbox((0, 0), text, font=font) |
|
|
return x1 - x0, y1 - y0 |
|
|
w, h = draw.textsize(text, font=font) |
|
|
if bold: |
|
|
w += stroke_width * 2 |
|
|
h += stroke_width * 2 |
|
|
return w, h |
|
|
|
|
|
|
|
|
def fit_text_to_width(draw, text, font, max_w, bold=False, stroke_width=0): |
|
|
text = text or '' |
|
|
if max_w <= 0: |
|
|
return '' |
|
|
|
|
|
text_w, _ = _get_text_size(draw, text, font, bold=bold, stroke_width=stroke_width) |
|
|
if text_w <= max_w: |
|
|
return text |
|
|
|
|
|
ellipsis = '...' |
|
|
ellipsis_w, _ = _get_text_size(draw, ellipsis, font, bold=bold, stroke_width=stroke_width) |
|
|
if ellipsis_w > max_w: |
|
|
return '' |
|
|
|
|
|
trimmed = text |
|
|
while trimmed: |
|
|
trimmed = trimmed[:-1] |
|
|
candidate = trimmed + ellipsis |
|
|
cand_w, _ = _get_text_size(draw, candidate, font, bold=bold, stroke_width=stroke_width) |
|
|
if cand_w <= max_w: |
|
|
return candidate |
|
|
return ellipsis |
|
|
|
|
|
|
|
|
def bgr_to_rgb(color): |
|
|
return (color[2], color[1], color[0]) |
|
|
|
|
|
|
|
|
def resize_with_points(img, pts, target_long_side): |
|
|
if target_long_side is None or target_long_side <= 0: |
|
|
return img, pts |
|
|
w, h = img.size |
|
|
max_dim = max(w, h) |
|
|
if max_dim <= 0 or max_dim == target_long_side: |
|
|
return img, pts |
|
|
scale = float(target_long_side) / float(max_dim) |
|
|
new_w = max(1, int(round(w * scale))) |
|
|
new_h = max(1, int(round(h * scale))) |
|
|
img = img.resize((new_w, new_h), Image.BILINEAR) |
|
|
if pts is not None and pts.size > 0: |
|
|
pts = pts * scale |
|
|
return img, pts |
|
|
|
|
|
|
|
|
def add_padding_with_text(img, text, pad, font, text_color, bg_color, bold, stroke_width): |
|
|
if pad is None or pad <= 0: |
|
|
return img |
|
|
draw_tmp = ImageDraw.Draw(img) |
|
|
text = text or '' |
|
|
text_w, text_h = _get_text_size(draw_tmp, text, font, bold=bold, stroke_width=stroke_width) |
|
|
min_text_gap = 24 |
|
|
min_pad = text_h + (2 * min_text_gap) |
|
|
pad = max(pad, min_pad) |
|
|
new_w = img.width + pad * 2 |
|
|
new_h = img.height + pad * 2 |
|
|
canvas = Image.new('RGB', (new_w, new_h), color=bg_color) |
|
|
canvas.paste(img, (pad, pad)) |
|
|
|
|
|
draw = ImageDraw.Draw(canvas) |
|
|
max_text_w = max(0, new_w - (2 * pad)) |
|
|
text = fit_text_to_width(draw, text, font, max_text_w, bold=bold, stroke_width=stroke_width) |
|
|
text_w, text_h = _get_text_size(draw, text, font, bold=bold, stroke_width=stroke_width) |
|
|
text_x = pad |
|
|
text_y = max(min_text_gap, (pad - text_h) // 2) |
|
|
text_y = min(text_y, max(0, pad - text_h - min_text_gap)) |
|
|
draw_text(draw, (text_x, text_y), text, font, text_color, bold=bold, stroke_width=stroke_width) |
|
|
return canvas |
|
|
|
|
|
|
|
|
def annotate_panel( |
|
|
img_bgr, |
|
|
pts, |
|
|
title_text, |
|
|
point_color_bgr, |
|
|
radius, |
|
|
font, |
|
|
text_color, |
|
|
title_bg, |
|
|
target_long_side, |
|
|
pad, |
|
|
): |
|
|
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) |
|
|
img = Image.fromarray(rgb) |
|
|
img, pts = resize_with_points(img, pts, target_long_side) |
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
max_dim = max(img.width, img.height) |
|
|
auto_radius = max(3, int(round(max_dim * 0.004))) |
|
|
if radius is None or radius < auto_radius: |
|
|
radius = auto_radius |
|
|
|
|
|
if pts is not None and pts.size > 0: |
|
|
color = bgr_to_rgb(point_color_bgr) |
|
|
for x, y in pts: |
|
|
x0 = x - radius |
|
|
y0 = y - radius |
|
|
x1 = x + radius |
|
|
y1 = y + radius |
|
|
draw.ellipse((x0, y0, x1, y1), fill=color, outline=color) |
|
|
|
|
|
return add_padding_with_text( |
|
|
img, |
|
|
title_text or '', |
|
|
pad, |
|
|
font, |
|
|
text_color, |
|
|
title_bg, |
|
|
bold=False, |
|
|
stroke_width=0, |
|
|
) |
|
|
|
|
|
|
|
|
def _load_state_dict(weight_path: Path): |
|
|
if not weight_path.exists(): |
|
|
raise FileNotFoundError(f'Weights file not found: {weight_path}') |
|
|
|
|
|
if weight_path.suffix == '.safetensors': |
|
|
try: |
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
except ImportError as exc: |
|
|
raise ImportError( |
|
|
'safetensors is required to load .safetensors weights. Install with: pip install safetensors' |
|
|
) from exc |
|
|
return load_safetensors(str(weight_path), device='cpu') |
|
|
|
|
|
checkpoint = torch.load(str(weight_path), map_location='cpu') |
|
|
if isinstance(checkpoint, dict) and 'model' in checkpoint and isinstance(checkpoint['model'], dict): |
|
|
return checkpoint['model'] |
|
|
if isinstance(checkpoint, dict) and checkpoint and all(torch.is_tensor(v) for v in checkpoint.values()): |
|
|
return checkpoint |
|
|
raise ValueError( |
|
|
'Unsupported checkpoint format. Expected .safetensors or .pth containing a model state_dict.' |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def infer_pet_points(model, frame_bgr, device, upper_bound): |
|
|
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) |
|
|
resized_rgb, scale = resize_for_eval(frame_rgb, upper_bound) |
|
|
resized_h, resized_w = resized_rgb.shape[:2] |
|
|
|
|
|
img = Image.fromarray(resized_rgb) |
|
|
img = PET_TRANSFORM(img) |
|
|
samples = utils.nested_tensor_from_tensor_list([img]).to(device) |
|
|
img_h, img_w = samples.tensors.shape[-2:] |
|
|
|
|
|
outputs = model(samples, test=True) |
|
|
outputs_points = outputs['pred_points'] |
|
|
if outputs_points.dim() == 3: |
|
|
outputs_points = outputs_points[0] |
|
|
pred_points = outputs_points.detach().cpu().numpy() |
|
|
|
|
|
if pred_points.size == 0: |
|
|
return np.zeros((0, 2), dtype=np.float32), scale |
|
|
|
|
|
pred_points[:, 0] *= float(img_h) |
|
|
pred_points[:, 1] *= float(img_w) |
|
|
|
|
|
pred_points[:, 0] = np.clip(pred_points[:, 0], 0.0, float(resized_h - 1)) |
|
|
pred_points[:, 1] = np.clip(pred_points[:, 1], 0.0, float(resized_w - 1)) |
|
|
|
|
|
if scale != 1.0: |
|
|
pred_points = pred_points / float(scale) |
|
|
|
|
|
orig_h, orig_w = frame_bgr.shape[:2] |
|
|
pred_points[:, 0] = np.clip(pred_points[:, 0], 0.0, float(orig_h - 1)) |
|
|
pred_points[:, 1] = np.clip(pred_points[:, 1], 0.0, float(orig_w - 1)) |
|
|
|
|
|
points_xy = np.stack([pred_points[:, 1], pred_points[:, 0]], axis=1) |
|
|
return points_xy, scale |
|
|
|
|
|
|
|
|
def main(args) -> None: |
|
|
device = resolve_device(args.device) |
|
|
|
|
|
model, _ = build_model(args) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
state_dict = _load_state_dict(Path(args.resume)) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
image_path = Path(args.image_path) |
|
|
frame_bgr = cv2.imread(str(image_path)) |
|
|
if frame_bgr is None: |
|
|
raise ValueError(f'Failed to read image: {image_path}') |
|
|
|
|
|
points_xy, scale = infer_pet_points(model, frame_bgr, device, args.upper_bound) |
|
|
count = int(points_xy.shape[0]) if points_xy.size > 0 else 0 |
|
|
|
|
|
result = { |
|
|
'image': str(image_path), |
|
|
'count': count, |
|
|
'points_xy': points_xy.tolist(), |
|
|
'scale': scale, |
|
|
} |
|
|
|
|
|
print(f'image: {result["image"]}') |
|
|
print(f'predicted_count: {result["count"]}') |
|
|
|
|
|
if args.output_json: |
|
|
output_json = Path(args.output_json) |
|
|
output_json.parent.mkdir(parents=True, exist_ok=True) |
|
|
output_json.write_text(json.dumps(result, indent=2)) |
|
|
print(f'json_saved_to: {output_json}') |
|
|
|
|
|
if args.output_image: |
|
|
output_image = Path(args.output_image) |
|
|
output_image.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
panel = annotate_panel( |
|
|
frame_bgr, |
|
|
points_xy, |
|
|
f'{args.title_text} Count : {count}', |
|
|
parse_color(args.point_color), |
|
|
args.radius, |
|
|
load_font(font_size=args.panel_font_size, bold=False), |
|
|
text_color=(0, 0, 0), |
|
|
title_bg=(255, 255, 255), |
|
|
target_long_side=args.panel_long_side, |
|
|
pad=args.panel_pad, |
|
|
) |
|
|
panel.save(str(output_image)) |
|
|
print(f'annotated_image_saved_to: {output_image}') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser( |
|
|
'PET single-image inference', |
|
|
parents=[get_args_parser()], |
|
|
) |
|
|
main(parser.parse_args()) |
|
|
|