fcn resnet50

Model description

The model was converted from a checkpoint from PyTorch Vision.

Use

$ pip install huggingface-hub matplotlib ai-edge-litert

Download https://github.com/pytorch/vision/raw/main/gallery/assets/dog1.jpg then run the following script as:

$ python your_script.py --image dog1.jpg

import argparse

import matplotlib.pyplot as plt
import numpy as np
from ai_edge_litert.compiled_model import CompiledModel
from ai_edge_litert.hardware_accelerator import HardwareAccelerator
from huggingface_hub import hf_hub_download
from PIL import Image

VOC_CATEGORIES = [
    "__background__",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]


def _download_model_from_hf(repo_id: str, filename: str | None = None) -> str:
  if filename:
    return hf_hub_download(repo_id=repo_id, filename=filename)
  candidates = ("fcn_resnet50_nchw.tflite", "fcn_resnet50.tflite")
  last_error: Exception | None = None
  for name in candidates:
    try:
      return hf_hub_download(repo_id=repo_id, filename=name)
    except Exception as err:  # pylint: disable=broad-except
      last_error = err
  raise FileNotFoundError(
      f"Could not find expected model in {repo_id}: {', '.join(candidates)}"
  ) from last_error


def _load_cpu_model(model_path: str) -> CompiledModel:
  return CompiledModel.from_file(model_path, hardware_accel=HardwareAccelerator.CPU)


def _infer_nchw_input_hw(model: CompiledModel) -> tuple[int, int]:
  req = model.get_input_buffer_requirements(0, 0)
  dims = req.get("dimensions") or req.get("shape") or req.get("dims")
  if not dims:
    return 520, 520
  dims = [int(v) for v in dims]
  if len(dims) == 4 and dims[1] == 3:
    return dims[2], dims[3]
  if len(dims) == 3 and dims[0] == 3:
    return dims[1], dims[2]
  return 520, 520


