ALYYAN commited on
Commit
dda6312
·
unverified ·
1 Parent(s): b5e2348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -38
app.py CHANGED
@@ -1,29 +1,26 @@
1
- # app.py (Final Version with Local Samples, Checkbox Selector, and UI Fixes)
2
 
3
  import gradio as gr
4
  from pathlib import Path
 
5
  import asyncio
6
- from PIL import Image
7
 
8
  from app.prediction import PredictionPipeline
9
  from app.database import add_patient_record, get_all_records
10
 
11
  # --- Initialization ---
12
  prediction_pipeline = PredictionPipeline()
13
- # --- FIX: Point to the locally cloned sample images directory from setup.sh ---
14
- SAMPLE_IMAGE_DIR = Path("sample_images")
15
  try:
 
16
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
17
- if not SAMPLE_IMAGES:
18
- raise FileNotFoundError
19
- except FileNotFoundError:
20
- print("Warning: 'sample_images' directory not found or is empty. Samples will be unavailable.")
21
  SAMPLE_IMAGES = []
22
 
23
-
24
  # --- Core Logic Functions (Unchanged and Correct) ---
 
25
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
26
- # ... (code is the same)
27
  if not is_sample and (not patient_name or patient_age is None): raise gr.Error("Patient Name and Age are required.")
28
  if not image_list: raise gr.Error("At least one image is required.")
29
  result = prediction_pipeline.predict(image_list)
@@ -32,9 +29,7 @@ async def process_analysis(patient_name, patient_age, image_list, is_sample=Fals
32
  if not is_sample: await add_patient_record(str(patient_name), int(patient_age), final_pred, final_conf)
33
  confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0}; confidences[final_pred] = final_conf; confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
34
  return [gr.update(visible=False), gr.update(visible=True), gr.update(value=result["watermarked_images"]), gr.update(value=confidences)]
35
-
36
  async def refresh_history_table():
37
- # ... (code is the same)
38
  records = await get_all_records()
39
  data = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records] if records else []
40
  return gr.update(value=data)
@@ -58,6 +53,10 @@ css = """
58
  """
59
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"), css=css, title="Pneumonia Detection AI") as demo:
60
 
 
 
 
 
61
  with gr.Column() as main_app:
62
  # ... (Main page layout is the same)
63
  with gr.Column(elem_id="app_header"):
@@ -85,9 +84,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
85
  with gr.Row():
86
  samples_btn = gr.Button("Try Sample Images")
87
  history_btn = gr.Button("View Patient History")
88
-
89
  with gr.Column(visible=False) as history_page:
90
- # ... (History page layout is the same)
91
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
92
  with gr.Row():
93
  back_to_main_btn_hist = gr.Button("⬅️ Back to Main App")
@@ -97,54 +94,95 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
97
  # --- SAMPLES PAGE (THE DEFINITIVE FIX) ---
98
  with gr.Column(visible=False) as samples_page:
99
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
100
- gr.Markdown("Select up to 3 images, then click 'Analyze'.")
101
 
102
- # Use a CheckboxGroup with images for selection
103
- sample_checkboxes = gr.CheckboxGroup(
 
104
  label="Sample Images",
105
- # A choice is a tuple: (Image for display, file path for value)
106
- choices=[(Image.open(p), p) for p in SAMPLE_IMAGES],
107
- type="value",
108
- elem_id="sample_gallery" # Use the gallery CSS
109
  )
110
 
 
 
 
111
  with gr.Row():
112
  analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
113
  back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
114
-
115
  # --- Event Handling Logic ---
116
-
117
- # ... (upload, modal, start over handlers are correct)
118
  def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
119
  image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
120
  async def submit_and_hide_modal(name, age, files):
121
- analysis_results = await process_analysis(name, age, files)
122
- return [*analysis_results, gr.update(visible=False)]
123
  submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
124
  cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
125
  start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
126
 
127
  # --- SAMPLE PAGE LOGIC (THE FIX) ---
