ALYYAN commited on
Commit
1063b82
·
unverified ·
1 Parent(s): 9ae30e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -51
app.py CHANGED
@@ -1,18 +1,16 @@
1
- # app.py (Final Version - No Downloads, Modern JS)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  import asyncio
 
6
 
 
7
  from app.prediction import PredictionPipeline
8
  from app.database import add_patient_record, get_all_records
9
 
10
  # --- Initialization ---
11
  prediction_pipeline = PredictionPipeline()
12
-
13
- # --- FIX 1: Remove Hugging Face Hub download logic ---
14
- # The setup.sh script already clones the 'sample_images' directory.
15
- # We just need to point to it.
16
  SAMPLE_IMAGE_DIR = Path("sample_images")
17
  try:
18
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
@@ -22,20 +20,35 @@ except FileNotFoundError:
22
  print("Warning: 'sample_images' directory not found or is empty. Please check setup.sh.")
23
  SAMPLE_IMAGES = []
24
 
25
- # --- Core Logic Functions (Unchanged and Correct) ---
26
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
27
- # ... (code is the same)
28
- if not is_sample and (not patient_name or patient_age is None): raise gr.Error("Patient Name and Age are required.")
29
- if not image_list: raise gr.Error("At least one image is required.")
 
 
30
  result = prediction_pipeline.predict(image_list)
31
- if "error" in result: raise gr.Error(result["error"])
32
- final_pred, final_conf = result["final_prediction"], result["final_confidence"]
33
- if not is_sample: await add_patient_record(str(patient_name), int(patient_age), final_pred, final_conf)
34
- confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0}; confidences[final_pred] = final_conf; confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
35
- return [gr.update(visible=False), gr.update(visible=True), gr.update(value=result["watermarked_images"]), gr.update(value=confidences)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  async def refresh_history_table():
38
- # ... (code is the same)
39
  records = await get_all_records()
40
  data = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records] if records else []
41
  return gr.update(value=data)
@@ -53,50 +66,81 @@ css = """
53
  #results_gallery .gallery-item { padding: 0.25rem !important; background-color: #374151; border: 1px solid #374151 !important; }
54
  #bottom_controls { max-width: 600px; margin: 2.5rem auto 1rem auto; }
55
  #bottom_controls .gr-accordion > .gr-block-label { text-align: center !important; display: block !important; }
56
- /* --- FIX: Style the sample gallery for a cleaner look --- */
57
- #sample_gallery { background-color: transparent !important; border: none !important; }
58
  #sample_gallery .gallery-item { box-shadow: 0 0 5px rgba(0,0,0,0.5); border-radius: 8px !important; }
 
59
  """
60
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue"), secondary_hue="blue"), css=css, title="Pneumonia Detection AI") as demo:
 
 
 
 
 
 
61
 
62
- # ... (UI Layout is the same)
63
  with gr.Column() as main_app:
64
- # ...
 
 
65
  with gr.Row(elem_id="main_container"):
66
  with gr.Column(scale=1) as uploader_column:
67
- gr.Markdown("### Upload Patient X-Rays"); image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
 
68
  with gr.Column(scale=2, visible=False) as results_column:
69
- gr.Markdown("### Analysis Results"); result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery"); result_label = gr.Label(label="Overall Prediction", num_top_classes=2); start_over_btn = gr.Button("Start New Analysis", variant="secondary")
 
 
 
70
  with gr.Group(visible=False) as patient_info_modal:
71
- gr.Markdown("## Enter Patient Details", elem_classes="text-center"); patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe"); patient_age_modal = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
72
- with gr.Row(): submit_analysis_btn = gr.Button("Analyze Images", variant="primary"); cancel_btn = gr.Button("Cancel", variant="stop")
 
 
 
 
73
  with gr.Column(elem_id="bottom_controls"):
