wstest / app.py
zazaman's picture
Patch Space for WoundDoc mobile app API integration
6771c3d verified
import base64
import io
import os
import re
import uuid
from typing import Any, Dict
import cv2
import gradio as gr
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from PIL import Image
from pydantic import BaseModel
# --- CONFIGURATION ---
MODEL_PATH = "wound_segmentation_model.h5"
IMG_HEIGHT = 256
IMG_WIDTH = 256
REFERENCE_AREA_CM2 = 4.0
LOWER_BLUE = np.array([95, 60, 100])
UPPER_BLUE = np.array([120, 255, 255])
# --- ARUCO SETTINGS ---
ARUCO_DICT_TYPE = cv2.aruco.DICT_4X4_50
MARKER_SIZE_CM = 2.0 # Assumes a 2x2 cm marker
REFERENCE_AREA_CM2 = MARKER_SIZE_CM * MARKER_SIZE_CM # 4.0
TISSUE_TYPES = [
{'name': 'Eschar / Necrotic Tissue (Pure Black & Near Blacks)', 'display_name': 'eschar_necrotic_tissue', 'id': 1, 'color': (255, 255, 255), 'lower': np.array([0, 0, 0]), 'upper': np.array([179, 255, 67])},
{'name': 'Aged Granulation / Old Blood (Very Dark Reds)', 'display_name': 'aged_granulation_old_blood', 'id': 2, 'color': (0, 255, 0), 'lower': np.array([0, 205, 0]), 'upper': np.array([10, 255, 100])},
{'name': 'Eschar / Necrotic Tissue (Off-Blacks)', 'display_name': 'eschar_necrotic_tissue', 'id': 3, 'color': (128, 0, 128), 'lower': np.array([110, 205, 0]), 'upper': np.array([130, 255, 100])},
{'name': 'Eschar / Necrotic Tissue (Very Dark Browns)', 'display_name': 'eschar_necrotic_tissue', 'id': 4, 'color': (255, 255, 0), 'lower': np.array([20, 155, 0]), 'upper': np.array([40, 255, 100])},
{'name': 'Healthy Granulation Tissue (Dark Reds)', 'display_name': 'healthy_granulation_tissue', 'id': 5, 'color': (0, 255, 100), 'lower': np.array([0, 205, 100]), 'upper': np.array([10, 255, 200])},
{'name': 'Eschar / Necrotic Tissue (Dark Browns)', 'display_name': 'eschar_necrotic_tissue', 'id': 6, 'color': (255, 0, 100), 'lower': np.array([20, 155, 50]), 'upper': np.array([40, 255, 150])},
{'name': 'Hematoma / Ischemia (Bruising) (Dark Purples)', 'display_name': 'hematoma_ischemia_bruising', 'id': 7, 'color': (255, 0, 255), 'lower': np.array([140, 136, 100]), 'upper': np.array([160, 236, 200])},
{'name': 'Necrotic Tissue with Bruising (Near Blacks (Brown/Purple Tinge))', 'display_name': 'necrotic_tissue_with_bruising', 'id': 8, 'color': (200, 255, 0), 'lower': np.array([159, 126, 0]), 'upper': np.array([179, 226, 142])},
{'name': 'Hematoma / Ischemia (Severe Bruising) (Very Dark Purples)', 'display_name': 'hematoma_ischemia_severe_bruising', 'id': 9, 'color': (128, 128, 128), 'lower': np.array([145, 205, 50]), 'upper': np.array([165, 255, 175])},
{'name': 'Fibrin / Devitalized Connective Tissue (Dark Greys)', 'display_name': 'fibrin_devitalized_connective_tissue', 'id': 10, 'color': (0, 255, 255), 'lower': np.array([0, 0, 50]), 'upper': np.array([179, 55, 150])},
{'name': 'Infected Tissue (e.g., Pseudomonas) (Dark Olive)', 'display_name': 'infected_tissue_eg_pseudomonas', 'id': 11, 'color': (0, 200, 0), 'lower': np.array([20, 105, 0]), 'upper': np.array([40, 205, 150])},
{'name': 'Hemosiderin Staining / Drying Exudate (Browns / Dark Oranges)', 'display_name': 'hemosiderin_staining_drying_exudate', 'id': 12, 'color': (255, 0, 200), 'lower': np.array([10, 205, 100]), 'upper': np.array([30, 255, 200])},
{'name': 'Healthy Granulation Tissue (Deep Reds / Maroons)', 'display_name': 'healthy_granulation_tissue', 'id': 13, 'color': (0, 200, 255), 'lower': np.array([0, 205, 150]), 'upper': np.array([10, 255, 255])},
{'name': 'Hemosiderin Staining / Drying Exudate (Burnt Oranges)', 'display_name': 'hemosiderin_staining_drying_exudate', 'id': 14, 'color': (255, 255, 100), 'lower': np.array([10, 205, 150]), 'upper': np.array([30, 255, 255])},
{'name': 'Early Bruising / Poor Perfusion (Medium Purples)', 'display_name': 'early_bruising_poor_perfusion', 'id': 15, 'color': (200, 0, 255), 'lower': np.array([140, 126, 150]), 'upper': np.array([160, 226, 255])},
{'name': 'Slough with Potential Infection (Olive / Dark Yellows)', 'display_name': 'slough_with_potential_infection', 'id': 16, 'color': (100, 255, 0), 'lower': np.array([20, 205, 100]), 'upper': np.array([40, 255, 200])},
{'name': 'Early / Fragile Epithelializing Tissue (Dusty Pinks)', 'display_name': 'early_fragile_epithelializing_tissue', 'id': 17, 'color': (0, 255, 200), 'lower': np.array([140, 55, 150]), 'upper': np.array([160, 155, 255])},
{'name': 'Fibrin / Devitalized Tissue (Mid-Tone Greys)', 'display_name': 'fibrin_devitalized_tissue', 'id': 18, 'color': (200, 200, 200), 'lower': np.array([0, 0, 140]), 'upper': np.array([179, 65, 240])},
{'name': 'Mixed Devitalized Tissue (Slough/Eschar) (Taupe / Greyish Browns)', 'display_name': 'mixed_devitalized_tissue_slough_eschar', 'id': 19, 'color': (64, 224, 208), 'lower': np.array([20, 45, 50]), 'upper': np.array([40, 145, 175])},
{'name': 'Serosanguinous Exudate (Blood/Serum Mix) (Oranges)', 'display_name': 'serosanguinous_exudate_blood_serum_mix', 'id': 20, 'color': (100, 255, 100), 'lower': np.array([10, 205, 203]), 'upper': np.array([30, 255, 255])},
{'name': 'Healthy Granulation Tissue (Medium Reds)', 'display_name': 'healthy_granulation_tissue', 'id': 21, 'color': (0, 255, 128), 'lower': np.array([0, 111, 202]), 'upper': np.array([10, 211, 255])},
{'name': 'Serosanguinous Exudate (Dilute) (Light Oranges / Peaches)', 'display_name': 'serosanguinous_exudate_dilute', 'id': 22, 'color': (255, 100, 255), 'lower': np.array([10, 92, 203]), 'upper': np.array([30, 192, 255])},
{'name': 'Healthy Granulation / Fresh Bleeding (Bright Reds)', 'display_name': 'healthy_granulation_fresh_bleeding', 'id': 23, 'color': (0, 200, 200), 'lower': np.array([0, 205, 202]), 'upper': np.array([10, 255, 255])},
{'name': 'Fibrin (Light Greys / Fibrin Tones)', 'display_name': 'fibrin', 'id': 24, 'color': (105, 105, 105), 'lower': np.array([0, 0, 175]), 'upper': np.array([179, 80, 255])},
{'name': 'Slough / Fibrinous Exudate (Pale Yellowish Tones)', 'display_name': 'slough_fibrinous_exudate', 'id': 25, 'color': (128, 255, 0), 'lower': np.array([50, 55, 150]), 'upper': np.array([70, 155, 255])},
{'name': 'Fragile Granulation / Epithelial Tissue (Light Magenta / Lavenders)', 'display_name': 'fragile_granulation_epithelial_tissue', 'id': 26, 'color': (255, 0, 128), 'lower': np.array([140, 92, 202]), 'upper': np.array([160, 192, 255])},
{'name': 'Hypergranulation / Irritated Tissue (Bright Pinks / Magenta)', 'display_name': 'hypergranulation_irritated_tissue', 'id': 27, 'color': (128, 0, 255), 'lower': np.array([140, 176, 202]), 'upper': np.array([160, 255, 255])},
{'name': 'Slough / Serous Exudate (Light Yellows)', 'display_name': 'slough_serous_exudate', 'id': 28, 'color': (0, 255, 192), 'lower': np.array([20, 80, 204]), 'upper': np.array([40, 180, 255])},
{'name': 'Epithelializing Tissue (Pale Pinks)', 'display_name': 'epithelializing_tissue', 'id': 29, 'color': (255, 0, 128), 'lower': np.array([140, 1, 203]), 'upper': np.array([160, 101, 255])},
{'name': 'Purulent Exudate / Infected Slough (Bright Yellows)', 'display_name': 'purulent_exudate_infected_slough', 'id': 30, 'color': (50, 255, 255), 'lower': np.array([20, 204, 204]), 'upper': np.array([40, 255, 255])},
{'name': 'Fibrin / Macerated Skin (Pure White & Off-Whites)', 'display_name': 'fibrin_macerated_skin', 'id': 31, 'color': (220, 220, 220), 'lower': np.array([0, 0, 202]), 'upper': np.array([179, 103, 255])},
]
# --- HELPERS ---
def iou(y_true, y_pred, smooth=1e-6):
y_true_f = tf.keras.backend.flatten(y_true)
y_pred_f = tf.keras.backend.flatten(y_pred)
intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
union = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) - intersection
return (intersection + smooth) / (union + smooth)
def calculate_infection_risk(tissue_percentages):
high_risk_indicators = [
"Purulent Exudate / Infected Slough (Bright Yellows)",
"Infected Tissue (e.g., Pseudomonas) (Dark Olive)",
]
medium_risk_indicators = [
"Slough with Potential Infection (Olive / Dark Yellows)"
]
high_risk_percentage = sum(tissue_percentages.get(name, 0) for name in high_risk_indicators)
medium_risk_percentage = sum(tissue_percentages.get(name, 0) for name in medium_risk_indicators)
if high_risk_percentage > 5.0:
return "High"
elif high_risk_percentage > 1.0 or medium_risk_percentage > 10.0:
return "Medium"
else:
return "Low"
def _decode_base64_to_rgb_np(base64_string: str) -> np.ndarray:
if "," in base64_string:
base64_string = re.sub(r"^data:image/.+;base64,", "", base64_string)
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
return np.array(image)
def _encode_rgb_np_to_data_url(image_np: np.ndarray, fmt: str = "JPEG") -> str:
pil_img = Image.fromarray(image_np)
buff = io.BytesIO()
pil_img.save(buff, format=fmt)
mime = "image/jpeg" if fmt.upper() == "JPEG" else "image/png"
encoded = base64.b64encode(buff.getvalue()).decode("utf-8")
return f"data:{mime};base64,{encoded}"
# --- MODEL LOADING ---
try:
if os.path.exists(MODEL_PATH):
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={"iou": iou})
print("--- Segmentation model loaded successfully. ---")
else:
model = None
print(f"--- WARNING: Model file not found at {MODEL_PATH} ---")
except Exception as e:
model = None
print(f"--- WARNING: Error loading model: {e} ---")
def analyze_wound_image(input_image_np: np.ndarray) -> Dict[str, Any]:
if model is None:
return {"status": "error", "message": "Model is not loaded. Check server logs."}
original_img = cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR)
img = cv2.resize(original_img, (IMG_WIDTH, IMG_HEIGHT))
img_array = np.expand_dims(img, axis=0) / 255.0
predicted_mask = model.predict(img_array, verbose=0)[0]
predicted_mask_binary = (predicted_mask > 0.5).astype(np.uint8) * 255
predicted_mask_resized = cv2.resize(
predicted_mask_binary,
(original_img.shape[1], original_img.shape[0]),
)
if cv2.countNonZero(predicted_mask_resized) == 0:
return {"status": "error", "message": "Segmentation failed. No wound detected."}
wound_area_cm2 = 0.0
ref_contour = None
applied_perspective = False
# --- 1. ATTEMPT ARUCO DETECTION & PERSPECTIVE CORRECTION ---
try:
# ArUco detection logic (Newer OpenCV API)
dictionary = cv2.aruco.getPredefinedDictionary(ARUCO_DICT_TYPE)
parameters = cv2.aruco.DetectorParameters()
detector = cv2.aruco.ArucoDetector(dictionary, parameters)
corners, ids, rejected = detector.detectMarkers(original_img)
if ids is not None and len(ids) > 0:
# Use the first detected marker
marker_corners = corners[0][0] # 4 corners
# Source points for perspective transform (marker corners in original image)
# Expected order: top-left, top-right, bottom-right, bottom-left
src_pts = marker_corners.astype(np.float32)
# Destination points: map to a square of known size
# We'll map the marker to a 100x100 pixel square for calculation
side = 100.0
dst_pts = np.array([
[0, 0],
[side, 0],
[side, side],
[0, side]
], dtype=np.float32)
# Calculate Homography
M = cv2.getPerspectiveTransform(src_pts, dst_pts)
# Warp the predicted mask to calculate area in "straightened" space
warped_mask = cv2.warpPerspective(predicted_mask_resized, M, (1000, 1000)) # Large enough canvas
# Pixels per cm logic in warped space
# Since marker is 'side' pixels and MARKER_SIZE_CM cm:
pixel_width_cm = MARKER_SIZE_CM / side
pixels_per_cm2 = (1.0 / pixel_width_cm) ** 2
wound_pixel_area_warped = cv2.countNonZero(warped_mask)
wound_area_cm2 = float(wound_pixel_area_warped / pixels_per_cm2)
applied_perspective = True
# For overlay visualization
ref_contour = src_pts.astype(np.int32).reshape((-1, 1, 2))
except Exception as e:
print(f"ArUco/Perspective error: {e}")
# --- 2. FALLBACK TO BLUE SQUARE (If ArUco failed) ---
if not applied_perspective:
try:
blurred_img = cv2.GaussianBlur(original_img, (5, 5), 0)
hsv_img_ref = cv2.cvtColor(blurred_img, cv2.COLOR_BGR2HSV)
blue_mask = cv2.inRange(hsv_img_ref, LOWER_BLUE, UPPER_BLUE)
kernel = np.ones((5, 5), np.uint8)
blue_mask = cv2.morphologyEx(blue_mask, cv2.MORPH_OPEN, kernel)
blue_mask = cv2.morphologyEx(blue_mask, cv2.MORPH_CLOSE, kernel)
contours, _ = cv2.findContours(blue_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
ref_contour = max(contours, key=cv2.contourArea)
ref_pixel_area = cv2.contourArea(ref_contour)
if ref_pixel_area > 0:
pixels_per_cm2 = ref_pixel_area / 4.0 # Current reference square is 4cm2
wound_pixel_area = cv2.countNonZero(predicted_mask_resized)
wound_area_cm2 = float(wound_pixel_area / pixels_per_cm2)
except Exception as e:
print(f"Fallback Area calc warning: {e}")
wound_only_img = cv2.bitwise_and(original_img, original_img, mask=predicted_mask_resized)
hsv_wound = cv2.cvtColor(wound_only_img, cv2.COLOR_BGR2HSV)
classification_mask = np.zeros(predicted_mask_resized.shape, dtype=np.uint8)
for tissue in TISSUE_TYPES:
color_mask = cv2.inRange(hsv_wound, tissue["lower"], tissue["upper"])
tissue_mask = cv2.bitwise_and(color_mask, color_mask, mask=predicted_mask_resized)
classification_mask[(tissue_mask > 0) & (classification_mask == 0)] = tissue["id"]
last_tissue_id = TISSUE_TYPES[-1]["id"]
remaining_pixels_mask = (predicted_mask_resized > 0) & (classification_mask == 0)
classification_mask[remaining_pixels_mask] = last_tissue_id
tissue_percentages = {}
display_percentages = {}
total_wound_pixels = cv2.countNonZero(predicted_mask_resized)
if total_wound_pixels > 0:
for tissue in TISSUE_TYPES:
tissue_pixel_count = np.count_nonzero(classification_mask == tissue["id"])
percentage = (tissue_pixel_count / total_wound_pixels) * 100
tissue_percentages[tissue["name"]] = percentage
if percentage > 0.1:
display_percentages[tissue["display_name"]] = round(float(percentage), 1)
infection_score = calculate_infection_risk(tissue_percentages)
final_overlay = original_img.copy()
wound_contours, _ = cv2.findContours(predicted_mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if wound_contours:
cv2.drawContours(final_overlay, wound_contours, -1, (255, 0, 0), 3)
for tissue in TISSUE_TYPES:
if tissue_percentages.get(tissue["name"], 0) > 0.1:
single_tissue_mask = np.where(classification_mask == tissue["id"], 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(single_tissue_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(final_overlay, contours, -1, tissue["color"], 2)
if ref_contour is not None:
color = (0, 255, 0) if not applied_perspective else (0, 0, 255) # Red for ArUco, Green for Blue
cv2.drawContours(final_overlay, [ref_contour], -1, color, 2)
final_overlay_rgb = cv2.cvtColor(final_overlay, cv2.COLOR_BGR2RGB)
overlay_data_url = _encode_rgb_np_to_data_url(final_overlay_rgb, fmt="JPEG")
return {
"status": "success",
"analysis": {
"total_area_cm2": round(wound_area_cm2, 2),
"infection_risk_score": infection_score,
"tissue_composition": display_percentages,
},
"overlay_image_base64": overlay_data_url,
"processed_image_base64": overlay_data_url,
}
api = FastAPI(title="Wound Segmentation API", version="1.0.0")
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
class AnalyzeRequest(BaseModel):
image_base64: str
@api.get("/")
def root():
return RedirectResponse(url="/gradio")
@api.get("/api/health")
def health():
return {
"status": "ok",
"service": "wound-segmentation",
"model_loaded": model is not None,
"supports": [
"segmentation_overlay",
"tissue_composition",
"infection_risk_score",
"infection_risk_score",
"area_cm2_via_aruco_with_perspective_correction",
"area_cm2_via_blue_reference_fallback",
],
}
def _run_segment(request: AnalyzeRequest):
request_id = str(uuid.uuid4())
try:
image_np = _decode_base64_to_rgb_np(request.image_base64)
result = analyze_wound_image(image_np)
if result.get("status") == "error":
raise HTTPException(status_code=400, detail=result["message"])
result["request_id"] = request_id
result["model_info"] = {
"name": MODEL_PATH,
"input_size": [IMG_HEIGHT, IMG_WIDTH],
}
return result
except HTTPException:
raise
except Exception as e:
if "Incorrect padding" in str(e):
raise HTTPException(status_code=400, detail="Invalid Base64 string. Please check the input.")
raise HTTPException(status_code=500, detail=f"Internal error: {e}")
@api.post("/api/segment")
def segment_api(request: AnalyzeRequest):
return _run_segment(request)
@api.post("/analyze")
def analyze_api(request: AnalyzeRequest):
return _run_segment(request)
def gradio_interface_with_image(image_input):
if image_input is None:
return None, {"status": "error", "message": "Please upload an image."}
result = analyze_wound_image(image_input)
if result.get("status") == "error":
return None, result
try:
base64_string = result.get("overlay_image_base64") or result["processed_image_base64"]
if "," in base64_string:
base64_string = re.sub(r"^data:image/.+;base64,", "", base64_string)
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image, result
except Exception as e:
return None, {"status": "error", "message": f"Failed to decode processed image: {e}"}
demo = gr.Interface(
fn=gradio_interface_with_image,
inputs=gr.Image(type="numpy", label="Upload Wound Image"),
outputs=[
gr.Image(type="pil", label="Segmentation Overlay"),
gr.JSON(label="Analysis Result"),
],
title="Wound Size Analysis Test Space (wstest)",
description=(
"Upload a wound image (ideally with a blue 2x2 cm reference square) to run "
"segmentation, tissue composition analysis, infection risk estimation, and overlay generation."
),
allow_flagging="never",
)
app = gr.mount_gradio_app(api, demo, path="/gradio")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)