AutismLens / app.py
JanaAlbader's picture
Update app.py
31f73de verified
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
# Define labels
labels = ['Healthy', 'Autistic', 'NDD']
# Load model
model = models.convnext_tiny(weights=None)
model.classifier[2] = torch.nn.Sequential(
torch.nn.Linear(model.classifier[2].in_features, 512),
torch.nn.GELU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(512, 3)
)
model.load_state_dict(torch.load("autism_model_weights.pth", map_location=torch.device("cpu")))
model.eval()
# Grad-CAM setup
target_layers = [model.features[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
# Image preprocessing
def preprocess(img):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
return transform(img).unsqueeze(0)
# Prediction function
def classify(image):
input_tensor = preprocess(image)
with torch.no_grad():
outputs = model(input_tensor)
probs = F.softmax(outputs[0], dim=0)
top_probs, top_idxs = torch.topk(probs, 3)
preds = {labels[i]: float(top_probs[j]) for j, i in enumerate(top_idxs)}
# Grad-CAM
grayscale_cam = cam(input_tensor=input_tensor)[0]
rgb_img = np.array(image.resize((224, 224))) / 255.0
cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
return preds, Image.fromarray(cam_img)
# Gradio interface
demo = gr.Interface(
fn=classify,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")],
title="AutismLens",
description="Generate new test"
)
demo.launch()