74
- with gr.Accordion("About this Tool", open=False): gr.Markdown("...")
75
- with gr.Row(): samples_btn = gr.Button("Try Sample Images"); history_btn = gr.Button("View Patient History")
 
 
 
 
 
 
 
 
 
 
 
76
  with gr.Column(visible=False) as history_page:
77
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
78
- with gr.Row(): back_to_main_btn_hist = gr.Button("⬅️ Back to Main App"); refresh_history_btn = gr.Button("Refresh History")
 
 
79
  history_df = gr.DataFrame(headers=["Name", "Age", "Prediction", "Confidence", "Date"], row_count=10, interactive=False)
 
80
  with gr.Column(visible=False) as samples_page:
81
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
82
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
83
- sample_gallery = gr.Gallery(value=SAMPLE_IMAGES, label="Sample Images", columns=5, height=400, elem_id="sample_gallery")
 
 
 
 
 
84
  selected_samples_textbox = gr.Textbox(visible=False, elem_id="selected_samples_textbox")
85
- with gr.Row(): analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary"); back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
 
 
86
 
87
  # --- Event Handling Logic ---
88
-
89
- # --- FIX 2: Use the modern gr.js() function for custom JavaScript ---
90
  select_js = """
91
  (evt) => {
92
- // This JS code runs in the browser when a sample image is clicked.
93
- // It's the same logic as before.
94
  const gallery = document.querySelector('#sample_gallery .grid-container');
95
  const clicked_img = gallery.children[evt.index];
96
  const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
97
  let selected_paths = selected_paths_input.value ? selected_paths_input.value.split(',') : [];
98
  const current_path = clicked_img.querySelector('img').alt;
99
-
100
  if (clicked_img.classList.contains('selected')) {
101
  clicked_img.classList.remove('selected');
102
  selected_paths = selected_paths.filter(p => p !== current_path);
@@ -108,51 +152,46 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue"), secondary_hue="blue"
108
  alert("You can select a maximum of 3 images.");
109
  }
110
  }
111
-
112
- // The return value of a gr.js function is passed to the next .then()
113
  return selected_paths.join(',');
114
  }
115
  """
 
116
 
117
- # Add the CSS for the selection border
118
- demo.css += "#sample_gallery .gallery-item.selected { border: 4px solid var(--primary-500) !important; }"
119
-
120
- # The modern way to link JS to an event:
121
- sample_gallery.select(
122
- fn=None, # No Python function runs on click
123
- js=select_js, # The JS function to run
124
- outputs=[selected_samples_textbox] # The JS function's return value updates this component
125
- )
126
-
127
- # ... (the rest of the event handlers are correct)
128
  def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
129
  image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
 
130
  async def submit_and_hide_modal(name, age, files):
131
- analysis_results = await process_analysis(name, age, files); return [*analysis_results, gr.update(visible=False)]
 
132
  submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
 
133
  cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
134
  start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
135
 
136
  async def handle_sample_analysis(selected_paths_str: str):
137
- selected_images = selected_paths_str.split(',') if selected_paths_str.strip() else []
138
  if not selected_images: raise gr.Error("Please select at least one sample image.")
139
  if len(selected_images) > 3: raise gr.Error("Please select no more than 3 sample images.")
 
140
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
141
  return [gr.update(visible=True), gr.update(visible=False), *analysis_results]
142
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
143
 
144
  all_pages = [main_app, history_page, samples_page]
145
- async def show_history_page_and_refresh(): records_update = await refresh_history_table(); return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
 
 
146
  def show_samples_page(): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
147
  def show_main_page(): return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
 
148
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
149
  samples_btn.click(fn=show_samples_page, outputs=all_pages)
150
  back_to_main_btn_hist.click(fn=show_main_page, outputs=all_pages)
151
  back_to_main_btn_samp.click(fn=show_main_page, outputs=all_pages)
 
152
  refresh_history_btn.click(fn=refresh_history_table, outputs=history_df)
153
  demo.load(fn=refresh_history_table, outputs=history_df)
154
 
155
-
156
  # --- Launch the App ---
157
  if __name__ == "__main__":
158
  demo.launch()
 
1
+ # app.py (Final Version with Syntax Fix)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  import asyncio
6
+ from PIL import Image
7
 
8
+ # Import backend components
9
  from app.prediction import PredictionPipeline
10
  from app.database import add_patient_record, get_all_records
11
 
12
  # --- Initialization ---
13
  prediction_pipeline = PredictionPipeline()
 
 
 
 
14
  SAMPLE_IMAGE_DIR = Path("sample_images")
15
  try:
16
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
 
20
  print("Warning: 'sample_images' directory not found or is empty. Please check setup.sh.")
21
  SAMPLE_IMAGES = []
22
 
23
+ # --- Core Logic Functions ---
24
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
25
+ if not is_sample and (not patient_name or patient_age is None):
26
+ raise gr.Error("Patient Name and Age are required.")
27
+ if not image_list:
28
+ raise gr.Error("At least one image is required.")
29
+
30
  result = prediction_pipeline.predict(image_list)
31
+ if "error" in result:
32
+ raise gr.Error(result["error"])
33
+
34
+ final_pred = result["final_prediction"]
35
+ final_conf = result["final_confidence"]
36
+
37
+ if not is_sample:
38
+ await add_patient_record(str(patient_name), int(patient_age), final_pred, final_conf)
39
+
40
+ confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0}
41
+ confidences[final_pred] = final_conf
42
+ confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
43
+
44
+ return [
45
+ gr.update(visible=False), # uploader_column
46
+ gr.update(visible=True), # results_column
47
+ gr.update(value=result["watermarked_images"]), # result_images
48
+ gr.update(value=confidences) # result_label
49
+ ]
50
 
51
  async def refresh_history_table():
 
52
  records = await get_all_records()
53
  data = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records] if records else []
54
  return gr.update(value=data)
 
66
  #results_gallery .gallery-item { padding: 0.25rem !important; background-color: #374151; border: 1px solid #374151 !important; }
67
  #bottom_controls { max-width: 600px; margin: 2.5rem auto 1rem auto; }
68
  #bottom_controls .gr-accordion > .gr-block-label { text-align: center !important; display: block !important; }
 
 
69
  #sample_gallery .gallery-item { box-shadow: 0 0 5px rgba(0,0,0,0.5); border-radius: 8px !important; }
