Subhajit01 commited on
Commit
94bf006
·
verified ·
1 Parent(s): e517585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -205
app.py CHANGED
@@ -1,205 +1,205 @@
1
- import streamlit as st
2
- import google.generativeai as genai
3
- import os
4
- import torch
5
- from dotenv import load_dotenv
6
- import torch.nn as nn
7
- from torchvision import models, transforms
8
- import torch.nn.functional as F
9
- import matplotlib.pyplot as plt
10
- import seaborn as sns
11
- import numpy as np
12
- from PIL import Image
13
- from io import BytesIO
14
- from predictor import *
15
- from image_processor import *
16
- import google.generativeai as genai
17
- import gdown
18
- load_dotenv()
19
- api_key = os.getenv("API_KEY")
20
- genai.configure(api_key=api_key)
21
-
22
- gen_model = genai.GenerativeModel('gemini-1.5-flash-latest')
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- @st.cache_resource
25
- def download_model():
26
- file_id = "1Ovlm72q3sa6BxobWb-6QKTAF-uv753kZ"
27
- url = f'https://drive.google.com/uc?export=download&id={file_id}'
28
- gdown.download(url, 'model.pth', quiet=False)
29
- model = models.vgg16(weights = 'VGG16_Weights.DEFAULT')
30
- model.classifier[6] = nn.Linear(4096,4)
31
- return model
32
-
33
- model = download_model()
34
-
35
- def createForm(prediction):
36
- st.markdown("<h2 style='text-align: center;'>To get the report fill the details</h2>", unsafe_allow_html=True)
37
- with st.form(key='user_info_form'):
38
- patient_name = st.text_input("Patient Name")
39
- patient_age = st.number_input("Age", min_value=0, max_value=120)
40
- patient_gender = st.selectbox("Gender", ["Male", "Female", "Other"])
41
- patient_symptoms = st.text_area("Other Symptoms", placeholder="Describe the symptoms...")
42
- submit_button = st.form_submit_button("Generate Report")
43
- if submit_button:
44
- if not patient_name:
45
- st.error('Patient name is required!!')
46
- elif not patient_age:
47
- st.error('Patient age is required!!')
48
- if patient_symptoms == "":
49
- patient_symptoms = 'NIL'
50
- user_data = {
51
- "Patient Name": patient_name,
52
- "Age": patient_age,
53
- "Gender": patient_gender,
54
- "Symptoms": patient_symptoms
55
- }
56
- st.markdown(f"### Report for {patient_name}")
57
- st.write("(The report generation will be solely based on the symptoms provided and prediction)")
58
- generate_and_display_report(prediction, user_data)
59
-
60
- def generate_report(prediction,patient_details):
61
- prompt = f"""You have to generate a medical report based on the predicted brain tumor by MRI: {prediction} and the patient details provided below.
62
- ### Patient details:
63
- 1. **Patient Name: {patient_details["Patient Name"]}**
64
- 2. **Patient Age: {patient_details["Age"]}**
65
- 3. **Patient Gender: {patient_details["Gender"]}**
66
- 4. **Patient Symptoms: {patient_details["Symptoms"]}**
67
-
68
- ### Report Instructions:
69
- 1. **Include the patient details in the header without patient symptoms** as listed above. Each piece of information must be on a separate line (e.g., "Patient Name" on its own line, followed by "Patient Age" on its own line).
70
- 2. **Do not place any of the details on the same line.** Each detail must appear separately as shown in the list.
71
- 3. After the patient details, generate a medical report in subsections with each section containing a maximum of 5 lines:
72
- - Diagnosis
73
- - Possible Cause of the condition based on patient details
74
- - Treatment options and recommendations
75
- - Prognosis
76
- 4. If Patient Symptoms is not empty then analyse those symptoms in diagnosis.
77
- 5. **Strictly follow these formatting rules**:
78
- - No bullet points or extra punctuation other than what's necessary for a medical report.
79
- """
80
-
81
- response = gen_model.generate_content(
82
- contents=prompt
83
- )
84
-
85
- if response._done and response._result and 'candidates' in response._result:
86
- report_content = response.text
87
- return report_content
88
- else:
89
- return "Error: Report generation failed."
90
-
91
- def generate_and_display_report(prediction,patient_details):
92
- report = generate_report(prediction, patient_details)
93
- st.markdown("<h2 style='text-align: center; background-color: #17253b'>Generated Report</h2>", unsafe_allow_html=True)
94
- st.write(report)
95
- st.download_button(
96
- label="Download Text Report",
97
- data=report,
98
- file_name="generated_report.txt",
99
- mime="text/plain"
100
- )
101
-
102
- def stats(logits):
103
-
104
- probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
105
- if probabilities.ndim > 1:
106
- probabilities = probabilities[0]
107
- class_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
108
- probabilities_normalized = probabilities / np.sum(probabilities)
109
- percentages = np.round(probabilities_normalized * 100, 2)
110
- fig, ax = plt.subplots(figsize=(10, 6))
111
- norm = plt.Normalize(vmin=np.min(probabilities_normalized), vmax=np.max(probabilities_normalized))
112
- colors = plt.cm.coolwarm(norm(probabilities_normalized))
113
- sns.barplot(x=probabilities_normalized, y=class_labels, palette=colors, orient='h', width=0.2, ax=ax)
114
- plt.title("Probabilities Window")
115
- plt.xlabel("Probability")
116
- plt.ylabel("Predictions")
117
- for i, percentage in enumerate(percentages):
118
- plt.text(probabilities_normalized[i] + 0.02, i, f'{percentage}%', va='center', fontsize=12)
119
- st.pyplot(fig)
120
- st.markdown(
121
- """
122
- <style>
123
- body {
124
- background-color: black;
125
- color: white; /* Set text color to white for visibility */
126
- }
127
- .stApp {
128
- background-color: black;
129
- color: white;
130
- }
131
- </style>
132
- """,
133
- unsafe_allow_html=True
134
- )
135
-
136
- st.markdown("""
137
- <style>
138
- .stForm {
139
- background-color: #ffffff;
140
- padding: 20px;
141
- border-radius: 10px;
142
- box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
143
- }
144
- .stDownloadButton button{
145
- color: #eb4634;
146
- border: 1px solid #eb4634
147
- border-radius: 8px;
148
- }
149
- .stFormSubmitButton button{
150
- color: #eb4634;
151
- border: 1px solid #eb4634
152
- border-radius: 8px;
153
- }
154
- </style>
155
- """, unsafe_allow_html=True)
156
-
157
- st.title("Brain Tumor Classification")
158
- uploaded_file = st.file_uploader("Upload an MRI Image", type=["png", "jpg", "jpeg"])
159
- if uploaded_file is not None:
160
- prediction, outputs = predict(uploaded_file, model, device)
161
- image = Image.open(uploaded_file).convert('RGB')
162
- saliency_image, blue_image, grad_image = processor(model, uploaded_file, device)
163
- # st.image(image, caption="Uploaded Image", width=300)
164
- fig, axes = plt.subplots(1,2, figsize=(15,5))
165
-
166
- axes[0].imshow(image)
167
- axes[0].set_title("Original Image" , color="white")
168
- axes[0].axis('off')
169
- axes[1].imshow(blue_image)
170
- axes[1].set_title("BW transformation", color="white")
171
- axes[1].axis('off')
172
- fig.patch.set_facecolor("black")
173
- st.pyplot(fig)
174
- fig1, axes1 = plt.subplots(1,2, figsize=(15,5))
175
-
176
- axes1[0].imshow(grad_image)
177
- axes1[0].set_title("GRAD transformation", color="white")
178
- axes1[0].axis('off')
179
- axes1[1].imshow(saliency_image)
180
- axes1[1].set_title("SALIENT", color="white")
181
- axes1[1].axis('off')
182
- fig1.patch.set_facecolor("black")
183
- st.pyplot(fig1)
184
- st.markdown(
185
- f"""
186
- <div style='
187
- margin-top: 20px;
188
- padding: 15px;
189
- background-color: #28292b;
190
- color: white;
191
- border-radius: 5px;
192
- text-align: left;
193
- display: inline-block;
194
- width: 100%;'>
195
- <h2 style='margin: 0; padding: 0;'>Prediction: <b>{prediction}</b></h2>
196
- </div>
197
- """,
198
- unsafe_allow_html=True
199
- )
200
- # logits = torch.randn(1, 4)
201
- st.markdown("<h1 style='text-align: center;'>Analysis</h1>", unsafe_allow_html=True)
202
- stats(outputs)
203
- # Report Section
204
- createForm(prediction)
205
-
 
