ALYYAN commited on
Commit
4c814c5
·
unverified ·
1 Parent(s): 6515a1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -54
app.py CHANGED
@@ -1,27 +1,35 @@
1
- # app.py (Final Version with Syntax Fix)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  import asyncio
6
- from PIL import Image
7
 
8
- # Import backend components
9
  from app.prediction import PredictionPipeline
10
  from app.database import add_patient_record, get_all_records
11
 
12
  # --- Initialization ---
13
  prediction_pipeline = PredictionPipeline()
 
 
14
  SAMPLE_IMAGE_DIR = Path("sample_images")
15
  try:
16
- SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
17
- if not SAMPLE_IMAGES:
 
 
 
 
18
  raise FileNotFoundError
19
  except FileNotFoundError:
20
- print("Warning: 'sample_images' directory not found or is empty. Please check setup.sh.")
21
  SAMPLE_IMAGES = []
22
 
23
- # --- Core Logic Functions ---
24
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
 
 
 
25
  if not is_sample and (not patient_name or patient_age is None):
26
  raise gr.Error("Patient Name and Age are required.")
27
  if not image_list:
@@ -29,7 +37,7 @@ async def process_analysis(patient_name, patient_age, image_list, is_sample=Fals
29
 
30
  result = prediction_pipeline.predict(image_list)
31
  if "error" in result:
32
- raise gr.Error(result["error"])
33
 
34
  final_pred = result["final_prediction"]
35
  final_conf = result["final_confidence"]
@@ -49,9 +57,12 @@ async def process_analysis(patient_name, patient_age, image_list, is_sample=Fals
49
  ]
50
 
51
  async def refresh_history_table():
 
52
  records = await get_all_records()
53
- data = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records] if records else []
54
- return gr.update(value=data)
 
 
55
 
56
  # --- Gradio UI Definition ---
57
  css = """
@@ -66,30 +77,28 @@ css = """
66
  #results_gallery .gallery-item { padding: 0.25rem !important; background-color: #374151; border: 1px solid #374151 !important; }
67
  #bottom_controls { max-width: 600px; margin: 2.5rem auto 1rem auto; }
68
  #bottom_controls .gr-accordion > .gr-block-label { text-align: center !important; display: block !important; }
69
- #sample_gallery .gallery-item { box-shadow: 0 0 5px rgba(0,0,0,0.5); border-radius: 8px !important; }
70
- #sample_gallery .gallery-item.selected { border: 4px solid var(--primary-500) !important; }
 
71
  """
72
-
73
- # --- THIS IS THE FIX ---
74
- with gr.Blocks(
75
- theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"),
76
- css=css,
77
- title="Pneumonia Detection AI"
78
- ) as demo:
79
 
80
  with gr.Column() as main_app:
81
  with gr.Column(elem_id="app_header"):
82
  gr.Markdown("# 🩺 Pneumonia Detection AI", elem_id="app_title")
83
  gr.Markdown("An AI-powered tool to assist in the diagnosis of pneumonia.", elem_id="app_subtitle")
 
84
  with gr.Row(elem_id="main_container"):
85
  with gr.Column(scale=1) as uploader_column:
86
  gr.Markdown("### Upload Patient X-Rays")
87
  image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
 
88
  with gr.Column(scale=2, visible=False) as results_column:
89
  gr.Markdown("### Analysis Results")
90
  result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery")
91
  result_label = gr.Label(label="Overall Prediction", num_top_classes=2)
92
  start_over_btn = gr.Button("Start New Analysis", variant="secondary")
 
93
  with gr.Group(visible=False) as patient_info_modal:
94
  gr.Markdown("## Enter Patient Details", elem_classes="text-center")
95
  patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
@@ -97,20 +106,27 @@ with gr.Blocks(
97
  with gr.Row():
98
  submit_analysis_btn = gr.Button("Analyze Images", variant="primary")
99
  cancel_btn = gr.Button("Cancel", variant="stop")
 
100
  with gr.Column(elem_id="bottom_controls"):
101
  with gr.Accordion("About this Tool", open=False):
102
  gr.Markdown(
103
  """
104
  ### MLOps-Powered Pneumonia Detection
105
- (Your professional description here)
 
 
 
106
  ---
107
- **Project Team:** Alyyan Ahmed & Munim Akbar
 
 
 
108
  """
109
  )
110
  with gr.Row():
111
  samples_btn = gr.Button("Try Sample Images")
112
  history_btn = gr.Button("View Patient History")
113
-
114
  with gr.Column(visible=False) as history_page:
115
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
116
  with gr.Row():
@@ -121,32 +137,45 @@ with gr.Blocks(
121
  with gr.Column(visible=False) as samples_page:
122
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
123
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
124
- sample_gallery = gr.Gallery(
125
- value=SAMPLE_IMAGES if SAMPLE_IMAGES else [],
126
- label="Sample Images",
127
- columns=5, height=400,
128
- elem_id="sample_gallery"
129
- )
130
- selected_samples_textbox = gr.Textbox(visible=False, elem_id="selected_samples_textbox")
131
  with gr.Row():
132
  analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
133
  back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
134
 
135
  # --- Event Handling Logic ---
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  select_js = """
138
  (evt) => {
139
  const gallery = document.querySelector('#sample_gallery .grid-container');
140
- const clicked_img = gallery.children[evt.index];
141
  const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
142
- let selected_paths = selected_paths_input.value ? selected_paths_input.value.split(',') : [];
143
- const current_path = clicked_img.querySelector('img').alt;
144
- if (clicked_img.classList.contains('selected')) {
145
- clicked_img.classList.remove('selected');
 
146
  selected_paths = selected_paths.filter(p => p !== current_path);
147
  } else {
148
  if (selected_paths.length < 3) {
149
- clicked_img.classList.add('selected');
150
  selected_paths.push(current_path);
151
  } else {
152
  alert("You can select a maximum of 3 images.");
@@ -155,37 +184,37 @@ with gr.Blocks(
155
  return selected_paths.join(',');
156
  }
157
  """
158
- sample_gallery.select(fn=None, js=select_js, outputs=[selected_samples_textbox])
159
-
160
- def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
161
- image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
162
-
163
- async def submit_and_hide_modal(name, age, files):
164
- analysis_results = await process_analysis(name, age, files)
165
- return [*analysis_results, gr.update(visible=False)]
166
- submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
167
-
168
- cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
169
- start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
170
 
171
  async def handle_sample_analysis(selected_paths_str: str):
172
  selected_images = [path for path in selected_paths_str.split(',') if path]
173
- if not selected_images: raise gr.Error("Please select at least one sample image.")
174
- if len(selected_images) > 3: raise gr.Error("Please select no more than 3 sample images.")
175
 
176
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
177
- return [gr.update(visible=True), gr.update(visible=False), *analysis_results]
 
 
 
 
178
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
179
-
 
180
  all_pages = [main_app, history_page, samples_page]
 
181
  async def show_history_page_and_refresh():
182
  records_update = await refresh_history_table()
183
  return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
184
- def show_samples_page(): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
185
- def show_main_page(): return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
186
 
 
 
 
 
 
 
 
187
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
188
- samples_btn.click(fn=show_samples_page, outputs=all_pages)
189
  back_to_main_btn_hist.click(fn=show_main_page, outputs=all_pages)
190
  back_to_main_btn_samp.click(fn=show_main_page, outputs=all_pages)
191
 
 
1
+ # app.py (Final, Definitive, and Working Version)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  import asyncio
 
6
 
7
+ # Import backend components from the 'app' folder
8
  from app.prediction import PredictionPipeline
9
  from app.database import add_patient_record, get_all_records
10
 
11
  # --- Initialization ---
12
  prediction_pipeline = PredictionPipeline()
13
+
14
+ # --- Point to the locally cloned sample images directory from setup.sh ---
15
  SAMPLE_IMAGE_DIR = Path("sample_images")
16
  try:
17
+ # Ensure the directory exists and has images before creating the list
18
+ if SAMPLE_IMAGE_DIR.is_dir():
19
+ SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
20
+ if not SAMPLE_IMAGES:
21
+ print("Warning: 'sample_images' directory found, but it's empty.")
22
+ else:
23
  raise FileNotFoundError
24
  except FileNotFoundError:
25
+ print("Warning: 'sample_images' directory not found. Please check setup.sh. Samples will be unavailable.")
26
  SAMPLE_IMAGES = []
27
 
28
+ # --- Core Logic (Async Functions) ---
29
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
30
+ """
31
+ Handles the core logic: validates input, gets prediction, saves to DB, and returns UI updates.
32
+ """
33
  if not is_sample and (not patient_name or patient_age is None):
34
  raise gr.Error("Patient Name and Age are required.")
35
  if not image_list:
 
37
 
38
  result = prediction_pipeline.predict(image_list)
39
  if "error" in result:
40
+ raise gr.Error(result.get("details", result["error"]))
41
 
42
  final_pred = result["final_prediction"]
43
  final_conf = result["final_confidence"]
 
57
  ]
58
 
59
  async def refresh_history_table():
60
+ """Fetches records from the DB and formats them for the DataFrame."""
61
  records = await get_all_records()
62
+ data_for_df = []
63
+ if records:
64
+ data_for_df = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records]
65
+ return gr.update(value=data_for_df)
66
 
67
  # --- Gradio UI Definition ---
68
  css = """
 
77
  #results_gallery .gallery-item { padding: 0.25rem !important; background-color: #374151; border: 1px solid #374151 !important; }
78
  #bottom_controls { max-width: 600px; margin: 2.5rem auto 1rem auto; }
79
  #bottom_controls .gr-accordion > .gr-block-label { text-align: center !important; display: block !important; }
80
+ /* --- Sample Gallery Selection Styling --- */
81
+ #sample_gallery .gallery-item { box-shadow: 0 0 5px rgba(0,0,0,0.5); border-radius: 8px !important; border: 4px solid transparent; transition: border-color 0.3s ease; }
82
+ #sample_gallery .gallery-item.selected { border-color: var(--primary-500) !important; }
83
  """
84
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"), css=css, title="Pneumonia Detection AI") as demo:
 
 
 
 
 
 
85
 
86
  with gr.Column() as main_app:
87
  with gr.Column(elem_id="app_header"):
88
  gr.Markdown("# 🩺 Pneumonia Detection AI", elem_id="app_title")
89
  gr.Markdown("An AI-powered tool to assist in the diagnosis of pneumonia.", elem_id="app_subtitle")
90
+
91
  with gr.Row(elem_id="main_container"):
92
  with gr.Column(scale=1) as uploader_column:
93
  gr.Markdown("### Upload Patient X-Rays")
94
  image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
95
+
96
  with gr.Column(scale=2, visible=False) as results_column:
97
  gr.Markdown("### Analysis Results")
98
  result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery")
99
  result_label = gr.Label(label="Overall Prediction", num_top_classes=2)
100
  start_over_btn = gr.Button("Start New Analysis", variant="secondary")
101
+
102
  with gr.Group(visible=False) as patient_info_modal:
103
  gr.Markdown("## Enter Patient Details", elem_classes="text-center")
104
  patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
 
106
  with gr.Row():
107
  submit_analysis_btn = gr.Button("Analyze Images", variant="primary")
108
  cancel_btn = gr.Button("Cancel", variant="stop")
109
+
110
  with gr.Column(elem_id="bottom_controls"):
111
  with gr.Accordion("About this Tool", open=False):
112
  gr.Markdown(
113
  """
114
  ### MLOps-Powered Pneumonia Detection
115
+ This application demonstrates a complete, end-to-end MLOps pipeline for medical image classification. It leverages a state-of-the-art **Vision Transformer (ViT)** model, fine-tuned on a public dataset of chest X-ray images to distinguish between Normal and Pneumonia cases.
116
+
117
+ **Disclaimer:** This tool is for demonstration and educational purposes only and is **not a substitute for professional medical advice.**
118
+
119
  ---
120
+
121
+ **Project Team:**
122
+ * **Alyyan Ahmed** - Lead ML Engineer & Developer
123
+ * **Munim Akbar** - Project Contributor & Reviewer
124
  """
125
  )
126
  with gr.Row():
127
  samples_btn = gr.Button("Try Sample Images")
128
  history_btn = gr.Button("View Patient History")
129
+
130
  with gr.Column(visible=False) as history_page:
131
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
132
  with gr.Row():
 
137
  with gr.Column(visible=False) as samples_page:
138
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
139
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
140
+
141
+ sample_gallery = gr.Gallery(value=SAMPLE_IMAGES, label="Sample Images", columns=5, height=400, elem_id="sample_gallery")
142
+
143
+ # This hidden textbox will store the list of selected file paths
144
+ selected_samples_textbox = gr.Textbox(label="selected", visible=False, elem_id="selected_samples_textbox")
145
+
 
146
  with gr.Row():
147
  analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
148
  back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
149
 
150
  # --- Event Handling Logic ---
151
 
152
+ def show_patient_info(files):
153
+ return gr.update(visible=True) if files else gr.update(visible=False)
154
+ image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
155
+
156
+ async def submit_and_hide_modal(name, age, files):
157
+ analysis_results = await process_analysis(name, age, files)
158
+ return [*analysis_results, gr.update(visible=False)]
159
+ submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
160
+
161
+ cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
162
+ start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
163
+
164
+ # --- Sample Page Logic with JavaScript ---
165
  select_js = """
166
  (evt) => {
167
  const gallery = document.querySelector('#sample_gallery .grid-container');
168
+ const clicked_img_container = gallery.children[evt.index];
169
  const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
170
+ let selected_paths = selected_paths_input.value ? selected_paths_input.value.split(',').filter(p => p.trim()) : [];
171
+ const current_path = clicked_img_container.querySelector('img').alt;
172
+
173
+ if (clicked_img_container.classList.contains('selected')) {
174
+ clicked_img_container.classList.remove('selected');
175
  selected_paths = selected_paths.filter(p => p !== current_path);
176
  } else {
177
  if (selected_paths.length < 3) {
178
+ clicked_img_container.classList.add('selected');
179
  selected_paths.push(current_path);
180
  } else {
181
  alert("You can select a maximum of 3 images.");
 
184
  return selected_paths.join(',');
185
  }
186
  """
187
+ sample_gallery.select(fn=None, _js=select_js, outputs=[selected_samples_textbox])
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  async def handle_sample_analysis(selected_paths_str: str):
190
  selected_images = [path for path in selected_paths_str.split(',') if path]
191
+ if not selected_images:
192
+ raise gr.Error("Please select at least one sample image to analyze.")
193
 
194
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
195
+ return [
196
+ gr.update(visible=True), # main_app
197
+ gr.update(visible=False), # samples_page
198
+ *analysis_results
199
+ ]
200
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
201
+
202
+ # --- Page Navigation ---
203
  all_pages = [main_app, history_page, samples_page]
204
+
205
  async def show_history_page_and_refresh():
206
  records_update = await refresh_history_table()
207
  return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
 
 
208
 
209
+ def show_samples_page():
210
+ # Also clear selections when navigating to the samples page
211
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value="")]
212
+
213
+ def show_main_page():
214
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
215
+
216
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
217
+ samples_btn.click(fn=show_samples_page, outputs=all_pages + [selected_samples_textbox])
218
  back_to_main_btn_hist.click(fn=show_main_page, outputs=all_pages)
219
  back_to_main_btn_samp.click(fn=show_main_page, outputs=all_pages)
220