| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoImageProcessor, SegformerForSemanticSegmentation, SegformerImageProcessor | |
| ModelBundle = Tuple[SegformerForSemanticSegmentation, AutoImageProcessor] | |
| def load_image(frame: Any, base_dir: Path) -> Image.Image: | |
| if isinstance(frame, (bytes, bytearray, memoryview)): | |
| return Image.open(BytesIO(frame)).convert("RGB") | |
| path = Path(str(frame)) | |
| if not path.is_absolute(): | |
| path = (Path.cwd() / path).resolve() | |
| if not path.exists(): | |
| candidate = (base_dir / str(frame)).resolve() | |
| if candidate.exists(): | |
| path = candidate | |
| return Image.open(path).convert("RGB") | |
| def load_model(*_args: Any, **_kwargs: Any) -> ModelBundle | None: | |
| base_dir = Path(__file__).resolve().parent | |
| if not (base_dir / "config.json").exists(): | |
| return None | |
| model = SegformerForSemanticSegmentation.from_pretrained(str(base_dir)) | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(str(base_dir)) | |
| except OSError: | |
| image_size = getattr(model.config, "image_size", 224) | |
| if isinstance(image_size, int): | |
| size = {"height": image_size, "width": image_size} | |
| else: | |
| size = image_size | |
| processor = SegformerImageProcessor(size=size) | |
| model.eval() | |
| return model, processor | |
| def resolve_person_id(model: SegformerForSemanticSegmentation, num_labels: int) -> int: | |
| label2id = getattr(model.config, "label2id", {}) or {} | |
| person_id = label2id.get("person") | |
| if isinstance(person_id, int) and 0 <= person_id < num_labels: | |
| return person_id | |
| if num_labels >= 2: | |
| return 1 | |
| return 0 | |
| def run_model(model_bundle: ModelBundle, frame: "np.ndarray") -> List[Dict[str, Any]]: | |
| image = Image.fromarray(frame) | |
| model, processor = model_bundle | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| num_labels = logits.shape[1] | |
| person_id = resolve_person_id(model, num_labels) | |
| upsampled_logits = torch.nn.functional.interpolate( | |
| logits, | |
| size=image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| probs = upsampled_logits.softmax(dim=1) | |
| pred = probs.argmax(dim=1)[0] | |
| mask = (pred == person_id).cpu().numpy() | |
| if not mask.any(): | |
| return [] | |
| ys, xs = np.where(mask) | |
| x_min = float(xs.min()) | |
| y_min = float(ys.min()) | |
| x_max = float(xs.max()) | |
| y_max = float(ys.max()) | |
| person_prob = probs[0, person_id].cpu().numpy() | |
| score = float(person_prob[mask].mean()) | |
| return [ | |
| { | |
| "frame_idx": 0, | |
| "class": "person", | |
| "bbox": [x_min, y_min, x_max, y_max], | |
| "score": score, | |
| "track_id": "f0-d0", | |
| } | |
| ] | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Run SegFormer person segmentation.") | |
| parser.add_argument( | |
| "--stdin-raw", | |
| action="store_true", | |
| default=True, | |
| help="Read raw image bytes from stdin.", | |
| ) | |
| return parser | |
| if __name__ == "__main__": | |
| args = build_parser().parse_args() | |
| base_dir = Path(__file__).resolve().parent | |
| model_bundle = load_model() | |
| if model_bundle is None: | |
| print("[]") | |
| sys.exit(0) | |
| try: | |
| image = load_image(sys.stdin.buffer.read(), base_dir) | |
| except Exception: | |
| print("[]") | |
| sys.exit(0) | |
| frame = np.array(image) | |
| output = run_model(model_bundle, frame) | |
| print(json.dumps(output)) | |