File size: 8,575 Bytes
0cdb35f
 
 
 
 
 
f74cf62
0cdb35f
 
 
 
 
8794644
 
0cdb35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f74cf62
0cdb35f
f74cf62
0cdb35f
 
f74cf62
0cdb35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f74cf62
 
0cdb35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0da6b
0cdb35f
 
 
 
 
 
 
 
 
 
bd0da6b
 
0cdb35f
 
 
 
 
 
 
 
 
 
 
44d9568
0cdb35f
 
 
 
 
 
 
8794644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e95037
8794644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e95037
8794644
 
 
 
fb6a0d1
8794644
fb6a0d1
8794644
fb6a0d1
8794644
fb6a0d1
 
8794644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e95037
8794644
 
 
 
 
 
 
 
 
 
 
fb6a0d1
8794644
 
 
 
 
44d9568
0cdb35f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import io

from model_server import get_predictor

# Page config
st.set_page_config(
    page_title="VREyeSAM - Non-frontal Iris Segmentation",
    page_icon="πŸ‘οΈ",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS
st.markdown("""
    <style>
    .main {
        padding: 2rem;
    }
    .stButton>button {
        width: 100%;
        background-color: #4CAF50;
        color: white;
        padding: 0.5rem;
        font-size: 16px;
    }
    .result-box {
        border: 2px solid #ddd;
        border-radius: 10px;
        padding: 1rem;
        margin: 1rem 0;
    }
    </style>
""", unsafe_allow_html=True)

@st.cache_resource
def load_model():
    """Load model securely through protected server"""
    try:
        predictor = get_predictor()
        return predictor
    except Exception as e:
        st.error(f"Error loading model")
        return None

def read_and_resize_image(image):
    """Read and resize image for processing"""
    img = np.array(image)
    if len(img.shape) == 2:  # Grayscale
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    elif img.shape[2] == 4:  # RGBA
        img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
    
    # Resize if needed
    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
    if r < 1:
        img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
    
    return img

def segment_iris(predictor, image):
    """Perform iris segmentation using secure model server"""
    return predictor.predict(image, num_samples=30)

def overlay_mask_on_image(image, binary_mask, color=(0, 255, 0), alpha=0.5):
    """Overlay binary mask on original image"""
    overlay = image.copy()
    mask_colored = np.zeros_like(image)
    mask_colored[binary_mask > 0] = color
    
    # Blend
    result = cv2.addWeighted(overlay, 1-alpha, mask_colored, alpha, 0)
    
    return result

# Main App
def main():
    st.title("πŸ‘οΈ VREyeSAM: Non-Frontal Iris Segmentation")
    st.markdown("""
    Upload a non-frontal iris image captured in VR/AR environments, and VREyeSAM will segment the iris region 
    using a fine-tuned SAM2 model with uncertainty-weighted loss.
    """)
    
    # Sidebar
    with st.sidebar:
        st.header("About VREyeSAM")
        st.markdown("""
        **VREyeSAM** is a robust non-frontal iris segmentation framework designed for images captured under:
        - Varying gaze directions
        - Partial occlusions
        - Inconsistent lighting conditions
        
        **Model Performance:**
        - Recall: 0.870
        - F1-Score: 0.806
        """)
        
        st.header("Settings")
        show_overlay = st.checkbox("Show Mask Overlay", value=True)
        show_probabilistic = st.checkbox("Show Probabilistic Mask", value=False)
    
    # Load model
    with st.spinner("Loading VREyeSAM model..."):
        predictor = load_model()
    
    if predictor is None:
        st.error("Failed to load model. Please check the setup.")
        return
    
    st.success("βœ… Model loaded successfully!")
    
    # File uploader with increased size limit
    uploaded_file = st.file_uploader(
        "Upload an iris image (JPG, PNG, JPEG)",
        type=["jpg", "png", "jpeg"],
        help="Upload a non-frontal iris image for segmentation"
    )
    
    if uploaded_file is not None:
        try:
            # Display original image
            image = Image.open(uploaded_file)
            
            col1, col2 = st.columns(2)
            
            with col1:
                st.subheader("πŸ“· Original Image")
                st.image(image, use_container_width=True)
            
            # Process button
            if st.button("πŸ” Segment Iris", type="primary"):
                with st.spinner("Segmenting iris..."):
                    try:
                        # Prepare image
                        img_array = read_and_resize_image(image)
                        
                        # Perform segmentation
                        binary_mask, prob_mask = segment_iris(predictor, img_array)
                        
                        with col2:
                            st.subheader("🎯 Binary Mask")
                            binary_mask_img = (binary_mask * 255).astype(np.uint8)
                            st.image(binary_mask_img, use_container_width=True)
                        
                        # Additional results
                        st.markdown("---")
                        st.subheader("πŸ“Š Segmentation Results")
                        
                        result_cols = st.columns(2)
                        
                        with result_cols[0]:
                            if show_overlay:
                                st.markdown("**Overlay View**")
                                overlay = overlay_mask_on_image(img_array, binary_mask)
                                st.image(overlay, use_container_width=True)
                        
                        with result_cols[1]:
                            if show_probabilistic:
                                st.markdown("**Probabilistic Mask**")
                                prob_mask_img = (prob_mask * 255).astype(np.uint8)
                                st.image(prob_mask_img, use_container_width=True)
                        
                        # Download options
                        st.markdown("---")
                        st.subheader("πŸ’Ύ Download Results")
                        
                        download_cols = st.columns(2)
                        
                        with download_cols[0]:
                            # Binary mask download
                            binary_pil = Image.fromarray(binary_mask_img)
                            buf = io.BytesIO()
                            binary_pil.save(buf, format="PNG")
                            st.download_button(
                                label="Download Binary Mask",
                                data=buf.getvalue(),
                                file_name="binary_mask.png",
                                mime="image/png"
                            )
                        
                        with download_cols[1]:
                            if show_overlay:
                                # Overlay download
                                overlay_pil = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
                                buf = io.BytesIO()
                                overlay_pil.save(buf, format="PNG")
                                st.download_button(
                                    label="Download Overlay",
                                    data=buf.getvalue(),
                                    file_name="overlay.png",
                                    mime="image/png"
                                )
                        
                        # Statistics
                        st.markdown("---")
                        st.subheader("πŸ“ˆ Segmentation Statistics")
                        stats_cols = st.columns(3)
                        
                        mask_area = np.sum(binary_mask > 0)
                        total_area = binary_mask.shape[0] * binary_mask.shape[1]
                        coverage = (mask_area / total_area) * 100
                        
                        with stats_cols[0]:
                            st.metric("Mask Coverage", f"{coverage:.2f}%")
                        with stats_cols[1]:
                            st.metric("Image Size", f"{img_array.shape[1]}x{img_array.shape[0]}")
                        with stats_cols[2]:
                            st.metric("Mask Area (pixels)", f"{mask_area:,}")
                    
                    except Exception as e:
                        st.error(f"❌ Error during segmentation: {str(e)}")
        
        except Exception as e:
            st.error(f"❌ Error loading image: {str(e)}")
            st.info("Please try uploading a different image or reducing the file size.")

    # Footer
    st.markdown("---")
    st.markdown("""
    <div style='text-align: center'>
        <p><strong>VREyeSAM</strong> - Virtual Reality Non-Frontal Iris Segmentation</p>
        <p>πŸ”— <a href='https://github.com/GeetanjaliGTZ/VREyeSAM'>GitHub</a> | 
        πŸ“§ <a href='mailto:geetanjalisharma546@gmail.com'>Contact</a></p>
    </div>
    """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()