SmartHeal commited on
Commit
9a6c1cd
·
verified ·
1 Parent(s): 071cf17

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +556 -881
src/ai_processor.py CHANGED
@@ -1,435 +1,58 @@
1
- # smartheal_ai_processor.py
2
- # Verbose, instrumented version — preserves public class/function names
3
- # Turn on deep logging: export LOGLEVEL=DEBUG SMARTHEAL_DEBUG=1
4
-
5
  import os
6
  import logging
7
- from datetime import datetime
8
- from typing import Optional, Dict, List, Tuple
9
-
10
- # ---- Environment defaults (mask CUDA in main process) ----
11
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
12
- os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # ensure main never touches CUDA
13
- LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper()
14
- SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1"
15
-
16
  import cv2
17
  import numpy as np
18
  from PIL import Image
19
- from PIL.ExifTags import TAGS
20
-
21
- # --- Logging config ---
22
- logging.basicConfig(
23
- level=getattr(logging, LOGLEVEL, logging.INFO),
24
- format="%(asctime)s - %(levelname)s - %(message)s",
25
- )
26
-
27
- def _log_kv(prefix: str, kv: Dict):
28
- logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
29
-
30
- # --- Spaces GPU (non-optional) ---
31
- import spaces # required; do not stub/optionalize
32
-
33
- UPLOADS_DIR = "uploads"
34
- os.makedirs(UPLOADS_DIR, exist_ok=True)
35
-
36
- HF_TOKEN = os.getenv("HF_TOKEN", None)
37
- YOLO_MODEL_PATH = "src/best.pt"
38
- SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
39
- GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
40
- DATASET_ID = "SmartHeal/wound-image-uploads"
41
- DEFAULT_PX_PER_CM = 38.0
42
- PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0
43
-
44
- # Segmentation preprocessing knobs
45
- SEG_EXPECTS_RGB = os.getenv("SEG_EXPECTS_RGB", "1") == "1" # most TF models trained on RGB
46
- SEG_NORM = os.getenv("SEG_NORM", "0to1") # "0to1" | "imagenet"
47
- SEG_THRESH = float(os.getenv("SEG_THRESH", "0.5"))
48
-
49
- models_cache: Dict[str, object] = {}
50
- knowledge_base_cache: Dict[str, object] = {}
51
-
52
- # ---------- Lazy imports ----------
53
- def _import_ultralytics():
54
- from ultralytics import YOLO
55
- return YOLO
56
-
57
- def _import_tf_loader():
58
- import tensorflow as tf
59
- try:
60
- tf.config.set_visible_devices([], "GPU") # keep TF on CPU
61
- except Exception:
62
- pass
63
- from tensorflow.keras.models import load_model
64
- return load_model
65
-
66
- def _import_hf_cls():
67
- from transformers import pipeline
68
- return pipeline
69
-
70
- def _import_embeddings():
71
- # updated per LangChain deprecations
72
- from langchain_huggingface import HuggingFaceEmbeddings
73
- return HuggingFaceEmbeddings
74
-
75
- def _import_langchain_pdf():
76
- from langchain_community.document_loaders import PyPDFLoader
77
- return PyPDFLoader
78
-
79
- def _import_langchain_faiss():
80
- from langchain_community.vectorstores import FAISS
81
- return FAISS
82
-
83
- def _import_hf_hub():
84
- from huggingface_hub import HfApi, HfFolder
85
- return HfApi, HfFolder
86
-
87
- # ---------- VLM (MedGemma replacement under the same public function name) ----------
88
- SMARTHEAL_VLM_ID = os.getenv("SMARTHEAL_VLM_ID", "Qwen/Qwen2-VL-2B-Instruct")
89
- SMARTHEAL_VLM_MAX_NEW_TOKENS = int(os.getenv("SMARTHEAL_VLM_MAX_NEW_TOKENS", "600"))
90
- SMARTHEAL_VLM_TEMPERATURE = float(os.getenv("SMARTHEAL_VLM_TEMPERATURE", "0.2"))
91
-
92
- SMARTHEAL_SYSTEM_PROMPT = """You are SmartHeal, a medical decision-support assistant specialized in wound assessment.
93
- You are given: (1) a wound photograph, (2) basic patient context, and (3) visual measurements (length, width, area)
94
- estimated from computer vision. You must:
95
-
96
- - Summarize clinically-relevant visual cues (tissue type, exudate amount, slough/necrosis, peri-wound condition).
97
- - Interpret in context of diabetes/infection/moisture/bleeding risks.
98
- - Provide clear next-step care: cleansing, debridement criteria, dressing selection, offloading, escalation triggers.
99
- - Include risk flags (ischemia, cellulitis, osteomyelitis suspicion) and monitoring frequency.
100
- - Be concise, structured, and avoid speculation beyond the image and given data.
101
- - Always add a short disclaimer: “Decision-support only; verify clinically.”"""
102
-
103
- def _build_vlm_messages(patient_info: str, visual_results: Dict, guideline_context: str) -> list:
104
- wt = visual_results.get("wound_type", "Unknown")
105
- L = visual_results.get("length_cm", 0)
106
- W = visual_results.get("breadth_cm", 0)
107
- A = visual_results.get("surface_area_cm2", 0)
108
- ppcm = visual_results.get("px_per_cm", "?")
109
-
110
- ctx = (guideline_context or "")
111
- if ctx:
112
- ctx = f"\n\nRelevant guideline snippets:\n{ctx[:1200]}{'...' if len(ctx)>1200 else ''}"
113
-
114
- text = (
115
- f"{SMARTHEAL_SYSTEM_PROMPT}\n\n"
116
- f"Patient: {patient_info}\n"
117
- f"Wound visual summary (from CV): type={wt}, length={L} cm, width={W} cm, "
118
- f"area={A} cm² (calibration {ppcm} px/cm)."
119
- f"{ctx}\n\n"
120
- "Analyze the image and provide:\n"
121
- "1) Clinical Summary\n2) Dressing & Treatment Plan\n"
122
- "3) Risk/Red Flags\n4) Monitoring Plan\n"
123
- "Format with short headings and bullets.\n"
124
- )
125
- return [{"role": "user", "content": [{"type": "text", "text": text}]}]
126
-
127
- @spaces.GPU # non-optional: ensure CUDA work happens only inside the ZeroGPU worker
128
- def _vlm_infer_gpu(
129
- image_pil: Image.Image,
130
- messages: list,
131
- max_new_tokens: int,
132
- temperature: float,
133
- model_id: str,
134
- token: Optional[str],
135
- ) -> str:
136
- import torch
137
- from transformers import AutoProcessor, AutoModelForCausalLM
138
-
139
- torch.backends.cuda.matmul.allow_tf32 = True
140
- device = "cuda"
141
-
142
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True, token=token)
143
- model = AutoModelForCausalLM.from_pretrained(
144
- model_id,
145
- torch_dtype=torch.float16,
146
- trust_remote_code=True,
147
- token=token,
148
- ).to(device)
149
-
150
- inputs = processor(messages=messages, images=[image_pil], return_tensors="pt").to(device)
151
- gen_ids = model.generate(
152
- **inputs,
153
- max_new_tokens=max_new_tokens,
154
- do_sample=False,
155
- temperature=temperature,
156
- )
157
- out = processor.batch_decode(gen_ids, skip_special_tokens=True)[0]
158
- return out.strip()
159
-
160
- def _vlm_infer_cpu(
161
- image_pil: Image.Image,
162
- messages: list,
163
- max_new_tokens: int,
164
- temperature: float,
165
- model_id: str,
166
- token: Optional[str],
167
- ) -> str:
168
- from transformers import pipeline
169
- pipe = pipeline(
170
- task="image-text-to-text",
171
- model=model_id,
172
- device="cpu",
173
- trust_remote_code=True,
174
- token=token,
175
- )
176
- out = pipe(
177
- text=[{"role": "user", "content": [{"type": "image", "image": image_pil}, *messages[0]["content"]]}],
178
- max_new_tokens=max_new_tokens,
179
- do_sample=False,
180
- temperature=temperature,
181
- )
182
- try:
183
- return (out[0]["generated_text"][-1].get("content", "") or "").strip()
184
- except Exception:
185
- return (out[0].get("generated_text", "") or "").strip()
186
-
187
- def generate_medgemma_report( # <-- keep the original PUBLIC NAME
188
- patient_info: str,
189
- visual_results: Dict,
190
- guideline_context: str,
191
- image_pil: Image.Image,
192
- max_new_tokens: Optional[int] = None,
193
- ) -> str:
194
- """
195
- Re-implemented to use Qwen/Qwen2-VL-* via ZeroGPU (@spaces.GPU) with CPU fallback.
196
- Name preserved for compatibility with existing callers.
197
- """
198
- msgs = _build_vlm_messages(patient_info, visual_results, guideline_context)
199
- try:
200
- return _vlm_infer_gpu(
201
- image_pil=image_pil,
202
- messages=msgs,
203
- max_new_tokens=max_new_tokens or SMARTHEAL_VLM_MAX_NEW_TOKENS,
204
- temperature=SMARTHEAL_VLM_TEMPERATURE,
205
- model_id=SMARTHEAL_VLM_ID,
206
- token=HF_TOKEN,
207
- )
208
- except Exception as e:
209
- logging.warning(f"GPU VLM failed; falling back to CPU: {e!r}")
210
- return _vlm_infer_cpu(
211
- image_pil=image_pil,
212
- messages=msgs,
213
- max_new_tokens=max_new_tokens or SMARTHEAL_VLM_MAX_NEW_TOKENS,
214
- temperature=SMARTHEAL_VLM_TEMPERATURE,
215
- model_id=SMARTHEAL_VLM_ID,
216
- token=HF_TOKEN,
217
- ) or "⚠️ VLM returned empty output"
218
-
219
- # ---------- Initialize CPU models ----------
220
- def load_yolo_model():
221
- YOLO = _import_ultralytics()
222
- return YOLO(YOLO_MODEL_PATH)
223
-
224
- def load_segmentation_model():
225
- load_model = _import_tf_loader()
226
- return load_model(SEG_MODEL_PATH, compile=False)
227
-
228
- def load_classification_pipeline():
229
- pipe = _import_hf_cls()
230
- return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu")
231
-
232
- def load_embedding_model():
233
- Emb = _import_embeddings()
234
- return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
235
-
236
- def initialize_cpu_models() -> None:
237
- if HF_TOKEN:
238
- try:
239
- HfApi, HfFolder = _import_hf_hub()
240
- HfFolder.save_token(HF_TOKEN)
241
- logging.info("✅ HF token set")
242
- except Exception as e:
243
- logging.warning(f"HF token save failed: {e}")
244
 
