SmartHeal-Agentic-AI / src /ai_processor.py
SmartHeal's picture
Update src/ai_processor.py
68507da verified
raw
history blame
22.1 kB
"""
SmartHeal AI Processor - Zero GPU Compatible Version
Designed specifically for Hugging Face Spaces with Zero GPU architecture
"""
import os
import logging
import cv2
import numpy as np
from PIL import Image
import json
from datetime import datetime
from typing import Optional, Dict, List, Tuple, Any
from contextlib import contextmanager
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
# Environment setup for Zero GPU compatibility
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") # Hide GPU from main process
# Import spaces decorator
try:
import spaces
_SPACES_GPU = spaces.GPU
except ImportError:
logging.warning("spaces package not available - running in CPU mode")
# Create dummy decorator for local testing
def _SPACES_GPU_dummy(*args, **kwargs):
def decorator(func):
return func
return decorator
_SPACES_GPU = _SPACES_GPU_dummy
@contextmanager
def _no_cuda_env():
"""Context manager to prevent CUDA initialization in main process"""
prev_cuda = os.environ.get("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
try:
yield
finally:
if prev_cuda is None:
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = prev_cuda
# ---- Paths / constants ----
UPLOADS_DIR = "uploads"
os.makedirs(UPLOADS_DIR, exist_ok=True)
HF_TOKEN = os.getenv("HF_TOKEN", None)
YOLO_MODEL_PATH = "src/best.pt"
SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
DATASET_ID = "SmartHeal/wound-image-uploads"
DEFAULT_PX_PER_CM = 38.0
PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0
models_cache: Dict[str, object] = {}
knowledge_base_cache: Dict[str, object] = {}
# ---------- Lazy imports (wrapped where needed) ----------
def _import_ultralytics():
# Prevent Ultralytics from probing CUDA on import
with _no_cuda_env():
from ultralytics import YOLO
return YOLO
def _import_tf_loader():
# Ensure TensorFlow does not try to use GPU in main process
with _no_cuda_env():
import tensorflow as tf
tf.config.set_visible_devices([], "GPU")
from tensorflow.keras.models import load_model
return load_model
def _import_hf_cls():
with _no_cuda_env():
from transformers import pipeline
return pipeline
def _import_embeddings():
with _no_cuda_env():
from langchain_community.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings
def _import_langchain_pdf():
with _no_cuda_env():
from langchain_community.document_loaders import PyPDFLoader
return PyPDFLoader
def _import_langchain_faiss():
with _no_cuda_env():
from langchain_community.vectorstores import FAISS
return FAISS
def _import_hf_hub():
with _no_cuda_env():
from huggingface_hub import HfApi, HfFolder
return HfApi, HfFolder
# ---------- SmartHeal prompts (system + user prefix) ----------
SMARTHEAL_SYSTEM_PROMPT = """
You are SmartHeal Clinical Assistant, a wound-care decision-support system.
You analyze wound photographs and brief patient context to produce careful,
specific, guideline-informed recommendations WITHOUT diagnosing. You always:
- Use the measurements calculated by the vision pipeline as ground truth.
- Prefer concise, actionable steps tailored to exudate level, infection risk, and pain.
- Flag uncertainties and red flags that need escalation to a clinician.
- Avoid contraindicated advice; do not infer unseen comorbidities.
- Keep under 300 words and use the requested headings exactly.
- Tone: professional, clear, and conservative; no definitive medical claims.
- Safety: remind the user to seek clinician review for changes or red flags.
"""
SMARTHEAL_USER_PREFIX = """
Patient: {patient_info}
Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
Guideline context (snippets you can draw principles from; do not quote at length):
{guideline_context}
Write a structured answer with these headings exactly:
1. Clinical Summary (max 4 bullet points)
2. Likely Stage/Type (if uncertain, say 'uncertain')
3. Treatment Plan (specific dressing choices and frequency based on exudate/infection risk)
4. Red Flags (what to escalate and when)
5. Follow-up Cadence (days)
6. Notes (assumptions/uncertainties)
Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
"""
# ---------- VLM (MedGemma replaced with Qwen2-VL) ----------
@_SPACES_GPU(enable_queue=True)
def _vlm_infer_gpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]):
"""
Runs entirely inside a Spaces GPU worker. It's the ONLY place we allow CUDA init.
"""
from transformers import pipeline
import torch # Ensure torch is imported here
pipe = pipeline(
task="image-text-to-text",
model=model_id,
torch_dtype=torch.bfloat16, # Use torch_dtype from the working example
device_map="auto", # CUDA init happens here, safely in GPU worker
token=token,
trust_remote_code=True,
model_kwargs={"low_cpu_mem_usage": True},
)
out = pipe(text=messages, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.2)
try:
txt = out[0]["generated_text"][-1].get("content", "")
except Exception:
txt = out[0].get("generated_text", "")
return (txt or "").strip() or "⚠️ Empty response"
def generate_medgemma_report( # kept name so callers don't change
patient_info: str,
visual_results: Dict,
guideline_context: str,
image_pil: Image.Image,
max_new_tokens: Optional[int] = None,
) -> str:
"""
MedGemma replacement using Qwen/Qwen2-VL-2B-Instruct via image-text-to-text.
Loads & runs ONLY inside a GPU worker to satisfy Stateless GPU constraints.
"""
if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
return "⚠️ VLM disabled"
model_id = os.getenv("SMARTHEAL_VLM_MODEL", "Qwen/Qwen2-VL-2B-Instruct")
max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
uprompt = SMARTHEAL_USER_PREFIX.format(
patient_info=patient_info,
wound_type=visual_results.get("wound_type", "Unknown"),
length_cm=visual_results.get("length_cm", 0),
breadth_cm=visual_results.get("breadth_cm", 0),
area_cm2=visual_results.get("surface_area_cm2", 0),
det_conf=float(visual_results.get("detection_confidence", 0.0)),
px_per_cm=visual_results.get("px_per_cm", "?"),
guideline_context=(guideline_context or "")[:900],
)
messages = [
{"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]},
{"role": "user", "content": [
{"type": "image", "image": image_pil},
{"type": "text", "text": uprompt},
]},
]
try:
# IMPORTANT: do not import transformers or touch CUDA here. Only call the GPU worker.
return _vlm_infer_gpu(messages, model_id, max_new_tokens, HF_TOKEN)
except Exception as e:
logging.error(f"VLM call failed: {e}")
return "⚠️ VLM error"
# ---------- Initialize CPU models ----------
def load_yolo_model():
YOLO = _import_ultralytics()
# Construct model with CUDA masked to avoid auto-selecting cuda:0
with _no_cuda_env():
model = YOLO(YOLO_MODEL_PATH)
return model
def load_segmentation_model():
load_model = _import_tf_loader()
return load_model(SEG_MODEL_PATH, compile=False, custom_objects={'InputLayer': _import_tf_loader().layers.InputLayer})
def load_classification_pipeline():
pipe = _import_hf_cls()
return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu")
def load_embedding_model():
Emb = _import_embeddings()
return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
def initialize_cpu_models() -> None:
if HF_TOKEN:
try:
HfApi, HfFolder = _import_hf_hub()
HfFolder.save_token(HF_TOKEN)
logging.info("✅ HF token set")
except Exception as e:
logging.warning(f"HF token save failed: {e}")
if "det" not in models_cache:
try:
models_cache["det"] = load_yolo_model()
logging.info("✅ YOLO loaded (CPU; CUDA masked in main)")
except Exception as e:
logging.error(f"YOLO load failed: {e}")
if "seg" not in models_cache:
try:
if os.path.exists(SEG_MODEL_PATH):
models_cache["seg"] = load_segmentation_model()
m = models_cache["seg"]
ishape = getattr(m, "input_shape", None)
oshape = getattr(m, "output_shape", None)
logging.info(f"✅ Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}")
else:
models_cache["seg"] = None
logging.warning("Segmentation model file missing; skipping.")
except Exception as e:
models_cache["seg"] = None
logging.warning(f"Segmentation unavailable: {e}")
if "cls" not in models_cache:
try:
models_cache["cls"] = load_classification_pipeline()
logging.info("✅ Classifier loaded (CPU)")
except Exception as e:
models_cache["cls"] = None
logging.warning(f"Classifier unavailable: {e}")
if "embedding_model" not in models_cache:
try:
models_cache["embedding_model"] = load_embedding_model()
logging.info("✅ Embeddings loaded (CPU)")
except Exception as e:
models_cache["embedding_model"] = None
logging.warning(f"Embeddings unavailable: {e}")
def setup_knowledge_base() -> None:
if "vector_store" in knowledge_base_cache:
return
docs: List = []
try:
PyPDFLoader = _import_langchain_pdf()
for pdf in GUIDELINE_PDFS:
if os.path.exists(pdf):
try:
docs.extend(PyPDFLoader(pdf).load())
logging.info(f"Loaded PDF: {pdf}")
except Exception as e:
logging.warning(f"PDF load failed ({pdf}): {e}")
except Exception as e:
logging.warning(f"LangChain PDF loader unavailable: {e}")
if docs and models_cache.get("embedding_model"):
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
FAISS = _import_langchain_faiss()
chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs)
knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)")
except Exception as e:
knowledge_base_cache["vector_store"] = None
logging.warning(f"KB build failed: {e}")
else:
knowledge_base_cache["vector_store"] = None
logging.warning("KB disabled (no docs or embeddings).")
initialize_cpu_models()
setup_knowledge_base()
# ---------- Calibration helpers ----------
def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
out = {}
try:
exif = pil_img.getexif()
if not exif:
return out
for k, v in exif.items():
tag = TAGS.get(k, k)
out[tag] = v
except Exception:
pass
return out
def _to_float(val) -> Optional[float]:
try:
if val is None:
return None
if isinstance(val, tuple) and len(val) == 2:
num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0
return num / den
return float(val)
except Exception:
return None
def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]:
if f_mm and f35 and f35 > 0:
return 36.0 * f_mm / f35
return None
def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]:
meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None}
try:
exif = _exif_to_dict(pil_img)
f_mm = _to_float(exif.get("FocalLength"))
f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm"))
subj_dist_m = _to_float(exif.get("SubjectDistance"))
sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35)
meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m})
if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0:
w_px = pil_img.width
field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm
field_w_cm = field_w_mm / 10.0
px_per_cm = w_px / max(field_w_cm, 1e-6)
px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX))
meta["used"] = "exif"
return px_per_cm, meta
return float(default_px_per_cm), meta
except Exception:
return float(default_px_per_cm), meta
# ---------- Segmentation helpers ----------
def _imagenet_norm(arr: np.ndarray) -> np.ndarray:
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
return (arr.astype(np.float32) - mean) / std
def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray:
H, W = target_hw
resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR)
x = resized.astype(np.float32) / 255.0
x = np.expand_dims(x, axis=0) # (1,H,W,3)
return x
def _to_prob(pred: np.ndarray) -> np.ndarray:
p = np.squeeze(pred)
pmin, pmax = float(p.min()), float(p.max())
if pmax > 1.0 or pmin < 0.0:
p = 1.0 / (1.0 + np.exp(-p))
return p.astype(np.float32)
# ---- Adaptive threshold + GrabCut grow ----
def _adaptive_prob_threshold(p: np.ndarray) -> float:
"""
Choose a threshold that avoids tiny blobs while not swallowing skin.
Try Otsu and the 90th percentile, clamp to [0.25, 0.65], pick by area heuristic.
"""
p01 = np.clip(p.astype(np.float32), 0, 1)
p255 = (p01 * 255).astype(np.uint8)
ret_otsu, _ = cv2.threshold(p255, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
thr_otsu = float(np.clip(ret_otsu / 255.0, 0.25, 0.65))
thr_pctl = float(np.clip(np.percentile(p01, 90), 0.25, 0.65))
def area_frac(thr: float) -> float:
return float((p01 >= thr).sum()) / float(p01.size)
af_otsu = area_frac(thr_otsu)
af_pctl = area_frac(thr_pctl)
def score(af: float) -> float:
target_low, target_high = 0.03, 0.10
if af < target_low: return abs(af - target_low) * 3.0
if af > target_high: return abs(af - target_high) * 1.5
return 0.0
return thr_otsu if score(af_otsu) <= score(af_pctl) else thr_pctl
def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.ndarray:
"""
Grow from a confident core into low-contrast margins.
"""
h, w = bgr.shape[:2]
gc = np.full((h, w), cv2.GC_PR_BGD, np.uint8)
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
seed_dil = cv2.dilate(seed01, k, iterations=1)
gc[seed01.astype(bool)] = cv2.GC_PR_FGD
gc[seed_dil.astype(bool)] = cv2.GC_FGD
gc[0, :], gc[-1, :], gc[:, 0], gc[:, 1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD
bgdModel = np.zeros((1, 65), np.float64)
fgdModel = np.zeros((1, 65), np.float64)
cv2.grabCut(bgr, gc, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
return np.where((gc == cv2.GC_FGD) | (gc == cv2.GC_PR_FGD), 1, 0).astype(np.uint8)
# ---------- Main AIProcessor Class ----------
class AIProcessor:
def __init__(self):
self.config = type("Config", (object,), {
"HF_TOKEN": HF_TOKEN,
"YOLO_MODEL_PATH": YOLO_MODEL_PATH,
"SEG_MODEL_PATH": SEG_MODEL_PATH,
"DATASET_ID": DATASET_ID,
"UPLOADS_DIR": UPLOADS_DIR,
"GUIDELINE_PDFS": GUIDELINE_PDFS
})()
self.models_cache = models_cache
self.knowledge_base_cache = knowledge_base_cache
self.px_per_cm = DEFAULT_PX_PER_CM # Use default from constants
# Ensure CPU models and KB are initialized
initialize_cpu_models()
setup_knowledge_base()
def perform_visual_analysis(self, image_pil: Image.Image) -> Dict[str, Any]:
image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
if "det" not in self.models_cache or not self.models_cache["det"]:
raise ValueError("YOLO model not initialized.")
results = self.models_cache["det"].predict(image_cv, verbose=False, device="cpu")
if not results or not results[0].boxes:
raise ValueError("No wound detected.")
box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
region_cv = image_cv[box[1]:box[3], box[0]:box[2]]
detection_confidence = float(results[0].boxes[0].conf[0].cpu().numpy())
length = breadth = area = 0
if "seg" in self.models_cache and self.models_cache["seg"]:
try:
seg_model = self.models_cache["seg"]
input_size = seg_model.input_shape[1:3]
preprocessed_roi = _preprocess_for_seg(region_cv, input_size)
mask_pred = seg_model.predict(preprocessed_roi, verbose=0)[0]
prob_mask = _to_prob(mask_pred)
# Adaptive thresholding and GrabCut refinement
initial_mask = (prob_mask >= _adaptive_prob_threshold(prob_mask)).astype(np.uint8)
refined_mask = _grabcut_refine(region_cv, initial_mask)
contours, _ = cv2.findContours(refined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
cnt = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(cnt)
length = round(h / self.px_per_cm, 2)
breadth = round(w / self.px_per_cm, 2)
area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
except Exception as e:
logging.warning(f"Segmentation process failed: {e}")
wound_type = "Unknown"
if "cls" in self.models_cache and self.models_cache["cls"]:
try:
wound_region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB))
classification_results = self.models_cache["cls"](wound_region_pil)
wound_type = max(classification_results, key=lambda x: x["score"])["label"]
except Exception as e:
logging.warning(f"Classification process failed: {e}")
return {
"wound_type": wound_type,
"length_cm": length,
"breadth_cm": breadth,
"surface_area_cm2": area,
"detection_confidence": detection_confidence,
"px_per_cm": self.px_per_cm
}
def query_guidelines(self, query: str) -> str:
vector_store = self.knowledge_base_cache.get("vector_store")
if not vector_store:
return "Knowledge base unavailable."
retriever = vector_store.as_retriever(search_kwargs={"k": 10})
docs = retriever.invoke(query)
return "\n\n".join([
f"Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}\nContent: {doc.page_content}"
for doc in docs
])
def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=2048):
return generate_medgemma_report(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
def save_and_commit_image(self, image_pil):
filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png"
local_path = os.path.join(self.config.UPLOADS_DIR, filename)
image_pil.convert("RGB").save(local_path)
logging.info(f"Image saved locally: {local_path}")
if self.config.HF_TOKEN and self.config.DATASET_ID:
try:
api = _import_hf_hub()[0]() # HfApi
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=f"images/{filename}",
repo_id=self.config.DATASET_ID,
repo_type="dataset",
commit_message=f"Upload wound image: {filename}",
token=self.config.HF_TOKEN
)
logging.info("✅ Image uploaded to HF dataset.")
except Exception as e:
logging.warning(f"Upload failed: {e}")
@_SPACES_GPU(enable_queue=True, duration=120)
def full_analysis_pipeline(self, image, questionnaire_data):
try:
self.save_and_commit_image(image)
visual = self.perform_visual_analysis(image)
patient_info = ", ".join([f"{k}: {v}" for k, v in questionnaire_data.items()])
query = f"best practices for managing a {visual['wound_type']} with moisture level '{questionnaire_data.get('moisture')}' and signs of infection '{questionnaire_data.get('infection')}' in a patient who is diabetic '{questionnaire_data.get('diabetic')}'"
guideline_context = self.query_guidelines(query)
return self.generate_final_report(patient_info, visual, guideline_context, image)
except Exception as e:
logging.error(f"Pipeline error: {e}", exc_info=True)
return f"❌ Error: {e}"