Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import math | |
| import yaml | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| import subprocess | |
| import sys | |
| import os | |
| # --- APGCC Setup --- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| APGCC_REPO = os.path.join(BASE_DIR, "APGCC") | |
| APGCC_DIR = os.path.join(APGCC_REPO, "apgcc") | |
| if not os.path.exists(APGCC_REPO): | |
| subprocess.run( | |
| ["git", "clone", "https://github.com/AaronCIH/APGCC.git", APGCC_REPO], | |
| check=True, | |
| ) | |
| VGG_PY = os.path.join(APGCC_DIR, "models", "backbones", "vgg.py") | |
| with open(VGG_PY, "r") as f: | |
| _vgg_src = f.read() | |
| if "model_paths[arch]" in _vgg_src: | |
| _vgg_src = _vgg_src.replace( | |
| "state_dict = torch.load(model_paths[arch])", | |
| "state_dict = torch.hub.load_state_dict_from_url(" | |
| "{'vgg16_bn':'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'}[arch])", | |
| ) | |
| with open(VGG_PY, "w") as f: | |
| f.write(_vgg_src) | |
| MODULES_PY = os.path.join(APGCC_DIR, "models", "modules.py") | |
| with open(MODULES_PY, "r") as f: | |
| _mod_src = f.read() | |
| if ".cuda()" in _mod_src: | |
| _mod_src = _mod_src.replace(".cuda()", ".to(res.device)") | |
| with open(MODULES_PY, "w") as f: | |
| f.write(_mod_src) | |
| WEIGHT_PATH = os.path.join(APGCC_DIR, "outputs", "best.pth") | |
| if not os.path.exists(WEIGHT_PATH): | |
| os.makedirs(os.path.dirname(WEIGHT_PATH), exist_ok=True) | |
| import gdown | |
| gdown.download(id="1pEvn5RrvmDqVJUDZ4c9-rCJcl2I7bRhu", output=WEIGHT_PATH, quiet=False) | |
| sys.path.insert(0, APGCC_DIR) | |
| from config import cfg as _cfg, merge_from_file | |
| from models import build_model as apgcc_build_model | |
| _cfg = merge_from_file(_cfg, os.path.join(APGCC_DIR, "configs", "SHHA_test.yml")) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| APGCC_MODEL = apgcc_build_model(cfg=_cfg, training=False) | |
| APGCC_MODEL.to(DEVICE) | |
| checkpoint = torch.load(WEIGHT_PATH, map_location="cpu", weights_only=False) | |
| model_state = APGCC_MODEL.state_dict() | |
| filtered = {k: v for k, v in checkpoint.items() if k in model_state} | |
| model_state.update(filtered) | |
| APGCC_MODEL.load_state_dict(model_state) | |
| APGCC_MODEL.eval() | |
| IMG_TRANSFORM = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| THRESHOLD = 0.5 | |
| # --- Locations --- | |
| LOCATIONS_PATH = os.path.join(BASE_DIR, "locations.yaml") | |
| with open(LOCATIONS_PATH, "r") as f: | |
| _raw = yaml.safe_load(f) | |
| LOCATION_NAMES = list(_raw.keys()) | |
| def count_people_in_image(image_path): | |
| if image_path is None: | |
| return 0, None | |
| img = Image.open(image_path).convert("RGB") | |
| img_tensor = IMG_TRANSFORM(img) | |
| _, h, w = img_tensor.shape | |
| scale = 1.0 | |
| if max(h, w) > 2560: | |
| scale = 2560.0 / max(h, w) | |
| new_h, new_w = int(h * scale), int(w * scale) | |
| img_tensor = F.interpolate( | |
| img_tensor.unsqueeze(0), size=(new_h, new_w), | |
| mode="bilinear", align_corners=False, | |
| ).squeeze(0) | |
| _, h, w = img_tensor.shape | |
| pad_h = ((h - 1) // 128 + 1) * 128 - h | |
| pad_w = ((w - 1) // 128 + 1) * 128 - w | |
| if pad_h > 0 or pad_w > 0: | |
| img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), value=0) | |
| outputs = APGCC_MODEL(img_tensor.unsqueeze(0).to(DEVICE)) | |
| scores = F.softmax(outputs["pred_logits"], dim=-1)[:, :, 1][0] | |
| pred_points = outputs["pred_points"][0] | |
| mask = scores > THRESHOLD | |
| points = pred_points[mask].detach().cpu().numpy() | |
| person_count = int(mask.sum().item()) | |
| original = cv2.imread(image_path) | |
| if original is None: | |
| return person_count, None | |
| for x, y in points: | |
| cx, cy = int(x / scale), int(y / scale) | |
| cv2.circle(original, (cx, cy), 6, (0, 0, 0), -1) | |
| cv2.circle(original, (cx, cy), 4, (0, 255, 255), -1) | |
| cv2.putText(original, f"Count: {person_count}", (10, 40), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 255), 3) | |
| return person_count, cv2.cvtColor(original, cv2.COLOR_BGR2RGB) | |
| CJK_FONT_PATHS = [ | |
| "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", | |
| "/usr/share/fonts/opentype/noto/NotoSansCJKtc-Regular.otf", | |
| "/System/Library/Fonts/PingFang.ttc", | |
| "/System/Library/Fonts/STHeiti Medium.ttc", | |
| ] | |
| def _load_cjk_font(size): | |
| for path in CJK_FONT_PATHS: | |
| try: | |
| return ImageFont.truetype(path, size) | |
| except OSError: | |
| continue | |
| return ImageFont.load_default(size=size) | |
| def create_result_collage(annotated_images, counts, location_name, total_count): | |
| if not annotated_images: | |
| return None | |
| cell_width = 640 | |
| padding = 10 | |
| header_height = 80 | |
| footer_height = 80 | |
| label_height = 40 | |
| n = len(annotated_images) | |
| cols = 2 if n > 1 else 1 | |
| rows = math.ceil(n / cols) | |
| resized = [] | |
| for img_arr in annotated_images: | |
| pil_img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr | |
| w, h = pil_img.size | |
| new_h = int(h * cell_width / w) | |
| resized.append(pil_img.resize((cell_width, new_h), Image.LANCZOS)) | |
| max_cell_h = max(img.size[1] for img in resized) + label_height | |
| canvas_w = cols * cell_width + (cols + 1) * padding | |
| canvas_h = header_height + rows * max_cell_h + (rows + 1) * padding + footer_height | |
| canvas = Image.new("RGB", (canvas_w, canvas_h), (255, 255, 255)) | |
| draw = ImageDraw.Draw(canvas) | |
| font_large = _load_cjk_font(36) | |
| font_medium = _load_cjk_font(24) | |
| header_text = f"地點: {location_name}" | |
| bbox = draw.textbbox((0, 0), header_text, font=font_large) | |
| text_w = bbox[2] - bbox[0] | |
| draw.text(((canvas_w - text_w) / 2, (header_height - (bbox[3] - bbox[1])) / 2), | |
| header_text, fill=(0, 0, 0), font=font_large) | |
| for idx, (img, count) in enumerate(zip(resized, counts)): | |
| row = idx // cols | |
| col = idx % cols | |
| x = padding + col * (cell_width + padding) | |
| y = header_height + padding + row * (max_cell_h + padding) | |
| canvas.paste(img, (x, y)) | |
| label_y = y + img.size[1] | |
| draw.rectangle([x, label_y, x + cell_width, label_y + label_height], fill=(0, 0, 0)) | |
| label_text = f"照片 {idx + 1}: {count} 人" | |
| lbbox = draw.textbbox((0, 0), label_text, font=font_medium) | |
| lw = lbbox[2] - lbbox[0] | |
| draw.text((x + (cell_width - lw) / 2, label_y + (label_height - (lbbox[3] - lbbox[1])) / 2), | |
| label_text, fill=(255, 255, 255), font=font_medium) | |
| footer_y = canvas_h - footer_height | |
| footer_text = f"估計總人數: {total_count} 人" | |
| fbbox = draw.textbbox((0, 0), footer_text, font=font_large) | |
| fw = fbbox[2] - fbbox[0] | |
| draw.text(((canvas_w - fw) / 2, footer_y + (footer_height - (fbbox[3] - fbbox[1])) / 2), | |
| footer_text, fill=(0, 0, 0), font=font_large) | |
| return canvas | |
| def process_files(files, location_choice): | |
| empty = ("", "", None) | |
| if not files: | |
| return empty | |
| location_name = location_choice | |
| valid_paths = [] | |
| for file_path in files: | |
| try: | |
| Image.open(file_path).verify() | |
| valid_paths.append(file_path) | |
| except Exception: | |
| continue | |
| if not valid_paths: | |
| return empty | |
| total_count = 0 | |
| annotated_arrays = [] | |
| collage_counts = [] | |
| detail_lines = [] | |
| for idx, fpath in enumerate(valid_paths, 1): | |
| count, annotated_img = count_people_in_image(fpath) | |
| total_count += count | |
| if annotated_img is not None: | |
| annotated_arrays.append(annotated_img) | |
| collage_counts.append(count) | |
| detail_lines.append(f"| {idx} | {count} |") | |
| details_md = f"**地點: {location_name}** · 共 {len(valid_paths)} 張照片\n\n" | |
| details_md += "| 照片 | 人數 |\n|------|------|\n" + "\n".join(detail_lines) | |
| total_html = ( | |
| f"<div style='text-align:center;padding:24px 0'>" | |
| f"<div style='font-size:56px;font-weight:700;line-height:1'>{total_count}</div>" | |
| f"<div style='font-size:14px;opacity:0.6;margin-top:4px'>估計總人數</div>" | |
| f"</div>" | |
| ) | |
| collage = create_result_collage(annotated_arrays, collage_counts, location_name, total_count) | |
| return total_html, details_md, collage | |
| # --- Gradio UI --- | |
| with gr.Blocks( | |
| title="流動人數估計系統", | |
| theme=gr.themes.Default(primary_hue="blue"), | |
| ) as demo: | |
| gr.Markdown("# 流動人數估計系統\n") | |
| location_dropdown = gr.Dropdown( | |
| label="1. 選擇地點", | |
| choices=LOCATION_NAMES, | |
| value=None, | |
| info="請先選擇地點,然後上傳照片。", | |
| ) | |
| upload_group = gr.Group(visible=False) | |
| with upload_group: | |
| image_input = gr.File( | |
| label="2. 上傳照片(可同時選取多張)", | |
| file_count="multiple", | |
| file_types=["image"], | |
| type="filepath", | |
| ) | |
| count_button = gr.Button("開始計算", variant="primary", size="lg") | |
| total_display = gr.Markdown() | |
| collage_output = gr.Image(label="結果拼圖", type="pil") | |
| with gr.Accordion("詳細結果", open=False): | |
| details_md = gr.Markdown() | |
| location_dropdown.change( | |
| fn=lambda loc: gr.Group(visible=loc is not None), | |
| inputs=[location_dropdown], | |
| outputs=[upload_group], | |
| ) | |
| count_button.click( | |
| fn=process_files, | |
| inputs=[image_input, location_dropdown], | |
| outputs=[total_display, details_md, collage_output], | |
| show_progress="full", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |