Spaces:
Sleeping
Sleeping
| """ | |
| GeoViG Multi-Task Gradio Demo | |
| """ | |
| import sys, os, json, urllib.request | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image, ImageDraw | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import gradio as gr | |
| print(f"DEBUG: System Python: {sys.version}") | |
| print(f"DEBUG: PyTorch version: {torch.__version__}") | |
| # ββ Gradio Theme ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _THEME = gr.themes.Soft(primary_hue="indigo") | |
| from geovig import geovig_ti, geovig_s, geovig_m, geovig_b | |
| # ββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO_ID = "OmarAlasqa/GeoViG" | |
| CKPT_FILES = { | |
| "GeoViG-Ti": "pth/geovig_ti_5e4_8G_300_75_22/checkpoint.pth", | |
| "GeoViG-S": "pth/geovig_s_5e4_8G_300_77_48/checkpoint.pth", | |
| "GeoViG-M": "pth/geovig_m_5e4_8G_300_80_70/checkpoint.pth", | |
| "GeoViG-B": "pth/geovig_b_5e4_8G_300_82_38/checkpoint.pth", | |
| "det_m": "coco_det_seg_pth/geovig_m_det_seg/epoch_12.pth", | |
| "det_b": "coco_det_seg_pth/geovig_b_det_seg/epoch_12.pth", | |
| "kvasir_m": "medical/kvasir_geovig_m/checkpoint.pth", | |
| "dsb_m": "medical/dsb_geovig_m/checkpoint.pth" | |
| } | |
| os.environ["HF_HUB_DISABLE_INTERACTIVE_FLOW"] = "1" | |
| def get_ckpt_path(key): | |
| try: | |
| return hf_hub_download(repo_id=REPO_ID, filename=CKPT_FILES[key]) | |
| except Exception as e: | |
| print(f"Error downloading {key}: {e}") | |
| return None | |
| # ββ ImageNet labels ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LABELS_FILE = "imagenet_labels.json" | |
| if not os.path.exists(LABELS_FILE): | |
| try: | |
| urllib.request.urlretrieve( | |
| "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json", | |
| LABELS_FILE) | |
| except: pass | |
| IMAGENET_LABELS = [] | |
| if os.path.exists(LABELS_FILE): | |
| with open(LABELS_FILE) as f: | |
| IMAGENET_LABELS = json.load(f) | |
| # ββ Pre-processing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| IMAGENET_TF = transforms.Compose([ | |
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ββ Model cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_CACHE = {} | |
| BUILDERS = {"GeoViG-Ti": geovig_ti, "GeoViG-S": geovig_s, "GeoViG-M": geovig_m, "GeoViG-B": geovig_b} | |
| def _load_cls_model(variant: str): | |
| if variant in MODEL_CACHE: return MODEL_CACHE[variant] | |
| model = BUILDERS[variant]() | |
| path = get_ckpt_path(variant) | |
| if not path: return None | |
| sd = torch.load(path, map_location="cpu") | |
| sd = sd.get("model", sd.get("state_dict", sd)) | |
| model.load_state_dict(sd, strict=False) | |
| model.eval() | |
| MODEL_CACHE[variant] = model | |
| return model | |
| def classify(image: Image.Image, variant: str): | |
| if image is None: return {} | |
| model = _load_cls_model(variant) | |
| if not model: return {"Error: Weights not found": 1.0} | |
| x = IMAGENET_TF(image.convert("RGB")).unsqueeze(0) | |
| with torch.no_grad(): | |
| probs = F.softmax(model(x), dim=-1)[0] | |
| top_probs, top_idxs = probs.topk(5) | |
| return {IMAGENET_LABELS[i.item()] if i < len(IMAGENET_LABELS) else f"Class {i}": float(p) for p, i in zip(top_probs, top_idxs)} | |
| # ββ Detection Logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| COCO_NAMES = [ | |
| "person","bicycle","car","motorcycle","airplane","bus","train","truck","boat", | |
| "traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat", | |
| "dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack", | |
| "umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball", | |
| "kite","baseball bat","baseball glove","skateboard","surfboard","tennis racket", | |
| "bottle","wine glass","cup","fork","knife","spoon","bowl","banana","apple", | |
| "sandwich","orange","broccoli","carrot","hot dog","pizza","donut","cake","chair", | |
| "couch","potted plant","bed","dining table","toilet","tv","laptop","mouse", | |
| "remote","keyboard","cell phone","microwave","oven","toaster","sink","refrigerator", | |
| "book","clock","vase","scissors","teddy bear","hair drier","toothbrush" | |
| ] | |
| def _draw_detections(image: Image.Image, result, score_thr=0.3): | |
| img_draw = image.convert("RGBA") | |
| draw = ImageDraw.Draw(img_draw) | |
| seg_np_img = np.array(image.convert("RGB")) | |
| # Robust unpacking for mmdet 2.x | |
| if isinstance(result, (list, tuple)) and len(result) == 2 and isinstance(result[0], (list, np.ndarray)): | |
| bbox_result, segm_result = result | |
| else: | |
| bbox_result, segm_result = result, None | |
| count = 0 | |
| if bbox_result is not None: | |
| for cls_id, bboxes in enumerate(bbox_result): | |
| if len(bboxes) == 0: continue | |
| color = tuple(np.random.randint(100, 255, 3).tolist()) | |
| for i, bbox in enumerate(bboxes): | |
| if bbox[4] < score_thr: continue | |
| count += 1 | |
| x1, y1, x2, y2, score = bbox | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=4) | |
| label = f"{COCO_NAMES[cls_id]} {score:.2f}" if cls_id < len(COCO_NAMES) else f"cls {cls_id} {score:.2f}" | |
| draw.text((x1 + 2, y1 + 2), label, fill=(255, 255, 0)) | |
| if segm_result is not None and len(segm_result) > cls_id and segm_result[cls_id] is not None: | |
| try: | |
| mask = segm_result[cls_id][i] | |
| if isinstance(mask, dict): | |
| import pycocotools.mask as mask_util | |
| mask = mask_util.decode(mask) | |
| mask_bool = mask.astype(bool) | |
| seg_sample = np.array(color, dtype=np.uint8) | |
| seg_np_img[mask_bool] = (seg_np_img[mask_bool] * 0.5 + seg_sample * 0.5).astype(np.uint8) | |
| except Exception as e: | |
| print(f"DEBUG: Mask blend failed: {e}") | |
| print(f"DEBUG: Total detections drawn: {count}") | |
| return img_draw.convert("RGB"), Image.fromarray(seg_np_img) | |
| DET_MODEL_CACHE = {} | |
| def detect(image: Image.Image, model_size: str, score_thr: float): | |
| if image is None: return None, None | |
| try: | |
| import mmcv | |
| print(f"DEBUG: MMCV version: {mmcv.__version__}") | |
| import geovig_det_backbone | |
| from mmdet.apis import init_detector, inference_detector | |
| import cv2 | |
| except Exception as e: | |
| print(f"DEBUG: Detection setup failed: {e}") | |
| return image, image | |
| key = f"det_{model_size.lower()}" | |
| if key not in DET_MODEL_CACHE: | |
| cfg = {"M": "configs/mask_rcnn_geovig_m_fpn_1x_coco.py", "B": "configs/mask_rcnn_geovig_b_fpn_1x_coco.py"}[model_size] | |
| path = get_ckpt_path(key) | |
| if not path: return image, image | |
| DET_MODEL_CACHE[key] = init_detector(cfg, path, device="cpu") | |
| cv_img = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) | |
| result = inference_detector(DET_MODEL_CACHE[key], cv_img) | |
| return _draw_detections(image, result, score_thr=score_thr) | |
| # ββ Medical logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MED_MODEL_CACHE = {} | |
| def segment_medical(image: Image.Image, dataset: str, score_thr: float): | |
| if image is None: return None | |
| try: | |
| import geovig_det_backbone | |
| from mmdet.apis import init_detector, inference_detector | |
| import cv2 | |
| except Exception as e: | |
| print(f"DEBUG: Medical setup failed: {e}") | |
| return image | |
| key = f"med_{dataset.lower().replace(' ', '_')}" | |
| if key not in MED_MODEL_CACHE: | |
| cfg = {"Kvasir-SEG": "configs/kvasir/mask_rcnn_geovig_m_fpn_1x_kvasir.py", "DSB 2018": "configs/dsb/mask_rcnn_geovig_m_fpn_1x_dsb.py"}[dataset] | |
| v_key = "kvasir_m" if dataset == "Kvasir-SEG" else "dsb_m" | |
| path = get_ckpt_path(v_key) | |
| if not path: return image | |
| MED_MODEL_CACHE[key] = init_detector(cfg, path, device="cpu") | |
| cv_img = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) | |
| result = inference_detector(MED_MODEL_CACHE[key], cv_img) | |
| if isinstance(result, (list, tuple)) and len(result) == 2: | |
| bbox_result, segm_result = result | |
| else: | |
| bbox_result, segm_result = result, None | |
| combined_mask = np.zeros(cv_img.shape[:2], dtype=np.float32) | |
| if segm_result: | |
| for cls_id, cls_masks in enumerate(segm_result): | |
| for i, mask in enumerate(cls_masks): | |
| if i < len(bbox_result[cls_id]) and bbox_result[cls_id][i, 4] >= score_thr: | |
| if isinstance(mask, dict): | |
| import pycocotools.mask as mask_util | |
| mask = mask_util.decode(mask) | |
| combined_mask = np.maximum(combined_mask, mask.astype(np.float32)) | |
| img_arr = np.array(image.convert("RGB")).astype(float) | |
| overlay = img_arr.copy() | |
| overlay[combined_mask > 0.5] = np.array([0, 200, 150], dtype=float) | |
| blended = (img_arr * 0.4 + overlay * 0.6).astype(np.uint8) | |
| img_pil = Image.fromarray(blended) | |
| draw = ImageDraw.Draw(img_pil) | |
| if bbox_result is not None: | |
| for cls_id, bboxes in enumerate(bbox_result): | |
| if len(bboxes) == 0: continue | |
| for bbox in bboxes: | |
| if bbox[4] < score_thr: continue | |
| x1, y1, x2, y2, score = bbox | |
| draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 255), width=3) | |
| label = f"{score:.2f}" | |
| draw.text((x1 + 2, y1 + 2), label, fill=(255, 255, 0)) | |
| return img_pil | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="GeoViG Full Demo", theme=_THEME) as demo: | |
| gr.HTML("<h1 style='text-align:center'>π· GeoViG Multi-Task</h1>") | |
| with gr.Tabs(): | |
| with gr.TabItem("πΌοΈ Classification"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| cls_image = gr.Image(type="pil", label="Input Image") | |
| cls_variant = gr.Radio(["GeoViG-Ti", "GeoViG-S", "GeoViG-M", "GeoViG-B"], value="GeoViG-Ti") | |
| cls_btn = gr.Button("Predict", variant="primary") | |
| cls_output = gr.Label(num_top_classes=5) | |
| cls_btn.click(classify, [cls_image, cls_variant], cls_output) | |
| gr.Examples( | |
| examples=[["examples/n03394916_14162.JPEG", "GeoViG-Ti"], ["examples/n03417042_2960.JPEG", "GeoViG-S"]], | |
| inputs=[cls_image, cls_variant] | |
| ) | |
| with gr.TabItem("π¦ Natural Detection & Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| det_image = gr.Image(type="pil", label="Input Image") | |
| det_size = gr.Radio(["M", "B"], value="M", label="Backbone") | |
| det_thr = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Score Threshold") | |
| det_btn = gr.Button("Detect & Segment", variant="primary") | |
| with gr.Row(): | |
| det_out = gr.Image(label="Bboxes", type="pil") | |
| seg_out = gr.Image(label="Masks", type="pil") | |
| det_btn.click(detect, [det_image, det_size, det_thr], [det_out, seg_out]) | |
| gr.Examples( | |
| examples=[["examples/cat.png", "B", 0.8], ["examples/000000032081.jpg", "M", 0.7], ["examples/000000045472.jpg", "B", 0.7]], | |
| inputs=[det_image, det_size, det_thr] | |
| ) | |
| with gr.TabItem("π₯ Medical Detection & Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| med_image = gr.Image(type="pil", label="Anatomy Image") | |
| med_dataset = gr.Radio(["Kvasir-SEG", "DSB 2018"], value="Kvasir-SEG") | |
| med_med_thr = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Threshold") | |
| med_btn = gr.Button("Detect & Segment", variant="primary") | |
| med_out = gr.Image(label="Detection & Segmentation Result", type="pil") | |
| med_btn.click(segment_medical, [med_image, med_dataset, med_med_thr], med_out) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/cju5y84q3mdv50817eyp82xf3.jpg", "Kvasir-SEG", 0.7], | |
| ["examples/0c2550a23b8a0f29a7575de8c61690d3c31bc897dd5ba66caec201d201a278c2.png", "DSB 2018", 0.7] | |
| ], | |
| inputs=[med_image, med_dataset, med_med_thr] | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |