Engineer786 commited on
Commit
371de10
·
verified ·
1 Parent(s): 9150bcc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ # Define the Custom U-Net Model
10
+ class UNet(nn.Module):
11
+ def __init__(self, in_channels=3, out_channels=1):
12
+ super(UNet, self).__init__()
13
+
14
+ self.encoder = nn.Sequential(
15
+ nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
18
+ nn.ReLU(inplace=True),
19
+ nn.MaxPool2d(kernel_size=2, stride=2)
20
+ )
21
+
22
+ self.middle = nn.Sequential(
23
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
26
+ nn.ReLU(inplace=True),
27
+ nn.MaxPool2d(kernel_size=2, stride=2)
28
+ )
29
+
30
+ self.decoder = nn.Sequential(
31
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
32
+ nn.ReLU(inplace=True),
33
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
34
+ nn.ReLU(inplace=True),
35
+ nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2)
36
+ )
37
+
38
+ self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)
39
+ self.sigmoid = nn.Sigmoid()
40
+
41
+ def forward(self, x):
42
+ enc = self.encoder(x)
43
+ mid = self.middle(enc)
44
+ dec = self.decoder(mid)
45
+ output = self.final_conv(dec)
46
+ return self.sigmoid(output)
47
+
48
+ # Initialize Model
49
+ model = UNet(in_channels=3, out_channels=1)
50
+ model.eval()
51
+
52
+ # Preprocess Images
53
+ def preprocess_image(image):
54
+ preprocess = transforms.Compose([
55
+ transforms.Resize((256, 256)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
58
+ ])
59
+ return preprocess(image).unsqueeze(0)
60
+
61
+ # Prediction Function
62
+ def predict_flood(image_terrain, image_rainfall):
63
+ image_terrain = Image.open(image_terrain).convert("RGB")
64
+ image_rainfall = Image.open(image_rainfall).convert("RGB")
65
+
66
+ terrain_tensor = preprocess_image(image_terrain)
67
+ rainfall_tensor = preprocess_image(image_rainfall)
68
+
69
+ combined_tensor = (terrain_tensor + rainfall_tensor) / 2
70
+
71
+ with torch.no_grad():
72
+ output = model(combined_tensor)
73
+
74
+ output_predictions = (output.squeeze().cpu().numpy() > 0.5).astype(np.uint8)
75
+
76
+ fig, ax = plt.subplots(figsize=(6, 6))
77
+ ax.imshow(output_predictions, cmap='jet', alpha=0.5)
78
+ ax.set_title("Predicted Flooded Area")
79
+ ax.axis("off")
80
+
81
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
82
+ ax.margins(0, 0)
83
+ fig.canvas.draw()
84
+
85
+ output_image = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
86
+ output_image = output_image.convert("RGB")
87
+
88
+ return output_image
89
+
90
+ # Gradio Interface
91
+ def create_gradio_interface():
92
+ inputs = [
93
+ gr.Image(type="pil", label="Upload Terrain Image (RGB)"),
94
+ gr.Image(type="pil", label="Upload Rainfall Image (RGB)")
95
+ ]
96
+ outputs = gr.Image(type="pil", label="Flood Prediction Output")
97
+
98
+ gr.Interface(
99
+ fn=predict_flood,
100
+ inputs=inputs,
101
+ outputs=outputs,
102
+ live=True,
103
+ description="Upload terrain and rainfall images to predict flood areas."
104
+ ).launch()
105
+
106
+ if __name__ == "__main__":
107
+ create_gradio_interface()