ALYYAN commited on
Commit
b383602
·
1 Parent(s): 4821854

Gradio UI added

Browse files
.nicegui/storage-user-86d297d3-3ea0-4fc7-835a-6c59d3b4ba3a.json ADDED
File without changes
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (in the root directory)
2
+
3
+ import gradio as gr
4
+ from pathlib import Path
5
+ from huggingface_hub import snapshot_download
6
+ import asyncio
7
+ from PIL import Image
8
+
9
+ # --- Import and Initialize Backend Components from the 'app' folder ---
10
+ from app.prediction import PredictionPipeline
11
+ from app.database import add_patient_record, get_all_records
12
+
13
+ # Initialize components once
14
+ prediction_pipeline = PredictionPipeline()
15
+ HF_DATASET_REPO = "ALYYAN/chest-xray-pneumonia-samples"
16
+ try:
17
+ SAMPLE_IMAGE_DIR = Path(snapshot_download(repo_id=HF_DATASET_REPO, repo_type="dataset"))
18
+ # Create a list of sample image paths for the Gradio component
19
+ SAMPLE_IMAGES = [str(p) for p in list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg'))[:10]]
20
+ except Exception as e:
21
+ print(f"Could not download sample images: {e}")
22
+ SAMPLE_IMAGES = []
23
+
24
+ # --- Core Prediction Logic for Gradio ---
25
+ async def classify_images(patient_name, patient_age, image_list):
26
+ # 1. Input Validation
27
+ if not patient_name or patient_age is None:
28
+ raise gr.Error("Patient Name and Age are required.")
29
+ if not image_list:
30
+ raise gr.Error("Please upload at least one image.")
31
+
32
+ # Gradio provides file paths for uploaded files in a temp directory
33
+ # Our prediction pipeline can handle these paths directly.
34
+
35
+ # 2. Run Prediction
36
+ result = prediction_pipeline.predict(image_list) # Pass the list of temp file paths
37
+ prediction = result.get("prediction", "Error")
38
+ confidence = result.get("confidence", 0)
39
+
40
+ if prediction == "Error":
41
+ raise gr.Error(result.get("details", "An unknown error occurred during prediction."))
42
+
43
+ # 3. Save to Database
44
+ # Ensure age is an integer
45
+ try:
46
+ age = int(patient_age)
47
+ except (ValueError, TypeError):
48
+ raise gr.Error("Patient Age must be a valid number.")
49
+
50
+ await add_patient_record(
51
+ name=str(patient_name),
52
+ age=age,
53
+ result=prediction,
54
+ confidence=confidence
55
+ )
56
+
57
+ # 4. Format the Output for Gradio
58
+ confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0} # Initialize both labels
59
+ confidences[prediction] = confidence
60
+
61
+ return confidences
62
+
63
+ # --- Function to fetch and format database records ---
64
+ async def get_records_html():
65
+ records = await get_all_records()
66
+ if not records:
67
+ return "<p>No records found in the database.</p>"
68
+
69
+ # Create an HTML table from the records
70
+ html = "<table><tr><th>Name</th><th>Age</th><th>Prediction</th><th>Confidence</th><th>Date</th></tr>"
71
+ for r in records:
72
+ confidence_percent = f"{r['confidence_score']:.2%}" if r['confidence_score'] is not None else "N/A"
73
+ timestamp = r['timestamp'].strftime('%Y-%m-%d %H:%M') if r['timestamp'] else "N/A"
74
+ html += f"<tr><td>{r.get('name', 'N/A')}</td><td>{r.get('age', 'N/A')}</td><td>{r.get('prediction_result', 'N/A')}</td><td>{confidence_percent}</td><td>{timestamp}</td></tr>"
75
+ html += "</table>"
76
+ return html
77
+
78
+ # --- Build the Gradio Interface ---
79
+ with gr.Blocks(theme=gr.themes.Soft(), css="table { width: 100%; border-collapse: collapse; } th, td { padding: 8px; text-align: left; border-bottom: 1px solid #ddd; }") as demo:
80
+ gr.Markdown("# 🩺 Pneumonia Detection AI")
81
+ gr.Markdown("Upload one or more chest X-ray images for a patient to classify them as **Normal** or **Pneumonia**.")
82
+
83
+ with gr.Row():
84
+ with gr.Column(scale=1):
85
+ gr.Markdown("### 1. Patient Information")
86
+ patient_name = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
87
+ patient_age = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
88
+
89
+ gr.Markdown("### 2. Upload Images")
90
+ # Using type="filepath" is simpler and avoids memory issues with large images
91
+ image_input = gr.File(
92
+ label="Upload up to 3 X-Rays",
93
+ file_count="multiple",
94
+ file_types=["image"],
95
+ type="filepath" # Gradio will save uploads to a temp dir and give us the path
96
+ )
97
+
98
+ submit_btn = gr.Button("Analyze Images", variant="primary")
99
+
100
+ if SAMPLE_IMAGES:
101
+ gr.Examples(
102
+ examples=SAMPLE_IMAGES,
103
+ inputs=image_input,
104
+ label="Sample Images (Click one, then click Analyze)",
105
+ examples_per_page=5
106
+ )
107
+
108
+ with gr.Column(scale=1):
109
+ gr.Markdown("### 3. Analysis Results")
110
+ output_label = gr.Label(label="Prediction", num_top_classes=2)
111
+ gr.Markdown("---")
112
+ with gr.Accordion("View Patient Record History", open=False):
113
+ records_html = gr.HTML("Loading records...")
114
+ demo.load(get_records_html, None, records_html) # Load records when the app starts
115
+ refresh_btn = gr.Button("Refresh History")
116
+
117
+
118
+ # --- Link Components to the Function ---
119
+ submit_btn.click(
120
+ fn=classify_images,
121
+ inputs=[patient_name, patient_age, image_input],
122
+ outputs=[output_label]
123
+ )
124
+
125
+ # When the refresh button is clicked, re-run the get_records_html function
126
+ refresh_btn.click(fn=get_records_html, inputs=None, outputs=records_html)
127
+
128
+ # --- Launch the App ---
129
+ if __name__ == "__main__":
130
+ demo.launch()
app/__init__.py ADDED
File without changes
app/database.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/database.py
2
+
3
+ import os
4
+ from motor.motor_asyncio import AsyncIOMotorClient
5
+ from dotenv import load_dotenv
6
+ import datetime
7
+ from typing import List, Dict
8
+
9
+ # Load environment variables from .env file
10
+ load_dotenv()
11
+
12
+ # --- Database Connection ---
13
+ # Get the connection string from the environment variables
14
+ MONGODB_URL = os.getenv("MONGODB_CONNECTION_STRING")
15
+
16
+ if not MONGODB_URL:
17
+ raise ValueError("MONGODB_CONNECTION_STRING not found in environment variables. Please check your .env file.")
18
+
19
+ # Create a client instance
20
+ client = AsyncIOMotorClient(MONGODB_URL)
21
+
22
+ # Get a handle to the database (it will be created if it doesn't exist)
23
+ # The database name 'pneumonia_db' should match the one in your connection string
24
+ database = client.pneumonia_db
25
+
26
+ # Get a handle to the collection (like a table in SQL)
27
+ patient_collection = database.get_collection("patient_records")
28
+
29
+
30
+ # --- Database Operations (now async) ---
31
+
32
+ async def add_patient_record(name: str, age: int, result: str, confidence: float) -> Dict:
33
+ """
34
+ Inserts a new patient record into the MongoDB collection.
35
+
36
+ Returns the inserted document.
37
+ """
38
+ record_document = {
39
+ "name": name,
40
+ "age": age,
41
+ "prediction_result": result,
42
+ "confidence_score": confidence,
43
+ "timestamp": datetime.datetime.utcnow()
44
+ }
45
+
46
+ # .insert_one is an async operation, so we must 'await' it
47
+ result = await patient_collection.insert_one(record_document)
48
+
49
+ # Find the newly created document to return it
50
+ new_record = await patient_collection.find_one({"_id": result.inserted_id})
51
+ return new_record
52
+
53
+
54
+ async def get_all_records() -> List[Dict]:
55
+ """
56
+ Retrieves all patient records, sorted by the most recent timestamp.
57
+ """
58
+ records = []
59
+ # .find() returns a cursor, which we iterate over asynchronously
60
+ cursor = patient_collection.find({}).sort("timestamp", -1) # -1 for descending order
61
+ async for document in cursor:
62
+ records.append(document)
63
+ return records
app/prediction.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/prediction.py
2
+
3
+ import torch
4
+ from transformers import ViTImageProcessor, ViTForImageClassification
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ import numpy as np
8
+ from typing import List, Dict, Union
9
+
10
+ # Define a type hint for the input, which can be a path or bytes
11
+ ImageType = Union[str, Path, bytes]
12
+
13
+ class PredictionPipeline:
14
+ def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ self.processor = ViTImageProcessor.from_pretrained(model_path)
17
+ self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
18
+ self.model.eval()
19
+ self.id2label = self.model.config.id2label
20
+
21
+ def predict(self, image_sources: List[ImageType]) -> Dict[str, Union[str, float]]:
22
+ if not image_sources:
23
+ return {"prediction": "Error", "confidence": 0.0, "details": "No images provided."}
24
+
25
+ all_logits = []
26
+ for source in image_sources:
27
+ try:
28
+ # --- THIS IS THE FIX ---
29
+ # The Image.open() function can handle both paths and byte streams.
30
+ # No special handling is needed.
31
+ image = Image.open(source).convert("RGB")
32
+
33
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
34
+
35
+ with torch.no_grad():
36
+ outputs = self.model(**inputs)
37
+ all_logits.append(outputs.logits)
38
+ except Exception as e:
39
+ print(f"Skipping a corrupted or invalid image file. Error: {e}")
40
+ continue
41
+
42
+ if not all_logits:
43
+ return {"prediction": "Error", "confidence": 0.0, "details": "All provided images were invalid."}
44
+
45
+ avg_logits = torch.mean(torch.stack(all_logits), dim=0)
46
+ probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
47
+ confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
48
+ predicted_label = self.id2label[predicted_class_idx.item()]
49
+
50
+ return {
51
+ "prediction": predicted_label,
52
+ "confidence": confidence_score.item()
53
+ }
requirements.txt CHANGED
@@ -18,4 +18,10 @@ dvc
18
  matplotlib
