ALYYAN commited on
Commit
9ae30e6
·
unverified ·
1 Parent(s): dda6312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -89
app.py CHANGED
@@ -1,8 +1,7 @@
1
- # app.py (Final Version with Working Sample Gallery)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
- from huggingface_hub import snapshot_download
6
  import asyncio
7
 
8
  from app.prediction import PredictionPipeline
@@ -10,17 +9,22 @@ from app.database import add_patient_record, get_all_records
10
 
11
  # --- Initialization ---
12
  prediction_pipeline = PredictionPipeline()
13
- HF_DATASET_REPO = "ALYYAN/chest-xray-pneumonia-samples"
 
 
 
 
14
  try:
15
- SAMPLE_IMAGE_DIR = Path(snapshot_download(repo_id=HF_DATASET_REPO, repo_type="dataset"))
16
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
17
- except Exception as e:
18
- print(f"Could not download sample images: {e}")
 
 
19
  SAMPLE_IMAGES = []
20
 
21
  # --- Core Logic Functions (Unchanged and Correct) ---
22
- # ... (process_analysis and refresh_history_table are the same as the last working version)
23
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
 
24
  if not is_sample and (not patient_name or patient_age is None): raise gr.Error("Patient Name and Age are required.")
25
  if not image_list: raise gr.Error("At least one image is required.")
26
  result = prediction_pipeline.predict(image_list)
@@ -29,7 +33,9 @@ async def process_analysis(patient_name, patient_age, image_list, is_sample=Fals
29
  if not is_sample: await add_patient_record(str(patient_name), int(patient_age), final_pred, final_conf)
30
  confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0}; confidences[final_pred] = final_conf; confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
31
  return [gr.update(visible=False), gr.update(visible=True), gr.update(value=result["watermarked_images"]), gr.update(value=confidences)]
 
32
  async def refresh_history_table():
 
33
  records = await get_all_records()
34
  data = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records] if records else []
35
  return gr.update(value=data)
@@ -51,83 +57,40 @@ css = """
51
  #sample_gallery { background-color: transparent !important; border: none !important; }
52
  #sample_gallery .gallery-item { box-shadow: 0 0 5px rgba(0,0,0,0.5); border-radius: 8px !important; }
53
  """
54
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"), css=css, title="Pneumonia Detection AI") as demo:
55
 
56
- # --- State to track selected sample images ---
57
- selected_samples = gr.State([])
58
-
59
- # --- UI LAYOUT (Unchanged) ---
60
  with gr.Column() as main_app:
61
- # ... (Main page layout is the same)
62
- with gr.Column(elem_id="app_header"):
63
- gr.Markdown("# 🩺 Pneumonia Detection AI", elem_id="app_title")
64
- gr.Markdown("An AI-powered tool to assist in the diagnosis of pneumonia.", elem_id="app_subtitle")
65
  with gr.Row(elem_id="main_container"):
66
  with gr.Column(scale=1) as uploader_column:
67
- gr.Markdown("### Upload Patient X-Rays")
68
- image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
69
  with gr.Column(scale=2, visible=False) as results_column:
70
- gr.Markdown("### Analysis Results")
71
- result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery")
72
- result_label = gr.Label(label="Overall Prediction", num_top_classes=2)
73
- start_over_btn = gr.Button("Start New Analysis", variant="secondary")
74
  with gr.Group(visible=False) as patient_info_modal:
75
- gr.Markdown("## Enter Patient Details", elem_classes="text-center")
76
- patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
77
- patient_age_modal = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
78
- with gr.Row():
79
- submit_analysis_btn = gr.Button("Analyze Images", variant="primary")
80
- cancel_btn = gr.Button("Cancel", variant="stop")
81
  with gr.Column(elem_id="bottom_controls"):
82
- with gr.Accordion("About this Tool", open=False):
83
- gr.Markdown("...") # (Your professional description here)
84
- with gr.Row():
85
- samples_btn = gr.Button("Try Sample Images")
86
- history_btn = gr.Button("View Patient History")
87
  with gr.Column(visible=False) as history_page:
88
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
89
- with gr.Row():
90
- back_to_main_btn_hist = gr.Button("⬅️ Back to Main App")
91
- refresh_history_btn = gr.Button("Refresh History")
92
  history_df = gr.DataFrame(headers=["Name", "Age", "Prediction", "Confidence", "Date"], row_count=10, interactive=False)
93
-
94
- # --- SAMPLES PAGE (THE DEFINITIVE FIX) ---
95
  with gr.Column(visible=False) as samples_page:
96
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
97
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
98
-
99
- # This gallery will show the images
100
- sample_gallery = gr.Gallery(
101
- value=SAMPLE_IMAGES,
102
- label="Sample Images",
103
- columns=5, height=400,
104
- elem_id="sample_gallery"
105
- )
106
-
107
- # This hidden textbox will store the list of selected file paths
108
- selected_samples_textbox = gr.Textbox(visible=False)
109
-
110
- with gr.Row():
111
- analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
112
- back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
113
-
114
  # --- Event Handling Logic ---