245
- if "det" not in models_cache:
246
- try:
247
- models_cache["det"] = load_yolo_model()
248
- logging.info(" YOLO loaded (CPU; CUDA masked in main)")
249
- except Exception as e:
250
- logging.error(f"YOLO load failed: {e}")
251
 
252
- if "seg" not in models_cache:
253
- try:
254
- if os.path.exists(SEG_MODEL_PATH):
255
- models_cache["seg"] = load_segmentation_model()
256
- m = models_cache["seg"]
257
- ishape = getattr(m, "input_shape", None)
258
- oshape = getattr(m, "output_shape", None)
259
- logging.info(f"✅ Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}")
260
- else:
261
- models_cache["seg"] = None
262
- logging.warning("Segmentation model file missing; skipping.")
263
- except Exception as e:
264
- models_cache["seg"] = None
265
- logging.warning(f"Segmentation unavailable: {e}")
266
 
267
- if "cls" not in models_cache:
268
- try:
269
- models_cache["cls"] = load_classification_pipeline()
270
- logging.info("✅ Classifier loaded (CPU)")
271
- except Exception as e:
272
- models_cache["cls"] = None
273
- logging.warning(f"Classifier unavailable: {e}")
274
 
275
- if "embedding_model" not in models_cache:
276
- try:
277
- models_cache["embedding_model"] = load_embedding_model()
278
- logging.info("✅ Embeddings loaded (CPU)")
279
- except Exception as e:
280
- models_cache["embedding_model"] = None
281
- logging.warning(f"Embeddings unavailable: {e}")
282
 
283
- def setup_knowledge_base() -> None:
284
- if "vector_store" in knowledge_base_cache:
285
- return
286
- docs: List = []
287
- try:
288
- PyPDFLoader = _import_langchain_pdf()
289
- for pdf in GUIDELINE_PDFS:
290
- if os.path.exists(pdf):
291
- try:
292
- docs.extend(PyPDFLoader(pdf).load())
293
- logging.info(f"Loaded PDF: {pdf}")
294
- except Exception as e:
295
- logging.warning(f"PDF load failed ({pdf}): {e}")
296
- except Exception as e:
297
- logging.warning(f"LangChain PDF loader unavailable: {e}")
298
 
299
- if docs and models_cache.get("embedding_model"):
300
- try:
301
- from langchain.text_splitter import RecursiveCharacterTextSplitter
302
- FAISS = _import_langchain_faiss()
303
- chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs)
304
- knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
305
- logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)")
306
- except Exception as e:
307
- knowledge_base_cache["vector_store"] = None
308
- logging.warning(f"KB build failed: {e}")
309
- else:
310
- knowledge_base_cache["vector_store"] = None
311
- logging.warning("KB disabled (no docs or embeddings).")
312
-
313
- initialize_cpu_models()
314
- setup_knowledge_base()
315
-
316
- # ---------- Calibration helpers ----------
317
- def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
318
- out = {}
319
- try:
320
- exif = pil_img.getexif()
321
- if not exif:
322
- return out
323
- for k, v in exif.items():
324
- tag = TAGS.get(k, k)
325
- out[tag] = v
326
- except Exception:
327
- pass
328
- return out
329
 
330
- def _to_float(val) -> Optional[float]:
331
- try:
332
- if val is None:
333
- return None
334
- if isinstance(val, tuple) and len(val) == 2:
335
- num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0
336
- return num / den
337
- return float(val)
338
- except Exception:
339
- return None
340
 
341
- def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]:
342
- if f_mm and f35 and f35 > 0:
343
- return 36.0 * f_mm / f35
344
- return None
345
 