128
- async def handle_sample_analysis(selected_images: list):
129
- # selected_images is now a list of file paths from the checkbox group
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if not selected_images: raise gr.Error("Please select at least one sample image.")
131
  if len(selected_images) > 3: raise gr.Error("Please select no more than 3 sample images.")
132
 
133
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
134
 
135
- # Return updates to show the results on the main page and hide this page
136
- return [
137
- gr.update(visible=True), # main_app
138
- gr.update(visible=False), # samples_page
139
- *analysis_results
140
- ]
141
- analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[sample_checkboxes], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
 
 
 
142
 
143
  # ... (Page Navigation is correct)
144
  all_pages = [main_app, history_page, samples_page]
145
- async def show_history_page_and_refresh():
146
- records_update = await refresh_history_table()
147
- return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
148
  def show_samples_page(): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
149
  def show_main_page(): return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
150
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
 
1
+ # app.py (Final Version with Working Sample Gallery)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
+ from huggingface_hub import snapshot_download
6
  import asyncio
 
7
 
8
  from app.prediction import PredictionPipeline
9
  from app.database import add_patient_record, get_all_records
10
 
11
  # --- Initialization ---
12
  prediction_pipeline = PredictionPipeline()
13
+ HF_DATASET_REPO = "ALYYAN/chest-xray-pneumonia-samples"
 
14
  try:
15
+ SAMPLE_IMAGE_DIR = Path(snapshot_download(repo_id=HF_DATASET_REPO, repo_type="dataset"))
16
  SAMPLE_IMAGES = [str(p) for p in sorted(list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg')))]
17
+ except Exception as e:
18
+ print(f"Could not download sample images: {e}")
 
 
19
  SAMPLE_IMAGES = []
20
 
 
21
  # --- Core Logic Functions (Unchanged and Correct) ---
22
+ # ... (process_analysis and refresh_history_table are the same as the last working version)
23
  async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
 
24
  if not is_sample and (not patient_name or patient_age is None): raise gr.Error("Patient Name and Age are required.")
25
  if not image_list: raise gr.Error("At least one image is required.")
26
  result = prediction_pipeline.predict(image_list)
 
29
  if not is_sample: await add_patient_record(str(patient_name), int(patient_age), final_pred, final_conf)
30
  confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0}; confidences[final_pred] = final_conf; confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
31
  return [gr.update(visible=False), gr.update(visible=True), gr.update(value=result["watermarked_images"]), gr.update(value=confidences)]
 
32
  async def refresh_history_table():
 
33
  records = await get_all_records()
34
  data = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records] if records else []
35
  return gr.update(value=data)
 
