omm7 commited on
Commit
8f8253e
·
verified ·
1 Parent(s): 036c810

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -54
app.py CHANGED
@@ -3,84 +3,90 @@ import numpy as np
3
  from PIL import Image
4
  import tensorflow as tf
5
  from tensorflow.keras.models import load_model
6
- import os
7
- import time
8
- from io import BytesIO
9
 
10
- # Define the correct class names for output
 
11
  CLASS_NAMES = {0: 'Normal', 1: 'Viral Pneumonia', 2: 'Covid'}
12
- IMG_SIZE = 224
 
13
 
14
- # Function to load the model (cached for efficiency)
 
15
  @st.cache_resource
16
  def load_tuned_model():
17
- # Use custom_objects for maximum robustness when loading models with
18
- # pre-trained components like VGG16 in specific environments.
19
  return tf.keras.models.load_model(
20
  "tuned_ai_model_best_lat.keras",
21
  custom_objects={'VGG16': tf.keras.applications.VGG16}
22
  )
23
 
24
- # Function to run the prediction and show progress
25
- def run_prediction(image_file, model, img_size):
26
- # Progress visualization (for user experience)
27
- progress_bar = st.progress(0)
28
- status_text = st.empty()
29
- for i in range(100):
30
- progress_bar.progress(i + 1)
31
- status_text.text(f"Processing... {i+1}%")
32
- time.sleep(0.01)
33
-
34
  try:
35
- # Read and process the image using PIL (Image)
36
  image = Image.open(image_file).convert("RGB")
37
- img_array = np.array(image.resize((img_size, img_size)))
38
- img_array = np.expand_dims(img_array, axis=0)
 
 
 
 
39
  img_array = img_array / 255.0
40
-
41
- # Make the prediction
42
- prediction = model.predict(img_array).flatten()
43
-
44
- # Find the predicted class index and name
45
- class_predicted_idx = np.argmax(prediction)
46
- predicted_class_name = CLASS_NAMES[class_predicted_idx]
47
-
48
- status_text.success("Prediction complete!")
49
 
50
- # Display results
51
- st.subheader("Prediction Probabilities:")
52
- for i, prob in enumerate(prediction):
53
- class_name = CLASS_NAMES[i]
54
- if i == class_predicted_idx:
55
- st.markdown(f"**{class_name}: {prob*100:.2f}%** (Predicted)", unsafe_allow_html=True)
56
- else:
57
- st.markdown(f"{class_name}: {prob*100:.2f}%")
58
-
59
- if class_predicted_idx == 2:
60
- st.error(f"**Result: ❗ HIGH LIKELIHOOD OF COVID DETECTED ❗**")
61
- else:
62
- st.success(f"**Result: ✅ {predicted_class_name} Detected**")
63
 
64
  except Exception as e:
65
- status_text.error(f"An error occurred during prediction: {e}")
66
- st.exception(e) # Show full traceback for debugging
 
67
 
68
- # Load the model and ensure the app stops if loading fails
 
 
 
 
 
69
  try:
70
  model = load_tuned_model()
71
  except Exception as e:
72
- st.error("Model Loading Failed. Check TF/Keras versions and model file name (tuned_ai_model_best_lat.keras).")
73
- st.exception(e)
74
  st.stop()
75
 
76
- st.title("COVID Detection from Chest X-ray")
77
- st.write("Upload a chest X-ray image to predict if it shows signs of Normal, Viral Pneumonia, or COVID.")
78
-
79
  uploaded_file = st.file_uploader("Choose an X-ray image...", type=["jpg", "jpeg", "png"])
80
 
81
  if uploaded_file is not None:
 
82
  image = Image.open(uploaded_file)
83
  st.image(image, caption="Uploaded X-ray Image", use_container_width=True)
