MahatirTusher commited on
Commit
4793156
·
verified ·
1 Parent(s): 900f653

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -118
app.py CHANGED
@@ -7,12 +7,16 @@ from datasets import load_dataset, DownloadConfig
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import cv2
 
 
 
 
10
 
11
  # Model and processor configuration
12
  model_name_or_path = "google/vit-base-patch16-224-in21k"
13
  processor = ViTImageProcessor.from_pretrained(model_name_or_path)
14
 
15
- # Load dataset (adjust dataset_path accordingly)
16
  dataset_path = "pawlo2013/chest_xray"
17
  download_config = DownloadConfig(max_retries=10)
18
  train_dataset = load_dataset(dataset_path, split="train", download_config=download_config)
@@ -26,126 +30,139 @@ model = ViTForImageClassification.from_pretrained(
26
  label2id={label: i for i, label in enumerate(class_names)},
27
  )
28
 
29
- # Set model to evaluation mode
30
  model.eval()
31
 
32
-
33
- # Define the classification function
34
- def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
35
- img = img.convert("RGB")
36
- processed_input = processor(images=img, return_tensors="pt").to(device)
37
- processed_input = processed_input["pixel_values"].to(device)
38
-
39
- with torch.no_grad():
40
- outputs = model(processed_input, output_attentions=True)
41
- logits = outputs.logits
42
- probabilities = torch.softmax(logits, dim=1)[0].tolist()
43
- prediction = torch.argmax(logits, dim=-1).item()
44
- predicted_class = class_names[prediction]
45
-
46
- result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
47
-
48
- # Generate attention heatmap
49
- heatmap_img = show_final_layer_attention_maps(
50
- outputs, processed_input, device, discard_ratio, head_fusion
51
  )
52
-
53
- return {"probabilities": result, "heatmap": heatmap_img}
54
-
55
-
56
- def format_output(output):
57
- return (output["probabilities"], output["heatmap"])
58
-
59
-
60
- # Function to load examples from a folder
61
- def load_examples_from_folder(folder_path):
62
- examples = []
63
- for file in os.listdir(folder_path):
64
- if file.endswith((".png", ".jpg", ".jpeg")):
65
- examples.append(Image.open(os.path.join(folder_path, file)))
66
- return examples
67
-
68
-
69
- # Function to show final layer attention maps
70
- def show_final_layer_attention_maps(
71
- outputs,
72
- processed_input,
73
- device,
74
- discard_ratio=0.6,
75
- head_fusion="max",
76
- only_last_layer=False,
77
- ):
78
-
79
- with torch.no_grad():
80
- image = processed_input.squeeze(0)
81
- image = image - image.min()
82
- image = image / image.max()
83
-
84
- result = torch.eye(outputs.attentions[0].size(-1)).to(device)
85
- if only_last_layer:
86
- attention_list = outputs.attentions[-1].unsqueeze(0).to(device)
87
- else:
88
- attention_list = outputs.attentions
89
-
90
- for attention in attention_list:
91
- if head_fusion == "mean":
92
- attention_heads_fused = attention.mean(axis=1)
93
- elif head_fusion == "max":
94
- attention_heads_fused = attention.max(axis=1)[0]
95
- elif head_fusion == "min":
96
- attention_heads_fused = attention.min(axis=1)[0]
97
-
98
- flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
99
- _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
100
- indices = indices[indices != 0]
101
- flat[0, indices] = 0
102
-
103
- I = torch.eye(attention_heads_fused.size(-1)).to(device)
104
- a = (attention_heads_fused + 1.0 * I) / 2
105
- a = a / a.sum(dim=-1)
106
- result = torch.matmul(a, result)
107
-
108
- mask = result[0, 0, 1:]
109
- width = int(mask.size(-1) ** 0.5)
110
- mask = mask.reshape(width, width).cpu().numpy()
111
- mask = mask / np.max(mask)
112
-
113
- mask = cv2.resize(mask, (224, 224))
114
- mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
115
- heatmap = plt.cm.jet(mask)[:, :, :3]
116
-
117
- showed_img = image.permute(1, 2, 0).detach().cpu().numpy()
118
- showed_img = (showed_img - np.min(showed_img)) / (
119
- np.max(showed_img) - np.min(showed_img)
120
- )
121
- superimposed_img = heatmap * 0.4 + showed_img * 0.6
122
-
123
- superimposed_img_pil = Image.fromarray(
124
- (superimposed_img * 255).astype(np.uint8)
125
- )
126
- return superimposed_img_pil
127
-
128
-
129
- # Define the path to the examples folder
130
- examples_folder = "./Examples"
131
- examples = load_examples_from_folder(examples_folder)
132
-
133
- # Create the Gradio interface
134
- iface = gr.Interface(
135
- fn=lambda img: format_output(classify_and_visualize(img)),
136
- inputs=gr.Image(type="pil", label="Upload X-Ray Image"),
137
- outputs=[
138
- gr.Label(),
139
- gr.Image(label="Attention Heatmap"),
140
- ],
141
- examples=examples,
142
- cache_examples=False,
143
- flagging_mode="never",
144
- concurrency_limit=1,
145
- title="Pneumonia X-Ray 3-Class Classification with Vision Transformer (ViT) using data augmentation",
146
- description="Upload an X-ray image to classify it as normal, viral or bacterial pneumonia. Checkout the model in more details [here](https://huggingface.co/pawlo2013/vit-pneumonia-x-ray_3_class). The examples presented are taken from the test set of [Kermany et al. (2018) dataset.](https://data.mendeley.com/datasets/rscbjbr9sj/2.) The attention heatmap over all layers of the transfomer done by the attention rollout techinique by the implementation of [jacobgil](https://github.com/jacobgil/vit-explain).",
147
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  # Launch the app
150
  if __name__ == "__main__":
151
- iface.launch(debug=True)
 
 
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import cv2
10
+ from groq import Groq
11
+
12
+ # Initialize Groq client
13
+ client = Groq(api_key="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")
14
 
15
  # Model and processor configuration
16
  model_name_or_path = "google/vit-base-patch16-224-in21k"
17
  processor = ViTImageProcessor.from_pretrained(model_name_or_path)
18
 
19
+ # Load dataset
20
  dataset_path = "pawlo2013/chest_xray"
21
  download_config = DownloadConfig(max_retries=10)
22
  train_dataset = load_dataset(dataset_path, split="train", download_config=download_config)
 
30
  label2id={label: i for i, label in enumerate(class_names)},
31
  )
