File size: 3,817 Bytes
58af589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from __future__ import annotations
import argparse
import json
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation, SegformerImageProcessor
ModelBundle = Tuple[SegformerForSemanticSegmentation, AutoImageProcessor]
def load_image(frame: Any, base_dir: Path) -> Image.Image:
if isinstance(frame, (bytes, bytearray, memoryview)):
return Image.open(BytesIO(frame)).convert("RGB")
path = Path(str(frame))
if not path.is_absolute():
path = (Path.cwd() / path).resolve()
if not path.exists():
candidate = (base_dir / str(frame)).resolve()
if candidate.exists():
path = candidate
return Image.open(path).convert("RGB")
def load_model(*_args: Any, **_kwargs: Any) -> ModelBundle | None:
base_dir = Path(__file__).resolve().parent
if not (base_dir / "config.json").exists():
return None
model = SegformerForSemanticSegmentation.from_pretrained(str(base_dir))
try:
processor = AutoImageProcessor.from_pretrained(str(base_dir))
except OSError:
image_size = getattr(model.config, "image_size", 224)
if isinstance(image_size, int):
size = {"height": image_size, "width": image_size}
else:
size = image_size
processor = SegformerImageProcessor(size=size)
model.eval()
return model, processor
def resolve_person_id(model: SegformerForSemanticSegmentation, num_labels: int) -> int:
label2id = getattr(model.config, "label2id", {}) or {}
person_id = label2id.get("person")
if isinstance(person_id, int) and 0 <= person_id < num_labels:
return person_id
if num_labels >= 2:
return 1
return 0
def run_model(model_bundle: ModelBundle, frame: "np.ndarray") -> List[Dict[str, Any]]:
image = Image.fromarray(frame)
model, processor = model_bundle
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
num_labels = logits.shape[1]
person_id = resolve_person_id(model, num_labels)
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
probs = upsampled_logits.softmax(dim=1)
pred = probs.argmax(dim=1)[0]
mask = (pred == person_id).cpu().numpy()
if not mask.any():
return []
ys, xs = np.where(mask)
x_min = float(xs.min())
y_min = float(ys.min())
x_max = float(xs.max())
y_max = float(ys.max())
person_prob = probs[0, person_id].cpu().numpy()
score = float(person_prob[mask].mean())
return [
{
"frame_idx": 0,
"class": "person",
"bbox": [x_min, y_min, x_max, y_max],
"score": score,
"track_id": "f0-d0",
}
]
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run SegFormer person segmentation.")
parser.add_argument(
"--stdin-raw",
action="store_true",
default=True,
help="Read raw image bytes from stdin.",
)
return parser
if __name__ == "__main__":
args = build_parser().parse_args()
base_dir = Path(__file__).resolve().parent
model_bundle = load_model()
if model_bundle is None:
print("[]")
sys.exit(0)
try:
image = load_image(sys.stdin.buffer.read(), base_dir)
except Exception:
print("[]")
sys.exit(0)
frame = np.array(image)
output = run_model(model_bundle, frame)
print(json.dumps(output))
|