70
+ #sample_gallery .gallery-item.selected { border: 4px solid var(--primary-500) !important; }
71
  """
72
+
73
+ # --- THIS IS THE FIX ---
74
+ with gr.Blocks(
75
+ theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"),
76
+ css=css,
77
+ title="Pneumonia Detection AI"
78
+ ) as demo:
79
 
 
80
  with gr.Column() as main_app:
81
+ with gr.Column(elem_id="app_header"):
82
+ gr.Markdown("# 🩺 Pneumonia Detection AI", elem_id="app_title")
83
+ gr.Markdown("An AI-powered tool to assist in the diagnosis of pneumonia.", elem_id="app_subtitle")
84
  with gr.Row(elem_id="main_container"):
85
  with gr.Column(scale=1) as uploader_column:
86
+ gr.Markdown("### Upload Patient X-Rays")
87
+ image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
88
  with gr.Column(scale=2, visible=False) as results_column:
89
+ gr.Markdown("### Analysis Results")
90
+ result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery")
91
+ result_label = gr.Label(label="Overall Prediction", num_top_classes=2)
92
+ start_over_btn = gr.Button("Start New Analysis", variant="secondary")
93
  with gr.Group(visible=False) as patient_info_modal:
94
+ gr.Markdown("## Enter Patient Details", elem_classes="text-center")
95
+ patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
96
+ patient_age_modal = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
97
+ with gr.Row():
98
+ submit_analysis_btn = gr.Button("Analyze Images", variant="primary")
99
+ cancel_btn = gr.Button("Cancel", variant="stop")
100
  with gr.Column(elem_id="bottom_controls"):
101
+ with gr.Accordion("About this Tool", open=False):
102
+ gr.Markdown(
103
+ """
104
+ ### MLOps-Powered Pneumonia Detection
105
+ (Your professional description here)
106
+ ---
107
+ **Project Team:** Alyyan Ahmed & Munim Akbar
108
+ """
109
+ )
110
+ with gr.Row():
111
+ samples_btn = gr.Button("Try Sample Images")
112
+ history_btn = gr.Button("View Patient History")
113
+
114
  with gr.Column(visible=False) as history_page:
115
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
116
+ with gr.Row():
117
+ back_to_main_btn_hist = gr.Button("⬅️ Back to Main App")
118
+ refresh_history_btn = gr.Button("Refresh History")
119
  history_df = gr.DataFrame(headers=["Name", "Age", "Prediction", "Confidence", "Date"], row_count=10, interactive=False)
120
+
121
  with gr.Column(visible=False) as samples_page:
122
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
123
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
124
+ sample_gallery = gr.Gallery(
125
+ value=SAMPLE_IMAGES if SAMPLE_IMAGES else [],
126
+ label="Sample Images",
127
+ columns=5, height=400,
128
+ elem_id="sample_gallery"
129
+ )
130
  selected_samples_textbox = gr.Textbox(visible=False, elem_id="selected_samples_textbox")
131
+ with gr.Row():
132
+ analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
133
+ back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
134
 
135
  # --- Event Handling Logic ---
136
+
 
137
  select_js = """
138
  (evt) => {
 
 
139
  const gallery = document.querySelector('#sample_gallery .grid-container');
140
  const clicked_img = gallery.children[evt.index];
141
  const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
142
  let selected_paths = selected_paths_input.value ? selected_paths_input.value.split(',') : [];
143
  const current_path = clicked_img.querySelector('img').alt;
 
144
  if (clicked_img.classList.contains('selected')) {
145
  clicked_img.classList.remove('selected');
146
  selected_paths = selected_paths.filter(p => p !== current_path);
 
152
  alert("You can select a maximum of 3 images.");
153
  }
154
  }
 
 
155
  return selected_paths.join(',');
156
  }
157
  """
158
+ sample_gallery.select(fn=None, js=select_js, outputs=[selected_samples_textbox])
159
 
 
 
 
 
 
 
 
 
 
 
 
160
  def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
161
  image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
162
+
163
  async def submit_and_hide_modal(name, age, files):
164
+ analysis_results = await process_analysis(name, age, files)
165
+ return [*analysis_results, gr.update(visible=False)]
166
  submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
167
+
168
  cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
169
  start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
170
 
171
  async def handle_sample_analysis(selected_paths_str: str):
172
+ selected_images = [path for path in selected_paths_str.split(',') if path]
173
  if not selected_images: raise gr.Error("Please select at least one sample image.")
174
  if len(selected_images) > 3: raise gr.Error("Please select no more than 3 sample images.")
175
+
176
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
177
  return [gr.update(visible=True), gr.update(visible=False), *analysis_results]
178
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
179
 
180
  all_pages = [main_app, history_page, samples_page]
181
+ async def show_history_page_and_refresh():
182
+ records_update = await refresh_history_table()
183
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
184
  def show_samples_page(): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
185
  def show_main_page(): return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
186
+
187
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
188
  samples_btn.click(fn=show_samples_page, outputs=all_pages)
189
  back_to_main_btn_hist.click(fn=show_main_page, outputs=all_pages)
190
  back_to_main_btn_samp.click(fn=show_main_page, outputs=all_pages)
191
+
192
  refresh_history_btn.click(fn=refresh_history_table, outputs=history_df)
193
  demo.load(fn=refresh_history_table, outputs=history_df)
194
 
 
195
  # --- Launch the App ---
196
  if __name__ == "__main__":
197
  demo.launch()