53
  """
54
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"), css=css, title="Pneumonia Detection AI") as demo:
55
 
56
+ # --- State to track selected sample images ---
57
+ selected_samples = gr.State([])
58
+
59
+ # --- UI LAYOUT (Unchanged) ---
60
  with gr.Column() as main_app:
61
  # ... (Main page layout is the same)
62
  with gr.Column(elem_id="app_header"):
 
84
  with gr.Row():
85
  samples_btn = gr.Button("Try Sample Images")
86
  history_btn = gr.Button("View Patient History")
 
87
  with gr.Column(visible=False) as history_page:
 
88
  gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
89
  with gr.Row():
90
  back_to_main_btn_hist = gr.Button("⬅️ Back to Main App")
 
94
  # --- SAMPLES PAGE (THE DEFINITIVE FIX) ---
95
  with gr.Column(visible=False) as samples_page:
96
  gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
97
+ gr.Markdown("Select up to 3 images by clicking on them, then click 'Analyze'.")
98
 
99
+ # This gallery will show the images
100
+ sample_gallery = gr.Gallery(
101
+ value=SAMPLE_IMAGES,
102
  label="Sample Images",
103
+ columns=5, height=400,
104
+ elem_id="sample_gallery"
 
 
105
  )
106
 
107
+ # This hidden textbox will store the list of selected file paths
108
+ selected_samples_textbox = gr.Textbox(visible=False)
109
+
110
  with gr.Row():
111
  analyze_samples_btn = gr.Button("Analyze Selected Samples", variant="primary")
112
  back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
113
+
114
  # --- Event Handling Logic ---
115
+ # ... (handlers for main upload workflow are correct)
 
116
  def show_patient_info(files): return gr.update(visible=True) if files else gr.update(visible=False)
117
  image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
118
  async def submit_and_hide_modal(name, age, files):
119
+ analysis_results = await process_analysis(name, age, files); return [*analysis_results, gr.update(visible=False)]
 
120
  submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
121
  cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
122
  start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
123
 
124
  # --- SAMPLE PAGE LOGIC (THE FIX) ---
125
+
126
+ # JavaScript to handle multi-select on the gallery
127
+ # When an image is clicked, this JS will add/remove its path from the hidden textbox
128
+ # and add/remove a 'selected' class for a visual border.
129
+ select_js = """
130
+ (evt) => {
131
+ const gallery = document.querySelector('#sample_gallery .grid-container');
132
+ const clicked_img = gallery.children[evt.index];
133
+ const selected_paths_input = document.querySelector('#selected_samples_textbox textarea');
134
+ let selected_paths = selected_paths_input.value ? selected_paths_input.value.split(',') : [];
135
+ const current_path = clicked_img.querySelector('img').alt;
136
+
137
+ if (clicked_img.classList.contains('selected')) {
138
+ clicked_img.classList.remove('selected');
139
+ selected_paths = selected_paths.filter(p => p !== current_path);
140
+ } else {
141
+ if (selected_paths.length < 3) {
142
+ clicked_img.classList.add('selected');
143
+ selected_paths.push(current_path);
144
+ } else {
145
+ // This is a simple browser alert. Gradio's gr.Warning is better for the final check.
146
+ alert("You can select a maximum of 3 images.");
147
+ }
148
+ }
149
+
150
+ // Return the updated list of paths to the hidden textbox
151
+ return selected_paths.join(',');
152
+ }
153
+ """
154
+
155
+ # We need to add a little CSS for the selection border
156
+ demo.css += "#sample_gallery .gallery-item.selected { border: 4px solid var(--primary-500) !important; }"
157
+
158
+ # Hidden textbox to store the paths
159
+ selected_samples_textbox = gr.Textbox(value="", visible=False, elem_id="selected_samples_textbox")
160
+
161
+ sample_gallery.select(fn=None, _js=select_js, outputs=[selected_samples_textbox])
162
+
163
+ async def handle_sample_analysis(selected_paths_str: str):
164
+ # The input is now a comma-separated string of paths from our hidden textbox
165
+ selected_images = selected_paths_str.split(',') if selected_paths_str else []
166
+
167
  if not selected_images: raise gr.Error("Please select at least one sample image.")
168
  if len(selected_images) > 3: raise gr.Error("Please select no more than 3 sample images.")
169
 
170
  analysis_results = await process_analysis("Sample User", 0, selected_images, is_sample=True)
171
 
172
+ return {
173
+ main_app: gr.update(visible=True),
174
+ samples_page: gr.update(visible=False),
175
+ # Unpack dictionary updates for specific components
176
+ uploader_column: analysis_results[0],
177
+ results_column: analysis_results[1],
178
+ result_images: analysis_results[2],
179
+ result_label: analysis_results[3],
180
+ }
181
+ analyze_samples_btn.click(fn=handle_sample_analysis, inputs=[selected_samples_textbox], outputs=[main_app, samples_page, uploader_column, results_column, result_images, result_label])
182
 
183
  # ... (Page Navigation is correct)
184
  all_pages = [main_app, history_page, samples_page]
185
+ async def show_history_page_and_refresh(): records_update = await refresh_history_table(); return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
 
 
186
  def show_samples_page(): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
187
  def show_main_page(): return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
188
  history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])