neecat's picture
add modified files
57d41d5
import streamlit as st
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from captum.attr import Saliency, GuidedBackprop, NoiseTunnel
from src.model import TrashNetClassifier
from src import config
@st.cache_resource
def load_model():
model = TrashNetClassifier()
model.load_state_dict(torch.load(
config.MODEL_SAVE_PATH, map_location=config.DEVICE))
model.eval().to(config.DEVICE)
return model
def compute_saliency_map(model, input_tensor, method):
model.zero_grad()
input_tensor = input_tensor.to(config.DEVICE)
input_tensor.requires_grad_()
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()
confidence = torch.softmax(output, dim=1)[0][pred_class].item()
if method == "saliency":
attr = Saliency(model)
attributions = attr.attribute(input_tensor, target=pred_class)
elif method == "smoothgrad":
attr = NoiseTunnel(Saliency(model))
attributions = attr.attribute(
input_tensor,
nt_type="smoothgrad",
target=pred_class,
nt_samples=config.SMOOTHGRAD_SAMPLES,
stdevs=config.SMOOTHGRAD_STDEV
)
elif method == "guided":
attr = GuidedBackprop(model)
attributions = attr.attribute(input_tensor, target=pred_class)
else:
raise ValueError("Unsupported method")
saliency = attributions.squeeze().abs().cpu().detach().numpy()
saliency = np.max(saliency, axis=0)
return pred_class, confidence, saliency
def preprocess_image(uploaded_file):
pil_image = Image.open(uploaded_file).convert("RGB")
transform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
return pil_image, transform(pil_image).unsqueeze(0)
def run_saliency(model, input_tensor):
input_tensor = input_tensor.to(config.DEVICE)
input_tensor.requires_grad_()
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()
confidence = torch.softmax(output, dim=1)[0][pred_class].item()
output[0, pred_class].backward()
saliency = input_tensor.grad.abs().squeeze().cpu().numpy()
saliency = np.max(saliency, axis=0)
return pred_class, confidence, saliency
def get_saliency_figure(input_tensor, saliency_map):
saliency_map -= saliency_map.min()
saliency_map /= saliency_map.max() + 1e-10
img_np = input_tensor.squeeze().detach().cpu().numpy()
img_np = np.transpose(img_np, (1, 2, 0))
img_np = (img_np * 0.5 + 0.5).clip(0, 1)
saliency_rgb = np.stack([saliency_map]*3, axis=-1)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(img_np)
axs[0].set_title("Original Image")
axs[0].axis("off")
axs[1].imshow(saliency_rgb, cmap="gray")
axs[1].set_title("Saliency Map")
axs[1].axis("off")
fig.tight_layout()
return fig
st.set_page_config(page_title="Saliency Demo", layout="centered")
st.title("🧠 Trash Classifier with Clean Saliency Visualization")
st.markdown(
"Upload a trash image. The model will classify it and show pixel-level attention.")
method = st.radio("🧠 Select Explanation Method", [
"saliency", "smoothgrad", "guided"])
uploaded_file = st.file_uploader(
"📤 Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
pil_img, input_tensor = preprocess_image(uploaded_file)
st.image(pil_img, caption="Uploaded Image", use_column_width=True)
with st.spinner(f"Computing {method} map..."):
model = load_model()
pred_class, confidence, saliency_map = compute_saliency_map(
model, input_tensor, method)
class_names = sorted(os.listdir(
os.path.join(config.DATA_DIR, "train")))
pred_label = class_names[pred_class]
fig = get_saliency_figure(input_tensor, saliency_map)
st.markdown(f"### 🧠 Prediction: **{pred_label}** ({confidence*100:.2f}%)")
st.pyplot(fig)