346
- def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]:
347
- meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None}
348
- try:
349
- exif = _exif_to_dict(pil_img)
350
- f_mm = _to_float(exif.get("FocalLength"))
351
- f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm"))
352
- subj_dist_m = _to_float(exif.get("SubjectDistance"))
353
- sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35)
354
- meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m})
355
-
356
- if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0:
357
- w_px = pil_img.width
358
- field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm
359
- field_w_cm = field_w_mm / 10.0
360
- px_per_cm = w_px / max(field_w_cm, 1e-6)
361
- px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX))
362
- meta["used"] = "exif"
363
- return px_per_cm, meta
364
- return float(default_px_per_cm), meta
365
- except Exception:
366
- return float(default_px_per_cm), meta
367
-
368
- # ---------- Segmentation helpers ----------
369
- def _imagenet_norm(arr: np.ndarray) -> np.ndarray:
370
- mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
371
- std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
372
- return (arr.astype(np.float32) - mean) / std
373
-
374
- def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray:
375
- H, W = target_hw
376
- resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR)
377
- if SEG_EXPECTS_RGB:
378
- resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
379
- if SEG_NORM.lower() == "imagenet":
380
- x = _imagenet_norm(resized)
381
- else:
382
- x = resized.astype(np.float32) / 255.0
383
- x = np.expand_dims(x, axis=0) # (1,H,W,3)
384
- return x
385
-
386
- def _to_prob(pred: np.ndarray) -> np.ndarray:
387
- p = np.squeeze(pred)
388
- pmin, pmax = float(p.min()), float(p.max())
389
- if pmax > 1.0 or pmin < 0.0:
390
- p = 1.0 / (1.0 + np.exp(-p))
391
- return p.astype(np.float32)
392
-
393
- # ---- Adaptive threshold + GrabCut grow ----
394
- def _adaptive_prob_threshold(p: np.ndarray) -> float:
395
- """
396
- Choose a threshold that avoids tiny blobs while not swallowing skin.
397
- Try Otsu and the 90th percentile, clamp to [0.25, 0.65], pick by area heuristic.
398
- """
399
- p01 = np.clip(p.astype(np.float32), 0, 1)
400
- p255 = (p01 * 255).astype(np.uint8)
401
-
402
- ret_otsu, _ = cv2.threshold(p255, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
403
- thr_otsu = float(np.clip(ret_otsu / 255.0, 0.25, 0.65))
404
- thr_pctl = float(np.clip(np.percentile(p01, 90), 0.25, 0.65))
405
-
406
- def area_frac(thr: float) -> float:
407
- return float((p01 >= thr).sum()) / float(p01.size)
408
-
409
- af_otsu = area_frac(thr_otsu)
410
- af_pctl = area_frac(thr_pctl)
411
-
412
- def score(af: float) -> float:
413
- target_low, target_high = 0.03, 0.10
414
- if af < target_low: return abs(af - target_low) * 3.0
415
- if af > target_high: return abs(af - target_high) * 1.5
416
- return 0.0
417
-
418
- return thr_otsu if score(af_otsu) <= score(af_pctl) else thr_pctl
419
 
420
- def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.ndarray:
421
- """Grow from a confident core into low-contrast margins."""
422
- h, w = bgr.shape[:2]
423
- gc = np.full((h, w), cv2.GC_PR_BGD, np.uint8)
424
- k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
425
- seed_dil = cv2.dilate(seed01, k, iterations=1)
426
- gc[seed01.astype(bool)] = cv2.GC_PR_FGD
427
- gc[seed_dil.astype(bool)] = cv2.GC_FGD
428
- gc[0, :], gc[-1, :], gc[:, 0], gc[:, -1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD
429
- bgdModel = np.zeros((1, 65), np.float64)
430
- fgdModel = np.zeros((1, 65), np.float64)
431
- cv2.grabCut(bgr, gc, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
432
- return np.where((gc == cv2.GC_FGD) | (gc == cv2.GC_PR_FGD), 1, 0).astype(np.uint8)
433
 
434
  def _fill_holes(mask01: np.ndarray) -> np.ndarray:
435
  h, w = mask01.shape[:2]
@@ -441,558 +64,610 @@ def _fill_holes(mask01: np.ndarray) -> np.ndarray:
441
  return out.astype(np.uint8)
442
 
443
  def _clean_mask(mask01: np.ndarray) -> np.ndarray:
444
- """Open → Close → Fill holes → Largest component (no dilation)."""
445
  mask01 = (mask01 > 0).astype(np.uint8)
446
  k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
447
  k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
448
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, k3, iterations=1)
449
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, k5, iterations=1)
450
  mask01 = _fill_holes(mask01)
451
- # Keep largest component only
452
- num, labels, stats, _ = cv2.connectedComponentsWithStats(mask01, 8)
453
- if num > 1:
454
- areas = stats[1:, cv2.CC_STAT_AREA]
455
- if areas.size:
456
- largest_idx = 1 + int(np.argmax(areas))
457
- mask01 = (labels == largest_idx).astype(np.uint8)
458
  return (mask01 > 0).astype(np.uint8)
459
 
460
- # ---------- Segmentation pipeline ----------
461
- def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]:
462
- """
463
- TF model adaptive threshold on prob → GrabCut grow → cleanup.
464
- Fallback: KMeans-Lab.
465
- Returns (mask_uint8_0_255, debug_dict)
466
- """
467
- debug = {"used": None, "reason": None, "positive_fraction": 0.0,
468
- "thr": None, "heatmap_path": None, "roi_seen_by_model": None}
469
-
470
- seg_model = models_cache.get("seg", None)
471
-
472
- # --- Model path ---
473
- if seg_model is not None:
474
- try:
475
- ishape = getattr(seg_model, "input_shape", None)
476
- if not ishape or len(ishape) < 4:
477
- raise ValueError(f"Bad seg input_shape: {ishape}")
478
- th, tw = int(ishape[1]), int(ishape[2])
479
-
480
- x = _preprocess_for_seg(image_bgr, (th, tw))
481
- roi_seen_path = None
482
- if SMARTHEAL_DEBUG:
483
- roi_seen_path = os.path.join(out_dir, f"roi_for_seg_{ts}.png")
484
- cv2.imwrite(roi_seen_path, image_bgr)
485
-
486
- pred = seg_model.predict(x, verbose=0)
487
- if isinstance(pred, (list, tuple)): pred = pred[0]
488
- p = _to_prob(pred)
489
- p = cv2.resize(p, (image_bgr.shape[1], image_bgr.shape[0]), interpolation=cv2.INTER_LINEAR)
490
-
491
- heatmap_path = None
492
- if SMARTHEAL_DEBUG:
493
- hm = (np.clip(p, 0, 1) * 255).astype(np.uint8)
494
- heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
495
- heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png")
496
- cv2.imwrite(heatmap_path, heat)
497
-
498
- thr = _adaptive_prob_threshold(p)
499
- core01 = (p >= thr).astype(np.uint8)
500
- core_frac = float(core01.sum()) / float(core01.size)
501
-
502
- if core_frac < 0.005:
503
- thr2 = max(thr - 0.10, 0.15)
504
- core01 = (p >= thr2).astype(np.uint8)
505
- thr = thr2
506
- core_frac = float(core01.sum()) / float(core01.size)
507
-
508
- if core01.any():
509
- gc01 = _grabcut_refine(image_bgr, core01, iters=3)
510
- mask01 = _clean_mask(gc01)
511
- else:
512
- mask01 = np.zeros(core01.shape, np.uint8)
513
-
514
- pos_frac = float(mask01.sum()) / float(mask01.size)
515
- logging.info(f"SegModel USED | thr={float(thr):.2f} core_frac={core_frac:.4f} final_frac={pos_frac:.4f}")
516
-
517
- debug.update({
518
- "used": "tf_model",
519
- "reason": "ok",
520
- "positive_fraction": pos_frac,
521
- "thr": float(thr),
522
- "heatmap_path": heatmap_path,
523
- "roi_seen_by_model": roi_seen_path
524
- })
525
- return (mask01 * 255).astype(np.uint8), debug
526
-
527
- except Exception as e:
528
- logging.warning(f"⚠️ Segmentation model failed → fallback. Reason: {e}")
529
- debug.update({"used": "fallback_kmeans", "reason": f"model_failed: {e}"})
530
 
531
- # --- Fallback: KMeans in Lab (reddest cluster as wound) ---
532
- Z = image_bgr.reshape((-1, 3)).astype(np.float32)
 
533
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
534
  _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
535
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
536
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
537
- wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (red)
538
- mask01 = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8)
539
- mask01 = _clean_mask(mask01)
540
-
541
- pos_frac = float(mask01.sum()) / float(mask01.size)
542
- logging.info(f"KMeans USED | final_frac={pos_frac:.4f}")
543
-
544
- debug.update({
545
- "used": "fallback_kmeans",
546
- "reason": debug.get("reason") or "no_model",
547
- "positive_fraction": pos_frac,
548
- "thr": None
549
- })
550
- return (mask01 * 255).astype(np.uint8), debug
551
-
552
- # ---------- Measurement + overlay helpers ----------
553
- def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
554
- num, labels, stats, _ = cv2.connectedComponentsWithStats(binary01.astype(np.uint8), connectivity=8)
555
- if num <= 1:
556
- return binary01.astype(np.uint8)
557
- areas = stats[1:, cv2.CC_STAT_AREA]
558
- if areas.size == 0 or areas.max() < min_area_px:
559
- return binary01.astype(np.uint8)
560
- largest_idx = 1 + int(np.argmax(areas))
561
- return (labels == largest_idx).astype(np.uint8)
562
-
563
- def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
564
- contours, _ = cv2.findContours(mask01.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
565
- if not contours:
566
- return 0.0, 0.0, (None, None)
567
- cnt = max(contours, key=cv2.contourArea)
568
- rect = cv2.minAreaRect(cnt)
569
- (w_px, h_px) = rect[1]
570
- length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
571
- length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
572
- breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
573
- box = cv2.boxPoints(rect).astype(int)
574
- return length_cm, breadth_cm, (box, rect[0])
575
-
576
- def area_cm2_from_contour(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, Optional[np.ndarray]]:
577
- """Area from largest polygon (sub-pixel); returns (area_cm2, contour)."""
578
- m = (mask01 > 0).astype(np.uint8)
579
- contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
580
- if not contours:
581
- return 0.0, None
582
- cnt = max(contours, key=cv2.contourArea)
583
- poly_area_px2 = float(cv2.contourArea(cnt))
584
- area_cm2 = round(poly_area_px2 / (max(px_per_cm, 1e-6) ** 2), 2)
585
- return area_cm2, cnt
586
-
587
- def clamp_area_with_minrect(cnt: np.ndarray, px_per_cm: float, area_cm2_poly: float) -> float:
588
- rect = cv2.minAreaRect(cnt)
589
- (w_px, h_px) = rect[1]
590
- rect_area_px2 = float(max(w_px, 0.0) * max(h_px, 0.0))
591
- rect_area_cm2 = rect_area_px2 / (max(px_per_cm, 1e-6) ** 2)
592
- return round(min(area_cm2_poly, rect_area_cm2 * 1.05), 2)
593
-
594
- def draw_measurement_overlay(
595
- base_bgr: np.ndarray,
596
- mask01: np.ndarray,
597
- rect_box: np.ndarray,
598
- length_cm: float,
599
- breadth_cm: float,
600
- thickness: int = 2
601
- ) -> np.ndarray:
602
- """
603
- 1) Strong red mask overlay + white contour
604
- 2) Min-area rectangle
605
- 3) Double-headed arrows labeled Length/Width
606
- """
607
- overlay = base_bgr.copy()
608
-
609
- # Mask tint
610
- mask255 = (mask01 * 255).astype(np.uint8)
611
- mask3 = cv2.merge([mask255, mask255, mask255])
612
- red = np.zeros_like(overlay); red[:] = (0, 0, 255)
613
- alpha = 0.55
614
- tinted = cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0)
615
- overlay = np.where(mask3 > 0, tinted, overlay)
616
-
617
- # Contour
618
- cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
619
- if cnts:
620
- cv2.drawContours(overlay, cnts, -1, (255, 255, 255), 2)
621
-
622
- if rect_box is not None:
623
- cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
624
- pts = rect_box.reshape(-1, 2)
625
-
626
- def midpoint(a, b): return (int((a[0] + b[0]) / 2), int((a[1] + b[1]) / 2))
627
- e = [np.linalg.norm(pts[i] - pts[(i + 1) % 4]) for i in range(4)]
628
- long_edge_idx = int(np.argmax(e))
629
- mids = [midpoint(pts[i], pts[(i + 1) % 4]) for i in range(4)]
630
- long_pair = (long_edge_idx, (long_edge_idx + 2) % 4)
631
- short_pair = ((long_edge_idx + 1) % 4, (long_edge_idx + 3) % 4)
632
-
633
- def draw_double_arrow(img, p1, p2):
634
- cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
635
- cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
636
- cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
637
- cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
638
-
639
- def put_label(text, anchor):
640
- org = (anchor[0] + 6, anchor[1] - 6)
641
- cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
642
- cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
643
-
644
- draw_double_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
645
- draw_double_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
646
- put_label(f"Length: {length_cm:.2f} cm", mids[long_pair[0]])
647
- put_label(f"Width: {breadth_cm:.2f} cm", mids[short_pair[0]])
648
-
649
- return overlay
650
-
651
- # ---------- AI PROCESSOR ----------
652
  class AIProcessor:
