Adarsh R Shenoy commited on
Commit
9a55284
·
unverified ·
1 Parent(s): 74e1ea2

Added .dcm format

Browse files
Files changed (1) hide show
  1. cerebAI.py +48 -23
cerebAI.py CHANGED
@@ -10,21 +10,26 @@ from typing import Tuple, Optional
10
  import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
- import requests # REQUIRED FOR DOWNLOADING MODEL
 
 
14
 
15
- # --- CONFIGURATION ---
 
 
 
 
16
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
17
- DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
18
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
19
  IMAGE_SIZE = 224
20
- DEVICE = torch.device("cpu")
21
 
22
- # --- MODEL LOADING (UPDATED FOR DOWNLOAD) ---
23
  @st.cache_resource
24
  def load_model(model_url, local_path):
25
  """Downloads model from URL if not cached, and loads the weights."""
26
 
27
- # 1. Check if the file is already downloaded
28
  if not os.path.exists(local_path):
29
  st.info(f"Model not found locally. Downloading from remote repository...")
30
  try:
@@ -36,10 +41,9 @@ def load_model(model_url, local_path):
36
  f.write(chunk)
37
  st.success("Model download complete!")
38
  except Exception as e:
39
- st.error(f"FATAL ERROR: Could not download model from {model_url}. Check the URL. Error: {e}")
40
  return None
41
 
42
- # 2. Load the model weights
43
  try:
44
  model = timm.create_model('convnext_base', pretrained=False)
45
  model.reset_classifier(num_classes=len(CLASS_LABELS))
@@ -51,23 +55,44 @@ def load_model(model_url, local_path):
51
  st.error(f"Failed to load model weights from cache. Error: {e}")
52
  return None
53
 
54
- # --- HELPER FUNCTIONS ---
 
 
55
 
56
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
57
  """Denormalizes a PyTorch tensor for matplotlib visualization."""
58
  if tensor.ndim == 4:
59
- tensor = tensor.squeeze(0)
 
 
60
 
61
  mean, std = np.array([0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5])
62
  img = tensor.cpu().permute(1, 2, 0).numpy()
63
  img = (img * std) + mean
64
  return np.clip(img, 0, 1)
65
 
