lnyf_pcs_mobile / app.py
samw212's picture
Upload 5 files
c0ab594 verified
import gradio as gr
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import math
import yaml
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import subprocess
import sys
import os
# --- APGCC Setup ---
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
APGCC_REPO = os.path.join(BASE_DIR, "APGCC")
APGCC_DIR = os.path.join(APGCC_REPO, "apgcc")
if not os.path.exists(APGCC_REPO):
subprocess.run(
["git", "clone", "https://github.com/AaronCIH/APGCC.git", APGCC_REPO],
check=True,
)
VGG_PY = os.path.join(APGCC_DIR, "models", "backbones", "vgg.py")
with open(VGG_PY, "r") as f:
_vgg_src = f.read()
if "model_paths[arch]" in _vgg_src:
_vgg_src = _vgg_src.replace(
"state_dict = torch.load(model_paths[arch])",
"state_dict = torch.hub.load_state_dict_from_url("
"{'vgg16_bn':'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'}[arch])",
)
with open(VGG_PY, "w") as f:
f.write(_vgg_src)
MODULES_PY = os.path.join(APGCC_DIR, "models", "modules.py")
with open(MODULES_PY, "r") as f:
_mod_src = f.read()
if ".cuda()" in _mod_src:
_mod_src = _mod_src.replace(".cuda()", ".to(res.device)")
with open(MODULES_PY, "w") as f:
f.write(_mod_src)
WEIGHT_PATH = os.path.join(APGCC_DIR, "outputs", "best.pth")
if not os.path.exists(WEIGHT_PATH):
os.makedirs(os.path.dirname(WEIGHT_PATH), exist_ok=True)
import gdown
gdown.download(id="1pEvn5RrvmDqVJUDZ4c9-rCJcl2I7bRhu", output=WEIGHT_PATH, quiet=False)
sys.path.insert(0, APGCC_DIR)
from config import cfg as _cfg, merge_from_file
from models import build_model as apgcc_build_model
_cfg = merge_from_file(_cfg, os.path.join(APGCC_DIR, "configs", "SHHA_test.yml"))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
APGCC_MODEL = apgcc_build_model(cfg=_cfg, training=False)
APGCC_MODEL.to(DEVICE)
checkpoint = torch.load(WEIGHT_PATH, map_location="cpu", weights_only=False)
model_state = APGCC_MODEL.state_dict()
filtered = {k: v for k, v in checkpoint.items() if k in model_state}
model_state.update(filtered)
APGCC_MODEL.load_state_dict(model_state)
APGCC_MODEL.eval()
IMG_TRANSFORM = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
THRESHOLD = 0.5
# --- Locations ---
LOCATIONS_PATH = os.path.join(BASE_DIR, "locations.yaml")
with open(LOCATIONS_PATH, "r") as f:
_raw = yaml.safe_load(f)
LOCATION_NAMES = list(_raw.keys())
@torch.no_grad()
def count_people_in_image(image_path):
if image_path is None:
return 0, None
img = Image.open(image_path).convert("RGB")
img_tensor = IMG_TRANSFORM(img)
_, h, w = img_tensor.shape
scale = 1.0
if max(h, w) > 2560:
scale = 2560.0 / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
img_tensor = F.interpolate(
img_tensor.unsqueeze(0), size=(new_h, new_w),
mode="bilinear", align_corners=False,
).squeeze(0)
_, h, w = img_tensor.shape
pad_h = ((h - 1) // 128 + 1) * 128 - h
pad_w = ((w - 1) // 128 + 1) * 128 - w
if pad_h > 0 or pad_w > 0:
img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), value=0)
outputs = APGCC_MODEL(img_tensor.unsqueeze(0).to(DEVICE))
scores = F.softmax(outputs["pred_logits"], dim=-1)[:, :, 1][0]
pred_points = outputs["pred_points"][0]
mask = scores > THRESHOLD
points = pred_points[mask].detach().cpu().numpy()
person_count = int(mask.sum().item())
original = cv2.imread(image_path)
if original is None:
return person_count, None
for x, y in points:
cx, cy = int(x / scale), int(y / scale)
cv2.circle(original, (cx, cy), 6, (0, 0, 0), -1)
cv2.circle(original, (cx, cy), 4, (0, 255, 255), -1)
cv2.putText(original, f"Count: {person_count}", (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 255), 3)
return person_count, cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
CJK_FONT_PATHS = [
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/opentype/noto/NotoSansCJKtc-Regular.otf",
"/System/Library/Fonts/PingFang.ttc",
"/System/Library/Fonts/STHeiti Medium.ttc",
]
def _load_cjk_font(size):
for path in CJK_FONT_PATHS:
try:
return ImageFont.truetype(path, size)
except OSError:
continue
return ImageFont.load_default(size=size)
def create_result_collage(annotated_images, counts, location_name, total_count):
if not annotated_images:
return None
cell_width = 640
padding = 10
header_height = 80
footer_height = 80
label_height = 40
n = len(annotated_images)
cols = 2 if n > 1 else 1
rows = math.ceil(n / cols)
resized = []
for img_arr in annotated_images:
pil_img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr
w, h = pil_img.size
new_h = int(h * cell_width / w)
resized.append(pil_img.resize((cell_width, new_h), Image.LANCZOS))
max_cell_h = max(img.size[1] for img in resized) + label_height
canvas_w = cols * cell_width + (cols + 1) * padding
canvas_h = header_height + rows * max_cell_h + (rows + 1) * padding + footer_height
canvas = Image.new("RGB", (canvas_w, canvas_h), (255, 255, 255))
draw = ImageDraw.Draw(canvas)
font_large = _load_cjk_font(36)
font_medium = _load_cjk_font(24)
header_text = f"地點: {location_name}"
bbox = draw.textbbox((0, 0), header_text, font=font_large)
text_w = bbox[2] - bbox[0]
draw.text(((canvas_w - text_w) / 2, (header_height - (bbox[3] - bbox[1])) / 2),
header_text, fill=(0, 0, 0), font=font_large)
for idx, (img, count) in enumerate(zip(resized, counts)):
row = idx // cols
col = idx % cols
x = padding + col * (cell_width + padding)
y = header_height + padding + row * (max_cell_h + padding)
canvas.paste(img, (x, y))
label_y = y + img.size[1]
draw.rectangle([x, label_y, x + cell_width, label_y + label_height], fill=(0, 0, 0))
label_text = f"照片 {idx + 1}: {count} 人"
lbbox = draw.textbbox((0, 0), label_text, font=font_medium)
lw = lbbox[2] - lbbox[0]
draw.text((x + (cell_width - lw) / 2, label_y + (label_height - (lbbox[3] - lbbox[1])) / 2),
label_text, fill=(255, 255, 255), font=font_medium)
footer_y = canvas_h - footer_height
footer_text = f"估計總人數: {total_count} 人"
fbbox = draw.textbbox((0, 0), footer_text, font=font_large)
fw = fbbox[2] - fbbox[0]
draw.text(((canvas_w - fw) / 2, footer_y + (footer_height - (fbbox[3] - fbbox[1])) / 2),
footer_text, fill=(0, 0, 0), font=font_large)
return canvas
def process_files(files, location_choice):
empty = ("", "", None)
if not files:
return empty
location_name = location_choice
valid_paths = []
for file_path in files:
try:
Image.open(file_path).verify()
valid_paths.append(file_path)
except Exception:
continue
if not valid_paths:
return empty
total_count = 0
annotated_arrays = []
collage_counts = []
detail_lines = []
for idx, fpath in enumerate(valid_paths, 1):
count, annotated_img = count_people_in_image(fpath)
total_count += count
if annotated_img is not None:
annotated_arrays.append(annotated_img)
collage_counts.append(count)
detail_lines.append(f"| {idx} | {count} |")
details_md = f"**地點: {location_name}** · 共 {len(valid_paths)} 張照片\n\n"
details_md += "| 照片 | 人數 |\n|------|------|\n" + "\n".join(detail_lines)
total_html = (
f"<div style='text-align:center;padding:24px 0'>"
f"<div style='font-size:56px;font-weight:700;line-height:1'>{total_count}</div>"
f"<div style='font-size:14px;opacity:0.6;margin-top:4px'>估計總人數</div>"
f"</div>"
)
collage = create_result_collage(annotated_arrays, collage_counts, location_name, total_count)
return total_html, details_md, collage
# --- Gradio UI ---
with gr.Blocks(
title="流動人數估計系統",
theme=gr.themes.Default(primary_hue="blue"),
) as demo:
gr.Markdown("# 流動人數估計系統\n")
location_dropdown = gr.Dropdown(
label="1. 選擇地點",
choices=LOCATION_NAMES,
value=None,
info="請先選擇地點,然後上傳照片。",
)
upload_group = gr.Group(visible=False)
with upload_group:
image_input = gr.File(
label="2. 上傳照片(可同時選取多張)",
file_count="multiple",
file_types=["image"],
type="filepath",
)
count_button = gr.Button("開始計算", variant="primary", size="lg")
total_display = gr.Markdown()
collage_output = gr.Image(label="結果拼圖", type="pil")
with gr.Accordion("詳細結果", open=False):
details_md = gr.Markdown()
location_dropdown.change(
fn=lambda loc: gr.Group(visible=loc is not None),
inputs=[location_dropdown],
outputs=[upload_group],
)
count_button.click(
fn=process_files,
inputs=[image_input, location_dropdown],
outputs=[total_display, details_md, collage_output],
show_progress="full",
)
if __name__ == "__main__":
demo.launch()