Spaces:
Build error
Build error
| import random | |
| import time | |
| import numpy as np | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| import matplotlib.pyplot as plt | |
| from glob import glob | |
| from PIL import Image | |
| from model.load_model import get_model | |
| from torchvision import transforms | |
| from pytorch_grad_cam import GradCAM, GuidedBackpropReLUModel | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image | |
| from ultralytics import YOLO | |
| from rembg import remove | |
| import uuid | |
| # Static variables | |
| model_path = "efficientnet-b2.pth" | |
| model_name = "efficientnet_b2" | |
| YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt" | |
| classes = ["Healthy", "Resistant", "Susceptible"] | |
| resizing_transforms = transforms.Compose([transforms.CenterCrop(256)]) | |
| # Function definitions | |
| def reproduce(seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| cudnn.deterministic = True | |
| cudnn.benchmark = False | |
| def get_grad_cam_results(image, transformed_image, class_index=0): | |
| with GradCAM(model=model, target_layers=target_layers) as cam: | |
| targets = [ClassifierOutputTarget(class_index)] | |
| grayscale_cam = cam( | |
| input_tensor=transformed_image.unsqueeze(0), targets=targets | |
| ) | |
| grayscale_cam = grayscale_cam[0, :] | |
| visualization = show_cam_on_image( | |
| np.array(image) / 255.0, grayscale_cam, use_rgb=True | |
| ) | |
| return visualization, grayscale_cam | |
| def get_backpropagation_results(transformed_image, class_index=0): | |
| transformed_image = transformed_image.unsqueeze(0) | |
| backpropagation = gbp_model(transformed_image, target_category=class_index) | |
| bp_deprocessed = deprocess_image(backpropagation) | |
| return backpropagation, bp_deprocessed | |
| def get_guided_gradcam(image, cam_grayscale, bp): | |
| cam_mask = np.expand_dims(cam_grayscale, axis=-1) | |
| cam_mask = np.repeat(cam_mask, 3, axis=-1) | |
| img = show_cam_on_image( | |
| np.array(image) / 255.0, deprocess_image(cam_mask * bp), use_rgb=False | |
| ) | |
| return img | |
| def explain_results(image, class_index=0): | |
| transformed_image = image_transform(image) | |
| image = resizing_transforms(image) | |
| visualization, cam_mask = get_grad_cam_results( | |
| image, transformed_image, class_index | |
| ) | |
| backpropagation, bp_deprocessed = get_backpropagation_results( | |
| transformed_image, class_index | |
| ) | |
| guided_gradcam = get_guided_gradcam(image, cam_mask, backpropagation) | |
| return visualization, bp_deprocessed, guided_gradcam | |
| def make_prediction_and_explain(image): | |
| transformed_image = image_transform(image) | |
| transformed_image = transformed_image.unsqueeze(0) | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(transformed_image) | |
| output = torch.nn.functional.softmax(output, dim=1) | |
| predictions = [round(x, 4) * 100 for x in output[0].tolist()] | |
| results = {} | |
| for i, k in enumerate(classes): | |
| gradcam, bp_deprocessed, guided_gradcam = explain_results(image, class_index=i) | |
| results[k] = { | |
| "original_image": image, | |
| "prediction": f"{k} ({predictions[i]}%)", | |
| "gradcam": gradcam, | |
| "backpropagation": bp_deprocessed, | |
| "guided_gradcam": guided_gradcam, | |
| } | |
| return results | |
| def save_explanation_results(res, path): | |
| fig, ax = plt.subplots(3, 4, figsize=(15, 15)) | |
| for i, (k, v) in enumerate(res.items()): | |
| ax[i, 0].imshow(v["original_image"]) | |
| ax[i, 0].set_title(f"Original Image (class: {v['prediction']}") | |
| ax[i, 0].axis("off") | |
| ax[i, 1].imshow(v["gradcam"]) | |
| ax[i, 1].set_title("GradCAM") | |
| ax[i, 1].axis("off") | |
| ax[i, 2].imshow(v["backpropagation"]) | |
| ax[i, 2].set_title("Backpropagation") | |
| ax[i, 2].axis("off") | |
| ax[i, 3].imshow(v["guided_gradcam"]) | |
| ax[i, 3].set_title("Guided GradCAM") | |
| ax[i, 3].axis("off") | |
| plt.tight_layout() | |
| plt.savefig(path, bbox_inches="tight") | |
| plt.close(fig) | |
| model, image_transform = get_model(model_name) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
| model.train() | |
| target_layers = [model.conv_head] | |
| gbp_model = GuidedBackpropReLUModel(model=model, device="cpu") | |
| yolo_model = YOLO(YOLO_MODEL_WEIGHTS) | |
| def get_results(img_path=None, img_for_testing=None, od=False, remove_bg=False): | |
| if img_path is None and img_for_testing is None: | |
| raise ValueError("Either img_path or img_for_testing should be provided.") | |
| if img_path is not None: | |
| image = Image.open(img_path) | |
| if img_for_testing is not None: | |
| image = Image.fromarray(img_for_testing) | |
| result_paths = [] | |
| if od: | |
| results = yolo_model(img_path if img_path else img_for_testing) | |
| for i, result in enumerate(results): | |
| unique_id = uuid.uuid4().hex | |
| save_path = f"/tmp/with-bg-result-{unique_id}.png" | |
| bbox = result.boxes.xyxy[0].cpu().numpy().astype(int) | |
| bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) | |
| if remove_bg: | |
| bbox_image = remove(bbox_image).convert("RGB") | |
| bbox_image = Image.fromarray( | |
| np.where( | |
| np.array(bbox_image) == [0, 0, 0], | |
| [255, 255, 255], | |
| np.array(bbox_image), | |
| ).astype(np.uint8) | |
| ) | |
| res = make_prediction_and_explain(bbox_image) | |
| save_explanation_results(res, save_path) | |
| result_paths.append(save_path) | |
| else: | |
| unique_id = uuid.uuid4().hex | |
| save_path = f"/tmp/with-bg-result-{unique_id}.png" | |
| res = make_prediction_and_explain(image) | |
| save_explanation_results(res, save_path) | |
| result_paths.append(save_path) | |
| return result_paths | |
| if __name__ == "__main__": | |
| # Actual logic | |
| reproduce() | |
| model, image_transform = get_model(model_name) | |
| model.load_state_dict(torch.load(model_path)) | |
| model.train() | |
| target_layers = [model.conv_head] | |
| gbp_model = GuidedBackpropReLUModel(model=model, device="cpu") | |
| yolo_model = YOLO(YOLO_MODEL_WEIGHTS) | |
| for IMAGE_PATH in glob("samples/*"): | |
| start = time.perf_counter() | |
| results = yolo_model(IMAGE_PATH) | |
| image = Image.open(IMAGE_PATH) | |
| for i, result in enumerate(results): | |
| save_path = IMAGE_PATH.replace( | |
| "samples/", f"sample-results/with-white-bg-result-{i:02d}-" | |
| ) | |
| bbox = result.boxes.xyxy[0].cpu().numpy().astype(int) | |
| bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) | |
| # bbox_image = remove(bbox_image).convert("RGB") | |
| # bbox_image = Image.fromarray( | |
| # np.where( | |
| # np.array(bbox_image) == [0, 0, 0], | |
| # [255, 255, 255], | |
| # np.array(bbox_image), | |
| # ).astype(np.uint8) | |
| # ) | |
| res = make_prediction_and_explain(bbox_image) | |
| save_explanation_results(res, save_path) | |
| end = time.perf_counter() - start | |
| print(f"Completed in {end}s") | |