ALYYAN commited on
Commit
48b3884
·
1 Parent(s): b383602

Prepare application for deployment

Browse files
Files changed (5) hide show
  1. README.md +35 -1
  2. app.py +152 -100
  3. app/image_utils.py +65 -0
  4. app/prediction.py +42 -15
  5. requirements.txt +10 -21
README.md CHANGED
@@ -1 +1,35 @@
1
- # End-to-End-Chest-X-ray-Pneumonia-Detection-with-ViT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pneumonia Detection AI
3
+ emoji: 🩺
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.19.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # 🩺 Pneumonia Detection AI
13
+
14
+ This Space demonstrates a complete, end-to-end MLOps pipeline for medical image classification.
15
+
16
+ ## ✨ Features
17
+
18
+ - **AI-Powered Diagnosis:** Upload one or more chest X-ray images to get an instant classification of **Normal** or **Pneumonia**.
19
+ - **Advanced Model:** Powered by a fine-tuned **Vision Transformer (ViT)** for high accuracy.
20
+ - **Multi-Image Analysis:** The AI provides both an overall prediction for the patient and individual watermarked results for each image.
21
+ - **Patient History:** All analyses are logged to a **MongoDB** database and can be reviewed.
22
+ - **Sample Library:** Test the app instantly with a library of sample X-ray images.
23
+
24
+ ## 🛠️ Tech Stack
25
+
26
+ - **Model:** Google's `vit-base-patch16-224-in21k`
27
+ - **MLOps Pipeline:** DVC & MLflow
28
+ - **Frontend:** Gradio
29
+ - **Database:** MongoDB Atlas
30
+ - **Hosting:** Hugging Face Spaces
31
+
32
+ This project was developed by **Alyyan Ahmed** and **Munim Akbar**.
33
+
34
+ ---
35
+ **Disclaimer:** This is a demo application for educational and portfolio purposes. It is **not a certified medical device** and should not be used for actual medical diagnosis.
app.py CHANGED
@@ -1,129 +1,181 @@
1
- # app.py (in the root directory)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  from huggingface_hub import snapshot_download
6
  import asyncio
7
- from PIL import Image
8
 
9
- # --- Import and Initialize Backend Components from the 'app' folder ---
10
  from app.prediction import PredictionPipeline
11
  from app.database import add_patient_record, get_all_records
12
 
13
- # Initialize components once
14
  prediction_pipeline = PredictionPipeline()
15
  HF_DATASET_REPO = "ALYYAN/chest-xray-pneumonia-samples"
16
  try:
17
  SAMPLE_IMAGE_DIR = Path(snapshot_download(repo_id=HF_DATASET_REPO, repo_type="dataset"))
18
- # Create a list of sample image paths for the Gradio component
19
- SAMPLE_IMAGES = [str(p) for p in list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg'))[:10]]
20
  except Exception as e:
21
  print(f"Could not download sample images: {e}")
22
  SAMPLE_IMAGES = []
23
 
24
- # --- Core Prediction Logic for Gradio ---
25
- async def classify_images(patient_name, patient_age, image_list):
26
- # 1. Input Validation
27
- if not patient_name or patient_age is None:
28
  raise gr.Error("Patient Name and Age are required.")
29
  if not image_list:
30
- raise gr.Error("Please upload at least one image.")
31
 
32
- # Gradio provides file paths for uploaded files in a temp directory
33
- # Our prediction pipeline can handle these paths directly.
34
-
35
- # 2. Run Prediction
36
- result = prediction_pipeline.predict(image_list) # Pass the list of temp file paths
37
- prediction = result.get("prediction", "Error")
38
- confidence = result.get("confidence", 0)
39
-
40
- if prediction == "Error":
41
- raise gr.Error(result.get("details", "An unknown error occurred during prediction."))
42
-
43
- # 3. Save to Database
44
- # Ensure age is an integer
45
- try:
46
- age = int(patient_age)
47
- except (ValueError, TypeError):
48
- raise gr.Error("Patient Age must be a valid number.")
49
-
50
- await add_patient_record(
51
- name=str(patient_name),
52
- age=age,
53
- result=prediction,
54
- confidence=confidence
55
- )
56
-
57
- # 4. Format the Output for Gradio
58
- confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0} # Initialize both labels
59
- confidences[prediction] = confidence
60
 
61
- return confidences
 
62
 
63
- # --- Function to fetch and format database records ---
64
- async def get_records_html():
 
 
 
 
 
 
 
 
 
 
65
  records = await get_all_records()
66
- if not records:
67
- return "<p>No records found in the database.</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Create an HTML table from the records
70
- html = "<table><tr><th>Name</th><th>Age</th><th>Prediction</th><th>Confidence</th><th>Date</th></tr>"
71
- for r in records:
72
- confidence_percent = f"{r['confidence_score']:.2%}" if r['confidence_score'] is not None else "N/A"
73
- timestamp = r['timestamp'].strftime('%Y-%m-%d %H:%M') if r['timestamp'] else "N/A"
74
- html += f"<tr><td>{r.get('name', 'N/A')}</td><td>{r.get('age', 'N/A')}</td><td>{r.get('prediction_result', 'N/A')}</td><td>{confidence_percent}</td><td>{timestamp}</td></tr>"
75
- html += "</table>"
76
- return html
77
-
78
- # --- Build the Gradio Interface ---
79
- with gr.Blocks(theme=gr.themes.Soft(), css="table { width: 100%; border-collapse: collapse; } th, td { padding: 8px; text-align: left; border-bottom: 1px solid #ddd; }") as demo:
80
- gr.Markdown("# 🩺 Pneumonia Detection AI")
81
- gr.Markdown("Upload one or more chest X-ray images for a patient to classify them as **Normal** or **Pneumonia**.")
82
-
83
- with gr.Row():
84
- with gr.Column(scale=1):
85
- gr.Markdown("### 1. Patient Information")
86
- patient_name = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
87
- patient_age = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
88
-
89
- gr.Markdown("### 2. Upload Images")
90
- # Using type="filepath" is simpler and avoids memory issues with large images
91
- image_input = gr.File(
92
- label="Upload up to 3 X-Rays",
93
- file_count="multiple",
94
- file_types=["image"],
95
- type="filepath" # Gradio will save uploads to a temp dir and give us the path
96
- )
97
-
98
- submit_btn = gr.Button("Analyze Images", variant="primary")
99
-
100
- if SAMPLE_IMAGES:
101
- gr.Examples(
102
- examples=SAMPLE_IMAGES,
103
- inputs=image_input,
104
- label="Sample Images (Click one, then click Analyze)",
105
- examples_per_page=5
 
 
 
 
 
 
 
 
 
 
106
  )
107
-
108
- with gr.Column(scale=1):
109
- gr.Markdown("### 3. Analysis Results")
110
- output_label = gr.Label(label="Prediction", num_top_classes=2)
111
- gr.Markdown("---")
112
- with gr.Accordion("View Patient Record History", open=False):
113
- records_html = gr.HTML("Loading records...")
114
- demo.load(get_records_html, None, records_html) # Load records when the app starts
115
- refresh_btn = gr.Button("Refresh History")
116
-
117
-
118
- # --- Link Components to the Function ---
119
- submit_btn.click(
120
- fn=classify_images,
121
- inputs=[patient_name, patient_age, image_input],
122
- outputs=[output_label]
123
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # When the refresh button is clicked, re-run the get_records_html function
126
- refresh_btn.click(fn=get_records_html, inputs=None, outputs=records_html)
127
 
128
  # --- Launch the App ---
129
  if __name__ == "__main__":
 
1
+ # app.py (Final UI Polish Version)
2
 
3
  import gradio as gr
4
  from pathlib import Path
5
  from huggingface_hub import snapshot_download
6
  import asyncio
 
7
 
 
8
  from app.prediction import PredictionPipeline
9
  from app.database import add_patient_record, get_all_records
10
 
11
+ # --- Initialization ---
12
  prediction_pipeline = PredictionPipeline()
13
  HF_DATASET_REPO = "ALYYAN/chest-xray-pneumonia-samples"
14
  try:
15
  SAMPLE_IMAGE_DIR = Path(snapshot_download(repo_id=HF_DATASET_REPO, repo_type="dataset"))
16
+ SAMPLE_IMAGES = [str(p) for p in list(SAMPLE_IMAGE_DIR.glob('*/*.jpeg'))]
 
17
  except Exception as e:
18
  print(f"Could not download sample images: {e}")
19
  SAMPLE_IMAGES = []
20
 
21
+ # --- Core Logic (Async Functions) ---
22
+ async def process_analysis(patient_name, patient_age, image_list, is_sample=False):
23
+ if not is_sample and (not patient_name or patient_age is None or str(patient_age).strip() == ""):
 
24
  raise gr.Error("Patient Name and Age are required.")
25
  if not image_list:
26
+ raise gr.Error("At least one image is required.")
27
 
28
+ result = prediction_pipeline.predict(image_list)
29
+ if "error" in result:
30
+ raise gr.Error(result["error"])
31
+
32
+ final_pred = result["final_prediction"]
33
+ final_conf = result["final_confidence"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ if not is_sample:
36
+ await add_patient_record(str(patient_name), int(patient_age), final_pred, final_conf)
37
 
38
+ confidences = {"NORMAL": 0.0, "PNEUMONIA": 0.0}
39
+ confidences[final_pred] = final_conf
40
+ confidences["NORMAL" if final_pred == "PNEUMONIA" else "PNEUMONIA"] = 1 - final_conf
41
+
42
+ return [
43
+ gr.update(visible=False), # uploader_column
44
+ gr.update(visible=True), # results_column
45
+ gr.update(value=result["watermarked_images"]), # result_images
46
+ gr.update(value=confidences) # result_label
47
+ ]
48
+
49
+ async def refresh_history_table():
50
  records = await get_all_records()
51
+ data_for_df = []
52
+ if records:
53
+ data_for_df = [[r.get('name'), r.get('age'), r.get('prediction_result'), f"{r.get('confidence_score', 0):.2%}", r.get('timestamp').strftime('%Y-%m-%d %H:%M')] for r in records]
54
+ return gr.update(value=data_for_df)
55
+
56
+ # --- Gradio UI Definition ---
57
+ css = """
58
+ /* --- Professional Dark Theme & Fonts --- */
59
+ :root { --primary-hue: 220 !important; --secondary-hue: 210 !important; --neutral-hue: 210 !important; --body-background-fill: #111827 !important; --block-background-fill: #1F2937 !important; --block-border-width: 1px !important; --border-color-accent: #374151 !important; --background-fill-secondary: #1F2937 !important;}
60
+ /* --- Header & Title Styling --- */
61
+ #app_header { text-align: center; }
62
+ #app_title { font-size: 2.8rem !important; font-weight: 700 !important; color: #FFFFFF !important; padding-top: 1rem; }
63
+ #app_subtitle { font-size: 1.2rem !important; color: #9CA3AF !important; margin-bottom: 2rem; }
64
+ /* --- Layout, Spacing, and Component Styling --- */
65
+ #main_container { gap: 2rem; }
66
+ #results_gallery { height: 350px !important; }
67
+ #results_gallery .gallery-item { height: 330px !important; max-height: 330px !important; padding: 0.25rem !important; background-color: #374151; border: 1px solid #374151 !important; }
68
+ #results_gallery .gallery-item img { object-fit: contain !important; }
69
+ #bottom_controls { max-width: 600px; margin: 2.5rem auto 1rem auto; }
70
+ #bottom_controls .gr-accordion > .gr-block-label { text-align: center !important; display: block !important; }
71
+ """
72
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"), css=css, title="Pneumonia Detection AI") as demo:
73
 
74
+ with gr.Column() as main_app:
75
+ with gr.Column(elem_id="app_header"):
76
+ gr.Markdown("# 🩺 Pneumonia Detection AI", elem_id="app_title")
77
+ gr.Markdown("An AI-powered tool to assist in the diagnosis of pneumonia.", elem_id="app_subtitle")
78
+ with gr.Row(elem_id="main_container"):
79
+ with gr.Column(scale=1) as uploader_column:
80
+ gr.Markdown("### Upload Patient X-Rays")
81
+ image_input = gr.File(label="Upload up to 3 Images", file_count="multiple", file_types=["image"], type="filepath")
82
+ with gr.Column(scale=2, visible=False) as results_column:
83
+ gr.Markdown("### Analysis Results")
84
+ result_images = gr.Gallery(label="Analyzed Images", columns=3, object_fit="contain", height=350, elem_id="results_gallery")
85
+ result_label = gr.Label(label="Overall Prediction", num_top_classes=2)
86
+ start_over_btn = gr.Button("Start New Analysis", variant="secondary")
87
+ with gr.Group(visible=False) as patient_info_modal:
88
+ gr.Markdown("## Enter Patient Details", elem_classes="text-center")
89
+ patient_name_modal = gr.Textbox(label="Patient Name", placeholder="e.g., John Doe")
90
+ patient_age_modal = gr.Number(label="Patient Age", minimum=0, maximum=120, step=1)
91
+ with gr.Row():
92
+ submit_analysis_btn = gr.Button("Analyze Images", variant="primary")
93
+ cancel_btn = gr.Button("Cancel", variant="stop")
94
+ with gr.Column(elem_id="bottom_controls"):
95
+ with gr.Accordion("About this Tool", open=False):
96
+ gr.Markdown(
97
+ """
98
+ ### MLOps-Powered Pneumonia Detection
99
+
100
+ This application demonstrates a complete, end-to-end MLOps pipeline for medical image classification. It leverages a state-of-the-art **Vision Transformer (ViT)** model, fine-tuned on a public dataset of chest X-ray images to distinguish between Normal and Pneumonia cases.
101
+
102
+ ---
103
+
104
+ **Key Features & Technologies:**
105
+
106
+ * **Model:** Google's `vit-base-patch16-224-in21k`, fine-tuned for high accuracy.
107
+ * **MLOps Pipeline:** Reproducible workflow managed by **DVC** for data versioning and **MLflow** for experiment tracking.
108
+ * **Database:** Patient and prediction data is stored and managed in a **MongoDB** database for scalability.
109
+ * **Frontend:** A responsive and interactive user interface built with **Gradio**.
110
+ * **Deployment Ready:** The entire project is containerized and ready for deployment on platforms like Hugging Face Spaces.
111
+
112
+ **Disclaimer:** This tool is for demonstration and educational purposes only and is **not a substitute for professional medical advice.**
113
+
114
+ ---
115
+
116
+ **Project Team:**
117
+
118
+ * **Alyyan Ahmed** - (roles)
119
+ * **Munim Akbar** - (roles)
120
+ """
121
  )
122
+ with gr.Row():
123
+ samples_btn = gr.Button("Try Sample Images")
124
+ history_btn = gr.Button("View Patient History")
125
+ with gr.Column(visible=False) as history_page:
126
+ gr.Markdown("# 📜 Patient Record History", elem_classes="app_title")
127
+ with gr.Row():
128
+ back_to_main_btn_hist = gr.Button("⬅️ Back to Main App")
129
+ refresh_history_btn = gr.Button("Refresh History")
130
+ history_df = gr.DataFrame(headers=["Name", "Age", "Prediction", "Confidence", "Date"], row_count=10, interactive=False)
131
+ with gr.Column(visible=False) as samples_page:
132
+ gr.Markdown("# 🖼️ Sample Image Library", elem_classes="app_title")
133
+ gr.Markdown("Click an image to run an anonymous analysis.")
134
+ back_to_main_btn_samp = gr.Button("⬅️ Back to Main App")
135
+ sample_gallery = gr.Gallery(value=SAMPLE_IMAGES, label="Sample Images", columns=5, height=400)
136
+
137
+ # --- Event Handling Logic ---
138
+ def show_patient_info(files):
139
+ return gr.update(visible=True) if files else gr.update(visible=False)
140
+ image_input.upload(fn=show_patient_info, inputs=image_input, outputs=patient_info_modal)
141
+
142
+ async def submit_and_hide_modal(name, age, files):
143
+ analysis_results = await process_analysis(name, age, files)
144
+ return [
145
+ *analysis_results,
146
+ gr.update(visible=False) # Hide the modal
147
+ ]
148
+ submit_analysis_btn.click(fn=submit_and_hide_modal, inputs=[patient_name_modal, patient_age_modal, image_input], outputs=[uploader_column, results_column, result_images, result_label, patient_info_modal])
149
+
150
+ cancel_btn.click(lambda: (gr.update(visible=False), None), None, [patient_info_modal, image_input])
151
+ start_over_btn.click(fn=None, js="() => { window.location.reload(); }")
152
+
153
+ async def handle_sample_click(evt: gr.SelectData):
154
+ selected_path = evt.value
155
+ analysis_results = await process_analysis("Sample User", 0, [selected_path], is_sample=True)
156
+ return [
157
+ gr.update(visible=True), # main_app
158
+ gr.update(visible=False), # samples_page
159
+ *analysis_results
160
+ ]
161
+ sample_gallery.select(handle_sample_click, None, [main_app, samples_page, uploader_column, results_column, result_images, result_label])
162
+
163
+ all_pages = [main_app, history_page, samples_page]
164
+ async def show_history_page_and_refresh():
165
+ records_update = await refresh_history_table()
166
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), records_update]
167
+ def show_samples_page():
168
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
169
+ def show_main_page():
170
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
171
+
172
+ history_btn.click(fn=show_history_page_and_refresh, outputs=all_pages + [history_df])
173
+ samples_btn.click(fn=show_samples_page, outputs=all_pages)
174
+ back_to_main_btn_hist.click(fn=show_main_page, outputs=all_pages)
175
+ back_to_main_btn_samp.click(fn=show_main_page, outputs=all_pages)
176
 
177
+ refresh_history_btn.click(fn=refresh_history_table, outputs=history_df)
178
+ demo.load(fn=refresh_history_table, outputs=history_df)
179
 
180
  # --- Launch the App ---
181
  if __name__ == "__main__":
app/image_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/image_utils.py
2
+
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import numpy as np
5
+
6
+ def add_watermark(image_array: np.ndarray, text: str, confidence: float) -> Image.Image:
7
+ """
8
+ Adds a translucent watermark to an image with the prediction result and confidence.
9
+
10
+ Args:
11
+ image_array: The input image as a NumPy array.
12
+ text: The prediction text (e.g., "NORMAL" or "PNEUMONIA").
13
+ confidence: The confidence score of the prediction.
14
+
15
+ Returns:
16
+ A PIL Image object with the watermark applied.
17
+ """
18
+ # Convert NumPy array to PIL Image
19
+ image = Image.fromarray(image_array).convert("RGBA")
20
+
21
+ # Create a transparent overlay for the text
22
+ txt_overlay = Image.new("RGBA", image.size, (255, 255, 255, 0))
23
+ draw = ImageDraw.Draw(txt_overlay)
24
+
25
+ # Define watermark properties
26
+ is_pneumonia = (text == "PNEUMONIA")
27
+ box_color = (220, 53, 69, 180) if is_pneumonia else (25, 135, 84, 180) # Red for Pneumonia, Green for Normal
28
+ text_color = (255, 255, 255, 255)
29
+
30
+ # Define font (uses a default if a specific .ttf is not found)
31
+ try:
32
+ font_size = int(image.height / 8)
33
+ font = ImageFont.truetype("arialbd.ttf", font_size)
34
+ except IOError:
35
+ print("Arial Bold font not found, using default. Watermark quality may be lower.")
36
+ font_size = int(image.height / 8)
37
+ font = ImageFont.load_default()
38
+
39
+ # Define text and box position
40
+ text_to_draw = f"{text}\n{confidence:.1%}"
41
+
42
+ # Get text size
43
+ try:
44
+ # Use getbbox for modern Pillow versions
45
+ _, _, text_width, text_height = draw.textbbox((0, 0), text_to_draw, font=font)
46
+ except AttributeError:
47
+ # Fallback for older Pillow versions
48
+ text_width, text_height = draw.textsize(text_to_draw, font=font)
49
+
50
+ position = (20, 20) # Top-left corner with some padding
51
+ box_position = [
52
+ position[0] - 10,
53
+ position[1] - 10,
54
+ position[0] + text_width + 10,
55
+ position[1] + text_height + 10
56
+ ]
57
+
58
+ # Draw the semi-transparent rectangle and the text
59
+ draw.rectangle(box_position, fill=box_color)
60
+ draw.text(position, text_to_draw, font=font, fill=text_color)
61
+
62
+ # Combine the overlay with the original image
63
+ watermarked_image = Image.alpha_composite(image, txt_overlay)
64
+
65
+ return watermarked_image.convert("RGB")
app/prediction.py CHANGED
@@ -5,10 +5,10 @@ from transformers import ViTImageProcessor, ViTForImageClassification
5
  from PIL import Image
6
  from pathlib import Path
7
  import numpy as np
8
- from typing import List, Dict, Union
 
9
 
10
- # Define a type hint for the input, which can be a path or bytes
11
- ImageType = Union[str, Path, bytes]
12
 
13
  class PredictionPipeline:
14
  def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
@@ -18,36 +18,63 @@ class PredictionPipeline:
18
  self.model.eval()
19
  self.id2label = self.model.config.id2label
20
 
21
- def predict(self, image_sources: List[ImageType]) -> Dict[str, Union[str, float]]:
22
  if not image_sources:
23
- return {"prediction": "Error", "confidence": 0.0, "details": "No images provided."}
24
 
 
25
  all_logits = []
 
 
26
  for source in image_sources:
27
  try:
28
- # --- THIS IS THE FIX ---
29
- # The Image.open() function can handle both paths and byte streams.
30
- # No special handling is needed.
31
- image = Image.open(source).convert("RGB")
32
 
33
- inputs = self.processor(images=image, return_tensors="pt").to(self.device)
34
 
 
35
  with torch.no_grad():
36
  outputs = self.model(**inputs)
37
- all_logits.append(outputs.logits)
 
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
  print(f"Skipping a corrupted or invalid image file. Error: {e}")
 
40
  continue
41
 
42
  if not all_logits:
43
- return {"prediction": "Error", "confidence": 0.0, "details": "All provided images were invalid."}
44
 
 
45
  avg_logits = torch.mean(torch.stack(all_logits), dim=0)
46
  probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
47
  confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
48
- predicted_label = self.id2label[predicted_class_idx.item()]
 
 
 
 
 
 
 
 
 
49
 
50
  return {
51
- "prediction": predicted_label,
52
- "confidence": confidence_score.item()
 
 
53
  }
 
5
  from PIL import Image
6
  from pathlib import Path
7
  import numpy as np
8
+ from typing import List, Dict, Union, Any
9
+ from .image_utils import add_watermark
10
 
11
+ ImageType = Union[str, Path, bytes, np.ndarray]
 
12
 
13
  class PredictionPipeline:
14
  def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
 
18
  self.model.eval()
19
  self.id2label = self.model.config.id2label
20
 
21
+ def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
22
  if not image_sources:
23
+ return {"error": "No images provided."}
24
 
25
+ individual_results = []
26
  all_logits = []
27
+ valid_images_as_np = []
28
+
29
  for source in image_sources:
30
  try:
31
+ if isinstance(source, np.ndarray):
32
+ image = Image.fromarray(source).convert("RGB")
33
+ else:
34
+ image = Image.open(source).convert("RGB")
35
 
36
+ valid_images_as_np.append(np.array(image))
37
 
38
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
39
  with torch.no_grad():
40
  outputs = self.model(**inputs)
41
+ logits = outputs.logits
42
+ all_logits.append(logits)
43
+
44
+ # --- NEW: Calculate individual prediction ---
45
+ ind_probs = torch.nn.functional.softmax(logits, dim=-1)
46
+ ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
47
+ individual_results.append({
48
+ "prediction": self.id2label[ind_idx.item()],
49
+ "confidence": ind_conf.item()
50
+ })
51
+
52
  except Exception as e:
53
  print(f"Skipping a corrupted or invalid image file. Error: {e}")
54
+ individual_results.append({"prediction": "Error", "confidence": 0})
55
  continue
56
 
57
  if not all_logits:
58
+ return {"error": "All images were invalid."}
59
 
60
+ # --- Aggregate Prediction (same as before) ---
61
  avg_logits = torch.mean(torch.stack(all_logits), dim=0)
62
  probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
63
  confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
64
+
65
+ final_prediction = self.id2label[predicted_class_idx.item()]
66
+ final_confidence = confidence_score.item()
67
+
68
+ # --- NEW: Watermark images with their INDIVIDUAL results ---
69
+ watermarked_images = [
70
+ add_watermark(img_np, res["prediction"], res["confidence"])
71
+ for img_np, res in zip(valid_images_as_np, individual_results)
72
+ if res["prediction"] != "Error"
73
+ ]
74
 
75
  return {
76
+ "final_prediction": final_prediction,
77
+ "final_confidence": final_confidence,
78
+ "individual_results": individual_results,
79
+ "watermarked_images": watermarked_images
80
  }
requirements.txt CHANGED
@@ -1,27 +1,16 @@
1
- pandas
2
- numpy
3
- torch
4
- torchvision
 
 
 
 
5
  transformers
6
- datasets>=2.14.5
7
- evaluate
8
- accelerate>=0.27
9
- mlflow
10
  scikit-learn
11
  imblearn
12
  python-box
13
  PyYAML
14
  ensure
15
- tqdm
16
- pathlib
17
- dvc
18
- matplotlib
19
- Pillow
20
- kaggle
21
- python-dotenv
22
- nicegui
23
- sqlalchemy
24
- pymongo
25
- motor
26
- huggingface_hub
27
- gradio
 
1
+ gradio==4.19.1
2
+ pymongo
3
+ motor
4
+ python-dotenv
5
+ huggingface_hub
6
+ torch --index-url https://download.pytorch.org/whl/cpu
7
+ torchvision --index-url https://download.pytorch.org/whl/cpu
8
+ Pillow
9
  transformers
10
+ datasets
 
 
 
11
  scikit-learn
12
  imblearn
13
  python-box
14
  PyYAML
15
  ensure
16
+ dvc[gdrive] # Add dvc with gdrive support