zsc / app_patch_img.py
Bao Tran
Initial commit: upload full project
5633819
# app.py
# Streamlit demo cho GDCount (FSC147-style): upload ảnh + prompt -> dự đoán count
# Chạy: streamlit run app.py
import os
import re
import time
from typing import Any, Dict, Tuple, Optional
import streamlit as st
import numpy as np
from PIL import Image, ImageDraw
import torch
try:
from torchvision.ops import nms as tv_nms
except Exception:
tv_nms = None
# ===== Import từ project của bạn (đảm bảo đúng path/module) =====
from models.gdcount_model import GDCountConfig, build_gdcount_model
# =========================
# Helpers (giống train/test)
# =========================
def sanitize_caption(p: str) -> str:
p = "" if p is None else str(p)
p = p.strip()
p = re.sub(r"\s+", " ", p)
p = re.sub(r"\s+\.", ".", p)
if p == "" or p == ".":
p = "object."
if not p.endswith("."):
p = p + "."
return p
def _scores_from_pred_logits(outputs: Dict[str, Any]) -> torch.Tensor:
"""
scores (B,Q) = sigmoid(max_token_logit over valid tokens) hoặc sigmoid(logit) nếu 2D.
"""
logits = outputs["pred_logits"].float() # (B,Q,T) hoặc (B,Q)
if logits.dim() == 2:
logits = torch.where(torch.isfinite(logits), logits, torch.full_like(logits, -1e4))
return torch.sigmoid(logits)
B, Q, T = logits.shape
token_mask = outputs.get("text_mask", None)
if token_mask is None:
token_mask = torch.ones((B, T), device=logits.device, dtype=torch.bool)
else:
token_mask = token_mask.to(device=logits.device, dtype=torch.bool)
if token_mask.shape[-1] < T:
pad = torch.zeros((B, T - token_mask.shape[-1]), device=logits.device, dtype=torch.bool)
token_mask = torch.cat([token_mask, pad], dim=-1)
token_mask = token_mask[:, :T]
input_ids = outputs.get("input_ids", None)
if input_ids is not None:
ids = input_ids.to(device=logits.device)
if ids.shape[-1] < T:
pad = torch.zeros((B, T - ids.shape[-1]), device=logits.device, dtype=ids.dtype)
ids = torch.cat([ids, pad], dim=-1)
ids = ids[:, :T]
specials = (ids == 0) | (ids == 101) | (ids == 102)
token_mask = token_mask & (~specials)
logits = torch.where(torch.isfinite(logits), logits, torch.full_like(logits, -1e4))
logits = logits.masked_fill(~token_mask[:, None, :], -1e4)
per_q = logits.max(dim=-1).values # (B,Q)
return torch.sigmoid(per_q)
def _cxcywh_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
cx, cy, w, h = boxes.unbind(-1)
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
return torch.stack([x1, y1, x2, y2], dim=-1)
def _pick_boxes_after_thresh_nms(
outputs: Dict[str, Any],
threshold: float,
nms_iou: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Trả về (boxes_xyxy_norm (K,4), scores (K,))
boxes là normalized xyxy theo input 384×384.
"""
scores = _scores_from_pred_logits(outputs)[0] # (Q,)
keep = scores > threshold
if "pred_boxes" not in outputs:
idx = keep.nonzero(as_tuple=False).flatten()
return torch.zeros((0, 4), device=scores.device), scores[idx]
boxes = outputs["pred_boxes"].float()[0] # (Q,4) cxcywh norm
boxes_xyxy = _cxcywh_to_xyxy(boxes).clamp(0, 1)
idx = keep.nonzero(as_tuple=False).flatten()
if idx.numel() == 0:
return torch.zeros((0, 4), device=scores.device), torch.zeros((0,), device=scores.device)
b = boxes_xyxy[idx]
s = scores[idx]
if tv_nms is None or idx.numel() == 1:
return b, s
kept = tv_nms(b, s, nms_iou)
return b[kept], s[kept]
def load_model_checkpoint(model_ckpt_path: str, model: torch.nn.Module, device: str) -> Dict[str, Any]:
ckpt = torch.load(model_ckpt_path, map_location=device)
if isinstance(ckpt, dict) and "model" in ckpt and isinstance(ckpt["model"], dict):
state = ckpt["model"]
elif isinstance(ckpt, dict) and all(isinstance(k, str) for k in ckpt.keys()):
state = ckpt
else:
raise ValueError(f"Unrecognized checkpoint format: {model_ckpt_path}")
model.load_state_dict(state, strict=True)
meta = {}
if isinstance(ckpt, dict):
meta["epoch"] = ckpt.get("epoch", None)
return meta
def preprocess_image_for_model(img: Image.Image) -> torch.Tensor:
"""
Pipeline bạn đang dùng:
- Resize 384×384
- Normalize ImageNet
"""
img = img.convert("RGB")
img = img.resize((384, 384), Image.BILINEAR)
arr = np.asarray(img).astype(np.float32) / 255.0 # (H,W,3)
arr = arr.transpose(2, 0, 1) # (3,H,W)
x = torch.from_numpy(arr)
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
x = (x - mean) / std
return x # (3,384,384)
def draw_boxes_on_pil_norm(
image: Image.Image,
boxes_xyxy_norm: np.ndarray,
scores: Optional[np.ndarray] = None,
score_threshold_to_show: float = 0.0
) -> Image.Image:
"""
Vẽ box normalized xyxy lên ảnh (PIL) theo kích thước của image.
"""
img = image.copy().convert("RGB")
W, H = img.size
dr = ImageDraw.Draw(img)
for i, box in enumerate(boxes_xyxy_norm):
x1 = int(max(0, min(W - 1, box[0] * W)))
y1 = int(max(0, min(H - 1, box[1] * H)))
x2 = int(max(0, min(W - 1, box[2] * W)))
y2 = int(max(0, min(H - 1, box[3] * H)))
dr.rectangle([x1, y1, x2, y2], width=2)
if scores is not None:
sc = float(scores[i])
if sc >= score_threshold_to_show:
dr.text((x1 + 3, y1 + 3), f"{sc:.2f}")
return img
def draw_boxes_on_pil_px(
image: Image.Image,
boxes_xyxy_px: np.ndarray,
scores: Optional[np.ndarray] = None,
score_threshold_to_show: float = 0.0
) -> Image.Image:
"""
Vẽ box pixel xyxy lên ảnh gốc (PIL).
"""
img = image.copy().convert("RGB")
W, H = img.size
dr = ImageDraw.Draw(img)
for i, box in enumerate(boxes_xyxy_px):
x1 = int(max(0, min(W - 1, box[0])))
y1 = int(max(0, min(H - 1, box[1])))
x2 = int(max(0, min(W - 1, box[2])))
y2 = int(max(0, min(H - 1, box[3])))
dr.rectangle([x1, y1, x2, y2], width=2)
if scores is not None:
sc = float(scores[i])
if sc >= score_threshold_to_show:
dr.text((x1 + 3, y1 + 3), f"{sc:.2f}")
return img
def infer_tiled_boxes(
img_pil: Image.Image,
model,
cap: str,
device: str,
threshold: float,
nms_iou: float,
patch: int = 384,
stride: int = 256,
border_ignore: int = 24, # 0 để tắt
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Tile inference:
- cắt ảnh gốc thành patch×patch với stride (overlap)
- chạy model trên từng patch (đã pad về 384 nếu cần)
- map box về ảnh gốc (pixel)
- global NMS để tránh đếm trùng
Trả về boxes pixel xyxy (K,4) và scores (K,)
"""
W, H = img_pil.size
all_boxes = []
all_scores = []
# sinh lưới crop (đảm bảo có patch cuối chạm biên)
xs = list(range(0, max(1, W - patch + 1), stride))
ys = list(range(0, max(1, H - patch + 1), stride))
if len(xs) == 0:
xs = [0]
if len(ys) == 0:
ys = [0]
last_x = max(0, W - patch)
last_y = max(0, H - patch)
if xs[-1] != last_x:
xs.append(last_x)
if ys[-1] != last_y:
ys.append(last_y)
for y0 in ys:
for x0 in xs:
x1 = min(W, x0 + patch)
y1 = min(H, y0 + patch)
crop = img_pil.crop((x0, y0, x1, y1)).convert("RGB")
# pad nếu crop nhỏ hơn patch (ở biên)
if crop.size != (patch, patch):
canvas = Image.new("RGB", (patch, patch), (0, 0, 0))
canvas.paste(crop, (0, 0))
crop = canvas
x = preprocess_image_for_model(crop).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(x, captions=[cap])
boxes_t, scores_t = _pick_boxes_after_thresh_nms(outputs, threshold=threshold, nms_iou=nms_iou)
if boxes_t is None or boxes_t.numel() == 0:
continue
# boxes_t: normalized xyxy theo patch
boxes_px = boxes_t.clone()
boxes_px[:, [0, 2]] *= patch
boxes_px[:, [1, 3]] *= patch
# optional: bỏ box sát viền patch (giảm box cụt + giảm trùng)
if border_ignore and border_ignore > 0:
cx = 0.5 * (boxes_px[:, 0] + boxes_px[:, 2])
cy = 0.5 * (boxes_px[:, 1] + boxes_px[:, 3])
keep = (
(cx > border_ignore) & (cx < (patch - border_ignore)) &
(cy > border_ignore) & (cy < (patch - border_ignore))
)
boxes_px = boxes_px[keep]
scores_t = scores_t[keep]
if boxes_px.numel() == 0:
continue
# map về ảnh gốc
boxes_px[:, [0, 2]] += x0
boxes_px[:, [1, 3]] += y0
# clip
boxes_px[:, 0].clamp_(0, W - 1)
boxes_px[:, 2].clamp_(0, W - 1)
boxes_px[:, 1].clamp_(0, H - 1)
boxes_px[:, 3].clamp_(0, H - 1)
all_boxes.append(boxes_px)
all_scores.append(scores_t)
if len(all_boxes) == 0:
return torch.zeros((0, 4), device=device), torch.zeros((0,), device=device)
boxes = torch.cat(all_boxes, dim=0) # pixel xyxy
scores = torch.cat(all_scores, dim=0)
# global NMS
if tv_nms is not None and boxes.shape[0] > 1:
keep = tv_nms(boxes, scores, nms_iou)
boxes = boxes[keep]
scores = scores[keep]
return boxes, scores
# =========================
# Streamlit UI
# =========================
st.set_page_config(page_title="GDCount Streamlit", layout="wide")
st.title("GDCount – Demo đếm theo prompt (FSC147)")
with st.sidebar:
st.header("Cấu hình")
device_choice = st.selectbox("Device", ["cuda", "cpu"], index=0)
device = device_choice if (device_choice == "cpu" or torch.cuda.is_available()) else "cpu"
default_config = r"C:\Users\PC\Documents\college\CV\gdcount\groundingdino\groundingdino\config\GroundingDINO_SwinT_OGC.py"
default_gdino_ckpt = r"C:\Users\PC\Documents\college\CV\gdcount\weights\groundingdino_swint_ogc.pth"
default_model_ckpt = r"C:\Users\PC\Documents\college\CV\gdcount\checkpoints_gdcount\best\gdcount_epoch_011_best.pth"
config_path = st.text_input("GroundingDINO config", value=default_config)
gdino_ckpt_path = st.text_input("GroundingDINO checkpoint", value=default_gdino_ckpt)
model_ckpt_path = st.text_input("GDCount trained checkpoint", value=default_model_ckpt)
st.divider()
threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, value=0.23, step=0.01)
nms_iou = st.slider("NMS IoU", min_value=0.0, max_value=1.0, value=0.50, step=0.01)
st.divider()
use_tiling = st.checkbox("Tile inference (cắt patch 384×384 + overlap)", value=True)
if use_tiling:
stride = st.slider("Stride", min_value=64, max_value=384, value=256, step=32)
border_ignore = st.slider("Border ignore (px)", min_value=0, max_value=64, value=24, step=4)
else:
stride = 256
border_ignore = 24
st.divider()
show_boxes = st.checkbox("Hiển thị bounding boxes", value=True)
show_scores = st.checkbox("Hiển thị score trên box", value=False)
st.divider()
prompt = st.text_input("Prompt", value="object")
run_btn = st.button("Chạy đếm", type="primary")
@st.cache_resource(show_spinner=True)
def load_model_cached(
config_path: str,
gdino_ckpt_path: str,
model_ckpt_path: str,
device: str,
threshold: float
):
gd_cfg = GDCountConfig(
threshold=threshold,
soa_level=-1,
feature_dim=256,
freeze_keywords=["backbone.0", "bert"],
)
model = build_gdcount_model(
config_path=config_path,
checkpoint_path=gdino_ckpt_path,
device=device,
gdcount_cfg=gd_cfg,
)
meta = load_model_checkpoint(model_ckpt_path, model, device)
model.eval()
return model, meta
col_left, col_right = st.columns([1, 1])
with col_left:
up = st.file_uploader("Upload ảnh (jpg/png)", type=["jpg", "jpeg", "png"])
if up is not None:
img = Image.open(up)
st.image(img, caption="Ảnh gốc", use_container_width=True)
else:
img = None
with col_right:
st.subheader("Kết quả")
if img is None:
st.info("Upload ảnh để bắt đầu.")
else:
if run_btn:
if not os.path.isfile(config_path):
st.error(f"Không tìm thấy config: {config_path}")
st.stop()
if not os.path.isfile(gdino_ckpt_path):
st.error(f"Không tìm thấy GroundingDINO ckpt: {gdino_ckpt_path}")
st.stop()
if not os.path.isfile(model_ckpt_path):
st.error(f"Không tìm thấy GDCount ckpt: {model_ckpt_path}")
st.stop()
with st.spinner("Đang load model (lần đầu có thể lâu)..."):
model, meta = load_model_cached(
config_path=config_path,
gdino_ckpt_path=gdino_ckpt_path,
model_ckpt_path=model_ckpt_path,
device=device,
threshold=threshold,
)
cap = sanitize_caption(prompt)
t0 = time.time()
if use_tiling:
boxes_px_t, scores_t = infer_tiled_boxes(
img_pil=img,
model=model,
cap=cap,
device=device,
threshold=threshold,
nms_iou=nms_iou,
patch=384,
stride=stride,
border_ignore=border_ignore,
)
dt = (time.time() - t0) * 1000.0
pred_count = int(boxes_px_t.shape[0]) if boxes_px_t is not None else 0
st.metric("Predicted count", pred_count)
st.write(
f"- Mode: `tile`\n"
f"- Patch: `384`, stride: `{stride}`, border_ignore: `{border_ignore}`\n"
f"- Prompt: `{cap}`\n"
f"- Device: `{device}`\n"
f"- Checkpoint epoch: `{meta.get('epoch', '')}`\n"
f"- Inference time: `{dt:.1f} ms`"
)
if show_boxes:
boxes_np = boxes_px_t.detach().cpu().numpy() if boxes_px_t is not None else np.zeros((0, 4), dtype=np.float32)
scores_np = scores_t.detach().cpu().numpy() if scores_t is not None else None
vis = draw_boxes_on_pil_px(
image=img,
boxes_xyxy_px=boxes_np,
scores=scores_np if show_scores else None,
score_threshold_to_show=0.0,
)
st.image(vis, caption="Ảnh gốc + boxes (tile + global NMS)", use_container_width=True)
with st.expander("Debug"):
st.write("boxes_px shape:", tuple(boxes_px_t.shape))
st.write("kept boxes:", int(boxes_px_t.shape[0]))
else:
# mode cũ: resize toàn ảnh về 384×384
x = preprocess_image_for_model(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs: Dict[str, Any] = model(x, captions=[cap])
dt = (time.time() - t0) * 1000.0
boxes_t, scores_t = _pick_boxes_after_thresh_nms(outputs, threshold=threshold, nms_iou=nms_iou)
pred_count = int(boxes_t.shape[0]) if boxes_t is not None else 0
st.metric("Predicted count", pred_count)
st.write(
f"- Mode: `resize384`\n"
f"- Prompt: `{cap}`\n"
f"- Device: `{device}`\n"
f"- Checkpoint epoch: `{meta.get('epoch', '')}`\n"
f"- Inference time: `{dt:.1f} ms`"
)
if show_boxes:
boxes_np = boxes_t.detach().cpu().numpy() if boxes_t is not None else np.zeros((0, 4), dtype=np.float32)
scores_np = scores_t.detach().cpu().numpy() if scores_t is not None else None
vis = draw_boxes_on_pil_norm(
image=img.resize((384, 384), Image.BILINEAR),
boxes_xyxy_norm=boxes_np,
scores=scores_np if show_scores else None,
score_threshold_to_show=0.0,
)
st.image(vis, caption="Ảnh 384×384 + boxes (resize384)", use_container_width=True)
with st.expander("Debug (tensors)"):
st.write("outputs keys:", list(outputs.keys()))
if "pred_boxes" in outputs:
st.write("pred_boxes shape:", tuple(outputs["pred_boxes"].shape))
if "pred_logits" in outputs:
st.write("pred_logits shape:", tuple(outputs["pred_logits"].shape))
st.write("kept boxes:", int(boxes_t.shape[0]))