|
|
import os |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import tempfile |
|
|
|
|
|
from gradio.themes.utils import sizes |
|
|
from classes_and_palettes import GOLIATH_CLASSES |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") |
|
|
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") |
|
|
CHECKPOINTS = { |
|
|
"0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2", |
|
|
"0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2", |
|
|
"1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
_cache = {} |
|
|
|
|
|
@staticmethod |
|
|
def load_model(name: str): |
|
|
if name in ModelManager._cache: |
|
|
return ModelManager._cache[name] |
|
|
|
|
|
path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[name]) |
|
|
model = torch.jit.load(path) |
|
|
model.eval().to("cuda") |
|
|
ModelManager._cache[name] = model |
|
|
return model |
|
|
|
|
|
@staticmethod |
|
|
@torch.inference_mode() |
|
|
def run(model, x, h, w): |
|
|
out = model(x) |
|
|
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=False) |
|
|
return out.argmax(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageProcessor: |
|
|
def __init__(self): |
|
|
self.tf = transforms.Compose([ |
|
|
transforms.Resize((1024, 768)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[123.5 / 255, 116.5 / 255, 103.5 / 255], |
|
|
std=[58.5 / 255, 57.0 / 255, 57.5 / 255], |
|
|
), |
|
|
]) |
|
|
|
|
|
def process(self, image: Image.Image, model_name: str): |
|
|
model = ModelManager.load_model(model_name) |
|
|
x = self.tf(image).unsqueeze(0).to("cuda") |
|
|
|
|
|
pred = ModelManager.run(model, x, image.height, image.width) |
|
|
mask = pred.squeeze(0).cpu().numpy() |
|
|
|
|
|
|
|
|
npy_path = tempfile.mktemp(suffix=".npy") |
|
|
np.save(npy_path, mask) |
|
|
|
|
|
|
|
|
annotations = self._build_annotations(mask) |
|
|
|
|
|
return (image, annotations), npy_path |
|
|
|
|
|
def _build_annotations(self, mask: np.ndarray): |
|
|
annotations = [] |
|
|
for class_id in np.unique(mask): |
|
|
if class_id >= len(GOLIATH_CLASSES): |
|
|
continue |
|
|
|
|
|
binary_mask = (mask == class_id).astype(np.uint8) |
|
|
if binary_mask.sum() == 0: |
|
|
continue |
|
|
|
|
|
annotations.append( |
|
|
(binary_mask, GOLIATH_CLASSES[class_id]) |
|
|
) |
|
|
|
|
|
return annotations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradioInterface: |
|
|
def __init__(self): |
|
|
self.processor = ImageProcessor() |
|
|
|
|
|
def create(self): |
|
|
def run(image, model): |
|
|
return self.processor.process(image, model) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_image = gr.Image( |
|
|
label="Input Image", |
|
|
type="pil", |
|
|
) |
|
|
|
|
|
model_name = gr.Dropdown( |
|
|
label="Model Size", |
|
|
choices=list(Config.CHECKPOINTS.keys()), |
|
|
value="1b", |
|
|
) |
|
|
|
|
|
run_btn = gr.Button("Run Segmentation", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
annotated = gr.AnnotatedImage( |
|
|
label="Segmentation Result", |
|
|
show_legend=True, |
|
|
height=512, |
|
|
) |
|
|
|
|
|
mask_file = gr.File(label="Raw Mask (.npy)") |
|
|
|
|
|
run_btn.click( |
|
|
fn=run, |
|
|
inputs=[input_image, model_name], |
|
|
outputs=[annotated, mask_file], |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
app = GradioInterface().create() |
|
|
app.launch(server_name="0.0.0.0", share=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|