ALYYAN commited on
Commit
65ab7ab
·
unverified ·
1 Parent(s): 4c814c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -47
app.py CHANGED
@@ -1,35 +1,30 @@
1
- # app.py (Final, Definitive, and Working Version)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  import asyncio
 
6
 
7
- # Import backend components from the 'app' folder
8
  from app.prediction import PredictionPipeline
9
  from app.database import add_patient_record, get_all_records
10
 
11
  # --- Initialization ---
12
  prediction_pipeline = PredictionPipeline()
13
-
14
- # --- Point to the locally cloned sample images directory from setup.sh ---
15
  SAMPLE_IMAGE_DIR = Path("sample_images")
16
  try:
17
- # Ensure the directory exists and has images before creating the list
18
  if SAMPLE_IMAGE_DIR.is_dir():
19
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
20
- if not SAMPLE_IMAGES:
21
- print("Warning: 'sample_images' directory found, but it's empty.")
22
  else:
23
  raise FileNotFoundError
24
  except FileNotFoundError:
25
- print("Warning: 'sample_images' directory not found. Please check setup.sh. Samples will be unavailable.")
26
  SAMPLE_IMAGES = []
27
 
28
- # --- Core Logic (Async Functions) ---
29
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
30
- """
31
- Handles the core logic: validates input, gets prediction, saves to DB, and returns UI updates.
32
- """
33
  if not is_sample and (not patient_name or patient_age is None):
34
  raise gr.Error("Patient Name and Age are required.")
35
  if not image_list:
@@ -50,14 +45,13 @@ async def process_analysis(patient_name, patient_age, image_list, is_sample=Fals
50
  confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
51
 
52
  return [
53
- gr.update(visible=False), # uploader_column
54
- gr.update(visible=True), # results_column
55
- gr.update(value=result["watermarked_images"]), # result_images
56
- gr.update(value=confidences) # result_label
57
  ]
58
 
59
  async def refresh_history_table():
60
- """Fetches records from the DB and formats them for the DataFrame."""
61
  records = await get_all_records()
62
  data_for_df = []
63
  if records:
@@ -87,18 +81,15 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
87
  with gr.Column(elem_id="app_header"):
88
  gr.Markdown("# 🩺 Pneumonia Detection AI", elem_id="app_title")
89
  gr.Markdown("An AI-powered tool to assist in the diagnosis of pneumonia.", elem_id="app_subtitle")
90
-
91
  with gr.Row(elem_id="main_container"):
92
  with gr.Column(scale=1) as uploader_column:
93
  gr.Markdown("### Upload Patient X-Rays")
94
  image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
95
-
96
  with gr.Column(scale=2, visible=False) as results_column:
97
  gr.Markdown("### Analysis Results")
98
  result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery")
99
  result_label = gr.Label(label="Overall Prediction", num_top_classes=2)
100
  start_over_btn = gr.Button("Start New Analysis", variant="secondary")
101
-
102
  with gr.Group(visible=False) as patient_info_modal:
103
  gr.Markdown("## Enter Patient Details", elem_classes="text-center")
104
  patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
@@ -106,7 +97,6 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
106
  with gr.Row():
107
  submit_analysis_btn = gr.Button("Analyze Images", variant="primary")
108
  cancel_btn = gr.Button("Cancel", variant="stop")
109
-
110
  with gr.Column(elem_id="bottom_controls"):
111
  with gr.Accordion("About this Tool", open=False):
112
  gr.Markdown(
@@ -126,7 +116,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
126
  with gr.Row():
127
  samples_btn = gr.Button("Try Sample Images")
128
  history_btn = gr.Button("View Patient History")
129
-
130
  with gr.Column(visible=False) as history_page:
131
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
132
  with gr.Row():
@@ -137,20 +127,15 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
137
  with gr.Column(visible=False) as samples_page:
138
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
139
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
140
-
141
  sample_gallery = gr.Gallery(value=SAMPLE_IMAGES, label="Sample Images", columns=5, height=400, elem_id="sample_gallery")
142
-
143
- # This hidden textbox will store the list of selected file paths
144
- selected_samples_textbox = gr.Textbox(label="selected", visible=False, elem_id="selected_samples_textbox")
145
-
146
  with gr.Row():
147
  analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
148
  back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
149
 
150
  # --- Event Handling Logic ---
151
 
152
- def show_patient_info(files):
153
- return gr.update(visible=True) if files else gr.update(visible=False)
154
  image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
155
 
156
  async def submit_and_hide_modal(name, age, files):
@@ -165,26 +150,26 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
165
  select_js = """
166
  (evt) => {
167
  const gallery = document.querySelector('#sample_gallery .grid-container');
168
- const clicked_img_container = gallery.children[evt.index];
169
- const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
170
- let selected_paths = selected_paths_input.value ? selected_paths_input.value.split(',').filter(p => p.trim()) : [];
171
- const current_path = clicked_img_container.querySelector('img').alt;
172
-
173
- if (clicked_img_container.classList.contains('selected')) {
174
- clicked_img_container.classList.remove('selected');
175
- selected_paths = selected_paths.filter(p => p !== current_path);
176
  } else {
177
- if (selected_paths.length < 3) {
178
- clicked_img_container.classList.add('selected');
179
- selected_paths.push(current_path);
180
  } else {
181
- alert("You can select a maximum of 3 images.");
182
  }
183
  }
184
- return selected_paths.join(',');
185
  }
186
  """
187
- sample_gallery.select(fn=None, _js=select_js, outputs=[selected_samples_textbox])
188
 
189
  async def handle_sample_analysis(selected_paths_str: str):
190
  selected_images = [path for path in selected_paths_str.split(',') if path]
@@ -192,13 +177,14 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
192
  raise gr.Error("Please select at least one sample image to analyze.")
193
 
194
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
 
195
  return [
196
  gr.update(visible=True), # main_app
197
  gr.update(visible=False), # samples_page
198
  *analysis_results
199
  ]
200
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
201
-
202
  # --- Page Navigation ---
203
  all_pages = [main_app, history_page, samples_page]
204
 
@@ -207,14 +193,13 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
207
  return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
208
 
209
  def show_samples_page():
210
- # Also clear selections when navigating to the samples page
211
- return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value="")]
212
 
