Spaces:
Running
Running
| import sys, os, subprocess, io, json, base64 | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| # ββ Clone official HRNet repo βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO = '/app/hrnet_repo' | |
| if not os.path.exists(REPO): | |
| print("Cloning HRNet...") | |
| subprocess.run(['git','clone','--depth=1', | |
| 'https://github.com/HRNet/HRNet-Human-Pose-Estimation.git', REPO], check=True) | |
| print("Cloned.") | |
| sys.path.insert(0, os.path.join(REPO, 'lib')) | |
| # ββ Write config file and load model βββββββββββββββββββββββββββββββββββββββββ | |
| cfg_path = '/app/hrnet_w32.yaml' | |
| with open(cfg_path, 'w') as f: | |
| f.write(""" | |
| AUTO_RESUME: false | |
| CUDNN: | |
| BENCHMARK: true | |
| DETERMINISTIC: false | |
| ENABLED: true | |
| DATA_DIR: '' | |
| GPUS: (0,) | |
| OUTPUT_DIR: 'output' | |
| LOG_DIR: 'log' | |
| WORKERS: 4 | |
| PRINT_FREQ: 100 | |
| MODEL: | |
| NAME: pose_hrnet | |
| NUM_JOINTS: 19 | |
| PRETRAINED: '' | |
| TARGET_TYPE: gaussian | |
| IMAGE_SIZE: | |
| - 256 | |
| - 320 | |
| HEATMAP_SIZE: | |
| - 64 | |
| - 80 | |
| SIGMA: 2 | |
| EXTRA: | |
| PRETRAINED_LAYERS: | |
| - 'conv1' | |
| - 'bn1' | |
| - 'conv2' | |
| - 'bn2' | |
| - 'layer1' | |
| - 'transition1' | |
| - 'stage2' | |
| - 'transition2' | |
| - 'stage3' | |
| - 'transition3' | |
| - 'stage4' | |
| FINAL_CONV_KERNEL: 1 | |
| STAGE2: | |
| NUM_MODULES: 1 | |
| NUM_BRANCHES: 2 | |
| BLOCK: BASIC | |
| NUM_BLOCKS: | |
| - 4 | |
| - 4 | |
| NUM_CHANNELS: | |
| - 32 | |
| - 64 | |
| FUSE_METHOD: SUM | |
| STAGE3: | |
| NUM_MODULES: 4 | |
| NUM_BRANCHES: 3 | |
| BLOCK: BASIC | |
| NUM_BLOCKS: | |
| - 4 | |
| - 4 | |
| - 4 | |
| NUM_CHANNELS: | |
| - 32 | |
| - 64 | |
| - 128 | |
| FUSE_METHOD: SUM | |
| STAGE4: | |
| NUM_MODULES: 3 | |
| NUM_BRANCHES: 4 | |
| BLOCK: BASIC | |
| NUM_BLOCKS: | |
| - 4 | |
| - 4 | |
| - 4 | |
| - 4 | |
| NUM_CHANNELS: | |
| - 32 | |
| - 64 | |
| - 128 | |
| - 256 | |
| FUSE_METHOD: SUM | |
| """) | |
| from config import cfg as _cfg | |
| _cfg.defrost() | |
| _cfg.merge_from_file(cfg_path) | |
| _cfg.freeze() | |
| from models.pose_hrnet import get_pose_net | |
| model = get_pose_net(_cfg, is_train=False) | |
| print("Built model with official HRNet.") | |
| # Load weights | |
| print("Loading weights...") | |
| model_path = hf_hub_download( | |
| repo_id="cwlachap/hrnet-cephalometric-landmark-detection", | |
| filename="best_model.pth" | |
| ) | |
| checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) | |
| state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint)) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| print(f"Missing: {len(missing)}, Unexpected: {len(unexpected)}") | |
| if missing: print(f"Missing sample: {missing[:2]}") | |
| model.eval() | |
| print("Model ready!") | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LM_IDS = ['S','N','Or','Po','ANS','PNS','A','U1tip','L1tip','B', | |
| 'Pog','Me','Gn','Go','Co','L1ap','U1ap','U6','L6'] | |
| INPUT_W, INPUT_H = 256, 320 | |
| def preprocess(pil_img): | |
| img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR) | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| arr = (arr - [0.485,0.456,0.406]) / [0.229,0.224,0.225] | |
| return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).float() | |
| def heatmap_to_coords(hm_np): | |
| coords = {} | |
| nj, hh, hw = hm_np.shape | |
| for j in range(min(nj, len(LM_IDS))): | |
| hm = hm_np[j] | |
| idx = int(hm.argmax()) | |
| py, px = divmod(idx, hw) | |
| if 1 <= px < hw-1 and 1 <= py < hh-1: | |
| spx = px + 0.25*np.sign(float(hm[py,px+1]-hm[py,px-1])) | |
| spy = py + 0.25*np.sign(float(hm[py+1,px]-hm[py-1,px])) | |
| else: | |
| spx, spy = float(px), float(py) | |
| coords[LM_IDS[j]] = { | |
| "x": round(float(np.clip(spx/hw, 0, 1)), 4), | |
| "y": round(float(np.clip(spy/hh, 0, 1)), 4), | |
| "confidence": round(float(hm.max()), 3) | |
| } | |
| return coords | |
| def run_detection(pil_img): | |
| tensor = preprocess(pil_img) | |
| with torch.no_grad(): | |
| out = model(tensor) | |
| hm = out[0].numpy() if isinstance(out,(list,tuple)) else out[0].numpy() | |
| print(f"hm shape:{hm.shape} max:{hm.max():.3f}") | |
| return heatmap_to_coords(hm) | |
| # ββ FastAPI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI() | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| async def root(): | |
| return "<h2>OrthoTimes Landmark Detection β Running</h2>" | |
| async def health(): | |
| return {"status":"ok","landmarks":len(LM_IDS)} | |
| async def detect(request: Request): | |
| try: | |
| body = await request.json() | |
| img_bytes = base64.b64decode(body.get("image_b64","")) | |
| pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB') | |
| coords = run_detection(pil_img) | |
| return JSONResponse({"landmarks": coords}) | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| return JSONResponse({"error": str(e)}, status_code=500) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |