SmartHeal commited on
Commit
bfd9991
Β·
verified Β·
1 Parent(s): 75ebbe5

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +84 -121
src/ai_processor.py CHANGED
@@ -1,5 +1,6 @@
1
  # smartheal_ai_processor.py
2
- # Full, functional module with conditional Spaces GPU support and CPU fallbacks.
 
3
 
4
  import os
5
  import time
@@ -11,32 +12,32 @@ import cv2
11
  import numpy as np
12
  from PIL import Image
13
 
14
- # =============== LOGGING SETUP ===============
15
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16
 
17
- # =============== CONFIGURATION ===============
18
  UPLOADS_DIR = "uploads"
19
  os.makedirs(UPLOADS_DIR, exist_ok=True)
20
 
21
  HF_TOKEN = os.getenv("HF_TOKEN", None)
22
  YOLO_MODEL_PATH = "src/best.pt"
23
- SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
24
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
25
- DATASET_ID = "SmartHeal/wound-image-uploads" # optional (set HF_TOKEN too)
26
- PIXELS_PER_CM = 38 # heuristic
27
 
28
- # =============== GLOBAL CACHES ===============
29
  models_cache: Dict[str, object] = {}
30
  knowledge_base_cache: Dict[str, object] = {}
31
 
32
- # ---------- Optional imports guarded ----------
33
  def _import_ultralytics():
34
  from ultralytics import YOLO
35
  return YOLO
36
 
37
  def _import_tf_loader():
38
  import tensorflow as tf
39
- tf.config.set_visible_devices([], "GPU") # force CPU
40
  from tensorflow.keras.models import load_model
41
  return load_model
42
 
@@ -60,39 +61,33 @@ def _import_hf_hub():
60
  from huggingface_hub import HfApi, HfFolder
61
  return HfApi, HfFolder
62
 
63
- # =============== SPACES GPU CONDITIONAL ===============
64
- def _spaces_gpu_available() -> bool:
65
- try:
66
- import torch
67
- return bool(torch.cuda.is_available())
68
- except Exception:
69
- return False
70
-
71
- def _spaces_lib_available() -> bool:
72
- try:
73
- import spaces # noqa
74
- return True
75
- except Exception:
76
- return False
77
-
78
- HAVE_SPACES_GPU = _spaces_gpu_available() and _spaces_lib_available()
79
-
80
- if HAVE_SPACES_GPU:
81
- import spaces # define only if available & GPU present
82
 
83
  @spaces.GPU(enable_queue=True, duration=90)
84
- def generate_medgemma_report_with_timeout(
85
  patient_info: str,
86
  visual_results: Dict,
87
  guideline_context: str,
88
  image_pil: Image.Image,
89
  max_new_tokens: Optional[int] = None,
90
  ) -> str:
91
- """Runs on Spaces GPU only; callers keep one signature on both paths."""
92
- import torch
93
- from transformers import pipeline
 
 
94
  try:
95
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
96
 
97
  prompt = f"""
98
  You are a medical AI assistant. Analyze this wound image and patient data.
@@ -110,21 +105,16 @@ Provide a structured report with:
110
  pipe = pipeline(
111
  "image-text-to-text",
112
  model="google/medgemma-4b-it",
113
- torch_dtype=torch.bfloat16,
114
  device_map="auto",
115
  token=HF_TOKEN,
116
  model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
117
  )
118
 
119
- messages = [
120
- {
121
- "role": "user",
122
- "content": [
123
- {"type": "image", "image": image_pil},
124
- {"type": "text", "text": prompt},
125
- ],
126
- }
127
- ]
128
 
129
  t0 = time.time()
130
  out = pipe(
@@ -134,10 +124,10 @@ Provide a structured report with:
134
  temperature=0.7,
135
  pad_token_id=pipe.tokenizer.eos_token_id,
136
  )
137
- logging.info(f"βœ… MedGemma completed in {time.time() - t0:.2f}s")
138
 
139
  if out and len(out) > 0:
140
- # Defensive extraction
141
  try:
142
  return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
143
  except Exception:
@@ -145,13 +135,9 @@ Provide a structured report with:
145
  return "⚠️ No output generated"
146
  except Exception as e:
147
  logging.error(f"❌ MedGemma generation error: {e}")
148
- return f"❌ Report generation failed: {str(e)}"
149
- finally:
150
- try:
151
- torch.cuda.empty_cache()
152
- except Exception:
153
- pass
154
- else:
155
  def generate_medgemma_report_with_timeout(
156
  patient_info: str,
157
  visual_results: Dict,
@@ -159,10 +145,9 @@ else:
159
  image_pil: Image.Image,
160
  max_new_tokens: Optional[int] = None,
161
  ) -> str:
162
- """CPU-only path: return a warning so caller uses fallback."""
163
  return "⚠️ GPU not available"
164
 
165
- # =============== MODEL INITIALIZATION (CPU-SAFE) ===============
166
  def load_yolo_model():
167
  YOLO = _import_ultralytics()
168
  return YOLO(YOLO_MODEL_PATH)
@@ -173,32 +158,25 @@ def load_segmentation_model():
173
 
174
  def load_classification_pipeline():
175
  pipe = _import_hf_cls()
176
- return pipe(
177
- "image-classification",
178
- model="Hemg/Wound-classification",
179
- token=HF_TOKEN,
180
- device="cpu",
181
- )
182
 
183
  def load_embedding_model():
184
  Emb = _import_embeddings()
185
  return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
186
 
187
  def initialize_cpu_models() -> None:
188
- """Initialize all CPU-only models once with robust fallbacks."""
189
- # Hugging Face auth (optional)
190
  if HF_TOKEN:
191
  try:
192
  HfApi, HfFolder = _import_hf_hub()
193
  HfFolder.save_token(HF_TOKEN)
194
- logging.info("βœ… HuggingFace token set")
195
  except Exception as e:
196
  logging.warning(f"HF token save failed: {e}")
197
 
198
  if "det" not in models_cache:
199
  try:
200
  models_cache["det"] = load_yolo_model()
201
- logging.info("βœ… YOLO model loaded (CPU)")
202
  except Exception as e:
203
  logging.error(f"YOLO load failed: {e}")
204
 
@@ -209,43 +187,41 @@ def initialize_cpu_models() -> None:
209
  logging.info("βœ… Segmentation model loaded (CPU)")
210
  else:
211
  models_cache["seg"] = None
212
- logging.warning("Segmentation model file not found; skipping seg.")
213
  except Exception as e:
214
  models_cache["seg"] = None
215
- logging.warning(f"Segmentation model not available: {e}")
216
 
217
  if "cls" not in models_cache:
218
  try:
219
  models_cache["cls"] = load_classification_pipeline()
220
- logging.info("βœ… Classification pipeline loaded (CPU)")
221
  except Exception as e:
222
  models_cache["cls"] = None
223
- logging.warning(f"Classification pipeline not available: {e}")
224
 
225
  if "embedding_model" not in models_cache:
226
  try:
227
  models_cache["embedding_model"] = load_embedding_model()
228
- logging.info("βœ… Embedding model loaded (CPU)")
229
  except Exception as e:
230
  models_cache["embedding_model"] = None
231
- logging.warning(f"Embedding model not available: {e}")
232
 
233
  def setup_knowledge_base() -> None:
234
- """Load PDFs and create FAISS vector store (optional)."""
235
  if "vector_store" in knowledge_base_cache:
236
  return
237
 
238
- docs = []
239
  try:
240
  PyPDFLoader = _import_langchain_pdf()
241
  for pdf in GUIDELINE_PDFS:
242
  if os.path.exists(pdf):
243
  try:
244
- loader = PyPDFLoader(pdf)
245
- docs.extend(loader.load())
246
  logging.info(f"Loaded PDF: {pdf}")
247
  except Exception as e:
248
- logging.warning(f"Failed to load PDF {pdf}: {e}")
249
  except Exception as e:
250
  logging.warning(f"LangChain PDF loader unavailable: {e}")
251
 
@@ -253,18 +229,17 @@ def setup_knowledge_base() -> None:
253
  try:
254
  from langchain.text_splitter import RecursiveCharacterTextSplitter
255
  FAISS = _import_langchain_faiss()
256
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
257
- chunks = splitter.split_documents(docs)
258
  knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
259
- logging.info(f"βœ… Knowledge base ready with {len(chunks)} chunks")
260
  except Exception as e:
261
  knowledge_base_cache["vector_store"] = None
262
- logging.warning(f"Knowledge base unavailable: {e}")
263
  else:
264
  knowledge_base_cache["vector_store"] = None
265
- logging.warning("Knowledge base disabled (no docs or embeddings).")
266
 
267
- # Initialize on import
268
  initialize_cpu_models()
269
  setup_knowledge_base()
270
 
@@ -278,14 +253,13 @@ class AIProcessor:
278
  self.dataset_id = DATASET_ID
279
  self.hf_token = HF_TOKEN
280
 
281
- # ---------- Image utilities ----------
282
  def _ensure_analysis_dir(self) -> str:
283
  out_dir = os.path.join(self.uploads_dir, "analysis")
284
  os.makedirs(out_dir, exist_ok=True)
285
  return out_dir
286
 
287
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
288
- """YOLO detect β†’ (optional) Keras seg β†’ (optional) HF classifier β†’ save visuals."""
289
  try:
290
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
291
 
@@ -338,7 +312,7 @@ class AIProcessor:
338
  seg_path = os.path.join(out_dir, f"segmentation_{ts}.png")
339
  cv2.imwrite(seg_path, seg_vis)
340
  except Exception as e:
341
- logging.warning(f"Segmentation step skipped: {e}")
342
 
343
  # Optional classification
344
  wound_type = "Unknown"
@@ -350,7 +324,7 @@ class AIProcessor:
350
  if preds:
351
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
352
  except Exception as e:
353
- logging.warning(f"Classification step failed: {e}")
354
 
355
  # Save detection & original
356
  out_dir = self._ensure_analysis_dir()
@@ -380,19 +354,17 @@ class AIProcessor:
380
  raise
381
 
382
  def query_guidelines(self, query: str) -> str:
383
- """Query the knowledge base (optional)."""
384
  try:
385
  vs = self.knowledge_base_cache.get("vector_store")
386
  if not vs:
387
  return "Knowledge base is not available."
388
- # support both old and new retriever APIs
389
  try:
390
  retriever = vs.as_retriever(search_kwargs={"k": 5})
391
  docs = retriever.get_relevant_documents(query) # LC >= 0.2
392
  except Exception:
393
  retriever = vs.as_retriever(search_kwargs={"k": 5})
394
- # older invoke API
395
- docs = retriever.invoke(query)
396
  lines: List[str] = []
397
  for d in docs:
398
  src = (d.metadata or {}).get("source", "N/A")
@@ -403,9 +375,7 @@ class AIProcessor:
403
  logging.warning(f"Guidelines query failed: {e}")
404
  return f"Guidelines query failed: {str(e)}"
405
 
406
- # ---------- Report builders ----------
407
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
408
- """Plaintext/markdown fallback when MedGemma is unavailable."""
409
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
410
 
411
  ## πŸ“‹ Patient Information
@@ -431,11 +401,11 @@ Automated analysis provides quantitative measurements; verify via clinical exami
431
  - Document with serial photos and measurements
432
 
433
  ## πŸ“… Monitoring
434
- - Daily in week 1, then every 2-3 days (or as indicated)
435
  - Weekly progress review
436
 
437
  ## πŸ“š Guideline Context
438
- {(guideline_context or '')[:800]}{'...' if guideline_context and len(guideline_context) > 800 else ''}
439
 
440
  **Disclaimer:** Automated, for decision support only. Verify clinically.
441
  """
@@ -448,7 +418,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
448
  image_pil: Image.Image,
449
  max_new_tokens: Optional[int] = None,
450
  ) -> str:
451
- """Try MedGemma (GPU) β†’ fallback report."""
452
  try:
453
  report = generate_medgemma_report_with_timeout(
454
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
@@ -461,9 +431,8 @@ Automated analysis provides quantitative measurements; verify via clinical exami
461
  logging.error(f"Report generation failed: {e}")
462
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
463
 
464
- # ---------- HF dataset commit ----------
465
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
466
- """Save image locally and optionally upload to HF dataset."""
467
  try:
468
  os.makedirs(self.uploads_dir, exist_ok=True)
469
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -472,17 +441,17 @@ Automated analysis provides quantitative measurements; verify via clinical exami
472
  image_pil.convert("RGB").save(path)
473
  logging.info(f"βœ… Image saved locally: {path}")
474
 
475
- if self.hf_token and self.dataset_id:
476
  try:
477
  HfApi, HfFolder = _import_hf_hub()
478
- HfFolder.save_token(self.hf_token)
479
  api = HfApi()
480
  api.upload_file(
481
  path_or_fileobj=path,
482
  path_in_repo=f"images/{filename}",
483
- repo_id=self.dataset_id,
484
  repo_type="dataset",
485
- token=self.hf_token,
486
  commit_message=f"Upload wound image: {filename}",
487
  )
488
  logging.info("βœ… Image committed to HF dataset")
@@ -494,28 +463,24 @@ Automated analysis provides quantitative measurements; verify via clinical exami
494
  logging.error(f"Failed to save/commit image: {e}")
495
  return ""
496
 
497
- # ---------- Orchestrator ----------
498
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
499
- """End-to-end analysis with robust fallbacks."""
500
  try:
501
  saved_path = self.save_and_commit_image(image_pil)
502
-
503
  visual_results = self.perform_visual_analysis(image_pil)
504
 
505
- # Patient info summary text
506
  pi = questionnaire_data or {}
507
  patient_info = (
508
- f"Age: {pi.get('age', 'N/A')}, "
509
- f"Diabetic: {pi.get('diabetic', 'N/A')}, "
510
- f"Allergies: {pi.get('allergies', 'N/A')}, "
511
- f"Date of Wound: {pi.get('date_of_injury', 'N/A')}, "
512
- f"Professional Care: {pi.get('professional_care', 'N/A')}, "
513
- f"Oozing/Bleeding: {pi.get('oozing_bleeding', 'N/A')}, "
514
- f"Infection: {pi.get('infection', 'N/A')}, "
515
- f"Moisture: {pi.get('moisture', 'N/A')}"
516
  )
517
 
518
- # Query guidelines
519
  query = (
520
  f"best practices for managing a {visual_results.get('wound_type','Unknown')} "
521
  f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' "
@@ -523,18 +488,16 @@ Automated analysis provides quantitative measurements; verify via clinical exami
523
  )
524
  guideline_context = self.query_guidelines(query)
525
 
526
- # Generate final report
527
- report = self.generate_final_report(patient_info=patient_info,
528
- visual_results=visual_results,
529
- guideline_context=guideline_context,
530
- image_pil=image_pil)
531
 
532
  return {
533
  "success": True,
534
  "visual_analysis": visual_results,
535
  "report": report,
536
  "saved_image_path": saved_path,
537
- "guideline_context": (guideline_context or "")[:500] + ("..." if guideline_context and len(guideline_context) > 500 else ""),
 
 
538
  }
539
  except Exception as e:
540
  logging.error(f"Pipeline error: {e}")
@@ -548,7 +511,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
548
  }
549
 
550
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
551
- """Public entrypoint used by your UI."""
552
  try:
553
  if isinstance(image, str):
554
  if not os.path.exists(image):
@@ -571,4 +534,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
571
  "report": f"Analysis initialization failed: {str(e)}",
572
  "saved_image_path": None,
573
  "guideline_context": "",
574
- }
 
1
  # smartheal_ai_processor.py
2
+ # Full, functional module with an always-present @spaces.GPU function (if `spaces` is importable)
3
+ # and robust CPU fallbacks to avoid crashes when GPU isn't actually available yet.
4
 
5
  import os
6
  import time
 
12
  import numpy as np
13
  from PIL import Image
14
 
15
+ # =============== LOGGING ===============
16
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
17
 
18
+ # =============== CONFIG ===============
19
  UPLOADS_DIR = "uploads"
20
  os.makedirs(UPLOADS_DIR, exist_ok=True)
21
 
22
  HF_TOKEN = os.getenv("HF_TOKEN", None)
23
  YOLO_MODEL_PATH = "src/best.pt"
24
+ SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
25
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
26
+ DATASET_ID = "SmartHeal/wound-image-uploads" # optional (requires HF_TOKEN)
27
+ PIXELS_PER_CM = 38
28
 
29
+ # =============== CACHES ===============
30
  models_cache: Dict[str, object] = {}