213
  def show_main_page():
214
  return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
215
 
216
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
217
- samples_btn.click(fn=show_samples_page, outputs=all_pages + [selected_samples_textbox])
218
  back_to_main_btn_hist.click(fn=show_main_page, outputs=all_pages)
219
  back_to_main_btn_samp.click(fn=show_main_page, outputs=all_pages)
220
 
 
1
+ # app.py (Definitive Final Version)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  import asyncio
6
+ from PIL import Image
7
 
8
+ # Import backend components
9
  from app.prediction import PredictionPipeline
10
  from app.database import add_patient_record, get_all_records
11
 
12
  # --- Initialization ---
13
  prediction_pipeline = PredictionPipeline()
14
+ # Point to the locally cloned sample images directory from setup.sh
 
15
  SAMPLE_IMAGE_DIR = Path("sample_images")
16
  try:
 
17
  if SAMPLE_IMAGE_DIR.is_dir():
18
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
19
+ if not SAMPLE_IMAGES: raise FileNotFoundError
 
20
  else:
21
  raise FileNotFoundError
22
  except FileNotFoundError:
23
+ print("Warning: 'sample_images' directory not found or empty. Please check setup.sh. Samples will be unavailable.")
24
  SAMPLE_IMAGES = []
25
 
26
+ # --- Core Logic Functions ---
27
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
 
 
 
28
  if not is_sample and (not patient_name or patient_age is None):
29
  raise gr.Error("Patient Name and Age are required.")
30
  if not image_list:
 
45
  confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
46
 
47
  return [
48
+ gr.update(visible=False),
49
+ gr.update(visible=True),
50
+ gr.update(value=result["watermarked_images"]),
51
+ gr.update(value=confidences)
52
  ]
53
 
54
  async def refresh_history_table():
 
55
  records = await get_all_records()
56
  data_for_df = []
57
  if records:
 
81
  with gr.Column(elem_id="app_header"):
