File size: 3,368 Bytes
371de10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt

# Define the Custom U-Net Model
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2)
        )

        self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        enc = self.encoder(x)
        mid = self.middle(enc)
        dec = self.decoder(mid)
        output = self.final_conv(dec)
        return self.sigmoid(output)

# Initialize Model
model = UNet(in_channels=3, out_channels=1)
model.eval()

# Preprocess Images
def preprocess_image(image):
    preprocess = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return preprocess(image).unsqueeze(0)

# Prediction Function
def predict_flood(image_terrain, image_rainfall):
    image_terrain = Image.open(image_terrain).convert("RGB")
    image_rainfall = Image.open(image_rainfall).convert("RGB")

    terrain_tensor = preprocess_image(image_terrain)
    rainfall_tensor = preprocess_image(image_rainfall)

    combined_tensor = (terrain_tensor + rainfall_tensor) / 2

    with torch.no_grad():
        output = model(combined_tensor)

    output_predictions = (output.squeeze().cpu().numpy() > 0.5).astype(np.uint8)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(output_predictions, cmap='jet', alpha=0.5)
    ax.set_title("Predicted Flooded Area")
    ax.axis("off")

    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    ax.margins(0, 0)
    fig.canvas.draw()

    output_image = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
    output_image = output_image.convert("RGB")

    return output_image

# Gradio Interface
def create_gradio_interface():
    inputs = [
        gr.Image(type="pil", label="Upload Terrain Image (RGB)"),
        gr.Image(type="pil", label="Upload Rainfall Image (RGB)")
    ]
    outputs = gr.Image(type="pil", label="Flood Prediction Output")

    gr.Interface(
        fn=predict_flood,
        inputs=inputs,
        outputs=outputs,
        live=True,
        description="Upload terrain and rainfall images to predict flood areas."
    ).launch()

if __name__ == "__main__":
    create_gradio_interface()