115
- # ... (handlers for main upload workflow are correct)
116
- def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
117
- image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
118
- async def submit_and_hide_modal(name, age, files):
119
- analysis_results = await process_analysis(name, age, files); return [*analysis_results, gr.update(visible=False)]
120
- submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
121
- cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
122
- start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
123
 
124
- # --- SAMPLE PAGE LOGIC (THE FIX) ---
125
-
126
- # JavaScript to handle multi-select on the gallery
127
- # When an image is clicked, this JS will add/remove its path from the hidden textbox
128
- # and add/remove a 'selected' class for a visual border.
129
  select_js = """
130
  (evt) => {
 
 
131
  const gallery = document.querySelector('#sample_gallery .grid-container');
132
  const clicked_img = gallery.children[evt.index];
133
  const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
@@ -142,45 +105,42 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
142
  clicked_img.classList.add('selected');
143
  selected_paths.push(current_path);
144
  } else {
145
- // This is a simple browser alert. Gradio's gr.Warning is better for the final check.
146
  alert("You can select a maximum of 3 images.");
147
  }
148
  }
149
 
150
- // Return the updated list of paths to the hidden textbox
151
  return selected_paths.join(',');
152
  }
153
  """
154
 
155
- # We need to add a little CSS for the selection border
156
  demo.css += "#sample_gallery .gallery-item.selected { border: 4px solid var(--primary-500) !important; }"
157
-
158
- # Hidden textbox to store the paths
159
- selected_samples_textbox = gr.Textbox(value="", visible=False, elem_id="selected_samples_textbox")
160
 
161
- sample_gallery.select(fn=None, _js=select_js, outputs=[selected_samples_textbox])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  async def handle_sample_analysis(selected_paths_str: str):
164
- # The input is now a comma-separated string of paths from our hidden textbox
165
- selected_images = selected_paths_str.split(',') if selected_paths_str else []
166
-
167
  if not selected_images: raise gr.Error("Please select at least one sample image.")
168
  if len(selected_images) > 3: raise gr.Error("Please select no more than 3 sample images.")
169
-
170
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
171
-
172
- return {
173
- main_app: gr.update(visible=True),
174
- samples_page: gr.update(visible=False),
175
- # Unpack dictionary updates for specific components
176
- uploader_column: analysis_results[0],
177
- results_column: analysis_results[1],
178
- result_images: analysis_results[2],
179
- result_label: analysis_results[3],
180
- }
181
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
182
 
183
- # ... (Page Navigation is correct)
184
  all_pages = [main_app, history_page, samples_page]
185
  async def show_history_page_and_refresh(): records_update = await refresh_history_table(); return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
186
  def show_samples_page(): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
@@ -192,6 +152,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
192
  refresh_history_btn.click(fn=refresh_history_table, outputs=history_df)
193
  demo.load(fn=refresh_history_table, outputs=history_df)
194
 
 
195
  # --- Launch the App ---
196
  if __name__ == "__main__":
197
  demo.launch()
 
1
+ # app.py (Final Version - No Downloads, Modern JS)
2
 
3
  import gradio as gr
4
  from pathlib import Path
 
5
  import asyncio
6
 
7
  from app.prediction import PredictionPipeline
 
9
 
10
  # --- Initialization ---
11
  prediction_pipeline = PredictionPipeline()
12
+
13
+ # --- FIX 1: Remove Hugging Face Hub download logic ---
14
+ # The setup.sh script already clones the 'sample_images' directory.
15
+ # We just need to point to it.
16
+ SAMPLE_IMAGE_DIR = Path("sample_images")
17
  try:
 
18
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
19
+ if not SAMPLE_IMAGES:
20
+ raise FileNotFoundError
21
+ except FileNotFoundError:
22
+ print("Warning: 'sample_images' directory not found or is empty. Please check setup.sh.")
23
  SAMPLE_IMAGES = []
24
 
25
  # --- Core Logic Functions (Unchanged and Correct) ---
 
26
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
27
+ # ... (code is the same)
28
  if not is_sample and (not patient_name or patient_age is None): raise gr.Error("Patient Name and Age are required.")
29
  if not image_list: raise gr.Error("At least one image is required.")
30
  result = prediction_pipeline.predict(image_list)
 
33
  if not is_sample: await add_patient_record(str(patient_name), int(patient_age), final_pred, final_conf)
34
  confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0}; confidences[final_pred] = final_conf; confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
35
  return [gr.update(visible=False), gr.update(visible=True), gr.update(value=result["watermarked_images"]), gr.update(value=confidences)]
36
+
37
  async def refresh_history_table():
38
+ # ... (code is the same)
39
  records = await get_all_records()
40
  data = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records] if records else []
41
  return gr.update(value=data)
 
57
  #sample_gallery { background-color: transparent !important; border: none !important; }
