import streamlit as st import torch import torchvision import torchmetrics import pytorch_lightning as pl import numpy as np import cv2 import time import pydicom import nibabel as nib import io from torchvision import transforms from PIL import Image # Load the trained model class PneumoniaModel(pl.LightningModule): def __init__(self): super(PneumoniaModel, self).__init__() self.model = torchvision.models.resnet18() self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.model.fc = torch.nn.Linear(in_features=512, out_features=1, bias=True) self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0])) self.val_acc = torchmetrics.Accuracy(task="binary") self.val_auc = torchmetrics.AUROC(task="binary") self.val_outputs = [] def forward(self, data): return self.model(data) def validation_step(self, batch, batch_idx): x_ray, label = batch label = label.float() pred = self(x_ray)[:, 0] loss = self.loss_fn(pred, label) self.val_outputs.append({"preds": pred, "targets": label}) return loss def on_validation_epoch_end(self): all_preds = torch.cat([x["preds"] for x in self.val_outputs]).cpu().numpy() all_targets = torch.cat([x["targets"] for x in self.val_outputs]).cpu().numpy() self.val_outputs.clear() def configure_optimizers(self): return torch.optim.Adam(self.model.parameters(), lr=1e-4) # Load trained model weights model = PneumoniaModel() checkpoint = torch.load("weights_3.ckpt", map_location=torch.device('cpu'), weights_only=False) state_dict = checkpoint["state_dict"] model.load_state_dict(state_dict) model.eval() # Preprocessing function def preprocess_image(image): transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) return transform(image).unsqueeze(0) # Function to load and preprocess different file types def load_image(file_path, file_type): file_type = file_type.lower() try: if file_type in ["png", "jpg", "jpeg"]: # For file objects from streamlit if hasattr(file_path, 'read'): image = Image.open(file_path).convert("L") # Convert to grayscale else: image = Image.open(file_path).convert("L") image = np.array(image) elif file_type == "dcm": # For file objects from streamlit if hasattr(file_path, 'read'): # Create a temporary BytesIO object temp_file = io.BytesIO(file_path.read()) file_path.seek(0) # Reset pointer for future reads dicom_data = pydicom.dcmread(temp_file) else: dicom_data = pydicom.dcmread(file_path) image = dicom_data.pixel_array elif file_type in ["nii", "nii.gz"]: # For file objects from streamlit if hasattr(file_path, 'read'): # We need to save temporarily for nibabel with open("temp_file." + file_type, "wb") as f: f.write(file_path.read()) file_path.seek(0) # Reset pointer for future reads nifti_data = nib.load("temp_file." + file_type) # Clean up the temp file import os try: os.remove("temp_file." + file_type) except: pass # Ignore cleanup errors else: nifti_data = nib.load(file_path) image = nifti_data.get_fdata() image = np.squeeze(image) # Only one squeeze needed else: return None # Common processing for all image types # Normalize to 0-255 range if needed if image.max() > 1.0 and image.max() <= 255: # Already in 0-255 range, no need to normalize pass else: # Normalize to 0-255 image = np.uint8(255 * (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-10)) # Added small value to prevent division by zero # Resize to model's expected input size image = cv2.resize(image, (256, 256)) # Apply the preprocessing and return tensor return preprocess_image(image) except Exception as e: import traceback st.error(f"Error processing image: {str(e)}") st.code(traceback.format_exc()) return None # Streamlit Web App st.set_page_config( page_title="PneumoFind", page_icon="🫁", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Header st.markdown("