ChantaroNtw commited on
Commit
5d8e5dc
·
verified ·
1 Parent(s): 4c28389

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, json, logging
2
+ from typing import List, Dict, Any
3
+
4
+ import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Request
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+ from PIL import Image
9
+ import tensorflow as tf
10
+ from huggingface_hub import snapshot_download,hf_hub_download
11
+
12
+ # optional gatekeep
13
+ try:
14
+ import cv2
15
+ HAS_OPENCV = True
16
+ except Exception:
17
+ HAS_OPENCV = False
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger("skinclassify")
21
+
22
+ # ---------------------- Config ----------------------
23
+ DERM_MODEL_ID = os.getenv("DERM_MODEL_ID", "google/derm-foundation")
24
+ DERM_LOCAL_DIR = os.getenv("DERM_LOCAL_DIR", "")
25
+
26
+ MODEL_REPO = "ChantaroNtw/Skin-model"
27
+
28
+ HEAD_PATH = hf_hub_download(
29
+ repo_id=MODEL_REPO,
30
+ filename="mlp_best.keras"
31
+ )
32
+
33
+ MU_PATH = hf_hub_download(
34
+ repo_id=MODEL_REPO,
35
+ filename="mu.npy"
36
+ )
37
+
38
+ SD_PATH = hf_hub_download(
39
+ repo_id=MODEL_REPO,
40
+ filename="sd.npy"
41
+ )
42
+
43
+ THRESHOLDS_PATH = hf_hub_download(
44
+ repo_id=MODEL_REPO,
45
+ filename="mlp_thresholds.npy"
46
+ )
47
+
48
+ LABELS_PATH = hf_hub_download(
49
+ repo_id=MODEL_REPO,
50
+ filename="class_names.json"
51
+ )
52
+
53
+ NPZ_PATH = os.getenv("NPZ_PATH", "")
54
+
55
+ TOPK = int(os.getenv("TOPK", "5"))
56
+
57
+ # Gate keep params
58
+ MIN_W, MIN_H = int(os.getenv("MIN_W", "128")), int(os.getenv("MIN_H", "128"))
59
+ MIN_ASPECT, MAX_ASPECT = float(os.getenv("MIN_ASPECT", "0.5")), float(os.getenv("MAX_ASPECT", "2.0"))
60
+ MIN_BRIGHT, MAX_BRIGHT = float(os.getenv("MIN_BRIGHT", "20")), float(os.getenv("MAX_BRIGHT", "235"))
61
+ MIN_SKIN_RATIO = float(os.getenv("MIN_SKIN_RATIO", "0.15"))
62
+ MIN_SHARPNESS = float(os.getenv("MIN_SHARPNESS", "30.0"))
63
+
64
+ # Performance: กัน OOM บน Free Space
65
+ os.environ.setdefault("TF_NUM_INTRAOP_THREADS", "1")
66
+ os.environ.setdefault("TF_NUM_INTEROP_THREADS", "1")
67
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
68
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
69
+
70
+ MAX_UPLOAD = int(os.getenv("MAX_UPLOAD", str(6 * 1024 * 1024))) # 6MB
71
+ DF_SIZE = (448, 448)
72
+
73
+ app = FastAPI(title="SkinClassify API (Derm-Foundation)", version="2.0.0")
74
+ app.add_middleware(
75
+ CORSMiddleware,
76
+ allow_origins=os.getenv("ALLOW_ORIGINS", "*").split(","),
77
+ allow_credentials=True,
78
+ allow_methods=["*"],
79
+ allow_headers=["*"],
80
+ )
81
+
82
+ # ---------------------- Load labels ----------------------
83
+ def _load_json(path):
84
+ with open(path, "r", encoding="utf-8") as f:
85
+ return json.load(f)
86
+
87
+ if os.path.exists(LABELS_PATH):
88
+ CLASS_NAMES: List[str] = _load_json(LABELS_PATH)
89
+ logger.info(f"Loaded class_names from {LABELS_PATH}")
90
+ elif NPZ_PATH and os.path.exists(NPZ_PATH):
91
+ arr = np.load(NPZ_PATH, allow_pickle=True)
92
+ if "class_names" in arr:
93
+ CLASS_NAMES = list(arr["class_names"])
94
+ logger.info(f"Loaded class_names from {NPZ_PATH}:class_names")
95
+ else:
96
+ raise RuntimeError("No LABELS_PATH and class_names not found in NPZ")
97
+ else:
98
+ raise RuntimeError("LABELS_PATH not found and NPZ_PATH not provided.")
99
+ C = len(CLASS_NAMES)
100
+
101
+ # ---------------------- Load head (.keras via Keras3) ----------------------
102
+ def load_head_keras3(path: str):
103
+ import keras
104
+ logger.info(f"Loading head (.keras) via Keras3 from {path}")
105
+ return keras.saving.load_model(path, compile=False)
106
+
107
+ head = load_head_keras3(HEAD_PATH)
108
+
109
+ # ---------------------- Load mu/sd ----------------------
110
+ def _load_mu_sd():
111
+ if os.path.exists(MU_PATH) and os.path.exists(SD_PATH):
112
+ mu_ = np.load(MU_PATH).astype("float32")
113
+ sd_ = np.load(SD_PATH).astype("float32")
114
+ return mu_, sd_
115
+ if NPZ_PATH and os.path.exists(NPZ_PATH):
116
+ arr = np.load(NPZ_PATH, allow_pickle=True)
117
+ mu_ = arr["mu"].astype("float32")
118
+ sd_ = arr["sd"].astype("float32")
119
+ return mu_, sd_
120
+ raise RuntimeError("mu/sd not found (MU_PATH/SD_PATH or NPZ_PATH).")
121
+
122
+ mu, sd = _load_mu_sd()
123
+ logger.info("Loaded mu/sd")
124
+
125
+ # ---------------------- Load thresholds ----------------------
126
+ if os.path.exists(THRESHOLDS_PATH):
127
+ best_th = np.load(THRESHOLDS_PATH).astype("float32")
128
+ if best_th.shape[0] != C:
129
+ raise RuntimeError(f"thresholds size {best_th.shape[0]} != #classes {C}")
130
+ else:
131
+ logger.warning("THRESHOLDS_PATH not found -> default 0.5 for all classes")
132
+ best_th = np.full(C, 0.5, dtype="float32")
133
+
134
+ # ---------------------- Load derm-foundation ----------------------
135
+
136
+ from huggingface_hub import snapshot_download
137
+
138
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
139
+ CACHE_DIR = os.getenv("HF_HOME", "/app/.cache")
140
+ LOCAL_DERM = os.getenv("DERM_LOCAL_DIR", "/app/derm-foundation")
141
+
142
+ os.makedirs(CACHE_DIR, exist_ok=True)
143
+ os.makedirs(LOCAL_DERM, exist_ok=True)
144
+
145
+ logger.info("Loading Derm Foundation (first time may take a while)...")
146
+ try:
147
+ if os.path.isdir(LOCAL_DERM) and os.path.exists(os.path.join(LOCAL_DERM, "saved_model.pb")):
148
+ derm_dir = LOCAL_DERM
149
+ logger.info(f"Loaded Derm Foundation from local: {derm_dir}")
150
+ else:
151
+ logger.info(f"Downloading derm-foundation from hub: {DERM_MODEL_ID}")
152
+ derm_dir = snapshot_download(
153
+ repo_id=DERM_MODEL_ID,
154
+ repo_type="model",
155
+ allow_patterns=["saved_model.pb", "variables/*"],
156
+ token=HF_TOKEN,
157
+ cache_dir=CACHE_DIR,
158
+ local_dir=LOCAL_DERM,
159
+ local_dir_use_symlinks=False,
160
+ )
161
+ logger.info(f"Derm Foundation downloaded to: {derm_dir}")
162
+
163
+ derm = tf.saved_model.load(derm_dir)
164
+ infer = derm.signatures["serving_default"]
165
+ except Exception as e:
166
+ raise RuntimeError(
167
+ f"Failed to load derm-foundation: {e}. "
168
+ "Make sure you accepted the model terms and set HF_TOKEN in Space Settings."
169
+ )
170
+
171
+
172
+ # ---------------------- Utils ----------------------
173
+ def pil_to_png_bytes_448(pil_img: Image.Image) -> bytes:
174
+ pil_img = pil_img.convert("RGB").resize(DF_SIZE)
175
+ arr = np.array(pil_img, dtype=np.uint8)
176
+ return tf.io.encode_png(arr).numpy()
177
+
178
+ def _brightness(np_img_rgb: np.ndarray) -> float:
179
+ r,g,b = np_img_rgb[...,0], np_img_rgb[...,1], np_img_rgb[...,2]
180
+ y = 0.2126*r + 0.7152*g + 0.0722*b
181
+ return float(y.mean())
182
+
183
+ def _sharpness(np_img_rgb: np.ndarray) -> float:
184
+ if not HAS_OPENCV:
185
+ return 100.0
186
+ gray = cv2.cvtColor(np_img_rgb, cv2.COLOR_RGB2GRAY)
187
+ return float(cv2.Laplacian(gray, cv2.CV_64F).var())
188
+
189
+ def _skin_ratio(np_img_rgb: np.ndarray) -> float:
190
+ img = Image.fromarray(np_img_rgb).convert("YCbCr")
191
+ ycbcr = np.array(img)
192
+ Cb = ycbcr[...,1]; Cr = ycbcr[...,2]
193
+ mask = (Cb >= 77) & (Cb <= 127) & (Cr >= 133) & (Cr <= 173)
194
+ return float(mask.mean())
195
+
196
+ def gatekeep_image(img_bytes: bytes) -> Dict[str, Any]:
197
+ try:
198
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
199
+ except Exception:
200
+ return {"ok": False, "reasons": ["invalid_image"], "metrics": {}}
201
+ w,h = img.size
202
+ metrics = {"width": w, "height": h}
203
+ reasons = []
204
+ if w < MIN_W or h < MIN_H:
205
+ reasons.append("too_small")
206
+ aspect = w / h
207
+ metrics["aspect"] = float(aspect)
208
+ if not (MIN_ASPECT <= aspect <= MAX_ASPECT):
209
+ reasons.append("weird_aspect")
210
+ np_img = np.array(img)
211
+ bright = _brightness(np_img)
212
+ metrics["brightness"] = bright
213
+ if bright < MIN_BRIGHT: reasons.append("too_dark")
214
+ if bright > MAX_BRIGHT: reasons.append("too_bright")
215
+ if HAS_OPENCV:
216
+ sharp = _sharpness(np_img)
217
+ metrics["sharpness"] = sharp
218
+ if sharp < MIN_SHARPNESS: reasons.append("too_blurry")
219
+ ratio = _skin_ratio(np_img)
220
+ metrics["skin_ratio"] = ratio
221
+ if ratio < MIN_SKIN_RATIO: reasons.append("not_skin_like")
222
+ return {"ok": len(reasons)==0, "reasons": reasons, "metrics": metrics}
223
+
224
+ def predict_probs(img_bytes: bytes) -> np.ndarray:
225
+ pil = Image.open(io.BytesIO(img_bytes)).convert("RGB").resize(DF_SIZE)
226
+ by = pil_to_png_bytes_448(pil)
227
+ ex = tf.train.Example(features=tf.train.Features(
228
+ feature={'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[by]))}
229
+ )).SerializeToString()
230
+ out = infer(inputs=tf.constant([ex]))
231
+ if "embedding" not in out:
232
+ raise RuntimeError(f"Unexpected derm-foundation outputs: {list(out.keys())}")
233
+ emb = out["embedding"].numpy().astype("float32") # (1, 6144)
234
+ z = (emb - mu) / (sd + 1e-6)
235
+ probs = head.predict(z, verbose=0)[0] # head (.keras) โดยตรง
236
+ return probs
237
+
238
+ # ---------------------- Endpoints ----------------------
239
+ @app.get("/health")
240
+ def health():
241
+ return {
242
+ "ok": True,
243
+ "classes": len(CLASS_NAMES),
244
+ "derm": DERM_MODEL_ID or DERM_LOCAL_DIR,
245
+ "has_opencv": HAS_OPENCV
246
+ }
247
+
248
+ @app.post("/predict")
249
+ async def predict(request: Request, file: UploadFile = File(...)):
250
+ cl = request.headers.get("content-length")
251
+ if cl and int(cl) > MAX_UPLOAD:
252
+ raise HTTPException(413, "File too large")
253
+ img_bytes = await file.read()
254
+ if len(img_bytes) > MAX_UPLOAD:
255
+ raise HTTPException(413, "File too large")
256
+
257
+ gate = gatekeep_image(img_bytes)
258
+ if not gate["ok"]:
259
+ return JSONResponse(status_code=200, content={"ok": False, "reason": "gate_reject", "gate": gate})
260
+
261
+ probs = predict_probs(img_bytes)
262
+ order = np.argsort(probs)[::-1]
263
+ top = [{"label": CLASS_NAMES[i], "prob": float(probs[i])} for i in order[:TOPK]]
264
+
265
+ preds = (probs >= best_th).astype(np.int32)
266
+ positives = [{"label": CLASS_NAMES[i], "prob": float(probs[i])} for i in range(C) if preds[i] == 1]
267
+
268
+ return {
269
+ "ok": True,
270
+ "gate": gate,
271
+ "result": {
272
+ "type": "multilabel",
273
+ "thresholds_used": {CLASS_NAMES[i]: float(best_th[i]) for i in range(C)},
274
+ "positives": positives,
275
+ "topk": top,
276
+ "probs": {CLASS_NAMES[i]: float(probs[i]) for i in range(C)}
277
+ }
278
+ }
279
+
280
+ #------------------------------UI-----------------------------------
281
+ import gradio as gr
282
+ import io
283
+
284
+ def gradio_predict(image):
285
+ buf = io.BytesIO()
286
+ image.save(buf, format="PNG")
287
+ img_bytes = buf.getvalue()
288
+
289
+ gate = gatekeep_image(img_bytes)
290
+ if not gate["ok"]:
291
+ return {"Error": "Image rejected"}
292
+
293
+ probs = predict_probs(img_bytes)
294
+ order = np.argsort(probs)[::-1]
295
+
296
+ return {
297
+ CLASS_NAMES[i]: float(probs[i])
298
+ for i in order[:5]
299
+ }
300
+
301
+ with gr.Blocks() as demo:
302
+ gr.Markdown("# 🧠 Skin Disease Classifier")
303
+
304
+ image_input = gr.Image(type="pil")
305
+ output = gr.Label()
306
+
307
+ btn = gr.Button("Predict")
308
+ btn.click(gradio_predict, inputs=image_input, outputs=output)
309
+
310
+ demo.launch()
311
+