seg / app.py
carpedm20's picture
Upload folder using huggingface_hub
2f05cc7 verified
import os
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import tempfile
from gradio.themes.utils import sizes
from classes_and_palettes import GOLIATH_CLASSES
# =========================================================
# Config
# =========================================================
class Config:
ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
CHECKPOINTS = {
"0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
"0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
"1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
}
# =========================================================
# Model
# =========================================================
class ModelManager:
_cache = {}
@staticmethod
def load_model(name: str):
if name in ModelManager._cache:
return ModelManager._cache[name]
path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[name])
model = torch.jit.load(path)
model.eval().to("cuda")
ModelManager._cache[name] = model
return model
@staticmethod
@torch.inference_mode()
def run(model, x, h, w):
out = model(x)
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=False)
return out.argmax(1)
# =========================================================
# Image Processing
# =========================================================
class ImageProcessor:
def __init__(self):
self.tf = transforms.Compose([
transforms.Resize((1024, 768)),
transforms.ToTensor(),
transforms.Normalize(
mean=[123.5 / 255, 116.5 / 255, 103.5 / 255],
std=[58.5 / 255, 57.0 / 255, 57.5 / 255],
),
])
def process(self, image: Image.Image, model_name: str):
model = ModelManager.load_model(model_name)
x = self.tf(image).unsqueeze(0).to("cuda")
pred = ModelManager.run(model, x, image.height, image.width)
mask = pred.squeeze(0).cpu().numpy()
# Save raw mask
npy_path = tempfile.mktemp(suffix=".npy")
np.save(npy_path, mask)
# Build AnnotatedImage output
annotations = self._build_annotations(mask)
return (image, annotations), npy_path
def _build_annotations(self, mask: np.ndarray):
annotations = []
for class_id in np.unique(mask):
if class_id >= len(GOLIATH_CLASSES):
continue
binary_mask = (mask == class_id).astype(np.uint8)
if binary_mask.sum() == 0:
continue
annotations.append(
(binary_mask, GOLIATH_CLASSES[class_id])
)
return annotations
# =========================================================
# UI
# =========================================================
class GradioInterface:
def __init__(self):
self.processor = ImageProcessor()
def create(self):
def run(image, model):
return self.processor.process(image, model)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Input Image",
type="pil",
)
model_name = gr.Dropdown(
label="Model Size",
choices=list(Config.CHECKPOINTS.keys()),
value="1b",
)
run_btn = gr.Button("Run Segmentation", variant="primary")
with gr.Column(scale=2):
annotated = gr.AnnotatedImage(
label="Segmentation Result",
show_legend=True,
height=512,
)
mask_file = gr.File(label="Raw Mask (.npy)")
run_btn.click(
fn=run,
inputs=[input_image, model_name],
outputs=[annotated, mask_file],
)
return demo
# =========================================================
# Entrypoint
# =========================================================
def main():
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
app = GradioInterface().create()
app.launch(server_name="0.0.0.0", share=False)
if __name__ == "__main__":
main()