Spaces:
Running
Running
File size: 5,678 Bytes
9c90ae6 0858169 b2187f5 9c90ae6 0858169 9c90ae6 854441c 9c90ae6 0858169 854441c fd19e5e 9c90ae6 854441c 9c90ae6 854441c 9c90ae6 854441c 9c90ae6 854441c 9c90ae6 854441c 9c90ae6 854441c 9c90ae6 854441c 0858169 fd19e5e 9c90ae6 0858169 9c90ae6 0858169 fd19e5e 9c90ae6 fd19e5e 0858169 fd19e5e 9c90ae6 fd19e5e 854441c fd19e5e 9c90ae6 a972fed 9c90ae6 fd19e5e 0858169 fd19e5e 854441c 9c90ae6 854441c fd19e5e b2187f5 9c90ae6 b2187f5 854441c 9c90ae6 b2187f5 fd19e5e b2187f5 9c90ae6 fd19e5e b2187f5 fd19e5e 035422d 9c90ae6 0858169 b2187f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import sys, os, subprocess, io, json, base64
import numpy as np
from PIL import Image
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
# ββ Clone official HRNet repo βββββββββββββββββββββββββββββββββββββββββββββββββ
REPO = '/app/hrnet_repo'
if not os.path.exists(REPO):
print("Cloning HRNet...")
subprocess.run(['git','clone','--depth=1',
'https://github.com/HRNet/HRNet-Human-Pose-Estimation.git', REPO], check=True)
print("Cloned.")
sys.path.insert(0, os.path.join(REPO, 'lib'))
# ββ Write config file and load model βββββββββββββββββββββββββββββββββββββββββ
cfg_path = '/app/hrnet_w32.yaml'
with open(cfg_path, 'w') as f:
f.write("""
AUTO_RESUME: false
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATA_DIR: ''
GPUS: (0,)
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 4
PRINT_FREQ: 100
MODEL:
NAME: pose_hrnet
NUM_JOINTS: 19
PRETRAINED: ''
TARGET_TYPE: gaussian
IMAGE_SIZE:
- 256
- 320
HEATMAP_SIZE:
- 64
- 80
SIGMA: 2
EXTRA:
PRETRAINED_LAYERS:
- 'conv1'
- 'bn1'
- 'conv2'
- 'bn2'
- 'layer1'
- 'transition1'
- 'stage2'
- 'transition2'
- 'stage3'
- 'transition3'
- 'stage4'
FINAL_CONV_KERNEL: 1
STAGE2:
NUM_MODULES: 1
NUM_BRANCHES: 2
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
NUM_CHANNELS:
- 32
- 64
FUSE_METHOD: SUM
STAGE3:
NUM_MODULES: 4
NUM_BRANCHES: 3
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
NUM_CHANNELS:
- 32
- 64
- 128
FUSE_METHOD: SUM
STAGE4:
NUM_MODULES: 3
NUM_BRANCHES: 4
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
- 4
NUM_CHANNELS:
- 32
- 64
- 128
- 256
FUSE_METHOD: SUM
""")
from config import cfg as _cfg
_cfg.defrost()
_cfg.merge_from_file(cfg_path)
_cfg.freeze()
from models.pose_hrnet import get_pose_net
model = get_pose_net(_cfg, is_train=False)
print("Built model with official HRNet.")
# Load weights
print("Loading weights...")
model_path = hf_hub_download(
repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
filename="best_model.pth"
)
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Missing: {len(missing)}, Unexpected: {len(unexpected)}")
if missing: print(f"Missing sample: {missing[:2]}")
model.eval()
print("Model ready!")
# ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
LM_IDS = ['S','N','Or','Po','ANS','PNS','A','U1tip','L1tip','B',
'Pog','Me','Gn','Go','Co','L1ap','U1ap','U6','L6']
INPUT_W, INPUT_H = 256, 320
def preprocess(pil_img):
img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - [0.485,0.456,0.406]) / [0.229,0.224,0.225]
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).float()
def heatmap_to_coords(hm_np):
coords = {}
nj, hh, hw = hm_np.shape
for j in range(min(nj, len(LM_IDS))):
hm = hm_np[j]
idx = int(hm.argmax())
py, px = divmod(idx, hw)
if 1 <= px < hw-1 and 1 <= py < hh-1:
spx = px + 0.25*np.sign(float(hm[py,px+1]-hm[py,px-1]))
spy = py + 0.25*np.sign(float(hm[py+1,px]-hm[py-1,px]))
else:
spx, spy = float(px), float(py)
coords[LM_IDS[j]] = {
"x": round(float(np.clip(spx/hw, 0, 1)), 4),
"y": round(float(np.clip(spy/hh, 0, 1)), 4),
"confidence": round(float(hm.max()), 3)
}
return coords
def run_detection(pil_img):
tensor = preprocess(pil_img)
with torch.no_grad():
out = model(tensor)
hm = out[0].numpy() if isinstance(out,(list,tuple)) else out[0].numpy()
print(f"hm shape:{hm.shape} max:{hm.max():.3f}")
return heatmap_to_coords(hm)
# ββ FastAPI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@app.get("/", response_class=HTMLResponse)
async def root():
return "<h2>OrthoTimes Landmark Detection β Running</h2>"
@app.get("/health")
async def health():
return {"status":"ok","landmarks":len(LM_IDS)}
@app.post("/detect")
async def detect(request: Request):
try:
body = await request.json()
img_bytes = base64.b64decode(body.get("image_b64",""))
pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
coords = run_detection(pil_img)
return JSONResponse({"landmarks": coords})
except Exception as e:
import traceback
print(traceback.format_exc())
return JSONResponse({"error": str(e)}, status_code=500)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |