""" GeoViG Multi-Task Gradio Demo """ import sys, os, json, urllib.request import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image, ImageDraw from huggingface_hub import hf_hub_download import numpy as np import gradio as gr print(f"DEBUG: System Python: {sys.version}") print(f"DEBUG: PyTorch version: {torch.__version__}") # ── Gradio Theme ────────────────────────────────────────────────────────────── _THEME = gr.themes.Soft(primary_hue="indigo") from geovig import geovig_ti, geovig_s, geovig_m, geovig_b # ── Paths ────────────────────────────────────────────────────────────────────── REPO_ID = "OmarAlasqa/GeoViG" CKPT_FILES = { "GeoViG-Ti": "pth/geovig_ti_5e4_8G_300_75_22/checkpoint.pth", "GeoViG-S": "pth/geovig_s_5e4_8G_300_77_48/checkpoint.pth", "GeoViG-M": "pth/geovig_m_5e4_8G_300_80_70/checkpoint.pth", "GeoViG-B": "pth/geovig_b_5e4_8G_300_82_38/checkpoint.pth", "det_m": "coco_det_seg_pth/geovig_m_det_seg/epoch_12.pth", "det_b": "coco_det_seg_pth/geovig_b_det_seg/epoch_12.pth", "kvasir_m": "medical/kvasir_geovig_m/checkpoint.pth", "dsb_m": "medical/dsb_geovig_m/checkpoint.pth" } os.environ["HF_HUB_DISABLE_INTERACTIVE_FLOW"] = "1" def get_ckpt_path(key): try: return hf_hub_download(repo_id=REPO_ID, filename=CKPT_FILES[key]) except Exception as e: print(f"Error downloading {key}: {e}") return None # ── ImageNet labels ──────────────────────────────────────────────────────────── LABELS_FILE = "imagenet_labels.json" if not os.path.exists(LABELS_FILE): try: urllib.request.urlretrieve( "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json", LABELS_FILE) except: pass IMAGENET_LABELS = [] if os.path.exists(LABELS_FILE): with open(LABELS_FILE) as f: IMAGENET_LABELS = json.load(f) # ── Pre-processing ───────────────────────────────────────────────────────────── IMAGENET_TF = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ── Model cache ──────────────────────────────────────────────────────────────── MODEL_CACHE = {} BUILDERS = {"GeoViG-Ti": geovig_ti, "GeoViG-S": geovig_s, "GeoViG-M": geovig_m, "GeoViG-B": geovig_b} def _load_cls_model(variant: str): if variant in MODEL_CACHE: return MODEL_CACHE[variant] model = BUILDERS[variant]() path = get_ckpt_path(variant) if not path: return None sd = torch.load(path, map_location="cpu") sd = sd.get("model", sd.get("state_dict", sd)) model.load_state_dict(sd, strict=False) model.eval() MODEL_CACHE[variant] = model return model def classify(image: Image.Image, variant: str): if image is None: return {} model = _load_cls_model(variant) if not model: return {"Error: Weights not found": 1.0} x = IMAGENET_TF(image.convert("RGB")).unsqueeze(0) with torch.no_grad(): probs = F.softmax(model(x), dim=-1)[0] top_probs, top_idxs = probs.topk(5) return {IMAGENET_LABELS[i.item()] if i < len(IMAGENET_LABELS) else f"Class {i}": float(p) for p, i in zip(top_probs, top_idxs)} # ── Detection Logic ─────────────────────────────────────────────────────────── COCO_NAMES = [ "person","bicycle","car","motorcycle","airplane","bus","train","truck","boat", "traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat", "dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack", "umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball", "kite","baseball bat","baseball glove","skateboard","surfboard","tennis racket", "bottle","wine glass","cup","fork","knife","spoon","bowl","banana","apple", "sandwich","orange","broccoli","carrot","hot dog","pizza","donut","cake","chair", "couch","potted plant","bed","dining table","toilet","tv","laptop","mouse", "remote","keyboard","cell phone","microwave","oven","toaster","sink","refrigerator", "book","clock","vase","scissors","teddy bear","hair drier","toothbrush" ] def _draw_detections(image: Image.Image, result, score_thr=0.3): img_draw = image.convert("RGBA") draw = ImageDraw.Draw(img_draw) seg_np_img = np.array(image.convert("RGB")) # Robust unpacking for mmdet 2.x if isinstance(result, (list, tuple)) and len(result) == 2 and isinstance(result[0], (list, np.ndarray)): bbox_result, segm_result = result else: bbox_result, segm_result = result, None count = 0 if bbox_result is not None: for cls_id, bboxes in enumerate(bbox_result): if len(bboxes) == 0: continue color = tuple(np.random.randint(100, 255, 3).tolist()) for i, bbox in enumerate(bboxes): if bbox[4] < score_thr: continue count += 1 x1, y1, x2, y2, score = bbox draw.rectangle([x1, y1, x2, y2], outline=color, width=4) label = f"{COCO_NAMES[cls_id]} {score:.2f}" if cls_id < len(COCO_NAMES) else f"cls {cls_id} {score:.2f}" draw.text((x1 + 2, y1 + 2), label, fill=(255, 255, 0)) if segm_result is not None and len(segm_result) > cls_id and segm_result[cls_id] is not None: try: mask = segm_result[cls_id][i] if isinstance(mask, dict): import pycocotools.mask as mask_util mask = mask_util.decode(mask) mask_bool = mask.astype(bool) seg_sample = np.array(color, dtype=np.uint8) seg_np_img[mask_bool] = (seg_np_img[mask_bool] * 0.5 + seg_sample * 0.5).astype(np.uint8) except Exception as e: print(f"DEBUG: Mask blend failed: {e}") print(f"DEBUG: Total detections drawn: {count}") return img_draw.convert("RGB"), Image.fromarray(seg_np_img) DET_MODEL_CACHE = {} def detect(image: Image.Image, model_size: str, score_thr: float): if image is None: return None, None try: import mmcv print(f"DEBUG: MMCV version: {mmcv.__version__}") import geovig_det_backbone from mmdet.apis import init_detector, inference_detector import cv2 except Exception as e: print(f"DEBUG: Detection setup failed: {e}") return image, image key = f"det_{model_size.lower()}" if key not in DET_MODEL_CACHE: cfg = {"M": "configs/mask_rcnn_geovig_m_fpn_1x_coco.py", "B": "configs/mask_rcnn_geovig_b_fpn_1x_coco.py"}[model_size] path = get_ckpt_path(key) if not path: return image, image DET_MODEL_CACHE[key] = init_detector(cfg, path, device="cpu") cv_img = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) result = inference_detector(DET_MODEL_CACHE[key], cv_img) return _draw_detections(image, result, score_thr=score_thr) # ── Medical logic ───────────────────────────────────────────────────────────── MED_MODEL_CACHE = {} def segment_medical(image: Image.Image, dataset: str, score_thr: float): if image is None: return None try: import geovig_det_backbone from mmdet.apis import init_detector, inference_detector import cv2 except Exception as e: print(f"DEBUG: Medical setup failed: {e}") return image key = f"med_{dataset.lower().replace(' ', '_')}" if key not in MED_MODEL_CACHE: cfg = {"Kvasir-SEG": "configs/kvasir/mask_rcnn_geovig_m_fpn_1x_kvasir.py", "DSB 2018": "configs/dsb/mask_rcnn_geovig_m_fpn_1x_dsb.py"}[dataset] v_key = "kvasir_m" if dataset == "Kvasir-SEG" else "dsb_m" path = get_ckpt_path(v_key) if not path: return image MED_MODEL_CACHE[key] = init_detector(cfg, path, device="cpu") cv_img = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) result = inference_detector(MED_MODEL_CACHE[key], cv_img) if isinstance(result, (list, tuple)) and len(result) == 2: bbox_result, segm_result = result else: bbox_result, segm_result = result, None combined_mask = np.zeros(cv_img.shape[:2], dtype=np.float32) if segm_result: for cls_id, cls_masks in enumerate(segm_result): for i, mask in enumerate(cls_masks): if i < len(bbox_result[cls_id]) and bbox_result[cls_id][i, 4] >= score_thr: if isinstance(mask, dict): import pycocotools.mask as mask_util mask = mask_util.decode(mask) combined_mask = np.maximum(combined_mask, mask.astype(np.float32)) img_arr = np.array(image.convert("RGB")).astype(float) overlay = img_arr.copy() overlay[combined_mask > 0.5] = np.array([0, 200, 150], dtype=float) blended = (img_arr * 0.4 + overlay * 0.6).astype(np.uint8) img_pil = Image.fromarray(blended) draw = ImageDraw.Draw(img_pil) if bbox_result is not None: for cls_id, bboxes in enumerate(bbox_result): if len(bboxes) == 0: continue for bbox in bboxes: if bbox[4] < score_thr: continue x1, y1, x2, y2, score = bbox draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 255), width=3) label = f"{score:.2f}" draw.text((x1 + 2, y1 + 2), label, fill=(255, 255, 0)) return img_pil # ── UI ──────────────────────────────────────────────────────────────────────── with gr.Blocks(title="GeoViG Full Demo", theme=_THEME) as demo: gr.HTML("