Spaces:
Sleeping
Sleeping
Prevent Gradio startup crash via deferred DB loading
Browse files
app.py
CHANGED
|
@@ -91,7 +91,7 @@ custom_css = """
|
|
| 91 |
height:190px !important;
|
| 92 |
}
|
| 93 |
.gallery-height {
|
| 94 |
-
height:
|
| 95 |
}
|
| 96 |
#custom_plot {
|
| 97 |
height: 300px !important;
|
|
@@ -128,11 +128,25 @@ def search_structure_from_mass(structureDB,mass, ppm):
|
|
| 128 |
structures = structureDB[(structureDB['MonoisotopicMass'] >= mmin) & (structureDB['MonoisotopicMass'] <= mmax)]
|
| 129 |
return structures
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
conn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
device='cpu'
|
| 138 |
pretrain_model_path_low,pretrain_model_path_median,pretrain_model_path_high='model/low_energy/checkpoints/model.pth','model/median_energy/checkpoints/model.pth','model/high_energy/checkpoints/model.pth'
|
|
@@ -192,8 +206,8 @@ def MS2Embedding(spectra):
|
|
| 192 |
def calculate_cosine_similarity(vector1, vector2):
|
| 193 |
return cosine_similarity(vector1.reshape(1, -1), vector2.reshape(1, -1))[0][0]
|
| 194 |
|
| 195 |
-
def retrieve_similarity_scores( table_name, target_mass,collision_energy, ms2_embedding_low, ms2_embedding_median, ms2_embedding_high):
|
| 196 |
-
cur =
|
| 197 |
if table_name == "CSU_MS2_DB":
|
| 198 |
table_name = 'ConSSDB'
|
| 199 |
if table_name == "BloodExp: blood exposome database":
|
|
@@ -312,14 +326,14 @@ def get_topK_result(library,ms_feature, smiles_feature, topK):
|
|
| 312 |
return indices, scores, candidates
|
| 313 |
|
| 314 |
|
| 315 |
-
def rank_lib(database_name,spectrum_path,instrument_type,adduct,parent_Mass,collision_energy):
|
| 316 |
ms2 = list(load_from_msp(spectrum_path.name))[0]
|
| 317 |
ms2 = spectrum_processing(ms2)
|
| 318 |
collision_energy=float(collision_energy)
|
| 319 |
parent_Mass=float(parent_Mass)
|
| 320 |
ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = MS2Embedding(ms2)
|
| 321 |
ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = torch.tensor(ms2_embedding_low).float(),torch.tensor(ms2_embedding_median).float(),torch.tensor(ms2_embedding_high).float()
|
| 322 |
-
top_10_smiles = retrieve_similarity_scores(database_name,parent_Mass,collision_energy,ms2_embedding_low,ms2_embedding_median,ms2_embedding_high)
|
| 323 |
smis = [x[0][0] for x in top_10_smiles]
|
| 324 |
scores = [x[1] for x in top_10_smiles]
|
| 325 |
images,image_descrips=[],[]
|
|
@@ -411,6 +425,7 @@ def rank_user_lib(candidate_file,spectrum_path,instrument_type,adduct,parent_Mas
|
|
| 411 |
with gr.Blocks(theme=seafoam) as demo:
|
| 412 |
gr.HTML(custom_css)
|
| 413 |
gr.Markdown('<div style="font-size:50px; font-weight:bold;">🔍 CSU-MS2 web server </div>')
|
|
|
|
| 414 |
with gr.Row():
|
| 415 |
with gr.Column():
|
| 416 |
peak_data = gr.File(file_count="single", label="Upload MS/MS spectrum file in .msp format", elem_classes=".file-upload-height")
|
|
@@ -472,7 +487,8 @@ with gr.Blocks(theme=seafoam) as demo:
|
|
| 472 |
with gr.Column():
|
| 473 |
user_button = gr.Button("Cross-Modal Retrieval")
|
| 474 |
user_output = gr.Gallery(height='auto',columns=4,elem_classes="gallery-height",label='Cross-modal retrieval results')
|
| 475 |
-
|
|
|
|
| 476 |
user_button.click(rank_user_lib, inputs=[user_dataset,peak_data,instru,ionmode,par_ion_mass,collision_e], outputs=user_output)
|
| 477 |
demo.launch(share=True)
|
| 478 |
|
|
|
|
| 91 |
height:190px !important;
|
| 92 |
}
|
| 93 |
.gallery-height {
|
| 94 |
+
height: 350px !important;
|
| 95 |
}
|
| 96 |
#custom_plot {
|
| 97 |
height: 300px !important;
|
|
|
|
| 128 |
structures = structureDB[(structureDB['MonoisotopicMass'] >= mmin) & (structureDB['MonoisotopicMass'] <= mmax)]
|
| 129 |
return structures
|
| 130 |
|
| 131 |
+
conn = None
|
| 132 |
+
|
| 133 |
+
def initialize_db():
|
| 134 |
+
global conn
|
| 135 |
+
if conn is None:
|
| 136 |
+
dataset_repo = "Tingxie/CSU-MS2-DB"
|
| 137 |
+
db_filename = "csu_ms2_db.db"
|
| 138 |
+
token = os.getenv("HF_TOKEN")
|
| 139 |
+
print("Starting large file download and DB connection...")
|
| 140 |
+
db_path = hf_hub_download(repo_id=dataset_repo, filename=db_filename, repo_type="dataset", token=token)
|
| 141 |
+
conn = sqlite3.connect(db_path, check_same_thread=False)
|
| 142 |
+
print("DB initialization complete.")
|
| 143 |
+
return conn
|
| 144 |
+
|
| 145 |
+
#dataset_repo = "Tingxie/CSU-MS2-DB"
|
| 146 |
+
#db_filename = "csu_ms2_db.db"
|
| 147 |
+
#token = os.getenv("HF_TOKEN")
|
| 148 |
+
#db_path = hf_hub_download(repo_id=dataset_repo, filename=db_filename, repo_type="dataset", token=token)
|
| 149 |
+
#conn = sqlite3.connect(db_path, check_same_thread=False)
|
| 150 |
|
| 151 |
device='cpu'
|
| 152 |
pretrain_model_path_low,pretrain_model_path_median,pretrain_model_path_high='model/low_energy/checkpoints/model.pth','model/median_energy/checkpoints/model.pth','model/high_energy/checkpoints/model.pth'
|
|
|
|
| 206 |
def calculate_cosine_similarity(vector1, vector2):
|
| 207 |
return cosine_similarity(vector1.reshape(1, -1), vector2.reshape(1, -1))[0][0]
|
| 208 |
|
| 209 |
+
def retrieve_similarity_scores( conn_obj, table_name, target_mass,collision_energy, ms2_embedding_low, ms2_embedding_median, ms2_embedding_high):
|
| 210 |
+
cur = conn_obj.cursor()
|
| 211 |
if table_name == "CSU_MS2_DB":
|
| 212 |
table_name = 'ConSSDB'
|
| 213 |
if table_name == "BloodExp: blood exposome database":
|
|
|
|
| 326 |
return indices, scores, candidates
|
| 327 |
|
| 328 |
|
| 329 |
+
def rank_lib(conn_obj, database_name,spectrum_path,instrument_type,adduct,parent_Mass,collision_energy):
|
| 330 |
ms2 = list(load_from_msp(spectrum_path.name))[0]
|
| 331 |
ms2 = spectrum_processing(ms2)
|
| 332 |
collision_energy=float(collision_energy)
|
| 333 |
parent_Mass=float(parent_Mass)
|
| 334 |
ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = MS2Embedding(ms2)
|
| 335 |
ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = torch.tensor(ms2_embedding_low).float(),torch.tensor(ms2_embedding_median).float(),torch.tensor(ms2_embedding_high).float()
|
| 336 |
+
top_10_smiles = retrieve_similarity_scores(conn_obj, database_name,parent_Mass,collision_energy,ms2_embedding_low,ms2_embedding_median,ms2_embedding_high)
|
| 337 |
smis = [x[0][0] for x in top_10_smiles]
|
| 338 |
scores = [x[1] for x in top_10_smiles]
|
| 339 |
images,image_descrips=[],[]
|
|
|
|
| 425 |
with gr.Blocks(theme=seafoam) as demo:
|
| 426 |
gr.HTML(custom_css)
|
| 427 |
gr.Markdown('<div style="font-size:50px; font-weight:bold;">🔍 CSU-MS2 web server </div>')
|
| 428 |
+
db_conn_state = gr.State(None)
|
| 429 |
with gr.Row():
|
| 430 |
with gr.Column():
|
| 431 |
peak_data = gr.File(file_count="single", label="Upload MS/MS spectrum file in .msp format", elem_classes=".file-upload-height")
|
|
|
|
| 487 |
with gr.Column():
|
| 488 |
user_button = gr.Button("Cross-Modal Retrieval")
|
| 489 |
user_output = gr.Gallery(height='auto',columns=4,elem_classes="gallery-height",label='Cross-modal retrieval results')
|
| 490 |
+
demo.load(fn=initialize_db, inputs=None, outputs=db_conn_state, queue=True, show_progress="full")
|
| 491 |
+
lib_button.click(rank_lib, inputs=[db_conn_state, dataset,peak_data,instru,ionmode,par_ion_mass,collision_e], outputs=lib_output)
|
| 492 |
user_button.click(rank_user_lib, inputs=[user_dataset,peak_data,instru,ionmode,par_ion_mass,collision_e], outputs=user_output)
|
| 493 |
demo.launch(share=True)
|
| 494 |
|