Spaces:
Runtime error
Runtime error
File size: 6,340 Bytes
bf07f10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import numpy as np
import json
from typing import Dict, Any
# --- New Helper Function for Dynamic Testing ---
def generate_random_heatmap(size: int = 224) -> np.ndarray:
"""
Generates a randomized, plausible heatmap array for testing the agent's dynamism.
The heatmap will have a focused, high-intensity area somewhere random.
"""
# Create a base array of zeros
cam_array = np.zeros((size, size), dtype=np.float32)
# 1. Define random center and size for the activation zone
center_y = np.random.randint(size // 4, size * 3 // 4)
center_x = np.random.randint(size // 4, size * 3 // 4)
height = np.random.randint(30, 80)
width = np.random.randint(30, 80)
# Define activation bounds (ensure they stay within the array limits)
y_min = max(0, center_y - height // 2)
y_max = min(size, center_y + height // 2)
x_min = max(0, center_x - width // 2)
x_max = min(size, center_x + width // 2)
# 2. Apply activation with random strength
random_strength = np.random.uniform(0.6, 1.0)
cam_array[y_min:y_max, x_min:x_max] = random_strength
# Optional: Add minor noise to make it less blocky
cam_array = cam_array + np.random.uniform(0, 0.1, (size, size))
cam_array = np.clip(cam_array, 0, 1)
return cam_array
# --- Helper function for localization (No changes needed, it is dynamic) ---
def calculate_heatmap_centroid(cam_array: np.ndarray, threshold: float = 0.5) -> tuple:
"""
Calculates the centroid (center of mass) of the significant activation area
in the Grad-CAM heatmap.
"""
# 1. Apply threshold to isolate the 'hot' region
binary_map = cam_array > threshold
if not np.any(binary_map):
return (0.5, 0.5, 0.0)
# 2. Calculate coordinates and weights (activation values)
coords = np.argwhere(binary_map)
weights = cam_array[binary_map]
if len(weights) == 0:
return (0.5, 0.5, 0.0)
# 3. Calculate weighted average for the centroid
y_coords = coords[:, 0] # Rows (Y)
x_coords = coords[:, 1] # Columns (X)
sum_weights = np.sum(weights)
centroid_x = np.sum(x_coords * weights) / sum_weights
centroid_y = np.sum(y_coords * weights) / sum_weights
# Normalize to [0, 1] based on map size
h, w = cam_array.shape
norm_x = centroid_x / w
norm_y = centroid_y / h
max_activation = np.max(weights)
return (norm_x, norm_y, max_activation)
# --- Explainability Agent Core (No changes needed, logic is dynamic) ---
class ExplainabilityAgent:
def __init__(self, class_names: list, body_part: str = "bone"):
self.class_names = class_names
self.body_part = body_part
def generate_explanation(self, diagnosis_result: Dict[str, Any], cam_array: np.ndarray) -> str:
"""
Converts the Grad-CAM heatmap and prediction result into a textual explanation.
"""
predicted_class = diagnosis_result.get("predicted_class", "Unknown")
confidence = diagnosis_result.get("confidence_score", 0.0)
# 1. Analyze Heatmap
norm_x, norm_y, strength = calculate_heatmap_centroid(cam_array, threshold=0.4)
# Determine general location (Simplified)
x_loc = "right side" if norm_x > 0.65 else ("left side" if norm_x < 0.35 else "center")
y_loc = "distal end" if norm_y > 0.65 else ("proximal end" if norm_y < 0.35 else "middle region")
# 2. Build Textual Explanation based on Prediction
if predicted_class == "Healthy":
if confidence > 0.90:
return f"The {self.body_part} appears **healthy** with high confidence ({confidence:.2f}). No fracture pattern was detected."
else:
return f"The {self.body_part} is likely **healthy** ({confidence:.2f}), though there is some low activation in the {y_loc} of the {x_loc} that warrants a closer look."
if not diagnosis_result.get("fracture_detected", True): # Default to True if key missing
return f"Diagnosis is **inconclusive** or data is missing."
# 3. Explanation for Detected Fracture
intro = f"A fracture pattern consistent with a **{predicted_class}** type is detected"
# Strength adjective
if strength > 0.7:
strength_adj = "strong"
elif strength > 0.5:
strength_adj = "clear"
else:
strength_adj = "mild"
# Confidence statement
confidence_stmt = f"(Confidence: {confidence:.2f})"
# Location statement
location_stmt = f"near the **{y_loc}** of the {self.body_part} in the {x_loc}."
# Final Assembly
explanation = f"{intro} {confidence_stmt}. The model's focus is {strength_adj} {location_stmt}"
# Add a note on the type based on visual focus
if predicted_class in ["Transverse", "Oblique"]:
explanation += " This is based on a distinct linear focus."
return explanation
# --- 4. EXAMPLE USAGE ---
if __name__ == '__main__':
# --- SIMULATED INPUT ---
SIMULATED_RESULT = {
"image_path": "test_image.jpg",
"fracture_detected": True,
"predicted_class": "Spiral",
"severity_type": "Spiral",
"confidence_score": 0.95,
"uncertainty_score": 0.05,
}
CLASS_NAMES = ["Comminuted", "Greenstick", "Healthy", "Oblique", "Oblique Displaced", "Spiral", "Transverse", "Transverse Displaced"]
explainer = ExplainabilityAgent(class_names=CLASS_NAMES, body_part="humerus")
# Run 3 times to demonstrate dynamic output
print("\n--- Testing Dynamic Output (Run 1: Random Heatmap) ---")
# Use the new dynamic heatmap function!
dynamic_cam_1 = generate_random_heatmap()
explanation_text_1 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_1)
print(f"Explanation 1: {explanation_text_1}")
print("\n--- Testing Dynamic Output (Run 2: Another Random Heatmap) ---")
dynamic_cam_2 = generate_random_heatmap()
explanation_text_2 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_2)
print(f"Explanation 2: {explanation_text_2}")
print("--------------------------------------------------\n") |