| |
| |
| |
|
|
| import os |
| import re |
| import time |
| from typing import Any, Dict, Tuple, Optional |
|
|
| import streamlit as st |
| import numpy as np |
| from PIL import Image, ImageDraw |
|
|
| import torch |
|
|
| try: |
| from torchvision.ops import nms as tv_nms |
| except Exception: |
| tv_nms = None |
|
|
| |
| from models.gdcount_model import GDCountConfig, build_gdcount_model |
|
|
|
|
| |
| |
| |
|
|
| def sanitize_caption(p: str) -> str: |
| p = "" if p is None else str(p) |
| p = p.strip() |
| p = re.sub(r"\s+", " ", p) |
| p = re.sub(r"\s+\.", ".", p) |
| if p == "" or p == ".": |
| p = "object." |
| if not p.endswith("."): |
| p = p + "." |
| return p |
|
|
|
|
| def _scores_from_pred_logits(outputs: Dict[str, Any]) -> torch.Tensor: |
| """ |
| scores (B,Q) = sigmoid(max_token_logit over valid tokens) hoặc sigmoid(logit) nếu 2D. |
| """ |
| logits = outputs["pred_logits"].float() |
|
|
| if logits.dim() == 2: |
| logits = torch.where(torch.isfinite(logits), logits, torch.full_like(logits, -1e4)) |
| return torch.sigmoid(logits) |
|
|
| B, Q, T = logits.shape |
|
|
| token_mask = outputs.get("text_mask", None) |
| if token_mask is None: |
| token_mask = torch.ones((B, T), device=logits.device, dtype=torch.bool) |
| else: |
| token_mask = token_mask.to(device=logits.device, dtype=torch.bool) |
| if token_mask.shape[-1] < T: |
| pad = torch.zeros((B, T - token_mask.shape[-1]), device=logits.device, dtype=torch.bool) |
| token_mask = torch.cat([token_mask, pad], dim=-1) |
| token_mask = token_mask[:, :T] |
|
|
| input_ids = outputs.get("input_ids", None) |
| if input_ids is not None: |
| ids = input_ids.to(device=logits.device) |
| if ids.shape[-1] < T: |
| pad = torch.zeros((B, T - ids.shape[-1]), device=logits.device, dtype=ids.dtype) |
| ids = torch.cat([ids, pad], dim=-1) |
| ids = ids[:, :T] |
| specials = (ids == 0) | (ids == 101) | (ids == 102) |
| token_mask = token_mask & (~specials) |
|
|
| logits = torch.where(torch.isfinite(logits), logits, torch.full_like(logits, -1e4)) |
| logits = logits.masked_fill(~token_mask[:, None, :], -1e4) |
|
|
| per_q = logits.max(dim=-1).values |
| return torch.sigmoid(per_q) |
|
|
|
|
| def _cxcywh_to_xyxy(boxes: torch.Tensor) -> torch.Tensor: |
| cx, cy, w, h = boxes.unbind(-1) |
| x1 = cx - w / 2 |
| y1 = cy - h / 2 |
| x2 = cx + w / 2 |
| y2 = cy + h / 2 |
| return torch.stack([x1, y1, x2, y2], dim=-1) |
|
|
|
|
| def _pick_boxes_after_thresh_nms( |
| outputs: Dict[str, Any], |
| threshold: float, |
| nms_iou: float |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Trả về (boxes_xyxy_norm (K,4), scores (K,)) |
| boxes là normalized xyxy theo input 384×384. |
| """ |
| scores = _scores_from_pred_logits(outputs)[0] |
| keep = scores > threshold |
|
|
| if "pred_boxes" not in outputs: |
| idx = keep.nonzero(as_tuple=False).flatten() |
| return torch.zeros((0, 4), device=scores.device), scores[idx] |
|
|
| boxes = outputs["pred_boxes"].float()[0] |
| boxes_xyxy = _cxcywh_to_xyxy(boxes).clamp(0, 1) |
|
|
| idx = keep.nonzero(as_tuple=False).flatten() |
| if idx.numel() == 0: |
| return torch.zeros((0, 4), device=scores.device), torch.zeros((0,), device=scores.device) |
|
|
| b = boxes_xyxy[idx] |
| s = scores[idx] |
|
|
| if tv_nms is None or idx.numel() == 1: |
| return b, s |
|
|
| kept = tv_nms(b, s, nms_iou) |
| return b[kept], s[kept] |
|
|
|
|
| def load_model_checkpoint(model_ckpt_path: str, model: torch.nn.Module, device: str) -> Dict[str, Any]: |
| ckpt = torch.load(model_ckpt_path, map_location=device) |
|
|
| if isinstance(ckpt, dict) and "model" in ckpt and isinstance(ckpt["model"], dict): |
| state = ckpt["model"] |
| elif isinstance(ckpt, dict) and all(isinstance(k, str) for k in ckpt.keys()): |
| state = ckpt |
| else: |
| raise ValueError(f"Unrecognized checkpoint format: {model_ckpt_path}") |
|
|
| model.load_state_dict(state, strict=True) |
| meta = {} |
| if isinstance(ckpt, dict): |
| meta["epoch"] = ckpt.get("epoch", None) |
| return meta |
|
|
|
|
| def preprocess_image_for_model(img: Image.Image) -> torch.Tensor: |
| """ |
| Pipeline bạn đang dùng: |
| - Resize 384×384 |
| - Normalize ImageNet |
| """ |
| img = img.convert("RGB") |
| img = img.resize((384, 384), Image.BILINEAR) |
|
|
| arr = np.asarray(img).astype(np.float32) / 255.0 |
| arr = arr.transpose(2, 0, 1) |
| x = torch.from_numpy(arr) |
|
|
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
| x = (x - mean) / std |
| return x |
|
|
|
|
| def draw_boxes_on_pil_norm( |
| image: Image.Image, |
| boxes_xyxy_norm: np.ndarray, |
| scores: Optional[np.ndarray] = None, |
| score_threshold_to_show: float = 0.0 |
| ) -> Image.Image: |
| """ |
| Vẽ box normalized xyxy lên ảnh (PIL) theo kích thước của image. |
| """ |
| img = image.copy().convert("RGB") |
| W, H = img.size |
| dr = ImageDraw.Draw(img) |
|
|
| for i, box in enumerate(boxes_xyxy_norm): |
| x1 = int(max(0, min(W - 1, box[0] * W))) |
| y1 = int(max(0, min(H - 1, box[1] * H))) |
| x2 = int(max(0, min(W - 1, box[2] * W))) |
| y2 = int(max(0, min(H - 1, box[3] * H))) |
|
|
| dr.rectangle([x1, y1, x2, y2], width=2) |
|
|
| if scores is not None: |
| sc = float(scores[i]) |
| if sc >= score_threshold_to_show: |
| dr.text((x1 + 3, y1 + 3), f"{sc:.2f}") |
|
|
| return img |
|
|
|
|
| def draw_boxes_on_pil_px( |
| image: Image.Image, |
| boxes_xyxy_px: np.ndarray, |
| scores: Optional[np.ndarray] = None, |
| score_threshold_to_show: float = 0.0 |
| ) -> Image.Image: |
| """ |
| Vẽ box pixel xyxy lên ảnh gốc (PIL). |
| """ |
| img = image.copy().convert("RGB") |
| W, H = img.size |
| dr = ImageDraw.Draw(img) |
|
|
| for i, box in enumerate(boxes_xyxy_px): |
| x1 = int(max(0, min(W - 1, box[0]))) |
| y1 = int(max(0, min(H - 1, box[1]))) |
| x2 = int(max(0, min(W - 1, box[2]))) |
| y2 = int(max(0, min(H - 1, box[3]))) |
|
|
| dr.rectangle([x1, y1, x2, y2], width=2) |
|
|
| if scores is not None: |
| sc = float(scores[i]) |
| if sc >= score_threshold_to_show: |
| dr.text((x1 + 3, y1 + 3), f"{sc:.2f}") |
|
|
| return img |
|
|
|
|
| def infer_tiled_boxes( |
| img_pil: Image.Image, |
| model, |
| cap: str, |
| device: str, |
| threshold: float, |
| nms_iou: float, |
| patch: int = 384, |
| stride: int = 256, |
| border_ignore: int = 24, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Tile inference: |
| - cắt ảnh gốc thành patch×patch với stride (overlap) |
| - chạy model trên từng patch (đã pad về 384 nếu cần) |
| - map box về ảnh gốc (pixel) |
| - global NMS để tránh đếm trùng |
| Trả về boxes pixel xyxy (K,4) và scores (K,) |
| """ |
| W, H = img_pil.size |
|
|
| all_boxes = [] |
| all_scores = [] |
|
|
| |
| xs = list(range(0, max(1, W - patch + 1), stride)) |
| ys = list(range(0, max(1, H - patch + 1), stride)) |
| if len(xs) == 0: |
| xs = [0] |
| if len(ys) == 0: |
| ys = [0] |
| last_x = max(0, W - patch) |
| last_y = max(0, H - patch) |
| if xs[-1] != last_x: |
| xs.append(last_x) |
| if ys[-1] != last_y: |
| ys.append(last_y) |
|
|
| for y0 in ys: |
| for x0 in xs: |
| x1 = min(W, x0 + patch) |
| y1 = min(H, y0 + patch) |
|
|
| crop = img_pil.crop((x0, y0, x1, y1)).convert("RGB") |
|
|
| |
| if crop.size != (patch, patch): |
| canvas = Image.new("RGB", (patch, patch), (0, 0, 0)) |
| canvas.paste(crop, (0, 0)) |
| crop = canvas |
|
|
| x = preprocess_image_for_model(crop).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(x, captions=[cap]) |
|
|
| boxes_t, scores_t = _pick_boxes_after_thresh_nms(outputs, threshold=threshold, nms_iou=nms_iou) |
| if boxes_t is None or boxes_t.numel() == 0: |
| continue |
|
|
| |
| boxes_px = boxes_t.clone() |
| boxes_px[:, [0, 2]] *= patch |
| boxes_px[:, [1, 3]] *= patch |
|
|
| |
| if border_ignore and border_ignore > 0: |
| cx = 0.5 * (boxes_px[:, 0] + boxes_px[:, 2]) |
| cy = 0.5 * (boxes_px[:, 1] + boxes_px[:, 3]) |
| keep = ( |
| (cx > border_ignore) & (cx < (patch - border_ignore)) & |
| (cy > border_ignore) & (cy < (patch - border_ignore)) |
| ) |
| boxes_px = boxes_px[keep] |
| scores_t = scores_t[keep] |
| if boxes_px.numel() == 0: |
| continue |
|
|
| |
| boxes_px[:, [0, 2]] += x0 |
| boxes_px[:, [1, 3]] += y0 |
|
|
| |
| boxes_px[:, 0].clamp_(0, W - 1) |
| boxes_px[:, 2].clamp_(0, W - 1) |
| boxes_px[:, 1].clamp_(0, H - 1) |
| boxes_px[:, 3].clamp_(0, H - 1) |
|
|
| all_boxes.append(boxes_px) |
| all_scores.append(scores_t) |
|
|
| if len(all_boxes) == 0: |
| return torch.zeros((0, 4), device=device), torch.zeros((0,), device=device) |
|
|
| boxes = torch.cat(all_boxes, dim=0) |
| scores = torch.cat(all_scores, dim=0) |
|
|
| |
| if tv_nms is not None and boxes.shape[0] > 1: |
| keep = tv_nms(boxes, scores, nms_iou) |
| boxes = boxes[keep] |
| scores = scores[keep] |
|
|
| return boxes, scores |
|
|
|
|
| |
| |
| |
|
|
| st.set_page_config(page_title="GDCount Streamlit", layout="wide") |
| st.title("GDCount – Demo đếm theo prompt (FSC147)") |
|
|
| with st.sidebar: |
| st.header("Cấu hình") |
|
|
| device_choice = st.selectbox("Device", ["cuda", "cpu"], index=0) |
| device = device_choice if (device_choice == "cpu" or torch.cuda.is_available()) else "cpu" |
|
|
| default_config = r"C:\Users\PC\Documents\college\CV\gdcount\groundingdino\groundingdino\config\GroundingDINO_SwinT_OGC.py" |
| default_gdino_ckpt = r"C:\Users\PC\Documents\college\CV\gdcount\weights\groundingdino_swint_ogc.pth" |
| default_model_ckpt = r"C:\Users\PC\Documents\college\CV\gdcount\checkpoints_gdcount\best\gdcount_epoch_011_best.pth" |
|
|
| config_path = st.text_input("GroundingDINO config", value=default_config) |
| gdino_ckpt_path = st.text_input("GroundingDINO checkpoint", value=default_gdino_ckpt) |
| model_ckpt_path = st.text_input("GDCount trained checkpoint", value=default_model_ckpt) |
|
|
| st.divider() |
|
|
| threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, value=0.23, step=0.01) |
| nms_iou = st.slider("NMS IoU", min_value=0.0, max_value=1.0, value=0.50, step=0.01) |
|
|
| st.divider() |
| use_tiling = st.checkbox("Tile inference (cắt patch 384×384 + overlap)", value=True) |
| if use_tiling: |
| stride = st.slider("Stride", min_value=64, max_value=384, value=256, step=32) |
| border_ignore = st.slider("Border ignore (px)", min_value=0, max_value=64, value=24, step=4) |
| else: |
| stride = 256 |
| border_ignore = 24 |
|
|
| st.divider() |
| show_boxes = st.checkbox("Hiển thị bounding boxes", value=True) |
| show_scores = st.checkbox("Hiển thị score trên box", value=False) |
|
|
| st.divider() |
| prompt = st.text_input("Prompt", value="object") |
| run_btn = st.button("Chạy đếm", type="primary") |
|
|
|
|
| @st.cache_resource(show_spinner=True) |
| def load_model_cached( |
| config_path: str, |
| gdino_ckpt_path: str, |
| model_ckpt_path: str, |
| device: str, |
| threshold: float |
| ): |
| gd_cfg = GDCountConfig( |
| threshold=threshold, |
| soa_level=-1, |
| feature_dim=256, |
| freeze_keywords=["backbone.0", "bert"], |
| ) |
| model = build_gdcount_model( |
| config_path=config_path, |
| checkpoint_path=gdino_ckpt_path, |
| device=device, |
| gdcount_cfg=gd_cfg, |
| ) |
| meta = load_model_checkpoint(model_ckpt_path, model, device) |
| model.eval() |
| return model, meta |
|
|
|
|
| col_left, col_right = st.columns([1, 1]) |
|
|
| with col_left: |
| up = st.file_uploader("Upload ảnh (jpg/png)", type=["jpg", "jpeg", "png"]) |
| if up is not None: |
| img = Image.open(up) |
| st.image(img, caption="Ảnh gốc", use_container_width=True) |
| else: |
| img = None |
|
|
| with col_right: |
| st.subheader("Kết quả") |
| if img is None: |
| st.info("Upload ảnh để bắt đầu.") |
| else: |
| if run_btn: |
| if not os.path.isfile(config_path): |
| st.error(f"Không tìm thấy config: {config_path}") |
| st.stop() |
| if not os.path.isfile(gdino_ckpt_path): |
| st.error(f"Không tìm thấy GroundingDINO ckpt: {gdino_ckpt_path}") |
| st.stop() |
| if not os.path.isfile(model_ckpt_path): |
| st.error(f"Không tìm thấy GDCount ckpt: {model_ckpt_path}") |
| st.stop() |
|
|
| with st.spinner("Đang load model (lần đầu có thể lâu)..."): |
| model, meta = load_model_cached( |
| config_path=config_path, |
| gdino_ckpt_path=gdino_ckpt_path, |
| model_ckpt_path=model_ckpt_path, |
| device=device, |
| threshold=threshold, |
| ) |
|
|
| cap = sanitize_caption(prompt) |
|
|
| t0 = time.time() |
|
|
| if use_tiling: |
| boxes_px_t, scores_t = infer_tiled_boxes( |
| img_pil=img, |
| model=model, |
| cap=cap, |
| device=device, |
| threshold=threshold, |
| nms_iou=nms_iou, |
| patch=384, |
| stride=stride, |
| border_ignore=border_ignore, |
| ) |
| dt = (time.time() - t0) * 1000.0 |
| pred_count = int(boxes_px_t.shape[0]) if boxes_px_t is not None else 0 |
|
|
| st.metric("Predicted count", pred_count) |
| st.write( |
| f"- Mode: `tile`\n" |
| f"- Patch: `384`, stride: `{stride}`, border_ignore: `{border_ignore}`\n" |
| f"- Prompt: `{cap}`\n" |
| f"- Device: `{device}`\n" |
| f"- Checkpoint epoch: `{meta.get('epoch', '')}`\n" |
| f"- Inference time: `{dt:.1f} ms`" |
| ) |
|
|
| if show_boxes: |
| boxes_np = boxes_px_t.detach().cpu().numpy() if boxes_px_t is not None else np.zeros((0, 4), dtype=np.float32) |
| scores_np = scores_t.detach().cpu().numpy() if scores_t is not None else None |
| vis = draw_boxes_on_pil_px( |
| image=img, |
| boxes_xyxy_px=boxes_np, |
| scores=scores_np if show_scores else None, |
| score_threshold_to_show=0.0, |
| ) |
| st.image(vis, caption="Ảnh gốc + boxes (tile + global NMS)", use_container_width=True) |
|
|
| with st.expander("Debug"): |
| st.write("boxes_px shape:", tuple(boxes_px_t.shape)) |
| st.write("kept boxes:", int(boxes_px_t.shape[0])) |
|
|
| else: |
| |
| x = preprocess_image_for_model(img).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| outputs: Dict[str, Any] = model(x, captions=[cap]) |
|
|
| dt = (time.time() - t0) * 1000.0 |
|
|
| boxes_t, scores_t = _pick_boxes_after_thresh_nms(outputs, threshold=threshold, nms_iou=nms_iou) |
| pred_count = int(boxes_t.shape[0]) if boxes_t is not None else 0 |
|
|
| st.metric("Predicted count", pred_count) |
| st.write( |
| f"- Mode: `resize384`\n" |
| f"- Prompt: `{cap}`\n" |
| f"- Device: `{device}`\n" |
| f"- Checkpoint epoch: `{meta.get('epoch', '')}`\n" |
| f"- Inference time: `{dt:.1f} ms`" |
| ) |
|
|
| if show_boxes: |
| boxes_np = boxes_t.detach().cpu().numpy() if boxes_t is not None else np.zeros((0, 4), dtype=np.float32) |
| scores_np = scores_t.detach().cpu().numpy() if scores_t is not None else None |
|
|
| vis = draw_boxes_on_pil_norm( |
| image=img.resize((384, 384), Image.BILINEAR), |
| boxes_xyxy_norm=boxes_np, |
| scores=scores_np if show_scores else None, |
| score_threshold_to_show=0.0, |
| ) |
| st.image(vis, caption="Ảnh 384×384 + boxes (resize384)", use_container_width=True) |
|
|
| with st.expander("Debug (tensors)"): |
| st.write("outputs keys:", list(outputs.keys())) |
| if "pred_boxes" in outputs: |
| st.write("pred_boxes shape:", tuple(outputs["pred_boxes"].shape)) |
| if "pred_logits" in outputs: |
| st.write("pred_logits shape:", tuple(outputs["pred_logits"].shape)) |
| st.write("kept boxes:", int(boxes_t.shape[0])) |
|
|