Spaces:
Sleeping
Sleeping
| import os, io, json, logging | |
| from typing import List, Dict, Any | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import tensorflow as tf | |
| from huggingface_hub import snapshot_download,hf_hub_download | |
| import cv2 | |
| import gradio as gr | |
| cnn_model = None | |
| last_conv_layer_name = None | |
| # optional gatekeep | |
| try: | |
| HAS_OPENCV = True | |
| except Exception: | |
| HAS_OPENCV = False | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("skinclassify") | |
| # ---------------------- Config ---------------------- | |
| DERM_MODEL_ID = os.getenv("DERM_MODEL_ID", "google/derm-foundation") | |
| DERM_LOCAL_DIR = os.getenv("DERM_LOCAL_DIR", "") | |
| MODEL_REPO = "ChantaroNtw/Skin-model" | |
| HEAD_PATH = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="mlp_best.keras" | |
| ) | |
| MU_PATH = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="mu.npy" | |
| ) | |
| SD_PATH = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="sd.npy" | |
| ) | |
| THRESHOLDS_PATH = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="mlp_thresholds.npy" | |
| ) | |
| LABELS_PATH = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="class_names.json" | |
| ) | |
| NPZ_PATH = os.getenv("NPZ_PATH", "") | |
| TOPK = int(os.getenv("TOPK", "5")) | |
| # Gate keep params | |
| MIN_W, MIN_H = int(os.getenv("MIN_W", "128")), int(os.getenv("MIN_H", "128")) | |
| MIN_ASPECT, MAX_ASPECT = float(os.getenv("MIN_ASPECT", "0.5")), float(os.getenv("MAX_ASPECT", "2.0")) | |
| MIN_BRIGHT, MAX_BRIGHT = float(os.getenv("MIN_BRIGHT", "20")), float(os.getenv("MAX_BRIGHT", "235")) | |
| MIN_SKIN_RATIO = float(os.getenv("MIN_SKIN_RATIO", "0.15")) | |
| MIN_SHARPNESS = float(os.getenv("MIN_SHARPNESS", "30.0")) | |
| # Performance: กัน OOM บน Free Space | |
| os.environ.setdefault("TF_NUM_INTRAOP_THREADS", "1") | |
| os.environ.setdefault("TF_NUM_INTEROP_THREADS", "1") | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") | |
| MAX_UPLOAD = int(os.getenv("MAX_UPLOAD", str(6 * 1024 * 1024))) # 6MB | |
| DF_SIZE = (448, 448) | |
| app = FastAPI(title="SkinClassify API (Derm-Foundation)", version="2.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=os.getenv("ALLOW_ORIGINS", "*").split(","), | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------------------- Load labels ---------------------- | |
| def _load_json(path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| if os.path.exists(LABELS_PATH): | |
| CLASS_NAMES: List[str] = _load_json(LABELS_PATH) | |
| logger.info(f"Loaded class_names from {LABELS_PATH}") | |
| elif NPZ_PATH and os.path.exists(NPZ_PATH): | |
| arr = np.load(NPZ_PATH, allow_pickle=True) | |
| if "class_names" in arr: | |
| CLASS_NAMES = list(arr["class_names"]) | |
| logger.info(f"Loaded class_names from {NPZ_PATH}:class_names") | |
| else: | |
| raise RuntimeError("No LABELS_PATH and class_names not found in NPZ") | |
| else: | |
| raise RuntimeError("LABELS_PATH not found and NPZ_PATH not provided.") | |
| C = len(CLASS_NAMES) | |
| # ---------------------- Load head (.keras via Keras3) ---------------------- | |
| def load_head_keras3(path: str): | |
| import keras | |
| logger.info(f"Loading head (.keras) via Keras3 from {path}") | |
| return keras.saving.load_model(path, compile=False) | |
| head = load_head_keras3(HEAD_PATH) | |
| # ---------------------- Load mu/sd ---------------------- | |
| def _load_mu_sd(): | |
| if os.path.exists(MU_PATH) and os.path.exists(SD_PATH): | |
| mu_ = np.load(MU_PATH).astype("float32") | |
| sd_ = np.load(SD_PATH).astype("float32") | |
| return mu_, sd_ | |
| if NPZ_PATH and os.path.exists(NPZ_PATH): | |
| arr = np.load(NPZ_PATH, allow_pickle=True) | |
| mu_ = arr["mu"].astype("float32") | |
| sd_ = arr["sd"].astype("float32") | |
| return mu_, sd_ | |
| raise RuntimeError("mu/sd not found (MU_PATH/SD_PATH or NPZ_PATH).") | |
| mu, sd = _load_mu_sd() | |
| logger.info("Loaded mu/sd") | |
| # ---------------------- Load thresholds ---------------------- | |
| if os.path.exists(THRESHOLDS_PATH): | |
| best_th = np.load(THRESHOLDS_PATH).astype("float32") | |
| if best_th.shape[0] != C: | |
| raise RuntimeError(f"thresholds size {best_th.shape[0]} != #classes {C}") | |
| else: | |
| logger.warning("THRESHOLDS_PATH not found -> default 0.5 for all classes") | |
| best_th = np.full(C, 0.5, dtype="float32") | |
| # ---------------------- Load derm-foundation ---------------------- | |
| from huggingface_hub import snapshot_download | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| CACHE_DIR = os.getenv("HF_HOME", "/app/.cache") | |
| LOCAL_DERM = os.getenv("DERM_LOCAL_DIR", "/app/derm-foundation") | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| os.makedirs(LOCAL_DERM, exist_ok=True) | |
| logger.info("Loading Derm Foundation (first time may take a while)...") | |
| try: | |
| if os.path.isdir(LOCAL_DERM) and os.path.exists(os.path.join(LOCAL_DERM, "saved_model.pb")): | |
| derm_dir = LOCAL_DERM | |
| logger.info(f"Loaded Derm Foundation from local: {derm_dir}") | |
| else: | |
| logger.info(f"Downloading derm-foundation from hub: {DERM_MODEL_ID}") | |
| derm_dir = snapshot_download( | |
| repo_id=DERM_MODEL_ID, | |
| repo_type="model", | |
| allow_patterns=["saved_model.pb", "variables/*"], | |
| token=HF_TOKEN, | |
| cache_dir=CACHE_DIR, | |
| local_dir=LOCAL_DERM, | |
| local_dir_use_symlinks=False, | |
| ) | |
| logger.info(f"Derm Foundation downloaded to: {derm_dir}") | |
| derm = tf.saved_model.load(derm_dir) | |
| infer = derm.signatures["serving_default"] | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to load derm-foundation: {e}. " | |
| "Make sure you accepted the model terms and set HF_TOKEN in Space Settings." | |
| ) | |
| import tempfile | |
| def create_tf_example(img_arr): | |
| img_uint8 = (img_arr * 255).astype(np.uint8) | |
| encoded = tf.io.encode_jpeg(img_uint8).numpy() | |
| feature = { | |
| "image/encoded": tf.train.Feature( | |
| bytes_list=tf.train.BytesList(value=[encoded]) | |
| ) | |
| } | |
| example = tf.train.Example( | |
| features=tf.train.Features(feature=feature) | |
| ) | |
| return example.SerializeToString() | |
| def get_embedding(img_arr): | |
| example = create_tf_example(img_arr) | |
| tensor = tf.constant([example]) | |
| out = infer(inputs=tensor) | |
| return out["embedding"].numpy()[0] | |
| def cosine_similarity(a, b): | |
| return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8) | |
| def make_patch_heatmap(img, size=32, stride=16): | |
| img = img.resize((224, 224)) | |
| img_arr = np.array(img) / 255.0 | |
| base_emb = get_embedding(img_arr) | |
| examples = [] | |
| coords = [] | |
| for y in range(0, 224, stride): | |
| for x in range(0, 224, stride): | |
| occluded = img_arr.copy() | |
| occluded[y:y+size, x:x+size] = 0 | |
| examples.append(create_tf_example(occluded)) | |
| coords.append((y, x)) | |
| tensor = tf.constant(examples) | |
| outputs = infer(inputs=tensor)["embedding"].numpy() | |
| heatmap = np.zeros((224, 224)) | |
| for i, (y, x) in enumerate(coords): | |
| diff = np.linalg.norm(base_emb - outputs[i]) | |
| heatmap[y:y+size, x:x+size] = diff | |
| heatmap = cv2.normalize(heatmap, None, 0, 1, cv2.NORM_MINMAX) | |
| return heatmap | |
| # ---------------------- Utils ---------------------- | |
| def pil_to_png_bytes_448(pil_img: Image.Image) -> bytes: | |
| pil_img = pil_img.convert("RGB").resize(DF_SIZE) | |
| arr = np.array(pil_img, dtype=np.uint8) | |
| return tf.io.encode_png(arr).numpy() | |
| def _brightness(np_img_rgb: np.ndarray) -> float: | |
| r,g,b = np_img_rgb[...,0], np_img_rgb[...,1], np_img_rgb[...,2] | |
| y = 0.2126*r + 0.7152*g + 0.0722*b | |
| return float(y.mean()) | |
| def _sharpness(np_img_rgb: np.ndarray) -> float: | |
| if not HAS_OPENCV: | |
| return 100.0 | |
| gray = cv2.cvtColor(np_img_rgb, cv2.COLOR_RGB2GRAY) | |
| return float(cv2.Laplacian(gray, cv2.CV_64F).var()) | |
| def _skin_ratio(np_img_rgb: np.ndarray) -> float: | |
| img = Image.fromarray(np_img_rgb).convert("YCbCr") | |
| ycbcr = np.array(img) | |
| Cb = ycbcr[...,1]; Cr = ycbcr[...,2] | |
| mask = (Cb >= 77) & (Cb <= 127) & (Cr >= 133) & (Cr <= 173) | |
| return float(mask.mean()) | |
| def gatekeep_image(img_bytes: bytes) -> Dict[str, Any]: | |
| try: | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| except Exception: | |
| return {"ok": False, "reasons": ["invalid_image"], "metrics": {}} | |
| w,h = img.size | |
| metrics = {"width": w, "height": h} | |
| reasons = [] | |
| if w < MIN_W or h < MIN_H: | |
| reasons.append("too_small") | |
| aspect = w / h | |
| metrics["aspect"] = float(aspect) | |
| if not (MIN_ASPECT <= aspect <= MAX_ASPECT): | |
| reasons.append("weird_aspect") | |
| np_img = np.array(img) | |
| bright = _brightness(np_img) | |
| metrics["brightness"] = bright | |
| if bright < MIN_BRIGHT: reasons.append("too_dark") | |
| if bright > MAX_BRIGHT: reasons.append("too_bright") | |
| if HAS_OPENCV: | |
| sharp = _sharpness(np_img) | |
| metrics["sharpness"] = sharp | |
| if sharp < MIN_SHARPNESS: reasons.append("too_blurry") | |
| ratio = _skin_ratio(np_img) | |
| metrics["skin_ratio"] = ratio | |
| if ratio < MIN_SKIN_RATIO: reasons.append("not_skin_like") | |
| return {"ok": len(reasons)==0, "reasons": reasons, "metrics": metrics} | |
| def predict_probs(img_bytes: bytes) -> np.ndarray: | |
| pil = Image.open(io.BytesIO(img_bytes)).convert("RGB").resize(DF_SIZE) | |
| by = pil_to_png_bytes_448(pil) | |
| ex = tf.train.Example(features=tf.train.Features( | |
| feature={'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[by]))} | |
| )).SerializeToString() | |
| out = infer(inputs=tf.constant([ex])) | |
| if "embedding" not in out: | |
| raise RuntimeError(f"Unexpected derm-foundation outputs: {list(out.keys())}") | |
| emb = out["embedding"].numpy().astype("float32") # (1, 6144) | |
| z = (emb - mu) / (sd + 1e-6) | |
| probs = head.predict(z, verbose=0)[0] # head (.keras) โดยตรง | |
| return probs | |
| # ---------------------- Endpoints ---------------------- | |
| def health(): | |
| return { | |
| "ok": True, | |
| "classes": len(CLASS_NAMES), | |
| "derm": DERM_MODEL_ID or DERM_LOCAL_DIR, | |
| "has_opencv": HAS_OPENCV | |
| } | |
| async def predict(request: Request, file: UploadFile = File(...)): | |
| cl = request.headers.get("content-length") | |
| if cl and int(cl) > MAX_UPLOAD: | |
| raise HTTPException(413, "File too large") | |
| img_bytes = await file.read() | |
| if len(img_bytes) > MAX_UPLOAD: | |
| raise HTTPException(413, "File too large") | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| gate = gatekeep_image(img_bytes) | |
| if not gate["ok"]: | |
| return JSONResponse(status_code=200, content={ | |
| "ok": False, | |
| "reason": "gate_reject", | |
| "gate": gate | |
| }) | |
| probs = predict_probs(img_bytes) | |
| order = np.argsort(probs)[::-1] | |
| top = [{"label": CLASS_NAMES[i], "prob": float(probs[i])} for i in order[:TOPK]] | |
| preds = (probs >= best_th).astype(np.int32) | |
| positives = [{"label": CLASS_NAMES[i], "prob": float(probs[i])} | |
| for i in range(C) if preds[i] == 1] | |
| heatmap = make_gradcam_heatmap(image) | |
| overlay = overlay_heatmap(image, heatmap) | |
| return { | |
| "ok": True, | |
| "gate": gate, | |
| "result": { | |
| "type": "multilabel", | |
| "thresholds_used": {CLASS_NAMES[i]: float(best_th[i]) for i in range(C)}, | |
| "positives": positives, | |
| "topk": top, | |
| "probs": {CLASS_NAMES[i]: float(probs[i]) for i in range(C)} | |
| }, | |
| "has_heatmap": overlay is not None | |
| } | |
| #----------------------------Over_lay------------------------------- | |
| def overlay_heatmap(img, heatmap): | |
| img = img.resize((224, 224)) | |
| img_np = np.array(img) | |
| heatmap = cv2.GaussianBlur(heatmap, (21, 21), 0) | |
| heatmap = np.power(heatmap, 1.5) | |
| heatmap = cv2.normalize(heatmap, None, 0, 1, cv2.NORM_MINMAX) | |
| heatmap_uint8 = np.uint8(255 * heatmap) | |
| heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) | |
| threshold = np.mean(heatmap) + np.std(heatmap) | |
| mask = heatmap > threshold | |
| overlay = img_np.copy() | |
| overlay[mask] = ( | |
| 0.6 * overlay[mask] + | |
| 0.4 * heatmap_color[mask] | |
| ).astype(np.uint8) | |
| return overlay | |
| #------------------------------UI----------------------------------- | |
| def gradio_predict(image): | |
| if image is None: | |
| return {}, None, "❗ Upload image first" | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| img_bytes = buf.getvalue() | |
| gate = gatekeep_image(img_bytes) | |
| if not gate["ok"]: | |
| return {}, None, "❌ Image rejected" | |
| probs = predict_probs(img_bytes) | |
| order = np.argsort(probs)[::-1] | |
| result = { | |
| CLASS_NAMES[i]: float(probs[i]) | |
| for i in order[:5] | |
| } | |
| #heatmap = make_patch_heatmap(image) | |
| # stage 1 | |
| coarse_map = make_patch_heatmap(image) | |
| # หา region ที่สำคัญ | |
| mask = coarse_map > np.mean(coarse_map) | |
| # stage 2 เฉพาะ mask | |
| refined_map = refine_heatmap(image, mask) | |
| overlay = overlay_heatmap(image, refined_map) | |
| return result, overlay, "✅ Done" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧠 Skin Disease Classifier") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| btn = gr.Button("🔍 Analyze", variant="primary") | |
| with gr.Column(): | |
| output_label = gr.Label(num_top_classes=5, label="Prediction") | |
| output_image = gr.Image(label="Heatmap") | |
| status = gr.Markdown("") | |
| btn.click( | |
| gradio_predict, | |
| inputs=image_input, | |
| outputs=[output_label, output_image] | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |