Spaces:
Sleeping
Sleeping
File size: 4,134 Bytes
84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import streamlit as st
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from captum.attr import Saliency, GuidedBackprop, NoiseTunnel
from src.model import TrashNetClassifier
from src import config
@st.cache_resource
def load_model():
model = TrashNetClassifier()
model.load_state_dict(torch.load(
config.MODEL_SAVE_PATH, map_location=config.DEVICE))
model.eval().to(config.DEVICE)
return model
def compute_saliency_map(model, input_tensor, method):
model.zero_grad()
input_tensor = input_tensor.to(config.DEVICE)
input_tensor.requires_grad_()
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()
confidence = torch.softmax(output, dim=1)[0][pred_class].item()
if method == "saliency":
attr = Saliency(model)
attributions = attr.attribute(input_tensor, target=pred_class)
elif method == "smoothgrad":
attr = NoiseTunnel(Saliency(model))
attributions = attr.attribute(
input_tensor,
nt_type="smoothgrad",
target=pred_class,
nt_samples=config.SMOOTHGRAD_SAMPLES,
stdevs=config.SMOOTHGRAD_STDEV
)
elif method == "guided":
attr = GuidedBackprop(model)
attributions = attr.attribute(input_tensor, target=pred_class)
else:
raise ValueError("Unsupported method")
saliency = attributions.squeeze().abs().cpu().detach().numpy()
saliency = np.max(saliency, axis=0)
return pred_class, confidence, saliency
def preprocess_image(uploaded_file):
pil_image = Image.open(uploaded_file).convert("RGB")
transform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
return pil_image, transform(pil_image).unsqueeze(0)
def run_saliency(model, input_tensor):
input_tensor = input_tensor.to(config.DEVICE)
input_tensor.requires_grad_()
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()
confidence = torch.softmax(output, dim=1)[0][pred_class].item()
output[0, pred_class].backward()
saliency = input_tensor.grad.abs().squeeze().cpu().numpy()
saliency = np.max(saliency, axis=0)
return pred_class, confidence, saliency
def get_saliency_figure(input_tensor, saliency_map):
saliency_map -= saliency_map.min()
saliency_map /= saliency_map.max() + 1e-10
img_np = input_tensor.squeeze().detach().cpu().numpy()
img_np = np.transpose(img_np, (1, 2, 0))
img_np = (img_np * 0.5 + 0.5).clip(0, 1)
saliency_rgb = np.stack([saliency_map]*3, axis=-1)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(img_np)
axs[0].set_title("Original Image")
axs[0].axis("off")
axs[1].imshow(saliency_rgb, cmap="gray")
axs[1].set_title("Saliency Map")
axs[1].axis("off")
fig.tight_layout()
return fig
st.set_page_config(page_title="Saliency Demo", layout="centered")
st.title("🧠 Trash Classifier with Clean Saliency Visualization")
st.markdown(
"Upload a trash image. The model will classify it and show pixel-level attention.")
method = st.radio("🧠 Select Explanation Method", [
"saliency", "smoothgrad", "guided"])
uploaded_file = st.file_uploader(
"📤 Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
pil_img, input_tensor = preprocess_image(uploaded_file)
st.image(pil_img, caption="Uploaded Image", use_column_width=True)
with st.spinner(f"Computing {method} map..."):
model = load_model()
pred_class, confidence, saliency_map = compute_saliency_map(
model, input_tensor, method)
class_names = sorted(os.listdir(
os.path.join(config.DATA_DIR, "train")))
pred_label = class_names[pred_class]
fig = get_saliency_figure(input_tensor, saliency_map)
st.markdown(f"### 🧠 Prediction: **{pred_label}** ({confidence*100:.2f}%)")
st.pyplot(fig)
|