31
  knowledge_base_cache: Dict[str, object] = {}
32
 
33
+ # =============== Optional imports (lazy) ===============
34
  def _import_ultralytics():
35
  from ultralytics import YOLO
36
  return YOLO
37
 
38
  def _import_tf_loader():
39
  import tensorflow as tf
40
+ tf.config.set_visible_devices([], "GPU") # force CPU for TF
41
  from tensorflow.keras.models import load_model
42
  return load_model
43
 
 
61
  from huggingface_hub import HfApi, HfFolder
62
  return HfApi, HfFolder
63
 
64
+ # =============== Spaces GPU function (always defined if `spaces` import works) ===============
65
+ try:
66
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  @spaces.GPU(enable_queue=True, duration=90)
69
+ def generate_medgemma_report(
70
  patient_info: str,
71
  visual_results: Dict,
72
  guideline_context: str,
73
  image_pil: Image.Image,
74
  max_new_tokens: Optional[int] = None,
75
  ) -> str:
76
+ """
77
+ This function MUST exist at import time so Spaces Zero detects it.
78
+ It is guarded internally so if anything fails (no GPU yet, model load error),
79
+ it returns a warning and your pipeline will use the fallback report.
80
+ """
81
  try:
82
+ import torch
83
+ from transformers import pipeline
84
+
85
+ # Try to free cache; if no CUDA, this will raise and we return a warning.
86
+ try:
87
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
88
+ torch.cuda.empty_cache()
89
+ except Exception:
90
+ pass
91
 
