MahatirTusher commited on
Commit
8170a5d
·
verified ·
1 Parent(s): 4793156

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -2
app.py CHANGED
@@ -55,8 +55,91 @@ def get_ai_explanation(diagnosis, probabilities):
55
 
56
  return completion.choices[0].message.content
57
 
58
- # Rest of your existing functions (classify_and_visualize, show_final_layer_attention_maps, etc.) remain the same
59
- [Previous functions remain unchanged...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def create_interface():
62
  # Custom CSS
 
55
 
56
  return completion.choices[0].message.content
57
 
58
+ def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
59
+ img = img.convert("RGB")
60
+ processed_input = processor(images=img, return_tensors="pt").to(device)
61
+ processed_input = processed_input["pixel_values"].to(device)
62
+
63
+ with torch.no_grad():
64
+ outputs = model(processed_input, output_attentions=True)
65
+ logits = outputs.logits
66
+ probabilities = torch.softmax(logits, dim=1)[0].tolist()
67
+ prediction = torch.argmax(logits, dim=-1).item()
68
+ predicted_class = class_names[prediction]
69
+
70
+ result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
71
+
72
+ # Generate attention heatmap
73
+ heatmap_img = show_final_layer_attention_maps(
74
+ outputs, processed_input, device, discard_ratio, head_fusion
75
+ )
76
+
77
+ return {"probabilities": result, "heatmap": heatmap_img}
78
+
79
+ def show_final_layer_attention_maps(
80
+ outputs,
81
+ processed_input,
82
+ device,
83
+ discard_ratio=0.6,
84
+ head_fusion="max",
85
+ only_last_layer=False,
86
+ ):
87
+ with torch.no_grad():
88
+ image = processed_input.squeeze(0)
89
+ image = image - image.min()
90
+ image = image / image.max()
91
+
92
+ result = torch.eye(outputs.attentions[0].size(-1)).to(device)
93
+ if only_last_layer:
94
+ attention_list = outputs.attentions[-1].unsqueeze(0).to(device)
95
+ else:
96
+ attention_list = outputs.attentions
97
+
98
+ for attention in attention_list:
99
+ if head_fusion == "mean":
100
+ attention_heads_fused = attention.mean(axis=1)
101
+ elif head_fusion == "max":
102
+ attention_heads_fused = attention.max(axis=1)[0]
103
+ elif head_fusion == "min":
104
+ attention_heads_fused = attention.min(axis=1)[0]
105
+
106
+ flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
107
+ _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
108
+ indices = indices[indices != 0]
109
+ flat[0, indices] = 0
110
+
111
+ I = torch.eye(attention_heads_fused.size(-1)).to(device)
112
+ a = (attention_heads_fused + 1.0 * I) / 2
113
+ a = a / a.sum(dim=-1)
114
+ result = torch.matmul(a, result)
115
+
116
+ mask = result[0, 0, 1:]
117
+ width = int(mask.size(-1) ** 0.5)
118
+ mask = mask.reshape(width, width).cpu().numpy()
119
+ mask = mask / np.max(mask)
120
+
121
+ mask = cv2.resize(mask, (224, 224))
122
+ mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
123
+ heatmap = plt.cm.jet(mask)[:, :, :3]
124
+
125
+ showed_img = image.permute(1, 2, 0).detach().cpu().numpy()
126
+ showed_img = (showed_img - np.min(showed_img)) / (
127
+ np.max(showed_img) - np.min(showed_img)
128
+ )
129
+ superimposed_img = heatmap * 0.4 + showed_img * 0.6
130
+
131
+ superimposed_img_pil = Image.fromarray(
132
+ (superimposed_img * 255).astype(np.uint8)
133
+ )
134
+ return superimposed_img_pil
135
+
136
+ def load_examples_from_folder(folder_path):
137
+ examples = []
138
+ if os.path.exists(folder_path):
139
+ for file in os.listdir(folder_path):
140
+ if file.endswith((".png", ".jpg", ".jpeg")):
141
+ examples.append(os.path.join(folder_path, file))
142
+ return examples
143
 
144
  def create_interface():
145
  # Custom CSS