File size: 6,340 Bytes
bf07f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import numpy as np
import json
from typing import Dict, Any

# --- New Helper Function for Dynamic Testing ---
def generate_random_heatmap(size: int = 224) -> np.ndarray:
    """
    Generates a randomized, plausible heatmap array for testing the agent's dynamism.
    The heatmap will have a focused, high-intensity area somewhere random.
    """
    # Create a base array of zeros
    cam_array = np.zeros((size, size), dtype=np.float32)
    
    # 1. Define random center and size for the activation zone
    center_y = np.random.randint(size // 4, size * 3 // 4)
    center_x = np.random.randint(size // 4, size * 3 // 4)
    height = np.random.randint(30, 80)
    width = np.random.randint(30, 80)
    
    # Define activation bounds (ensure they stay within the array limits)
    y_min = max(0, center_y - height // 2)
    y_max = min(size, center_y + height // 2)
    x_min = max(0, center_x - width // 2)
    x_max = min(size, center_x + width // 2)
    
    # 2. Apply activation with random strength
    random_strength = np.random.uniform(0.6, 1.0)
    cam_array[y_min:y_max, x_min:x_max] = random_strength
    
    # Optional: Add minor noise to make it less blocky
    cam_array = cam_array + np.random.uniform(0, 0.1, (size, size))
    cam_array = np.clip(cam_array, 0, 1)

    return cam_array

# --- Helper function for localization (No changes needed, it is dynamic) ---

def calculate_heatmap_centroid(cam_array: np.ndarray, threshold: float = 0.5) -> tuple:
    """
    Calculates the centroid (center of mass) of the significant activation area
    in the Grad-CAM heatmap.
    """
    # 1. Apply threshold to isolate the 'hot' region
    binary_map = cam_array > threshold
    
    if not np.any(binary_map):
        return (0.5, 0.5, 0.0)

    # 2. Calculate coordinates and weights (activation values)
    coords = np.argwhere(binary_map)
    weights = cam_array[binary_map]

    if len(weights) == 0:
        return (0.5, 0.5, 0.0)

    # 3. Calculate weighted average for the centroid
    y_coords = coords[:, 0] # Rows (Y)
    x_coords = coords[:, 1] # Columns (X)

    sum_weights = np.sum(weights)
    
    centroid_x = np.sum(x_coords * weights) / sum_weights
    centroid_y = np.sum(y_coords * weights) / sum_weights

    # Normalize to [0, 1] based on map size
    h, w = cam_array.shape
    norm_x = centroid_x / w
    norm_y = centroid_y / h

    max_activation = np.max(weights)

    return (norm_x, norm_y, max_activation)

# --- Explainability Agent Core (No changes needed, logic is dynamic) ---

class ExplainabilityAgent:
    def __init__(self, class_names: list, body_part: str = "bone"):
        self.class_names = class_names
        self.body_part = body_part

    def generate_explanation(self, diagnosis_result: Dict[str, Any], cam_array: np.ndarray) -> str:
        """
        Converts the Grad-CAM heatmap and prediction result into a textual explanation.
        """
        predicted_class = diagnosis_result.get("predicted_class", "Unknown")
        confidence = diagnosis_result.get("confidence_score", 0.0)
        
        # 1. Analyze Heatmap
        norm_x, norm_y, strength = calculate_heatmap_centroid(cam_array, threshold=0.4)
        
        # Determine general location (Simplified)
        x_loc = "right side" if norm_x > 0.65 else ("left side" if norm_x < 0.35 else "center")
        y_loc = "distal end" if norm_y > 0.65 else ("proximal end" if norm_y < 0.35 else "middle region")
        
        # 2. Build Textual Explanation based on Prediction
        
        if predicted_class == "Healthy":
            if confidence > 0.90:
                return f"The {self.body_part} appears **healthy** with high confidence ({confidence:.2f}). No fracture pattern was detected."
            else:
                return f"The {self.body_part} is likely **healthy** ({confidence:.2f}), though there is some low activation in the {y_loc} of the {x_loc} that warrants a closer look."

        if not diagnosis_result.get("fracture_detected", True): # Default to True if key missing
             return f"Diagnosis is **inconclusive** or data is missing."
             
        # 3. Explanation for Detected Fracture
        
        intro = f"A fracture pattern consistent with a **{predicted_class}** type is detected"
        
        # Strength adjective
        if strength > 0.7:
            strength_adj = "strong"
        elif strength > 0.5:
            strength_adj = "clear"
        else:
            strength_adj = "mild"
            
        # Confidence statement
        confidence_stmt = f"(Confidence: {confidence:.2f})"
        
        # Location statement
        location_stmt = f"near the **{y_loc}** of the {self.body_part} in the {x_loc}."
        
        # Final Assembly
        explanation = f"{intro} {confidence_stmt}. The model's focus is {strength_adj} {location_stmt}"
        
        # Add a note on the type based on visual focus
        if predicted_class in ["Transverse", "Oblique"]:
            explanation += " This is based on a distinct linear focus."
        
        return explanation


# --- 4. EXAMPLE USAGE ---

if __name__ == '__main__':
    
    # --- SIMULATED INPUT ---
    SIMULATED_RESULT = {
        "image_path": "test_image.jpg",
        "fracture_detected": True,
        "predicted_class": "Spiral",
        "severity_type": "Spiral",
        "confidence_score": 0.95,
        "uncertainty_score": 0.05,
    }

    CLASS_NAMES = ["Comminuted", "Greenstick", "Healthy", "Oblique", "Oblique Displaced", "Spiral", "Transverse", "Transverse Displaced"]
    explainer = ExplainabilityAgent(class_names=CLASS_NAMES, body_part="humerus")
    
    # Run 3 times to demonstrate dynamic output
    print("\n--- Testing Dynamic Output (Run 1: Random Heatmap) ---")
    
    # Use the new dynamic heatmap function!
    dynamic_cam_1 = generate_random_heatmap()
    explanation_text_1 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_1)
    print(f"Explanation 1: {explanation_text_1}")

    print("\n--- Testing Dynamic Output (Run 2: Another Random Heatmap) ---")
    dynamic_cam_2 = generate_random_heatmap()
    explanation_text_2 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_2)
    print(f"Explanation 2: {explanation_text_2}")
    
    print("--------------------------------------------------\n")