653
  def __init__(self):
654
- self.models_cache = models_cache
655
- self.knowledge_base_cache = knowledge_base_cache
656
- self.uploads_dir = UPLOADS_DIR
657
- self.dataset_id = DATASET_ID
658
- self.hf_token = HF_TOKEN
659
-
660
- def _ensure_analysis_dir(self) -> str:
661
- out_dir = os.path.join(self.uploads_dir, "analysis")
662
- os.makedirs(out_dir, exist_ok=True)
663
- return out_dir
664
-
665
- def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
666
- """
667
- YOLO detect → crop ROI → segment_wound(ROI) → clean mask →
668
- minAreaRect measurement (cm) using EXIF px/cm → save outputs.
669
- """
670
  try:
671
- px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
672
- # Guardrails for calibration to avoid huge area blow-ups
673
- px_per_cm = float(np.clip(px_per_cm, 20.0, 350.0))
674
- if (exif_meta or {}).get("used") != "exif":
675
- logging.warning(f"Calibration fallback used: px_per_cm={px_per_cm:.2f} (default). Prefer ruler/Aruco for accuracy.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
678
 
679
- # --- Detection ---
680
- det_model = self.models_cache.get("det")
681
- if det_model is None:
682
- raise RuntimeError("YOLO model not loaded")
683
- results = det_model.predict(image_cv, verbose=False, device="cpu")
684
  if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
685
- try:
686
- import gradio as gr
687
- raise gr.Error("No wound could be detected.")
688
- except Exception:
689
- raise RuntimeError("No wound could be detected.")
690
 
691
- box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
692
  x1, y1, x2, y2 = [int(v) for v in box]
693
  x1, y1 = max(0, x1), max(0, y1)
694
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
695
- roi = image_cv[y1:y2, x1:x2].copy()
696
- if roi.size == 0:
697
- try:
698
- import gradio as gr
699
- raise gr.Error("Detected ROI is empty.")
700
- except Exception:
701
- raise RuntimeError("Detected ROI is empty.")
702
 
703
- out_dir = self._ensure_analysis_dir()
 
704
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
705
-
706
- # --- Segmentation (model-first + KMeans fallback) ---
707
- mask_u8_255, seg_debug = segment_wound(roi, ts, out_dir)
708
- mask01 = (mask_u8_255 > 127).astype(np.uint8)
709
-
710
- if mask01.any():
711
- mask01 = _clean_mask(mask01)
712
- logging.debug(f"Mask postproc: px_after={int(mask01.sum())}")
713
-
714
- # --- Measurement (accurate & conservative) ---
715
- if mask01.any():
716
- length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm)
717
- area_poly_cm2, largest_cnt = area_cm2_from_contour(mask01, px_per_cm)
718
- if largest_cnt is not None:
719
- surface_area_cm2 = clamp_area_with_minrect(largest_cnt, px_per_cm, area_poly_cm2)
720
- else:
721
- surface_area_cm2 = area_poly_cm2
722
-
723
- anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
724
- segmentation_empty = False
725
- else:
726
- # Fallback if seg failed: use ROI dimensions
727
- h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
728
- length_cm = round(max(h_px, w_px) / px_per_cm, 2)
729
- breadth_cm = round(min(h_px, w_px) / px_per_cm, 2)
730
- surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
731
- anno_roi = roi.copy()
732
- cv2.rectangle(anno_roi, (2, 2), (anno_roi.shape[1]-3, anno_roi.shape[0]-3), (0, 0, 255), 3)
733
- cv2.line(anno_roi, (0, 0), (anno_roi.shape[1]-1, anno_roi.shape[0]-1), (0, 0, 255), 2)
734
- cv2.line(anno_roi, (anno_roi.shape[1]-1, 0), (0, anno_roi.shape[0]-1), (0, 0, 255), 2)
735
- box_pts = None
736
- segmentation_empty = True
737
-
738
- # --- Save visualizations ---
739
- original_path = os.path.join(out_dir, f"original_{ts}.png")
740
- cv2.imwrite(original_path, image_cv)
741
-
742
  det_vis = image_cv.copy()
743
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
744
- detection_path = os.path.join(out_dir, f"detection_{ts}.png")
745
- cv2.imwrite(detection_path, det_vis)
746
-
747
- roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
748
- cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8))
749
-
750
- # ROI overlay (mask tint + contour, without arrows)
751
- mask255 = (mask01 * 255).astype(np.uint8)
752
- mask3 = cv2.merge([mask255, mask255, mask255])
753
- red = np.zeros_like(roi); red[:] = (0, 0, 255)
754
- alpha = 0.55
755
- tinted = cv2.addWeighted(roi, 1 - alpha, red, alpha, 0)
756
- if mask255.any():
757
- roi_overlay = np.where(mask3 > 0, tinted, roi)
758
- cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
759
- cv2.drawContours(roi_overlay, cnts, -1, (255, 255, 255), 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
  else:
761
- roi_overlay = anno_roi
762
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  seg_full = image_cv.copy()
764
- seg_full[y1:y2, x1:x2] = roi_overlay
765
- segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
766
- cv2.imwrite(segmentation_path, seg_full)
767
 
768
- segmentation_roi_path = os.path.join(out_dir, f"segmentation_roi_{ts}.png")
769
- cv2.imwrite(segmentation_roi_path, roi_overlay)
770
-
771
- # Annotated (mask + arrows + labels) in full-frame
772
- anno_full = image_cv.copy()
773
- anno_full[y1:y2, x1:x2] = anno_roi
774
- annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
775
- cv2.imwrite(annotated_seg_path, anno_full)
776
-
777
- # --- Optional classification ---
778
  wound_type = "Unknown"
779
- cls_pipe = self.models_cache.get("cls")
780
- if cls_pipe is not None:
781
  try:
782
- preds = cls_pipe(Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)))
 
783
  if preds:
784
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
785
  except Exception as e:
786
- logging.warning(f"Classification failed: {e}")
787
-
788
- # Log end-of-seg summary
789
- seg_summary = {
790
- "seg_used": seg_debug.get("used"),
791
- "seg_reason": seg_debug.get("reason"),
792
- "positive_fraction": round(float(seg_debug.get("positive_fraction", 0.0)), 6),
793
- "threshold": seg_debug.get("thr"),
794
- "segmentation_empty": segmentation_empty,
795
- "exif_px_per_cm": round(px_per_cm, 3),
796
- }
797
- _log_kv("SEG_SUMMARY", seg_summary)
798
 
799
  return {
800
  "wound_type": wound_type,
801
- "length_cm": length_cm,
802
- "breadth_cm": breadth_cm,
803
- "surface_area_cm2": surface_area_cm2,
804
- "px_per_cm": round(px_per_cm, 2),
805
- "calibration_meta": exif_meta,
806
- "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
807
- if getattr(results[0].boxes, "conf", None) is not None else 0.0,
808
- "detection_image_path": detection_path,
809
- "segmentation_image_path": annotated_seg_path,
810
- "segmentation_annotated_path": annotated_seg_path,
811
- "segmentation_roi_path": segmentation_roi_path,
812
- "roi_mask_path": roi_mask_path,
813
- "segmentation_empty": segmentation_empty,
814
- "segmentation_debug": seg_debug,
815
- "original_image_path": original_path,
816
  }
 
817
  except Exception as e:
818
- logging.error(f"Visual analysis failed: {e}", exc_info=True)
819
- raise
820
 
821
- # ---------- Knowledge base + reporting ----------
822
- def query_guidelines(self, query: str) -> str:
 
823
  try:
824
- vs = self.knowledge_base_cache.get("vector_store")
825
- if not vs:
826
- return "Knowledge base is not available."
827
- retriever = vs.as_retriever(search_kwargs={"k": 5})
828
- # LangChain deprecation fix: use invoke()
829
- docs = retriever.invoke(query)
830
- lines: List[str] = []
 
 
 
 
 
 
 
 
831
  for d in docs:
832
- src = (d.metadata or {}).get("source", "N/A")
833
- txt = (d.page_content or "")[:300]
834
- lines.append(f"Source: {src}\nContent: {txt}...")
835
- return "\n\n".join(lines) if lines else "No relevant guideline snippets found."
 
 
836
  except Exception as e:
837
- logging.warning(f"Guidelines query failed: {e}")
838
- return f"Guidelines query failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
839
 
