Spaces:
Sleeping
Sleeping
| import argparse | |
| import csv | |
| import json | |
| import os | |
| import time | |
| from itertools import product | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as standard_transforms | |
| from PIL import Image | |
| from scipy.spatial import cKDTree | |
| from models import build_model | |
| class Args: | |
| backbone = "vgg16_bn" | |
| row = 2 | |
| line = 2 | |
| def load_model(weight_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if device.type == "cuda": | |
| torch.backends.cudnn.benchmark = True | |
| model = build_model(Args()).to(device).eval() | |
| if os.path.exists(weight_path): | |
| checkpoint = torch.load(weight_path, map_location=device) | |
| model.load_state_dict(checkpoint["model"]) | |
| transform = standard_transforms.Compose([ | |
| standard_transforms.ToTensor(), | |
| standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| return model, device, transform | |
| def merge_points(points, radius=8.0): | |
| if not points: | |
| return [] | |
| pts = np.array(points, dtype=np.float32) | |
| tree = cKDTree(pts) | |
| suppressed = set() | |
| for i, j in tree.query_pairs(r=radius): | |
| if i not in suppressed and j not in suppressed: | |
| suppressed.add(j) | |
| return [pts[i].tolist() for i in range(len(pts)) if i not in suppressed] | |
| def infer_frame(image, model, device, transform, confidence, magnification, batch_size, patch_overlap): | |
| orig_w, orig_h = image.size | |
| patch_size = 512 | |
| pad = 256 | |
| work_w, work_h = int(orig_w * magnification), int(orig_h * magnification) | |
| scale = min(1.0, 3840 / float(max(work_w, work_h))) | |
| work_w, work_h = int(work_w * scale), int(work_h * scale) | |
| magnification = work_w / float(orig_w) | |
| resample_filter = getattr(Image, "Resampling", Image).LANCZOS if hasattr(Image, "Resampling") else getattr(Image, "ANTIALIAS", 1) | |
| image = image.resize((work_w, work_h), resample_filter) | |
| padded_w = ((work_w + pad * 2 + patch_size - 1) // patch_size) * patch_size | |
| padded_h = ((work_h + pad * 2 + patch_size - 1) // patch_size) * patch_size | |
| padded = Image.new("RGB", (padded_w, padded_h), (0, 0, 0)) | |
| padded.paste(image, (pad, pad)) | |
| stride = max(64, int(patch_size * (1.0 - patch_overlap))) | |
| jobs = [] | |
| for y in range(0, padded_h - stride + 1, stride): | |
| for x in range(0, padded_w - stride + 1, stride): | |
| if x + patch_size <= padded_w and y + patch_size <= padded_h: | |
| jobs.append((x, y, padded.crop((x, y, x + patch_size, y + patch_size)))) | |
| all_points = [] | |
| for start in range(0, len(jobs), batch_size): | |
| batch = jobs[start:start + batch_size] | |
| samples = torch.stack([transform(patch) for _, _, patch in batch]).to(device) | |
| with torch.inference_mode(): | |
| if device.type == "cuda": | |
| with torch.cuda.amp.autocast(): | |
| out = model(samples) | |
| else: | |
| out = model(samples) | |
| scores = torch.nn.functional.softmax(out["pred_logits"].float(), -1)[:, :, 1] | |
| points = out["pred_points"].float() | |
| for idx, (x, y, _) in enumerate(batch): | |
| selected = points[idx][scores[idx] > confidence].detach().cpu().numpy() | |
| if len(selected): | |
| selected[:, 0] += x - pad | |
| selected[:, 1] += y - pad | |
| selected /= float(magnification) | |
| all_points.extend([ | |
| p.tolist() for p in selected | |
| if 0 <= p[0] < orig_w and 0 <= p[1] < orig_h | |
| ]) | |
| return merge_points(all_points) | |
| def run_config(video, model, device, transform, cfg, max_frames): | |
| cap = cv2.VideoCapture(video) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frames_read = 0 | |
| frames_analyzed = 0 | |
| counts = [] | |
| start = time.perf_counter() | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret or (max_frames and frames_read >= max_frames): | |
| break | |
| if frames_read % cfg["frame_skip"] == 0: | |
| image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| points = infer_frame(image, model, device, transform, cfg["confidence"], cfg["magnification"], cfg["batch_size"], cfg["patch_overlap"]) | |
| counts.append(len(points)) | |
| frames_analyzed += 1 | |
| frames_read += 1 | |
| cap.release() | |
| elapsed = time.perf_counter() - start | |
| return { | |
| **cfg, | |
| "video_frames": total_frames, | |
| "frames_read": frames_read, | |
| "frames_analyzed": frames_analyzed, | |
| "elapsed_sec": round(elapsed, 4), | |
| "effective_fps": round(frames_read / elapsed, 4) if elapsed else 0, | |
| "analysis_fps": round(frames_analyzed / elapsed, 4) if elapsed else 0, | |
| "avg_count": round(float(np.mean(counts)), 4) if counts else 0, | |
| "max_count": int(max(counts)) if counts else 0, | |
| "std_count": round(float(np.std(counts)), 4) if counts else 0, | |
| } | |
| def recommendations(rows): | |
| return { | |
| "fast": max(rows, key=lambda row: row["effective_fps"]), | |
| "balanced": min(rows, key=lambda row: (row["std_count"], -row["effective_fps"])), | |
| "accurate": max(rows, key=lambda row: (row["patch_overlap"], row["magnification"], -row["frame_skip"])), | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--video", required=True) | |
| parser.add_argument("--weights", default=os.path.join("weights", "SHTechA.pth")) | |
| parser.add_argument("--output_dir", default="benchmark_results") | |
| parser.add_argument("--max_frames", type=int, default=120) | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| model, device, transform = load_model(args.weights) | |
| configs = [] | |
| for frame_skip, mag, batch, conf in product([1, 2, 5], [1.0, 1.5, 2.0], [4, 8, 16], [0.45, 0.5, 0.55]): | |
| configs.append({ | |
| "frame_skip": frame_skip, | |
| "magnification": mag, | |
| "batch_size": batch, | |
| "confidence": conf, | |
| "patch_overlap": 0.5 if frame_skip == 1 else 0.25 if frame_skip == 2 else 0.0, | |
| }) | |
| rows = [run_config(args.video, model, device, transform, cfg, args.max_frames) for cfg in configs] | |
| recs = recommendations(rows) | |
| csv_path = os.path.join(args.output_dir, "benchmark_results.csv") | |
| json_path = os.path.join(args.output_dir, "benchmark_results.json") | |
| with open(csv_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| with open(json_path, "w", encoding="utf-8") as f: | |
| json.dump({"results": rows, "recommendations": recs}, f, indent=2) | |
| print(json.dumps({"csv": csv_path, "json": json_path, "recommendations": recs}, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |