whitney0507 commited on
Commit
7c50066
·
verified ·
1 Parent(s): 5c5832b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +445 -442
app.py CHANGED
@@ -1,442 +1,445 @@
1
- import gradio as gr
2
- import numpy as np
3
- import torch
4
- import torch.nn as nn
5
- from torchvision import transforms
6
- import requests
7
- import os
8
- from PIL import Image
9
- from collections import OrderedDict
10
- from torchvision import models
11
- import torch.nn.functional as F
12
- import matplotlib.pyplot as plt
13
- import cv2
14
- import io
15
- # Import CSS and URL File
16
- css_file_path = os.path.join(os.path.dirname(__file__), "ui.css")
17
- with open(css_file_path,"r") as f:
18
- custom_css = f.read()
19
- # HTML Design
20
- html_welcome_page = """
21
- <div class="container">
22
- <div class="inner-container">
23
- <h1 class="title-text">Welcome to RemoveWeed Weed Detection System</h1>
24
- <img src="https://i.ibb.co/fY1nk315/image-2.png" alt="RemoveWeed Logo" class="logo-container"/>
25
- <p class="description-text">
26
- Project Aim: This system is designed to optimize rice planting schedules with broad-leaved weed detection using machine learning.
27
- </p>
28
- <p class="description-text">
29
- Designed by: Whitney Lim Wan Yee (TP068221)
30
- </p>
31
- </div>
32
- </div>
33
- """
34
- html_system_page ="""
35
- <div class="container">
36
- <img src="https://i.ibb.co/KxMMTmxG/Screenshot-2025-03-28-224907.png" alt="RemoveWeed Logo" class="logo-container-system"/>
37
- <h1 class="system-page-title">RemoveWeed System Overview</h1>
38
- <p class="system-page-description">
39
- This system is designed to help farmers detect broad-leaved weeds in rice fields using machine learning techniques.
40
- The aim is to optimize rice planting schedules and improve crop yield.
41
- </p>
42
- </div>
43
- """
44
- html_project_description = """
45
- <div class="project-container">
46
- <h1 class="project-title">- 🌿 About Project 🌿 -</h1>
47
-
48
- <div class="upper-content">
49
- <div class="left-upper-column">
50
- <div class="chart">
51
- <img src="https://i.ibb.co/j9Ch3xnC/1312103.png" alt="Agricultural consumption of herbicides worldwide from 1990 to 2022" class="chart-image">
52
- <p class="chart-caption">Resource: Statista (2024) - Agricultural consumption of herbicides worldwide from 1990 to 2022 (in 1,000 metric tons)</p>
53
- </div>
54
- </div>
55
- <div class="right-upper-column">
56
- <div class="herbicide-description">
57
- <h2 class="herbicide-title">Herbicide Use Soars: A Shocking Yearly Increase!</h2>
58
- <p class="herbicide-text">
59
- Statista (2024) revealed that global herbicide consumption has reached <span class="bold-red">1.94 million</span> metric tons. To control dock weed in farming fields, the application of herbicides can cause <span class="bold-red">delays</span> in rice planting schedules ranging from <span class="bold-red">7 to 30 days</span>.
60
- </p>
61
- </div>
62
- </div>
63
- </div>
64
-
65
- <div class="middle-content">
66
- <div class="left-middle-column">
67
- <div class="objective-description">
68
- <h2 class="objective-title">Why Choose RemoveWeed?</h2>
69
- <p class="objective-text">
70
- RemoveWeed is a system designed to detect broad-leaved dock weed in paddy fields. It uses object detection like <span class="bold-red">Single Shot Detection (SSD)</span> model, along with instance segmentation models like <span class="bold-red">U-Net</span> and <span class="bold-red">Fully Convolutional Neural Network (FCNN</span>, to predict the presence of dock weed.
71
- </p>
72
- </div>
73
- </div>
74
- <div class="right-middle-column">
75
- <div class="carousel-wrapper">
76
- <div class="carousel-container">
77
- <p class="carousel-title">Broad-leaved Dock Weed in Paddy Field</p>
78
- <div class="carousel">
79
- <div class="image-one"></div>
80
- <div class="image-two"></div>
81
- <div class="image-three"></div>
82
- </div>
83
- </div>
84
- </div>
85
- </div>
86
- </div>
87
-
88
- <div class="bottom-content">
89
- <div class="left-bottom-column">
90
- <div class="Proceed-To-Detection">
91
- <img src="https://i.ibb.co/Txb9LFf5/agriculture-tan.jpg" alt="Model Training" class="model-image">
92
- </div>
93
- </div>
94
- <div class="right-bottom-column">
95
- <div class="benefits-description">
96
- <h2 class="benefits-title">Potential Benefits</h2>
97
- <ul class="benefits-list">
98
- <li>Cost Savings 💰</li>
99
- <li>Reduce Labor and Manual Monitoring Cost 💹</li>
100
- <li>Increase Profitability by Rice Planting Scheduling Advice 📈</li>
101
- <li>Provide Sustainable Practices in Agriculture 🧑‍🌾</li>
102
- <li>Reduce Herbicide Pollution ☢️</li>
103
- </ul>
104
- </div>
105
- </div>
106
- </div>
107
- </div>
108
- """
109
- html_author_review_page = """
110
- <div class="author-section">
111
- <h1 class="author-title">- Project Owner Introduction -</h1>
112
-
113
- <div class="author-content">
114
- <div class="author-image-container">
115
- <img src="https://i.ibb.co/4RZW1Pq4/Wanyu.jpg" alt="Whitney Lim Wan Yee" class="author-image">
116
- </div>
117
-
118
- <div class="author-bio">
119
- <p class="author-text">
120
- Whitney Lim Wan Yee is a student at Asia Pacific University (APU), pursuing Year 3 Computer Science specialization in Data Analytics. She is passionate about machine learning and its applications in agriculture.
121
- </p>
122
-
123
- <div class="social-links">
124
- <a href="https://www.linkedin.com/in/whitneylimwanyee/" target="_blank" class="social-link">
125
- <img src="https://images.rawpixel.com/image_png_800/czNmcy1wcml2YXRlL3Jhd3BpeGVsX2ltYWdlcy93ZWJzaXRlX2NvbnRlbnQvbHIvdjk4Mi1kMy0xMC5wbmc.png" alt="LinkedIn" class="social-icon">
126
- <span>LinkedIn Profile</span>
127
- </a>
128
-
129
- <a href="https://www.kaggle.com/whitneylimwanyee" target="_blank" class="social-link">
130
- <img src="https://cdn4.iconfinder.com/data/icons/logos-and-brands/512/189_Kaggle_logo_logos-512.png" alt="Kaggle" class="social-icon">
131
- <span>Kaggle Profile</span>
132
- </a>
133
-
134
- <button onclick="window.location.href='mailto:whitneylim0719@gmail.com'" class="social-link">
135
- <img src="https://static.vecteezy.com/system/resources/previews/016/716/465/non_2x/gmail-icon-free-png.png" alt="Email" class="social-icon">
136
- <span>Email Me</span>
137
- </button>
138
- <a href="https://drive.google.com/file/d/1SvbvzLpFQJjzX6_VPGS3NddzK0ksXE8r/view" target="_blank" class="social-link">
139
- <img src="https://cdn-icons-png.flaticon.com/512/8347/8347432.png" alt="Kaggle" class="social-icon">
140
- <span>My Resume</span>
141
- </a>
142
- </div>
143
- </div>
144
- </div>
145
- </div>
146
- """
147
-
148
-
149
- js_func = """
150
- function refresh() {
151
- const url = new URL(window.location);
152
-
153
- if (url.searchParams.get('__theme') !== 'light') {
154
- url.searchParams.set('__theme', 'light');
155
- window.location.href = url.href;
156
- }
157
- }
158
- """
159
- def choose_model(choice):
160
- if choice == "Instance Segmentation Model (U-Net)":
161
- return "You have selected U-Net"
162
- else:
163
- return "Invalid selection"
164
- # Gradio Interface
165
- def gradio_interface(selected_model, uploaded_image):
166
- # This will call the predict function and display the results
167
- return predict(selected_model, uploaded_image)
168
-
169
- with gr.Blocks(css=custom_css,js=js_func) as demo:
170
- # State to track current page
171
- page = gr.State(value="welcome")
172
-
173
- # Welcome page container
174
- with gr.Group(visible=True, elem_classes="gradio-container") as welcome_page:
175
- gr.HTML(html_welcome_page) # Insert HTML structure
176
- start_trial_button = gr.Button("Start Trial", variant="primary", elem_classes="trial-button")
177
-
178
- # System description page container (initially hidden)
179
- with gr.Group(visible=False) as system_page:
180
- gr.HTML(html_system_page)
181
- tabs = gr.Tabs()
182
- with tabs:
183
- with gr.TabItem("Project Description"):
184
- tab_state = gr.State(value=0)
185
- gr.HTML(html_project_description)
186
- with gr.TabItem("Model Playground"):
187
- gr.Markdown("""
188
- ### Model Playground:
189
- This section allows users to interact with the model and test its capabilities.
190
- """)
191
- # Model selection radio buttons
192
- radio = gr.Radio(choices=["Instance Segmentation Model (U-Net)"], label="Select Model", elem_classes="model-selection")
193
-
194
- # Output for model selection result
195
- output = gr.Textbox(label="Model Selection Result", elem_classes="model-selection-output")
196
-
197
- # Trigger the choose_model function when a model is selected
198
- radio.change(fn=choose_model, inputs=radio, outputs=output)
199
-
200
- # Image input and upload button
201
- img_input = gr.Image(type="numpy", label="Upload Image", elem_classes="image-input")
202
- upload_image_button = gr.Button("Start Prediction", variant="primary", elem_classes="upload-button")
203
-
204
- # Predicted output
205
- img_output = gr.Image(label="Predicted Image", elem_classes="image-output")
206
-
207
- # Predict and show output when image is uploaded
208
- upload_image_button.click(fn=gradio_interface, inputs=[radio, img_input], outputs=img_output)
209
-
210
-
211
- with gr.TabItem("Open Source API Link"):
212
- gr.Markdown("""
213
- ### Open Source API Link:
214
- This section provides access to the open-source API for the weed detection model.
215
- """)
216
- gr.Markdown("### API Documentation:")
217
- with gr.TabItem("Contact and Review"):
218
- gr.HTML(html_author_review_page)
219
- back_button = gr.Button("Back", variant="secondary",elem_classes="back-button")
220
-
221
-
222
- # Navigation functions
223
- def go_to_system_page():
224
- print("Going to system page")
225
- return gr.update(visible=False), gr.update(visible=True)
226
-
227
- def go_to_welcome_page():
228
- print("Going to welcome page")
229
- return gr.update(visible=True), gr.update(visible=False)
230
-
231
- def process_image(uploaded_image):
232
- # If the image is passed as a numpy array, convert it to a PIL image
233
- if isinstance(uploaded_image, np.ndarray):
234
- image = Image.fromarray(uploaded_image)
235
- elif isinstance(uploaded_image, Image.Image):
236
- image = uploaded_image
237
- else:
238
- raise ValueError("Uploaded image must be either a numpy array or a PIL Image.")
239
-
240
- # Define the necessary transformations
241
- transform = transforms.Compose([
242
- # transforms.Resize((256, 256)), # Resize according to your model's input size
243
- transforms.ToTensor(),
244
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
245
- ])
246
-
247
- # Apply transformations and add batch dimension
248
- image = transform(image).unsqueeze(0)
249
-
250
- return image
251
-
252
-
253
- class DoubleConv(nn.Module):
254
- def __init__(self, in_channels, out_channels):
255
- super().__init__()
256
- self.double_conv = nn.Sequential(
257
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
258
- nn.BatchNorm2d(out_channels),
259
- nn.ReLU(inplace=True),
260
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
261
- nn.BatchNorm2d(out_channels),
262
- nn.ReLU(inplace=True)
263
- )
264
-
265
- def forward(self, x):
266
- return self.double_conv(x)
267
-
268
- class Down(nn.Module):
269
- def __init__(self, in_channels, out_channels):
270
- super().__init__()
271
- self.maxpool_conv = nn.Sequential(
272
- nn.MaxPool2d(2),
273
- DoubleConv(in_channels, out_channels)
274
- )
275
-
276
- def forward(self, x):
277
- return self.maxpool_conv(x)
278
-
279
- class Up(nn.Module):
280
- def __init__(self, in_channels, out_channels, bilinear=True):
281
- super().__init__()
282
- if bilinear:
283
- self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
284
- else:
285
- self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
286
-
287
- self.conv = DoubleConv(in_channels, out_channels)
288
-
289
- def forward(self, x1, x2):
290
- x1 = self.up(x1)
291
- # Resize x1 to match x2
292
- diffY = x2.size()[2] - x1.size()[2]
293
- diffX = x2.size()[3] - x1.size()[3]
294
- x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
295
- diffY // 2, diffY - diffY // 2])
296
- x = torch.cat([x2, x1], dim=1)
297
- return self.conv(x)
298
-
299
- class OutConv(nn.Module):
300
- def __init__(self, in_channels, out_channels):
301
- super().__init__()
302
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
303
-
304
- def forward(self, x):
305
- return self.conv(x)
306
-
307
- class UNet(nn.Module):
308
- def __init__(self, n_channels=3, n_classes=1, bilinear=True):
309
- super().__init__()
310
- self.n_channels = n_channels
311
- self.n_classes = n_classes
312
- self.bilinear = bilinear
313
-
314
- # Encoder
315
- self.inc = DoubleConv(n_channels, 64)
316
- self.down1 = Down(64, 128)
317
- self.down2 = Down(128, 256)
318
- self.down3 = Down(256, 512)
319
- factor = 2 if bilinear else 1
320
- self.down4 = Down(512, 1024 // factor)
321
-
322
- # Decoder
323
- self.up1 = Up(1024, 512 // factor, bilinear)
324
- self.up2 = Up(512, 256 // factor, bilinear)
325
- self.up3 = Up(256, 128 // factor, bilinear)
326
- self.up4 = Up(128, 64, bilinear)
327
- self.outc = OutConv(64, n_classes)
328
-
329
- def forward(self, x):
330
- x1 = self.inc(x)
331
- x2 = self.down1(x1)
332
- x3 = self.down2(x2)
333
- x4 = self.down3(x3)
334
- x5 = self.down4(x4)
335
-
336
- x = self.up1(x5, x4)
337
- x = self.up2(x, x3)
338
- x = self.up3(x, x2)
339
- x = self.up4(x, x1)
340
- logits = self.outc(x)
341
- return torch.sigmoid(logits)
342
- def init_weights(self):
343
- # Initialize with Kaiming initialization
344
- def init_fn(m):
345
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
346
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
347
-
348
- self.apply(init_fn)
349
-
350
- def load_UNet_model(model_path):
351
- print(f"Loading model from {model_path}")
352
- model = torch.load(model_path, weights_only=False, map_location=torch.device('cpu')) # Load the model (entire model saved with torch.save)
353
- model.eval() # Set the model to evaluation mode
354
- return model
355
-
356
- def predict(selected_model, uploaded_image):
357
- if selected_model == "Instance Segmentation Model (U-Net)":
358
- print("Predicting using U-Net")
359
- model_path = "UNet_Model.pth" # Path to your trained model
360
- else:
361
- print("Invalid model selected")
362
- return None
363
-
364
- # Visualize predictions (call visualize_predictions)
365
- return visualize_predictions(uploaded_image, model_path)
366
- # Visualization function for contours and IoU
367
-
368
- def visualize_predictions(uploaded_image, model_path="UNet.pth"):
369
- model = load_UNet_model(model_path)
370
- image = process_image(uploaded_image)
371
-
372
- # Make prediction
373
- with torch.no_grad():
374
- output = model(image)
375
- binary_pred = (output > 0.5).float().cpu().numpy() # Prediction as a binary mask
376
- pred_prob = output.squeeze().cpu().numpy() # Prediction probabilities (for heatmap)
377
-
378
- # Visualization part (assumes ground truth is available)
379
- fig, axes = plt.subplots(1, 4, figsize=(16, 4))
380
-
381
- # Original image
382
- img = np.array(uploaded_image) / 255.0 # Normalize the image to [0, 1]
383
- axes[0].imshow(img)
384
- axes[0].set_title('Original Image')
385
- axes[0].axis('off')
386
-
387
- # Ground truth (this is just an example, you should provide the actual mask)
388
- # For the sake of demonstration, we use a dummy mask
389
- ground_truth = np.zeros_like(binary_pred[0, 0])
390
- axes[1].imshow(ground_truth, cmap='gray')
391
- axes[1].set_title('Ground Truth')
392
- axes[1].axis('off')
393
-
394
- # Prediction Probability
395
- axes[2].imshow(pred_prob, cmap='jet', vmin=0, vmax=1)
396
- axes[2].set_title('Prediction Probability')
397
- axes[2].axis('off')
398
-
399
- # Calculate IoU (Intersection over Union)
400
- intersection = np.logical_and(binary_pred[0, 0] > 0.5, ground_truth > 0.5).sum()
401
- union = np.logical_or(binary_pred[0, 0] > 0.5, ground_truth > 0.5).sum()
402
- iou = intersection / union if union > 0 else 0
403
- axes[3].imshow(img)
404
- contours, _ = cv2.findContours(binary_pred[0, 0].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
405
- contour_img = np.zeros_like(binary_pred[0, 0])
406
- cv2.drawContours(contour_img, contours, -1, 1, 2)
407
-
408
- # Add the contour overlay with IoU text
409
- axes[3].imshow(contour_img, cmap='Reds', alpha=0.5)
410
- axes[3].set_title(f'Prediction Contour')
411
- axes[3].axis('off')
412
-
413
- plt.tight_layout()
414
-
415
- # Save the figure to a BytesIO object and return it as an image
416
- buf = io.BytesIO()
417
- plt.savefig(buf, format='png')
418
- buf.seek(0)
419
- img = Image.open(buf)
420
- return img
421
-
422
- # Connect buttons to navigation functions
423
- start_trial_button.click(
424
- fn=go_to_system_page,
425
- inputs=None, # Pass the current page state
426
- outputs=[welcome_page, system_page]
427
- )
428
-
429
- back_button.click(
430
- fn=go_to_welcome_page,
431
- inputs=None, # Pass the current page state
432
- outputs=[welcome_page, system_page]
433
- )
434
- upload_image_button.click(
435
- fn=predict,
436
- inputs=[radio, img_input],
437
- outputs=img_output
438
- )
439
-
440
-
441
-
442
- demo.launch(share=True)
 
 
 
 
1
+ os.system("pip install torch")
2
+
3
+ import gradio as gr
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import transforms
9
+ import requests
10
+ import os
11
+ from PIL import Image
12
+ from collections import OrderedDict
13
+ from torchvision import models
14
+ import torch.nn.functional as F
15
+ import matplotlib.pyplot as plt
16
+ import cv2
17
+ import io
18
+ # Import CSS and URL File
19
+ css_file_path = os.path.join(os.path.dirname(__file__), "ui.css")
20
+ with open(css_file_path,"r") as f:
21
+ custom_css = f.read()
22
+ # HTML Design
23
+ html_welcome_page = """
24
+ <div class="container">
25
+ <div class="inner-container">
26
+ <h1 class="title-text">Welcome to RemoveWeed Weed Detection System</h1>
27
+ <img src="https://i.ibb.co/fY1nk315/image-2.png" alt="RemoveWeed Logo" class="logo-container"/>
28
+ <p class="description-text">
29
+ Project Aim: This system is designed to optimize rice planting schedules with broad-leaved weed detection using machine learning.
30
+ </p>
31
+ <p class="description-text">
32
+ Designed by: Whitney Lim Wan Yee (TP068221)
33
+ </p>
34
+ </div>
35
+ </div>
36
+ """
37
+ html_system_page ="""
38
+ <div class="container">
39
+ <img src="https://i.ibb.co/KxMMTmxG/Screenshot-2025-03-28-224907.png" alt="RemoveWeed Logo" class="logo-container-system"/>
40
+ <h1 class="system-page-title">RemoveWeed System Overview</h1>
41
+ <p class="system-page-description">
42
+ This system is designed to help farmers detect broad-leaved weeds in rice fields using machine learning techniques.
43
+ The aim is to optimize rice planting schedules and improve crop yield.
44
+ </p>
45
+ </div>
46
+ """
47
+ html_project_description = """
48
+ <div class="project-container">
49
+ <h1 class="project-title">- 🌿 About Project 🌿 -</h1>
50
+
51
+ <div class="upper-content">
52
+ <div class="left-upper-column">
53
+ <div class="chart">
54
+ <img src="https://i.ibb.co/j9Ch3xnC/1312103.png" alt="Agricultural consumption of herbicides worldwide from 1990 to 2022" class="chart-image">
55
+ <p class="chart-caption">Resource: Statista (2024) - Agricultural consumption of herbicides worldwide from 1990 to 2022 (in 1,000 metric tons)</p>
56
+ </div>
57
+ </div>
58
+ <div class="right-upper-column">
59
+ <div class="herbicide-description">
60
+ <h2 class="herbicide-title">Herbicide Use Soars: A Shocking Yearly Increase!</h2>
61
+ <p class="herbicide-text">
62
+ Statista (2024) revealed that global herbicide consumption has reached <span class="bold-red">1.94 million</span> metric tons. To control dock weed in farming fields, the application of herbicides can cause <span class="bold-red">delays</span> in rice planting schedules ranging from <span class="bold-red">7 to 30 days</span>.
63
+ </p>
64
+ </div>
65
+ </div>
66
+ </div>
67
+
68
+ <div class="middle-content">
69
+ <div class="left-middle-column">
70
+ <div class="objective-description">
71
+ <h2 class="objective-title">Why Choose RemoveWeed?</h2>
72
+ <p class="objective-text">
73
+ RemoveWeed is a system designed to detect broad-leaved dock weed in paddy fields. It uses object detection like <span class="bold-red">Single Shot Detection (SSD)</span> model, along with instance segmentation models like <span class="bold-red">U-Net</span> and <span class="bold-red">Fully Convolutional Neural Network (FCNN</span>, to predict the presence of dock weed.
74
+ </p>
75
+ </div>
76
+ </div>
77
+ <div class="right-middle-column">
78
+ <div class="carousel-wrapper">
79
+ <div class="carousel-container">
80
+ <p class="carousel-title">Broad-leaved Dock Weed in Paddy Field</p>
81
+ <div class="carousel">
82
+ <div class="image-one"></div>
83
+ <div class="image-two"></div>
84
+ <div class="image-three"></div>
85
+ </div>
86
+ </div>
87
+ </div>
88
+ </div>
89
+ </div>
90
+
91
+ <div class="bottom-content">
92
+ <div class="left-bottom-column">
93
+ <div class="Proceed-To-Detection">
94
+ <img src="https://i.ibb.co/Txb9LFf5/agriculture-tan.jpg" alt="Model Training" class="model-image">
95
+ </div>
96
+ </div>
97
+ <div class="right-bottom-column">
98
+ <div class="benefits-description">
99
+ <h2 class="benefits-title">Potential Benefits</h2>
100
+ <ul class="benefits-list">
101
+ <li>Cost Savings 💰</li>
102
+ <li>Reduce Labor and Manual Monitoring Cost 💹</li>
103
+ <li>Increase Profitability by Rice Planting Scheduling Advice 📈</li>
104
+ <li>Provide Sustainable Practices in Agriculture 🧑‍🌾</li>
105
+ <li>Reduce Herbicide Pollution ☢️</li>
106
+ </ul>
107
+ </div>
108
+ </div>
109
+ </div>
110
+ </div>
111
+ """
112
+ html_author_review_page = """
113
+ <div class="author-section">
114
+ <h1 class="author-title">- Project Owner Introduction -</h1>
115
+
116
+ <div class="author-content">
117
+ <div class="author-image-container">
118
+ <img src="https://i.ibb.co/4RZW1Pq4/Wanyu.jpg" alt="Whitney Lim Wan Yee" class="author-image">
119
+ </div>
120
+
121
+ <div class="author-bio">
122
+ <p class="author-text">
123
+ Whitney Lim Wan Yee is a student at Asia Pacific University (APU), pursuing Year 3 Computer Science specialization in Data Analytics. She is passionate about machine learning and its applications in agriculture.
124
+ </p>
125
+
126
+ <div class="social-links">
127
+ <a href="https://www.linkedin.com/in/whitneylimwanyee/" target="_blank" class="social-link">
128
+ <img src="https://images.rawpixel.com/image_png_800/czNmcy1wcml2YXRlL3Jhd3BpeGVsX2ltYWdlcy93ZWJzaXRlX2NvbnRlbnQvbHIvdjk4Mi1kMy0xMC5wbmc.png" alt="LinkedIn" class="social-icon">
129
+ <span>LinkedIn Profile</span>
130
+ </a>
131
+
132
+ <a href="https://www.kaggle.com/whitneylimwanyee" target="_blank" class="social-link">
133
+ <img src="https://cdn4.iconfinder.com/data/icons/logos-and-brands/512/189_Kaggle_logo_logos-512.png" alt="Kaggle" class="social-icon">
134
+ <span>Kaggle Profile</span>
135
+ </a>
136
+
137
+ <button onclick="window.location.href='mailto:whitneylim0719@gmail.com'" class="social-link">
138
+ <img src="https://static.vecteezy.com/system/resources/previews/016/716/465/non_2x/gmail-icon-free-png.png" alt="Email" class="social-icon">
139
+ <span>Email Me</span>
140
+ </button>
141
+ <a href="https://drive.google.com/file/d/1SvbvzLpFQJjzX6_VPGS3NddzK0ksXE8r/view" target="_blank" class="social-link">
142
+ <img src="https://cdn-icons-png.flaticon.com/512/8347/8347432.png" alt="Kaggle" class="social-icon">
143
+ <span>My Resume</span>
144
+ </a>
145
+ </div>
146
+ </div>
147
+ </div>
148
+ </div>
149
+ """
150
+
151
+
152
+ js_func = """
153
+ function refresh() {
154
+ const url = new URL(window.location);
155
+
156
+ if (url.searchParams.get('__theme') !== 'light') {
157
+ url.searchParams.set('__theme', 'light');
158
+ window.location.href = url.href;
159
+ }
160
+ }
161
+ """
162
+ def choose_model(choice):
163
+ if choice == "Instance Segmentation Model (U-Net)":
164
+ return "You have selected U-Net"
165
+ else:
166
+ return "Invalid selection"
167
+ # Gradio Interface
168
+ def gradio_interface(selected_model, uploaded_image):
169
+ # This will call the predict function and display the results
170
+ return predict(selected_model, uploaded_image)
171
+
172
+ with gr.Blocks(css=custom_css,js=js_func) as demo:
173
+ # State to track current page
174
+ page = gr.State(value="welcome")
175
+
176
+ # Welcome page container
177
+ with gr.Group(visible=True, elem_classes="gradio-container") as welcome_page:
178
+ gr.HTML(html_welcome_page) # Insert HTML structure
179
+ start_trial_button = gr.Button("Start Trial", variant="primary", elem_classes="trial-button")
180
+
181
+ # System description page container (initially hidden)
182
+ with gr.Group(visible=False) as system_page:
183
+ gr.HTML(html_system_page)
184
+ tabs = gr.Tabs()
185
+ with tabs:
186
+ with gr.TabItem("Project Description"):
187
+ tab_state = gr.State(value=0)
188
+ gr.HTML(html_project_description)
189
+ with gr.TabItem("Model Playground"):
190
+ gr.Markdown("""
191
+ ### Model Playground:
192
+ This section allows users to interact with the model and test its capabilities.
193
+ """)
194
+ # Model selection radio buttons
195
+ radio = gr.Radio(choices=["Instance Segmentation Model (U-Net)"], label="Select Model", elem_classes="model-selection")
196
+
197
+ # Output for model selection result
198
+ output = gr.Textbox(label="Model Selection Result", elem_classes="model-selection-output")
199
+
200
+ # Trigger the choose_model function when a model is selected
201
+ radio.change(fn=choose_model, inputs=radio, outputs=output)
202
+
203
+ # Image input and upload button
204
+ img_input = gr.Image(type="numpy", label="Upload Image", elem_classes="image-input")
205
+ upload_image_button = gr.Button("Start Prediction", variant="primary", elem_classes="upload-button")
206
+
207
+ # Predicted output
208
+ img_output = gr.Image(label="Predicted Image", elem_classes="image-output")
209
+
210
+ # Predict and show output when image is uploaded
211
+ upload_image_button.click(fn=gradio_interface, inputs=[radio, img_input], outputs=img_output)
212
+
213
+
214
+ with gr.TabItem("Open Source API Link"):
215
+ gr.Markdown("""
216
+ ### Open Source API Link:
217
+ This section provides access to the open-source API for the weed detection model.
218
+ """)
219
+ gr.Markdown("### API Documentation:")
220
+ with gr.TabItem("Contact and Review"):
221
+ gr.HTML(html_author_review_page)
222
+ back_button = gr.Button("Back", variant="secondary",elem_classes="back-button")
223
+
224
+
225
+ # Navigation functions
226
+ def go_to_system_page():
227
+ print("Going to system page")
228
+ return gr.update(visible=False), gr.update(visible=True)
229
+
230
+ def go_to_welcome_page():
231
+ print("Going to welcome page")
232
+ return gr.update(visible=True), gr.update(visible=False)
233
+
234
+ def process_image(uploaded_image):
235
+ # If the image is passed as a numpy array, convert it to a PIL image
236
+ if isinstance(uploaded_image, np.ndarray):
237
+ image = Image.fromarray(uploaded_image)
238
+ elif isinstance(uploaded_image, Image.Image):
239
+ image = uploaded_image
240
+ else:
241
+ raise ValueError("Uploaded image must be either a numpy array or a PIL Image.")
242
+
243
+ # Define the necessary transformations
244
+ transform = transforms.Compose([
245
+ # transforms.Resize((256, 256)), # Resize according to your model's input size
246
+ transforms.ToTensor(),
247
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
248
+ ])
249
+
250
+ # Apply transformations and add batch dimension
251
+ image = transform(image).unsqueeze(0)
252
+
253
+ return image
254
+
255
+
256
+ class DoubleConv(nn.Module):
257
+ def __init__(self, in_channels, out_channels):
258
+ super().__init__()
259
+ self.double_conv = nn.Sequential(
260
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
261
+ nn.BatchNorm2d(out_channels),
262
+ nn.ReLU(inplace=True),
263
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
264
+ nn.BatchNorm2d(out_channels),
265
+ nn.ReLU(inplace=True)
266
+ )
267
+
268
+ def forward(self, x):
269
+ return self.double_conv(x)
270
+
271
+ class Down(nn.Module):
272
+ def __init__(self, in_channels, out_channels):
273
+ super().__init__()
274
+ self.maxpool_conv = nn.Sequential(
275
+ nn.MaxPool2d(2),
276
+ DoubleConv(in_channels, out_channels)
277
+ )
278
+
279
+ def forward(self, x):
280
+ return self.maxpool_conv(x)
281
+
282
+ class Up(nn.Module):
283
+ def __init__(self, in_channels, out_channels, bilinear=True):
284
+ super().__init__()
285
+ if bilinear:
286
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
287
+ else:
288
+ self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
289
+
290
+ self.conv = DoubleConv(in_channels, out_channels)
291
+
292
+ def forward(self, x1, x2):
293
+ x1 = self.up(x1)
294
+ # Resize x1 to match x2
295
+ diffY = x2.size()[2] - x1.size()[2]
296
+ diffX = x2.size()[3] - x1.size()[3]
297
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
298
+ diffY // 2, diffY - diffY // 2])
299
+ x = torch.cat([x2, x1], dim=1)
300
+ return self.conv(x)
301
+
302
+ class OutConv(nn.Module):
303
+ def __init__(self, in_channels, out_channels):
304
+ super().__init__()
305
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
306
+
307
+ def forward(self, x):
308
+ return self.conv(x)
309
+
310
+ class UNet(nn.Module):
311
+ def __init__(self, n_channels=3, n_classes=1, bilinear=True):
312
+ super().__init__()
313
+ self.n_channels = n_channels
314
+ self.n_classes = n_classes
315
+ self.bilinear = bilinear
316
+
317
+ # Encoder
318
+ self.inc = DoubleConv(n_channels, 64)
319
+ self.down1 = Down(64, 128)
320
+ self.down2 = Down(128, 256)
321
+ self.down3 = Down(256, 512)
322
+ factor = 2 if bilinear else 1
323
+ self.down4 = Down(512, 1024 // factor)
324
+
325
+ # Decoder
326
+ self.up1 = Up(1024, 512 // factor, bilinear)
327
+ self.up2 = Up(512, 256 // factor, bilinear)
328
+ self.up3 = Up(256, 128 // factor, bilinear)
329
+ self.up4 = Up(128, 64, bilinear)
330
+ self.outc = OutConv(64, n_classes)
331
+
332
+ def forward(self, x):
333
+ x1 = self.inc(x)
334
+ x2 = self.down1(x1)
335
+ x3 = self.down2(x2)
336
+ x4 = self.down3(x3)
337
+ x5 = self.down4(x4)
338
+
339
+ x = self.up1(x5, x4)
340
+ x = self.up2(x, x3)
341
+ x = self.up3(x, x2)
342
+ x = self.up4(x, x1)
343
+ logits = self.outc(x)
344
+ return torch.sigmoid(logits)
345
+ def init_weights(self):
346
+ # Initialize with Kaiming initialization
347
+ def init_fn(m):
348
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
349
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
350
+
351
+ self.apply(init_fn)
352
+
353
+ def load_UNet_model(model_path):
354
+ print(f"Loading model from {model_path}")
355
+ model = torch.load(model_path, weights_only=False, map_location=torch.device('cpu')) # Load the model (entire model saved with torch.save)
356
+ model.eval() # Set the model to evaluation mode
357
+ return model
358
+
359
+ def predict(selected_model, uploaded_image):
360
+ if selected_model == "Instance Segmentation Model (U-Net)":
361
+ print("Predicting using U-Net")
362
+ model_path = "UNet_Model.pth" # Path to your trained model
363
+ else:
364
+ print("Invalid model selected")
365
+ return None
366
+
367
+ # Visualize predictions (call visualize_predictions)
368
+ return visualize_predictions(uploaded_image, model_path)
369
+ # Visualization function for contours and IoU
370
+
371
+ def visualize_predictions(uploaded_image, model_path="UNet.pth"):
372
+ model = load_UNet_model(model_path)
373
+ image = process_image(uploaded_image)
374
+
375
+ # Make prediction
376
+ with torch.no_grad():
377
+ output = model(image)
378
+ binary_pred = (output > 0.5).float().cpu().numpy() # Prediction as a binary mask
379
+ pred_prob = output.squeeze().cpu().numpy() # Prediction probabilities (for heatmap)
380
+
381
+ # Visualization part (assumes ground truth is available)
382
+ fig, axes = plt.subplots(1, 4, figsize=(16, 4))
383
+
384
+ # Original image
385
+ img = np.array(uploaded_image) / 255.0 # Normalize the image to [0, 1]
386
+ axes[0].imshow(img)
387
+ axes[0].set_title('Original Image')
388
+ axes[0].axis('off')
389
+
390
+ # Ground truth (this is just an example, you should provide the actual mask)
391
+ # For the sake of demonstration, we use a dummy mask
392
+ ground_truth = np.zeros_like(binary_pred[0, 0])
393
+ axes[1].imshow(ground_truth, cmap='gray')
394
+ axes[1].set_title('Ground Truth')
395
+ axes[1].axis('off')
396
+
397
+ # Prediction Probability
398
+ axes[2].imshow(pred_prob, cmap='jet', vmin=0, vmax=1)
399
+ axes[2].set_title('Prediction Probability')
400
+ axes[2].axis('off')
401
+
402
+ # Calculate IoU (Intersection over Union)
403
+ intersection = np.logical_and(binary_pred[0, 0] > 0.5, ground_truth > 0.5).sum()
404
+ union = np.logical_or(binary_pred[0, 0] > 0.5, ground_truth > 0.5).sum()
405
+ iou = intersection / union if union > 0 else 0
406
+ axes[3].imshow(img)
407
+ contours, _ = cv2.findContours(binary_pred[0, 0].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
408
+ contour_img = np.zeros_like(binary_pred[0, 0])
409
+ cv2.drawContours(contour_img, contours, -1, 1, 2)
410
+
411
+ # Add the contour overlay with IoU text
412
+ axes[3].imshow(contour_img, cmap='Reds', alpha=0.5)
413
+ axes[3].set_title(f'Prediction Contour')
414
+ axes[3].axis('off')
415
+
416
+ plt.tight_layout()
417
+
418
+ # Save the figure to a BytesIO object and return it as an image
419
+ buf = io.BytesIO()
420
+ plt.savefig(buf, format='png')
421
+ buf.seek(0)
422
+ img = Image.open(buf)
423
+ return img
424
+
425
+ # Connect buttons to navigation functions
426
+ start_trial_button.click(
427
+ fn=go_to_system_page,
428
+ inputs=None, # Pass the current page state
429
+ outputs=[welcome_page, system_page]
430
+ )
431
+
432
+ back_button.click(
433
+ fn=go_to_welcome_page,
434
+ inputs=None, # Pass the current page state
435
+ outputs=[welcome_page, system_page]
436
+ )
437
+ upload_image_button.click(
438
+ fn=predict,
439
+ inputs=[radio, img_input],
440
+ outputs=img_output
441
+ )
442
+
443
+
444
+
445
+ demo.launch(share=True)