66
- def preprocess_image(image_bytes: bytes) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
67
- """Loads, resizes, and normalizes the image for model input."""
68
- image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
69
- if image is None: return None, None
70
- image_rgb = cv2.cvtColor(cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
73
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
@@ -76,7 +101,6 @@ def preprocess_image(image_bytes: bytes) -> Tuple[Optional[torch.Tensor], Option
76
 
77
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
78
  """Computes Integrated Gradients for the given input and class."""
79
-
80
  target_class_int = int(predicted_class_idx)
81
  input_tensor.requires_grad_(True)
82
 
@@ -99,7 +123,6 @@ def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted
99
 
100
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
101
  """Creates a Matplotlib figure for visualization."""
102
-
103
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
104
  original_image_vis = (original_image.astype(np.float32) / 255.0)
105
 
@@ -117,9 +140,9 @@ def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, p
117
  plt.tight_layout()
118
  return fig
119
 
120
- # ==============================================================================
121
  # -------------------- STREAMLIT FRONTEND --------------------
122
- # ==============================================================================
123
 
124
  st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wide")
125
  st.title("CerebAI: AI-Powered Stroke Detection")
@@ -136,7 +159,7 @@ if model is not None:
136
  'Integration Steps (Affects Accuracy & Speed)',
137
  min_value=5,
138
  max_value=50,
139
- value=20, # Default to a safe, medium-speed value
140
  step=5,
141
  help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
142
  )
@@ -146,12 +169,13 @@ if model is not None:
146
  # --- FILE UPLOAD ---
147
  st.markdown("### Upload CT Scan Image")
148
  uploaded_file = st.file_uploader(
149
- "Choose a PNG, JPG, or JPEG file",
150
- type=["png", "jpg", "jpeg"]
151
  )
152
 
153
  if uploaded_file is not None:
154
  image_bytes = uploaded_file.read()
 
155
 
156
  # --- DISPLAY AND RESULTS LAYOUT ---
157
  col1, col2 = st.columns(2)
@@ -161,7 +185,8 @@ if model is not None:
161
  st.image(image_bytes, use_container_width=True)
162
 
163
  # Run Prediction and Attribution
164
- input_tensor, original_image_rgb = preprocess_image(image_bytes)
 
165
 
166
  if input_tensor is not None:
167
  # Predict
 
10
  import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
+ import requests
14
+ import pydicom # REQUIRED FOR DICOM SUPPORT
15
+ import io # REQUIRED for reading image bytes as a file
16
 
17
+
18
+ # -------------------- CONFIGURATION & MODEL LOADING --------------------
19
+
20
+
21
+ # --- CONFIG ---
22
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
23
+ DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
24
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
25
  IMAGE_SIZE = 224
26
+ DEVICE = torch.device("cpu") # For Streamlit Cloud stability
27
 
28
+ # --- MODEL LOADING ---
29
  @st.cache_resource
30
  def load_model(model_url, local_path):
31
  """Downloads model from URL if not cached, and loads the weights."""
32
 
 
33
  if not os.path.exists(local_path):
34
  st.info(f"Model not found locally. Downloading from remote repository...")
35
  try:
 
41
  f.write(chunk)
42
  st.success("Model download complete!")
43
  except Exception as e:
44
+ st.error(f"FATAL ERROR: Could not download model. Error: {e}")
45
  return None
46
 
 
47
  try:
48
  model = timm.create_model('convnext_base', pretrained=False)
49
  model.reset_classifier(num_classes=len(CLASS_LABELS))
 
55
  st.error(f"Failed to load model weights from cache. Error: {e}")
56
  return None
57
 
58
+
59
+ # -------------------- HELPER FUNCTIONS --------------------
60
+
61
 
62
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
63
  """Denormalizes a PyTorch tensor for matplotlib visualization."""
64
  if tensor.ndim == 4:
65
+ tensor = tensor.squeeze(0).detach()
66
+ else:
67
+ tensor = tensor.detach()
68
 
69
  mean, std = np.array([0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5])
70
  img = tensor.cpu().permute(1, 2, 0).numpy()
71
  img = (img * std) + mean
72
  return np.clip(img, 0, 1)
73
 
74
+ def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
75
+ """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
76
+
77
+ # 1. READ IMAGE DATA (Handles DICOM vs Standard formats)
78
+ if file_name.lower().endswith(('.dcm', '.dicom')):
79
+ try:
80
+ dcm = pydicom.dcmread(io.BytesIO(image_bytes))
81
+ pixel_array = dcm.pixel_array.astype(np.float32)
82
+
83
+ # Simple intensity scaling for visualization/processing
84
+ pixel_array = (pixel_array - np.min(pixel_array)) / (np.max(pixel_array) - np.min(pixel_array))
85
+ pixel_array = (pixel_array * 255).astype(np.uint8)
86
+ image_grayscale = pixel_array
87
+ except Exception:
88
+ return None, None
89
+ else:
90
+ # Read standard image (PNG/JPG)
91
+ image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
92
+ if image_grayscale is None: return None, None
93
+
94
+ # 2. STANDARD PREPROCESSING (The rest of your original logic)
95
+ image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
96
 
97
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
98
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
 
101
 
102
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
103
  """Computes Integrated Gradients for the given input and class."""
 
104
  target_class_int = int(predicted_class_idx)
105
  input_tensor.requires_grad_(True)
106
 
 
123
 
124
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
125
  """Creates a Matplotlib figure for visualization."""
 
126
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
127
  original_image_vis = (original_image.astype(np.float32) / 255.0)
128
 
 
140
  plt.tight_layout()
141
  return fig
142
 
143
+
144
  # -------------------- STREAMLIT FRONTEND --------------------
145
+
146
 
147
  st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wide")
148
  st.title("CerebAI: AI-Powered Stroke Detection")
 
159
  'Integration Steps (Affects Accuracy & Speed)',
160
  min_value=5,
161
  max_value=50,
162
+ value=20,
163
  step=5,
164
  help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
165
  )
 
169
  # --- FILE UPLOAD ---
170
  st.markdown("### Upload CT Scan Image")
171
  uploaded_file = st.file_uploader(
172
+ "Choose a Dicom, PNG, JPG, or JPEG file",
173
+ type=["dcm", "dicom", "png", "jpg", "jpeg"]
174
  )
175
 
176
  if uploaded_file is not None:
177
  image_bytes = uploaded_file.read()
178
+ file_name = uploaded_file.name # Get file name for DICOM check
179
 
180
  # --- DISPLAY AND RESULTS LAYOUT ---
181
  col1, col2 = st.columns(2)
 
185
  st.image(image_bytes, use_container_width=True)
186
 
187
  # Run Prediction and Attribution
188
+ # FIX: Pass file_name to the preprocessing function
189
+ input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
190
 
191
  if input_tensor is not None:
192
  # Predict