Spaces:
Sleeping
Sleeping
File size: 5,293 Bytes
b383602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# app.py (in the root directory)
import gradio as gr
from pathlib import Path
from huggingface_hub import snapshot_download
import asyncio
from PIL import Image
# --- Import and Initialize Backend Components from the 'app' folder ---
from app.prediction import PredictionPipeline
from app.database import add_patient_record, get_all_records
# Initialize components once
prediction_pipeline = PredictionPipeline()
HF_DATASET_REPO = "ALYYAN/chest-xray-pneumonia-samples"
try:
SAMPLE_IMAGE_DIR = Path(snapshot_download(repo_id=HF_DATASET_REPO, repo_type="dataset"))
# Create a list of sample image paths for the Gradio component
SAMPLE_IMAGES = [str(p) for p in list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg'))[:10]]
except Exception as e:
print(f"Could not download sample images: {e}")
SAMPLE_IMAGES = []
# --- Core Prediction Logic for Gradio ---
async def classify_images(patient_name, patient_age, image_list):
# 1. Input Validation
if not patient_name or patient_age is None:
raise gr.Error("Patient Name and Age are required.")
if not image_list:
raise gr.Error("Please upload at least one image.")
# Gradio provides file paths for uploaded files in a temp directory
# Our prediction pipeline can handle these paths directly.
# 2. Run Prediction
result = prediction_pipeline.predict(image_list) # Pass the list of temp file paths
prediction = result.get("prediction", "Error")
confidence = result.get("confidence", 0)
if prediction == "Error":
raise gr.Error(result.get("details", "An unknown error occurred during prediction."))
# 3. Save to Database
# Ensure age is an integer
try:
age = int(patient_age)
except (ValueError, TypeError):
raise gr.Error("Patient Age must be a valid number.")
await add_patient_record(
name=str(patient_name),
age=age,
result=prediction,
confidence=confidence
)
# 4. Format the Output for Gradio
confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0} # Initialize both labels
confidences[prediction] = confidence
return confidences
# --- Function to fetch and format database records ---
async def get_records_html():
records = await get_all_records()
if not records:
return "<p>No records found in the database.</p>"
# Create an HTML table from the records
html = "<table><tr><th>Name</th><th>Age</th><th>Prediction</th><th>Confidence</th><th>Date</th></tr>"
for r in records:
confidence_percent = f"{r['confidence_score']:.2%}" if r['confidence_score'] is not None else "N/A"
timestamp = r['timestamp'].strftime('%Y-%m-%d %H:%M') if r['timestamp'] else "N/A"
html += f"<tr><td>{r.get('name', 'N/A')}</td><td>{r.get('age', 'N/A')}</td><td>{r.get('prediction_result', 'N/A')}</td><td>{confidence_percent}</td><td>{timestamp}</td></tr>"
html += "</table>"
return html
# --- Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="table { width: 100%; border-collapse: collapse; } th, td { padding: 8px; text-align: left; border-bottom: 1px solid #ddd; }") as demo:
gr.Markdown("# 🩺 Pneumonia Detection AI")
gr.Markdown("Upload one or more chest X-ray images for a patient to classify them as **Normal** or **Pneumonia**.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Patient Information")
patient_name = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
patient_age = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
gr.Markdown("### 2. Upload Images")
# Using type="filepath" is simpler and avoids memory issues with large images
image_input = gr.File(
label="Upload up to 3 X-Rays",
file_count="multiple",
file_types=["image"],
type="filepath" # Gradio will save uploads to a temp dir and give us the path
)
submit_btn = gr.Button("Analyze Images", variant="primary")
if SAMPLE_IMAGES:
gr.Examples(
examples=SAMPLE_IMAGES,
inputs=image_input,
label="Sample Images (Click one, then click Analyze)",
examples_per_page=5
)
with gr.Column(scale=1):
gr.Markdown("### 3. Analysis Results")
output_label = gr.Label(label="Prediction", num_top_classes=2)
gr.Markdown("---")
with gr.Accordion("View Patient Record History", open=False):
records_html = gr.HTML("Loading records...")
demo.load(get_records_html, None, records_html) # Load records when the app starts
refresh_btn = gr.Button("Refresh History")
# --- Link Components to the Function ---
submit_btn.click(
fn=classify_images,
inputs=[patient_name, patient_age, image_input],
outputs=[output_label]
)
# When the refresh button is clicked, re-run the get_records_html function
refresh_btn.click(fn=get_records_html, inputs=None, outputs=records_html)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch() |