| | --- |
| | library_name: litert |
| | tags: |
| | - vision |
| | - image-segmentation |
| | --- |
| | # fcn resnet50 |
| |
|
| | ## Model description |
| |
|
| | The model was converted from a checkpoint from PyTorch Vision. |
| |
|
| | ## Use |
| |
|
| | $ pip install huggingface-hub matplotlib ai-edge-litert |
| |
|
| | Download https://github.com/pytorch/vision/raw/main/gallery/assets/dog1.jpg then run the following script as: |
| |
|
| | $ python your_script.py --image dog1.jpg |
| | |
| | |
| | ```python |
| | import argparse |
| | |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from ai_edge_litert.compiled_model import CompiledModel |
| | from ai_edge_litert.hardware_accelerator import HardwareAccelerator |
| | from huggingface_hub import hf_hub_download |
| | from PIL import Image |
| |
|
| | VOC_CATEGORIES = [ |
| | "__background__", |
| | "aeroplane", |
| | "bicycle", |
| | "bird", |
| | "boat", |
| | "bottle", |
| | "bus", |
| | "car", |
| | "cat", |
| | "chair", |
| | "cow", |
| | "diningtable", |
| | "dog", |
| | "horse", |
| | "motorbike", |
| | "person", |
| | "pottedplant", |
| | "sheep", |
| | "sofa", |
| | "train", |
| | "tvmonitor", |
| | ] |
| | |
| | |
| | def _download_model_from_hf(repo_id: str, filename: str | None = None) -> str: |
| | if filename: |
| | return hf_hub_download(repo_id=repo_id, filename=filename) |
| | candidates = ("fcn_resnet50_nchw.tflite", "fcn_resnet50.tflite") |
| | last_error: Exception | None = None |
| | for name in candidates: |
| | try: |
| | return hf_hub_download(repo_id=repo_id, filename=name) |
| | except Exception as err: # pylint: disable=broad-except |
| | last_error = err |
| | raise FileNotFoundError( |
| | f"Could not find expected model in {repo_id}: {', '.join(candidates)}" |
| | ) from last_error |
| | |
| |
|
| | def _load_cpu_model(model_path: str) -> CompiledModel: |
| | return CompiledModel.from_file(model_path, hardware_accel=HardwareAccelerator.CPU) |
| | |
| | |
| | def _infer_nchw_input_hw(model: CompiledModel) -> tuple[int, int]: |
| | req = model.get_input_buffer_requirements(0, 0) |
| | dims = req.get("dimensions") or req.get("shape") or req.get("dims") |
| | if not dims: |
| | return 520, 520 |
| | dims = [int(v) for v in dims] |
| | if len(dims) == 4 and dims[1] == 3: |
| | return dims[2], dims[3] |
| | if len(dims) == 3 and dims[0] == 3: |
| | return dims[1], dims[2] |
| | return 520, 520 |
| | |
| |
|
| | def _preprocess_nchw(image: Image.Image, input_h: int, input_w: int) -> np.ndarray: |
| | image = image.convert("RGB") |
| | w, h = image.size |
| | short = 520 |
| | if w < h: |
| | image = image.resize((short, int(round(h * short / w))), Image.BILINEAR) |
| | else: |
| | image = image.resize((int(round(w * short / h)), short), Image.BILINEAR) |
| | image = image.resize((input_w, input_h), Image.BILINEAR) |
| | x = np.asarray(image, dtype=np.float32) / 255.0 |
| | x = (x - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array( |
| | [0.229, 0.224, 0.225], dtype=np.float32 |
| | ) |
| | return np.transpose(x, (2, 0, 1)) |
| | |
| |
|
| | def _run_logits_hwc(model: CompiledModel, nchw_input: np.ndarray, num_classes: int) -> np.ndarray: |
| | inp = model.create_input_buffers(0) |
| | out = model.create_output_buffers(0) |
| | inp[0].write(nchw_input) |
| | model.run_by_index(0, inp, out) |
| |
|
| | req = model.get_output_buffer_requirements(0, 0) |
| | y = out[0].read(req["buffer_size"] // np.dtype(np.float32).itemsize, np.float32).reshape(-1) |
| | h, w = nchw_input.shape[1], nchw_input.shape[2] |
| | if y.size != num_classes * h * w: |
| | raise ValueError(f"Unexpected output size {y.size}; expected {num_classes * h * w}") |
| | chw = y.reshape(num_classes, h, w) |
| | return np.transpose(chw, (1, 2, 0)) |
| | |
| | |
| | def _softmax_last_axis(logits: np.ndarray) -> np.ndarray: |
| | logits = logits.astype(np.float32, copy=False) |
| | max_logits = np.max(logits, axis=-1, keepdims=True) |
| | exps = np.exp(logits - max_logits) |
| | return exps / np.sum(exps, axis=-1, keepdims=True) |
| |
|
| |
|
| | def _normalize_mask(mask_prob: np.ndarray) -> np.ndarray: |
| | mask = np.clip(mask_prob.astype(np.float32), 0.0, 1.0) |
| | lo = float(np.percentile(mask, 5.0)) |
| | hi = float(np.percentile(mask, 99.0)) |
| | if hi <= lo: |
| | hi = float(mask.max()) |
| | lo = float(mask.min()) |
| | if hi <= lo: |
| | return np.zeros_like(mask, dtype=np.float32) |
| | return np.clip((mask - lo) / (hi - lo), 0.0, 1.0) |
| | |
| |
|
| | def _build_overlay_rgb(base_rgb: np.ndarray, mask_prob: np.ndarray) -> np.ndarray: |
| | """Returns RGB overlay image in uint8.""" |
| | mask = _normalize_mask(mask_prob) |
| | heat = np.zeros_like(base_rgb, dtype=np.float32) |
| | heat[..., 0] = 255.0 |
| | heat[..., 1] = 180.0 * mask |
| | alpha = 0.75 * (np.clip((mask - 0.35) / 0.65, 0.0, 1.0) ** 1.5)[..., None] |
| | out = (1.0 - alpha) * base_rgb.astype(np.float32) + alpha * heat |
| | return np.clip(out, 0, 255).astype(np.uint8) |
| | |
| | |
| | def main() -> int: |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--image", required=True) |
| | ap.add_argument("--repo_id", default="litert-community/fcn_resnet50") |
| | ap.add_argument( |
| | "--model_file", |
| | default=None, |
| | help="Optional model filename in repo. If omitted, common names are tried.", |
| | ) |
| | ap.add_argument("--class_name", default="dog", choices=VOC_CATEGORIES) |
| | ap.add_argument( |
| | "--save_figure", |
| | default=None, |
| | help="Optional output PNG path for the matplotlib figure.", |
| | ) |
| | ap.add_argument( |
| | "--no_show", |
| | action="store_true", |
| | help="Do not open interactive window (use with --save_figure).", |
| | ) |
| | args = ap.parse_args() |
| | |
| | model_path = _download_model_from_hf(args.repo_id, filename=args.model_file) |
| | model = _load_cpu_model(model_path) |
| | |
| | image = Image.open(args.image).convert("RGB") |
| | input_h, input_w = _infer_nchw_input_hw(model) |
| | x_nchw = _preprocess_nchw(image, input_h, input_w) |
| | logits_hwc = _run_logits_hwc(model, x_nchw, len(VOC_CATEGORIES)) |
| | probs_hwc = _softmax_last_axis(logits_hwc) |
| | class_index = VOC_CATEGORIES.index(args.class_name) |
| | mask_prob = probs_hwc[..., class_index] |
| | |
| | # Resize mask to original image for display. |
| | mask_img = Image.fromarray((np.clip(mask_prob, 0.0, 1.0) * 255.0).astype(np.uint8), mode="L") |
| | mask_img = mask_img.resize(image.size, Image.BILINEAR) |
| | mask_prob_resized = np.asarray(mask_img, dtype=np.float32) / 255.0 |
| |
|
| | base_rgb = np.asarray(image, dtype=np.uint8) |
| | overlay_rgb = _build_overlay_rgb(base_rgb, mask_prob_resized) |
| |
|
| | fig, axs = plt.subplots(1, 3, figsize=(15, 5)) |
| | axs[0].imshow(base_rgb) |
| | axs[0].set_title("Input") |
| | axs[0].axis("off") |
| |
|
| | im = axs[1].imshow(mask_prob_resized, cmap="magma", vmin=0.0, vmax=1.0) |
| | axs[1].set_title(f"Class Prob: {args.class_name}") |
| | axs[1].axis("off") |
| | fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04) |
| |
|
| | axs[2].imshow(overlay_rgb) |
| | axs[2].set_title("Overlay") |
| | axs[2].axis("off") |
| |
|
| | fig.suptitle(f"Model: {model_path}\nClass: {args.class_name}") |
| | fig.tight_layout() |
| | |
| | if args.save_figure: |
| | fig.savefig(args.save_figure, dpi=180, bbox_inches="tight") |
| | print(f"Saved figure to: {args.save_figure}") |
| | |
| | if not args.no_show: |
| | plt.show() |
| | else: |
| | plt.close(fig) |
| | |
| | return 0 |
| | |
| | |
| | if __name__ == "__main__": |
| | raise SystemExit(main()) |
| | ``` |
| | |
| | |
| | ### BibTeX entry and citation info |
| | ```bibtex |
| | @article{DBLP:journals/corr/LongSD14, |
| | author = {Jonathan Long and |
| | Evan Shelhamer and |
| | Trevor Darrell}, |
| | title = {Fully Convolutional Networks for Semantic Segmentation}, |
| | journal = {CoRR}, |
| | volume = {abs/1411.4038}, |
| | year = {2014}, |
| | url = {http://arxiv.org/abs/1411.4038}, |
| | eprinttype = {arXiv}, |
| | eprint = {1411.4038}, |
| | timestamp = {Mon, 13 Aug 2018 16:48:17 +0200}, |
| | biburl = {https://dblp.org/rec/journals/corr/LongSD14.bib}, |
| | bibsource = {dblp computer science bibliography, https://dblp.org} |
| | } |
| | ``` |