1
+ import streamlit as st
2
+ import google.generativeai as genai
3
+ import os
4
+ import torch
5
+ from dotenv import load_dotenv
6
+ import torch.nn as nn
7
+ from torchvision import models, transforms
8
+ import torch.nn.functional as F
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ import numpy as np
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from predictor import *
15
+ from image_processor import *
16
+ import google.generativeai as genai
17
+ import gdown
18
+ # load_dotenv()
19
+ api_key = os.getenv("GOOGLE_API_KEY")
20
+ genai.configure(api_key=api_key)
21
+
22
+ gen_model = genai.GenerativeModel('gemini-1.5-flash-latest')
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ @st.cache_resource
25
+ def download_model():
26
+ file_id = "1Ovlm72q3sa6BxobWb-6QKTAF-uv753kZ"
27
+ url = f'https://drive.google.com/uc?export=download&id={file_id}'
28
+ gdown.download(url, 'model.pth', quiet=False)
29
+ model = models.vgg16(weights = 'VGG16_Weights.DEFAULT')
30
+ model.classifier[6] = nn.Linear(4096,4)
31
+ return model
32
+
33
+ model = download_model()
34
+
35
+ def createForm(prediction):
36
+ st.markdown("<h2 style='text-align: center;'>To get the report fill the details</h2>", unsafe_allow_html=True)
37
+ with st.form(key='user_info_form'):
38
+ patient_name = st.text_input("Patient Name")
39
+ patient_age = st.number_input("Age", min_value=0, max_value=120)
40
+ patient_gender = st.selectbox("Gender", ["Male", "Female", "Other"])
41
+ patient_symptoms = st.text_area("Other Symptoms", placeholder="Describe the symptoms...")
42
+ submit_button = st.form_submit_button("Generate Report")
43
+ if submit_button:
44
+ if not patient_name:
45
+ st.error('Patient name is required!!')
46
+ elif not patient_age:
47
+ st.error('Patient age is required!!')
48
+ if patient_symptoms == "":
49
+ patient_symptoms = 'NIL'
50
+ user_data = {
51
+ "Patient Name": patient_name,
52
+ "Age": patient_age,
53
+ "Gender": patient_gender,
54
+ "Symptoms": patient_symptoms
55
+ }
56
+ st.markdown(f"### Report for {patient_name}")
57
+ st.write("(The report generation will be solely based on the symptoms provided and prediction)")
58
+ generate_and_display_report(prediction, user_data)
59
+
60
+ def generate_report(prediction,patient_details):
61
+ prompt = f"""You have to generate a medical report based on the predicted brain tumor by MRI: {prediction} and the patient details provided below.
62
+ ### Patient details:
63
+ 1. **Patient Name: {patient_details["Patient Name"]}**
64
+ 2. **Patient Age: {patient_details["Age"]}**
65
+ 3. **Patient Gender: {patient_details["Gender"]}**
66
+ 4. **Patient Symptoms: {patient_details["Symptoms"]}**
67
+
68
+ ### Report Instructions:
69
+ 1. **Include the patient details in the header without patient symptoms** as listed above. Each piece of information must be on a separate line (e.g., "Patient Name" on its own line, followed by "Patient Age" on its own line).
70
+ 2. **Do not place any of the details on the same line.** Each detail must appear separately as shown in the list.
71
+ 3. After the patient details, generate a medical report in subsections with each section containing a maximum of 5 lines:
72
+ - Diagnosis
73
+ - Possible Cause of the condition based on patient details
74
+ - Treatment options and recommendations
75
+ - Prognosis
76
+ 4. If Patient Symptoms is not empty then analyse those symptoms in diagnosis.
77
+ 5. **Strictly follow these formatting rules**:
78
+ - No bullet points or extra punctuation other than what's necessary for a medical report.
79
+ """
80
+
81
+ response = gen_model.generate_content(
82
+ contents=prompt
83
+ )
84
+
85
+ if response._done and response._result and 'candidates' in response._result:
86
+ report_content = response.text
87
+ return report_content
88
+ else:
89
+ return "Error: Report generation failed."
90
+
91
+ def generate_and_display_report(prediction,patient_details):
92
+ report = generate_report(prediction, patient_details)
93
+ st.markdown("<h2 style='text-align: center; background-color: #17253b'>Generated Report</h2>", unsafe_allow_html=True)
94
+ st.write(report)
95
+ st.download_button(
96
+ label="Download Text Report",
97
+ data=report,
98
+ file_name="generated_report.txt",
99
+ mime="text/plain"
100
+ )
101
+
102
+ def stats(logits):
103
+
104
+ probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
105
+ if probabilities.ndim > 1:
106
+ probabilities = probabilities[0]
107
+ class_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
108
+ probabilities_normalized = probabilities / np.sum(probabilities)
109
+ percentages = np.round(probabilities_normalized * 100, 2)
110
+ fig, ax = plt.subplots(figsize=(10, 6))
111
+ norm = plt.Normalize(vmin=np.min(probabilities_normalized), vmax=np.max(probabilities_normalized))
112
+ colors = plt.cm.coolwarm(norm(probabilities_normalized))
113
+ sns.barplot(x=probabilities_normalized, y=class_labels, palette=colors, orient='h', width=0.2, ax=ax)
114
+ plt.title("Probabilities Window")
115
+ plt.xlabel("Probability")
116
+ plt.ylabel("Predictions")
117
+ for i, percentage in enumerate(percentages):
118
+ plt.text(probabilities_normalized[i] + 0.02, i, f'{percentage}%', va='center', fontsize=12)
119
+ st.pyplot(fig)
120
+ st.markdown(
121
+ """
122
+ <style>
123
+ body {
124
+ background-color: black;
125
+ color: white; /* Set text color to white for visibility */
126
+ }
127
+ .stApp {
128
+ background-color: black;
129
+ color: white;
130
+ }
131
+ </style>
132
+ """,
133
+ unsafe_allow_html=True
134
+ )
135
+
136
+ st.markdown("""
137
+ <style>
138
+ .stForm {
139
+ background-color: #ffffff;
140
+ padding: 20px;
141
+ border-radius: 10px;
142
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
143
+ }
144
+ .stDownloadButton button{
145
+ color: #eb4634;
146
+ border: 1px solid #eb4634
147
+ border-radius: 8px;
148
+ }
149
+ .stFormSubmitButton button{
150
+ color: #eb4634;
151
+ border: 1px solid #eb4634
152
+ border-radius: 8px;
153
+ }
154
+ </style>
155
+ """, unsafe_allow_html=True)
156
+
157
+ st.title("Brain Tumor Classification")
158
+ uploaded_file = st.file_uploader("Upload an MRI Image", type=["png", "jpg", "jpeg"])
159
+ if uploaded_file is not None:
160
+ prediction, outputs = predict(uploaded_file, model, device)
161
+ image = Image.open(uploaded_file).convert('RGB')
162
+ saliency_image, blue_image, grad_image = processor(model, uploaded_file, device)
163
+ # st.image(image, caption="Uploaded Image", width=300)
164
+ fig, axes = plt.subplots(1,2, figsize=(15,5))
165
+
166
+ axes[0].imshow(image)
167
+ axes[0].set_title("Original Image" , color="white")
168
+ axes[0].axis('off')
169
+ axes[1].imshow(blue_image)
170
+ axes[1].set_title("BW transformation", color="white")
171
+ axes[1].axis('off')
172
+ fig.patch.set_facecolor("black")
173
+ st.pyplot(fig)
174
+ fig1, axes1 = plt.subplots(1,2, figsize=(15,5))
175
+
176
+ axes1[0].imshow(grad_image)
177
+ axes1[0].set_title("GRAD transformation", color="white")
178
+ axes1[0].axis('off')
179
+ axes1[1].imshow(saliency_image)
180
+ axes1[1].set_title("SALIENT", color="white")
181
+ axes1[1].axis('off')
182
+ fig1.patch.set_facecolor("black")
183
+ st.pyplot(fig1)
184
+ st.markdown(
185
+ f"""
186
+ <div style='
187
+ margin-top: 20px;
188
+ padding: 15px;
189
+ background-color: #28292b;
190
+ color: white;
191
+ border-radius: 5px;
192
+ text-align: left;
193
+ display: inline-block;
194
+ width: 100%;'>
195
+ <h2 style='margin: 0; padding: 0;'>Prediction: <b>{prediction}</b></h2>
196
+ </div>
197
+ """,
198
+ unsafe_allow_html=True
199
+ )
200
+ # logits = torch.randn(1, 4)
201
+ st.markdown("<h1 style='text-align: center;'>Analysis</h1>", unsafe_allow_html=True)
202
+ stats(outputs)
203
+ # Report Section
204
+ createForm(prediction)
205
+