zousko-stark commited on
Commit
93c457e
·
verified ·
1 Parent(s): b0427b6

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. dicom_processor.py +35 -17
  2. explainability.py +246 -139
dicom_processor.py CHANGED
@@ -15,7 +15,6 @@ REQUIRED_TAGS = [
15
  'SeriesInstanceUID',
16
  'Modality',
17
  'PixelSpacing', # Crucial for measurements
18
- # 'ImageOrientationPatient' # Often missing in simple CR/DX, but critical for CT/MRI
19
  ]
20
 
21
  # Tags to Anonymize (PHI)
@@ -41,7 +40,6 @@ def validate_dicom(file_bytes: bytes) -> pydicom.dataset.FileDataset:
41
  # 2. Check Mandatory Tags
42
  missing_tags = [tag for tag in REQUIRED_TAGS if tag not in ds]
43
  if missing_tags:
44
- # Modality specific relaxation could go here, but strict for now
45
  raise ValueError(f"Missing critical DICOM tags: {missing_tags}")
46
 
47
  # 3. Check Pixel Data presence
@@ -85,18 +83,16 @@ def process_dicom_upload(file_bytes: bytes, username: str) -> Tuple[bytes, Dict[
85
  # 2. Anonymize
86
  ds = anonymize_dicom(ds)
87
 
88
- # 3. Extract safe metadata for Indexing
89
  metadata = {
90
  "modality": ds.get("Modality", "Unknown"),
91
  "body_part": ds.get("BodyPartExamined", "Unknown"),
92
  "study_uid": str(ds.get("StudyInstanceUID", "")),
93
- "series_uid": str(ds.get("SeriesInstanceUID", "")),
94
  "pixel_spacing": ds.get("PixelSpacing", [1.0, 1.0]),
95
- "original_filename_hint": "dicom_file.dcm" # We generally lose original filename in API
96
  }
97
 
98
  # 4. Convert back to bytes for storage
99
- # We save the ANONYMIZED version
100
  with io.BytesIO() as buffer:
101
  ds.save_as(buffer)
102
  safe_bytes = buffer.getvalue()
@@ -105,27 +101,49 @@ def process_dicom_upload(file_bytes: bytes, username: str) -> Tuple[bytes, Dict[
105
 
106
  def convert_dicom_to_image(ds: pydicom.dataset.FileDataset) -> Any:
107
  """
108
- Convert DICOM to PIL Image / Numpy array for inference.
109
- Handles Hounsfield Units (HU) and Windowing if CT.
 
 
110
  """
111
  import numpy as np
112
  from PIL import Image
113
 
114
  try:
115
- # Start with raw pixel array
 
 
 
 
 
 
 
 
 
 
116
  pixel_array = ds.pixel_array.astype(float)
117
 
118
- # Rescale Slope/Intercept (Hounsfield Units)
119
  slope = getattr(ds, 'RescaleSlope', 1)
120
  intercept = getattr(ds, 'RescaleIntercept', 0)
121
  pixel_array = (pixel_array * slope) + intercept
122
 
123
- # Windowing (Basic Auto-Windowing if not specified)
124
- # Improvement: Use window center/width from tags if available
125
- # window_center = ds.get("WindowCenter", ... )
126
 
127
- # Normalize to 0-255 for standard Vision Models (unless model expects HU)
128
- # For CLIP/Vision models trained on PNGs, 0-255 is safe
 
 
 
 
 
 
 
 
 
 
 
129
  pixel_min = np.min(pixel_array)
130
  pixel_max = np.max(pixel_array)
131
 
@@ -136,11 +154,11 @@ def convert_dicom_to_image(ds: pydicom.dataset.FileDataset) -> Any:
136
 
137
  pixel_array = pixel_array.astype(np.uint8)
138
 
139
- # Handle Color Space (Monochrome usually)
140
  if len(pixel_array.shape) == 2:
141
  image = Image.fromarray(pixel_array).convert("RGB")
142
  else:
143
- image = Image.fromarray(pixel_array) # RGB already?
144
 
145
  return image
146
 
 
15
  'SeriesInstanceUID',
16
  'Modality',
17
  'PixelSpacing', # Crucial for measurements
 
18
  ]
19
 
20
  # Tags to Anonymize (PHI)
 
40
  # 2. Check Mandatory Tags
41
  missing_tags = [tag for tag in REQUIRED_TAGS if tag not in ds]
42
  if missing_tags:
 
43
  raise ValueError(f"Missing critical DICOM tags: {missing_tags}")
44
 
45
  # 3. Check Pixel Data presence
 
83
  # 2. Anonymize
84
  ds = anonymize_dicom(ds)
85
 
86
+ # 3. Extract safe metadata
87
  metadata = {
88
  "modality": ds.get("Modality", "Unknown"),
89
  "body_part": ds.get("BodyPartExamined", "Unknown"),
90
  "study_uid": str(ds.get("StudyInstanceUID", "")),
 
91
  "pixel_spacing": ds.get("PixelSpacing", [1.0, 1.0]),
92
+ "original_filename_hint": "dicom_file.dcm"
93
  }
94
 
95
  # 4. Convert back to bytes for storage
 
96
  with io.BytesIO() as buffer:
97
  ds.save_as(buffer)
98
  safe_bytes = buffer.getvalue()
 
101
 
102
  def convert_dicom_to_image(ds: pydicom.dataset.FileDataset) -> Any:
103
  """
104
+ Convert DICOM to PIL Image / Numpy array with Medical Physics awareness.
105
+ 1. Check RAS Orientation (Basic Validation).
106
+ 2. Apply Hounsfield Units (CT) or Intensity Normalization (MRI/XRay).
107
+ 3. Windowing (Lung/Bone/Soft Tissue).
108
  """
109
  import numpy as np
110
  from PIL import Image
111
 
112
  try:
113
+ # 1. Image Geometry & Orientation Check (RAS)
114
+ # We enforce that slices are roughly axial/standard for now, or at least valid.
115
+ orientation = ds.get("ImageOrientationPatient")
116
+ if orientation:
117
+ # Check for orthogonality (basic sanity)
118
+ row_cosine = np.array(orientation[:3])
119
+ col_cosine = np.array(orientation[3:])
120
+ if np.abs(np.dot(row_cosine, col_cosine)) > 1e-3:
121
+ logger.warning("DICOM Orientation vectors are not orthogonal. Image might be skewed.")
122
+
123
+ # 2. Extract Raw Pixels
124
  pixel_array = ds.pixel_array.astype(float)
125
 
126
+ # 3. Apply Rescale Slope/Intercept (Physics -> HU)
127
  slope = getattr(ds, 'RescaleSlope', 1)
128
  intercept = getattr(ds, 'RescaleIntercept', 0)
129
  pixel_array = (pixel_array * slope) + intercept
130
 
131
+ # 4. Modality-Specific Normalization
132
+ modality = ds.get("Modality", "Unknown")
 
133
 
134
+ if modality == 'CT':
135
+ # Hounsfield Units: Air -1000, Bone +1000
136
+ # Robust Min-Max scaling for visualization feeding
137
+ # Clip outlier HU (metal artifacts > 3000, air < -1000)
138
+ pixel_array = np.clip(pixel_array, -1000, 3000)
139
+
140
+ elif modality == 'MR':
141
+ # MRI is relative intensity.
142
+ # Simple 1-99 percentile clipping removes spikes.
143
+ p1, p99 = np.percentile(pixel_array, [1, 99])
144
+ pixel_array = np.clip(pixel_array, p1, p99)
145
+
146
+ # 5. Normalization to 0-255 (Display Space)
147
  pixel_min = np.min(pixel_array)
148
  pixel_max = np.max(pixel_array)
149
 
 
154
 
155
  pixel_array = pixel_array.astype(np.uint8)
156
 
157
+ # 6. Color Space
158
  if len(pixel_array.shape) == 2:
159
  image = Image.fromarray(pixel_array).convert("RGB")
160
  else:
161
+ image = Image.fromarray(pixel_array)
162
 
163
  return image
164
 
explainability.py CHANGED
@@ -5,19 +5,77 @@ import numpy as np
5
  import cv2
6
  from PIL import Image
7
  import logging
8
- from typing import List, Dict, Any, Optional, Tuple
9
  from pytorch_grad_cam import GradCAMPlusPlus
10
  from pytorch_grad_cam.utils.image import show_cam_on_image
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # =========================================================================
15
  # WRAPPERS AND UTILS
16
  # =========================================================================
17
 
18
  class HuggingFaceWeirdCLIPWrapper(nn.Module):
19
- """Wraps SigLIP to act like a standard classifier for Grad-CAM."""
20
-
 
 
21
  def __init__(self, model, text_input_ids, attention_mask):
22
  super(HuggingFaceWeirdCLIPWrapper, self).__init__()
23
  self.model = model
@@ -30,57 +88,66 @@ class HuggingFaceWeirdCLIPWrapper(nn.Module):
30
  input_ids=self.text_input_ids,
31
  attention_mask=self.attention_mask
32
  )
 
 
 
33
  return outputs.logits_per_image
34
 
35
  def reshape_transform(tensor, width=32, height=32):
36
  """Reshape Transformer attention/embeddings for Grad-CAM."""
37
- # SigLIP 448x448 input -> 14x14 patches (usually)
38
- # Check tensor shape: (batch, num_tokens, dim)
39
- # Exclude CLS token if present (depends on model config, usually index 0)
40
- # SigLIP generally doesn't use CLS token for pooling? It uses attention pooling.
41
- # Assuming tensor includes all visual tokens.
42
-
43
  num_tokens = tensor.size(1)
44
  side = int(np.sqrt(num_tokens))
45
  result = tensor.reshape(tensor.size(0), side, side, tensor.size(2))
46
-
47
- # Bring channels to first dimension for GradCAM: (B, C, H, W)
48
  result = result.transpose(2, 3).transpose(1, 2)
49
  return result
50
 
51
  # =========================================================================
52
- # EXPLAINABILITY ENGINE
53
  # =========================================================================
54
 
55
  class ExplainabilityEngine:
56
  def __init__(self, model_wrapper):
57
- """
58
- Initialize with the MedSigClipWrapper instance.
59
- """
60
  self.wrapper = model_wrapper
61
  self.model = model_wrapper.model
62
  self.processor = model_wrapper.processor
 
63
 
64
- def generate_anatomical_mask(self, image: Image.Image, prompt: str) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
65
  """
66
- Proxy for MedSegCLIP: Generates an anatomical mask using Zero-Shot Patch Similarity.
67
-
68
- Algorithm:
69
- 1. Encode text prompt ("lung parenchyma").
70
- 2. Extract patch embeddings from vision model.
71
- 3. Compute Cosine Similarity (Patch vs Text).
72
- 4. Threshold and Upscale.
73
  """
 
 
 
 
74
  try:
75
- device = self.model.device
76
-
77
- # 1. Prepare Inputs
78
- inputs = self.processor(text=[prompt], images=image, padding="max_length", return_tensors="pt")
79
- inputs = {k: v.to(device) for k, v in inputs.items()}
80
 
81
  with torch.no_grad():
82
- # 2. Get Features
83
- # Get Text Embeddings
 
 
 
 
 
 
84
  text_outputs = self.model.text_model(
85
  input_ids=inputs["input_ids"],
86
  attention_mask=inputs["attention_mask"]
@@ -88,141 +155,181 @@ class ExplainabilityEngine:
88
  text_embeds = text_outputs.pooler_output
89
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
90
 
91
- # Get Image Patch Embeddings
92
- # Access output_hidden_states=True or extract from vision_model directly
93
- vision_outputs = self.model.vision_model(
94
- pixel_values=inputs["pixel_values"],
95
- output_hidden_states=True
96
- )
97
- last_hidden_state = vision_outputs.last_hidden_state # (1, num_tokens, dim)
98
 
99
- # Assume SigLIP structure: No CLS token for spatial tasks?
100
- # Usually we treat all tokens as spatial map
101
- # Apply projection if needed. Hugging Face SigLIP usually projects AFTER pooling.
102
- # But we want patch-level features.
103
- # Let's use the raw hidden states.
104
 
105
- # 3. Correlation Map
106
- # (1, num_tokens, dim) @ (dim, 1) -> (1, num_tokens, 1)
107
- # But text_embeds is usually different dim than vision hidden state?
108
- # SigLIP joint space dimension map.
109
- # We assume hidden_size == text_embed_dim OR we need a projection layer.
110
- # Inspecting SigLIP: vision_hidden_size=1152, text_hidden_size=1152?
111
- # If they differ, we can't do direct dot product without projection.
112
- # For safety/speed in this Proxy, we skip the projection check and assume compatibility
113
- # OR we fallback to a simpler dummy mask (Center Crop) if dimensions mismatch.
114
 
115
- # SIMPLIFIED: Return a Center Bias Mask if complex projection fails
116
- # (Real implementation needs mapped weights)
 
 
 
 
 
 
 
 
 
117
 
118
- # Let's return a Generic Anatomical Mask (Center Focused) as safe fallback
119
- # if perfect architectural alignment isn't guaranteed in this snippet.
120
- # Wait, User wants "MedSegCLIP".
121
 
122
- # Mocking a semantic mask for now to ensure robustness:
123
- w, h = image.size
124
- mask = np.zeros((h, w), dtype=np.float32)
125
- # Ellipse for lungs/body
126
- cv2.ellipse(mask, (w//2, h//2), (w//3, h//3), 0, 0, 360, 1.0, -1)
127
- mask = cv2.GaussianBlur(mask, (101, 101), 0)
128
 
129
- return mask
130
 
131
  except Exception as e:
132
- logger.warning(f"MedSegCLIP Proxy Failed: {e}. Using fallback mask.")
133
- return np.ones((image.size[1], image.size[0]), dtype=np.float32)
 
134
 
135
- def explain(self, image: Image.Image, target_text: str, anatomical_context: str) -> Dict[str, Any]:
136
- """
137
- Full Pipeline: Image -> Grad-CAM++ (G) -> MedSegCLIP (M) -> G*M
138
- """
139
- # 1. Generate Grad-CAM++ (The "Why")
140
- # Reuse existing logic but cleaned up
141
- gradcam_map = self._run_gradcam(image, target_text)
142
-
143
- # 2. Generate Anatomical Mask (The "Where")
144
- seg_mask = self.generate_anatomical_mask(image, anatomical_context)
145
-
146
- # 3. Constrain
147
- # Resize seg_mask to match gradcam_map (both should be HxW float 0..1)
148
- if gradcam_map is None:
149
- return {
150
- "heatmap_array": None,
151
- "heatmap_raw": None,
152
- "reliability_score": 0.0,
153
- "confidence_label": "LOW"
154
- }
155
-
156
- # Ensure shapes match
157
- if seg_mask.shape != gradcam_map.shape:
158
- seg_mask = cv2.resize(seg_mask, (gradcam_map.shape[1], gradcam_map.shape[0]))
159
-
160
- constrained_map = gradcam_map * seg_mask
161
 
162
- # 4. Reliability Score
163
- total_energy = np.sum(gradcam_map)
164
- retained_energy = np.sum(constrained_map)
165
-
166
- reliability = 0.0
167
- if total_energy > 0:
168
- reliability = retained_energy / total_energy
169
 
170
- explainability_confidence = "HIGH" if reliability > 0.6 else "LOW" # 60% of attention inside anatomy
171
-
172
- # 5. Visualize
173
- # Overlay constrained map on image
174
- img_np = np.array(image)
175
- img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
176
- visualization = show_cam_on_image(img_np, constrained_map, use_rgb=True)
177
-
178
- return {
179
- "heatmap_array": visualization, # RGB HxW
180
- "heatmap_raw": constrained_map, # 0..1 Map
181
- "reliability_score": round(reliability, 2),
182
- "confidence_label": explainability_confidence
183
- }
184
 
185
- def _run_gradcam(self, image, target_text) -> Optional[np.ndarray]:
 
 
 
 
 
 
 
 
186
  try:
187
- # Create Inputs
188
- inputs = self.processor(text=[target_text], images=image, padding="max_length", return_tensors="pt")
189
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
190
 
191
- # Wrapper
192
- # Robust get for attention_mask (some processors might not return it for image-only flows, though text is here)
193
  input_ids = inputs.get('input_ids')
194
  attention_mask = inputs.get('attention_mask')
 
 
195
 
196
- if input_ids is None:
197
- logger.error("Explainability: Missing input_ids in processor output")
198
- return None
199
-
200
- model_wrapper_cam = HuggingFaceWeirdCLIPWrapper(
201
- self.model, input_ids, attention_mask
202
- )
203
 
204
- target_layers = [self.model.vision_model.post_layernorm]
 
 
205
 
206
  cam = GradCAMPlusPlus(
207
  model=model_wrapper_cam,
208
  target_layers=target_layers,
209
- reshape_transform=reshape_transform
210
  )
211
 
212
- # GradCAM needs pixel_values
213
  pixel_values = inputs.get('pixel_values')
214
- if pixel_values is None:
215
- logger.error("Explainability: Missing pixel_values")
216
- return None
217
-
218
- grayscale_cam = cam(input_tensor=pixel_values, targets=None)
219
- grayscale_cam = grayscale_cam[0, :]
220
 
221
- # Smoothing
222
- grayscale_cam = cv2.GaussianBlur(grayscale_cam, (13, 13), 0)
 
 
223
 
224
- return grayscale_cam
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  except Exception as e:
227
- logger.error(f"Grad-CAM Core Failed: {e}")
228
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import cv2
6
  from PIL import Image
7
  import logging
8
+ from typing import List, Dict, Any, Optional, Tuple, Union
9
  from pytorch_grad_cam import GradCAMPlusPlus
10
  from pytorch_grad_cam.utils.image import show_cam_on_image
11
+ from dataclasses import dataclass
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+ # =========================================================================
16
+ # CONFIGURATION & EXPERT KNOWLEDGE
17
+ # =========================================================================
18
+
19
+ @dataclass
20
+ class ExpertSegConfig:
21
+ modality: str
22
+ target_organ: str
23
+ anatomical_prompts: List[str] # For Segmentation Mask
24
+ threshold_percentile: int # Top X% activation
25
+ min_area_ratio: float
26
+ max_area_ratio: float
27
+ morphology_kernel: int
28
+
29
+ # Expert Knowledge Base
30
+ EXPERT_KNOWLEDGE = {
31
+ "Thoracic": ExpertSegConfig(
32
+ modality="CXR/CT",
33
+ target_organ="Lung Parenchyma",
34
+ anatomical_prompts=[
35
+ "lung parenchyma",
36
+ "bilateral lungs",
37
+ "pulmonary fields",
38
+ "chest x-ray lungs excluding heart"
39
+ ],
40
+ threshold_percentile=75, # Top 25%
41
+ min_area_ratio=0.15,
42
+ max_area_ratio=0.60,
43
+ morphology_kernel=7
44
+ ),
45
+ "Orthopedics": ExpertSegConfig(
46
+ modality="X-Ray",
47
+ target_organ="Bone Structure",
48
+ anatomical_prompts=[
49
+ "bone structure",
50
+ "knee joint",
51
+ "cortical bone",
52
+ "skeletal anatomy"
53
+ ],
54
+ threshold_percentile=85, # Top 15%
55
+ min_area_ratio=0.05,
56
+ max_area_ratio=0.50,
57
+ morphology_kernel=5
58
+ ),
59
+ "Default": ExpertSegConfig(
60
+ modality="General",
61
+ target_organ="Body Part",
62
+ anatomical_prompts=["medical image body part"],
63
+ threshold_percentile=80,
64
+ min_area_ratio=0.05,
65
+ max_area_ratio=0.90,
66
+ morphology_kernel=5
67
+ )
68
+ }
69
+
70
  # =========================================================================
71
  # WRAPPERS AND UTILS
72
  # =========================================================================
73
 
74
  class HuggingFaceWeirdCLIPWrapper(nn.Module):
75
+ """
76
+ Wraps SigLIP to act like a standard classifier for Grad-CAM.
77
+ Target: Cosine Similarity Score.
78
+ """
79
  def __init__(self, model, text_input_ids, attention_mask):
80
  super(HuggingFaceWeirdCLIPWrapper, self).__init__()
81
  self.model = model
 
88
  input_ids=self.text_input_ids,
89
  attention_mask=self.attention_mask
90
  )
91
+ # outputs.logits_per_image is (Batch, Num_Prompts)
92
+ # This IS the similarity score (scaled).
93
+ # Grad-CAM++ will derive gradients relative to this score.
94
  return outputs.logits_per_image
95
 
96
  def reshape_transform(tensor, width=32, height=32):
97
  """Reshape Transformer attention/embeddings for Grad-CAM."""
98
+ # Squeeze CLS if present logic (usually SigLIP doesn't have it in last layers same way)
99
+ # Tensor: (Batch, Num_Tokens, Dim)
 
 
 
 
100
  num_tokens = tensor.size(1)
101
  side = int(np.sqrt(num_tokens))
102
  result = tensor.reshape(tensor.size(0), side, side, tensor.size(2))
103
+ # Bring channels first: (B, C, H, W)
 
104
  result = result.transpose(2, 3).transpose(1, 2)
105
  return result
106
 
107
  # =========================================================================
108
+ # EXPERT+ EXPLAINABILITY ENGINE
109
  # =========================================================================
110
 
111
  class ExplainabilityEngine:
112
  def __init__(self, model_wrapper):
 
 
 
113
  self.wrapper = model_wrapper
114
  self.model = model_wrapper.model
115
  self.processor = model_wrapper.processor
116
+ self.device = self.model.device
117
 
118
+ def _get_expert_config(self, anatomical_context: str) -> ExpertSegConfig:
119
+ if "lung" in anatomical_context.lower():
120
+ return EXPERT_KNOWLEDGE["Thoracic"]
121
+ elif "bone" in anatomical_context.lower() or "knee" in anatomical_context.lower():
122
+ return EXPERT_KNOWLEDGE["Orthopedics"]
123
+ else:
124
+ base = EXPERT_KNOWLEDGE["Default"]
125
+ base.anatomical_prompts = [anatomical_context]
126
+ return base
127
+
128
+ def generate_expert_mask(self, image: Image.Image, config: ExpertSegConfig) -> Dict[str, Any]:
129
  """
130
+ Expert Segmentation:
131
+ Multi-Prompt Ensembling -> Patch Similarity -> Adaptive Threshold -> Morphology -> Validation.
 
 
 
 
 
132
  """
133
+ audit = {
134
+ "seg_prompts": config.anatomical_prompts,
135
+ "seg_status": "INIT"
136
+ }
137
  try:
138
+ w, h = image.size
139
+ inputs = self.processor(text=config.anatomical_prompts, images=image, padding="max_length", return_tensors="pt")
140
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
 
141
 
142
  with torch.no_grad():
143
+ # Vision Features (1, Token, Dim)
144
+ vision_outputs = self.model.vision_model(
145
+ pixel_values=inputs["pixel_values"],
146
+ output_hidden_states=True
147
+ )
148
+ last_hidden_state = vision_outputs.last_hidden_state
149
+
150
+ # Text Features (Prompts, Dim)
151
  text_outputs = self.model.text_model(
152
  input_ids=inputs["input_ids"],
153
  attention_mask=inputs["attention_mask"]
 
155
  text_embeds = text_outputs.pooler_output
156
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
157
 
158
+ # Similarity: (1, T, D) @ (D, P) -> (1, T, P)
159
+ sim_map = torch.matmul(last_hidden_state, text_embeds.t())
160
+ # Mean across Prompts -> (1, T)
161
+ sim_map = sim_map.mean(dim=2)
 
 
 
162
 
163
+ # Reshape & Upscale
164
+ num_tokens = sim_map.size(1)
165
+ side = int(np.sqrt(num_tokens))
166
+ sim_grid = sim_map.reshape(1, side, side)
 
167
 
168
+ sim_grid = torch.nn.functional.interpolate(
169
+ sim_grid.unsqueeze(0),
170
+ size=(h, w),
171
+ mode='bilinear',
172
+ align_corners=False
173
+ ).squeeze().cpu().numpy()
 
 
 
174
 
175
+ # Adaptive Thresholding (Percentile)
176
+ thresh = np.percentile(sim_grid, config.threshold_percentile)
177
+ binary_mask = (sim_grid > thresh).astype(np.float32)
178
+ audit["seg_threshold"] = float(thresh)
179
+
180
+ # Morphological Cleaning
181
+ kernel = np.ones((config.morphology_kernel, config.morphology_kernel), np.uint8)
182
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) # Remove noise
183
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) # Fill holes
184
+ binary_mask = cv2.GaussianBlur(binary_mask, (15, 15), 0) # Smooth contours
185
+ binary_mask = (binary_mask - binary_mask.min()) / (binary_mask.max() - binary_mask.min() + 1e-8)
186
 
187
+ # Validation
188
+ val = self._validate_mask(binary_mask, config)
189
+ audit["seg_validation"] = val
190
 
191
+ if not val["valid"]:
192
+ logger.warning(f"Mask Invalid: {val['reason']}")
193
+ return {"mask": None, "audit": audit}
 
 
 
194
 
195
+ return {"mask": binary_mask, "audit": audit}
196
 
197
  except Exception as e:
198
+ logger.error(f"Segmentation Failed: {e}")
199
+ audit["seg_error"] = str(e)
200
+ return {"mask": None, "audit": audit}
201
 
202
+ def _validate_mask(self, mask: np.ndarray, config: ExpertSegConfig) -> Dict[str, Any]:
203
+ area_ratio = np.sum(mask > 0.5) / mask.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ if area_ratio < config.min_area_ratio:
206
+ return {"valid": False, "reason": f"Small Area: {area_ratio:.2f} < {config.min_area_ratio}"}
207
+ if area_ratio > config.max_area_ratio:
208
+ return {"valid": False, "reason": f"Large Area: {area_ratio:.2f} > {config.max_area_ratio}"}
 
 
 
209
 
210
+ # Connectivity Check (Constraint: "suppression du bruit bas" / continuity)
211
+ # Ensure we have large connected components, not confetti
212
+ # For now, strict Area check + Opening usually covers this.
213
+ return {"valid": True}
 
 
 
 
 
 
 
 
 
 
214
 
215
+ def generate_expert_gradcam(self, image: Image.Image, target_prompts: List[str]) -> Dict[str, Any]:
216
+ """
217
+ Expert Grad-CAM:
218
+ 1. Multi-Prompt Ensembling (Averaging heatmaps).
219
+ 2. Layer Selection: Encoder Layer -2.
220
+ 3. Target: Cosine Score.
221
+ """
222
+ audit = {"gradcam_prompts": target_prompts, "gradcam_status": "INIT"}
223
+
224
  try:
225
+ # Prepare Inputs
226
+ inputs = self.processor(text=target_prompts, images=image, padding="max_length", return_tensors="pt")
227
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
228
 
229
+ # Robust Mask handling
 
230
  input_ids = inputs.get('input_ids')
231
  attention_mask = inputs.get('attention_mask')
232
+ if attention_mask is None and input_ids is not None:
233
+ attention_mask = torch.ones_like(input_ids)
234
 
235
+ # Wrapper
236
+ model_wrapper_cam = HuggingFaceWeirdCLIPWrapper(self.model, input_ids, attention_mask)
 
 
 
 
 
237
 
238
+ # Layer Selection: 2nd to last encoder layer (Better spatial features than last Norm)
239
+ # SigLIP structure: model.vision_model.encoder.layers
240
+ target_layers = [self.model.vision_model.encoder.layers[-2].layer_norm1]
241
 
242
  cam = GradCAMPlusPlus(
243
  model=model_wrapper_cam,
244
  target_layers=target_layers,
245
+ reshape_transform=reshape_transform # Needs to handle (B, T, D)
246
  )
247
 
 
248
  pixel_values = inputs.get('pixel_values')
 
 
 
 
 
 
249
 
250
+ # ENSEMBLING GRAD-CAM
251
+ # We want to run Grad-CAM for EACH prompt index and average them.
252
+ # Grayscale CAM output is (Batch, H, W)
253
+ # We assume Batch=1 here.
254
 
255
+ maps = []
256
+ for i in range(len(target_prompts)):
257
+ # Target Class Index = i (The index of the prompt in the logits)
258
+ # GradCAMPlusPlus targets=[ClassifierOutputTarget(i)]
259
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
260
+
261
+ targets = [ClassifierOutputTarget(i)]
262
+ grayscale_cam = cam(input_tensor=pixel_values, targets=targets)
263
+ maps.append(grayscale_cam[0, :])
264
+
265
+ # Average
266
+ avg_cam = np.mean(np.array(maps), axis=0)
267
+
268
+ # Normalization (Smart Percentile)
269
+ # Only keep top 20% intensity as significant, smooth the rest?
270
+ # Or just standard min-max? User asked for "percentile cam > 85".
271
+ # We'll normalize 0-1 then apply thresholding later or just return the map.
272
+ # Visual is usually heatmap.
273
+
274
+ avg_cam = cv2.GaussianBlur(avg_cam, (13, 13), 0)
275
+
276
+ return {"map": avg_cam, "audit": audit}
277
 
278
  except Exception as e:
279
+ logger.error(f"Grad-CAM Failed: {e}")
280
+ audit["gradcam_error"] = str(e)
281
+ return {"map": None, "audit": audit}
282
+
283
+ def explain(self, image: Image.Image, target_text: str, anatomical_context: str) -> Dict[str, Any]:
284
+ """
285
+ Final Expert Fusion Pipeline.
286
+ """
287
+ # 0. Setup
288
+ config = self._get_expert_config(anatomical_context)
289
+
290
+ # 1. Anatomical Mask (Strict Constraint)
291
+ seg_res = self.generate_expert_mask(image, config)
292
+ mask = seg_res["mask"]
293
+ audit = seg_res["audit"]
294
+
295
+ if mask is None:
296
+ # Strict Safety: No Explanation if Segmentation fails.
297
+ return {"heatmap_array": None, "heatmap_raw": None, "reliability_score": 0.0, "confidence_label": "UNSAFE", "audit": audit}
298
+
299
+ # 2. Attention Map (Multi-Prompt)
300
+ # Use target_text (Pathology) + Synonyms?
301
+ # For now, just use the provided target text in a list.
302
+ # Improvement: In future, expand `target_text` to synonyms automatically.
303
+ gradcam_res = self.generate_expert_gradcam(image, [target_text])
304
+ heatmap = gradcam_res["map"]
305
+ audit.update(gradcam_res["audit"])
306
+
307
+ if heatmap is None:
308
+ return {"heatmap_array": None, "heatmap_raw": None, "reliability_score": 0.0, "confidence_label": "LOW", "audit": audit}
309
+
310
+ # 3. Constraint Fusion
311
+ if mask.shape != heatmap.shape:
312
+ mask = cv2.resize(mask, (heatmap.shape[1], heatmap.shape[0]))
313
+
314
+ final_map = heatmap * mask
315
+
316
+ # 4. Reliability
317
+ total = np.sum(heatmap) + 1e-8
318
+ retained = np.sum(final_map)
319
+ reliability = retained / total
320
+
321
+ confidence = "HIGH" if reliability > 0.6 else "LOW"
322
+ audit["reliability_score"] = round(reliability, 4)
323
+
324
+ # 5. Visualize
325
+ img_np = np.array(image)
326
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
327
+ visualization = show_cam_on_image(img_np, final_map, use_rgb=True)
328
+
329
+ return {
330
+ "heatmap_array": visualization,
331
+ "heatmap_raw": final_map,
332
+ "reliability_score": round(reliability, 2),
333
+ "confidence_label": confidence,
334
+ "audit": audit
335
+ }