Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| import os | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from captum.attr import Saliency, GuidedBackprop, NoiseTunnel | |
| from src.model import TrashNetClassifier | |
| from src import config | |
| def load_model(): | |
| model = TrashNetClassifier() | |
| model.load_state_dict(torch.load( | |
| config.MODEL_SAVE_PATH, map_location=config.DEVICE)) | |
| model.eval().to(config.DEVICE) | |
| return model | |
| def compute_saliency_map(model, input_tensor, method): | |
| model.zero_grad() | |
| input_tensor = input_tensor.to(config.DEVICE) | |
| input_tensor.requires_grad_() | |
| output = model(input_tensor) | |
| pred_class = output.argmax(dim=1).item() | |
| confidence = torch.softmax(output, dim=1)[0][pred_class].item() | |
| if method == "saliency": | |
| attr = Saliency(model) | |
| attributions = attr.attribute(input_tensor, target=pred_class) | |
| elif method == "smoothgrad": | |
| attr = NoiseTunnel(Saliency(model)) | |
| attributions = attr.attribute( | |
| input_tensor, | |
| nt_type="smoothgrad", | |
| target=pred_class, | |
| nt_samples=config.SMOOTHGRAD_SAMPLES, | |
| stdevs=config.SMOOTHGRAD_STDEV | |
| ) | |
| elif method == "guided": | |
| attr = GuidedBackprop(model) | |
| attributions = attr.attribute(input_tensor, target=pred_class) | |
| else: | |
| raise ValueError("Unsupported method") | |
| saliency = attributions.squeeze().abs().cpu().detach().numpy() | |
| saliency = np.max(saliency, axis=0) | |
| return pred_class, confidence, saliency | |
| def preprocess_image(uploaded_file): | |
| pil_image = Image.open(uploaded_file).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3) | |
| ]) | |
| return pil_image, transform(pil_image).unsqueeze(0) | |
| def run_saliency(model, input_tensor): | |
| input_tensor = input_tensor.to(config.DEVICE) | |
| input_tensor.requires_grad_() | |
| output = model(input_tensor) | |
| pred_class = output.argmax(dim=1).item() | |
| confidence = torch.softmax(output, dim=1)[0][pred_class].item() | |
| output[0, pred_class].backward() | |
| saliency = input_tensor.grad.abs().squeeze().cpu().numpy() | |
| saliency = np.max(saliency, axis=0) | |
| return pred_class, confidence, saliency | |
| def get_saliency_figure(input_tensor, saliency_map): | |
| saliency_map -= saliency_map.min() | |
| saliency_map /= saliency_map.max() + 1e-10 | |
| img_np = input_tensor.squeeze().detach().cpu().numpy() | |
| img_np = np.transpose(img_np, (1, 2, 0)) | |
| img_np = (img_np * 0.5 + 0.5).clip(0, 1) | |
| saliency_rgb = np.stack([saliency_map]*3, axis=-1) | |
| fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
| axs[0].imshow(img_np) | |
| axs[0].set_title("Original Image") | |
| axs[0].axis("off") | |
| axs[1].imshow(saliency_rgb, cmap="gray") | |
| axs[1].set_title("Saliency Map") | |
| axs[1].axis("off") | |
| fig.tight_layout() | |
| return fig | |
| st.set_page_config(page_title="Saliency Demo", layout="centered") | |
| st.title("🧠 Trash Classifier with Clean Saliency Visualization") | |
| st.markdown( | |
| "Upload a trash image. The model will classify it and show pixel-level attention.") | |
| method = st.radio("🧠 Select Explanation Method", [ | |
| "saliency", "smoothgrad", "guided"]) | |
| uploaded_file = st.file_uploader( | |
| "📤 Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| pil_img, input_tensor = preprocess_image(uploaded_file) | |
| st.image(pil_img, caption="Uploaded Image", use_column_width=True) | |
| with st.spinner(f"Computing {method} map..."): | |
| model = load_model() | |
| pred_class, confidence, saliency_map = compute_saliency_map( | |
| model, input_tensor, method) | |
| class_names = sorted(os.listdir( | |
| os.path.join(config.DATA_DIR, "train"))) | |
| pred_label = class_names[pred_class] | |
| fig = get_saliency_figure(input_tensor, saliency_map) | |
| st.markdown(f"### 🧠 Prediction: **{pred_label}** ({confidence*100:.2f}%)") | |
| st.pyplot(fig) | |