Fully Convolutional Networks for Semantic Segmentation
Paper
•
1411.4038
•
Published
The model was converted from a checkpoint from PyTorch Vision.
$ pip install huggingface-hub matplotlib ai-edge-litert
Download https://github.com/pytorch/vision/raw/main/gallery/assets/dog1.jpg then run the following script as:
$ python your_script.py --image dog1.jpg
import argparse
import matplotlib.pyplot as plt
import numpy as np
from ai_edge_litert.compiled_model import CompiledModel
from ai_edge_litert.hardware_accelerator import HardwareAccelerator
from huggingface_hub import hf_hub_download
from PIL import Image
VOC_CATEGORIES = [
"__background__",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
def _download_model_from_hf(repo_id: str, filename: str | None = None) -> str:
if filename:
return hf_hub_download(repo_id=repo_id, filename=filename)
candidates = ("fcn_resnet50_nchw.tflite", "fcn_resnet50.tflite")
last_error: Exception | None = None
for name in candidates:
try:
return hf_hub_download(repo_id=repo_id, filename=name)
except Exception as err: # pylint: disable=broad-except
last_error = err
raise FileNotFoundError(
f"Could not find expected model in {repo_id}: {', '.join(candidates)}"
) from last_error
def _load_cpu_model(model_path: str) -> CompiledModel:
return CompiledModel.from_file(model_path, hardware_accel=HardwareAccelerator.CPU)
def _infer_nchw_input_hw(model: CompiledModel) -> tuple[int, int]:
req = model.get_input_buffer_requirements(0, 0)
dims = req.get("dimensions") or req.get("shape") or req.get("dims")
if not dims:
return 520, 520
dims = [int(v) for v in dims]
if len(dims) == 4 and dims[1] == 3:
return dims[2], dims[3]
if len(dims) == 3 and dims[0] == 3:
return dims[1], dims[2]
return 520, 520
def _preprocess_nchw(image: Image.Image, input_h: int, input_w: int) -> np.ndarray:
image = image.convert("RGB")
w, h = image.size
short = 520
if w < h:
image = image.resize((short, int(round(h * short / w))), Image.BILINEAR)
else:
image = image.resize((int(round(w * short / h)), short), Image.BILINEAR)
image = image.resize((input_w, input_h), Image.BILINEAR)
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array(
[0.229, 0.224, 0.225], dtype=np.float32
)
return np.transpose(x, (2, 0, 1))
def _run_logits_hwc(model: CompiledModel, nchw_input: np.ndarray, num_classes: int) -> np.ndarray:
inp = model.create_input_buffers(0)
out = model.create_output_buffers(0)
inp[0].write(nchw_input)
model.run_by_index(0, inp, out)
req = model.get_output_buffer_requirements(0, 0)
y = out[0].read(req["buffer_size"] // np.dtype(np.float32).itemsize, np.float32).reshape(-1)
h, w = nchw_input.shape[1], nchw_input.shape[2]
if y.size != num_classes * h * w:
raise ValueError(f"Unexpected output size {y.size}; expected {num_classes * h * w}")
chw = y.reshape(num_classes, h, w)
return np.transpose(chw, (1, 2, 0))
def _softmax_last_axis(logits: np.ndarray) -> np.ndarray:
logits = logits.astype(np.float32, copy=False)
max_logits = np.max(logits, axis=-1, keepdims=True)
exps = np.exp(logits - max_logits)
return exps / np.sum(exps, axis=-1, keepdims=True)
def _normalize_mask(mask_prob: np.ndarray) -> np.ndarray:
mask = np.clip(mask_prob.astype(np.float32), 0.0, 1.0)
lo = float(np.percentile(mask, 5.0))
hi = float(np.percentile(mask, 99.0))
if hi <= lo:
hi = float(mask.max())
lo = float(mask.min())
if hi <= lo:
return np.zeros_like(mask, dtype=np.float32)
return np.clip((mask - lo) / (hi - lo), 0.0, 1.0)
def _build_overlay_rgb(base_rgb: np.ndarray, mask_prob: np.ndarray) -> np.ndarray:
"""Returns RGB overlay image in uint8."""
mask = _normalize_mask(mask_prob)
heat = np.zeros_like(base_rgb, dtype=np.float32)
heat[..., 0] = 255.0
heat[..., 1] = 180.0 * mask
alpha = 0.75 * (np.clip((mask - 0.35) / 0.65, 0.0, 1.0) ** 1.5)[..., None]
out = (1.0 - alpha) * base_rgb.astype(np.float32) + alpha * heat
return np.clip(out, 0, 255).astype(np.uint8)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--image", required=True)
ap.add_argument("--repo_id", default="litert-community/fcn_resnet50")
ap.add_argument(
"--model_file",
default=None,
help="Optional model filename in repo. If omitted, common names are tried.",
)
ap.add_argument("--class_name", default="dog", choices=VOC_CATEGORIES)
ap.add_argument(
"--save_figure",
default=None,
help="Optional output PNG path for the matplotlib figure.",
)
ap.add_argument(
"--no_show",
action="store_true",
help="Do not open interactive window (use with --save_figure).",
)
args = ap.parse_args()
model_path = _download_model_from_hf(args.repo_id, filename=args.model_file)
model = _load_cpu_model(model_path)
image = Image.open(args.image).convert("RGB")
input_h, input_w = _infer_nchw_input_hw(model)
x_nchw = _preprocess_nchw(image, input_h, input_w)
logits_hwc = _run_logits_hwc(model, x_nchw, len(VOC_CATEGORIES))
probs_hwc = _softmax_last_axis(logits_hwc)
class_index = VOC_CATEGORIES.index(args.class_name)
mask_prob = probs_hwc[..., class_index]
# Resize mask to original image for display.
mask_img = Image.fromarray((np.clip(mask_prob, 0.0, 1.0) * 255.0).astype(np.uint8), mode="L")
mask_img = mask_img.resize(image.size, Image.BILINEAR)
mask_prob_resized = np.asarray(mask_img, dtype=np.float32) / 255.0
base_rgb = np.asarray(image, dtype=np.uint8)
overlay_rgb = _build_overlay_rgb(base_rgb, mask_prob_resized)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(base_rgb)
axs[0].set_title("Input")
axs[0].axis("off")
im = axs[1].imshow(mask_prob_resized, cmap="magma", vmin=0.0, vmax=1.0)
axs[1].set_title(f"Class Prob: {args.class_name}")
axs[1].axis("off")
fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04)
axs[2].imshow(overlay_rgb)
axs[2].set_title("Overlay")
axs[2].axis("off")
fig.suptitle(f"Model: {model_path}\nClass: {args.class_name}")
fig.tight_layout()
if args.save_figure:
fig.savefig(args.save_figure, dpi=180, bbox_inches="tight")
print(f"Saved figure to: {args.save_figure}")
if not args.no_show:
plt.show()
else:
plt.close(fig)
return 0
if __name__ == "__main__":
raise SystemExit(main())
@article{DBLP:journals/corr/LongSD14,
author = {Jonathan Long and
Evan Shelhamer and
Trevor Darrell},
title = {Fully Convolutional Networks for Semantic Segmentation},
journal = {CoRR},
volume = {abs/1411.4038},
year = {2014},
url = {http://arxiv.org/abs/1411.4038},
eprinttype = {arXiv},
eprint = {1411.4038},
timestamp = {Mon, 13 Aug 2018 16:48:17 +0200},
biburl = {https://dblp.org/rec/journals/corr/LongSD14.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}