Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -224,6 +224,28 @@ def compute_gc_content(sequence):
|
|
| 224 |
# 7. MAIN ANALYSIS STEP (Gradio Step 1)
|
| 225 |
###############################################################################
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
| 228 |
if fasta_text.strip():
|
| 229 |
text = fasta_text.strip()
|
|
@@ -232,13 +254,13 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
| 232 |
with open(file_obj, 'r') as f:
|
| 233 |
text = f.read()
|
| 234 |
except Exception as e:
|
| 235 |
-
return (f"Error reading file: {str(e)}", None, None, None, None, None)
|
| 236 |
else:
|
| 237 |
-
return ("Please provide a FASTA sequence.", None, None, None, None, None)
|
| 238 |
|
| 239 |
sequences = parse_fasta(text)
|
| 240 |
if not sequences:
|
| 241 |
-
return ("No valid FASTA sequences found.", None, None, None, None, None)
|
| 242 |
header, seq = sequences[0]
|
| 243 |
|
| 244 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
@@ -249,7 +271,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
| 249 |
model.load_state_dict(state_dict)
|
| 250 |
scaler = joblib.load('scaler.pkl')
|
| 251 |
except Exception as e:
|
| 252 |
-
return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None)
|
| 253 |
|
| 254 |
freq_vector = sequence_to_kmer_vector(seq)
|
| 255 |
scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
|
|
@@ -284,11 +306,13 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
| 284 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
|
| 285 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 286 |
|
| 287 |
-
#
|
| 288 |
-
|
|
|
|
|
|
|
| 289 |
state_dict_out = {"seq": seq, "shap_means": shap_means}
|
| 290 |
|
| 291 |
-
return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
|
| 292 |
|
| 293 |
###############################################################################
|
| 294 |
# 8. SUBREGION ANALYSIS (Gradio Step 2)
|
|
@@ -963,9 +987,22 @@ def prepare_csv_download(data, filename="analysis_results.csv"):
|
|
| 963 |
return output.getvalue().encode(), filename
|
| 964 |
else:
|
| 965 |
raise ValueError("Unsupported data type for CSV download")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 966 |
|
| 967 |
###############################################################################
|
| 968 |
-
#
|
| 969 |
###############################################################################
|
| 970 |
|
| 971 |
css = """
|
|
@@ -993,6 +1030,10 @@ with gr.Blocks(css=css) as iface:
|
|
| 993 |
with gr.Column(scale=1):
|
| 994 |
file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
|
| 995 |
text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
|
| 997 |
win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions")
|
| 998 |
analyze_btn = gr.Button("Analyze Sequence", variant="primary")
|
|
@@ -1000,14 +1041,25 @@ with gr.Blocks(css=css) as iface:
|
|
| 1000 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
| 1001 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
| 1002 |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
seq_state = gr.State()
|
| 1005 |
header_state = gr.State()
|
| 1006 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1007 |
analyze_btn.click(
|
| 1008 |
analyze_sequence,
|
| 1009 |
inputs=[file_input, top_k, text_input, win_size],
|
| 1010 |
-
outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results]
|
| 1011 |
)
|
| 1012 |
|
| 1013 |
with gr.Tab("2) Subregion Exploration"):
|
|
@@ -1114,8 +1166,9 @@ with gr.Blocks(css=css) as iface:
|
|
| 1114 |
- Statistical summary of differences
|
| 1115 |
- **Data Export**:
|
| 1116 |
- Download results as CSV files
|
|
|
|
| 1117 |
- Save analysis outputs for further processing
|
| 1118 |
""")
|
| 1119 |
|
| 1120 |
if __name__ == "__main__":
|
| 1121 |
-
iface.launch()
|
|
|
|
| 224 |
# 7. MAIN ANALYSIS STEP (Gradio Step 1)
|
| 225 |
###############################################################################
|
| 226 |
|
| 227 |
+
def create_kmer_shap_csv(kmers, shap_values):
|
| 228 |
+
"""Create a CSV file with k-mer SHAP values and return the filepath"""
|
| 229 |
+
# Create DataFrame with k-mers and SHAP values
|
| 230 |
+
kmer_df = pd.DataFrame({
|
| 231 |
+
'kmer': kmers,
|
| 232 |
+
'shap_value': shap_values,
|
| 233 |
+
'abs_shap': np.abs(shap_values)
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
# Sort by absolute SHAP value (most influential first)
|
| 237 |
+
kmer_df = kmer_df.sort_values('abs_shap', ascending=False)
|
| 238 |
+
|
| 239 |
+
# Drop the abs_shap column used for sorting
|
| 240 |
+
kmer_df = kmer_df[['kmer', 'shap_value']]
|
| 241 |
+
|
| 242 |
+
# Save to temporary file
|
| 243 |
+
temp_dir = tempfile.gettempdir()
|
| 244 |
+
temp_path = os.path.join(temp_dir, f"kmer_shap_values_{os.urandom(4).hex()}.csv")
|
| 245 |
+
kmer_df.to_csv(temp_path, index=False)
|
| 246 |
+
|
| 247 |
+
return temp_path
|
| 248 |
+
|
| 249 |
def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
| 250 |
if fasta_text.strip():
|
| 251 |
text = fasta_text.strip()
|
|
|
|
| 254 |
with open(file_obj, 'r') as f:
|
| 255 |
text = f.read()
|
| 256 |
except Exception as e:
|
| 257 |
+
return (f"Error reading file: {str(e)}", None, None, None, None, None, None)
|
| 258 |
else:
|
| 259 |
+
return ("Please provide a FASTA sequence.", None, None, None, None, None, None)
|
| 260 |
|
| 261 |
sequences = parse_fasta(text)
|
| 262 |
if not sequences:
|
| 263 |
+
return ("No valid FASTA sequences found.", None, None, None, None, None, None)
|
| 264 |
header, seq = sequences[0]
|
| 265 |
|
| 266 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 271 |
model.load_state_dict(state_dict)
|
| 272 |
scaler = joblib.load('scaler.pkl')
|
| 273 |
except Exception as e:
|
| 274 |
+
return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None, None)
|
| 275 |
|
| 276 |
freq_vector = sequence_to_kmer_vector(seq)
|
| 277 |
scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
|
|
|
|
| 306 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
|
| 307 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 308 |
|
| 309 |
+
# Create CSV with k-mer SHAP values and return the file path
|
| 310 |
+
kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
|
| 311 |
+
|
| 312 |
+
# State dictionary for subregion analysis
|
| 313 |
state_dict_out = {"seq": seq, "shap_means": shap_means}
|
| 314 |
|
| 315 |
+
return (results_text, bar_img, heatmap_img, state_dict_out, header, None, kmer_shap_csv)
|
| 316 |
|
| 317 |
###############################################################################
|
| 318 |
# 8. SUBREGION ANALYSIS (Gradio Step 2)
|
|
|
|
| 987 |
return output.getvalue().encode(), filename
|
| 988 |
else:
|
| 989 |
raise ValueError("Unsupported data type for CSV download")
|
| 990 |
+
|
| 991 |
+
###############################################################################
|
| 992 |
+
# 13. EXAMPLE FASTA LOADER
|
| 993 |
+
###############################################################################
|
| 994 |
+
|
| 995 |
+
def load_example_fasta():
|
| 996 |
+
"""Load the example.fasta file contents"""
|
| 997 |
+
try:
|
| 998 |
+
with open('example.fasta', 'r') as f:
|
| 999 |
+
example_text = f.read()
|
| 1000 |
+
return example_text
|
| 1001 |
+
except Exception as e:
|
| 1002 |
+
return f">example_sequence\nACGTACGT...\n\n(Note: Could not load example.fasta: {str(e)})"
|
| 1003 |
|
| 1004 |
###############################################################################
|
| 1005 |
+
# 14. BUILD GRADIO INTERFACE
|
| 1006 |
###############################################################################
|
| 1007 |
|
| 1008 |
css = """
|
|
|
|
| 1030 |
with gr.Column(scale=1):
|
| 1031 |
file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
|
| 1032 |
text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
|
| 1033 |
+
|
| 1034 |
+
with gr.Row():
|
| 1035 |
+
example_btn = gr.Button("Load Example FASTA", variant="secondary")
|
| 1036 |
+
|
| 1037 |
top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
|
| 1038 |
win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions")
|
| 1039 |
analyze_btn = gr.Button("Analyze Sequence", variant="primary")
|
|
|
|
| 1041 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
| 1042 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
| 1043 |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
|
| 1044 |
+
|
| 1045 |
+
with gr.Row():
|
| 1046 |
+
download_kmer_shap = gr.File(label="Download k-mer SHAP Values (CSV)", visible=True)
|
| 1047 |
+
download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button")
|
| 1048 |
+
|
| 1049 |
seq_state = gr.State()
|
| 1050 |
header_state = gr.State()
|
| 1051 |
|
| 1052 |
+
# Event handlers
|
| 1053 |
+
example_btn.click(
|
| 1054 |
+
load_example_fasta,
|
| 1055 |
+
inputs=[],
|
| 1056 |
+
outputs=[text_input]
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
analyze_btn.click(
|
| 1060 |
analyze_sequence,
|
| 1061 |
inputs=[file_input, top_k, text_input, win_size],
|
| 1062 |
+
outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results, download_kmer_shap]
|
| 1063 |
)
|
| 1064 |
|
| 1065 |
with gr.Tab("2) Subregion Exploration"):
|
|
|
|
| 1166 |
- Statistical summary of differences
|
| 1167 |
- **Data Export**:
|
| 1168 |
- Download results as CSV files
|
| 1169 |
+
- Download k-mer SHAP values
|
| 1170 |
- Save analysis outputs for further processing
|
| 1171 |
""")
|
| 1172 |
|
| 1173 |
if __name__ == "__main__":
|
| 1174 |
+
iface.launch()
|