19
  Pillow
20
  kaggle
21
- python-dotenv
 
 
 
 
 
 
 
18
  matplotlib
19
  Pillow
20
  kaggle
21
+ python-dotenv
22
+ nicegui
23
+ sqlalchemy
24
+ pymongo
25
+ motor
26
+ huggingface_hub
27
+ gradio
src/vitClassifier/pipeline/prediction.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prediction.py
2
+
3
+ import torch
4
+ from transformers import ViTImageProcessor, ViTForImageClassification
5
+ from PIL import Image
6
+ import argparse
7
+ import os
8
+ from pathlib import Path
9
+
10
+ class PredictionPipeline:
11
+ def __init__(self, model_path: str = "artifacts/model_training/model"):
12
+ """
13
+ Initializes the prediction pipeline by loading the trained model and processor.
14
+
15
+ Args:
16
+ model_path (str): The path to the directory containing the saved model and processor.
17
+ """
18
+ # Set the device
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Using device: {self.device}")
21
+
22
+ # Load the processor and model from the specified path
23
+ self.processor = ViTImageProcessor.from_pretrained(model_path)
24
+ self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
25
+ self.model.eval() # Set the model to evaluation mode
26
+
27
+ # Get the label mappings from the model's configuration
28
+ self.id2label = self.model.config.id2label
29
+
30
+ def predict(self, image_path: str):
31
+ """
32
+ Makes a prediction on a single image.
33
+
34
+ Args:
35
+ image_path (str): The file path of the image to be classified.
36
+
37
+ Returns:
38
+ dict: A dictionary containing the predicted label and its confidence score.
39
+ """
40
+ try:
41
+ # Open the image using PIL (Python Imaging Library)
42
+ image = Image.open(image_path).convert("RGB")
43
+ except FileNotFoundError:
44
+ return {"error": f"Image not found at path: {image_path}"}
45
+ except Exception as e:
46
+ return {"error": f"Failed to open image: {e}"}
47
+
48
+ # Preprocess the image using the ViTImageProcessor
49
+ # This handles resizing, normalization, and conversion to a tensor
50
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
51
+
52
+ # Make a prediction
53
+ with torch.no_grad(): # Disable gradient calculation for faster inference
54
+ outputs = self.model(**inputs)
55
+ logits = outputs.logits
56
+
57
+ # Get the predicted class index
58
+ predicted_class_idx = logits.argmax(-1).item()
59
+
60
+ # Get the human-readable label
61
+ predicted_label = self.id2label[predicted_class_idx]
62
+
63
+ # Calculate the confidence score using softmax
64
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
65
+ confidence_score = probabilities[0][predicted_class_idx].item()
66
+
67
+ result = {
68
+ "predicted_label": predicted_label,
69
+ "confidence_score": f"{confidence_score:.4f}"
70
+ }
71
+
72
+ return result
73
+
74
+ if __name__ == '__main__':
75
+ # --- How to run this script from the command line ---
76
+ # Example 1 (Pneumonia):
77
+ # python prediction.py --image "artifacts/data_ingestion/chest_xray/test/PNEUMONIA/person100_bacteria_475.jpeg"
78
+
79
+ # Example 2 (Normal):
80
+ # python prediction.py --image "artifacts/data_ingestion/chest_xray/test/NORMAL/IM-0001-0001.jpeg"
81
+
82
+ # Set up argument parser to accept image path from the command line
83
+ parser = argparse.ArgumentParser(description="Chest X-ray Pneumonia Detection")
84
+ parser.add_argument("--image", type=str, required=True, help="Path to the input image")
85
+ args = parser.parse_args()
86
+
87
+ # Create an instance of the pipeline
88
+ pipeline = PredictionPipeline()
89
+
90
+ # Make a prediction
91
+ result = pipeline.predict(args.image)
92
+
93
+ # Print the result
94
+ print("\n--- Prediction Result ---")
95
+ if "error" in result:
96
+ print(f"Error: {result['error']}")
97
+ else:
98
+ print(f"The model predicts this is a '{result['predicted_label']}' case.")
99
+ print(f"Confidence: {result['confidence_score']}")
100
+ print("-------------------------\n")