papsofts commited on
Commit
bc25bfd
·
verified ·
1 Parent(s): 63fafea

Upload 6 files

Browse files
app.py CHANGED
@@ -1,657 +1,735 @@
1
- """Gradio Inference App for Transfer Learning Models - HuggingFace Spaces Version
2
-
3
- Features:
4
- - Auto-scan models directory or HuggingFace Hub for available models and approaches.
5
- - Dropdown selection of Model and Approach.
6
- - Dynamic architecture detection from filename (e.g., resnet50, densenet121, inception_v3, efficientnet_b0, resnet34).
7
- - Image upload and preprocessing (ImageNet normalization).
8
- - Top-K prediction display (configurable class labels).
9
- - Optional Grad-CAM visualization for interpretability.
10
- - Environment variable configuration for HuggingFace deployment.
11
- - Graceful error handling and clear user feedback.
12
-
13
- Environment Variables:
14
- - HF_TOKEN: HuggingFace API token for private repositories
15
- - MODEL_REPO_ID: HuggingFace repository containing models
16
- - NUM_CLASSES: Number of output classes (default: 2)
17
- - DEBUG: Enable debug logging
18
  """
19
- import numpy as np
20
- import pandas as pd
21
- import os
22
- import re
23
- import logging
24
- from typing import Dict, Tuple, List, Optional, Union
25
- from pathlib import Path
26
 
27
- import pydicom
28
  import torch
29
- import torch.nn.functional as F
30
- from torchvision import models, transforms
 
31
  from PIL import Image
32
- import gradio as gr
33
- from huggingface_hub import hf_hub_download, list_repo_files
34
- from dotenv import load_dotenv
 
 
35
 
36
- try:
37
- import timm # EfficientNet etc.
38
- except ImportError:
39
- timm = None
 
 
 
 
 
40
 
 
41
  try:
42
- import cv2
 
43
  except ImportError:
44
- cv2 = None
45
 
46
- import gradio_client.utils as gc_utils
 
 
47
 
48
- # Hot-patch for TypeError: argument of type 'bool' is not iterable
49
- def _safe_json_schema_to_python_type(schema, defs=None):
50
- try:
51
- # Handle if schema is a boolean (Gradio 5.x bug)
52
- if isinstance(schema, bool):
53
- # Return readable type name to avoid crash
54
- return "bool" if schema else "NoneType"
55
- # If schema is a dict, proceed as usual
56
- if isinstance(schema, dict) and "const" in schema:
57
- pass
58
- # Fall back to original internal logic if it exists
59
- if hasattr(gc_utils, "_json_schema_to_python_type_original"):
60
- return gc_utils._json_schema_to_python_type_original(schema, defs)
61
- else:
62
- # If no original stored yet, call the default one
63
- return gc_utils._json_schema_to_python_type(schema, defs)
64
- except Exception:
65
- return "Any"
66
-
67
- # Save original before replacing (only once)
68
- if not hasattr(gc_utils, "_json_schema_to_python_type_original"):
69
- gc_utils._json_schema_to_python_type_original = gc_utils._json_schema_to_python_type
70
-
71
- # Apply the patch
72
- gc_utils._json_schema_to_python_type = _safe_json_schema_to_python_type
73
-
74
- # Load environment variables
75
- load_dotenv()
76
-
77
- # Configuration from environment variables
78
- HF_TOKEN = os.getenv("HF_TOKEN")
79
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID")
80
- NUM_CLASSES = int(os.getenv("NUM_CLASSES", "2"))
81
- DEBUG = os.getenv("DEBUG", "False").lower() == "true"
82
-
83
- # Setup logging
84
- logging.basicConfig(level=logging.DEBUG if DEBUG else logging.INFO)
85
- logger = logging.getLogger(__name__)
86
-
87
- # Directory paths
88
- MODELS_DIR = Path("models")
89
- MODELS_DIR.mkdir(exist_ok=True)
90
-
91
- # Placeholder class labels; customize based on your dataset
92
- # CLASS_LABELS = {i: f"Class_{i}" for i in range(NUM_CLASSES)}
93
- CLASS_LABELS = {0:'No Pneumonia',1:'Pneumonia'}
94
-
95
- # Architecture-specific input sizes
96
- _ARCH_INPUT_SIZE = {
97
- "inception_v3": 299,
98
- # Most others default to 224
99
- }
100
 
101
- # Regex to parse weight filenames: Appr_<Approach>_<arch>.pt
102
- WEIGHT_PATTERN = re.compile(r"Appr_([A-Za-z0-9]+)_([A-Za-z0-9_]+)\.pt")
103
 
104
- ModelMap = Dict[str, Dict[str, Dict[str, str]]]
105
- _model_map_cache: Optional[ModelMap] = None
106
- # Cache for instantiated models to avoid recreation per inference
107
- _model_instance_cache: Dict[Tuple[str, str], Tuple[torch.nn.Module, str]] = {}
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- def download_models_from_hub() -> None:
111
- """Download models from HuggingFace Hub if MODEL_REPO_ID is set."""
112
- if not MODEL_REPO_ID:
113
- logger.info("No MODEL_REPO_ID set, skipping Hub download")
114
- return
115
-
116
  try:
