File size: 4,263 Bytes
e6748e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73176e5
e6748e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

import streamlit as st
from PIL import Image
from src.ui.drawable_canvas import drawable_canvas
from src.ui.streamlit_ui import streamlit_ui
from src.segmentation import segment_everything
from src.utils import calculate_parameters, plot_distribution, calculate_pixel_length, plot_cumulative_frequency
from ultralytics import YOLO
import torch
import cv2


# Cache the model and device
@st.cache_data()
def load_model_and_initialize():
    model_path = "src/model/FastSAM-x.pt"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = YOLO(model_path)
    return model, device

def main():
    """Main application logic."""
    uploaded_image, input_size, iou_threshold, conf_threshold, better_quality, contour_thickness, real_world_length, max_det = streamlit_ui()
    if uploaded_image is not None:
        try:
            canvas_result = drawable_canvas(uploaded_image, input_size)
            pixel_length = None
            if canvas_result.json_data is not None and "objects" in canvas_result.json_data:
                if len(canvas_result.json_data["objects"]) > 0:
                    line_object = canvas_result.json_data["objects"][0]
                    start_point = [line_object['x1'], line_object['y1']]
                    end_point = [line_object['x2'], line_object['y2']]

                    # Get image dimensions for calculating the scaling factor
                    image_width, image_height = Image.open(uploaded_image).size
                    scale_factor = input_size / max(image_width, image_height)

                    # Calculate pixel length with the scaling factor
                    pixel_length = calculate_pixel_length(start_point, end_point)
                    st.write(f"Pixel length of the line: {pixel_length}")
                else:
                    st.write("Please draw a line to set the scale or enter the real-world length.")
            else:
                st.write("Please draw a line to set the scale or enter the real-world length.")

            if pixel_length is not None and real_world_length is not None:
                scale_factor = real_world_length / pixel_length
            else:
                st.write("Scale factor could not be calculated. Make sure to draw a line and enter the real-world length.")
                return

            input_image = Image.open(uploaded_image)

            # Load the model and device from cache
            model, device = load_model_and_initialize()

            segmented_image, annotations = segment_everything(
                input_image,
                model=model,
                device=device,
                input_size=input_size,
                iou_threshold=iou_threshold,
                conf_threshold=conf_threshold,
                better_quality=better_quality,
                contour_thickness=contour_thickness,
                max_det=max_det
            )

            st.image(segmented_image, caption="Segmented Image", use_column_width=True)

            # Calculate and display object parameters
            df = calculate_parameters(annotations, scale_factor)

            if not df.empty:
                st.write("Summary of Object Parameters:")
                st.dataframe(df)

                csv = df.to_csv(index=False)
                st.download_button(
                    label="Download data as CSV",
                    data=csv,
                    file_name='grain_parameters.csv',
                    mime='text/csv',
                )

                plot_cumulative_frequency(df)
                filtered_columns = [col for col in df.columns.tolist() if col != 'Object']
                selected_parameter = st.selectbox("Select a parameter to see its distribution:", filtered_columns)

                if selected_parameter:
                    plot_distribution(df, selected_parameter)
                else:
                    st.write("No parameter selected for plotting.")

            else:
                st.write("No objects detected.")

        except Exception as e:
            st.error("An error occurred during processing. Please check the logs for details.")

    else:
        st.write("Please upload an image.")

if __name__ == "__main__":
    main()