82
  gr.Markdown("# 🩺 Pneumonia Detection AI", elem_id="app_title")
83
  gr.Markdown("An AI-powered tool to assist in the diagnosis of pneumonia.", elem_id="app_subtitle")
 
84
  with gr.Row(elem_id="main_container"):
85
  with gr.Column(scale=1) as uploader_column:
86
  gr.Markdown("### Upload Patient X-Rays")
87
  image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
 
88
  with gr.Column(scale=2, visible=False) as results_column:
89
  gr.Markdown("### Analysis Results")
90
  result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery")
91
  result_label = gr.Label(label="Overall Prediction", num_top_classes=2)
92
  start_over_btn = gr.Button("Start New Analysis", variant="secondary")
 
93
  with gr.Group(visible=False) as patient_info_modal:
94
  gr.Markdown("## Enter Patient Details", elem_classes="text-center")
95
  patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
 
97
  with gr.Row():
98
  submit_analysis_btn = gr.Button("Analyze Images", variant="primary")
99
  cancel_btn = gr.Button("Cancel", variant="stop")
 
100
  with gr.Column(elem_id="bottom_controls"):
101
  with gr.Accordion("About this Tool", open=False):
102
  gr.Markdown(
 
116
  with gr.Row():
117
  samples_btn = gr.Button("Try Sample Images")
118
  history_btn = gr.Button("View Patient History")
119
+
120
  with gr.Column(visible=False) as history_page:
121
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
122
  with gr.Row():
 
127
  with gr.Column(visible=False) as samples_page:
128
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
129
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
 
130
  sample_gallery = gr.Gallery(value=SAMPLE_IMAGES, label="Sample Images", columns=5, height=400, elem_id="sample_gallery")
131
+ selected_samples_textbox = gr.Textbox(visible=False, elem_id="selected_samples_textbox")
 
 
 
132
  with gr.Row():
133
  analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
134
  back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
135
 
136
  # --- Event Handling Logic ---
137
 
138
+ def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
 
139
  image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
140
 
141
  async def submit_and_hide_modal(name, age, files):
 
150
  select_js = """
151
  (evt) => {
152
  const gallery = document.querySelector('#sample_gallery .grid-container');
153
+ const clicked_container = gallery.children[evt.index];
154
+ const hidden_input = document.querySelector('#selected_samples_textbox textarea');
155
+ let selections = hidden_input.value ? hidden_input.value.split(',').filter(p => p.trim()) : [];
156
+ const path = clicked_container.querySelector('img').alt;
157
+
158
+ if (clicked_container.classList.contains('selected')) {
159
+ clicked_container.classList.remove('selected');
160
+ selections = selections.filter(p => p !== path);
161
  } else {
162
+ if (selections.length < 3) {
163
+ clicked_container.classList.add('selected');
164
+ selections.push(path);
165
  } else {
166
+ alert("Maximum of 3 images can be selected.");
167
  }
168
  }
169
+ return [selections.join(',')]; // Return value must be a list/tuple for Gradio
170
  }
171
  """
172
+ sample_gallery.select(fn=None, js=select_js, outputs=[selected_samples_textbox])
173
 
174
  async def handle_sample_analysis(selected_paths_str: str):
175
  selected_images = [path for path in selected_paths_str.split(',') if path]
 
177
  raise gr.Error("Please select at least one sample image to analyze.")
178
 
179
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
180
+ # We need to return an update for every output component
181
  return [
182
  gr.update(visible=True), # main_app
183
  gr.update(visible=False), # samples_page
184
  *analysis_results
185
  ]
186
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
187
+
188
  # --- Page Navigation ---
189
  all_pages = [main_app, history_page, samples_page]
190
 
 
193
  return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
194
 
195
  def show_samples_page():
196
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
 
197
 
198
  def show_main_page():
199
  return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
200
 
201
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
202
+ samples_btn.click(fn=show_samples_page, outputs=all_pages)
203
  back_to_main_btn_hist.click(fn=show_main_page, outputs=all_pages)
204
  back_to_main_btn_samp.click(fn=show_main_page, outputs=all_pages)
205