import torch import torch.nn as nn from torchvision import transforms from PIL import Image import gradio as gr import numpy as np import matplotlib.pyplot as plt # Define the Custom U-Net Model class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=1): super(UNet, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.middle = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.decoder = nn.Sequential( nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2) ) self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1) self.sigmoid = nn.Sigmoid() def forward(self, x): enc = self.encoder(x) mid = self.middle(enc) dec = self.decoder(mid) output = self.final_conv(dec) return self.sigmoid(output) # Initialize Model model = UNet(in_channels=3, out_channels=1) model.eval() # Preprocess Images def preprocess_image(image): preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return preprocess(image).unsqueeze(0) # Prediction Function def predict_flood(image_terrain, image_rainfall): image_terrain = Image.open(image_terrain).convert("RGB") image_rainfall = Image.open(image_rainfall).convert("RGB") terrain_tensor = preprocess_image(image_terrain) rainfall_tensor = preprocess_image(image_rainfall) combined_tensor = (terrain_tensor + rainfall_tensor) / 2 with torch.no_grad(): output = model(combined_tensor) output_predictions = (output.squeeze().cpu().numpy() > 0.5).astype(np.uint8) fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(output_predictions, cmap='jet', alpha=0.5) ax.set_title("Predicted Flooded Area") ax.axis("off") plt.subplots_adjust(left=0, right=1, top=1, bottom=0) ax.margins(0, 0) fig.canvas.draw() output_image = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) output_image = output_image.convert("RGB") return output_image # Gradio Interface def create_gradio_interface(): inputs = [ gr.Image(type="pil", label="Upload Terrain Image (RGB)"), gr.Image(type="pil", label="Upload Rainfall Image (RGB)") ] outputs = gr.Image(type="pil", label="Flood Prediction Output") gr.Interface( fn=predict_flood, inputs=inputs, outputs=outputs, live=True, description="Upload terrain and rainfall images to predict flood areas." ).launch() if __name__ == "__main__": create_gradio_interface()