SmartHeal-Agentic-AI / src /ai_processor.py
SmartHeal's picture
Update src/ai_processor.py
9eef931 verified
raw
history blame
10.7 kB
import os
import logging
import cv2
import numpy as np
from PIL import Image
import torch
import json
from datetime import datetime
import tensorflow as tf
from transformers import pipeline
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from huggingface_hub import HfApi, HfFolder
import spaces
from src.config import Config
class AIProcessor:
def __init__(self):
self.models_cache = {}
self.knowledge_base_cache = {}
self.config = Config()
self.px_per_cm = 38
self._initialize_models()
def _initialize_models(self):
try:
HfFolder.save_token(self.config.HF_TOKEN)
self.models_cache['yolo'] = YOLO(self.config.YOLO_MODEL_PATH)
self.models_cache['segmentation'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
self.models_cache['medgemma_pipe'] = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
token=self.config.HF_TOKEN
)
self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
self.models_cache['cls'] = pipeline(
"image-classification",
model="Hemg/Wound-classification",
token=self.config.HF_TOKEN,
device="cpu"
)
logging.info("✅ All models loaded.")
self._load_knowledge_base()
except Exception as e:
logging.error(f"Error initializing AI models: {e}")
def _load_knowledge_base(self):
try:
docs = []
for pdf in self.config.GUIDELINE_PDFS:
if os.path.exists(pdf):
loader = PyPDFLoader(pdf)
docs.extend(loader.load())
if docs:
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = splitter.split_documents(docs)
vectorstore = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
self.knowledge_base_cache['vectorstore'] = vectorstore
logging.info("✅ Knowledge base loaded.")
else:
self.knowledge_base_cache['vectorstore'] = None
except Exception as e:
logging.warning(f"Knowledge base error: {e}")
def perform_visual_analysis(self, image_pil):
image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
results = self.models_cache['yolo'].predict(image_cv, verbose=False, device="cpu")
if not results or not results[0].boxes:
raise ValueError("No wound detected.")
box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
region_cv = image_cv[box[1]:box[3], box[0]:box[2]]
input_size = self.models_cache['segmentation'].input_shape[1:3]
resized = cv2.resize(region_cv, (input_size[1], input_size[0]))
mask = self.models_cache['segmentation'].predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
mask_np = (mask[:, :, 0] > 0.5).astype(np.uint8)
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
length = breadth = area = 0
if contours:
cnt = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(cnt)
length = round(h / self.px_per_cm, 2)
breadth = round(w / self.px_per_cm, 2)
area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
wound_type = max(self.models_cache['cls'](Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB))), key=lambda x: x['score'])['label']
return {
'wound_type': wound_type,
'length_cm': length,
'breadth_cm': breadth,
'surface_area_cm2': area
}
def query_guidelines(self, query: str):
vector_store = self.knowledge_base_cache.get("vectorstore")
if not vector_store:
return "Knowledge base unavailable."
retriever = vector_store.as_retriever(search_kwargs={"k": 10})
docs = retriever.invoke(query)
return "\n\n".join([
f"Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}\nContent: {doc.page_content}"
for doc in docs
])
def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=2048):
prompt = f"""
🩺 You are SmartHeal-AI, a world-class wound care AI specialist trained in clinical wound assessment and guideline-based treatment planning.
Your task is to process the following structured inputs (patient data, wound measurements, clinical guidelines, and image) and perform **clinical reasoning and decision-making** to generate a complete wound care report.
---
🔍 **YOUR PROCESS — FOLLOW STRICTLY:**
### Step 1: Clinical Reasoning (Chain-of-Thought)
Use the provided information to think step-by-step about:
- Patient’s risk factors (e.g. diabetes, age, healing limitations)
- Wound characteristics (size, tissue appearance, moisture, infection signs)
- Visual clues from the image (location, granulation, maceration, inflammation, surrounding skin)
- Clinical guidelines provided — selectively choose the ones most relevant to this case
Do NOT list all guidelines verbatim. Use judgment: apply them where relevant. Explain why or why not.
Also assess whether this wound appears:
- Acute vs chronic
- Surgical vs traumatic
- Inflammatory vs proliferative healing phase
---
### Step 2: Structured Clinical Report
Generate the following report sections using markdown and medical terminology:
#### **1. Clinical Summary**
- Describe wound appearance and tissue types (e.g., slough, necrotic, granulating, epithelializing)
- Include size, wound bed condition, peri-wound skin, and signs of infection or biofilm
- Mention inferred location (e.g., heel, forefoot) if image allows
- Summarize patient's systemic risk profile
#### **2. Medicinal & Dressing Recommendations**
Based on your analysis:
- Recommend specific **wound care dressings** (e.g., hydrocolloid, alginate, foam, antimicrobial silver, etc.) suitable to wound moisture level and infection risk
- Propose **topical or systemic agents** ONLY if relevant — include name classes (e.g., antiseptic: povidone iodine, antibiotic ointments, enzymatic debriders)
- Mention **techniques** (e.g., sharp debridement, NPWT, moisture balance, pressure offloading, dressing frequency)
- Avoid repeating guidelines — **apply them**
#### **3. Key Risk Factors**
Explain how the patient’s condition (e.g., diabetic, poor circulation, advanced age, poor hygiene) may affect wound healing
#### **4. Prognosis & Monitoring Advice**
- Mention how often wound should be reassessed
- Indicate signs to monitor for deterioration or improvement
- Include when escalation to specialist is necessary
#### **5. Disclaimer**
This is an AI-generated summary based on available data. It is not a substitute for clinical evaluation by a wound care professional.
**Note:** Every dressing change is a chance for wound reassessment. Always perform a thorough wound evaluation at each dressing change.
---
🧾 **INPUT DATA**
**Patient Info:**
{patient_info}
**Wound Details:**
- Type: {visual_results['wound_type']}
- Size: {visual_results['length_cm']} × {visual_results['breadth_cm']} cm
- Area: {visual_results['surface_area_cm2']} cm²
**Clinical Guideline Evidence:**
{guideline_context}
You may now begin your analysis and generate the two-part report.
"""
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a world-class medical AI assistant..."}],
},
{
"role": "user",
"content": [
{"type": "image", "image": image_pil},
{"type": "text", "text": prompt},
]
}
]
try:
output = self.models_cache['medgemma_pipe'](
text=messages,
max_new_tokens=max_new_tokens,
do_sample=False,
)
return output[0]['generated_text'][-1].get('content', '').strip()
except Exception as e:
logging.error(f"MedGemma error: {e}", exc_info=True)
return f"❌ Failed to generate report: {e}"
def save_and_commit_image(self, image_pil):
filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png"
local_path = os.path.join(self.config.UPLOADS_DIR, filename)
image_pil.convert("RGB").save(local_path)
logging.info(f"Image saved locally: {local_path}")
if self.config.HF_TOKEN and self.config.DATASET_ID:
try:
api = HfApi()
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=f"images/{filename}",
repo_id=self.config.DATASET_ID,
repo_type="dataset",
commit_message=f"Upload wound image: {filename}"
)
logging.info("✅ Image uploaded to HF dataset.")
except Exception as e:
logging.warning(f"Upload failed: {e}")
@spaces.GPU(enable_queue=True, duration=120)
def full_analysis_pipeline(self, image, questionnaire_data):
try:
self.save_and_commit_image(image)
visual = self.perform_visual_analysis(image)
patient_info = ", ".join([f"{k}: {v}" for k, v in questionnaire_data.items()])
query = f"best practices for managing a {visual['wound_type']} with moisture level '{questionnaire_data.get('moisture')}' and signs of infection '{questionnaire_data.get('infection')}' in a patient who is diabetic '{questionnaire_data.get('diabetic')}'"
guideline_context = self.query_guidelines(query)
return self.generate_final_report(patient_info, visual, guideline_context, image)
except Exception as e:
logging.error(f"Pipeline error: {e}", exc_info=True)
return f"❌ Error: {e}"