Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from utils import validate_sequence, predict
|
| 3 |
from model import models
|
| 4 |
import pandas as pd
|
| 5 |
import matplotlib.pyplot as plt
|
|
@@ -56,29 +56,6 @@ def main():
|
|
| 56 |
st.write("## Graphs")
|
| 57 |
plot_prediction_graphs(all_data)
|
| 58 |
|
| 59 |
-
def plot_prediction_graphs(data):
|
| 60 |
-
# Create a color palette that is consistent across graphs
|
| 61 |
-
unique_sequences = sorted(set(seq for seq in data))
|
| 62 |
-
palette = sns.color_palette("hsv", len(unique_sequences))
|
| 63 |
-
color_dict = {seq: color for seq, color in zip(unique_sequences, palette)}
|
| 64 |
-
|
| 65 |
-
for model_name in models.keys():
|
| 66 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
|
| 67 |
-
for prediction_val in [0, 1]:
|
| 68 |
-
ax = ax1 if prediction_val == 0 else ax2
|
| 69 |
-
filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
|
| 70 |
-
# Sorting sequences based on confidence, descending
|
| 71 |
-
sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
|
| 72 |
-
sequences = [x[0] for x in sorted_sequences]
|
| 73 |
-
conf_values = [x[1][1] for x in sorted_sequences]
|
| 74 |
-
colors = [color_dict[seq] for seq in sequences]
|
| 75 |
-
sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax)
|
| 76 |
-
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
|
| 77 |
-
ax.set_xlabel('Sequences')
|
| 78 |
-
ax.set_ylabel('Confidence')
|
| 79 |
-
ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
|
| 80 |
-
|
| 81 |
-
st.pyplot(fig) # Display the plot with two subplots below the results table
|
| 82 |
|
| 83 |
if __name__ == "__main__":
|
| 84 |
main()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from utils import validate_sequence, predict, plot_prediction_graphs
|
| 3 |
from model import models
|
| 4 |
import pandas as pd
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
| 56 |
st.write("## Graphs")
|
| 57 |
plot_prediction_graphs(all_data)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
if __name__ == "__main__":
|
| 61 |
main()
|