117
- logger.info(f"Downloading models from {MODEL_REPO_ID}")
118
- repo_files = list_repo_files(MODEL_REPO_ID, token=HF_TOKEN)
119
- model_files = [f for f in repo_files if f.endswith('.pt')]
120
-
121
- for model_file in model_files:
122
- local_path = MODELS_DIR / model_file
123
- if not local_path.exists():
124
- logger.info(f"Downloading {model_file}")
125
- downloaded_path = hf_hub_download(
126
- MODEL_REPO_ID,
127
- model_file,
128
- cache_dir=str(MODELS_DIR),
129
- token=HF_TOKEN
130
- )
131
- # Move to our models directory structure
132
- local_path.parent.mkdir(parents=True, exist_ok=True)
133
- Path(downloaded_path).rename(local_path)
134
-
135
- except Exception as e:
136
- logger.warning(f"Failed to download from Hub: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
 
138
 
139
- def scan_local_models() -> ModelMap:
140
- """Scan local models directory for available models."""
141
- mapping: ModelMap = {}
142
-
143
- # First, try to download from Hub
144
- download_models_from_hub()
145
-
146
- # Scan the models directory
147
- if MODELS_DIR.exists():
148
- for item in MODELS_DIR.iterdir():
149
- if item.is_dir():
150
- # Model directory structure: models/Model_X/Appr_Y_arch.pt
151
- model_name = item.name
152
- for model_file in item.glob("*.pt"):
153
- match = WEIGHT_PATTERN.match(model_file.name)
154
- if match:
155
- appr_code, arch = match.groups()
156
- mapping.setdefault(model_name, {})[appr_code] = {
157
- "path": str(model_file),
158
- "arch": arch.lower(),
159
- }
160
- elif item.suffix == ".pt":
161
- # Flat structure: models/Appr_Y_arch.pt
162
- match = WEIGHT_PATTERN.match(item.name)
163
- if match:
164
- appr_code, arch = match.groups()
165
- model_name = f"Model_{arch}"
166
- mapping.setdefault(model_name, {})[appr_code] = {
167
- "path": str(item),
168
- "arch": arch.lower(),
169
- }
170
-
171
- return mapping
172
 
 
173
 
174
- def scan_weights(refresh: bool = False) -> ModelMap:
175
- """Scan for available models, with caching."""
176
- global _model_map_cache
177
- if _model_map_cache is not None and not refresh:
178
- return _model_map_cache
 
 
 
 
 
 
 
179
 
180
- mapping = scan_local_models()
 
 
 
 
 
 
 
181
 
182
- # Fallback: create demo models if none found
183
- if not mapping:
184
- logger.warning("No models found, creating demo entries")
185
- mapping = {
186
- "ResNet34": {
187
- "G1": {"path": "", "arch": "resnet34"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  },
189
- "InceptionV3": {
190
- "H1": {"path": "", "arch": "inception_v3"},
 
 
 
 
 
 
 
191
  }
192
  }
 
 
 
 
 
 
 
 
 
193
 
194
- _model_map_cache = mapping
195
- logger.info(f"Found models: {list(mapping.keys())}")
196
- return mapping
197
-
198
-
199
- def _create_model(arch: str, num_classes: int) -> torch.nn.Module:
200
- """Instantiate a model given architecture string."""
201
- arch = arch.lower()
202
- logger.debug(f"Creating model: {arch} with {num_classes} classes")
203
 
204
- if arch.startswith("resnet"):
205
- base = getattr(models, arch)(weights=None)
206
- base.fc = torch.nn.Linear(base.fc.in_features, num_classes)
207
- return base
208
- elif arch.startswith("densenet"):
209
- base = getattr(models, arch)(weights=None)
210
- base.classifier = torch.nn.Linear(base.classifier.in_features, num_classes)
211
- return base
212
- elif arch.startswith("inception_v3"):
213
- # Disable aux logits for lighter inference path
214
- base = models.inception_v3(weights=None, aux_logits=False)
215
- base.fc = torch.nn.Linear(base.fc.in_features, num_classes)
216
- return base
217
- elif arch.startswith("efficientnet"):
218
- if timm is None:
219
- raise RuntimeError("timm not installed; cannot create EfficientNet model.")
220
- base = timm.create_model(arch, pretrained=False, num_classes=num_classes)
221
- return base
222
 
223
- # Fallback using timm if available
224
- if timm is not None:
225
- try:
226
- return timm.create_model(arch, pretrained=False, num_classes=num_classes)
227
- except Exception as e:
228
- logger.warning(f"timm failed to create {arch}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- raise ValueError(f"Unsupported architecture: {arch}")
231
-
232
-
233
- def _resolve_device() -> torch.device:
234
- """Resolve device based on DEVICE env var (cpu|cuda|auto)."""
235
- device_pref = os.getenv("DEVICE", "auto").lower()
236
- if device_pref == "cpu":
237
- return torch.device("cpu")
238
- if device_pref == "cuda":
239
- if torch.cuda.is_available():
240
- return torch.device("cuda")
241
- logger.warning("DEVICE=cuda requested but CUDA not available; falling back to CPU")
242
- return torch.device("cpu")
243
- # auto
244
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
245
-
246
-
247
- def load_model(model_name: str, approach: str, use_cache: bool = True) -> Tuple[torch.nn.Module, str]:
248
- """Load a model given model name and approach, optionally using cache."""
249
- if use_cache:
250
- cache_key = (model_name, approach)
251
- if cache_key in _model_instance_cache:
252
- return _model_instance_cache[cache_key]
253
-
254
- mapping = scan_weights()
255
- if model_name not in mapping:
256
- raise ValueError(f"Model '{model_name}' not found in {list(mapping.keys())}")
257
- if approach not in mapping[model_name]:
258
- raise ValueError(
259
- f"Approach '{approach}' not found for model '{model_name}'. Available: {list(mapping[model_name].keys())}"
260
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- info = mapping[model_name][approach]
263
- arch = info["arch"]
264
- weight_path = info["path"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- device = _resolve_device()
267
- logger.info(f"Loading {model_name}/{approach} ({arch}) on {device}")
 
268
 
269
- model = _create_model(arch, NUM_CLASSES)
 
270
 
271
- # Load weights if path exists
272
- if weight_path and Path(weight_path).exists():
273
- try:
274
- state = torch.load(weight_path, map_location=device)
275
- if isinstance(state, dict) and "state_dict" in state:
276
- state = state["state_dict"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Remove any DistributedDataParallel prefixes
279
- new_state = {k.replace("module.", ""): v for k, v in state.items()}
280
 
281
- missing, unexpected = model.load_state_dict(new_state, strict=False)
282
- if missing:
283
- logger.warning(f"Missing keys: {missing[:5]}...")
284
- if unexpected:
285
- logger.warning(f"Unexpected keys: {unexpected[:5]}...")
286
 
287
- except Exception as e:
288
- logger.warning(f"Failed to load weights from {weight_path}: {e}")
289
- else:
290
- logger.warning(f"No weights file found at {weight_path}, using random weights")
291
 
292
- model.to(device)
293
- model.eval()
 
 
 
 
 
 
 
294
 
295
- if use_cache:
296
- _model_instance_cache[cache_key] = (model, arch)
297
 
298
- return model, arch
 
 
299
 
300
- def load_image_any(file_obj):
301
  """
302
- Accepts .dcm, .png, .jpg, .jpeg, etc.
303
- Returns a Pillow RGB image.
 
 
 
 
 
 
304
  """
305
- # Handle case when file is removed
306
- if file_obj is None:
307
- return None
308
 
309
- if isinstance(file_obj, str):
310
- filepath = file_obj
311
- else:
312
- filepath = file_obj.name # if it's a tempfile
313
-
314
- ext = os.path.splitext(filepath)[1].lower()
315
-
316
- if ext == ".dcm":
317
- # Read DICOM file
318
- ds = pydicom.dcmread(filepath)
319
- img_array = ds.pixel_array.astype(float)
320
-
321
- # Normalize to 0-255
322
- img_array = (np.maximum(img_array, 0) / img_array.max()) * 255.0
323
- img_array = np.uint8(img_array)
324
- # Convert grayscale to RGB
325
- img_rgb = Image.fromarray(img_array).convert("RGB")
326
- return img_rgb
327
- else:
328
- # Standard formats
329
- return Image.open(filepath).convert("RGB")
330
-
331
-
332
- _image_cache_transform: Dict[int, transforms.Compose] = {}
333
-
334
-
335
- def get_transform(arch: str) -> transforms.Compose:
336
- """Get image transforms for the given architecture."""
337
- size = _ARCH_INPUT_SIZE.get(arch, 224)
338
- if size in _image_cache_transform:
339
- return _image_cache_transform[size]
340
-
341
- t = transforms.Compose([
342
- transforms.Resize((size, size)),
343
- transforms.ToTensor(),
344
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
345
- ])
346
- _image_cache_transform[size] = t
347
- return t
348
-
349
-
350
- # Grad-CAM utilities
351
- class GradCAM:
352
- """Grad-CAM implementation for generating attention heatmaps."""
353
-
354
- def __init__(self, model: torch.nn.Module, target_layer: Optional[str] = None):
355
- self.model = model
356
- self.model.eval()
357
- self.target_layer = target_layer
358
- self.activations = None
359
- self.gradients = None
360
-
361
- # Try to automatically pick a layer if not provided
362
- if target_layer is None:
363
- layer = None
364
- # Common patterns for different architectures
365
- layer_candidates = ["layer4", "features.denseblock4", "blocks.6", "conv_head", "features"]
366
- for cand in layer_candidates:
367
- parts = cand.split('.')
368
- current = model
369
- try:
370
- for part in parts:
371
- current = getattr(current, part)
372
- layer = current
373
- break
374
- except AttributeError:
375
- continue
376
- self.target_module = layer if layer is not None else model
377
  else:
378
- self.target_module = dict(model.named_modules()).get(target_layer, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
- def fwd_hook(_, __, output):
381
- self.activations = output.detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- def bwd_hook(_, grad_input, grad_output):
384
- if grad_output[0] is not None:
385
- self.gradients = grad_output[0].detach()
386
 
387
- self.target_module.register_forward_hook(fwd_hook)
388
- # Use full backward hook (non-deprecated) for gradient capture
389
- try:
390
- self.target_module.register_full_backward_hook(bwd_hook)
391
- except AttributeError:
392
- # Fallback if running older torch without full hook
393
- self.target_module.register_backward_hook(bwd_hook)
394
-
395
- def generate(self, tensor: torch.Tensor, class_idx: Optional[int] = None) -> torch.Tensor:
396
- """Generate Grad-CAM heatmap."""
397
- tensor = tensor.requires_grad_(True)
398
- logits = self.model(tensor)
399
- if isinstance(logits, tuple): # Inception may return (logits, aux)
400
- logits = logits[0]
401
-
402
- if class_idx is None:
403
- class_idx = logits.argmax(dim=1).item()
404
-
405
- score = logits[:, class_idx]
406
- score.backward(retain_graph=True)
407
-
408
- # Compute weights
409
- grads = self.gradients # [B, C, H, W]
410
- acts = self.activations
411
-
412
- if grads is None or acts is None:
413
- raise RuntimeError("GradCAM hooks did not capture activations/gradients")
414
-
415
- weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
416
- cam = (weights * acts).sum(dim=1, keepdim=True)
417
- cam = F.relu(cam)
418
- cam = F.interpolate(cam, size=tensor.shape[2:], mode="bilinear", align_corners=False)
419
-
420
- # Normalize
421
- cam_min, cam_max = cam.min(), cam.max()
422
- cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
423
- return cam.squeeze(0).squeeze(0) # [H, W]
424
-
425
-
426
- def predict(image: Image.Image, model_name: str, approach: str, grad_cam: bool = False, top_k: int = 5):
427
- """Run inference on the uploaded image."""
428
- if image is None:
429
- return [], None, None
430
- else:
431
- image = load_image_any(image)
432
- try:
433
- model, arch = load_model(model_name, approach, use_cache=True)
434
  except Exception as e:
435
- logger.error(f"Model loading failed: {e}")
436
- error_df = [["Error", 0.0, str(e)]]
437
- return error_df, image, None
438
 
439
- try:
440
- transform = get_transform(arch)
441
- tensor = transform(image).unsqueeze(0)
442
- device = next(model.parameters()).device
443
- tensor = tensor.to(device)
444
 
445
- with torch.no_grad():
446
- out = model(tensor)
447
- if isinstance(out, tuple): # Inception
448
- out = out[0]
449
- probs = F.softmax(out, dim=1).cpu().squeeze(0)
450
-
451
- top_k = min(top_k, probs.shape[0])
452
- top_probs, top_indices = torch.topk(probs, top_k)
453
-
454
- results = []
455
- for p, idx in zip(top_probs.tolist(), top_indices.tolist()):
456
- label = CLASS_LABELS.get(idx, f"Class_{idx}")
457
- results.append([label, f"{p * 100:.2f}%"])
458
-
459
- cam_img = None
460
- if grad_cam:
461
- try:
462
- gcam = GradCAM(model)
463
- cam = gcam.generate(tensor)
464
-
465
- # Convert cam to PIL heatmap overlay
466
- if cv2 is not None:
467
- import numpy as np
468
- base_img = image.resize((cam.shape[1], cam.shape[0]))
469
- base_arr = np.array(base_img)
470
- heatmap = (cam.cpu().numpy() * 255).astype('uint8')
471
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
472
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
473
- overlay = (0.4 * heatmap + 0.6 * base_arr).astype('uint8')
474
- cam_img = Image.fromarray(overlay)
475
- else:
476
- # Fallback without OpenCV
477
- import numpy as np
478
- cam_np = cam.cpu().numpy()
479
- cam_img = Image.fromarray((cam_np * 255).astype('uint8'))
480
-
481
- except Exception as e:
482
- logger.warning(f"Grad-CAM failed: {e}")
483
- cam_img = Image.new('RGB', (224, 224), color=(255, 100, 100))
484
-
485
- return results, image, cam_img
486
 
487
- except Exception as e:
488
- logger.error(f"Prediction failed: {e}")
489
- error_df = [["Error", 0.0, str(e)]]
490
- return error_df, image, None
491
-
492
- def build_interface():
493
- """Build the Gradio interface."""
494
- mapping = scan_weights()
495
- model_choices = sorted(mapping.keys())
496
-
497
- with gr.Blocks(
498
- title="Transfer Learning Inference",
499
- theme=gr.themes.Soft(),
500
- css="""
501
- .gradio-container {
502
- max-width: 1200px;
503
- margin: auto;
504
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  """
506
- ) as demo:
507
- gr.Markdown(
508
- """
509
- # Transfer Learning Inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
 
511
- Select a pre-trained model and approach, upload an image, and view predictions with optional attention visualization.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
- **Available Models:** ResNet, Inception V3, and Ensemble
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  """
515
- )
516
-
517
- with gr.Row():
518
- with gr.Column(scale=1):
519
- model_dd = gr.Dropdown(
520
- choices=model_choices,
521
- label="Model",
522
- value=model_choices[0] if model_choices else None,
523
- info="Select the model architecture"
524
- )
525
- approach_dd = gr.Dropdown(
526
- choices=[],
527
- label="Approach",
528
- info="Select the training approach/variant"
529
- )
530
- grad_cam_cb = gr.Checkbox(
531
- label="Generate Grad-CAM",
532
- value=False,
533
- info="Show attention heatmap overlay"
534
- )
535
-
536
- with gr.Column(scale=2):
537
- image_in = gr.File(
538
- file_types=["image", ".dcm"],
539
- label="Upload a chest x-Ray image for classification (.png, .jpg, .jpeg, .dcm)",
540
- )
541
- image_preview = gr.Image(interactive=False, size=(224, 224))
542
- original_out = gr.Image(interactive=False, size=(224, 224))
543
- cam_out = gr.Image(interactive=False, size=(224, 224))
544
- #image_in.change(fn=load_image_any, inputs=image_in, outputs=image_preview)
545
- image_in.change(
546
- fn=load_image_any,
547
- inputs=image_in,
548
- outputs=[image_preview, original_out, cam_out],
549
- show_progress=False
550
- )
551
-
552
- submit_btn = gr.Button("Run Inference", variant="primary", size="lg")
553
-
554
- with gr.Row():
555
- with gr.Column():
556
- results_out = gr.Dataframe(
557
- headers=["Label", "Probability"],
558
- datatype=["str", "number"],
559
- label="Top Predictions",
560
- interactive=False
561
- )
562
-
563
- with gr.Column():
564
- with gr.Row():
565
- original_out = gr.Image(label="Original Image", interactive=False)
566
- cam_out = gr.Image(label="Grad-CAM Overlay", interactive=False)
567
-
568
- # Add model info
569
- with gr.Accordion("Model Information", open=False):
570
- gr.Markdown(f"""
571
- - **Number of Classes:** {NUM_CLASSES}
572
- - **Available Models:** {len(model_choices)}
573
- - **Environment:** {'HuggingFace Spaces' if MODEL_REPO_ID else 'Local'}
574
- """)
575
-
576
- def update_approaches(selected_model):
577
- if not selected_model:
578
- return gr.update(choices=[], value=None)
579
- mapping_local = scan_weights()
580
- apprs = sorted(mapping_local.get(selected_model, {}).keys())
581
- value = apprs[0] if apprs else None
582
- return gr.update(choices=apprs, value=value)
583
-
584
- model_dd.change(fn=update_approaches, inputs=model_dd, outputs=approach_dd)
585
-
586
- submit_btn.click(
587
- fn=predict,
588
- inputs=[image_in, model_dd, approach_dd, grad_cam_cb],
589
- outputs=[results_out, original_out, cam_out],
590
  )
591
 
592
- def init_approaches():
593
- if not model_choices:
594
- return gr.update(choices=[], value=None)
595
- mapping_local = scan_weights()
596
- apprs = sorted(mapping_local.get(model_choices[0], {}).keys())
597
- value = apprs[0] if apprs else None
598
- return gr.update(choices=apprs, value=value)
599
-
600
- # Initialize approaches for first model
601
- if model_choices:
602
- demo.load(fn=init_approaches, inputs=[], outputs=approach_dd)
603
-
604
- return demo
605
 
606
- def safe_gradio_launch(demo):
607
- try:
608
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
609
- except TypeError as e:
610
- if "argument of type 'bool' is not iterable" in str(e):
611
- logger.warning("Gradio schema bug detected, restarting with share=True fallback.")
612
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=False)
613
- else:
614
- raise
615
 
616
- def main():
617
- """Main function to launch the Gradio app."""
618
- logger.info("Starting Transfer Learning Inference App")
619
- demo = build_interface()
620
-
621
- # Configuration for different environments
622
- server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
623
- server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
624
- share = os.getenv("GRADIO_SHARE", "False").lower() == "true"
625
-
626
- # demo.launch(
627
- # server_name=server_name,
628
- # server_port=server_port,
629
- # share=share,
630
- # show_error=True
631
- # )
632
 
633
- safe_gradio_launch(demo)
 
 
 
634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
  if __name__ == "__main__":
637
- # Quick internal test if requested
638
- if os.environ.get("HEADLESS_TEST") == "1":
639
- logger.info("Running headless test")
640
- mapping = scan_weights()
641
- if mapping:
642
- first_model = next(iter(mapping))
643
- first_appr = next(iter(mapping[first_model]))
644
- try:
645
- model, arch = load_model(first_model, first_appr)
646
- size = _ARCH_INPUT_SIZE.get(arch, 224)
647
- x = torch.randn(1, 3, size, size)
648
- out = model(x)
649
- if isinstance(out, tuple):
650
- out = out[0]
651
- logger.info(f"Test forward output shape: {out.shape}")
652
- except Exception as e:
653
- logger.error(f"Test failed: {e}")
654
- else:
655
- logger.warning("No models found for testing")
656
- else:
657
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Hugging Face Gradio App for Pneumonia Detection Ensemble
3
+ Supports JPEG, PNG, and DICOM image formats
4
+ """
 
 
 
 
5
 
6
+ import gradio as gr
7
  import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as transforms
10
+ import torchvision.models as models
11
  from PIL import Image
12
+ import numpy as np
13
+ import json
14
+ from pathlib import Path
15
+ import io
16
+ import os
17
 
18
+ # ----------------------------------------------------------------------------
19
+ # Debug / Diagnostics Configuration
20
+ # Set environment variable CLINICAL_DEBUG=1 to enable verbose logging
21
+ # ----------------------------------------------------------------------------
22
+ DEBUG = os.getenv("CLINICAL_DEBUG", "0") in ("1", "true", "True")
23
+
24
+ def _dbg(msg):
25
+ if DEBUG:
26
+ print(f"[DEBUG] {msg}")
27
 
28
+ # DICOM support
29
  try:
30
+ import pydicom
31
+ DICOM_AVAILABLE = True
32
  except ImportError:
33
+ DICOM_AVAILABLE = False
34
 
35
+ # ============================================================================
36
+ # Model Architectures (simplified versions for deployment)
37
+ # ============================================================================
38
 
39
+ class MobileNetV2Model(nn.Module):
40
+ def __init__(self, num_classes=2):
41
+ super(MobileNetV2Model, self).__init__()
42
+ self.model = models.mobilenet_v2(weights=None)
43
+ self.model.classifier[1] = nn.Linear(1280, num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def forward(self, x):
46
+ return self.model(x)
47
 
 
 
 
 
48
 
49
+ class ResNet50Model(nn.Module):
50
+ def __init__(self, num_classes=2):
51
+ super(ResNet50Model, self).__init__()
52
+ self.model = models.resnet50(weights=None)
53
+ self.model.fc = nn.Linear(2048, num_classes)
54
+
55
+ def forward(self, x):
56
+ return self.model(x)
57
+
58
+
59
+ class EfficientNetB0Model(nn.Module):
60
+ def __init__(self, num_classes=2):
61
+ super(EfficientNetB0Model, self).__init__()
62
+ try:
63
+ from torchvision.models import efficientnet_b0
64
+ self.model = efficientnet_b0(weights=None)
65
+ except:
66
+ self.model = models.efficientnet_b0(weights=None)
67
+ num_features = self.model.classifier[1].in_features
68
+ self.model.classifier[1] = nn.Linear(num_features, num_classes)
69
+
70
+ def forward(self, x):
71
+ return self.model(x)
72
+
73
+
74
+ class VGG19Model(nn.Module):
75
+ def __init__(self, num_classes=2):
76
+ super(VGG19Model, self).__init__()
77
+ self.model = models.vgg19(weights=None)
78
+ self.model.classifier[6] = nn.Linear(4096, num_classes)
79
+
80
+ def forward(self, x):
81
+ return self.model(x)
82
+
83
+
84
+ class DenseNet101Model(nn.Module):
85
+ def __init__(self, num_classes=2):
86
+ super(DenseNet101Model, self).__init__()
87
+ self.model = models.densenet101(weights=None)
88
+ self.model.classifier = nn.Linear(1024, num_classes)
89
+
90
+ def forward(self, x):
91
+ return self.model(x)
92
+
93
+
94
+ # ============================================================================
95
+ # DICOM Processing Functions
96
+ # ============================================================================
97
+
98
+ def process_dicom_file(file_obj):
99
+ """Process DICOM file and convert to PIL Image with improved medical handling.
100
+
101
+ Adds support for:
102
+ - RescaleSlope / RescaleIntercept
103
+ - PhotometricInterpretation inversion (MONOCHROME1)
104
+ - Float window center/width handling
105
+ - Detailed pixel statistics for debugging
106
+ """
107
+ if not DICOM_AVAILABLE:
108
+ raise ValueError("DICOM support not available. Please install pydicom.")
109
 
 
 
 
 
 
 
110
  try:
111
+ # Read DICOM
112
+ ds = pydicom.dcmread(file_obj.name if hasattr(file_obj, 'name') else file_obj)
113
+ pixel_array = ds.pixel_array.astype(np.float32)
114
+
115
+ _dbg(f"DICOM original shape={pixel_array.shape} dtype={pixel_array.dtype} min={pixel_array.min():.2f} max={pixel_array.max():.2f}")
116
+
117
+ # Apply rescale if present
118
+ slope = float(getattr(ds, 'RescaleSlope', 1.0))
119
+ intercept = float(getattr(ds, 'RescaleIntercept', 0.0))
120
+ if slope != 1.0 or intercept != 0.0:
121
+ pixel_array = pixel_array * slope + intercept
122
+ _dbg(f"Applied rescale slope={slope} intercept={intercept} new_min={pixel_array.min():.2f} new_max={pixel_array.max():.2f}")
123
+
124
+ # Photometric interpretation inversion (MONOCHROME1 means high values = dark)
125
+ photometric = getattr(ds, 'PhotometricInterpretation', '').upper()
126
+ if photometric == 'MONOCHROME1':
127
+ max_val = pixel_array.max()
128
+ pixel_array = max_val - pixel_array
129
+ _dbg("Applied MONOCHROME1 inversion")
130
+
131
+ # Windowing
132
+ window_center = getattr(ds, 'WindowCenter', None)
133
+ window_width = getattr(ds, 'WindowWidth', None)
134
+ if window_center is not None and window_width is not None:
135
+ if isinstance(window_center, (list, tuple)): window_center = float(window_center[0])
136
+ if isinstance(window_width, (list, tuple)): window_width = float(window_width[0])
137
+ window_min = window_center - window_width / 2.0
138
+ window_max = window_center + window_width / 2.0
139
+ pixel_array = np.clip(pixel_array, window_min, window_max)
140
+ pixel_array = (pixel_array - window_min) / max(window_max - window_min, 1e-6)
141
+ _dbg(f"Applied windowing center={window_center} width={window_width} -> min={window_min:.2f} max={window_max:.2f}")
142
+ else:
143
+ # Min-max normalize
144
+ pmin, pmax = pixel_array.min(), pixel_array.max()
145
+ pixel_array = (pixel_array - pmin) / max(pmax - pmin, 1e-6)
146
+ _dbg("Applied min-max normalization (no window tags)")
147
 
148
+ # Scale to 0-255
149
+ pixel_array = (pixel_array * 255.0).clip(0, 255).astype(np.uint8)
150
+ image = Image.fromarray(pixel_array, mode='L')
151
 
152
+ # Optional mild contrast enhancement
153
+ try:
154
+ from PIL import ImageEnhance
155
+ enhancer = ImageEnhance.Contrast(image)
156
+ image = enhancer.enhance(1.15)
157
+ except Exception:
158
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ image = image.convert('RGB')
161
 
162
+ # Log summary stats
163
+ arr = np.array(image) # RGB
164
+ _dbg(f"Post-process RGB stats: mean={arr.mean():.2f} std={arr.std():.2f} min={arr.min()} max={arr.max()}")
165
+
166
+ return image
167
+ except Exception as e:
168
+ raise ValueError(f"Error processing DICOM file: {str(e)}")
169
+
170
+
171
+ def process_uploaded_image(file_obj):
172
+ """
173
+ Process uploaded image file (JPEG, PNG, or DICOM)
174
 
175
+ Args:
176
+ file_obj: File object from Gradio upload
177
+
178
+ Returns:
179
+ PIL Image object
180
+ """
181
+ if file_obj is None:
182
+ return None
183
 
184
+ try:
185
+ # Check file extension
186
+ file_name = getattr(file_obj, 'name', '').lower()
187
+
188
+ if file_name.endswith(('.dcm', '.dicom')):
189
+ # Process as DICOM
190
+ return process_dicom_file(file_obj)
191
+ else:
192
+ # Process as regular image
193
+ if hasattr(file_obj, 'name'):
194
+ # File path provided
195
+ image = Image.open(file_obj.name)
196
+ else:
197
+ # File object provided
198
+ image = Image.open(file_obj)
199
+
200
+ # Ensure RGB
201
+ if image.mode != 'RGB':
202
+ image = image.convert('RGB')
203
+
204
+ return image
205
+
206
+ except Exception as e:
207
+ raise ValueError(f"Error processing image file: {str(e)}")
208
+
209
+
210
+ # ============================================================================
211
+ # Model Classes
212
+ # ============================================================================
213
+
214
+ class PneumoniaModelSystem:
215
+ """Flexible model system supporting both individual models and ensemble"""
216
+ def __init__(self, device='cpu'):
217
+ self.device = device
218
+ self.models = {}
219
+ self.transform = transforms.Compose([
220
+ transforms.Resize((224, 224)),
221
+ transforms.ToTensor(),
222
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
223
+ std=[0.229, 0.224, 0.225])
224
+ ])
225
+
226
+ # Model definitions with their architectures and weights
227
+ self.model_definitions = {
228
+ 'Model_A1_7CB_Appr_D': {
229
+ 'architecture': 'VGG19',
230
+ 'file': 'Model_A1_7CB_Appr_D.pt',
231
+ 'description': 'VGG19 - Model A1 with 7CB approach D'
232
+ },
233
+ 'Model_C_Appr_B': {
234
+ 'architecture': 'MobileNetV2',
235
+ 'file': 'Model_C_Appr_B.pt',
236
+ 'description': 'MobileNetV2 - Model C with approach B'
237
+ },
238
+ 'Model_F_Appr_B': {
239
+ 'architecture': 'ResNet50',
240
+ 'file': 'Model_F_Appr_B.pt',
241
+ 'description': 'ResNet50 - Model F with approach B'
242
  },
243
+ 'Model_G_Appr_B': {
244
+ 'architecture': 'EfficientNet-B0',
245
+ 'file': 'Model_G_Appr_B.pt',
246
+ 'description': 'EfficientNet-B0 - Model G with approach B'
247
+ },
248
+ 'Model_H_Appr_B': {
249
+ 'architecture': 'DenseNet101',
250
+ 'file': 'Model_H_Appr_B.pt',
251
+ 'description': 'DenseNet101 - Model H with approach B'
252
  }
253
  }
254
+
255
+ # Ensemble configuration
256
+ self.ensemble_weights = {
257
+ 'Model_A1_7CB_Appr_D': 0.30, # VGG19 - Higher weight
258
+ 'Model_C_Appr_B': 0.175, # MobileNetV2
259
+ 'Model_F_Appr_B': 0.175, # ResNet50
260
+ 'Model_G_Appr_B': 0.175, # EfficientNet-B0
261
+ 'Model_H_Appr_B': 0.175 # DenseNet101
262
+ }
263
 
 
 
 
 
 
 
 
 
 
264
 
265
+ def _create_model(self, architecture):
266
+ """Create a model instance based on architecture type"""
267
+ if architecture == 'MobileNetV2':
268
+ return MobileNetV2Model(num_classes=2).to(self.device)
269
+ elif architecture == 'ResNet50':
270
+ return ResNet50Model(num_classes=2).to(self.device)
271
+ elif architecture == 'EfficientNet-B0':
272
+ return EfficientNetB0Model(num_classes=2).to(self.device)
273
+ elif architecture == 'VGG19':
274
+ return VGG19Model(num_classes=2).to(self.device)
275
+ elif architecture == 'DenseNet101':
276
+ return DenseNet101Model(num_classes=2).to(self.device)
277
+ else:
278
+ raise ValueError(f"Unknown architecture: {architecture}")
 
 
 
 
279
 
280
+ def load_models(self, model_dir='models'):
281
+ """Load all available models from directory"""
282
+ model_dir = Path(model_dir)
283
+ loaded_models = {}
284
+
285
+ load_kwargs = {"map_location": self.device, "weights_only": False}
286
+
287
+ for model_name, model_info in self.model_definitions.items():
288
+ model_path = model_dir / model_info['file']
289
+ if model_path.exists():
290
+ try:
291
+ model = self._create_model(model_info['architecture'])
292
+ model.load_state_dict(torch.load(model_path, **load_kwargs))
293
+ model.eval()
294
+ loaded_models[model_name] = {
295
+ 'model': model,
296
+ 'info': model_info
297
+ }
298
+ _dbg(f"Loaded {model_name} ({model_info['architecture']})")
299
+ except Exception as e:
300
+ print(f"Warning: Could not load {model_name}: {str(e)}")
301
+ else:
302
+ print(f"Warning: Model file not found: {model_path}")
303
+
304
+ self.models = loaded_models
305
+ return self
306
 
307
+ def get_available_models(self):
308
+ """Get list of available model names"""
309
+ return list(self.models.keys())
310
+
311
+ def predict_single_model(self, image, model_name):
312
+ """
313
+ Predict using a single specified model
314
+
315
+ Args:
316
+ image: PIL Image
317
+ model_name: Name of the model to use
318
+
319
+ Returns:
320
+ dict with predictions and probabilities
321
+ """
322
+ if model_name not in self.models:
323
+ raise ValueError(f"Model {model_name} not available")
324
+
325
+ # Convert to PIL if needed
326
+ if isinstance(image, np.ndarray):
327
+ image = Image.fromarray(image)
328
+
329
+ # Ensure RGB
330
+ if image.mode != 'RGB':
331
+ image = image.convert('RGB')
332
+
333
+ # Transform
334
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
335
+
336
+ model = self.models[model_name]['model']
337
+ model_info = self.models[model_name]['info']
338
+
339
+ with torch.no_grad():
340
+ outputs = model(img_tensor)
341
+ probs = torch.softmax(outputs, dim=1)
342
+
343
+ probs_numpy = probs[0].cpu().numpy()
344
+ prediction_index = int(probs_numpy.argmax())
345
+ predicted_label = 'PNEUMONIA' if prediction_index == 1 else 'NORMAL'
346
+
347
+ _dbg(f"{model_name} logits={outputs.cpu().numpy()} probs={probs_numpy} label={predicted_label}")
348
+
349
+ result = {
350
+ 'prediction': predicted_label,
351
+ 'confidence': float(probs_numpy[prediction_index]),
352
+ 'pneumonia_probability': float(probs_numpy[1]),
353
+ 'normal_probability': float(probs_numpy[0]),
354
+ 'model_used': model_name,
355
+ 'model_architecture': model_info['architecture'],
356
+ 'model_description': model_info['description']
357
+ }
358
+ return result
359
 
360
+ def predict_ensemble(self, image, selected_models=None):
361
+ """
362
+ Predict using ensemble of models
363
+
364
+ Args:
365
+ image: PIL Image
366
+ selected_models: List of model names to include in ensemble, or None for all
367
+
368
+ Returns:
369
+ dict with predictions and probabilities
370
+ """
371
+ if selected_models is None:
372
+ selected_models = list(self.ensemble_weights.keys())
373
+
374
+ # Filter to only available models
375
+ available_models = [m for m in selected_models if m in self.models]
376
+ if not available_models:
377
+ raise ValueError("No valid models available for ensemble")
378
+
379
+ # Convert to PIL if needed
380
+ if isinstance(image, np.ndarray):
381
+ image = Image.fromarray(image)
382
 
383
+ # Ensure RGB
384
+ if image.mode != 'RGB':
385
+ image = image.convert('RGB')
386
 
387
+ # Transform
388
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
389
 
390
+ # Get predictions from each model
391
+ per_model = []
392
+ with torch.no_grad():
393
+ ensemble_probs = torch.zeros(1, 2).to(self.device)
394
+ total_weight = 0
395
+
396
+ for model_name in available_models:
397
+ model = self.models[model_name]['model']
398
+ weight = self.ensemble_weights.get(model_name, 1.0)
399
+
400
+ outputs = model(img_tensor)
401
+ probs = torch.softmax(outputs, dim=1)
402
+ ensemble_probs += weight * probs
403
+ total_weight += weight
404
+
405
+ per_model.append({
406
+ 'model_name': model_name,
407
+ 'architecture': self.models[model_name]['info']['architecture'],
408
+ 'weight': weight,
409
+ 'logits': outputs.detach().cpu().numpy().tolist(),
410
+ 'probs': probs.detach().cpu().numpy().tolist()
411
+ })
412
+ _dbg(f"{model_name} weight={weight} logits={outputs.cpu().numpy()} probs={probs.cpu().numpy()}")
413
+
414
+ # Normalize by total weight
415
+ if total_weight > 0:
416
+ ensemble_probs /= total_weight
417
+
418
+ probs_numpy = ensemble_probs[0].cpu().numpy()
419
+ prediction_index = int(probs_numpy.argmax())
420
+ predicted_label = 'PNEUMONIA' if prediction_index == 1 else 'NORMAL'
421
+ _dbg(f"Ensemble probs={probs_numpy} predicted_index={prediction_index} label={predicted_label}")
422
+
423
+ result = {
424
+ 'prediction': predicted_label,
425
+ 'confidence': float(probs_numpy[prediction_index]),
426
+ 'pneumonia_probability': float(probs_numpy[1]),
427
+ 'normal_probability': float(probs_numpy[0]),
428
+ 'models_used': available_models,
429
+ 'per_model': per_model,
430
+ 'prediction_type': 'ensemble'
431
+ }
432
+ return result
433
 
 
 
434
 
435
+ # ============================================================================
436
+ # Global Model System Instance
437
+ # ============================================================================
 
 
438
 
439
+ # Initialize model system
440
+ model_system = PneumoniaModelSystem(device='cpu')
 
 
441
 
442
+ # Try to load models
443
+ try:
444
+ model_system.load_models('models')
445
+ available_models = model_system.get_available_models()
446
+ print(f"Loaded {len(available_models)} models: {available_models}")
447
+ except Exception as e:
448
+ print(f"Warning: Could not load models - {str(e)}")
449
+ print(" Running in demo mode")
450
+ available_models = []
451
 
 
 
452
 
453
+ # ============================================================================
454
+ # Gradio Interface Functions
455
+ # ============================================================================
456
 
457
+ def predict_pneumonia(file_path, selected_model_option):
458
  """
459
+ Main prediction function for Gradio
460
+
461
+ Args:
462
+ file_path: File path from Gradio file upload
463
+ selected_model_option: Selected model from radio button
464
+
465
+ Returns:
466
+ tuple: (processed_image, result_text, probability_dict, confidence_html)
467
  """
468
+ if file_path is None:
469
+ return None, "Please upload an X-ray image", {}, ""
 
470
 
471
+ try:
472
+ # Process uploaded image (handles JPEG, PNG, DICOM)
473
+ class FileObj:
474
+ def __init__(self, path):
475
+ self.name = path
476
+
477
+ file_obj = FileObj(file_path)
478
+ processed_image = process_uploaded_image(file_obj)
479
+
480
+ if processed_image is None:
481
+ return None, "Error: Could not process the uploaded image", {}, ""
482
+
483
+ # Get prediction based on selected model option
484
+ if selected_model_option == "Ensemble (All Models)":
485
+ result = model_system.predict_ensemble(processed_image)
486
+ model_info = f"Ensemble of {len(result['models_used'])} models"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  else:
488
+ # Individual model prediction
489
+ result = model_system.predict_single_model(processed_image, selected_model_option)
490
+ model_info = f"{result['model_architecture']} - {result['model_description']}"
491
+
492
+ # Format result text
493
+ result_text = f"## Prediction: {result['prediction']}\n\n"
494
+ result_text += f"**Confidence:** {result['confidence']*100:.2f}%\n\n"
495
+ result_text += f"**Model Used:** {model_info}\n\n"
496
+
497
+ if result['prediction'] == 'PNEUMONIA':
498
+ result_text += "⚠️ **Pneumonia detected**\n\n"
499
+ result_text += "This X-ray shows signs consistent with pneumonia. "
500
+ result_text += "Please consult a qualified radiologist for confirmation."
501
+ else:
502
+ result_text += "✓ **No pneumonia detected**\n\n"
503
+ result_text += "This X-ray appears normal. "
504
+ result_text += "However, always consult a healthcare professional for accurate diagnosis."
505
+
506
+ # Create probability dictionary for bar chart
507
+ prob_dict = {
508
+ "Normal": result['normal_probability'],
509
+ "Pneumonia": result['pneumonia_probability']
510
+ }
511
 
512
+ # Create confidence HTML with color coding
513
+ confidence_pct = result['confidence'] * 100
514
+ if confidence_pct >= 90:
515
+ color = "green"
516
+ level = "Very High"
517
+ elif confidence_pct >= 75:
518
+ color = "blue"
519
+ level = "High"
520
+ elif confidence_pct >= 60:
521
+ color = "orange"
522
+ level = "Moderate"
523
+ else:
524
+ color = "red"
525
+ level = "Low"
526
+
527
+ confidence_html = f"""
528
+ <div style="padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
529
+ <h3 style="color: {color};">Confidence Level: {level}</h3>
530
+ <p style="font-size: 24px; color: {color}; font-weight: bold;">{confidence_pct:.1f}%</p>
531
+ <p style="font-size: 12px; color: #666;">
532
+ Model: {model_info}
533
+ </p>
534
+ </div>
535
+ """
536
 
537
+ return processed_image, result_text, prob_dict, confidence_html
 
 
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  except Exception as e:
540
+ error_msg = f"Error processing image: {str(e)}"
541
+ return None, error_msg, {}, f"<p style='color: red;'>{error_msg}</p>"
 
542
 
 
 
 
 
 
543
 
544
+ def get_model_system_info():
545
+ """Return information about the model system"""
546
+ info = f"""
547
+ # Model System Information
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
+ **Available Models:** {len(available_models)}
550
+
551
+ **Individual Models:**
552
+ """
553
+ for model_name in available_models:
554
+ model_info = model_system.model_definitions[model_name]
555
+ info += f"\n- **{model_name}**: {model_info['architecture']} - {model_info['description']}"
556
+
557
+ info += f"""
558
+
559
+ **Ensemble Configuration:**
560
+ - Uses weighted voting from multiple models
561
+ - Ensemble weights: {model_system.ensemble_weights}
562
+
563
+ **Performance Expectations:**
564
+ - Individual models may have varying strengths
565
+ - Ensemble typically provides more robust predictions
566
+ - Model selection allows comparing different approaches
567
+
568
+ **Best Practices:**
569
+ - Use ensemble for most reliable results
570
+ - Compare individual models to understand prediction confidence
571
+ - Consider model architecture strengths for specific cases
572
+ """
573
+
574
+ return info
575
+
576
+
577
+ # ============================================================================
578
+ # Gradio Interface
579
+ # ============================================================================
580
+
581
+ # Custom CSS
582
+ custom_css = """
583
+ .gradio-container {
584
+ font-family: 'Arial', sans-serif;
585
+ }
586
+ .output-markdown h2 {
587
+ color: #2c3e50;
588
+ }
589
+ """
590
+
591
+ # Create Gradio interface
592
+ with gr.Blocks(css=custom_css, title="Pneumonia Detection AI") as demo:
593
+ gr.Markdown(
594
  """
595
+ # Pneumonia Detection from Chest X-rays
596
+ ### AI-Powered Individual Models and Ensemble for Medical Screening
597
+
598
+ Upload a chest X-ray image and select a model to detect signs of pneumonia using our
599
+ state-of-the-art deep learning models.
600
+
601
+ **DISCLAIMER:** This tool is for research and educational purposes only.
602
+ It should not be used as a substitute for professional medical diagnosis.
603
+ Always consult qualified healthcare professionals for medical advice.
604
+ """
605
+ )
606
+
607
+ with gr.Row():
608
+ with gr.Column(scale=1):
609
+ # Input section - Use File upload to handle DICOM properly
610
+ input_file = gr.File(
611
+ label="Upload Chest X-Ray Image (JPEG, PNG, DICOM)",
612
+ file_types=[".jpg", ".jpeg", ".png", ".dcm", ".dicom"],
613
+ type="filepath"
614
+ )
615
 
616
+ # Model selection radio button
617
+ model_options = ["Ensemble (All Models)"] + available_models
618
+ model_selector = gr.Radio(
619
+ choices=model_options,
620
+ value="Ensemble (All Models)",
621
+ label="Select Model",
622
+ info="Choose an individual model or use the ensemble of all models"
623
+ )
624
+
625
+ predict_btn = gr.Button("Analyze X-Ray", variant="primary", size="lg")
626
+
627
+ gr.Markdown("### Supported Formats")
628
+ gr.Markdown("**JPEG, PNG, DICOM** (.dcm, .dicom files)")
629
+ gr.Markdown("Try with your own chest X-ray images")
630
+
631
+ with gr.Column(scale=1):
632
+ # Preview section for processed image
633
+ preview_image = gr.Image(
634
+ label="Processed Image Preview",
635
+ height=400,
636
+ interactive=False
637
+ )
638
 
639
+ gr.Markdown("*This shows how the image appears to the AI model*")
640
+
641
+ with gr.Column(scale=1):
642
+ # Output section
643
+ output_text = gr.Markdown(label="Diagnosis Result")
644
+ confidence_html = gr.HTML(label="Confidence Level")
645
+ prob_chart = gr.Label(label="Probability Distribution", num_top_classes=2)
646
+
647
+ # Model system info accordion
648
+ with gr.Accordion("🤖 About These Models", open=False):
649
+ model_info = gr.Markdown(get_model_system_info())
650
+
651
+ # Technical details accordion
652
+ with gr.Accordion("🔬 Technical Details", open=False):
653
+ gr.Markdown(
654
+ """
655
+ ### Model Architecture
656
+
657
+ This system uses an ensemble of multiple deep learning models:
658
+ - **VGG19** (Transfer Learning): Deep convolutional network with 19 layers
659
+ - **VGG16** (Transfer Learning): 16-layer convolutional network
660
+ - **EfficientNet-B0**: Efficient architecture with compound scaling
661
+
662
+ ### How It Works
663
+
664
+ 1. **Image Preprocessing**: X-ray image is resized to 224×224 and normalized
665
+ 2. **Ensemble Prediction**: Each model independently analyzes the image
666
+ 3. **Weighted Voting**: Predictions are combined using learned weights
667
+ 4. **Confidence Score**: Final probability based on ensemble agreement
668
+
669
+ ### Performance Metrics
670
+
671
+ The ensemble has been validated on 33 different model configurations
672
+ and ranked based on clinical utility metrics.
673
+
674
+ ### Dataset
675
+
676
+ Models trained on chest X-ray dataset with thousands of images
677
+ from actual clinical cases.
678
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
  )
680
 
681
+ # Citation accordion
682
+ with gr.Accordion("📚 Citation & Credits", open=False):
683
+ gr.Markdown(
684
+ """
685
+ ### Citation
 
 
 
 
 
 
 
 
686
 
687
+ If you use this model in your research, please cite:
 
 
 
 
 
 
 
 
688
 
689
+ ```
690
+ @software{pneumonia_ensemble_2025,
691
+ title={Deep Learning Ensemble for Pneumonia Detection},
692
+ author={Prabakaran Thangamani},
693
+ year={2025},
694
+ url={https://huggingface.co/spaces/papsofts/pneumonia-detection}
695
+ }
696
+ ```
697
+
698
+ ### Credits
 
 
 
 
 
 
699
 
700
+ - **Model Development**: Based on 33 model architectures including VGG, EfficientNet, ResNet families
701
+ - **Framework**: PyTorch, torchvision
702
+ - **Interface**: Gradio
703
+ - **Deployment**: Hugging Face Spaces
704
 
705
+ ### License
706
+
707
+ This model is released for research and educational purposes.
708
+ """
709
+ )
710
+
711
+ # Connect the prediction function
712
+ predict_btn.click(
713
+ fn=predict_pneumonia,
714
+ inputs=[input_file, model_selector],
715
+ outputs=[preview_image, output_text, prob_chart, confidence_html]
716
+ )
717
+
718
+ # Examples (you can add actual example images here)
719
+ gr.Examples(
720
+ examples=[],
721
+ inputs=input_file,
722
+ label="Example X-Ray Images"
723
+ )
724
+
725
+
726
+ # ============================================================================
727
+ # Launch
728
+ # ============================================================================
729
 
730
  if __name__ == "__main__":
731
+ demo.launch(
732
+ share=False,
733
+ server_name="0.0.0.0",
734
+ server_port=7860
735
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/Model_A1_7CB_Appr_D.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ee1e63db1ee140d6dd425400e3745918ff0aed42d1fe65c85a520d1967b5e2e
3
+ size 558329434
models/Model_C_Appr_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4018a7db58cf03fa3965b5195fdaef3ffee034eb445d98b6ee70ecbf4e104153
3
+ size 9163946
models/Model_F_Appr_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69bf003475478056556f558cc9f3f511f99869c0ba57617a2c918776207f1dab
3
+ size 94379802
models/Model_G_Appr_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05bb0b26d6bc63fe82f47e3c5ffc4c6dc0d34bd483d17b74480930f9c6583fa8
3
+ size 16359314
models/Model_H_Appr_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a6c44c009fde3eb81e55dd87d1a8606a325ed5fb82bcc5635f6a7c0a7fbdde2
3
+ size 28459058