File size: 29,130 Bytes
4923714
 
 
 
 
 
bc8f4cc
4923714
 
7c50066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d3b13
 
7c50066
 
 
 
 
 
 
 
 
 
c9d3b13
 
 
 
 
 
7c50066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d3b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c50066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d3b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c50066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d3b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f4979c
5018182
8f4979c
 
 
 
 
c9d3b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c50066
 
c9d3b13
7c50066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d3b13
 
7c50066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d3b13
7c50066
 
 
 
 
c9d3b13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
try:
    import torch
    import torchvision
except ImportError:
    import subprocess
    print("Attempting to install missing packages...")
    subprocess.check_call(["pip", "install", "torch", "torchvision","matplotlib","numpy","opencv-python","Pillow"])
    import torch
    import torchvision

import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import requests
import os
from PIL import Image
from collections import OrderedDict
from torchvision import models
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
import io
# Import CSS and URL File
css_file_path = os.path.join(os.path.dirname(__file__), "ui.css")
with open(css_file_path,"r") as f:
    custom_css = f.read()
# HTML Design
html_welcome_page = """
<div class="container">
    <div class="inner-container">
        <h1 class="title-text">Welcome to RemoveWeed Weed Detection System</h1>
        <img src="https://i.ibb.co/fY1nk315/image-2.png" alt="RemoveWeed Logo" class="logo-container"/>
        <p class="description-text">
            Project Aim: This system is designed to optimize rice planting schedules with broad-leaved weed detection using machine learning.
        </p>
        <p class="description-text">
            Designed by: Whitney Lim Wan Yee (TP068221)
        </p>
    </div>
</div>
"""
html_system_page ="""
<div class="container">
    <img src="https://i.ibb.co/KxMMTmxG/Screenshot-2025-03-28-224907.png" alt="RemoveWeed Logo" class="logo-container-system"/>
    <h1 class="system-page-title">RemoveWeed System Overview</h1>
    <p class="system-page-description">
        This system is designed to help farmers detect broad-leaved weeds in rice fields using machine learning techniques. 
        The aim is to optimize rice planting schedules and improve crop yield.
    </p>
</div>
"""
html_project_description = """
<div class="project-container">
    <h1 class="project-title">- 🌿 About Project 🌿 -</h1>
    
    <div class="upper-content">
        <div class="left-upper-column">
            <div class="chart">
                <img src="https://i.ibb.co/j9Ch3xnC/1312103.png" alt="Agricultural consumption of herbicides worldwide from 1990 to 2022" class="chart-image">
                <p class="chart-caption">Resource: Statista (2024) - Agricultural consumption of herbicides worldwide from 1990 to 2022 (in 1,000 metric tons)</p>
            </div>
        </div>
        <div class="right-upper-column">
            <div class="herbicide-description">
                <h2 class="herbicide-title">Herbicide Use Soars: A Shocking Yearly Increase!</h2>
                <p class="herbicide-text">
                    Statista (2024) revealed that global herbicide consumption has reached <span class="bold-red">1.94 million</span> metric tons. To control dock weed in farming fields, 
                    the application of herbicides can cause <span class="bold-red">delays</span> in rice planting schedules ranging from <span class="bold-red">7 to 30 days</span>.
                </p>
            </div>
        </div>
    </div>
    
    <div class="middle-content">
        <div class="left-middle-column">
            <div class="objective-description">
                <h2 class="objective-title">Why Choose RemoveWeed?</h2>
                <p class="objective-text">
                    RemoveWeed is a specialized system that detects broad-leaved dock weed in paddy fields with <span class="bold-red">92%</span> accuracy, 
                    enabling timely interventions that can increase crop yields by up to <span class="bold-red">15%</span>. 
                    Our lightweight U-Net model, built from scratch, processes field images in seconds, allowing farmers to save up to <span class="bold-red">30%</span> on 
                    herbicide costs through precise application. The system integrates seamlessly with existing agricultural technology, 
                    offering a return on investment within a single growing season through reduced labor costs and optimized planting schedules.
                    
                </p>
            </div>
        </div>
        <div class="right-middle-column">
            <div class="carousel-wrapper">
                <div class="carousel-container">
                    <p class="carousel-title">Broad-leaved Dock Weed in Paddy Field</p>
                    <div class="carousel">
                        <div class="image-one"></div>
                        <div class="image-two"></div>
                        <div class="image-three"></div>
                    </div>
                </div>
            </div>      
        </div>
    </div>
    
    <div class="bottom-content">
    <div class="left-bottom-column">
        <div class="Proceed-To-Detection">
            <img src="https://i.ibb.co/Txb9LFf5/agriculture-tan.jpg" alt="Model Training" class="model-image">
        </div>
    </div>
    <div class="right-bottom-column">
         <div class="benefits-description">
            <h2 class="benefits-title">Potential Benefits</h2>
            <ul class="benefits-list">
                <li>Cost Savings πŸ’°</li>
                <li>Reduce Labor and Manual Monitoring Cost πŸ’Ή</li>
                <li>Increase Profitability by Rice Planting Scheduling Advice πŸ“ˆ</li>
                <li>Provide Sustainable Practices in Agriculture πŸ§‘β€πŸŒΎ</li>
                <li>Reduce Herbicide Pollution ☒️</li>
            </ul>
        </div>
    </div>
</div>
</div>
"""
html_author_review_page = """
<div class="author-section">
    <h1 class="author-title">- Project Owner Introduction -</h1>
    
    <div class="author-content">
        <div class="author-image-container">
            <img src="https://i.ibb.co/4RZW1Pq4/Wanyu.jpg" alt="Whitney Lim Wan Yee" class="author-image">
        </div>
        
        <div class="author-bio">
            <p class="author-text">
                Whitney Lim Wan Yee is a student at Asia Pacific University (APU), pursuing Year 3 Computer Science specialization in Data Analytics. She is passionate about machine learning and its applications in agriculture.
            </p>
            
            <div class="social-links">
                <a href="https://www.linkedin.com/in/whitneylimwanyee/" target="_blank" class="social-link">
                    <img src="https://images.rawpixel.com/image_png_800/czNmcy1wcml2YXRlL3Jhd3BpeGVsX2ltYWdlcy93ZWJzaXRlX2NvbnRlbnQvbHIvdjk4Mi1kMy0xMC5wbmc.png" alt="LinkedIn" class="social-icon">
                    <span>LinkedIn Profile</span>
                </a>
                
                <a href="https://www.kaggle.com/whitneylimwanyee" target="_blank" class="social-link">
                    <img src="https://cdn4.iconfinder.com/data/icons/logos-and-brands/512/189_Kaggle_logo_logos-512.png" alt="Kaggle" class="social-icon">
                    <span>Kaggle Profile</span>
                </a>
                
                <button onclick="window.location.href='mailto:whitneylim0719@gmail.com'" class="social-link">
                    <img src="https://static.vecteezy.com/system/resources/previews/016/716/465/non_2x/gmail-icon-free-png.png" alt="Email" class="social-icon">
                    <span>Email Me</span>
                </button>
                <a href="https://drive.google.com/file/d/1SvbvzLpFQJjzX6_VPGS3NddzK0ksXE8r/view" target="_blank" class="social-link">
                    <img src="https://cdn-icons-png.flaticon.com/512/8347/8347432.png" alt="Kaggle" class="social-icon">
                    <span>My Resume</span>
                </a>
            </div>
        </div>
    </div>
</div>
"""
"""
"""
html_api_page = """
<div class="api-introduction-section">
    <h1 class="api-page-title">- API Usage Introduction -</h1>
    <p class="api-page-description">
        As U-Net is the most stable and accurate model for detecting dock weed leave in paddy field, 
        this API link is provided to any agriculture research and industries who currently work on 
        IoT base weed detection system could simply use for model creation purpose.
    </p>
    
    <div class="needed-tools">
        <h3>~ The tools that need to be pre-installed ~</h3>
        <div class="tool-icons">
            <img src="https://miro.medium.com/v2/resize:fit:1400/0*adyeTInZ7lebNANK.png" alt="Hugging Face" class="tool-icon">
            <img src="https://wolke.img.univie.ac.at/documentation/general/mkdocs/img/jupyter-logo.png" alt="Jupyter" class="tool-icon">
            <img src="https://miro.medium.com/v2/resize:fit:512/1*IMGOKBIN8qkOBt5CH55NSw.png" alt="TensorFlow" class="tool-icon">
        </div>
    </div>
    
    <div class="api-link-section">
        <h3>API Endpoint:</h3>
        <a href="https://endpoints.huggingface.co/whitney0507/endpoints/unet-model-fky" class="api-link-button">
            Access API Endpoint
        </a>
        <p class="api-link-description">
        Click the API Endpoint button above. Then you will see there is a "Playground" section that allow you to copy and use the model.
        </p>
        <img src="https://i.ibb.co/JFqT40nj/Screenshot-2025-04-15-172825.png" alt="Hugging Face Playground ScreenShot" class="api-link-section">
        
    </div>
    
    <div class="how-to-use">
        <h3>How to Use:</h3>
        <ol>
            <li>Create a Hugging Face Account.</li>
            <li>Create new token.</li>
            <li>Copy the code and replace "hf_XXXXX" with your actual token.</li>
            <li>Install the requests library if you haven't already (pip install requests)</li>
            <li>Modify input <"Hello World"> to  "inputs": image_base64 </li>

        </ol>
    </div>
</div>
"""



js_func = """
function refresh() {
    const url = new URL(window.location);

    if (url.searchParams.get('__theme') !== 'light') {
        url.searchParams.set('__theme', 'light');
        window.location.href = url.href;
    }
}
"""
def choose_model(choice):
        if choice == "Instance Segmentation Model (U-Net)":
            return "You have selected U-Net"
        else:
            return "Invalid selection"
# Gradio Interface

def predict(selected_model, uploaded_image):
        if selected_model == "Instance Segmentation Model (U-Net)":
            print("Predicting using U-Net")
            model_path = "UNet_Model (1).pth"  # Path to your trained model
        else:
            print("Invalid model selected")
            return None, None
    
            # Get the visualization and weed confidence
        viz_image = visualize_predictions(uploaded_image, model_path)
        
        # Get confidence score from the prediction
        # This is simplified - you should get the actual confidence from your model
        confidence = get_weed_confidence(uploaded_image, model_path)
        
        # Generate advice based on confidence
        advice = generate_advice(confidence)
        
        return viz_image, advice
def get_weed_confidence(uploaded_image, model_path):
        model = load_UNet_model(model_path)
        image = process_image(uploaded_image)
        
        # Make prediction
        with torch.no_grad():
            output = model(image)
            pred_prob = output.squeeze().cpu().numpy()
            
        # Calculate average confidence in the predicted areas
        confidence = np.mean(pred_prob[pred_prob > 0.5]) if np.any(pred_prob > 0.5) else 0.0
        
        return confidence

# Add this function to generate advice based on confidence
def generate_advice(confidence):
    if confidence > 0.7:  # High confidence of weed detection
        advice = """
        <div class="advice-card high-risk">
            <h3>🚨 High Dock Weed Infestation Detected</h3>
            <div class="advice-content">
                <div class="advice-section">
                    <h4>Planting Schedule Impact:</h4>
                    <p>Delay rice planting by <strong>14-21 days</strong> to allow for proper weed control</p>
                </div>
                <div class="advice-section">
                    <h4>Recommended Actions:</h4>
                    <ul>
                        <li>Apply targeted herbicide treatment within 3-5 days</li>
                        <li>Consider mechanical removal for dense areas</li>
                        <li>Schedule follow-up inspection after 10 days</li>
                    </ul>
                </div>
                <div class="advice-section">
                    <h4>Long-term Strategy:</h4>
                    <p>Implement crop rotation plan next season to break weed cycle</p>
                </div>
            </div>
        </div>
        """
    elif confidence > 0.3:  # Medium confidence
        advice = """
        <div class="advice-card medium-risk">
            <h3>⚠️ Moderate Dock Weed Presence Detected</h3>
            <div class="advice-content">
                <div class="advice-section">
                    <h4>Planting Schedule Impact:</h4>
                    <p>Consider delaying rice planting by <strong>7-10 days</strong> for weed control</p>
                </div>
                <div class="advice-section">
                    <h4>Recommended Actions:</h4>
                    <ul>
                        <li>Spot treatment with selective herbicide</li>
                        <li>Monitor field closely during next 2 weeks</li>
                        <li>Apply pre-emergent herbicide before planting</li>
                    </ul>
                </div>
                <div class="advice-section">
                    <h4>Long-term Strategy:</h4>
                    <p>Evaluate field drainage and soil pH to reduce favorable conditions for dock weed</p>
                </div>
            </div>
        </div>
        """
    else:  # Low confidence
        advice = """
        <div class="advice-card low-risk">
            <h3>βœ… Minimal/No Dock Weed Detected</h3>
            <div class="advice-content">
                <div class="advice-section">
                    <h4>Planting Schedule Impact:</h4>
                    <p>Proceed with <strong>normal rice planting schedule</strong></p>
                </div>
                <div class="advice-section">
                    <h4>Recommended Actions:</h4>
                    <ul>
                        <li>Continue regular field monitoring</li>
                        <li>Apply standard pre-planting herbicide as preventative measure</li>
                        <li>Maintain good field hygiene practices</li>
                    </ul>
                </div>
                <div class="advice-section">
                    <h4>Long-term Strategy:</h4>
                    <p>Implement regular crop rotation and field monitoring to prevent future weed issues</p>
                </div>
            </div>
        </div>
        """
    
    return advice


