| from transformers import AutoImageProcessor, AutoModel | |
| from typing import Dict | |
| import numpy as np | |
| from matplotlib import cm | |
| from PIL import Image | |
| from torch import Tensor | |
| model = AutoModel.from_pretrained( | |
| "RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="model_cache" | |
| ) | |
| image_processor = AutoImageProcessor.from_pretrained( | |
| "RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="image_processor_cache" | |
| ) | |
| def inference(rgb: Image.Image, depth: Image.Image) -> Image.Image: | |
| rgb = rgb.convert(mode="RGB") | |
| depth = depth.convert(mode="L") | |
| preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess( | |
| { | |
| "rgb": rgb, | |
| "depth": depth, | |
| } | |
| ) | |
| output: Dict[str, Tensor] = model( | |
| preprocessed_sample["rgb"], preprocessed_sample["depth"] | |
| ) | |
| postprocessed_sample: np.ndarray = image_processor.postprocess( | |
| output["logits"], [rgb.size[1], rgb.size[0]] | |
| ) | |
| prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255)) | |
| return prediction | |
| if __name__ == "__main__": | |
| pass | |