def _preprocess_nchw(image: Image.Image, input_h: int, input_w: int) -> np.ndarray:
  image = image.convert("RGB")
  w, h = image.size
  short = 520
  if w < h:
    image = image.resize((short, int(round(h * short / w))), Image.BILINEAR)
  else:
    image = image.resize((int(round(w * short / h)), short), Image.BILINEAR)
  image = image.resize((input_w, input_h), Image.BILINEAR)
  x = np.asarray(image, dtype=np.float32) / 255.0
  x = (x - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array(
      [0.229, 0.224, 0.225], dtype=np.float32
  )
  return np.transpose(x, (2, 0, 1))


def _run_logits_hwc(model: CompiledModel, nchw_input: np.ndarray, num_classes: int) -> np.ndarray:
  inp = model.create_input_buffers(0)
  out = model.create_output_buffers(0)
  inp[0].write(nchw_input)
  model.run_by_index(0, inp, out)

  req = model.get_output_buffer_requirements(0, 0)
  y = out[0].read(req["buffer_size"] // np.dtype(np.float32).itemsize, np.float32).reshape(-1)
  h, w = nchw_input.shape[1], nchw_input.shape[2]
  if y.size != num_classes * h * w:
    raise ValueError(f"Unexpected output size {y.size}; expected {num_classes * h * w}")
  chw = y.reshape(num_classes, h, w)
  return np.transpose(chw, (1, 2, 0))


def _softmax_last_axis(logits: np.ndarray) -> np.ndarray:
  logits = logits.astype(np.float32, copy=False)
  max_logits = np.max(logits, axis=-1, keepdims=True)
  exps = np.exp(logits - max_logits)
  return exps / np.sum(exps, axis=-1, keepdims=True)


def _normalize_mask(mask_prob: np.ndarray) -> np.ndarray:
  mask = np.clip(mask_prob.astype(np.float32), 0.0, 1.0)
  lo = float(np.percentile(mask, 5.0))
  hi = float(np.percentile(mask, 99.0))
  if hi <= lo:
    hi = float(mask.max())
    lo = float(mask.min())
  if hi <= lo:
    return np.zeros_like(mask, dtype=np.float32)
  return np.clip((mask - lo) / (hi - lo), 0.0, 1.0)


def _build_overlay_rgb(base_rgb: np.ndarray, mask_prob: np.ndarray) -> np.ndarray:
  """Returns RGB overlay image in uint8."""
  mask = _normalize_mask(mask_prob)
  heat = np.zeros_like(base_rgb, dtype=np.float32)
  heat[..., 0] = 255.0
  heat[..., 1] = 180.0 * mask
  alpha = 0.75 * (np.clip((mask - 0.35) / 0.65, 0.0, 1.0) ** 1.5)[..., None]
  out = (1.0 - alpha) * base_rgb.astype(np.float32) + alpha * heat
  return np.clip(out, 0, 255).astype(np.uint8)


def main() -> int:
  ap = argparse.ArgumentParser()
  ap.add_argument("--image", required=True)
  ap.add_argument("--repo_id", default="litert-community/fcn_resnet50")
  ap.add_argument(
      "--model_file",
      default=None,
      help="Optional model filename in repo. If omitted, common names are tried.",
  )
  ap.add_argument("--class_name", default="dog", choices=VOC_CATEGORIES)
  ap.add_argument(
      "--save_figure",
      default=None,
      help="Optional output PNG path for the matplotlib figure.",
  )
  ap.add_argument(
      "--no_show",
      action="store_true",
      help="Do not open interactive window (use with --save_figure).",
  )
  args = ap.parse_args()

  model_path = _download_model_from_hf(args.repo_id, filename=args.model_file)
  model = _load_cpu_model(model_path)

  image = Image.open(args.image).convert("RGB")
  input_h, input_w = _infer_nchw_input_hw(model)
  x_nchw = _preprocess_nchw(image, input_h, input_w)
  logits_hwc = _run_logits_hwc(model, x_nchw, len(VOC_CATEGORIES))
  probs_hwc = _softmax_last_axis(logits_hwc)
  class_index = VOC_CATEGORIES.index(args.class_name)
  mask_prob = probs_hwc[..., class_index]

  # Resize mask to original image for display.
  mask_img = Image.fromarray((np.clip(mask_prob, 0.0, 1.0) * 255.0).astype(np.uint8), mode="L")
  mask_img = mask_img.resize(image.size, Image.BILINEAR)
  mask_prob_resized = np.asarray(mask_img, dtype=np.float32) / 255.0

  base_rgb = np.asarray(image, dtype=np.uint8)
  overlay_rgb = _build_overlay_rgb(base_rgb, mask_prob_resized)

  fig, axs = plt.subplots(1, 3, figsize=(15, 5))
  axs[0].imshow(base_rgb)
  axs[0].set_title("Input")
  axs[0].axis("off")

  im = axs[1].imshow(mask_prob_resized, cmap="magma", vmin=0.0, vmax=1.0)
  axs[1].set_title(f"Class Prob: {args.class_name}")
  axs[1].axis("off")
  fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04)

  axs[2].imshow(overlay_rgb)
  axs[2].set_title("Overlay")
  axs[2].axis("off")

  fig.suptitle(f"Model: {model_path}\nClass: {args.class_name}")
  fig.tight_layout()

  if args.save_figure:
    fig.savefig(args.save_figure, dpi=180, bbox_inches="tight")
    print(f"Saved figure to: {args.save_figure}")

  if not args.no_show:
    plt.show()
  else:
    plt.close(fig)

  return 0


if __name__ == "__main__":
  raise SystemExit(main())

BibTeX entry and citation info

@article{DBLP:journals/corr/LongSD14,
  author       = {Jonathan Long and
                  Evan Shelhamer and
                  Trevor Darrell},
  title        = {Fully Convolutional Networks for Semantic Segmentation},
  journal      = {CoRR},
  volume       = {abs/1411.4038},
  year         = {2014},
  url          = {http://arxiv.org/abs/1411.4038},
  eprinttype    = {arXiv},
  eprint       = {1411.4038},
  timestamp    = {Mon, 13 Aug 2018 16:48:17 +0200},
  biburl       = {https://dblp.org/rec/journals/corr/LongSD14.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}
Downloads last month
16
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for litert-community/fcn_resnet50