Landmark / app.py
mujtaba1212's picture
Update app.py
9c90ae6 verified
import sys, os, subprocess, io, json, base64
import numpy as np
from PIL import Image
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
# ── Clone official HRNet repo ─────────────────────────────────────────────────
REPO = '/app/hrnet_repo'
if not os.path.exists(REPO):
print("Cloning HRNet...")
subprocess.run(['git','clone','--depth=1',
'https://github.com/HRNet/HRNet-Human-Pose-Estimation.git', REPO], check=True)
print("Cloned.")
sys.path.insert(0, os.path.join(REPO, 'lib'))
# ── Write config file and load model ─────────────────────────────────────────
cfg_path = '/app/hrnet_w32.yaml'
with open(cfg_path, 'w') as f:
f.write("""
AUTO_RESUME: false
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATA_DIR: ''
GPUS: (0,)
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 4
PRINT_FREQ: 100
MODEL:
NAME: pose_hrnet
NUM_JOINTS: 19
PRETRAINED: ''
TARGET_TYPE: gaussian
IMAGE_SIZE:
- 256
- 320
HEATMAP_SIZE:
- 64
- 80
SIGMA: 2
EXTRA:
PRETRAINED_LAYERS:
- 'conv1'
- 'bn1'
- 'conv2'
- 'bn2'
- 'layer1'
- 'transition1'
- 'stage2'
- 'transition2'
- 'stage3'
- 'transition3'
- 'stage4'
FINAL_CONV_KERNEL: 1
STAGE2:
NUM_MODULES: 1
NUM_BRANCHES: 2
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
NUM_CHANNELS:
- 32
- 64
FUSE_METHOD: SUM
STAGE3:
NUM_MODULES: 4
NUM_BRANCHES: 3
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
NUM_CHANNELS:
- 32
- 64
- 128
FUSE_METHOD: SUM
STAGE4:
NUM_MODULES: 3
NUM_BRANCHES: 4
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
- 4
NUM_CHANNELS:
- 32
- 64
- 128
- 256
FUSE_METHOD: SUM
""")
from config import cfg as _cfg
_cfg.defrost()
_cfg.merge_from_file(cfg_path)
_cfg.freeze()
from models.pose_hrnet import get_pose_net
model = get_pose_net(_cfg, is_train=False)
print("Built model with official HRNet.")
# Load weights
print("Loading weights...")
model_path = hf_hub_download(
repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
filename="best_model.pth"
)
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Missing: {len(missing)}, Unexpected: {len(unexpected)}")
if missing: print(f"Missing sample: {missing[:2]}")
model.eval()
print("Model ready!")
# ── Constants ─────────────────────────────────────────────────────────────────
LM_IDS = ['S','N','Or','Po','ANS','PNS','A','U1tip','L1tip','B',
'Pog','Me','Gn','Go','Co','L1ap','U1ap','U6','L6']
INPUT_W, INPUT_H = 256, 320
def preprocess(pil_img):
img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - [0.485,0.456,0.406]) / [0.229,0.224,0.225]
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).float()
def heatmap_to_coords(hm_np):
coords = {}
nj, hh, hw = hm_np.shape
for j in range(min(nj, len(LM_IDS))):
hm = hm_np[j]
idx = int(hm.argmax())
py, px = divmod(idx, hw)
if 1 <= px < hw-1 and 1 <= py < hh-1:
spx = px + 0.25*np.sign(float(hm[py,px+1]-hm[py,px-1]))
spy = py + 0.25*np.sign(float(hm[py+1,px]-hm[py-1,px]))
else:
spx, spy = float(px), float(py)
coords[LM_IDS[j]] = {
"x": round(float(np.clip(spx/hw, 0, 1)), 4),
"y": round(float(np.clip(spy/hh, 0, 1)), 4),
"confidence": round(float(hm.max()), 3)
}
return coords
def run_detection(pil_img):
tensor = preprocess(pil_img)
with torch.no_grad():
out = model(tensor)
hm = out[0].numpy() if isinstance(out,(list,tuple)) else out[0].numpy()
print(f"hm shape:{hm.shape} max:{hm.max():.3f}")
return heatmap_to_coords(hm)
# ── FastAPI ───────────────────────────────────────────────────────────────────
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@app.get("/", response_class=HTMLResponse)
async def root():
return "<h2>OrthoTimes Landmark Detection β€” Running</h2>"
@app.get("/health")
async def health():
return {"status":"ok","landmarks":len(LM_IDS)}
@app.post("/detect")
async def detect(request: Request):
try:
body = await request.json()
img_bytes = base64.b64decode(body.get("image_b64",""))
pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
coords = run_detection(pil_img)
return JSONResponse({"landmarks": coords})
except Exception as e:
import traceback
print(traceback.format_exc())
return JSONResponse({"error": str(e)}, status_code=500)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)