FloodPrediction / app.py
EngrKashifKhan's picture
Create app.py
378a667 verified
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from unet_model import UNet # Make sure unet_model.py is also uploaded
# Load the model
model = UNet()
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval()
st.title("๐ŸŒŠ Flood Prediction App")
st.write("Upload a satellite image, and the model will predict flood-affected areas.")
# Image uploader
uploaded_file = st.file_uploader("Choose a satellite image", type=["jpg", "png", "jpeg"])
if uploaded_file:
# Load and display input image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Preprocess image
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
input_tensor = transform(image).unsqueeze(0)
# Predict flood mask
with torch.no_grad():
output = model(input_tensor)[0, 0].numpy()
# Binarize and display mask
mask = (output > 0.5).astype(np.uint8) * 255
mask_image = Image.fromarray(mask).resize(image.size)
st.image(mask_image, caption="Predicted Flood Mask", use_column_width=True)