84
-
85
- if st.button("Predict"):
86
- run_prediction(uploaded_file, model, IMG_SIZE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from PIL import Image
4
  import tensorflow as tf
5
  from tensorflow.keras.models import load_model
 
 
 
6
 
7
+ # --- CONFIGURATION ---
8
+ # Define the three possible outcomes
9
  CLASS_NAMES = {0: 'Normal', 1: 'Viral Pneumonia', 2: 'Covid'}
10
+ # The image size the model expects (224x224 pixels)
11
+ IMAGE_SIZE = 224
12
 
13
+ # --- MODEL LOADING ---
14
+ # Use cache so the model only loads once
15
  @st.cache_resource
16
  def load_tuned_model():
17
+ # Load the Keras model file.
18
+ # We include custom_objects to correctly load the VGG16 base.
19
  return tf.keras.models.load_model(
20
  "tuned_ai_model_best_lat.keras",
21
  custom_objects={'VGG16': tf.keras.applications.VGG16}
22
  )
23
 
24
+ # --- PREDICTION LOGIC ---
25
+ def run_prediction(image_file, model):
26
+ """Processes the image and gets the diagnosis from the model."""
 
 
 
 
 
 
 
27
  try:
28
+ # 1. Load and prepare the image
29
  image = Image.open(image_file).convert("RGB")
30
+ img_array = np.array(image.resize((IMAGE_SIZE, IMAGE_SIZE)))
31
+
32
+ # Add a dimension for the batch (1, 224, 224, 3)
33
+ img_array = np.expand_dims(img_array, axis=0)
34
+
35
+ # Normalize pixel values (0 to 1)
36
  img_array = img_array / 255.0
 
 
 
 
 
 
 
 
 
37
 
38
+ # 2. Make prediction
39
+ # The result is an array of probabilities for all three classes
40
+ prediction_probabilities = model.predict(img_array).flatten()
41
+
42
+ # 3. Find the most likely class
43
+ class_index = np.argmax(prediction_probabilities)
44
+ predicted_name = CLASS_NAMES[class_index]
45
+ predicted_prob = prediction_probabilities[class_index]
46
+
47
+ return predicted_name, predicted_prob
 
 
 
48
 
49
  except Exception as e:
50
+ st.error(f"An error occurred during prediction: {e}")
51
+ # Return None if any error happens
52
+ return None, None
53
 
54
+ # --- STREAMLIT INTERFACE ---
55
+
56
+ st.title("COVID Detection from Chest X-ray")
57
+ st.markdown("Upload a chest X-ray image for diagnosis (Normal, Viral Pneumonia, or COVID).")
58
+
59
+ # Attempt to load the model and stop if it fails
60
  try:
61
  model = load_tuned_model()
62
  except Exception as e:
63
+ st.error("Model Loading Failed. Please check dependencies and model file.")
 
64
  st.stop()
65
 
66
+ # --- UPLOAD SECTION ---
 
 
67
  uploaded_file = st.file_uploader("Choose an X-ray image...", type=["jpg", "jpeg", "png"])
68
 
69
  if uploaded_file is not None:
70
+ # Display the uploaded image
71
  image = Image.open(uploaded_file)
72
  st.image(image, caption="Uploaded X-ray Image", use_container_width=True)
73
+
74
+ # Run prediction when the button is clicked
75
+ if st.button("Predict Diagnosis", type="primary"):
76
+
77
+ # Run the prediction logic
78
+ predicted_name, predicted_prob = run_prediction(uploaded_file, model)
79
+
80
+ if predicted_name:
81
+ st.markdown("---")
82
+ st.subheader("Predicted Diagnosis")
83
+
84
+ # Display the result simply (no emojis)
85
+ if predicted_name == 'Covid':
86
+ st.error(f"Result: **{predicted_name}**")
87
+ else:
88
+ st.success(f"Result: **{predicted_name}**")
89
+
90
+ # Use the expander to show probability on click
91
+ with st.expander(f"View Confidence Score for {predicted_name}"):
92
+ st.markdown(f"Confidence: **{predicted_prob*100:.2f}%**")