papsofts commited on
Commit
91326ca
·
verified ·
1 Parent(s): 17cd064

Initial commit of Mode_D_Appr_G1 and Model_I_Appr_H1

Browse files
Files changed (4) hide show
  1. .gitignore +177 -0
  2. README.md +65 -13
  3. app.py +570 -0
  4. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be added to the global gitignore or merged into this project gitignore. For a PyTorch
158
+ # research project template, see https://github.com/PyTorchLightning/deep-learning-project-template
159
+ # PyCharm: File | Settings | File Templates
160
+ .idea/
161
+
162
+ # VS Code
163
+ .vscode/
164
+
165
+ # Gradio temporary files
166
+ gradio_cached_examples/
167
+ flagged/
168
+
169
+ # Model files (too large for git)
170
+ *.pt
171
+ *.pth
172
+ *.bin
173
+ *.safetensors
174
+
175
+ # Temporary files
176
+ *.tmp
177
+ *.temp
README.md CHANGED
@@ -1,13 +1,65 @@
1
- ---
2
- title: Pneumonia Detection
3
- emoji: 🔥
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: Space for Pneumonia Detection Capstone project from GL
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Transfer Learning Inference
3
+ emoji: 🧪
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Transfer Learning Inference App
14
+
15
+ A Gradio-based inference application for transfer learning models supporting multiple architectures (ResNet, DenseNet, Inception, EfficientNet).
16
+
17
+ ## Features
18
+
19
+ - **Multi-Model Support**: Supports ResNet, DenseNet, Inception V3, and EfficientNet architectures
20
+ - **Dynamic Model Loading**: Automatically detects and loads available models from HuggingFace Hub
21
+ - **Interactive Interface**: User-friendly dropdowns for model and approach selection
22
+ - **Image Classification**: Upload images and get top-K predictions
23
+ - **Grad-CAM Visualization**: Optional attention heatmaps for model interpretability
24
+ - **Professional Results**: Clean tabular display of predictions with confidence scores
25
+
26
+ ## Supported Architectures
27
+
28
+ - ResNet (resnet18, resnet34, resnet50, resnet101, resnet152)
29
+ - DenseNet (densenet121, densenet161, densenet169, densenet201)
30
+ - Inception V3
31
+ - EfficientNet (via timm library)
32
+
33
+ ## Model Configuration
34
+
35
+ The app expects models to be available either:
36
+ 1. Uploaded as files in this repository under `models/` directory
37
+ 2. Referenced from HuggingFace Hub repositories (set via environment variables)
38
+
39
+ ## Environment Variables
40
+
41
+ - `HF_TOKEN`: HuggingFace API token for private model access (optional)
42
+ - `MODEL_REPO_ID`: HuggingFace repository ID containing model weights (optional)
43
+ - `NUM_CLASSES`: Number of output classes (default: 2)
44
+
45
+ ## Usage
46
+
47
+ 1. Select a model from the dropdown
48
+ 2. Choose an approach/variant
49
+ 3. Upload an image (JPG, PNG, etc.)
50
+ 4. Optionally enable Grad-CAM visualization
51
+ 5. Click "Run Inference" to see results
52
+
53
+ ## Local Development
54
+
55
+ ```bash
56
+ pip install -r requirements.txt
57
+ python app.py
58
+ ```
59
+
60
+ ## Model Format
61
+
62
+ Models should be saved as PyTorch state dictionaries (.pt files) with filenames following the pattern:
63
+ `Appr_{approach}_{architecture}.pt`
64
+
65
+ Example: `Appr_A_resnet50.pt`, `Appr_B_densenet121.pt`
app.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio Inference App for Transfer Learning Models - HuggingFace Spaces Version
2
+
3
+ Features:
4
+ - Auto-scan models directory or HuggingFace Hub for available models and approaches.
5
+ - Dropdown selection of Model and Approach.
6
+ - Dynamic architecture detection from filename (e.g., resnet50, densenet121, inception_v3, efficientnet_b0, resnet34).
7
+ - Image upload and preprocessing (ImageNet normalization).
8
+ - Top-K prediction display (configurable class labels).
9
+ - Optional Grad-CAM visualization for interpretability.
10
+ - Environment variable configuration for HuggingFace deployment.
11
+ - Graceful error handling and clear user feedback.
12
+
13
+ Environment Variables:
14
+ - HF_TOKEN: HuggingFace API token for private repositories
15
+ - MODEL_REPO_ID: HuggingFace repository containing models
16
+ - NUM_CLASSES: Number of output classes (default: 2)
17
+ - DEBUG: Enable debug logging
18
+ """
19
+ import os
20
+ import re
21
+ import logging
22
+ from typing import Dict, Tuple, List, Optional, Union
23
+ from pathlib import Path
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from torchvision import models, transforms
28
+ from PIL import Image
29
+ import gradio as gr
30
+ from huggingface_hub import hf_hub_download, list_repo_files
31
+ from dotenv import load_dotenv
32
+
33
+ try:
34
+ import timm # EfficientNet etc.
35
+ except ImportError:
36
+ timm = None
37
+
38
+ try:
39
+ import cv2
40
+ except ImportError:
41
+ cv2 = None
42
+
43
+ # Load environment variables
44
+ load_dotenv()
45
+
46
+ # Configuration from environment variables
47
+ HF_TOKEN = os.getenv("HF_TOKEN")
48
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID")
49
+ NUM_CLASSES = int(os.getenv("NUM_CLASSES", "2"))
50
+ DEBUG = os.getenv("DEBUG", "False").lower() == "true"
51
+
52
+ # Setup logging
53
+ logging.basicConfig(level=logging.DEBUG if DEBUG else logging.INFO)
54
+ logger = logging.getLogger(__name__)
55
+
56
+ # Directory paths
57
+ MODELS_DIR = Path("models")
58
+ MODELS_DIR.mkdir(exist_ok=True)
59
+
60
+ # Placeholder class labels; customize based on your dataset
61
+ CLASS_LABELS = {i: f"Class_{i}" for i in range(NUM_CLASSES)}
62
+
63
+ # Architecture-specific input sizes
64
+ _ARCH_INPUT_SIZE = {
65
+ "inception_v3": 299,
66
+ # Most others default to 224
67
+ }
68
+
69
+ # Regex to parse weight filenames: Appr_<Approach>_<arch>.pt
70
+ WEIGHT_PATTERN = re.compile(r"Appr_([A-Za-z0-9]+)_([A-Za-z0-9_]+)\.pt")
71
+
72
+ ModelMap = Dict[str, Dict[str, Dict[str, str]]]
73
+ _model_map_cache: Optional[ModelMap] = None
74
+ # Cache for instantiated models to avoid recreation per inference
75
+ _model_instance_cache: Dict[Tuple[str, str], Tuple[torch.nn.Module, str]] = {}
76
+
77
+
78
+ def download_models_from_hub() -> None:
79
+ """Download models from HuggingFace Hub if MODEL_REPO_ID is set."""
80
+ if not MODEL_REPO_ID:
81
+ logger.info("No MODEL_REPO_ID set, skipping Hub download")
82
+ return
83
+
84
+ try:
85
+ logger.info(f"Downloading models from {MODEL_REPO_ID}")
86
+ repo_files = list_repo_files(MODEL_REPO_ID, token=HF_TOKEN)
87
+ model_files = [f for f in repo_files if f.endswith('.pt')]
88
+
89
+ for model_file in model_files:
90
+ local_path = MODELS_DIR / model_file
91
+ if not local_path.exists():
92
+ logger.info(f"Downloading {model_file}")
93
+ downloaded_path = hf_hub_download(
94
+ MODEL_REPO_ID,
95
+ model_file,
96
+ cache_dir=str(MODELS_DIR),
97
+ token=HF_TOKEN
98
+ )
99
+ # Move to our models directory structure
100
+ local_path.parent.mkdir(parents=True, exist_ok=True)
101
+ Path(downloaded_path).rename(local_path)
102
+
103
+ except Exception as e:
104
+ logger.warning(f"Failed to download from Hub: {e}")
105
+
106
+
107
+ def scan_local_models() -> ModelMap:
108
+ """Scan local models directory for available models."""
109
+ mapping: ModelMap = {}
110
+
111
+ # First, try to download from Hub
112
+ download_models_from_hub()
113
+
114
+ # Scan the models directory
115
+ if MODELS_DIR.exists():
116
+ for item in MODELS_DIR.iterdir():
117
+ if item.is_dir():
118
+ # Model directory structure: models/Model_X/Appr_Y_arch.pt
119
+ model_name = item.name
120
+ for model_file in item.glob("*.pt"):
121
+ match = WEIGHT_PATTERN.match(model_file.name)
122
+ if match:
123
+ appr_code, arch = match.groups()
124
+ mapping.setdefault(model_name, {})[appr_code] = {
125
+ "path": str(model_file),
126
+ "arch": arch.lower(),
127
+ }
128
+ elif item.suffix == ".pt":
129
+ # Flat structure: models/Appr_Y_arch.pt
130
+ match = WEIGHT_PATTERN.match(item.name)
131
+ if match:
132
+ appr_code, arch = match.groups()
133
+ model_name = f"Model_{arch}"
134
+ mapping.setdefault(model_name, {})[appr_code] = {
135
+ "path": str(item),
136
+ "arch": arch.lower(),
137
+ }
138
+
139
+ return mapping
140
+
141
+
142
+ def scan_weights(refresh: bool = False) -> ModelMap:
143
+ """Scan for available models, with caching."""
144
+ global _model_map_cache
145
+ if _model_map_cache is not None and not refresh:
146
+ return _model_map_cache
147
+
148
+ mapping = scan_local_models()
149
+
150
+ # Fallback: create demo models if none found
151
+ if not mapping:
152
+ logger.warning("No models found, creating demo entries")
153
+ mapping = {
154
+ "ResNet34": {
155
+ "G1": {"path": "", "arch": "resnet34"},
156
+ },
157
+ "InceptionV3": {
158
+ "H1": {"path": "", "arch": "inception_v3"},
159
+ }
160
+ }
161
+
162
+ _model_map_cache = mapping
163
+ logger.info(f"Found models: {list(mapping.keys())}")
164
+ return mapping
165
+
166
+
167
+ def _create_model(arch: str, num_classes: int) -> torch.nn.Module:
168
+ """Instantiate a model given architecture string."""
169
+ arch = arch.lower()
170
+ logger.debug(f"Creating model: {arch} with {num_classes} classes")
171
+
172
+ if arch.startswith("resnet"):
173
+ base = getattr(models, arch)(weights=None)
174
+ base.fc = torch.nn.Linear(base.fc.in_features, num_classes)
175
+ return base
176
+ elif arch.startswith("densenet"):
177
+ base = getattr(models, arch)(weights=None)
178
+ base.classifier = torch.nn.Linear(base.classifier.in_features, num_classes)
179
+ return base
180
+ elif arch.startswith("inception_v3"):
181
+ # Disable aux logits for lighter inference path
182
+ base = models.inception_v3(weights=None, aux_logits=False)
183
+ base.fc = torch.nn.Linear(base.fc.in_features, num_classes)
184
+ return base
185
+ elif arch.startswith("efficientnet"):
186
+ if timm is None:
187
+ raise RuntimeError("timm not installed; cannot create EfficientNet model.")
188
+ base = timm.create_model(arch, pretrained=False, num_classes=num_classes)
189
+ return base
190
+
191
+ # Fallback using timm if available
192
+ if timm is not None:
193
+ try:
194
+ return timm.create_model(arch, pretrained=False, num_classes=num_classes)
195
+ except Exception as e:
196
+ logger.warning(f"timm failed to create {arch}: {e}")
197
+
198
+ raise ValueError(f"Unsupported architecture: {arch}")
199
+
200
+
201
+ def _resolve_device() -> torch.device:
202
+ """Resolve device based on DEVICE env var (cpu|cuda|auto)."""
203
+ device_pref = os.getenv("DEVICE", "auto").lower()
204
+ if device_pref == "cpu":
205
+ return torch.device("cpu")
206
+ if device_pref == "cuda":
207
+ if torch.cuda.is_available():
208
+ return torch.device("cuda")
209
+ logger.warning("DEVICE=cuda requested but CUDA not available; falling back to CPU")
210
+ return torch.device("cpu")
211
+ # auto
212
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
213
+
214
+
215
+ def load_model(model_name: str, approach: str, use_cache: bool = True) -> Tuple[torch.nn.Module, str]:
216
+ """Load a model given model name and approach, optionally using cache."""
217
+ if use_cache:
218
+ cache_key = (model_name, approach)
219
+ if cache_key in _model_instance_cache:
220
+ return _model_instance_cache[cache_key]
221
+
222
+ mapping = scan_weights()
223
+ if model_name not in mapping:
224
+ raise ValueError(f"Model '{model_name}' not found in {list(mapping.keys())}")
225
+ if approach not in mapping[model_name]:
226
+ raise ValueError(
227
+ f"Approach '{approach}' not found for model '{model_name}'. Available: {list(mapping[model_name].keys())}"
228
+ )
229
+
230
+ info = mapping[model_name][approach]
231
+ arch = info["arch"]
232
+ weight_path = info["path"]
233
+
234
+ device = _resolve_device()
235
+ logger.info(f"Loading {model_name}/{approach} ({arch}) on {device}")
236
+
237
+ model = _create_model(arch, NUM_CLASSES)
238
+
239
+ # Load weights if path exists
240
+ if weight_path and Path(weight_path).exists():
241
+ try:
242
+ state = torch.load(weight_path, map_location=device)
243
+ if isinstance(state, dict) and "state_dict" in state:
244
+ state = state["state_dict"]
245
+
246
+ # Remove any DistributedDataParallel prefixes
247
+ new_state = {k.replace("module.", ""): v for k, v in state.items()}
248
+
249
+ missing, unexpected = model.load_state_dict(new_state, strict=False)
250
+ if missing:
251
+ logger.warning(f"Missing keys: {missing[:5]}...")
252
+ if unexpected:
253
+ logger.warning(f"Unexpected keys: {unexpected[:5]}...")
254
+
255
+ except Exception as e:
256
+ logger.warning(f"Failed to load weights from {weight_path}: {e}")
257
+ else:
258
+ logger.warning(f"No weights file found at {weight_path}, using random weights")
259
+
260
+ model.to(device)
261
+ model.eval()
262
+
263
+ if use_cache:
264
+ _model_instance_cache[cache_key] = (model, arch)
265
+
266
+ return model, arch
267
+
268
+
269
+ _image_cache_transform: Dict[int, transforms.Compose] = {}
270
+
271
+
272
+ def get_transform(arch: str) -> transforms.Compose:
273
+ """Get image transforms for the given architecture."""
274
+ size = _ARCH_INPUT_SIZE.get(arch, 224)
275
+ if size in _image_cache_transform:
276
+ return _image_cache_transform[size]
277
+
278
+ t = transforms.Compose([
279
+ transforms.Resize((size, size)),
280
+ transforms.ToTensor(),
281
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
282
+ ])
283
+ _image_cache_transform[size] = t
284
+ return t
285
+
286
+
287
+ # Grad-CAM utilities
288
+ class GradCAM:
289
+ """Grad-CAM implementation for generating attention heatmaps."""
290
+
291
+ def __init__(self, model: torch.nn.Module, target_layer: Optional[str] = None):
292
+ self.model = model
293
+ self.model.eval()
294
+ self.target_layer = target_layer
295
+ self.activations = None
296
+ self.gradients = None
297
+
298
+ # Try to automatically pick a layer if not provided
299
+ if target_layer is None:
300
+ layer = None
301
+ # Common patterns for different architectures
302
+ layer_candidates = ["layer4", "features.denseblock4", "blocks.6", "conv_head", "features"]
303
+ for cand in layer_candidates:
304
+ parts = cand.split('.')
305
+ current = model
306
+ try:
307
+ for part in parts:
308
+ current = getattr(current, part)
309
+ layer = current
310
+ break
311
+ except AttributeError:
312
+ continue
313
+ self.target_module = layer if layer is not None else model
314
+ else:
315
+ self.target_module = dict(model.named_modules()).get(target_layer, model)
316
+
317
+ def fwd_hook(_, __, output):
318
+ self.activations = output.detach()
319
+
320
+ def bwd_hook(_, grad_input, grad_output):
321
+ if grad_output[0] is not None:
322
+ self.gradients = grad_output[0].detach()
323
+
324
+ self.target_module.register_forward_hook(fwd_hook)
325
+ # Use full backward hook (non-deprecated) for gradient capture
326
+ try:
327
+ self.target_module.register_full_backward_hook(bwd_hook)
328
+ except AttributeError:
329
+ # Fallback if running older torch without full hook
330
+ self.target_module.register_backward_hook(bwd_hook)
331
+
332
+ def generate(self, tensor: torch.Tensor, class_idx: Optional[int] = None) -> torch.Tensor:
333
+ """Generate Grad-CAM heatmap."""
334
+ tensor = tensor.requires_grad_(True)
335
+ logits = self.model(tensor)
336
+ if isinstance(logits, tuple): # Inception may return (logits, aux)
337
+ logits = logits[0]
338
+
339
+ if class_idx is None:
340
+ class_idx = logits.argmax(dim=1).item()
341
+
342
+ score = logits[:, class_idx]
343
+ score.backward(retain_graph=True)
344
+
345
+ # Compute weights
346
+ grads = self.gradients # [B, C, H, W]
347
+ acts = self.activations
348
+
349
+ if grads is None or acts is None:
350
+ raise RuntimeError("GradCAM hooks did not capture activations/gradients")
351
+
352
+ weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
353
+ cam = (weights * acts).sum(dim=1, keepdim=True)
354
+ cam = F.relu(cam)
355
+ cam = F.interpolate(cam, size=tensor.shape[2:], mode="bilinear", align_corners=False)
356
+
357
+ # Normalize
358
+ cam_min, cam_max = cam.min(), cam.max()
359
+ cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
360
+ return cam.squeeze(0).squeeze(0) # [H, W]
361
+
362
+
363
+ def predict(image: Image.Image, model_name: str, approach: str, grad_cam: bool = False, top_k: int = 5):
364
+ """Run inference on the uploaded image."""
365
+ if image is None:
366
+ return [], None, None
367
+
368
+ try:
369
+ model, arch = load_model(model_name, approach, use_cache=True)
370
+ except Exception as e:
371
+ logger.error(f"Model loading failed: {e}")
372
+ error_df = [["Error", 0.0, str(e)]]
373
+ return error_df, image, None
374
+
375
+ try:
376
+ transform = get_transform(arch)
377
+ tensor = transform(image).unsqueeze(0)
378
+ device = next(model.parameters()).device
379
+ tensor = tensor.to(device)
380
+
381
+ with torch.no_grad():
382
+ out = model(tensor)
383
+ if isinstance(out, tuple): # Inception
384
+ out = out[0]
385
+ probs = F.softmax(out, dim=1).cpu().squeeze(0)
386
+
387
+ top_k = min(top_k, probs.shape[0])
388
+ top_probs, top_indices = torch.topk(probs, top_k)
389
+
390
+ results = []
391
+ for p, idx in zip(top_probs.tolist(), top_indices.tolist()):
392
+ label = CLASS_LABELS.get(idx, f"Class_{idx}")
393
+ results.append([label, round(p, 4)])
394
+
395
+ cam_img = None
396
+ if grad_cam:
397
+ try:
398
+ gcam = GradCAM(model)
399
+ cam = gcam.generate(tensor)
400
+
401
+ # Convert cam to PIL heatmap overlay
402
+ if cv2 is not None:
403
+ import numpy as np
404
+ base_img = image.resize((cam.shape[1], cam.shape[0]))
405
+ base_arr = np.array(base_img)
406
+ heatmap = (cam.cpu().numpy() * 255).astype('uint8')
407
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
408
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
409
+ overlay = (0.4 * heatmap + 0.6 * base_arr).astype('uint8')
410
+ cam_img = Image.fromarray(overlay)
411
+ else:
412
+ # Fallback without OpenCV
413
+ import numpy as np
414
+ cam_np = cam.cpu().numpy()
415
+ cam_img = Image.fromarray((cam_np * 255).astype('uint8'))
416
+
417
+ except Exception as e:
418
+ logger.warning(f"Grad-CAM failed: {e}")
419
+ cam_img = Image.new('RGB', (224, 224), color=(255, 100, 100))
420
+
421
+ return results, image, cam_img
422
+
423
+ except Exception as e:
424
+ logger.error(f"Prediction failed: {e}")
425
+ error_df = [["Error", 0.0, str(e)]]
426
+ return error_df, image, None
427
+
428
+
429
+ def build_interface():
430
+ """Build the Gradio interface."""
431
+ mapping = scan_weights()
432
+ model_choices = sorted(mapping.keys())
433
+
434
+ with gr.Blocks(
435
+ title="Transfer Learning Inference",
436
+ theme=gr.themes.Soft(),
437
+ css="""
438
+ .gradio-container {
439
+ max-width: 1200px;
440
+ margin: auto;
441
+ }
442
+ """
443
+ ) as demo:
444
+ gr.Markdown(
445
+ """
446
+ # 🧪 Transfer Learning Inference
447
+
448
+ Select a pre-trained model and approach, upload an image, and view predictions with optional attention visualization.
449
+
450
+ **Available Models:** ResNet, DenseNet, Inception V3, EfficientNet
451
+ """
452
+ )
453
+
454
+ with gr.Row():
455
+ with gr.Column(scale=1):
456
+ model_dd = gr.Dropdown(
457
+ choices=model_choices,
458
+ label="Model",
459
+ value=model_choices[0] if model_choices else None,
460
+ info="Select the model architecture"
461
+ )
462
+ approach_dd = gr.Dropdown(
463
+ choices=[],
464
+ label="Approach",
465
+ info="Select the training approach/variant"
466
+ )
467
+ grad_cam_cb = gr.Checkbox(
468
+ label="Generate Grad-CAM",
469
+ value=False,
470
+ info="Show attention heatmap overlay"
471
+ )
472
+
473
+ with gr.Column(scale=2):
474
+ image_in = gr.Image(
475
+ type="pil",
476
+ label="Input Image",
477
+ info="Upload an image for classification"
478
+ )
479
+
480
+ submit_btn = gr.Button("🚀 Run Inference", variant="primary", size="lg")
481
+
482
+ with gr.Row():
483
+ with gr.Column():
484
+ results_out = gr.Dataframe(
485
+ headers=["Label", "Probability"],
486
+ datatype=["str", "number"],
487
+ label="🎯 Top Predictions",
488
+ interactive=False
489
+ )
490
+
491
+ with gr.Column():
492
+ with gr.Row():
493
+ original_out = gr.Image(label="📷 Original Image", interactive=False)
494
+ cam_out = gr.Image(label="🔥 Grad-CAM Overlay", interactive=False)
495
+
496
+ # Add model info
497
+ with gr.Accordion("ℹ️ Model Information", open=False):
498
+ gr.Markdown(f"""
499
+ - **Number of Classes:** {NUM_CLASSES}
500
+ - **Available Models:** {len(model_choices)}
501
+ - **Environment:** {'HuggingFace Spaces' if MODEL_REPO_ID else 'Local'}
502
+ """)
503
+
504
+ def update_approaches(selected_model):
505
+ if not selected_model:
506
+ return gr.update(choices=[], value=None)
507
+ mapping_local = scan_weights()
508
+ apprs = sorted(mapping_local.get(selected_model, {}).keys())
509
+ value = apprs[0] if apprs else None
510
+ return gr.update(choices=apprs, value=value)
511
+
512
+ model_dd.change(fn=update_approaches, inputs=model_dd, outputs=approach_dd)
513
+
514
+ submit_btn.click(
515
+ fn=predict,
516
+ inputs=[image_in, model_dd, approach_dd, grad_cam_cb],
517
+ outputs=[results_out, original_out, cam_out],
518
+ )
519
+
520
+ # Initialize approaches for first model
521
+ if model_choices:
522
+ demo.load(
523
+ fn=lambda: update_approaches(model_choices[0]),
524
+ inputs=None,
525
+ outputs=approach_dd
526
+ )
527
+
528
+ return demo
529
+
530
+
531
+ def main():
532
+ """Main function to launch the Gradio app."""
533
+ logger.info("Starting Transfer Learning Inference App")
534
+ demo = build_interface()
535
+
536
+ # Configuration for different environments
537
+ server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
538
+ server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
539
+ share = os.getenv("GRADIO_SHARE", "False").lower() == "true"
540
+
541
+ demo.launch(
542
+ server_name=server_name,
543
+ server_port=server_port,
544
+ share=share,
545
+ show_error=True
546
+ )
547
+
548
+
549
+ if __name__ == "__main__":
550
+ # Quick internal test if requested
551
+ if os.environ.get("HEADLESS_TEST") == "1":
552
+ logger.info("Running headless test")
553
+ mapping = scan_weights()
554
+ if mapping:
555
+ first_model = next(iter(mapping))
556
+ first_appr = next(iter(mapping[first_model]))
557
+ try:
558
+ model, arch = load_model(first_model, first_appr)
559
+ size = _ARCH_INPUT_SIZE.get(arch, 224)
560
+ x = torch.randn(1, 3, size, size)
561
+ out = model(x)
562
+ if isinstance(out, tuple):
563
+ out = out[0]
564
+ logger.info(f"Test forward output shape: {out.shape}")
565
+ except Exception as e:
566
+ logger.error(f"Test failed: {e}")
567
+ else:
568
+ logger.warning("No models found for testing")
569
+ else:
570
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.1
2
+ torchvision==0.17.1
3
+ timm==0.9.12
4
+ gradio==4.44.0
5
+ Pillow==10.4.0
6
+ opencv-python-headless==4.9.0.80
7
+ numpy==1.26.4
8
+ huggingface-hub==0.23.4
9
+ python-dotenv==1.0.1
10
+ requests==2.32.3