GeoViG-Demo / app.py
OmarAlasqa's picture
Show BBox for medical
ba7be42
"""
GeoViG Multi-Task Gradio Demo
"""
import sys, os, json, urllib.request
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
import numpy as np
import gradio as gr
print(f"DEBUG: System Python: {sys.version}")
print(f"DEBUG: PyTorch version: {torch.__version__}")
# ── Gradio Theme ──────────────────────────────────────────────────────────────
_THEME = gr.themes.Soft(primary_hue="indigo")
from geovig import geovig_ti, geovig_s, geovig_m, geovig_b
# ── Paths ──────────────────────────────────────────────────────────────────────
REPO_ID = "OmarAlasqa/GeoViG"
CKPT_FILES = {
"GeoViG-Ti": "pth/geovig_ti_5e4_8G_300_75_22/checkpoint.pth",
"GeoViG-S": "pth/geovig_s_5e4_8G_300_77_48/checkpoint.pth",
"GeoViG-M": "pth/geovig_m_5e4_8G_300_80_70/checkpoint.pth",
"GeoViG-B": "pth/geovig_b_5e4_8G_300_82_38/checkpoint.pth",
"det_m": "coco_det_seg_pth/geovig_m_det_seg/epoch_12.pth",
"det_b": "coco_det_seg_pth/geovig_b_det_seg/epoch_12.pth",
"kvasir_m": "medical/kvasir_geovig_m/checkpoint.pth",
"dsb_m": "medical/dsb_geovig_m/checkpoint.pth"
}
os.environ["HF_HUB_DISABLE_INTERACTIVE_FLOW"] = "1"
def get_ckpt_path(key):
try:
return hf_hub_download(repo_id=REPO_ID, filename=CKPT_FILES[key])
except Exception as e:
print(f"Error downloading {key}: {e}")
return None
# ── ImageNet labels ────────────────────────────────────────────────────────────
LABELS_FILE = "imagenet_labels.json"
if not os.path.exists(LABELS_FILE):
try:
urllib.request.urlretrieve(
"https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json",
LABELS_FILE)
except: pass
IMAGENET_LABELS = []
if os.path.exists(LABELS_FILE):
with open(LABELS_FILE) as f:
IMAGENET_LABELS = json.load(f)
# ── Pre-processing ─────────────────────────────────────────────────────────────
IMAGENET_TF = transforms.Compose([
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ── Model cache ────────────────────────────────────────────────────────────────
MODEL_CACHE = {}
BUILDERS = {"GeoViG-Ti": geovig_ti, "GeoViG-S": geovig_s, "GeoViG-M": geovig_m, "GeoViG-B": geovig_b}
def _load_cls_model(variant: str):
if variant in MODEL_CACHE: return MODEL_CACHE[variant]
model = BUILDERS[variant]()
path = get_ckpt_path(variant)
if not path: return None
sd = torch.load(path, map_location="cpu")
sd = sd.get("model", sd.get("state_dict", sd))
model.load_state_dict(sd, strict=False)
model.eval()
MODEL_CACHE[variant] = model
return model
def classify(image: Image.Image, variant: str):
if image is None: return {}
model = _load_cls_model(variant)
if not model: return {"Error: Weights not found": 1.0}
x = IMAGENET_TF(image.convert("RGB")).unsqueeze(0)
with torch.no_grad():
probs = F.softmax(model(x), dim=-1)[0]
top_probs, top_idxs = probs.topk(5)
return {IMAGENET_LABELS[i.item()] if i < len(IMAGENET_LABELS) else f"Class {i}": float(p) for p, i in zip(top_probs, top_idxs)}
# ── Detection Logic ───────────────────────────────────────────────────────────
COCO_NAMES = [
"person","bicycle","car","motorcycle","airplane","bus","train","truck","boat",
"traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat",
"dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack",
"umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball",
"kite","baseball bat","baseball glove","skateboard","surfboard","tennis racket",
"bottle","wine glass","cup","fork","knife","spoon","bowl","banana","apple",
"sandwich","orange","broccoli","carrot","hot dog","pizza","donut","cake","chair",
"couch","potted plant","bed","dining table","toilet","tv","laptop","mouse",
"remote","keyboard","cell phone","microwave","oven","toaster","sink","refrigerator",
"book","clock","vase","scissors","teddy bear","hair drier","toothbrush"
]
def _draw_detections(image: Image.Image, result, score_thr=0.3):
img_draw = image.convert("RGBA")
draw = ImageDraw.Draw(img_draw)
seg_np_img = np.array(image.convert("RGB"))
# Robust unpacking for mmdet 2.x
if isinstance(result, (list, tuple)) and len(result) == 2 and isinstance(result[0], (list, np.ndarray)):
bbox_result, segm_result = result
else:
bbox_result, segm_result = result, None
count = 0
if bbox_result is not None:
for cls_id, bboxes in enumerate(bbox_result):
if len(bboxes) == 0: continue
color = tuple(np.random.randint(100, 255, 3).tolist())
for i, bbox in enumerate(bboxes):
if bbox[4] < score_thr: continue
count += 1
x1, y1, x2, y2, score = bbox
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
label = f"{COCO_NAMES[cls_id]} {score:.2f}" if cls_id < len(COCO_NAMES) else f"cls {cls_id} {score:.2f}"
draw.text((x1 + 2, y1 + 2), label, fill=(255, 255, 0))
if segm_result is not None and len(segm_result) > cls_id and segm_result[cls_id] is not None:
try:
mask = segm_result[cls_id][i]
if isinstance(mask, dict):
import pycocotools.mask as mask_util
mask = mask_util.decode(mask)
mask_bool = mask.astype(bool)
seg_sample = np.array(color, dtype=np.uint8)
seg_np_img[mask_bool] = (seg_np_img[mask_bool] * 0.5 + seg_sample * 0.5).astype(np.uint8)
except Exception as e:
print(f"DEBUG: Mask blend failed: {e}")
print(f"DEBUG: Total detections drawn: {count}")
return img_draw.convert("RGB"), Image.fromarray(seg_np_img)
DET_MODEL_CACHE = {}
def detect(image: Image.Image, model_size: str, score_thr: float):
if image is None: return None, None
try:
import mmcv
print(f"DEBUG: MMCV version: {mmcv.__version__}")
import geovig_det_backbone
from mmdet.apis import init_detector, inference_detector
import cv2
except Exception as e:
print(f"DEBUG: Detection setup failed: {e}")
return image, image
key = f"det_{model_size.lower()}"
if key not in DET_MODEL_CACHE:
cfg = {"M": "configs/mask_rcnn_geovig_m_fpn_1x_coco.py", "B": "configs/mask_rcnn_geovig_b_fpn_1x_coco.py"}[model_size]
path = get_ckpt_path(key)
if not path: return image, image
DET_MODEL_CACHE[key] = init_detector(cfg, path, device="cpu")
cv_img = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
result = inference_detector(DET_MODEL_CACHE[key], cv_img)
return _draw_detections(image, result, score_thr=score_thr)
# ── Medical logic ─────────────────────────────────────────────────────────────
MED_MODEL_CACHE = {}
def segment_medical(image: Image.Image, dataset: str, score_thr: float):
if image is None: return None
try:
import geovig_det_backbone
from mmdet.apis import init_detector, inference_detector
import cv2
except Exception as e:
print(f"DEBUG: Medical setup failed: {e}")
return image
key = f"med_{dataset.lower().replace(' ', '_')}"
if key not in MED_MODEL_CACHE:
cfg = {"Kvasir-SEG": "configs/kvasir/mask_rcnn_geovig_m_fpn_1x_kvasir.py", "DSB 2018": "configs/dsb/mask_rcnn_geovig_m_fpn_1x_dsb.py"}[dataset]
v_key = "kvasir_m" if dataset == "Kvasir-SEG" else "dsb_m"
path = get_ckpt_path(v_key)
if not path: return image
MED_MODEL_CACHE[key] = init_detector(cfg, path, device="cpu")
cv_img = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
result = inference_detector(MED_MODEL_CACHE[key], cv_img)
if isinstance(result, (list, tuple)) and len(result) == 2:
bbox_result, segm_result = result
else:
bbox_result, segm_result = result, None
combined_mask = np.zeros(cv_img.shape[:2], dtype=np.float32)
if segm_result:
for cls_id, cls_masks in enumerate(segm_result):
for i, mask in enumerate(cls_masks):
if i < len(bbox_result[cls_id]) and bbox_result[cls_id][i, 4] >= score_thr:
if isinstance(mask, dict):
import pycocotools.mask as mask_util
mask = mask_util.decode(mask)
combined_mask = np.maximum(combined_mask, mask.astype(np.float32))
img_arr = np.array(image.convert("RGB")).astype(float)
overlay = img_arr.copy()
overlay[combined_mask > 0.5] = np.array([0, 200, 150], dtype=float)
blended = (img_arr * 0.4 + overlay * 0.6).astype(np.uint8)
img_pil = Image.fromarray(blended)
draw = ImageDraw.Draw(img_pil)
if bbox_result is not None:
for cls_id, bboxes in enumerate(bbox_result):
if len(bboxes) == 0: continue
for bbox in bboxes:
if bbox[4] < score_thr: continue
x1, y1, x2, y2, score = bbox
draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 255), width=3)
label = f"{score:.2f}"
draw.text((x1 + 2, y1 + 2), label, fill=(255, 255, 0))
return img_pil
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="GeoViG Full Demo", theme=_THEME) as demo:
gr.HTML("<h1 style='text-align:center'>πŸ”· GeoViG Multi-Task</h1>")
with gr.Tabs():
with gr.TabItem("πŸ–ΌοΈ Classification"):
with gr.Row():
with gr.Column():
cls_image = gr.Image(type="pil", label="Input Image")
cls_variant = gr.Radio(["GeoViG-Ti", "GeoViG-S", "GeoViG-M", "GeoViG-B"], value="GeoViG-Ti")
cls_btn = gr.Button("Predict", variant="primary")
cls_output = gr.Label(num_top_classes=5)
cls_btn.click(classify, [cls_image, cls_variant], cls_output)
gr.Examples(
examples=[["examples/n03394916_14162.JPEG", "GeoViG-Ti"], ["examples/n03417042_2960.JPEG", "GeoViG-S"]],
inputs=[cls_image, cls_variant]
)
with gr.TabItem("πŸ“¦ Natural Detection & Segmentation"):
with gr.Row():
with gr.Column():
det_image = gr.Image(type="pil", label="Input Image")
det_size = gr.Radio(["M", "B"], value="M", label="Backbone")
det_thr = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Score Threshold")
det_btn = gr.Button("Detect & Segment", variant="primary")
with gr.Row():
det_out = gr.Image(label="Bboxes", type="pil")
seg_out = gr.Image(label="Masks", type="pil")
det_btn.click(detect, [det_image, det_size, det_thr], [det_out, seg_out])
gr.Examples(
examples=[["examples/cat.png", "B", 0.8], ["examples/000000032081.jpg", "M", 0.7], ["examples/000000045472.jpg", "B", 0.7]],
inputs=[det_image, det_size, det_thr]
)
with gr.TabItem("πŸ₯ Medical Detection & Segmentation"):
with gr.Row():
with gr.Column():
med_image = gr.Image(type="pil", label="Anatomy Image")
med_dataset = gr.Radio(["Kvasir-SEG", "DSB 2018"], value="Kvasir-SEG")
med_med_thr = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Threshold")
med_btn = gr.Button("Detect & Segment", variant="primary")
med_out = gr.Image(label="Detection & Segmentation Result", type="pil")
med_btn.click(segment_medical, [med_image, med_dataset, med_med_thr], med_out)
gr.Examples(
examples=[
["examples/cju5y84q3mdv50817eyp82xf3.jpg", "Kvasir-SEG", 0.7],
["examples/0c2550a23b8a0f29a7575de8c61690d3c31bc897dd5ba66caec201d201a278c2.png", "DSB 2018", 0.7]
],
inputs=[med_image, med_dataset, med_med_thr]
)
demo.launch(server_name="0.0.0.0", server_port=7860)