--- library_name: litert tags: - vision - image-segmentation --- # fcn resnet50 ## Model description The model was converted from a checkpoint from PyTorch Vision. ## Use $ pip install huggingface-hub matplotlib ai-edge-litert Download https://github.com/pytorch/vision/raw/main/gallery/assets/dog1.jpg then run the following script as: $ python your_script.py --image dog1.jpg ```python import argparse import matplotlib.pyplot as plt import numpy as np from ai_edge_litert.compiled_model import CompiledModel from ai_edge_litert.hardware_accelerator import HardwareAccelerator from huggingface_hub import hf_hub_download from PIL import Image VOC_CATEGORIES = [ "__background__", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ] def _download_model_from_hf(repo_id: str, filename: str | None = None) -> str: if filename: return hf_hub_download(repo_id=repo_id, filename=filename) candidates = ("fcn_resnet50_nchw.tflite", "fcn_resnet50.tflite") last_error: Exception | None = None for name in candidates: try: return hf_hub_download(repo_id=repo_id, filename=name) except Exception as err: # pylint: disable=broad-except last_error = err raise FileNotFoundError( f"Could not find expected model in {repo_id}: {', '.join(candidates)}" ) from last_error def _load_cpu_model(model_path: str) -> CompiledModel: return CompiledModel.from_file(model_path, hardware_accel=HardwareAccelerator.CPU) def _infer_nchw_input_hw(model: CompiledModel) -> tuple[int, int]: req = model.get_input_buffer_requirements(0, 0) dims = req.get("dimensions") or req.get("shape") or req.get("dims") if not dims: return 520, 520 dims = [int(v) for v in dims] if len(dims) == 4 and dims[1] == 3: return dims[2], dims[3] if len(dims) == 3 and dims[0] == 3: return dims[1], dims[2] return 520, 520 def _preprocess_nchw(image: Image.Image, input_h: int, input_w: int) -> np.ndarray: image = image.convert("RGB") w, h = image.size short = 520 if w < h: image = image.resize((short, int(round(h * short / w))), Image.BILINEAR) else: image = image.resize((int(round(w * short / h)), short), Image.BILINEAR) image = image.resize((input_w, input_h), Image.BILINEAR) x = np.asarray(image, dtype=np.float32) / 255.0 x = (x - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array( [0.229, 0.224, 0.225], dtype=np.float32 ) return np.transpose(x, (2, 0, 1)) def _run_logits_hwc(model: CompiledModel, nchw_input: np.ndarray, num_classes: int) -> np.ndarray: inp = model.create_input_buffers(0) out = model.create_output_buffers(0) inp[0].write(nchw_input) model.run_by_index(0, inp, out) req = model.get_output_buffer_requirements(0, 0) y = out[0].read(req["buffer_size"] // np.dtype(np.float32).itemsize, np.float32).reshape(-1) h, w = nchw_input.shape[1], nchw_input.shape[2] if y.size != num_classes * h * w: raise ValueError(f"Unexpected output size {y.size}; expected {num_classes * h * w}") chw = y.reshape(num_classes, h, w) return np.transpose(chw, (1, 2, 0)) def _softmax_last_axis(logits: np.ndarray) -> np.ndarray: logits = logits.astype(np.float32, copy=False) max_logits = np.max(logits, axis=-1, keepdims=True) exps = np.exp(logits - max_logits) return exps / np.sum(exps, axis=-1, keepdims=True) def _normalize_mask(mask_prob: np.ndarray) -> np.ndarray: mask = np.clip(mask_prob.astype(np.float32), 0.0, 1.0) lo = float(np.percentile(mask, 5.0)) hi = float(np.percentile(mask, 99.0)) if hi <= lo: hi = float(mask.max()) lo = float(mask.min()) if hi <= lo: return np.zeros_like(mask, dtype=np.float32) return np.clip((mask - lo) / (hi - lo), 0.0, 1.0) def _build_overlay_rgb(base_rgb: np.ndarray, mask_prob: np.ndarray) -> np.ndarray: """Returns RGB overlay image in uint8.""" mask = _normalize_mask(mask_prob) heat = np.zeros_like(base_rgb, dtype=np.float32) heat[..., 0] = 255.0 heat[..., 1] = 180.0 * mask alpha = 0.75 * (np.clip((mask - 0.35) / 0.65, 0.0, 1.0) ** 1.5)[..., None] out = (1.0 - alpha) * base_rgb.astype(np.float32) + alpha * heat return np.clip(out, 0, 255).astype(np.uint8) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--image", required=True) ap.add_argument("--repo_id", default="litert-community/fcn_resnet50") ap.add_argument( "--model_file", default=None, help="Optional model filename in repo. If omitted, common names are tried.", ) ap.add_argument("--class_name", default="dog", choices=VOC_CATEGORIES) ap.add_argument( "--save_figure", default=None, help="Optional output PNG path for the matplotlib figure.", ) ap.add_argument( "--no_show", action="store_true", help="Do not open interactive window (use with --save_figure).", ) args = ap.parse_args() model_path = _download_model_from_hf(args.repo_id, filename=args.model_file) model = _load_cpu_model(model_path) image = Image.open(args.image).convert("RGB") input_h, input_w = _infer_nchw_input_hw(model) x_nchw = _preprocess_nchw(image, input_h, input_w) logits_hwc = _run_logits_hwc(model, x_nchw, len(VOC_CATEGORIES)) probs_hwc = _softmax_last_axis(logits_hwc) class_index = VOC_CATEGORIES.index(args.class_name) mask_prob = probs_hwc[..., class_index] # Resize mask to original image for display. mask_img = Image.fromarray((np.clip(mask_prob, 0.0, 1.0) * 255.0).astype(np.uint8), mode="L") mask_img = mask_img.resize(image.size, Image.BILINEAR) mask_prob_resized = np.asarray(mask_img, dtype=np.float32) / 255.0 base_rgb = np.asarray(image, dtype=np.uint8) overlay_rgb = _build_overlay_rgb(base_rgb, mask_prob_resized) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) axs[0].imshow(base_rgb) axs[0].set_title("Input") axs[0].axis("off") im = axs[1].imshow(mask_prob_resized, cmap="magma", vmin=0.0, vmax=1.0) axs[1].set_title(f"Class Prob: {args.class_name}") axs[1].axis("off") fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04) axs[2].imshow(overlay_rgb) axs[2].set_title("Overlay") axs[2].axis("off") fig.suptitle(f"Model: {model_path}\nClass: {args.class_name}") fig.tight_layout() if args.save_figure: fig.savefig(args.save_figure, dpi=180, bbox_inches="tight") print(f"Saved figure to: {args.save_figure}") if not args.no_show: plt.show() else: plt.close(fig) return 0 if __name__ == "__main__": raise SystemExit(main()) ``` ### BibTeX entry and citation info ```bibtex @article{DBLP:journals/corr/LongSD14, author = {Jonathan Long and Evan Shelhamer and Trevor Darrell}, title = {Fully Convolutional Networks for Semantic Segmentation}, journal = {CoRR}, volume = {abs/1411.4038}, year = {2014}, url = {http://arxiv.org/abs/1411.4038}, eprinttype = {arXiv}, eprint = {1411.4038}, timestamp = {Mon, 13 Aug 2018 16:48:17 +0200}, biburl = {https://dblp.org/rec/journals/corr/LongSD14.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} } ```