# app.py (in the root directory) import gradio as gr from pathlib import Path from huggingface_hub import snapshot_download import asyncio from PIL import Image # --- Import and Initialize Backend Components from the 'app' folder --- from app.prediction import PredictionPipeline from app.database import add_patient_record, get_all_records # Initialize components once prediction_pipeline = PredictionPipeline() HF_DATASET_REPO = "ALYYAN/chest-xray-pneumonia-samples" try: SAMPLE_IMAGE_DIR = Path(snapshot_download(repo_id=HF_DATASET_REPO, repo_type="dataset")) # Create a list of sample image paths for the Gradio component SAMPLE_IMAGES = [str(p) for p in list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg'))[:10]] except Exception as e: print(f"Could not download sample images: {e}") SAMPLE_IMAGES = [] # --- Core Prediction Logic for Gradio --- async def classify_images(patient_name, patient_age, image_list): # 1. Input Validation if not patient_name or patient_age is None: raise gr.Error("Patient Name and Age are required.") if not image_list: raise gr.Error("Please upload at least one image.") # Gradio provides file paths for uploaded files in a temp directory # Our prediction pipeline can handle these paths directly. # 2. Run Prediction result = prediction_pipeline.predict(image_list) # Pass the list of temp file paths prediction = result.get("prediction", "Error") confidence = result.get("confidence", 0) if prediction == "Error": raise gr.Error(result.get("details", "An unknown error occurred during prediction.")) # 3. Save to Database # Ensure age is an integer try: age = int(patient_age) except (ValueError, TypeError): raise gr.Error("Patient Age must be a valid number.") await add_patient_record( name=str(patient_name), age=age, result=prediction, confidence=confidence ) # 4. Format the Output for Gradio confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0} # Initialize both labels confidences[prediction] = confidence return confidences # --- Function to fetch and format database records --- async def get_records_html(): records = await get_all_records() if not records: return "
No records found in the database.
" # Create an HTML table from the records html = "| Name | Age | Prediction | Confidence | Date |
|---|---|---|---|---|
| {r.get('name', 'N/A')} | {r.get('age', 'N/A')} | {r.get('prediction_result', 'N/A')} | {confidence_percent} | {timestamp} |