Tingxie commited on
Commit
93c8da3
·
verified ·
1 Parent(s): 00c1ca5

Prevent Gradio startup crash via deferred DB loading

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -91,7 +91,7 @@ custom_css = """
91
  height:190px !important;
92
  }
93
  .gallery-height {
94
- height: 380px !important;
95
  }
96
  #custom_plot {
97
  height: 300px !important;
@@ -128,11 +128,25 @@ def search_structure_from_mass(structureDB,mass, ppm):
128
  structures = structureDB[(structureDB['MonoisotopicMass'] >= mmin) & (structureDB['MonoisotopicMass'] <= mmax)]
129
  return structures
130
 
131
- dataset_repo = "Tingxie/CSU-MS2-DB"
132
- db_filename = "csu_ms2_db.db"
133
- token = os.getenv("HF_TOKEN")
134
- db_path = hf_hub_download(repo_id=dataset_repo, filename=db_filename, repo_type="dataset", token=token)
135
- conn = sqlite3.connect(db_path, check_same_thread=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  device='cpu'
138
  pretrain_model_path_low,pretrain_model_path_median,pretrain_model_path_high='model/low_energy/checkpoints/model.pth','model/median_energy/checkpoints/model.pth','model/high_energy/checkpoints/model.pth'
@@ -192,8 +206,8 @@ def MS2Embedding(spectra):
192
  def calculate_cosine_similarity(vector1, vector2):
193
  return cosine_similarity(vector1.reshape(1, -1), vector2.reshape(1, -1))[0][0]
194
 
195
- def retrieve_similarity_scores( table_name, target_mass,collision_energy, ms2_embedding_low, ms2_embedding_median, ms2_embedding_high):
196
- cur = conn.cursor()
197
  if table_name == "CSU_MS2_DB":
198
  table_name = 'ConSSDB'
199
  if table_name == "BloodExp: blood exposome database":
@@ -312,14 +326,14 @@ def get_topK_result(library,ms_feature, smiles_feature, topK):
312
  return indices, scores, candidates
313
 
314
 
315
- def rank_lib(database_name,spectrum_path,instrument_type,adduct,parent_Mass,collision_energy):
316
  ms2 = list(load_from_msp(spectrum_path.name))[0]
317
  ms2 = spectrum_processing(ms2)
318
  collision_energy=float(collision_energy)
319
  parent_Mass=float(parent_Mass)
320
  ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = MS2Embedding(ms2)
321
  ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = torch.tensor(ms2_embedding_low).float(),torch.tensor(ms2_embedding_median).float(),torch.tensor(ms2_embedding_high).float()
322
- top_10_smiles = retrieve_similarity_scores(database_name,parent_Mass,collision_energy,ms2_embedding_low,ms2_embedding_median,ms2_embedding_high)
323
  smis = [x[0][0] for x in top_10_smiles]
324
  scores = [x[1] for x in top_10_smiles]
325
  images,image_descrips=[],[]
@@ -411,6 +425,7 @@ def rank_user_lib(candidate_file,spectrum_path,instrument_type,adduct,parent_Mas
411
  with gr.Blocks(theme=seafoam) as demo:
412
  gr.HTML(custom_css)
413
  gr.Markdown('<div style="font-size:50px; font-weight:bold;">🔍 CSU-MS2 web server </div>')
 
414
  with gr.Row():
415
  with gr.Column():
416
  peak_data = gr.File(file_count="single", label="Upload MS/MS spectrum file in .msp format", elem_classes=".file-upload-height")
@@ -472,7 +487,8 @@ with gr.Blocks(theme=seafoam) as demo:
472
  with gr.Column():
473
  user_button = gr.Button("Cross-Modal Retrieval")
474
  user_output = gr.Gallery(height='auto',columns=4,elem_classes="gallery-height",label='Cross-modal retrieval results')
475
- lib_button.click(rank_lib, inputs=[dataset,peak_data,instru,ionmode,par_ion_mass,collision_e], outputs=lib_output)
 
476
  user_button.click(rank_user_lib, inputs=[user_dataset,peak_data,instru,ionmode,par_ion_mass,collision_e], outputs=user_output)
477
  demo.launch(share=True)
478
 
 
91
  height:190px !important;
92
  }
93
  .gallery-height {
94
+ height: 350px !important;
95
  }
96
  #custom_plot {
97
  height: 300px !important;
 
128
  structures = structureDB[(structureDB['MonoisotopicMass'] >= mmin) & (structureDB['MonoisotopicMass'] <= mmax)]
129
  return structures
130
 
131
+ conn = None
132
+
133
+ def initialize_db():
134
+ global conn
135
+ if conn is None:
136
+ dataset_repo = "Tingxie/CSU-MS2-DB"
137
+ db_filename = "csu_ms2_db.db"
138
+ token = os.getenv("HF_TOKEN")
139
+ print("Starting large file download and DB connection...")
140
+ db_path = hf_hub_download(repo_id=dataset_repo, filename=db_filename, repo_type="dataset", token=token)
141
+ conn = sqlite3.connect(db_path, check_same_thread=False)
142
+ print("DB initialization complete.")
143
+ return conn
144
+
145
+ #dataset_repo = "Tingxie/CSU-MS2-DB"
146
+ #db_filename = "csu_ms2_db.db"
147
+ #token = os.getenv("HF_TOKEN")
148
+ #db_path = hf_hub_download(repo_id=dataset_repo, filename=db_filename, repo_type="dataset", token=token)
149
+ #conn = sqlite3.connect(db_path, check_same_thread=False)
150
 
151
  device='cpu'
152
  pretrain_model_path_low,pretrain_model_path_median,pretrain_model_path_high='model/low_energy/checkpoints/model.pth','model/median_energy/checkpoints/model.pth','model/high_energy/checkpoints/model.pth'
 
206
  def calculate_cosine_similarity(vector1, vector2):
207
  return cosine_similarity(vector1.reshape(1, -1), vector2.reshape(1, -1))[0][0]
208
 
209
+ def retrieve_similarity_scores( conn_obj, table_name, target_mass,collision_energy, ms2_embedding_low, ms2_embedding_median, ms2_embedding_high):
210
+ cur = conn_obj.cursor()
211
  if table_name == "CSU_MS2_DB":
212
  table_name = 'ConSSDB'
213
  if table_name == "BloodExp: blood exposome database":
 
326
  return indices, scores, candidates
327
 
328
 
329
+ def rank_lib(conn_obj, database_name,spectrum_path,instrument_type,adduct,parent_Mass,collision_energy):
330
  ms2 = list(load_from_msp(spectrum_path.name))[0]
331
  ms2 = spectrum_processing(ms2)
332
  collision_energy=float(collision_energy)
333
  parent_Mass=float(parent_Mass)
334
  ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = MS2Embedding(ms2)
335
  ms2_embedding_low,ms2_embedding_median,ms2_embedding_high = torch.tensor(ms2_embedding_low).float(),torch.tensor(ms2_embedding_median).float(),torch.tensor(ms2_embedding_high).float()
336
+ top_10_smiles = retrieve_similarity_scores(conn_obj, database_name,parent_Mass,collision_energy,ms2_embedding_low,ms2_embedding_median,ms2_embedding_high)
337
  smis = [x[0][0] for x in top_10_smiles]
338
  scores = [x[1] for x in top_10_smiles]
339
  images,image_descrips=[],[]
 
425
  with gr.Blocks(theme=seafoam) as demo:
426
  gr.HTML(custom_css)
427
  gr.Markdown('<div style="font-size:50px; font-weight:bold;">🔍 CSU-MS2 web server </div>')
428
+ db_conn_state = gr.State(None)
429
  with gr.Row():
430
  with gr.Column():
431
  peak_data = gr.File(file_count="single", label="Upload MS/MS spectrum file in .msp format", elem_classes=".file-upload-height")
 
487
  with gr.Column():
488
  user_button = gr.Button("Cross-Modal Retrieval")
489
  user_output = gr.Gallery(height='auto',columns=4,elem_classes="gallery-height",label='Cross-modal retrieval results')
490
+ demo.load(fn=initialize_db, inputs=None, outputs=db_conn_state, queue=True, show_progress="full")
491
+ lib_button.click(rank_lib, inputs=[db_conn_state, dataset,peak_data,instru,ionmode,par_ion_mass,collision_e], outputs=lib_output)
492
  user_button.click(rank_user_lib, inputs=[user_dataset,peak_data,instru,ionmode,par_ion_mass,collision_e], outputs=user_output)
493
  demo.launch(share=True)
494