32
 
 
33
  model.eval()
34
 
35
+ def get_ai_explanation(diagnosis, probabilities):
36
+ if diagnosis == "normal":
37
+ prompt = f"""Given a chest X-ray analysis showing NORMAL results with {probabilities['normal']:.2%} confidence:
38
+ 1. Explain what this means
39
+ 2. Suggest when they should still consider consulting a doctor
40
+ 3. List key symptoms that would warrant medical attention
41
+ Keep the tone professional yet reassuring."""
42
+ else:
43
+ prompt = f"""Given a chest X-ray analysis showing {diagnosis} pneumonia with {probabilities[diagnosis]:.2%} confidence:
44
+ 1. Explain what {diagnosis} pneumonia is
45
+ 2. List immediate steps the patient should take
46
+ 3. Provide care recommendations
47
+ 4. Mention warning signs to watch for
48
+ Keep the tone informative and caring but emphasize the importance of professional medical consultation."""
49
+
50
+ completion = client.chat.completions.create(
51
+ messages=[{"role": "user", "content": prompt}],
52
+ model="mixtral-8x7b-32768",
53
+ temperature=0.7,
54
  )
55
+
56
+ return completion.choices[0].message.content
57
+
58
+ # Rest of your existing functions (classify_and_visualize, show_final_layer_attention_maps, etc.) remain the same
59
+ [Previous functions remain unchanged...]
60
+
61
+ def create_interface():
62
+ # Custom CSS
63
+ custom_css = """
64
+ .logo-container { text-align: center; margin-bottom: 20px; }
65
+ .logo-container img { max-width: 300px; }
66
+ .welcome-message { text-align: center; margin: 20px 0; padding: 20px; background-color: #f5f5f5; border-radius: 10px; }
67
+ .model-explanation { margin: 20px 0; padding: 20px; background-color: #f0f7ff; border-radius: 10px; }
68
+ .pneumonia-info { margin: 20px 0; padding: 20px; background-color: #fff5f5; border-radius: 10px; }
69
+ .disclaimer { margin-top: 20px; padding: 20px; background-color: #f5f5f5; border-radius: 10px; font-size: 0.9em; }
70
+ """
71
+
72
+ # HTML Components
73
+ logo_html = """
74
+ <div class="logo-container">
75
+ <img src="file/logo.png" alt="PneumoInsight Logo">
76
+ </div>
77
+ """
78
+
79
+ welcome_message = """
80
+ <div class="welcome-message">
81
+ <h1>Welcome to PneumoInsight</h1>
82
+ <p>PneumoInsight is a side project of EarlyMed—an initiative by our team at VIT-AP University dedicated to empowering you with early health insights.
83
+ Leveraging AI for early detection, our mission is simple: "Early Detection, Smarter Decision."
84
+ This project is one of our key efforts to help you stay informed before visiting a doctor.</p>
85
+ </div>
86
+ """
87
+
88
+ model_explanation = """
89
+ <div class="model-explanation">
90
+ <h2>How Our Model Works</h2>
91
+ <p>Our system uses a Vision Transformer (ViT) model to analyze chest X-ray images. The attention heatmap visualizes
92
+ areas the AI focuses on while making its diagnosis, helping make the decision-making process more transparent.
93
+ The warmer colors (red/yellow) indicate areas of higher attention.</p>
94
+ <p>Credits: The attention heatmap visualization is implemented using the attention rollout technique by
95
+ <a href="https://github.com/jacobgil/vit-explain" target="_blank">jacobgil</a>.</p>
96
+ </div>
97
+ """
98
+
99
+ pneumonia_info = """
100
+ <div class="pneumonia-info">
101
+ <h2>Understanding Pneumonia</h2>
102
+ <p>Pneumonia is an infection that inflames the air sacs in one or both lungs. Common symptoms include:</p>
103
+ <ul>
104
+ <li>Chest pain when breathing or coughing</li>
105
+ <li>Cough with phlegm or pus</li>
106
+ <li>Fatigue and difficulty breathing</li>
107
+ <li>Fever, sweating, and shaking chills</li>
108
+ </ul>
109
+ <p>Prevention tips:</p>
110
+ <ul>
111
+ <li>Get vaccinated</li>
112
+ <li>Practice good hygiene</li>
113
+ <li>Don't smoke</li>
114
+ <li>Maintain a strong immune system</li>
115
+ </ul>
116
+ </div>
117
+ """
118
+
119
+ disclaimer = """
120
+ <div class="disclaimer">
121
+ <h3>Disclaimer</h3>
122
+ <p>This tool is for educational purposes only and should not be used as a substitute for professional medical advice,
123
+ diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider.</p>
124
+ <p>Created by the team at VIT-AP University. View the source code on
125
+ <a href="https://github.com/Mahatir-Ahmed-Tusher/PneumoInsight" target="_blank">GitHub</a>.</p>
126
+ </div>
127
+ """
128
+
129
+ def enhanced_classification(img):
130
+ if img is None:
131
+ return None, None, "Please upload an image to proceed."
132
+
133
+ result = classify_and_visualize(img)
134
+ probabilities = result["probabilities"]
135
+ heatmap = result["heatmap"]
136
+
137
+ # Get the predicted class
138
+ predicted_class = max(probabilities.items(), key=lambda x: x[1])[0]
139
+
140
+ # Get AI explanation
141
+ ai_explanation = get_ai_explanation(predicted_class, probabilities)
142
+
143
+ return probabilities, heatmap, ai_explanation
144
+
145
+ # Create the Gradio interface
146
+ iface = gr.Interface(
147
+ fn=enhanced_classification,
148
+ inputs=gr.Image(type="pil", label="Upload Chest X-Ray Image"),
149
+ outputs=[
150
+ gr.Label(label="Diagnosis Probabilities"),
151
+ gr.Image(label="Attention Heatmap"),
152
+ gr.Textbox(label="AI Analysis and Recommendations", lines=10)
153
+ ],
154
+ css=custom_css,
155
+ examples=load_examples_from_folder("./Examples"),
156
+ cache_examples=False,
157
+ article=model_explanation + pneumonia_info + disclaimer,
158
+ description=welcome_message,
159
+ title=logo_html,
160
+ theme=gr.themes.Soft()
161
+ )
162
+
163
+ return iface
164
 
165
  # Launch the app
166
  if __name__ == "__main__":
167
+ demo = create_interface()
168
+ demo.launch(debug=True)