with gr.Blocks(css=custom_css,js=js_func) as demo:
    # State to track current page
    page = gr.State(value="welcome")
    
    # Welcome page container
    with gr.Group(visible=True, elem_classes="gradio-container") as welcome_page:
         gr.HTML(html_welcome_page)  # Insert HTML structure
         start_trial_button = gr.Button("Start Trial", variant="primary", elem_classes="trial-button")
    
    # System description page container (initially hidden)
    with gr.Group(visible=False) as system_page:
        gr.HTML(html_system_page)
        tabs = gr.Tabs()
        with tabs:
            with gr.TabItem("Project Description"):
                tab_state = gr.State(value=0)
                gr.HTML(html_project_description) 
            with gr.TabItem("Model Playground"):
                with gr.Column(elem_classes="model-playground-container"):
                    gr.Markdown("""
                        <h2 class="model-playground-header">- Model Playground -</h2>
                        <p class="model-playground-description">This section allows users to interact with the model and test its capabilities. Before attempting model training, please follow the guidelines in the user manual to prevent any issues.</p>
                        """, elem_classes="")
                    
                    gr.Image(
                        value="https://i.ibb.co/4nzB4NH5/Group-15.png",
                        label="User Manual",
                        show_download_button=False,
                        show_label=False,
                        container=False,
                        height=300  # Adjust height as needed
                    )
                                    
                    # For sections 1 and 2 side by side
                    with gr.Row():
                        # Left column - Download Image
                        with gr.Column(elem_classes="section-container"):
                            gr.Markdown("""
                                <h2 class="download-image-header">1. Download Image</h2>
                                <p class="download-image-description">Download these sample images to test with the model.</p>
                                <a href="https://github.com/whitney0507/FYP/blob/main/Sample%20Images%20for%20Download.zip" 
                                   download="Sample Images for Download.zip"
                                   class="download-button"
                                   style="display: inline-block; padding: 10px 20px; background-color: #4CAF50; color: white; text-decoration: none; border-radius: 4px; margin-top: 10px;">
                                   Download Sample Images
                                </a>
                            """)
                        # Right column - Select Model
                        with gr.Column(elem_classes="section-container"):
                            gr.Markdown("""
                                <h2 class="select-model-header">2. Select Model</h2>
                                <p class="select-model-description">Choose the model you want to use for prediction.</p>
                            """)
                            with gr.Column(elem_classes="model-selection-container"):
                                radio = gr.Radio(
                                    choices=["Instance Segmentation Model (U-Net)"], 
                                    label="Click the Model", 
                                    elem_classes="model-selection-radio"
                                )
                                radio.change(fn=choose_model, inputs=radio)

                    # Section 3 below the side-by-side layout
                    gr.Markdown("""
                        <h2 class="download-image-header">3. Drop an image and Click "Start Prediction"</h2>
                        <p class="download-image-description">Sometimes the page may load slowly or the output may be missing. If this happens, please click the button again.</p>
                    """)
                                
                    with gr.Row():
                        # Left column for input image
                        with gr.Column(scale=1):
                            img_input = gr.Image(
                                type="numpy", 
                                label="Upload Image", 
                                elem_classes="image-input"
                            )
                            upload_image_button = gr.Button("Start Prediction", variant="primary", elem_classes="upload-button")
                            
                        
                        # Right column for output image
                        with gr.Column(scale=1):
                            img_output = gr.Image(
                                label="Predicted Image", 
                                elem_classes="image-output"
                            )    # Add a button to go back to the welcome page
                    # Predict and show output when image is uploaded
                    
                     
                    with gr.Column(elem_classes="advice-container"):
                        gr.Markdown("""
                            <h2 class="advice-header">Planting Schedule Recommendation</h2>
                            <p class="advice-description">Based on weed detection results, get personalized advice for your rice planting schedule.</p>
                        """)    # System description page container (initially hidden)
                        advice_output = gr.HTML(
                            label="", 
                            elem_classes="advice-output"
                        )
                upload_image_button.click(
                    fn=predict,
                    inputs=[radio, img_input],
                    outputs=[img_output, advice_output]
                )    # Add a button to go back to the welcome page
                
            with gr.TabItem("Open Source API Link"):
                gr.HTML(html_api_page) 
            with gr.TabItem("Contact and Review"):
                gr.HTML(html_author_review_page)
        back_button = gr.Button("Back", variant="secondary",elem_classes="back-button")


    # Navigation functions
    def go_to_system_page():
        print("Going to system page")
        return gr.update(visible=False), gr.update(visible=True)

    def go_to_welcome_page():
        print("Going to welcome page")
        return gr.update(visible=True), gr.update(visible=False)
   
    def process_image(uploaded_image):
        # If the image is passed as a numpy array, convert it to a PIL image
        if isinstance(uploaded_image, np.ndarray):
            image = Image.fromarray(uploaded_image)
        elif isinstance(uploaded_image, Image.Image):
            image = uploaded_image
        else:
            raise ValueError("Uploaded image must be either a numpy array or a PIL Image.")
        
        # Define the necessary transformations
        transform = transforms.Compose([
            # transforms.Resize((256, 256)),  # Resize according to your model's input size
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        # Apply transformations and add batch dimension
        image = transform(image).unsqueeze(0)
        
        return image
    
    
    class DoubleConv(nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.double_conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

        def forward(self, x):
            return self.double_conv(x)

    class Down(nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.maxpool_conv = nn.Sequential(
                nn.MaxPool2d(2),
                DoubleConv(in_channels, out_channels)
            )

        def forward(self, x):
            return self.maxpool_conv(x)

    class Up(nn.Module):
        def __init__(self, in_channels, out_channels, bilinear=True):
            super().__init__()
            if bilinear:
                self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            else:
                self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            
            self.conv = DoubleConv(in_channels, out_channels)

        def forward(self, x1, x2):
            x1 = self.up(x1)
            # Resize x1 to match x2
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
            x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2])
            x = torch.cat([x2, x1], dim=1)
            return self.conv(x)

    class OutConv(nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        def forward(self, x):
            return self.conv(x)

    class UNet(nn.Module):
        def __init__(self, n_channels=3, n_classes=1, bilinear=True):
            super().__init__()
            self.n_channels = n_channels
            self.n_classes = n_classes
            self.bilinear = bilinear

            # Encoder
            self.inc = DoubleConv(n_channels, 64)
            self.down1 = Down(64, 128)
            self.down2 = Down(128, 256)
            self.down3 = Down(256, 512)
            factor = 2 if bilinear else 1
            self.down4 = Down(512, 1024 // factor)
            
            # Decoder
            self.up1 = Up(1024, 512 // factor, bilinear)
            self.up2 = Up(512, 256 // factor, bilinear)
            self.up3 = Up(256, 128 // factor, bilinear)
            self.up4 = Up(128, 64, bilinear)
            self.outc = OutConv(64, n_classes)

        def forward(self, x):
            x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)
            x5 = self.down4(x4)

            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1)
            logits = self.outc(x)
            return torch.sigmoid(logits)
        def init_weights(self):
            # Initialize with Kaiming initialization
            def init_fn(m):
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            
            self.apply(init_fn)
    
    def load_UNet_model(model_path):
        print(f"Loading model from {model_path}")
        model = torch.load(model_path, weights_only=False, map_location=torch.device('cpu'))  # Load the model (entire model saved with torch.save)
        model.eval()  # Set the model to evaluation mode
        return model
    
    
    def visualize_predictions(uploaded_image, model_path="UNet.pth"):
        model = load_UNet_model(model_path)
        image = process_image(uploaded_image)
        
        # Make prediction
        with torch.no_grad():
            output = model(image)
            binary_pred = (output > 0.5).float().cpu().numpy()  # Prediction as a binary mask
            pred_prob = output.squeeze().cpu().numpy()  # Prediction probabilities (for heatmap)

        # Visualization part (assumes ground truth is available)
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))

        # Original image
        img = np.array(uploaded_image) / 255.0  # Normalize the image to [0, 1]
        axes[0].imshow(img)
        axes[0].set_title('Original Image')
        axes[0].axis('off')

        # Ground truth (this is just an example, you should provide the actual mask)
        # For the sake of demonstration, we use a dummy mask
        ground_truth = np.zeros_like(binary_pred[0, 0])
        axes[1].imshow(ground_truth, cmap='gray')
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')

        # Prediction Probability
        axes[2].imshow(pred_prob, cmap='jet', vmin=0, vmax=1)
        axes[2].set_title('Prediction Probability')
        axes[2].axis('off')

        # Calculate IoU (Intersection over Union)
        intersection = np.logical_and(binary_pred[0, 0] > 0.5, ground_truth > 0.5).sum()
        union = np.logical_or(binary_pred[0, 0] > 0.5, ground_truth > 0.5).sum()
        iou = intersection / union if union > 0 else 0
        axes[3].imshow(img)
        contours, _ = cv2.findContours(binary_pred[0, 0].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contour_img = np.zeros_like(binary_pred[0, 0])
        cv2.drawContours(contour_img, contours, -1, 1, 2)
        
        # Add the contour overlay with IoU text
        axes[3].imshow(contour_img, cmap='Reds', alpha=0.5)
        axes[3].set_title(f'Prediction Contour')
        axes[3].axis('off')

        plt.tight_layout()

        # Save the figure to a BytesIO object and return it as an image
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        img = Image.open(buf)
        return img
        # Add this function to calculate weed confidence
   
    # Connect buttons to navigation functions
    start_trial_button.click(
        fn=go_to_system_page,
        inputs=None,  # Pass the current page state
        outputs=[welcome_page, system_page]
    )
    
    back_button.click(
        fn=go_to_welcome_page,
        inputs=None,  # Pass the current page state
        outputs=[welcome_page, system_page]
    )
    upload_image_button.click(
        fn=predict,
        inputs=[radio, img_input],
        outputs=[img_output,advice_output]
    )
    
    

demo.launch(share=True)