Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import json | |
| from typing import Dict, Any | |
| # --- New Helper Function for Dynamic Testing --- | |
| def generate_random_heatmap(size: int = 224) -> np.ndarray: | |
| """ | |
| Generates a randomized, plausible heatmap array for testing the agent's dynamism. | |
| The heatmap will have a focused, high-intensity area somewhere random. | |
| """ | |
| # Create a base array of zeros | |
| cam_array = np.zeros((size, size), dtype=np.float32) | |
| # 1. Define random center and size for the activation zone | |
| center_y = np.random.randint(size // 4, size * 3 // 4) | |
| center_x = np.random.randint(size // 4, size * 3 // 4) | |
| height = np.random.randint(30, 80) | |
| width = np.random.randint(30, 80) | |
| # Define activation bounds (ensure they stay within the array limits) | |
| y_min = max(0, center_y - height // 2) | |
| y_max = min(size, center_y + height // 2) | |
| x_min = max(0, center_x - width // 2) | |
| x_max = min(size, center_x + width // 2) | |
| # 2. Apply activation with random strength | |
| random_strength = np.random.uniform(0.6, 1.0) | |
| cam_array[y_min:y_max, x_min:x_max] = random_strength | |
| # Optional: Add minor noise to make it less blocky | |
| cam_array = cam_array + np.random.uniform(0, 0.1, (size, size)) | |
| cam_array = np.clip(cam_array, 0, 1) | |
| return cam_array | |
| # --- Helper function for localization (No changes needed, it is dynamic) --- | |
| def calculate_heatmap_centroid(cam_array: np.ndarray, threshold: float = 0.5) -> tuple: | |
| """ | |
| Calculates the centroid (center of mass) of the significant activation area | |
| in the Grad-CAM heatmap. | |
| """ | |
| # 1. Apply threshold to isolate the 'hot' region | |
| binary_map = cam_array > threshold | |
| if not np.any(binary_map): | |
| return (0.5, 0.5, 0.0) | |
| # 2. Calculate coordinates and weights (activation values) | |
| coords = np.argwhere(binary_map) | |
| weights = cam_array[binary_map] | |
| if len(weights) == 0: | |
| return (0.5, 0.5, 0.0) | |
| # 3. Calculate weighted average for the centroid | |
| y_coords = coords[:, 0] # Rows (Y) | |
| x_coords = coords[:, 1] # Columns (X) | |
| sum_weights = np.sum(weights) | |
| centroid_x = np.sum(x_coords * weights) / sum_weights | |
| centroid_y = np.sum(y_coords * weights) / sum_weights | |
| # Normalize to [0, 1] based on map size | |
| h, w = cam_array.shape | |
| norm_x = centroid_x / w | |
| norm_y = centroid_y / h | |
| max_activation = np.max(weights) | |
| return (norm_x, norm_y, max_activation) | |
| # --- Explainability Agent Core (No changes needed, logic is dynamic) --- | |
| class ExplainabilityAgent: | |
| def __init__(self, class_names: list, body_part: str = "bone"): | |
| self.class_names = class_names | |
| self.body_part = body_part | |
| def generate_explanation(self, diagnosis_result: Dict[str, Any], cam_array: np.ndarray) -> str: | |
| """ | |
| Converts the Grad-CAM heatmap and prediction result into a textual explanation. | |
| """ | |
| predicted_class = diagnosis_result.get("predicted_class", "Unknown") | |
| confidence = diagnosis_result.get("confidence_score", 0.0) | |
| # 1. Analyze Heatmap | |
| norm_x, norm_y, strength = calculate_heatmap_centroid(cam_array, threshold=0.4) | |
| # Determine general location (Simplified) | |
| x_loc = "right side" if norm_x > 0.65 else ("left side" if norm_x < 0.35 else "center") | |
| y_loc = "distal end" if norm_y > 0.65 else ("proximal end" if norm_y < 0.35 else "middle region") | |
| # 2. Build Textual Explanation based on Prediction | |
| if predicted_class == "Healthy": | |
| if confidence > 0.90: | |
| return f"The {self.body_part} appears **healthy** with high confidence ({confidence:.2f}). No fracture pattern was detected." | |
| else: | |
| return f"The {self.body_part} is likely **healthy** ({confidence:.2f}), though there is some low activation in the {y_loc} of the {x_loc} that warrants a closer look." | |
| if not diagnosis_result.get("fracture_detected", True): # Default to True if key missing | |
| return f"Diagnosis is **inconclusive** or data is missing." | |
| # 3. Explanation for Detected Fracture | |
| intro = f"A fracture pattern consistent with a **{predicted_class}** type is detected" | |
| # Strength adjective | |
| if strength > 0.7: | |
| strength_adj = "strong" | |
| elif strength > 0.5: | |
| strength_adj = "clear" | |
| else: | |
| strength_adj = "mild" | |
| # Confidence statement | |
| confidence_stmt = f"(Confidence: {confidence:.2f})" | |
| # Location statement | |
| location_stmt = f"near the **{y_loc}** of the {self.body_part} in the {x_loc}." | |
| # Final Assembly | |
| explanation = f"{intro} {confidence_stmt}. The model's focus is {strength_adj} {location_stmt}" | |
| # Add a note on the type based on visual focus | |
| if predicted_class in ["Transverse", "Oblique"]: | |
| explanation += " This is based on a distinct linear focus." | |
| return explanation | |
| # --- 4. EXAMPLE USAGE --- | |
| if __name__ == '__main__': | |
| # --- SIMULATED INPUT --- | |
| SIMULATED_RESULT = { | |
| "image_path": "test_image.jpg", | |
| "fracture_detected": True, | |
| "predicted_class": "Spiral", | |
| "severity_type": "Spiral", | |
| "confidence_score": 0.95, | |
| "uncertainty_score": 0.05, | |
| } | |
| CLASS_NAMES = ["Comminuted", "Greenstick", "Healthy", "Oblique", "Oblique Displaced", "Spiral", "Transverse", "Transverse Displaced"] | |
| explainer = ExplainabilityAgent(class_names=CLASS_NAMES, body_part="humerus") | |
| # Run 3 times to demonstrate dynamic output | |
| print("\n--- Testing Dynamic Output (Run 1: Random Heatmap) ---") | |
| # Use the new dynamic heatmap function! | |
| dynamic_cam_1 = generate_random_heatmap() | |
| explanation_text_1 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_1) | |
| print(f"Explanation 1: {explanation_text_1}") | |
| print("\n--- Testing Dynamic Output (Run 2: Another Random Heatmap) ---") | |
| dynamic_cam_2 = generate_random_heatmap() | |
| explanation_text_2 = explainer.generate_explanation(SIMULATED_RESULT, dynamic_cam_2) | |
| print(f"Explanation 2: {explanation_text_2}") | |
| print("--------------------------------------------------\n") |