840
- def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
841
- return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
842
 
843
- ## 📋 Patient Information
844
  {patient_info}
845
 
846
- ## 🔍 Visual Analysis Results
847
- - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
848
- - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
849
- - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
850
- - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
851
- - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
852
-
853
- ## 📊 Analysis Images
854
- - **Original**: {visual_results.get('original_image_path', 'N/A')}
855
- - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
856
- - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
857
- - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
858
-
859
- ## 🎯 Clinical Summary
860
- Automated analysis provides quantitative measurements; verify via clinical examination.
861
-
862
- ## 💊 Recommendations
863
- - Cleanse wound gently; select dressing per exudate/infection risk
864
- - Debride necrotic tissue if indicated (clinical decision)
865
- - Document with serial photos and measurements
866
-
867
- ## 📅 Monitoring
868
- - Daily in week 1, then every 2–3 days (or as indicated)
869
- - Weekly progress review
870
-
871
- ## 📚 Guideline Context
872
- {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
873
-
874
- **Disclaimer:** Automated, for decision support only. Verify clinically.
875
- """
876
-
877
- def generate_final_report(
878
- self,
879
- patient_info: str,
880
- visual_results: Dict,
881
- guideline_context: str,
882
- image_pil: Image.Image,
883
- max_new_tokens: Optional[int] = None,
884
- ) -> str:
885
- try:
886
- # call the preserved function name (now backed by Qwen2-VL)
887
- report = generate_medgemma_report(
888
- patient_info, visual_results, guideline_context, image_pil, max_new_tokens
889
  )
890
- if report and report.strip() and not report.startswith(("⚠️", "❌")):
891
- return report
892
- logging.warning("VLM unavailable/invalid; using fallback.")
 
 
 
 
 
 
 
893
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
894
  except Exception as e:
895
- logging.error(f"Report generation failed: {e}")
896
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
897
 
898
- def save_and_commit_image(self, image_pil: Image.Image) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  try:
900
- os.makedirs(self.uploads_dir, exist_ok=True)
901
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
902
- filename = f"{ts}.png"
903
- path = os.path.join(self.uploads_dir, filename)
904
- image_pil.convert("RGB").save(path)
905
- logging.info(f"✅ Image saved locally: {path}")
906
 
907
- if HF_TOKEN and DATASET_ID:
908
  try:
909
- HfApi, HfFolder = _import_hf_hub()
910
- HfFolder.save_token(HF_TOKEN)
911
  api = HfApi()
912
  api.upload_file(
913
- path_or_fileobj=path,
914
  path_in_repo=f"images/{filename}",
915
- repo_id=DATASET_ID,
916
  repo_type="dataset",
917
- token=HF_TOKEN,
918
  commit_message=f"Upload wound image: {filename}",
919
  )
920
  logging.info("✅ Image committed to HF dataset")
921
  except Exception as e:
922
  logging.warning(f"HF upload failed: {e}")
923
 
924
- return path
925
  except Exception as e:
926
- logging.error(f"Failed to save/commit image: {e}")
927
- return ""
928
 
929
- def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
 
 
 
 
930
  try:
931
- saved_path = self.save_and_commit_image(image_pil)
932
- visual_results = self.perform_visual_analysis(image_pil)
933
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
934
  pi = questionnaire_data or {}
935
- patient_info = (
936
- f"Age: {pi.get('age','N/A')}, "
937
- f"Diabetic: {pi.get('diabetic','N/A')}, "
938
- f"Allergies: {pi.get('allergies','N/A')}, "
939
- f"Date of Wound: {pi.get('date_of_injury','N/A')}, "
940
- f"Professional Care: {pi.get('professional_care','N/A')}, "
941
- f"Oozing/Bleeding: {pi.get('oozing_bleeding','N/A')}, "
942
- f"Infection: {pi.get('infection','N/A')}, "
943
- f"Moisture: {pi.get('moisture','N/A')}"
944
- )
945
-
946
- query = (
947
- f"best practices for managing a {visual_results.get('wound_type','Unknown')} "
948
- f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' "
949
- f"in a diabetic status '{pi.get('diabetic','unknown')}'"
950
- )
951
  guideline_context = self.query_guidelines(query)
952
 
953
- report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil)
 
954
 
955
  return {
956
  "success": True,
957
  "visual_analysis": visual_results,
958
- "report": report,
959
  "saved_image_path": saved_path,
960
- "guideline_context": (guideline_context or "")[:500] + (
961
- "..." if guideline_context and len(guideline_context) > 500 else ""
962
- ),
963
  }
 
964
  except Exception as e:
965
- logging.error(f"Pipeline error: {e}")
966
  return {
967
  "success": False,
968
  "error": str(e),
969
- "visual_analysis": {},
970
- "report": f"Analysis failed: {str(e)}",
971
- "saved_image_path": None,
972
- "guideline_context": "",
973
  }
974
 
975
- def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
 
 
976
  try:
977
  if isinstance(image, str):
978
- if not os.path.exists(image):
979
- raise ValueError(f"Image file not found: {image}")
980
- image_pil = Image.open(image)
981
- elif isinstance(image, Image.Image):
982
- image_pil = image
983
- elif isinstance(image, np.ndarray):
984
- image_pil = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  else:
986
- raise ValueError(f"Unsupported image type: {type(image)}")
 
 
 
 
 
 
 
 
987
 
988
- return self.full_analysis_pipeline(image_pil, questionnaire_data or {})
989
  except Exception as e:
990
- logging.error(f"Wound analysis error: {e}")
991
  return {
992
- "success": False,
993
- "error": str(e),
994
- "visual_analysis": {},
995
- "report": f"Analysis initialization failed: {str(e)}",
996
- "saved_image_path": None,
997
- "guideline_context": "",
998
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/ai_processor.py
 
 
 
2
  import os
3
  import logging
 
 
 
 
 
 
 
 
 
4
  import cv2
5
  import numpy as np
6
  from PIL import Image
7
+ from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # ---- Safe env defaults (do NOT init CUDA in main) ----
10
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
11
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # mask GPU in main
12
+ HF_HUB_DISABLE_TELEMETRY = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "1")
 
 
13
 
14
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Light imports that won't trigger CUDA
17
+ from transformers import pipeline
18
+ from ultralytics import YOLO
 
 
 
 
19
 
20
+ # TensorFlow: keep on CPU in main process
21
+ import tensorflow as tf
22
+ try:
23
+ tf.config.set_visible_devices([], "GPU")
24
+ except Exception:
25
+ pass
 
26
 
27
+ # LangChain bits (match your old code; no function name change)
28
+ from langchain_community.document_loaders import PyPDFLoader
29
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
30
+ from langchain_community.embeddings import HuggingFaceEmbeddings
31
+ from langchain_community.vectorstores import FAISS
 
 
 
 
 
 
 
 
 
 
32
 
33
+ from huggingface_hub import HfApi, HfFolder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Spaces (ZeroGPU)
36
+ try:
37
+ import spaces
38
+ SPACES_AVAILABLE = True
39
+ except Exception:
40
+ spaces = None
41
+ SPACES_AVAILABLE = False
 
 
 
42
 
43
+ from .config import Config
 
 
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # ----------------------------- utils -----------------------------
47
+ def _largest_component(mask01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
48
+ num, labels, stats, _ = cv2.connectedComponentsWithStats(mask01.astype(np.uint8), 8)
49
+ if num <= 1:
50
+ return (mask01 > 0).astype(np.uint8)
51
+ areas = stats[1:, cv2.CC_STAT_AREA]
52
+ if areas.size == 0 or areas.max() < min_area_px:
53
+ return (mask01 > 0).astype(np.uint8)
54
+ idx = 1 + int(np.argmax(areas))
55
+ return (labels == idx).astype(np.uint8)
 
 
 
56
 
57
  def _fill_holes(mask01: np.ndarray) -> np.ndarray:
58
  h, w = mask01.shape[:2]
 
64
  return out.astype(np.uint8)
65
 
66
  def _clean_mask(mask01: np.ndarray) -> np.ndarray:
 
67
  mask01 = (mask01 > 0).astype(np.uint8)
68
  k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
69
  k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
70
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, k3, iterations=1)
71
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, k5, iterations=1)
72
  mask01 = _fill_holes(mask01)
73
+ mask01 = _largest_component(mask01)
 
 
 
 
 
 
74
  return (mask01 > 0).astype(np.uint8)
75
 
76
+ def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.ndarray:
77
+ h, w = bgr.shape[:2]
78
+ gc = np.full((h, w), cv2.GC_PR_BGD, np.uint8)
79
+ k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
80
+ seed_dil = cv2.dilate(seed01, k, iterations=1)
81
+ gc[seed01.astype(bool)] = cv2.GC_PR_FGD
82
+ gc[seed_dil.astype(bool)] = cv2.GC_FGD
83
+ gc[0, :], gc[-1, :], gc[:, 0], gc[:, -1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD
84
+ bgdModel = np.zeros((1, 65), np.float64)
85
+ fgdModel = np.zeros((1, 65), np.float64)
86
+ cv2.grabCut(bgr, gc, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
87
+ return np.where((gc == cv2.GC_FGD) | (gc == cv2.GC_PR_FGD), 1, 0).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ def _fallback_segment(roi_bgr: np.ndarray) -> np.ndarray:
90
+ """Robust OpenCV fallback: Lab 2-cluster (maximize a*), then GrabCut grow + cleanup."""
91
+ Z = roi_bgr.reshape((-1, 3)).astype(np.float32)
92
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
93
  _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
94
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
95
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
96
+ wound_idx = int(np.argmax(centers_lab[:, 1])) # reddest cluster
97
+ seed01 = (labels.reshape(roi_bgr.shape[:2]) == wound_idx).astype(np.uint8)
98
+ gc01 = _grabcut_refine(roi_bgr, seed01, iters=3)
99
+ return _clean_mask(gc01)
100
+
101
+ def _safe_load_seg_model(path: str):
102
+ """Try multiple loaders to survive Keras 3 / TF 2.15 / h5 mismatches."""
103
+ if not os.path.exists(path):
104
+ return None
105
+ try:
106
+ # Keras legacy API (present in TF 2.13+ with legacy shim)
107
+ from tensorflow import keras as tfk
108
+ if hasattr(tfk, "saving") and hasattr(tfk.saving, "legacy"):
109
+ return tfk.saving.legacy.load_model(path, compile=False)
110
+ except Exception:
111
+ pass
112
+ try:
113
+ # tf.keras standard loader
114
+ from tensorflow.keras.models import load_model as tf_load_model
115
+ return tf_load_model(path, compile=False)
116
+ except Exception as e:
117
+ logging.warning(f"Segmentation model failed to load with legacy + tf.keras: {e}")
118
+ return None
119
+
120
+
121
+ # ----------------------------- main class -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  class AIProcessor:
123
  def __init__(self):
124
+ self.models_cache = {}
125
+ self.knowledge_base_cache = {}
126
+ self.config = Config()
127
+ self.px_per_cm = self.config.PIXELS_PER_CM
128
+ self._initialize_models()
129
+
130
+ def _initialize_models(self):
131
+ """Initialize all AI models except GPU VLM (that one loads inside the GPU worker)."""
 
 
 
 
 
 
 
 
132
  try:
133
+ # HF token
134
+ if self.config.HF_TOKEN:
135
+ HfFolder.save_token(self.config.HF_TOKEN)
136
+ logging.info(" HF token set")
137
+
138
+ # YOLO (CPU-only in main)
139
+ try:
140
+ self.models_cache["det"] = YOLO(self.config.YOLO_MODEL_PATH)
141
+ logging.info("✅ YOLO loaded (CPU; CUDA masked in main)")
142
+ except Exception as e:
143
+ logging.warning(f"YOLO not available: {e}")
144
+
145
+ # Segmentation (safe loader, stays on CPU)
146
+ try:
147
+ seg = _safe_load_seg_model(self.config.SEG_MODEL_PATH)
148
+ if seg is None:
149
+ raise RuntimeError("segmentation file missing or incompatible")
150
+ self.models_cache["seg"] = seg
151
+ logging.info("✅ Segmentation model loaded (CPU)")
152
+ except Exception as e:
153
+ self.models_cache["seg"] = None
154
+ logging.warning(f"Segmentation unavailable: {e}")
155
+
156
+ # Wound classifier (CPU)
157
+ try:
158
+ self.models_cache["cls"] = pipeline(
159
+ "image-classification",
160
+ model="Hemg/Wound-classification",
161
+ token=self.config.HF_TOKEN,
162
+ device="cpu",
163
+ )
164
+ logging.info("✅ Classifier loaded (CPU)")
165
+ except Exception as e:
166
+ self.models_cache["cls"] = None
167
+ logging.warning(f"Classifier unavailable: {e}")
168
+
169
+ # Embeddings for KB (CPU)
170
+ try:
171
+ self.models_cache["embedding_model"] = HuggingFaceEmbeddings(
172
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
173
+ model_kwargs={"device": "cpu"},
174
+ )
175
+ logging.info("✅ Embeddings loaded (CPU)")
176
+ except Exception as e:
177
+ self.models_cache["embedding_model"] = None
178
+ logging.warning(f"Embeddings unavailable: {e}")
179
+
180
+ self._load_knowledge_base()
181
+ except Exception as e:
182
+ logging.error(f"Error initializing models: {e}")
183
 
184
+ def _load_knowledge_base(self):
185
+ """Load guideline PDFs into a FAISS vector store."""
186
+ try:
187
+ documents = []
188
+ for pdf_path in self.config.GUIDELINE_PDFS:
189
+ if os.path.exists(pdf_path):
190
+ try:
191
+ loader = PyPDFLoader(pdf_path)
192
+ docs = loader.load()
193
+ documents.extend(docs)
194
+ logging.info(f"Loaded PDF: {pdf_path}")
195
+ except Exception as e:
196
+ logging.warning(f"PDF load failed ({pdf_path}): {e}")
197
+
198
+ if documents and self.models_cache.get("embedding_model"):
199
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
200
+ chunks = splitter.split_documents(documents)
201
+ vectorstore = FAISS.from_documents(chunks, self.models_cache["embedding_model"])
202
+ self.knowledge_base_cache["vectorstore"] = vectorstore
203
+ logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)")
204
+ else:
205
+ self.knowledge_base_cache["vectorstore"] = None
206
+ logging.warning("Knowledge base not available (no PDFs or embeddings).")
207
+ except Exception as e:
208
+ logging.warning(f"Knowledge base loading error: {e}")
209
+ self.knowledge_base_cache["vectorstore"] = None
210
+
211
+ # ------------------------ vision core ------------------------
212
+ def perform_visual_analysis(self, image_pil):
213
+ """Perform comprehensive visual analysis of wound image."""
214
+ try:
215
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
216
 
217
+ # YOLO detection
218
+ if "det" not in self.models_cache or self.models_cache["det"] is None:
219
+ raise ValueError("YOLO detection model not available.")
220
+ results = self.models_cache["det"].predict(image_cv, verbose=False, device="cpu")
 
221
  if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
222
+ raise ValueError("No wound detected in the image.")
 
 
 
 
223
 
224
+ box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
225
  x1, y1, x2, y2 = [int(v) for v in box]
226
  x1, y1 = max(0, x1), max(0, y1)
227
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
228
+ region_cv = image_cv[y1:y2, x1:x2].copy()
 
 
 
 
 
 
229
 
230
+ # Save detection vis
231
+ os.makedirs(os.path.join(self.config.UPLOADS_DIR, "analysis"), exist_ok=True)
232
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  det_vis = image_cv.copy()
234
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
235
+ detection_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"detection_{ts}.png")
236
+ cv2.imwrite(detection_image_path, det_vis)
237
+ detection_image_pil = Image.fromarray(cv2.cvtColor(det_vis, cv2.COLOR_BGR2RGB))
238
+
239
+ # --- segmentation ---
240
+ length = breadth = area = 0.0
241
+ segmentation_image_pil = None
242
+ segmentation_image_path = None
243
+
244
+ mask01 = None
245
+ seg_model = self.models_cache.get("seg", None)
246
+ if seg_model is not None:
247
+ try:
248
+ ishape = getattr(seg_model, "input_shape", None)
249
+ th, tw = int(ishape[1]), int(ishape[2]) if ishape and len(ishape) >= 3 else (224, 224)
250
+ resized = cv2.resize(region_cv, (tw, th), interpolation=cv2.INTER_LINEAR)
251
+ x = np.expand_dims(resized.astype(np.float32) / 255.0, 0)
252
+ pred = seg_model.predict(x, verbose=0)
253
+ if isinstance(pred, (list, tuple)):
254
+ pred = pred[0]
255
+ p = np.squeeze(pred)
256
+ # sigmoid if raw
257
+ if p.max() > 1.0 or p.min() < 0.0:
258
+ p = 1.0 / (1.0 + np.exp(-p))
259
+ p = cv2.resize(p.astype(np.float32), (region_cv.shape[1], region_cv.shape[0]), interpolation=cv2.INTER_LINEAR)
260
+ # adaptive threshold
261
+ p255 = (np.clip(p, 0, 1) * 255).astype(np.uint8)
262
+ thr_val, _ = cv2.threshold(p255, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
263
+ thr = float(np.clip(thr_val / 255.0, 0.25, 0.65))
264
+ seed01 = (p >= thr).astype(np.uint8)
265
+ if seed01.sum() == 0:
266
+ seed01 = (p >= max(thr - 0.1, 0.15)).astype(np.uint8)
267
+ gc01 = _grabcut_refine(region_cv, seed01, iters=3)
268
+ mask01 = _clean_mask(gc01)
269
+ except Exception as e:
270
+ logging.warning(f"Segmentation model failed; using OpenCV fallback: {e}")
271
+ mask01 = _fallback_segment(region_cv)
272
  else:
273
+ mask01 = _fallback_segment(region_cv)
274
+
275
+ # overlay + measurements
276
+ overlay = region_cv.copy()
277
+ red = overlay.copy(); red[:] = (0, 0, 255)
278
+ if mask01 is not None and mask01.any():
279
+ mask255 = (mask01 * 255).astype(np.uint8)
280
+ mask3 = cv2.merge([mask255, mask255, mask255])
281
+ tinted = cv2.addWeighted(region_cv, 0.45, red, 0.55, 0)
282
+ overlay = np.where(mask3 > 0, tinted, region_cv)
283
+ cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
284
+ if cnts:
285
+ cnt = max(cnts, key=cv2.contourArea)
286
+ x, y, w, h = cv2.boundingRect(cnt)
287
+ length = round(h / self.px_per_cm, 2)
288
+ breadth = round(w / self.px_per_cm, 2)
289
+ area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
290
+ cv2.drawContours(overlay, [cnt], -1, (255, 255, 255), 2)
291
+
292
+ segmentation_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"segmentation_{ts}.png")
293
  seg_full = image_cv.copy()
294
+ seg_full[y1:y2, x1:x2] = overlay
295
+ cv2.imwrite(segmentation_image_path, seg_full)
296
+ segmentation_image_pil = Image.fromarray(cv2.cvtColor(seg_full, cv2.COLOR_BGR2RGB))
297
 
298
+ # classification
 
 
 
 
 
 
 
 
 
299
  wound_type = "Unknown"
300
+ if self.models_cache.get("cls") is not None:
 
301
  try:
302
+ region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB))
303
+ preds = self.models_cache["cls"](region_pil)
304
  if preds:
305
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
306
  except Exception as e:
307
+ logging.warning(f"Wound classification error: {e}")
308
+
309
+ conf = 0.0
310
+ try:
311
+ conf = float(results[0].boxes.conf[0].cpu().item())
312
+ except Exception:
313
+ pass
 
 
 
 
 
314
 
315
  return {
316
  "wound_type": wound_type,
317
+ "length_cm": float(length),
318
+ "breadth_cm": float(breadth),
319
+ "surface_area_cm2": float(area),
320
+ "detection_confidence": conf,
321
+ "bounding_box": [int(x1), int(y1), int(x2), int(y2)],
322
+ "detection_image_path": detection_image_path,
323
+ "detection_image_pil": detection_image_pil,
324
+ "segmentation_image_path": segmentation_image_path,
325
+ "segmentation_image_pil": segmentation_image_pil,
 
 
 
 
 
 
326
  }
327
+
328
  except Exception as e:
329
+ logging.error(f"Visual analysis error: {e}", exc_info=True)
330
+ raise ValueError(f"Visual analysis failed: {str(e)}")
331
 
332
+ # ------------------------ KB / RAG ------------------------
333
+ def query_guidelines(self, query: str):
334
+ """Query the knowledge base for relevant guidelines"""
335
  try:
336
+ vector_store = self.knowledge_base_cache.get("vectorstore")
337
+ if not vector_store:
338
+ return "Knowledge base unavailable - clinical guidelines not loaded"
339
+
340
+ retriever = vector_store.as_retriever(search_kwargs={"k": 10})
341
+ try:
342
+ docs = retriever.invoke(query)
343
+ except Exception:
344
+ # old API fallback
345
+ docs = retriever.get_relevant_documents(query)
346
+
347
+ if not docs:
348
+ return "No relevant guidelines found for the query"
349
+
350
+ out = []
351
  for d in docs:
352
+ meta = d.metadata or {}
353
+ src = meta.get("source", "Unknown")
354
+ page = meta.get("page", "N/A")
355
+ content = (d.page_content or "").strip()
356
+ out.append(f"Source: {src}, Page: {page}\nContent: {content}")
357
+ return "\n\n".join(out)
358
  except Exception as e:
359
+ logging.error(f"Guidelines query error: {e}")
360
+ return f"Error querying guidelines: {str(e)}"
361
+
362
+ # ------------------------ Reporting (VLM + fallback) ------------------------
363
+ def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
364
+ """Generate comprehensive medical report using a VLM if available (loaded inside GPU worker), else fallback."""
365
+ try:
366
+ # If medgemma/qwen pipeline wasn't cached by GPU worker, fallback right away.
367
+ if "medgemma_pipe" not in self.models_cache or self.models_cache["medgemma_pipe"] is None:
368
+ return self._generate_fallback_report(patient_info, visual_results, guideline_context)
369
+
370
+ max_tokens = max_new_tokens or self.config.MAX_NEW_TOKENS
371
+ detection_path = visual_results.get("detection_image_path", "")
372
+ segmentation_path = visual_results.get("segmentation_image_path", "")
373
 
374
+ prompt = f"""
375
+ # Wound Care Report
376
 
377
+ ## Patient Information
378
  {patient_info}
379
 
380
+ ## Visual Analysis Summary
381
+ - Wound Type: {visual_results.get('wound_type', 'Unknown')}
382
+ - Length: {visual_results.get('length_cm', 0)} cm
383
+ - Breadth: {visual_results.get('breadth_cm', 0)} cm
384
+ - Surface Area: {visual_results.get('surface_area_cm2', 0)} cm²
385
+ - Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
386
+
387
+ ## Clinical Reference
388
+ {guideline_context}
389
+
390
+ You are SmartHeal-AI Agent, a world-class wound care AI specialist trained in clinical wound assessment and guideline-based treatment planning.
391
+ Generate a concise, actionable, evidence-based report with: Clinical Summary, Dressing/Medication Recommendations, Key Risk Factors, and Prognosis & Monitoring. Avoid generic advice; tailor to the data above.
392
+ """.strip()
393
+
394
+ content_list = [{"type": "text", "text": prompt}]
395
+ if image_pil:
396
+ content_list.insert(0, {"type": "image", "image": image_pil})
397
+ if visual_results.get("detection_image_pil"):
398
+ content_list.append({"type": "image", "image": visual_results["detection_image_pil"]})
399
+ if visual_results.get("segmentation_image_pil"):
400
+ content_list.append({"type": "image", "image": visual_results["segmentation_image_pil"]})
401
+
402
+ messages = [
403
+ {
404
+ "role": "system",
405
+ "content": [{"type": "text", "text": "You are a medical AI assistant specializing in wound care. Be precise, objective, and recommendation-focused."}],
406
+ },
407
+ {
408
+ "role": "user",
409
+ "content": content_list,
410
+ },
411
+ ]
412
+
413
+ out = self.models_cache["medgemma_pipe"](
414
+ text=messages,
415
+ max_new_tokens=int(max_tokens),
416
+ do_sample=False,
 
 
 
 
 
 
417
  )
418
+ generated = ""
419
+ try:
420
+ generated = (out[0]["generated_text"][-1].get("content", "") or "").strip()
421
+ except Exception:
422
+ generated = (out[0].get("generated_text", "") or "").strip()
423
+
424
+ if generated:
425
+ images_sec = f"\n\n## Analysis Images\n- Detection: {detection_path}\n- Segmentation: {segmentation_path}\n"
426
+ return images_sec + generated
427
+
428
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
429
  except Exception as e:
430
+ logging.error(f"VLM report generation error: {e}")
431
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
432
 
433
+ def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
434
+ detection_path = visual_results.get("detection_image_path", "Not available")
435
+ segmentation_path = visual_results.get("segmentation_image_path", "Not available")
436
+ return f"""
437
+ # Wound Analysis Report
438
+ ## Patient Information
439
+ {patient_info}
440
+
441
+ ## Visual Analysis Results
442
+ - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
443
+ - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
444
+ - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
445
+ - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
446
+
447
+ ## Analysis Images
448
+ - **Detection Image**: {detection_path}
449
+ - **Segmentation Image**: {segmentation_path}
450
+
451
+ ## Assessment
452
+ Automated measurements provided. Verify via clinical exam.
453
+
454
+ ## Recommendations
455
+ - Cleanse wound; choose dressing per moisture/infection risk
456
+ - Consider debridement if indicated
457
+ - Document with serial photos & measurements
458
+
459
+ ## Clinical Guidelines
460
+ {(guideline_context or '')[:500]}...
461
+
462
+ *Note: Decision support only; not a diagnosis.*
463
+ """.strip()
464
+
465
+ # ------------------------ I/O ------------------------
466
+ def save_and_commit_image(self, image_pil):
467
+ """Save image locally and optionally upload to HuggingFace dataset"""
468
  try:
469
+ os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
470
+ filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
471
+ local_path = os.path.join(self.config.UPLOADS_DIR, filename)
472
+ image_pil.convert("RGB").save(local_path)
473
+ logging.info(f"✅ Image saved locally: {local_path}")
 
474
 
475
+ if self.config.HF_TOKEN and self.config.DATASET_ID:
476
  try:
477
+ HfFolder.save_token(self.config.HF_TOKEN)
 
478
  api = HfApi()
479
  api.upload_file(
480
+ path_or_fileobj=local_path,
481
  path_in_repo=f"images/{filename}",
482
+ repo_id=self.config.DATASET_ID,
483
  repo_type="dataset",
 
484
  commit_message=f"Upload wound image: {filename}",
485
  )
486
  logging.info("✅ Image committed to HF dataset")
487
  except Exception as e:
488
  logging.warning(f"HF upload failed: {e}")
489
 
490
+ return local_path
491
  except Exception as e:
492
+ logging.error(f"Image saving error: {e}")
493
+ return None
494
 
495
+ # ------------------------ Pipeline (GPU-safe) ------------------------
496
+ # decoration is evaluated at import; only the selected branch executes
497
+ @spaces.GPU(enable_queue=True, duration=120) if SPACES_AVAILABLE else (lambda f: f)
498
+ def full_analysis_pipeline(self, image, questionnaire_data):
499
+ """Complete analysis pipeline. VLM loads here (inside GPU worker) to avoid main-process CUDA init."""
500
  try:
501
+ # Try to build a VLM inside the worker. If ZeroGPU fails, we fallback to CPU.
502
+ if "medgemma_pipe" not in self.models_cache or self.models_cache["medgemma_pipe"] is None:
503
+ vlm_loaded = False
504
+ # Prefer a small VLM (Qwen2-VL 2B) for Spaces; if it fails on GPU, retry on CPU.
505
+ try:
506
+ self.models_cache["medgemma_pipe"] = pipeline(
507
+ "image-text-to-text",
508
+ model=os.environ.get("SMARTHEAL_VLM", "Qwen/Qwen2-VL-2B-Instruct"),
509
+ token=self.config.HF_TOKEN,
510
+ device_map="cuda", # we're inside ZeroGPU worker now
511
+ torch_dtype="auto",
512
+ max_new_tokens=self.config.MAX_NEW_TOKENS,
513
+ trust_remote_code=True,
514
+ )
515
+ vlm_loaded = True
516
+ logging.info("✅ VLM loaded on GPU worker")
517
+ except Exception as e:
518
+ logging.warning(f"GPU VLM failed; falling back to CPU: {e}")
519
+ try:
520
+ self.models_cache["medgemma_pipe"] = pipeline(
521
+ "image-text-to-text",
522
+ model=os.environ.get("SMARTHEAL_VLM", "Qwen/Qwen2-VL-2B-Instruct"),
523
+ token=self.config.HF_TOKEN,
524
+ device_map="cpu",
525
+ torch_dtype="auto",
526
+ max_new_tokens=self.config.MAX_NEW_TOKENS,
527
+ trust_remote_code=True,
528
+ )
529
+ vlm_loaded = True
530
+ logging.info("✅ VLM loaded on CPU")
531
+ except Exception as e2:
532
+ self.models_cache["medgemma_pipe"] = None
533
+ logging.error(f"❌ Could not load any VLM: {e2}")
534
+
535
+ # Save image
536
+ saved_path = self.save_and_commit_image(image)
537
+
538
+ # Visual analysis
539
+ visual_results = self.perform_visual_analysis(image)
540
+
541
+ # Patient info string
542
  pi = questionnaire_data or {}
543
+ patient_info = ", ".join([f"{k}: {v}" for k, v in pi.items() if str(v).strip() != ""])
544
+
545
+ # KB query
546
+ wound_type = visual_results.get("wound_type", "wound")
547
+ moisture = pi.get("moisture", "unknown")
548
+ infection = pi.get("infection", "unknown")
549
+ diabetic = pi.get("diabetic", "unknown")
550
+ query = f"best practices for managing a {wound_type} with moisture level '{moisture}' and signs of infection '{infection}' in a patient who is diabetic '{diabetic}'"
 
 
 
 
 
 
 
 
551
  guideline_context = self.query_guidelines(query)
552
 
553
+ # Report
554
+ final_report = self.generate_final_report(patient_info, visual_results, guideline_context, image)
555
 
556
  return {
557
  "success": True,
558
  "visual_analysis": visual_results,
559
+ "report": final_report,
560
  "saved_image_path": saved_path,
561
+ "timestamp": datetime.now().isoformat(),
 
 
562
  }
563
+
564
  except Exception as e:
565
+ logging.error(f"Full analysis pipeline error: {e}", exc_info=True)
566
  return {
567
  "success": False,
568
  "error": str(e),
569
+ "timestamp": datetime.now().isoformat(),
 
 
 
570
  }
571
 
572
+ # ------------------------ Legacy API ------------------------
573
+ def analyze_wound(self, image, questionnaire_data):
574
+ """Legacy method for backward compatibility"""
575
  try:
576
  if isinstance(image, str):
577
+ try:
578
+ image = Image.open(image)
579
+ logging.info("Converted path to PIL Image")
580
+ except Exception as e:
581
+ logging.error(f"Error opening image: {e}")
582
+ if not isinstance(image, Image.Image):
583
+ # file-like?
584
+ if hasattr(image, "read"):
585
+ try:
586
+ if hasattr(image, "seek"):
587
+ image.seek(0)
588
+ image = Image.open(image)
589
+ except Exception as e:
590
+ logging.error(f"Error reading file-like image: {e}")
591
+ raise ValueError(f"Invalid image format: {type(image)}")
592
+
593
+ result = self.full_analysis_pipeline(image, questionnaire_data)
594
+
595
+ if result.get("success"):
596
+ return {
597
+ "timestamp": result["timestamp"],
598
+ "summary": f"Analysis completed for {questionnaire_data.get('patient_name', 'patient')}",
599
+ "recommendations": result["report"],
600
+ "wound_detection": {
601
+ "status": "success",
602
+ "detections": [result["visual_analysis"]],
603
+ "total_wounds": 1,
604
+ },
605
+ "segmentation_result": {
606
+ "status": "success",
607
+ "wound_area_percentage": result["visual_analysis"].get("surface_area_cm2", 0),
608
+ },
609
+ "risk_assessment": self._assess_risk_legacy(questionnaire_data),
610
+ "guideline_recommendations": [result["report"][:200] + "..."],
611
+ }
612
  else:
613
+ return {
614
+ "timestamp": result["timestamp"],
615
+ "summary": f"Analysis failed: {result.get('error','unknown')}",
616
+ "recommendations": "Please consult with a healthcare professional.",
617
+ "wound_detection": {"status": "error", "message": result.get("error", "")},
618
+ "segmentation_result": {"status": "error", "message": result.get("error", "")},
619
+ "risk_assessment": {"risk_score": 0, "risk_level": "Unknown", "risk_factors": []},
620
+ "guideline_recommendations": ["Analysis unavailable due to error"],
621
+ }
622
 
 
623
  except Exception as e:
624
+ logging.error(f"Legacy analyze_wound error: {e}")
625
  return {
626
+ "timestamp": datetime.now().isoformat(),
627
+ "summary": f"Analysis error: {str(e)}",
628
+ "recommendations": "Please consult with a healthcare professional.",
629
+ "wound_detection": {"status": "error", "message": str(e)},
630
+ "segmentation_result": {"status": "error", "message": str(e)},
631
+ "risk_assessment": {"risk_score": 0, "risk_level": "Unknown", "risk_factors": []},
632
+ "guideline_recommendations": ["Analysis unavailable due to error"],
633
+ }
634
+
635
+ def _assess_risk_legacy(self, questionnaire_data):
636
+ """Legacy risk assessment for backward compatibility"""
637
+ risk_factors = []
638
+ risk_score = 0
639
+ try:
640
+ age = int(questionnaire_data.get("patient_age", 0) or 0)
641
+ if age > 65:
642
+ risk_factors.append("Advanced age (>65)")
643
+ risk_score += 2
644
+ elif age > 50:
645
+ risk_factors.append("Older adult (50-65)")
646
+ risk_score += 1
647
+
648
+ duration = str(questionnaire_data.get("wound_duration", "")).lower()
649
+ if any(t in duration for t in ["month", "months", "year"]):
650
+ risk_factors.append("Chronic wound (>4 weeks)")
651
+ risk_score += 3
652
+
653
+ pain_level = int(questionnaire_data.get("pain_level", 0) or 0)
654
+ if pain_level >= 7:
655
+ risk_factors.append("High pain level")
656
+ risk_score += 2
657
+
658
+ medical_history = str(questionnaire_data.get("medical_history", "")).lower()
659
+ if "diabetes" in medical_history:
660
+ risk_factors.append("Diabetes mellitus")
661
+ risk_score += 3
662
+ if "circulation" in medical_history or "vascular" in medical_history:
663
+ risk_factors.append("Vascular/circulation issues")
664
+ risk_score += 2
665
+ if "immune" in medical_history:
666
+ risk_factors.append("Immune system compromise")
667
+ risk_score += 2
668
+
669
+ risk_level = "High" if risk_score >= 7 else ("Moderate" if risk_score >= 4 else "Low")
670
+ return {"risk_score": risk_score, "risk_level": risk_level, "risk_factors": risk_factors}
671
+ except Exception as e:
672
+ logging.error(f"Risk assessment error: {e}")
673
+ return {"risk_score": 0, "risk_level": "Unknown", "risk_factors": []}