from __future__ import annotations import argparse import json import sys from io import BytesIO from pathlib import Path from typing import Any, Dict, List, Tuple import numpy as np from PIL import Image import torch from transformers import AutoImageProcessor, SegformerForSemanticSegmentation, SegformerImageProcessor ModelBundle = Tuple[SegformerForSemanticSegmentation, AutoImageProcessor] def load_image(frame: Any, base_dir: Path) -> Image.Image: if isinstance(frame, (bytes, bytearray, memoryview)): return Image.open(BytesIO(frame)).convert("RGB") path = Path(str(frame)) if not path.is_absolute(): path = (Path.cwd() / path).resolve() if not path.exists(): candidate = (base_dir / str(frame)).resolve() if candidate.exists(): path = candidate return Image.open(path).convert("RGB") def load_model(*_args: Any, **_kwargs: Any) -> ModelBundle | None: base_dir = Path(__file__).resolve().parent if not (base_dir / "config.json").exists(): return None model = SegformerForSemanticSegmentation.from_pretrained(str(base_dir)) try: processor = AutoImageProcessor.from_pretrained(str(base_dir)) except OSError: image_size = getattr(model.config, "image_size", 224) if isinstance(image_size, int): size = {"height": image_size, "width": image_size} else: size = image_size processor = SegformerImageProcessor(size=size) model.eval() return model, processor def resolve_person_id(model: SegformerForSemanticSegmentation, num_labels: int) -> int: label2id = getattr(model.config, "label2id", {}) or {} person_id = label2id.get("person") if isinstance(person_id, int) and 0 <= person_id < num_labels: return person_id if num_labels >= 2: return 1 return 0 def run_model(model_bundle: ModelBundle, frame: "np.ndarray") -> List[Dict[str, Any]]: image = Image.fromarray(frame) model, processor = model_bundle inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits num_labels = logits.shape[1] person_id = resolve_person_id(model, num_labels) upsampled_logits = torch.nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) probs = upsampled_logits.softmax(dim=1) pred = probs.argmax(dim=1)[0] mask = (pred == person_id).cpu().numpy() if not mask.any(): return [] ys, xs = np.where(mask) x_min = float(xs.min()) y_min = float(ys.min()) x_max = float(xs.max()) y_max = float(ys.max()) person_prob = probs[0, person_id].cpu().numpy() score = float(person_prob[mask].mean()) return [ { "frame_idx": 0, "class": "person", "bbox": [x_min, y_min, x_max, y_max], "score": score, "track_id": "f0-d0", } ] def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run SegFormer person segmentation.") parser.add_argument( "--stdin-raw", action="store_true", default=True, help="Read raw image bytes from stdin.", ) return parser if __name__ == "__main__": args = build_parser().parse_args() base_dir = Path(__file__).resolve().parent model_bundle = load_model() if model_bundle is None: print("[]") sys.exit(0) try: image = load_image(sys.stdin.buffer.read(), base_dir) except Exception: print("[]") sys.exit(0) frame = np.array(image) output = run_model(model_bundle, frame) print(json.dumps(output))