Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,12 +7,16 @@ from datasets import load_dataset, DownloadConfig
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Model and processor configuration
|
| 12 |
model_name_or_path = "google/vit-base-patch16-224-in21k"
|
| 13 |
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
|
| 14 |
|
| 15 |
-
# Load dataset
|
| 16 |
dataset_path = "pawlo2013/chest_xray"
|
| 17 |
download_config = DownloadConfig(max_retries=10)
|
| 18 |
train_dataset = load_dataset(dataset_path, split="train", download_config=download_config)
|
|
@@ -26,126 +30,139 @@ model = ViTForImageClassification.from_pretrained(
|
|
| 26 |
label2id={label: i for i, label in enumerate(class_names)},
|
| 27 |
)
|
| 28 |
|
| 29 |
-
# Set model to evaluation mode
|
| 30 |
model.eval()
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
)
|
| 52 |
-
|
| 53 |
-
return
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
# Launch the app
|
| 150 |
if __name__ == "__main__":
|
| 151 |
-
|
|
|
|
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
| 10 |
+
from groq import Groq
|
| 11 |
+
|
| 12 |
+
# Initialize Groq client
|
| 13 |
+
client = Groq(api_key="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")
|
| 14 |
|
| 15 |
# Model and processor configuration
|
| 16 |
model_name_or_path = "google/vit-base-patch16-224-in21k"
|
| 17 |
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
|
| 18 |
|
| 19 |
+
# Load dataset
|
| 20 |
dataset_path = "pawlo2013/chest_xray"
|
| 21 |
download_config = DownloadConfig(max_retries=10)
|
| 22 |
train_dataset = load_dataset(dataset_path, split="train", download_config=download_config)
|
|
|
|
| 30 |
label2id={label: i for i, label in enumerate(class_names)},
|
| 31 |
)
|
| 32 |
|
|
|
|
| 33 |
model.eval()
|
| 34 |
|
| 35 |
+
def get_ai_explanation(diagnosis, probabilities):
|
| 36 |
+
if diagnosis == "normal":
|
| 37 |
+
prompt = f"""Given a chest X-ray analysis showing NORMAL results with {probabilities['normal']:.2%} confidence:
|
| 38 |
+
1. Explain what this means
|
| 39 |
+
2. Suggest when they should still consider consulting a doctor
|
| 40 |
+
3. List key symptoms that would warrant medical attention
|
| 41 |
+
Keep the tone professional yet reassuring."""
|
| 42 |
+
else:
|
| 43 |
+
prompt = f"""Given a chest X-ray analysis showing {diagnosis} pneumonia with {probabilities[diagnosis]:.2%} confidence:
|
| 44 |
+
1. Explain what {diagnosis} pneumonia is
|
| 45 |
+
2. List immediate steps the patient should take
|
| 46 |
+
3. Provide care recommendations
|
| 47 |
+
4. Mention warning signs to watch for
|
| 48 |
+
Keep the tone informative and caring but emphasize the importance of professional medical consultation."""
|
| 49 |
+
|
| 50 |
+
completion = client.chat.completions.create(
|
| 51 |
+
messages=[{"role": "user", "content": prompt}],
|
| 52 |
+
model="mixtral-8x7b-32768",
|
| 53 |
+
temperature=0.7,
|
| 54 |
)
|
| 55 |
+
|
| 56 |
+
return completion.choices[0].message.content
|
| 57 |
+
|
| 58 |
+
# Rest of your existing functions (classify_and_visualize, show_final_layer_attention_maps, etc.) remain the same
|
| 59 |
+
[Previous functions remain unchanged...]
|
| 60 |
+
|
| 61 |
+
def create_interface():
|
| 62 |
+
# Custom CSS
|
| 63 |
+
custom_css = """
|
| 64 |
+
.logo-container { text-align: center; margin-bottom: 20px; }
|
| 65 |
+
.logo-container img { max-width: 300px; }
|
| 66 |
+
.welcome-message { text-align: center; margin: 20px 0; padding: 20px; background-color: #f5f5f5; border-radius: 10px; }
|
| 67 |
+
.model-explanation { margin: 20px 0; padding: 20px; background-color: #f0f7ff; border-radius: 10px; }
|
| 68 |
+
.pneumonia-info { margin: 20px 0; padding: 20px; background-color: #fff5f5; border-radius: 10px; }
|
| 69 |
+
.disclaimer { margin-top: 20px; padding: 20px; background-color: #f5f5f5; border-radius: 10px; font-size: 0.9em; }
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# HTML Components
|
| 73 |
+
logo_html = """
|
| 74 |
+
<div class="logo-container">
|
| 75 |
+
<img src="file/logo.png" alt="PneumoInsight Logo">
|
| 76 |
+
</div>
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
welcome_message = """
|
| 80 |
+
<div class="welcome-message">
|
| 81 |
+
<h1>Welcome to PneumoInsight</h1>
|
| 82 |
+
<p>PneumoInsight is a side project of EarlyMed—an initiative by our team at VIT-AP University dedicated to empowering you with early health insights.
|
| 83 |
+
Leveraging AI for early detection, our mission is simple: "Early Detection, Smarter Decision."
|
| 84 |
+
This project is one of our key efforts to help you stay informed before visiting a doctor.</p>
|
| 85 |
+
</div>
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
model_explanation = """
|
| 89 |
+
<div class="model-explanation">
|
| 90 |
+
<h2>How Our Model Works</h2>
|
| 91 |
+
<p>Our system uses a Vision Transformer (ViT) model to analyze chest X-ray images. The attention heatmap visualizes
|
| 92 |
+
areas the AI focuses on while making its diagnosis, helping make the decision-making process more transparent.
|
| 93 |
+
The warmer colors (red/yellow) indicate areas of higher attention.</p>
|
| 94 |
+
<p>Credits: The attention heatmap visualization is implemented using the attention rollout technique by
|
| 95 |
+
<a href="https://github.com/jacobgil/vit-explain" target="_blank">jacobgil</a>.</p>
|
| 96 |
+
</div>
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
pneumonia_info = """
|
| 100 |
+
<div class="pneumonia-info">
|
| 101 |
+
<h2>Understanding Pneumonia</h2>
|
| 102 |
+
<p>Pneumonia is an infection that inflames the air sacs in one or both lungs. Common symptoms include:</p>
|
| 103 |
+
<ul>
|
| 104 |
+
<li>Chest pain when breathing or coughing</li>
|
| 105 |
+
<li>Cough with phlegm or pus</li>
|
| 106 |
+
<li>Fatigue and difficulty breathing</li>
|
| 107 |
+
<li>Fever, sweating, and shaking chills</li>
|
| 108 |
+
</ul>
|
| 109 |
+
<p>Prevention tips:</p>
|
| 110 |
+
<ul>
|
| 111 |
+
<li>Get vaccinated</li>
|
| 112 |
+
<li>Practice good hygiene</li>
|
| 113 |
+
<li>Don't smoke</li>
|
| 114 |
+
<li>Maintain a strong immune system</li>
|
| 115 |
+
</ul>
|
| 116 |
+
</div>
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
disclaimer = """
|
| 120 |
+
<div class="disclaimer">
|
| 121 |
+
<h3>Disclaimer</h3>
|
| 122 |
+
<p>This tool is for educational purposes only and should not be used as a substitute for professional medical advice,
|
| 123 |
+
diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider.</p>
|
| 124 |
+
<p>Created by the team at VIT-AP University. View the source code on
|
| 125 |
+
<a href="https://github.com/Mahatir-Ahmed-Tusher/PneumoInsight" target="_blank">GitHub</a>.</p>
|
| 126 |
+
</div>
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def enhanced_classification(img):
|
| 130 |
+
if img is None:
|
| 131 |
+
return None, None, "Please upload an image to proceed."
|
| 132 |
+
|
| 133 |
+
result = classify_and_visualize(img)
|
| 134 |
+
probabilities = result["probabilities"]
|
| 135 |
+
heatmap = result["heatmap"]
|
| 136 |
+
|
| 137 |
+
# Get the predicted class
|
| 138 |
+
predicted_class = max(probabilities.items(), key=lambda x: x[1])[0]
|
| 139 |
+
|
| 140 |
+
# Get AI explanation
|
| 141 |
+
ai_explanation = get_ai_explanation(predicted_class, probabilities)
|
| 142 |
+
|
| 143 |
+
return probabilities, heatmap, ai_explanation
|
| 144 |
+
|
| 145 |
+
# Create the Gradio interface
|
| 146 |
+
iface = gr.Interface(
|
| 147 |
+
fn=enhanced_classification,
|
| 148 |
+
inputs=gr.Image(type="pil", label="Upload Chest X-Ray Image"),
|
| 149 |
+
outputs=[
|
| 150 |
+
gr.Label(label="Diagnosis Probabilities"),
|
| 151 |
+
gr.Image(label="Attention Heatmap"),
|
| 152 |
+
gr.Textbox(label="AI Analysis and Recommendations", lines=10)
|
| 153 |
+
],
|
| 154 |
+
css=custom_css,
|
| 155 |
+
examples=load_examples_from_folder("./Examples"),
|
| 156 |
+
cache_examples=False,
|
| 157 |
+
article=model_explanation + pneumonia_info + disclaimer,
|
| 158 |
+
description=welcome_message,
|
| 159 |
+
title=logo_html,
|
| 160 |
+
theme=gr.themes.Soft()
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return iface
|
| 164 |
|
| 165 |
# Launch the app
|
| 166 |
if __name__ == "__main__":
|
| 167 |
+
demo = create_interface()
|
| 168 |
+
demo.launch(debug=True)
|