File size: 3,060 Bytes
d5ed229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08bd9f6
2c30e99
d5ed229
 
 
 
 
 
50b0082
d5ed229
 
 
 
 
336a3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1fb5fc
2b340f4
f1fb5fc
336a3c8
 
f1fb5fc
2b340f4
f1fb5fc
336a3c8
 
f1fb5fc
2b340f4
f1fb5fc
336a3c8
 
f1fb5fc
2b340f4
f1fb5fc
336a3c8
 
 
d5ed229
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import tensorflow as tf
from io import BytesIO
import numpy as np
import cv2
import base64

def set_page_background(png_file):
    @st.cache_data(show_spinner=False)
    def get_base64_of_bin_file(bin_file):
        with open(bin_file, 'rb') as f:
            data = f.read()
        return base64.b64encode(data).decode()

    bin_str = get_base64_of_bin_file(png_file)
    custom_css = f'''
        <style>
            .stApp {{
                background-image: url("data:image/png;base64,{bin_str}");
                background-size: cover;
                background-repeat: no-repeat;
                background-attachment: scroll;
            }}
            
            #MainMenu {{visibility: hidden;}}
            footer {{visibility: hidden;}}
            icon {{color: white;}}
            nav-link {{--hover-color: grey; }}
            nav-link-selected {{background-color: #4ABF7E;}}
        </style>
    '''
    st.markdown(custom_css, unsafe_allow_html=True)

set_page_background("./BG.jpg")
st.title("Stages of Alzheimer's Disease (AD) Prediction")

st.markdown("[Dataset Source](https://www.kaggle.com/datasets/tourist55/alzheimers-dataset-4-class-of-images)")

model = tf.keras.models.load_model('./model/model_1.h5')
model.load_weights('./model/best_model_custom_1.h5')

uploaded_file = st.file_uploader("Upload a brain MRI image here", type=["jpg", "png", "jpeg"])

if uploaded_file is not None:
    file_bytes = BytesIO(uploaded_file.read())
    st.image(file_bytes,use_column_width=True,clamp = True) 
    
    predict_button = st.button("ㅤㅤPredictㅤㅤ")
    if predict_button:
        img = cv2.imdecode(np.frombuffer(file_bytes.read(), np.uint8), 0)
        #img=np.array(file_bytes)   
      
        if len(img.shape) == 2:
            img=cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
        
        img=cv2.resize(img,(176, 176))
        if img.max() > 1:
            img = img / 255.0
        
        img = np.expand_dims(img, axis=0)
        pred=model.predict(img)
        predict_val = np.argmax(pred, axis=1)
       
        if predict_val == 0:
            probability = pred[0][predict_val][0]            
            st.markdown(" Stage: Mildly Demented") 
            st.markdown(f" Prediction Probability: {probability}")
                     
        elif predict_val == 1:
            probability = pred[0][predict_val][0]           
            st.markdown(" Stage: Moderately Demented") 
            st.markdown(f" Prediction Probability: {probability}")
                     
        elif predict_val == 2:
            probability = pred[0][predict_val][0]            
            st.markdown(" Stage: Not Demented") 
            st.markdown(f" Prediction Probability: {probability}")
                    
        elif predict_val == 3:
            probability = pred[0][predict_val][0]            
            st.markdown(" Stage: Very Mildly Demented") 
            st.markdown(f" Prediction Probability: {probability}")
                    
        else:
           st.warning("Error!")