Spaces:
Sleeping
Sleeping
| from typing import Tuple | |
| import streamlit as st | |
| import os | |
| import torch | |
| from .app_env import SOD_MODEL_TYPE | |
| from .app_utils import count_parameters | |
| from .smultimae_model import RGBDSMultiMAEModel | |
| from .base_model import BaseRGBDModel | |
| from .device import device | |
| from s_multimae.da.dav6 import DataAugmentationV6 | |
| from s_multimae.configs.base_config import base_cfg | |
| from s_multimae.configs.experiment_config import arg_cfg | |
| from s_multimae.model_pl import ModelPL | |
| # from spnet_model import SPNetModel | |
| def load_smultimae_model( | |
| sod_model_config_key: str, top: int | |
| ) -> Tuple[BaseRGBDModel, base_cfg]: | |
| """ | |
| 1. Construct model | |
| 2. Load pretrained weights | |
| 3. Load model into device | |
| """ | |
| cfg = arg_cfg[sod_model_config_key]() | |
| weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth" | |
| ckpt_path = os.path.join("weights", weights_fname) | |
| print(ckpt_path) | |
| if not os.path.isfile(ckpt_path): | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download( | |
| repo_id="RGBD-SOD/S-MultiMAE", | |
| filename=weights_fname, | |
| local_dir="weights", | |
| ) | |
| assert os.path.isfile(ckpt_path) | |
| # sod_model = ModelPL.load_from_checkpoint( | |
| # ckpt_path, | |
| # cfg=cfg, | |
| # map_location=device, | |
| # ) | |
| sod_model = ModelPL(cfg) | |
| sod_model.model.load_state_dict( | |
| torch.load(ckpt_path, map_location="cpu"), strict=False | |
| ) | |
| da = DataAugmentationV6(cfg) | |
| return RGBDSMultiMAEModel(cfg, sod_model), cfg, da | |
| # @st.cache_resource | |
| # def load_spnet_model() -> BaseRGBDModel: | |
| # """ | |
| # 1. Construct model | |
| # 2. Load pretrained weights | |
| # 3. Load model into device | |
| # """ | |
| # sod_model = SPNetModel() | |
| # return sod_model | |
| # @st.cache_resource | |
| # def load_bbsnet_model() -> BaseRGBDModel: | |
| # """ | |
| # 1. Construct model | |
| # 2. Load pretrained weights | |
| # 3. Load model into device | |
| # """ | |
| # sod_model = BBSNetModel() | |
| # return sod_model | |
| def sod_selection_ui() -> BaseRGBDModel: | |
| sod_model_type = st.selectbox( | |
| "Choose SOD model", | |
| ( | |
| SOD_MODEL_TYPE.S_MULTIMAE, | |
| # SOD_MODEL_TYPE.SPNET, | |
| # SOD_MODEL_TYPE.BBSNET, | |
| ), | |
| key="sod_model_type", | |
| ) | |
| if sod_model_type == SOD_MODEL_TYPE.S_MULTIMAE: | |
| d = { | |
| "S-MultiMAE [ViT-L] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2006"}, | |
| "S-MultiMAE [ViT-B] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2007"}, | |
| } | |
| sod_model_config_key = st.selectbox( | |
| "Choose config", | |
| list(d.keys()), | |
| key="sod_model_config_key", | |
| ) | |
| sod_model, cfg, da = load_smultimae_model( | |
| d[sod_model_config_key]["cfg"], d[sod_model_config_key]["top"] | |
| ) | |
| # st.text(f"Model description: {cfg.description}") | |
| # elif sod_model_type == SOD_MODEL_TYPE.SPNET: | |
| # sod_model = load_spnet_model() | |
| # st.text(f"Model description: SPNet (https://github.com/taozh2017/SPNet)") | |
| # elif sod_model_type == SOD_MODEL_TYPE.BBSNET: | |
| # sod_model = load_bbsnet_model() | |
| # st.text(f"Model description: BBSNet (https://github.com/DengPingFan/BBS-Net)") | |
| st.text(f"Number of parameters {count_parameters(sod_model)}") | |
| return sod_model, da | |