Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -93,7 +93,11 @@ def create_importance_plot(shap_values, kmers, top_k=10):
|
|
| 93 |
"""
|
| 94 |
Create horizontal bar plot of feature importance.
|
| 95 |
"""
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
fig = plt.figure(figsize=(10, 8))
|
| 98 |
|
| 99 |
# Sort by absolute importance
|
|
@@ -115,8 +119,13 @@ def create_contribution_plot(important_kmers, final_prob):
|
|
| 115 |
"""
|
| 116 |
Create waterfall plot showing cumulative feature contributions.
|
| 117 |
"""
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
base_prob = 0.5
|
| 122 |
cumulative = [base_prob]
|
|
@@ -126,15 +135,36 @@ def create_contribution_plot(important_kmers, final_prob):
|
|
| 126 |
cumulative.append(cumulative[-1] + kmer_info['impact'])
|
| 127 |
labels.append(kmer_info['kmer'])
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
|
|
|
|
|
|
|
| 138 |
return fig
|
| 139 |
|
| 140 |
def predict(file_obj, top_kmers=10, fasta_text=""):
|
|
@@ -165,7 +195,8 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
|
|
| 165 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 166 |
try:
|
| 167 |
model = VirusClassifier(256).to(device)
|
| 168 |
-
model
|
|
|
|
| 169 |
scaler = joblib.load('scaler.pkl')
|
| 170 |
except Exception as e:
|
| 171 |
return f"Error loading model: {str(e)}", None, None
|
|
|
|
| 93 |
"""
|
| 94 |
Create horizontal bar plot of feature importance.
|
| 95 |
"""
|
| 96 |
+
# Set style directly instead of using seaborn
|
| 97 |
+
plt.rcParams['figure.facecolor'] = '#ffffff'
|
| 98 |
+
plt.rcParams['axes.facecolor'] = '#ffffff'
|
| 99 |
+
plt.rcParams['axes.grid'] = True
|
| 100 |
+
plt.rcParams['grid.alpha'] = 0.3
|
| 101 |
fig = plt.figure(figsize=(10, 8))
|
| 102 |
|
| 103 |
# Sort by absolute importance
|
|
|
|
| 119 |
"""
|
| 120 |
Create waterfall plot showing cumulative feature contributions.
|
| 121 |
"""
|
| 122 |
+
# Set style parameters
|
| 123 |
+
plt.rcParams['figure.facecolor'] = '#ffffff'
|
| 124 |
+
plt.rcParams['axes.facecolor'] = '#ffffff'
|
| 125 |
+
plt.rcParams['axes.grid'] = True
|
| 126 |
+
plt.rcParams['grid.alpha'] = 0.3
|
| 127 |
+
|
| 128 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 129 |
|
| 130 |
base_prob = 0.5
|
| 131 |
cumulative = [base_prob]
|
|
|
|
| 135 |
cumulative.append(cumulative[-1] + kmer_info['impact'])
|
| 136 |
labels.append(kmer_info['kmer'])
|
| 137 |
|
| 138 |
+
# Plot cumulative line with markers
|
| 139 |
+
line = ax.plot(range(len(cumulative)), cumulative, '-o',
|
| 140 |
+
color='#3498db', linewidth=2,
|
| 141 |
+
marker='o', markersize=8,
|
| 142 |
+
markerfacecolor='white',
|
| 143 |
+
markeredgecolor='#3498db',
|
| 144 |
+
markeredgewidth=2)
|
| 145 |
+
|
| 146 |
+
# Add reference line at 0.5
|
| 147 |
+
ax.axhline(y=0.5, color='#95a5a6', linestyle='--', alpha=0.5)
|
| 148 |
+
|
| 149 |
+
# Customize plot
|
| 150 |
+
ax.set_xticks(range(len(labels)))
|
| 151 |
+
ax.set_xticklabels(labels, rotation=45, ha='right')
|
| 152 |
+
ax.set_ylim(0, 1)
|
| 153 |
+
ax.grid(True, axis='y', linestyle='--', alpha=0.3)
|
| 154 |
+
ax.set_title('Cumulative Feature Contributions')
|
| 155 |
+
ax.set_ylabel('Probability of Human Origin')
|
| 156 |
|
| 157 |
+
# Add value labels
|
| 158 |
+
for i, prob in enumerate(cumulative):
|
| 159 |
+
ax.annotate(f'{prob:.3f}',
|
| 160 |
+
(i, prob),
|
| 161 |
+
xytext=(0, 10),
|
| 162 |
+
textcoords='offset points',
|
| 163 |
+
ha='center',
|
| 164 |
+
va='bottom')
|
| 165 |
|
| 166 |
+
# Adjust layout to prevent label cutoff
|
| 167 |
+
plt.tight_layout()
|
| 168 |
return fig
|
| 169 |
|
| 170 |
def predict(file_obj, top_kmers=10, fasta_text=""):
|
|
|
|
| 195 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 196 |
try:
|
| 197 |
model = VirusClassifier(256).to(device)
|
| 198 |
+
# Load model weights safely
|
| 199 |
+
model.load_state_dict(torch.load('model.pt', map_location=device, weights_only=True))
|
| 200 |
scaler = joblib.load('scaler.pkl')
|
| 201 |
except Exception as e:
|
| 202 |
return f"Error loading model: {str(e)}", None, None
|