MTerryJack's picture
Upload 10 files
58af589 verified
from __future__ import annotations
import argparse
import json
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation, SegformerImageProcessor
ModelBundle = Tuple[SegformerForSemanticSegmentation, AutoImageProcessor]
def load_image(frame: Any, base_dir: Path) -> Image.Image:
if isinstance(frame, (bytes, bytearray, memoryview)):
return Image.open(BytesIO(frame)).convert("RGB")
path = Path(str(frame))
if not path.is_absolute():
path = (Path.cwd() / path).resolve()
if not path.exists():
candidate = (base_dir / str(frame)).resolve()
if candidate.exists():
path = candidate
return Image.open(path).convert("RGB")
def load_model(*_args: Any, **_kwargs: Any) -> ModelBundle | None:
base_dir = Path(__file__).resolve().parent
if not (base_dir / "config.json").exists():
return None
model = SegformerForSemanticSegmentation.from_pretrained(str(base_dir))
try:
processor = AutoImageProcessor.from_pretrained(str(base_dir))
except OSError:
image_size = getattr(model.config, "image_size", 224)
if isinstance(image_size, int):
size = {"height": image_size, "width": image_size}
else:
size = image_size
processor = SegformerImageProcessor(size=size)
model.eval()
return model, processor
def resolve_person_id(model: SegformerForSemanticSegmentation, num_labels: int) -> int:
label2id = getattr(model.config, "label2id", {}) or {}
person_id = label2id.get("person")
if isinstance(person_id, int) and 0 <= person_id < num_labels:
return person_id
if num_labels >= 2:
return 1
return 0
def run_model(model_bundle: ModelBundle, frame: "np.ndarray") -> List[Dict[str, Any]]:
image = Image.fromarray(frame)
model, processor = model_bundle
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
num_labels = logits.shape[1]
person_id = resolve_person_id(model, num_labels)
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
probs = upsampled_logits.softmax(dim=1)
pred = probs.argmax(dim=1)[0]
mask = (pred == person_id).cpu().numpy()
if not mask.any():
return []
ys, xs = np.where(mask)
x_min = float(xs.min())
y_min = float(ys.min())
x_max = float(xs.max())
y_max = float(ys.max())
person_prob = probs[0, person_id].cpu().numpy()
score = float(person_prob[mask].mean())
return [
{
"frame_idx": 0,
"class": "person",
"bbox": [x_min, y_min, x_max, y_max],
"score": score,
"track_id": "f0-d0",
}
]
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run SegFormer person segmentation.")
parser.add_argument(
"--stdin-raw",
action="store_true",
default=True,
help="Read raw image bytes from stdin.",
)
return parser
if __name__ == "__main__":
args = build_parser().parse_args()
base_dir = Path(__file__).resolve().parent
model_bundle = load_model()
if model_bundle is None:
print("[]")
sys.exit(0)
try:
image = load_image(sys.stdin.buffer.read(), base_dir)
except Exception:
print("[]")
sys.exit(0)
frame = np.array(image)
output = run_model(model_bundle, frame)
print(json.dumps(output))