92
  prompt = f"""
93
  You are a medical AI assistant. Analyze this wound image and patient data.
 
105
  pipe = pipeline(
106
  "image-text-to-text",
107
  model="google/medgemma-4b-it",
108
+ torch_dtype=getattr(torch, "bfloat16", None),
109
  device_map="auto",
110
  token=HF_TOKEN,
111
  model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
112
  )
113
 
114
+ messages = [{"role": "user", "content": [
115
+ {"type": "image", "image": image_pil},
116
+ {"type": "text", "text": prompt},
117
+ ]}]
 
 
 
 
 
118
 
119
  t0 = time.time()
120
  out = pipe(
 
124
  temperature=0.7,
125
  pad_token_id=pipe.tokenizer.eos_token_id,
126
  )
127
+ logging.info(f"βœ… MedGemma finished in {time.time()-t0:.2f}s")
128
 
129
  if out and len(out) > 0:
130
+ # Defensive extraction (different transformers versions)
131
  try:
132
  return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
133
  except Exception:
 
135
  return "⚠️ No output generated"
136
  except Exception as e:
137
  logging.error(f"❌ MedGemma generation error: {e}")
138
+ return "⚠️ GPU worker unavailable"
139
+ except Exception:
140
+ # If `spaces` cannot be imported locally, expose a CPU-safe stub with same signature.
 
 
 
 
141
  def generate_medgemma_report_with_timeout(
142
  patient_info: str,
143
  visual_results: Dict,
 
145
  image_pil: Image.Image,
146
  max_new_tokens: Optional[int] = None,
147
  ) -> str:
 
148
  return "⚠️ GPU not available"
149
 
150
+ # =============== Model init (CPU-safe) ===============
151
  def load_yolo_model():
152
  YOLO = _import_ultralytics()
153
  return YOLO(YOLO_MODEL_PATH)
 
158
 
159
  def load_classification_pipeline():
160
  pipe = _import_hf_cls()
161
+ return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu")
 
 
 
 
 
162
 
163
  def load_embedding_model():
164
  Emb = _import_embeddings()
165
  return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
166
 
167
  def initialize_cpu_models() -> None:
 
 
168
  if HF_TOKEN:
169
  try:
170
  HfApi, HfFolder = _import_hf_hub()
171
  HfFolder.save_token(HF_TOKEN)
172
+ logging.info("βœ… HF token set")
173
  except Exception as e:
174
  logging.warning(f"HF token save failed: {e}")
175
 
176
  if "det" not in models_cache:
177
  try:
178
  models_cache["det"] = load_yolo_model()
179
+ logging.info("βœ… YOLO loaded (CPU)")
180
  except Exception as e:
181
  logging.error(f"YOLO load failed: {e}")
182
 
 
187
  logging.info("βœ… Segmentation model loaded (CPU)")
188
  else:
189
  models_cache["seg"] = None
190
+ logging.warning("Segmentation model file missing; skipping.")
191
  except Exception as e:
192
  models_cache["seg"] = None
193
+ logging.warning(f"Segmentation unavailable: {e}")
194
 
195
  if "cls" not in models_cache:
196
  try:
197
  models_cache["cls"] = load_classification_pipeline()
198
+ logging.info("βœ… Classifier loaded (CPU)")
199
  except Exception as e:
200
  models_cache["cls"] = None
201
+ logging.warning(f"Classifier unavailable: {e}")
202
 
203
  if "embedding_model" not in models_cache:
204
  try:
205
  models_cache["embedding_model"] = load_embedding_model()
206
+ logging.info("βœ… Embeddings loaded (CPU)")
207
  except Exception as e:
208
  models_cache["embedding_model"] = None
209
+ logging.warning(f"Embeddings unavailable: {e}")
210
 
211
  def setup_knowledge_base() -> None:
 
212
  if "vector_store" in knowledge_base_cache:
213
  return
214
 
215
+ docs: List = []
216
  try:
217
  PyPDFLoader = _import_langchain_pdf()
218
  for pdf in GUIDELINE_PDFS:
219
  if os.path.exists(pdf):
220
  try:
221
+ docs.extend(PyPDFLoader(pdf).load())
 
222
  logging.info(f"Loaded PDF: {pdf}")
223
  except Exception as e:
224
+ logging.warning(f"PDF load failed ({pdf}): {e}")
225
  except Exception as e:
226
  logging.warning(f"LangChain PDF loader unavailable: {e}")
227
 
 
229
  try:
230
  from langchain.text_splitter import RecursiveCharacterTextSplitter
231
  FAISS = _import_langchain_faiss()
232
+ chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs)
 
233
  knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
234
+ logging.info(f"βœ… Knowledge base ready ({len(chunks)} chunks)")
235
  except Exception as e:
236
  knowledge_base_cache["vector_store"] = None
237
+ logging.warning(f"KB build failed: {e}")
238
  else:
239
  knowledge_base_cache["vector_store"] = None
240
+ logging.warning("KB disabled (no docs or embeddings).")
241
 
242
+ # Initialize on import so app is ready
243
  initialize_cpu_models()
244
  setup_knowledge_base()
245
 
 
253
  self.dataset_id = DATASET_ID
254
  self.hf_token = HF_TOKEN
255
 
 
256
  def _ensure_analysis_dir(self) -> str:
257
  out_dir = os.path.join(self.uploads_dir, "analysis")
258
  os.makedirs(out_dir, exist_ok=True)
259
  return out_dir
260
 
261
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
262
+ """YOLO detect β†’ (optional) Keras seg β†’ (optional) HF classify β†’ save visuals."""
263
  try:
264
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
265
 
 
312
  seg_path = os.path.join(out_dir, f"segmentation_{ts}.png")
313
  cv2.imwrite(seg_path, seg_vis)
314
  except Exception as e:
315
+ logging.warning(f"Segmentation skipped: {e}")
316
 
317
  # Optional classification
318
  wound_type = "Unknown"
 
324
  if preds:
325
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
326
  except Exception as e:
327
+ logging.warning(f"Classification failed: {e}")
328
 
329
  # Save detection & original
330
  out_dir = self._ensure_analysis_dir()
 
354
  raise
355
 
356
  def query_guidelines(self, query: str) -> str:
357
+ """Query the (optional) guideline knowledge base."""
358
  try:
359
  vs = self.knowledge_base_cache.get("vector_store")
360
  if not vs:
361
  return "Knowledge base is not available."
 
362
  try:
363
  retriever = vs.as_retriever(search_kwargs={"k": 5})
364
  docs = retriever.get_relevant_documents(query) # LC >= 0.2
365
  except Exception:
366
  retriever = vs.as_retriever(search_kwargs={"k": 5})
367
+ docs = retriever.invoke(query) # older LC
 
368
  lines: List[str] = []
369
  for d in docs:
370
  src = (d.metadata or {}).get("source", "N/A")
 
375
  logging.warning(f"Guidelines query failed: {e}")
376
  return f"Guidelines query failed: {str(e)}"
377
 
 
378
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
 
379
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
380
 
381
  ## πŸ“‹ Patient Information
 
401
  - Document with serial photos and measurements
402
 
403
  ## πŸ“… Monitoring
404
+ - Daily in week 1, then every 2–3 days (or as indicated)
405
  - Weekly progress review
406
 
407
  ## πŸ“š Guideline Context
408
+ {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
409
 
410
  **Disclaimer:** Automated, for decision support only. Verify clinically.
411
  """
 
