import os import gradio as gr import numpy as np import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import tempfile from gradio.themes.utils import sizes from classes_and_palettes import GOLIATH_CLASSES # ========================================================= # Config # ========================================================= class Config: ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") CHECKPOINTS = { "0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2", "0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2", "1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", } # ========================================================= # Model # ========================================================= class ModelManager: _cache = {} @staticmethod def load_model(name: str): if name in ModelManager._cache: return ModelManager._cache[name] path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[name]) model = torch.jit.load(path) model.eval().to("cuda") ModelManager._cache[name] = model return model @staticmethod @torch.inference_mode() def run(model, x, h, w): out = model(x) out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=False) return out.argmax(1) # ========================================================= # Image Processing # ========================================================= class ImageProcessor: def __init__(self): self.tf = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize( mean=[123.5 / 255, 116.5 / 255, 103.5 / 255], std=[58.5 / 255, 57.0 / 255, 57.5 / 255], ), ]) def process(self, image: Image.Image, model_name: str): model = ModelManager.load_model(model_name) x = self.tf(image).unsqueeze(0).to("cuda") pred = ModelManager.run(model, x, image.height, image.width) mask = pred.squeeze(0).cpu().numpy() # Save raw mask npy_path = tempfile.mktemp(suffix=".npy") np.save(npy_path, mask) # Build AnnotatedImage output annotations = self._build_annotations(mask) return (image, annotations), npy_path def _build_annotations(self, mask: np.ndarray): annotations = [] for class_id in np.unique(mask): if class_id >= len(GOLIATH_CLASSES): continue binary_mask = (mask == class_id).astype(np.uint8) if binary_mask.sum() == 0: continue annotations.append( (binary_mask, GOLIATH_CLASSES[class_id]) ) return annotations # ========================================================= # UI # ========================================================= class GradioInterface: def __init__(self): self.processor = ImageProcessor() def create(self): def run(image, model): return self.processor.process(image, model) with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Input Image", type="pil", ) model_name = gr.Dropdown( label="Model Size", choices=list(Config.CHECKPOINTS.keys()), value="1b", ) run_btn = gr.Button("Run Segmentation", variant="primary") with gr.Column(scale=2): annotated = gr.AnnotatedImage( label="Segmentation Result", show_legend=True, height=512, ) mask_file = gr.File(label="Raw Mask (.npy)") run_btn.click( fn=run, inputs=[input_image, model_name], outputs=[annotated, mask_file], ) return demo # ========================================================= # Entrypoint # ========================================================= def main(): if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True app = GradioInterface().create() app.launch(server_name="0.0.0.0", share=False) if __name__ == "__main__": main()