58
  #sample_gallery .gallery-item { box-shadow: 0 0 5px rgba(0,0,0,0.5); border-radius: 8px !important; }
59
  """
60
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue"), secondary_hue="blue"), css=css, title="Pneumonia Detection AI") as demo:
61
 
62
+ # ... (UI Layout is the same)
 
 
 
63
  with gr.Column() as main_app:
64
+ # ...
 
 
 
65
  with gr.Row(elem_id="main_container"):
66
  with gr.Column(scale=1) as uploader_column:
67
+ gr.Markdown("### Upload Patient X-Rays"); image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
 
68
  with gr.Column(scale=2, visible=False) as results_column:
69
+ gr.Markdown("### Analysis Results"); result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery"); result_label = gr.Label(label="Overall Prediction", num_top_classes=2); start_over_btn = gr.Button("Start New Analysis", variant="secondary")
 
 
 
70
  with gr.Group(visible=False) as patient_info_modal:
71
+ gr.Markdown("## Enter Patient Details", elem_classes="text-center"); patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe"); patient_age_modal = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
72
+ with gr.Row(): submit_analysis_btn = gr.Button("Analyze Images", variant="primary"); cancel_btn = gr.Button("Cancel", variant="stop")
 
 
 
 
73
  with gr.Column(elem_id="bottom_controls"):
74
+ with gr.Accordion("About this Tool", open=False): gr.Markdown("...")
75
+ with gr.Row(): samples_btn = gr.Button("Try Sample Images"); history_btn = gr.Button("View Patient History")
 
 
 
76
  with gr.Column(visible=False) as history_page:
77
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
78
+ with gr.Row(): back_to_main_btn_hist = gr.Button("⬅️ Back to Main App"); refresh_history_btn = gr.Button("Refresh History")
 
 
79
  history_df = gr.DataFrame(headers=["Name", "Age", "Prediction", "Confidence", "Date"], row_count=10, interactive=False)
 
 
80
  with gr.Column(visible=False) as samples_page:
81
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
82
  gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
83
+ sample_gallery = gr.Gallery(value=SAMPLE_IMAGES, label="Sample Images", columns=5, height=400, elem_id="sample_gallery")
84
+ selected_samples_textbox = gr.Textbox(visible=False, elem_id="selected_samples_textbox")
85
+ with gr.Row(): analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary"); back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
86
+
 
 
 
 
 
 
 
 
 
 
 
 
87
  # --- Event Handling Logic ---
 
 
 
 
 
 
 
 
88
 
89
+ # --- FIX 2: Use the modern gr.js() function for custom JavaScript ---
 
 
 
 
90
  select_js = """
91
  (evt) => {
92
+ // This JS code runs in the browser when a sample image is clicked.
93
+ // It's the same logic as before.
94
  const gallery = document.querySelector('#sample_gallery .grid-container');
95
  const clicked_img = gallery.children[evt.index];
96
  const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
 
105
  clicked_img.classList.add('selected');
106
  selected_paths.push(current_path);
107
  } else {
 
108
  alert("You can select a maximum of 3 images.");
109
  }
110
  }
111
 
112
+ // The return value of a gr.js function is passed to the next .then()
113
  return selected_paths.join(',');
114
  }
115
  """
116
 
117
+ # Add the CSS for the selection border
118
  demo.css += "#sample_gallery .gallery-item.selected { border: 4px solid var(--primary-500) !important; }"
 
 
 
119
 
120
+ # The modern way to link JS to an event:
121
+ sample_gallery.select(
122
+ fn=None, # No Python function runs on click
123
+ js=select_js, # The JS function to run
124
+ outputs=[selected_samples_textbox] # The JS function's return value updates this component
125
+ )
126
+
127
+ # ... (the rest of the event handlers are correct)
128
+ def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
129
+ image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
130
+ async def submit_and_hide_modal(name, age, files):
131
+ analysis_results = await process_analysis(name, age, files); return [*analysis_results, gr.update(visible=False)]
132
+ submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
133
+ cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
134
+ start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
135
 
136
  async def handle_sample_analysis(selected_paths_str: str):
137
+ selected_images = selected_paths_str.split(',') if selected_paths_str.strip() else []
 
 
138
  if not selected_images: raise gr.Error("Please select at least one sample image.")
139
  if len(selected_images) > 3: raise gr.Error("Please select no more than 3 sample images.")
 
140
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
141
+ return [gr.update(visible=True), gr.update(visible=False), *analysis_results]
 
 
 
 
 
 
 
 
 
142
  analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
143
 
 
144
  all_pages = [main_app, history_page, samples_page]
145
  async def show_history_page_and_refresh(): records_update = await refresh_history_table(); return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
146
  def show_samples_page(): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
 
152
  refresh_history_btn.click(fn=refresh_history_table, outputs=history_df)
153
  demo.load(fn=refresh_history_table, outputs=history_df)
154
 
155
+
156
  # --- Launch the App ---
157
  if __name__ == "__main__":
158
  demo.launch()