418
  image_pil: Image.Image,
419
  max_new_tokens: Optional[int] = None,
420
  ) -> str:
421
+ """Use GPU path when available, fallback otherwise."""
422
  try:
423
  report = generate_medgemma_report_with_timeout(
424
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
 
431
  logging.error(f"Report generation failed: {e}")
432
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
433
 
 
434
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
435
+ """Save locally and (optionally) upload to HF dataset."""
436
  try:
437
  os.makedirs(self.uploads_dir, exist_ok=True)
438
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
 
441
  image_pil.convert("RGB").save(path)
442
  logging.info(f"βœ… Image saved locally: {path}")
443
 
444
+ if HF_TOKEN and DATASET_ID:
445
  try:
446
  HfApi, HfFolder = _import_hf_hub()
447
+ HfFolder.save_token(HF_TOKEN)
448
  api = HfApi()
449
  api.upload_file(
450
  path_or_fileobj=path,
451
  path_in_repo=f"images/{filename}",
452
+ repo_id=DATASET_ID,
453
  repo_type="dataset",
454
+ token=HF_TOKEN,
455
  commit_message=f"Upload wound image: {filename}",
456
  )
457
  logging.info("βœ… Image committed to HF dataset")
 
463
  logging.error(f"Failed to save/commit image: {e}")
464
  return ""
465
 
 
466
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
467
+ """End-to-end analysis."""
468
  try:
469
  saved_path = self.save_and_commit_image(image_pil)
 
470
  visual_results = self.perform_visual_analysis(image_pil)
471
 
 
472
  pi = questionnaire_data or {}
473
  patient_info = (
474
+ f"Age: {pi.get('age','N/A')}, "
475
+ f"Diabetic: {pi.get('diabetic','N/A')}, "
476
+ f"Allergies: {pi.get('allergies','N/A')}, "
477
+ f"Date of Wound: {pi.get('date_of_injury','N/A')}, "
478
+ f"Professional Care: {pi.get('professional_care','N/A')}, "
479
+ f"Oozing/Bleeding: {pi.get('oozing_bleeding','N/A')}, "
480
+ f"Infection: {pi.get('infection','N/A')}, "
481
+ f"Moisture: {pi.get('moisture','N/A')}"
482
  )
483
 
 
484
  query = (
485
  f"best practices for managing a {visual_results.get('wound_type','Unknown')} "
486
  f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' "
 
488
  )
489
  guideline_context = self.query_guidelines(query)
490
 
491
+ report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil)
 
 
 
 
492
 
493
  return {
494
  "success": True,
495
  "visual_analysis": visual_results,
496
  "report": report,
497
  "saved_image_path": saved_path,
498
+ "guideline_context": (guideline_context or "")[:500] + (
499
+ "..." if guideline_context and len(guideline_context) > 500 else ""
500
+ ),
501
  }
502
  except Exception as e:
503
  logging.error(f"Pipeline error: {e}")
 
511
  }
512
 
513
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
514
+ """Public entrypoint used by UI."""
515
  try:
516
  if isinstance(image, str):
517
  if not os.path.exists(image):
 
534
  "report": f"Analysis initialization failed: {str(e)}",
535
  "saved_image_path": None,
536
  "guideline_context": "",
537
+ }