import argparse import json import os from pathlib import Path import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFont import torch import torchvision.transforms as standard_transforms import util.misc as utils from models import build_model PET_TRANSFORM = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def get_args_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser('PET single-image inference (HF release)', add_help=False) parser.add_argument('--image_path', required=True, type=str, help='Path to a single input image.') parser.add_argument('--resume', default='PET_Finetuned.safetensors', type=str, help='Path to model weights (.safetensors or .pth).') parser.add_argument('--device', default='cuda', type=str, help='Device for inference, e.g. cuda or cpu.') parser.add_argument('--backbone', default='vgg16_bn', type=str) parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned', 'fourier')) parser.add_argument('--dec_layers', default=2, type=int) parser.add_argument('--dim_feedforward', default=512, type=int) parser.add_argument('--hidden_dim', default=256, type=int) parser.add_argument('--dropout', default=0.0, type=float) parser.add_argument('--nheads', default=8, type=int) parser.add_argument('--set_cost_class', default=1, type=float) parser.add_argument('--set_cost_point', default=0.05, type=float) parser.add_argument('--ce_loss_coef', default=1.0, type=float) parser.add_argument('--point_loss_coef', default=5.0, type=float) parser.add_argument('--eos_coef', default=0.5, type=float) parser.add_argument('--dataset_file', default='SHA') parser.add_argument('--data_path', default='./data/ShanghaiTech/PartA', type=str) parser.add_argument('--upper_bound', default=-1, type=int, help='Max image side for inference; -1 means only cap at 2560 (same as compare_models).') parser.add_argument('--output_image', default='', type=str, help='Optional path to save annotated image panel.') parser.add_argument('--title_text', default='PET-Finetuned', type=str, help='Title prefix used in top panel text.') parser.add_argument('--radius', default=3, type=int) parser.add_argument('--point_color', default='0,255,0', type=str, help='BGR color for points, e.g., 0,255,0') parser.add_argument('--panel_long_side', default=1600, type=int, help='Resize annotated panel long side to this value.') parser.add_argument('--panel_pad', default=24, type=int, help='Panel padding around the image and title area.') parser.add_argument('--panel_font_size', default=48, type=int, help='Font size for panel title text.') parser.add_argument('--output_json', default='', type=str, help='Optional output JSON path for prediction details.') parser.add_argument('--seed', default=42, type=int) return parser def parse_color(color_str: str): parts = color_str.split(',') if len(parts) != 3: raise ValueError('color must be B,G,R like 0,255,0') return tuple(int(p.strip()) for p in parts) def resolve_device(device_str: str) -> torch.device: if device_str.startswith('cuda') and not torch.cuda.is_available(): print('CUDA not available. Falling back to CPU.') return torch.device('cpu') device = torch.device(device_str) if device.type == 'cuda' and device.index is not None: torch.cuda.set_device(device.index) return device def resize_for_eval(frame_rgb, upper_bound): h, w = frame_rgb.shape[:2] max_size = max(h, w) if upper_bound != -1 and max_size > upper_bound: scale = float(upper_bound) / float(max_size) elif max_size > 2560: scale = 2560.0 / float(max_size) else: scale = 1.0 if scale == 1.0: return frame_rgb, scale new_w = max(1, int(round(w * scale))) new_h = max(1, int(round(h * scale))) resized = cv2.resize(frame_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR) return resized, scale def load_font(font_size=40, bold=False, font_paths=None): if font_paths is None: if bold: font_paths = [ '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', '/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf', '/usr/share/fonts/truetype/freefont/FreeSansBold.ttf', ] else: font_paths = [ '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', '/usr/share/fonts/truetype/freefont/FreeSans.ttf', ] for font_path in font_paths: if os.path.exists(font_path): try: return ImageFont.truetype(font_path, font_size) except OSError: continue try: fallback = 'DejaVuSans-Bold.ttf' if bold else 'DejaVuSans.ttf' return ImageFont.truetype(fallback, font_size) except OSError: return ImageFont.load_default() def draw_text(draw, xy, text, font, fill, bold=False, stroke_width=0): if bold and stroke_width <= 0: stroke_width = 2 try: if bold: draw.text( xy, text, fill=fill, font=font, stroke_width=stroke_width, stroke_fill=fill, ) else: draw.text(xy, text, fill=fill, font=font) except TypeError: if bold: offsets = [(0, 0), (1, 0), (0, 1), (1, 1)] for dx, dy in offsets: draw.text((xy[0] + dx, xy[1] + dy), text, fill=fill, font=font) else: draw.text(xy, text, fill=fill, font=font) def _get_text_size(draw, text, font, bold=False, stroke_width=0): if hasattr(draw, 'textbbox'): try: x0, y0, x1, y1 = draw.textbbox( (0, 0), text, font=font, stroke_width=stroke_width if bold else 0, ) except TypeError: x0, y0, x1, y1 = draw.textbbox((0, 0), text, font=font) return x1 - x0, y1 - y0 w, h = draw.textsize(text, font=font) if bold: w += stroke_width * 2 h += stroke_width * 2 return w, h def fit_text_to_width(draw, text, font, max_w, bold=False, stroke_width=0): text = text or '' if max_w <= 0: return '' text_w, _ = _get_text_size(draw, text, font, bold=bold, stroke_width=stroke_width) if text_w <= max_w: return text ellipsis = '...' ellipsis_w, _ = _get_text_size(draw, ellipsis, font, bold=bold, stroke_width=stroke_width) if ellipsis_w > max_w: return '' trimmed = text while trimmed: trimmed = trimmed[:-1] candidate = trimmed + ellipsis cand_w, _ = _get_text_size(draw, candidate, font, bold=bold, stroke_width=stroke_width) if cand_w <= max_w: return candidate return ellipsis def bgr_to_rgb(color): return (color[2], color[1], color[0]) def resize_with_points(img, pts, target_long_side): if target_long_side is None or target_long_side <= 0: return img, pts w, h = img.size max_dim = max(w, h) if max_dim <= 0 or max_dim == target_long_side: return img, pts scale = float(target_long_side) / float(max_dim) new_w = max(1, int(round(w * scale))) new_h = max(1, int(round(h * scale))) img = img.resize((new_w, new_h), Image.BILINEAR) if pts is not None and pts.size > 0: pts = pts * scale return img, pts def add_padding_with_text(img, text, pad, font, text_color, bg_color, bold, stroke_width): if pad is None or pad <= 0: return img draw_tmp = ImageDraw.Draw(img) text = text or '' text_w, text_h = _get_text_size(draw_tmp, text, font, bold=bold, stroke_width=stroke_width) min_text_gap = 24 min_pad = text_h + (2 * min_text_gap) pad = max(pad, min_pad) new_w = img.width + pad * 2 new_h = img.height + pad * 2 canvas = Image.new('RGB', (new_w, new_h), color=bg_color) canvas.paste(img, (pad, pad)) draw = ImageDraw.Draw(canvas) max_text_w = max(0, new_w - (2 * pad)) text = fit_text_to_width(draw, text, font, max_text_w, bold=bold, stroke_width=stroke_width) text_w, text_h = _get_text_size(draw, text, font, bold=bold, stroke_width=stroke_width) text_x = pad text_y = max(min_text_gap, (pad - text_h) // 2) text_y = min(text_y, max(0, pad - text_h - min_text_gap)) draw_text(draw, (text_x, text_y), text, font, text_color, bold=bold, stroke_width=stroke_width) return canvas def annotate_panel( img_bgr, pts, title_text, point_color_bgr, radius, font, text_color, title_bg, target_long_side, pad, ): rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img = Image.fromarray(rgb) img, pts = resize_with_points(img, pts, target_long_side) draw = ImageDraw.Draw(img) max_dim = max(img.width, img.height) auto_radius = max(3, int(round(max_dim * 0.004))) if radius is None or radius < auto_radius: radius = auto_radius if pts is not None and pts.size > 0: color = bgr_to_rgb(point_color_bgr) for x, y in pts: x0 = x - radius y0 = y - radius x1 = x + radius y1 = y + radius draw.ellipse((x0, y0, x1, y1), fill=color, outline=color) return add_padding_with_text( img, title_text or '', pad, font, text_color, title_bg, bold=False, stroke_width=0, ) def _load_state_dict(weight_path: Path): if not weight_path.exists(): raise FileNotFoundError(f'Weights file not found: {weight_path}') if weight_path.suffix == '.safetensors': try: from safetensors.torch import load_file as load_safetensors except ImportError as exc: raise ImportError( 'safetensors is required to load .safetensors weights. Install with: pip install safetensors' ) from exc return load_safetensors(str(weight_path), device='cpu') checkpoint = torch.load(str(weight_path), map_location='cpu') if isinstance(checkpoint, dict) and 'model' in checkpoint and isinstance(checkpoint['model'], dict): return checkpoint['model'] if isinstance(checkpoint, dict) and checkpoint and all(torch.is_tensor(v) for v in checkpoint.values()): return checkpoint raise ValueError( 'Unsupported checkpoint format. Expected .safetensors or .pth containing a model state_dict.' ) @torch.no_grad() def infer_pet_points(model, frame_bgr, device, upper_bound): frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) resized_rgb, scale = resize_for_eval(frame_rgb, upper_bound) resized_h, resized_w = resized_rgb.shape[:2] img = Image.fromarray(resized_rgb) img = PET_TRANSFORM(img) samples = utils.nested_tensor_from_tensor_list([img]).to(device) img_h, img_w = samples.tensors.shape[-2:] outputs = model(samples, test=True) outputs_points = outputs['pred_points'] if outputs_points.dim() == 3: outputs_points = outputs_points[0] pred_points = outputs_points.detach().cpu().numpy() if pred_points.size == 0: return np.zeros((0, 2), dtype=np.float32), scale pred_points[:, 0] *= float(img_h) pred_points[:, 1] *= float(img_w) pred_points[:, 0] = np.clip(pred_points[:, 0], 0.0, float(resized_h - 1)) pred_points[:, 1] = np.clip(pred_points[:, 1], 0.0, float(resized_w - 1)) if scale != 1.0: pred_points = pred_points / float(scale) orig_h, orig_w = frame_bgr.shape[:2] pred_points[:, 0] = np.clip(pred_points[:, 0], 0.0, float(orig_h - 1)) pred_points[:, 1] = np.clip(pred_points[:, 1], 0.0, float(orig_w - 1)) points_xy = np.stack([pred_points[:, 1], pred_points[:, 0]], axis=1) return points_xy, scale def main(args) -> None: device = resolve_device(args.device) model, _ = build_model(args) model.to(device) model.eval() state_dict = _load_state_dict(Path(args.resume)) model.load_state_dict(state_dict, strict=True) image_path = Path(args.image_path) frame_bgr = cv2.imread(str(image_path)) if frame_bgr is None: raise ValueError(f'Failed to read image: {image_path}') points_xy, scale = infer_pet_points(model, frame_bgr, device, args.upper_bound) count = int(points_xy.shape[0]) if points_xy.size > 0 else 0 result = { 'image': str(image_path), 'count': count, 'points_xy': points_xy.tolist(), 'scale': scale, } print(f'image: {result["image"]}') print(f'predicted_count: {result["count"]}') if args.output_json: output_json = Path(args.output_json) output_json.parent.mkdir(parents=True, exist_ok=True) output_json.write_text(json.dumps(result, indent=2)) print(f'json_saved_to: {output_json}') if args.output_image: output_image = Path(args.output_image) output_image.parent.mkdir(parents=True, exist_ok=True) panel = annotate_panel( frame_bgr, points_xy, f'{args.title_text} Count : {count}', parse_color(args.point_color), args.radius, load_font(font_size=args.panel_font_size, bold=False), text_color=(0, 0, 0), title_bg=(255, 255, 255), target_long_side=args.panel_long_side, pad=args.panel_pad, ) panel.save(str(output_image)) print(f'annotated_image_saved_to: {output_image}') if __name__ == '__main__': parser = argparse.ArgumentParser( 'PET single-image inference', parents=[get_args_parser()], ) main(parser.parse_args())