Update utils.py
Browse filesmoved plotting function here
utils.py
CHANGED
|
@@ -20,4 +20,28 @@ def predict(model, sequence):
|
|
| 20 |
probabilities = F.softmax(output.logits, dim=-1)
|
| 21 |
predicted_label = torch.argmax(probabilities, dim=-1)
|
| 22 |
confidence = probabilities.max().item() * 0.85
|
| 23 |
-
return predicted_label.item(), confidence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
probabilities = F.softmax(output.logits, dim=-1)
|
| 21 |
predicted_label = torch.argmax(probabilities, dim=-1)
|
| 22 |
confidence = probabilities.max().item() * 0.85
|
| 23 |
+
return predicted_label.item(), confidence
|
| 24 |
+
|
| 25 |
+
def plot_prediction_graphs(data):
|
| 26 |
+
# Create a color palette that is consistent across graphs
|
| 27 |
+
unique_sequences = sorted(set(seq for seq in data))
|
| 28 |
+
palette = sns.color_palette("hsv", len(unique_sequences))
|
| 29 |
+
color_dict = {seq: color for seq, color in zip(unique_sequences, palette)}
|
| 30 |
+
|
| 31 |
+
for model_name in models.keys():
|
| 32 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
|
| 33 |
+
for prediction_val in [0, 1]:
|
| 34 |
+
ax = ax1 if prediction_val == 0 else ax2
|
| 35 |
+
filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
|
| 36 |
+
# Sorting sequences based on confidence, descending
|
| 37 |
+
sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
|
| 38 |
+
sequences = [x[0] for x in sorted_sequences]
|
| 39 |
+
conf_values = [x[1][1] for x in sorted_sequences]
|
| 40 |
+
colors = [color_dict[seq] for seq in sequences]
|
| 41 |
+
sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax)
|
| 42 |
+
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
|
| 43 |
+
ax.set_xlabel('Sequences')
|
| 44 |
+
ax.set_ylabel('Confidence')
|
| 45 |
+
ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
|
| 46 |
+
|
| 47 |
+
st.pyplot(fig)
|