| from mmseg.apis import MMSegInferencer | |
| from glob import glob | |
| from vegseg.datasets import L8BIOMEDataset | |
| import numpy as np | |
| from typing import List | |
| import os | |
| from PIL import Image | |
| from vegseg import models | |
| def get_palette() -> List[int]: | |
| """ | |
| get palette of dataset. | |
| return: | |
| palette: list of palette. | |
| """ | |
| palette = [] | |
| palette_list = L8BIOMEDataset.METAINFO["palette"] | |
| for palette_item in palette_list: | |
| palette.extend(palette_item) | |
| return palette | |
| def give_color_to_mask( | |
| mask: Image.Image | np.ndarray, palette: List[int] | |
| ) -> Image.Image: | |
| """ | |
| give color to mask. | |
| return: | |
| color_mask: color mask. | |
| """ | |
| color_mask = Image.fromarray(mask).convert("P") | |
| color_mask.putpalette(palette) | |
| return color_mask | |
| def main(): | |
| config_path = "work_dirs/experiment_p_l8/experiment_p_l8.py" | |
| weight_path = "work_dirs/experiment_p_l8/best_mIoU_iter_20000.pth" | |
| inference = MMSegInferencer( | |
| model=config_path, | |
| weights=weight_path, | |
| device="cuda:1", | |
| classes=L8BIOMEDataset.METAINFO["classes"], | |
| palette=L8BIOMEDataset.METAINFO["palette"], | |
| ) | |
| images = glob("data/vis/input/*.png") | |
| palette = get_palette() | |
| predictions = inference.__call__(images,batch_size=16)["predictions"] | |
| for image_path, prediction in zip(images, predictions): | |
| filename = os.path.basename(image_path) | |
| filename = os.path.join("data/vis/ktda",filename) | |
| prediction = prediction.astype(np.uint8) | |
| color_mask = give_color_to_mask(prediction, palette=palette) | |
| color_mask.save(filename) | |
| if __name__ == "__main__": | |
| main() | |