Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,82 +3,82 @@ import torch
|
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import numpy as np
|
| 5 |
from fastai.vision.all import load_learner, PILImage
|
| 6 |
-
from fastai.vision.utils import show_image
|
| 7 |
import io
|
| 8 |
-
from torchvision.transforms.functional import to_pil_image
|
| 9 |
|
| 10 |
-
#
|
| 11 |
class Hook:
|
| 12 |
-
def __init__(self,
|
| 13 |
-
self.hook =
|
| 14 |
def remove(self): self.hook.remove()
|
| 15 |
|
| 16 |
class HookBwd:
|
| 17 |
-
def __init__(self,
|
| 18 |
-
self.hook =
|
| 19 |
def remove(self): self.hook.remove()
|
| 20 |
|
| 21 |
# Load the learner
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
# Function to predict + generate CAM
|
| 25 |
def predict_with_cam(img):
|
| 26 |
img = PILImage.create(img)
|
| 27 |
|
| 28 |
-
# Get
|
| 29 |
model = learn.model
|
| 30 |
-
target_layer = model[0][-1] #
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# Hook functions
|
| 37 |
def hook_activations(out): activations.append(out)
|
| 38 |
def hook_gradients(grad): gradients.append(grad)
|
| 39 |
-
|
| 40 |
-
#
|
| 41 |
h1 = Hook(target_layer, hook_activations)
|
| 42 |
h2 = HookBwd(target_layer, hook_gradients)
|
| 43 |
-
|
| 44 |
-
#
|
| 45 |
pred_class, pred_idx, probs = learn.predict(img)
|
| 46 |
-
|
| 47 |
-
#
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
output[0, pred_idx].backward()
|
| 50 |
-
|
| 51 |
# Remove hooks
|
| 52 |
h1.remove()
|
| 53 |
h2.remove()
|
| 54 |
-
|
| 55 |
-
# Generate CAM
|
| 56 |
-
act = activations[0].detach().cpu()
|
| 57 |
-
grad = gradients[0].detach().cpu()
|
| 58 |
-
|
| 59 |
weights = grad.mean(dim=(1, 2), keepdim=True)
|
| 60 |
cam = (weights * act).sum(0)
|
| 61 |
cam = cam.clamp(min=0).numpy()
|
| 62 |
-
|
| 63 |
# Normalize CAM
|
| 64 |
-
cam
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Convert to image
|
| 68 |
fig, ax = plt.subplots()
|
| 69 |
ax.imshow(img)
|
| 70 |
ax.imshow(cam, alpha=0.5, cmap='jet')
|
| 71 |
ax.axis('off')
|
| 72 |
-
|
| 73 |
-
# Save CAM
|
| 74 |
buf = io.BytesIO()
|
| 75 |
-
plt.savefig(buf, format='png')
|
| 76 |
buf.seek(0)
|
| 77 |
-
|
| 78 |
-
# Return predictions + CAM image
|
| 79 |
return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf
|
| 80 |
|
| 81 |
-
# Gradio interface
|
| 82 |
interface = gr.Interface(
|
| 83 |
fn=predict_with_cam,
|
| 84 |
inputs=gr.Image(type='pil'),
|
|
|
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import numpy as np
|
| 5 |
from fastai.vision.all import load_learner, PILImage
|
|
|
|
| 6 |
import io
|
|
|
|
| 7 |
|
| 8 |
+
# Ensure custom classes exist before loading the model
|
| 9 |
class Hook:
|
| 10 |
+
def __init__(self, module, func):
|
| 11 |
+
self.hook = module.register_forward_hook(lambda mod, inp, out: func(out))
|
| 12 |
def remove(self): self.hook.remove()
|
| 13 |
|
| 14 |
class HookBwd:
|
| 15 |
+
def __init__(self, module, func):
|
| 16 |
+
self.hook = module.register_full_backward_hook(lambda mod, grad_input, grad_output: func(grad_output[0]))
|
| 17 |
def remove(self): self.hook.remove()
|
| 18 |
|
| 19 |
# Load the learner
|
| 20 |
+
try:
|
| 21 |
+
learn = load_learner('export.pkl')
|
| 22 |
+
print("Model loaded successfully!")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error loading model: {e}")
|
| 25 |
|
| 26 |
+
# Function to predict + generate Class Activation Map (CAM)
|
| 27 |
def predict_with_cam(img):
|
| 28 |
img = PILImage.create(img)
|
| 29 |
|
| 30 |
+
# Get model and target layer (modify as needed for your architecture)
|
| 31 |
model = learn.model
|
| 32 |
+
target_layer = model[0][-1] # Adjust based on model architecture
|
| 33 |
+
|
| 34 |
+
activations, gradients = [], []
|
| 35 |
+
|
| 36 |
+
# Define hook functions
|
|
|
|
|
|
|
| 37 |
def hook_activations(out): activations.append(out)
|
| 38 |
def hook_gradients(grad): gradients.append(grad)
|
| 39 |
+
|
| 40 |
+
# Attach hooks
|
| 41 |
h1 = Hook(target_layer, hook_activations)
|
| 42 |
h2 = HookBwd(target_layer, hook_gradients)
|
| 43 |
+
|
| 44 |
+
# Run prediction
|
| 45 |
pred_class, pred_idx, probs = learn.predict(img)
|
| 46 |
+
|
| 47 |
+
# Perform backward pass for gradients
|
| 48 |
+
img_tensor = learn.dls.test_dl([img]).one_batch()[0]
|
| 49 |
+
img_tensor.requires_grad_()
|
| 50 |
+
output = model(img_tensor)
|
| 51 |
output[0, pred_idx].backward()
|
| 52 |
+
|
| 53 |
# Remove hooks
|
| 54 |
h1.remove()
|
| 55 |
h2.remove()
|
| 56 |
+
|
| 57 |
+
# Generate Class Activation Map (CAM)
|
| 58 |
+
act = activations[0].detach().cpu().squeeze(0)
|
| 59 |
+
grad = gradients[0].detach().cpu().squeeze(0)
|
| 60 |
+
|
| 61 |
weights = grad.mean(dim=(1, 2), keepdim=True)
|
| 62 |
cam = (weights * act).sum(0)
|
| 63 |
cam = cam.clamp(min=0).numpy()
|
| 64 |
+
|
| 65 |
# Normalize CAM
|
| 66 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min())
|
| 67 |
+
|
| 68 |
+
# Plot CAM
|
|
|
|
| 69 |
fig, ax = plt.subplots()
|
| 70 |
ax.imshow(img)
|
| 71 |
ax.imshow(cam, alpha=0.5, cmap='jet')
|
| 72 |
ax.axis('off')
|
| 73 |
+
|
| 74 |
+
# Save CAM image
|
| 75 |
buf = io.BytesIO()
|
| 76 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
| 77 |
buf.seek(0)
|
| 78 |
+
|
|
|
|
| 79 |
return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf
|
| 80 |
|
| 81 |
+
# Create Gradio interface
|
| 82 |
interface = gr.Interface(
|
| 83 |
fn=predict_with_cam,
|
| 84 |
inputs=gr.Image(type='pil'),
|