Spaces:
Build error
Build error
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| import os | |
| from osgeo import gdal | |
| # Load the pretrained model | |
| def load_model(): | |
| model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', | |
| pretrained=True, progress=True) | |
| model.eval() | |
| return model | |
| # Function to load large TIFF images | |
| def load_tiff_image(tiff_path): | |
| try: | |
| dataset = gdal.Open(tiff_path) | |
| if dataset is None: | |
| st.error("Failed to load the TIFF image. Please check the file format.") | |
| return None | |
| band = dataset.GetRasterBand(1) # Assuming grayscale or single band | |
| image = band.ReadAsArray() | |
| return image | |
| except Exception as e: | |
| st.error(f"Error loading image: {e}") | |
| return None | |
| # Preprocess image | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((256, 256)), # Resize image for model input | |
| transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| return image_tensor | |
| # Post-process prediction to display | |
| def postprocess_prediction(pred): | |
| pred = torch.sigmoid(pred) | |
| pred = pred.squeeze().detach().numpy() # Remove batch dimension | |
| pred = (pred > 0.5).astype(np.uint8) # Binary mask thresholding | |
| return pred | |
| # Streamlit app | |
| st.title("TIFF Image Upload and Model Prediction") | |
| # Upload image | |
| uploaded_file = st.file_uploader("Upload a large TIFF image (up to 5GB)", type=["tiff"]) | |
| if uploaded_file is not None: | |
| with open("temp_image.tiff", "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| tiff_image = load_tiff_image("temp_image.tiff") | |
| if tiff_image is not None: | |
| st.write("Original Image") | |
| st.image(tiff_image, caption="Uploaded Image", use_column_width=True) | |
| model = load_model() | |
| image = Image.fromarray(tiff_image) | |
| image_tensor = preprocess_image(image) | |
| with torch.no_grad(): | |
| prediction = model(image_tensor) | |
| pred_image = postprocess_prediction(prediction) | |
| st.write("Model Prediction") | |
| st.image(pred_image, caption="Predicted Image", use_column_width=True) | |
| os.remove("temp_image.tiff") | |