Anton Bushuiev commited on
Commit
bbf2542
·
1 Parent(s): 1dd40c0

Add basic images

Browse files
Files changed (1) hide show
  1. app.py +89 -11
app.py CHANGED
@@ -2,13 +2,18 @@ import gradio as gr
2
  import urllib.request
3
  import os
4
  from functools import partial
5
-
 
6
  import pandas as pd
7
  import numpy as np
8
  from pathlib import Path
9
  from tqdm import tqdm
10
  from sklearn.metrics.pairwise import cosine_similarity
11
  from rdkit import Chem
 
 
 
 
12
  import dreams.utils.spectra as su
13
  import dreams.utils.io as io
14
  from dreams.utils.spectra import PeakListModifiedCosine
@@ -17,7 +22,60 @@ from dreams.api import dreams_embeddings
17
  from dreams.definitions import *
18
 
19
 
20
- def setup():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Download spectral library
22
  data_path = Path('./DreaMS/data')
23
  data_path.mkdir(parents=True, exist_ok=True)
@@ -38,7 +96,7 @@ def setup():
38
 
39
  def predict(lib_pth, in_pth):
40
  in_pth = Path(in_pth)
41
- # in_pth = Path('DreaMS/data/MSV000086206/peak/mzml/S_N1.mzML') # Example dataset
42
 
43
  msdata_lib = MSData.load(lib_pth)
44
  embs_lib = msdata_lib[DREAMS_EMBEDDING]
@@ -60,17 +118,21 @@ def predict(lib_pth, in_pth):
60
  cos_sim = su.PeakListModifiedCosine()
61
  for i, topk in enumerate(tqdm(topk_cands)):
62
  for n, j in enumerate(topk):
 
 
 
63
  df.append({
64
  'feature_id': i + 1,
65
  'topk': n + 1,
66
  'library_j': j,
67
- 'library_SMILES': msdata_lib.get_smiles(j),
 
68
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
69
  'DreaMS_similarity': sims[i, j],
70
  'Modified_cosine_similarity': cos_sim(
71
- spec1=msdata.get_spectra(i),
72
  prec_mz1=msdata.get_prec_mzs(i),
73
- spec2=msdata_lib.get_spectra(j),
74
  prec_mz2=msdata_lib.get_prec_mzs(j),
75
  ),
76
  'i': i,
@@ -78,7 +140,7 @@ def predict(lib_pth, in_pth):
78
  })
79
  df = pd.DataFrame(df)
80
 
81
- # TODO Add some (random) name to the output file
82
  df_path = io.append_to_stem(in_pth, 'MassSpecGym_hits').with_suffix('.csv')
83
  df.to_csv(df_path, index=False)
84
 
@@ -119,9 +181,18 @@ with app:
119
  with gr.Row(equal_height=True):
120
  in_pth = gr.File(
121
  file_count="single",
122
- label=".mzML file (TODO Extend to other formats)"
123
  )
124
  lib_pth = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5') # MassSpecGym library
 
 
 
 
 
 
 
 
 
125
 
126
  # Predict GUI
127
  predict_button = gr.Button(value="Run library matching", variant="primary")
@@ -130,9 +201,16 @@ with app:
130
  gr.Markdown("## Predictions")
131
  df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
132
  df = gr.Dataframe(
133
- headers=["feature_id", "topk", "library_j", "library_SMILES", "library_ID", "DreaMS_similarity", "Modified_cosine_similarity", "i", "j"],
134
- datatype=["number", "number", "number", "str", "str", "number", "number", "number", "number"],
135
- col_count=(9, "fixed"),
 
 
 
 
 
 
 
136
  )
137
 
138
  # Main logic
 
2
  import urllib.request
3
  import os
4
  from functools import partial
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib
7
  import pandas as pd
8
  import numpy as np
9
  from pathlib import Path
10
  from tqdm import tqdm
11
  from sklearn.metrics.pairwise import cosine_similarity
12
  from rdkit import Chem
13
+ from rdkit.Chem import Draw
14
+ from rdkit.Chem.Draw import rdMolDraw2D
15
+ import base64
16
+ from io import BytesIO
17
  import dreams.utils.spectra as su
18
  import dreams.utils.io as io
19
  from dreams.utils.spectra import PeakListModifiedCosine
 
22
  from dreams.definitions import *
23
 
24
 
25
+ def smiles_to_html_img(smiles, svg_size=1500):
26
+ """
27
+ Convert SMILES to HTML image string for display in Gradio dataframe
28
+ """
29
+ try:
30
+ mol = Chem.MolFromSmiles(smiles)
31
+ if mol is None:
32
+ return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
33
+
34
+ # Create SVG drawing
35
+ d2d = rdMolDraw2D.MolDraw2DSVG(svg_size, svg_size)
36
+ opts = d2d.drawOptions()
37
+ opts.clearBackground = False
38
+ d2d.DrawMolecule(mol)
39
+ d2d.FinishDrawing()
40
+ svg_str = d2d.GetDrawingText()
41
+
42
+ # Convert to base64 for HTML embedding
43
+ buffered = BytesIO()
44
+ buffered.write(str.encode(svg_str))
45
+ img_str = base64.b64encode(buffered.getvalue())
46
+ img_str = f"data:image/svg+xml;base64,{repr(img_str)[2:-1]}"
47
+
48
+ return f"<img src='{img_str}' style='width: {svg_size}px; height: {svg_size}px;' title='{smiles}' />"
49
+ except Exception as e:
50
+ return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
51
+
52
+
53
+ def spectrum_to_html_img(spec1, spec2, img_size=1500):
54
+ """
55
+ Convert spectrum plot to HTML image string for display in Gradio dataframe
56
+ """
57
+ try:
58
+ matplotlib.use('Agg') # Use non-interactive backend
59
+
60
+ # Create the plot using the existing function
61
+ su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(8, 4))
62
+
63
+ # Save the current figure to a buffer
64
+ buffered = BytesIO()
65
+ plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100)
66
+ buffered.seek(0)
67
+ img_str = base64.b64encode(buffered.getvalue())
68
+ img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
69
+
70
+ # Close the figure to free memory
71
+ plt.close()
72
+
73
+ return f"<img src='{img_str}' style='width: {img_size}px; height: auto;' title='Spectrum comparison' />"
74
+ except Exception as e:
75
+ return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
76
+
77
+
78
+ def setup():
79
  # Download spectral library
80
  data_path = Path('./DreaMS/data')
81
  data_path.mkdir(parents=True, exist_ok=True)
 
96
 
97
  def predict(lib_pth, in_pth):
98
  in_pth = Path(in_pth)
99
+ # # in_pth = Path('DreaMS/data/MSV000086206/peak/mzml/S_N1.mzML') # Example dataset
100
 
101
  msdata_lib = MSData.load(lib_pth)
102
  embs_lib = msdata_lib[DREAMS_EMBEDDING]
 
118
  cos_sim = su.PeakListModifiedCosine()
119
  for i, topk in enumerate(tqdm(topk_cands)):
120
  for n, j in enumerate(topk):
121
+ smiles = msdata_lib.get_smiles(j)
122
+ spec1 = msdata.get_spectra(i)
123
+ spec2 = msdata_lib.get_spectra(j)
124
  df.append({
125
  'feature_id': i + 1,
126
  'topk': n + 1,
127
  'library_j': j,
128
+ 'library_SMILES': smiles_to_html_img(smiles),
129
+ 'Spectrum': spectrum_to_html_img(spec1, spec2),
130
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
131
  'DreaMS_similarity': sims[i, j],
132
  'Modified_cosine_similarity': cos_sim(
133
+ spec1=spec1,
134
  prec_mz1=msdata.get_prec_mzs(i),
135
+ spec2=spec2,
136
  prec_mz2=msdata_lib.get_prec_mzs(j),
137
  ),
138
  'i': i,
 
140
  })
141
  df = pd.DataFrame(df)
142
 
143
+ # # TODO Add some (random) name to the output file
144
  df_path = io.append_to_stem(in_pth, 'MassSpecGym_hits').with_suffix('.csv')
145
  df.to_csv(df_path, index=False)
146
 
 
181
  with gr.Row(equal_height=True):
182
  in_pth = gr.File(
183
  file_count="single",
184
+ label=".mzML file (TODO Extend to other formats)",
185
  )
186
  lib_pth = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5') # MassSpecGym library
187
+ examples = gr.Examples(
188
+ examples=["./data/S_N1.mzML", "./data/example_5_spectra.mgf"],
189
+ inputs=[in_pth],
190
+ label="Examples (click on a line to pre-fill the inputs)",
191
+ # TODO
192
+ # cache_examples=True
193
+ # outputs=[df, df_file],
194
+ # fn=predict,
195
+ )
196
 
197
  # Predict GUI
198
  predict_button = gr.Button(value="Run library matching", variant="primary")
 
201
  gr.Markdown("## Predictions")
202
  df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
203
  df = gr.Dataframe(
204
+ headers=["feature_id", "topk", "library_j", "library_SMILES", "Spectrum", "library_ID", "DreaMS_similarity", "Modified_cosine_similarity", "i", "j"],
205
+ datatype=["number", "number", "number", "html", "html", "str", "number", "number", "number", "number"],
206
+ col_count=(10, "fixed"),
207
+ wrap=True,
208
+ column_widths=["80px", "60px", "80px", "400px", "800px", "120px", "120px", "150px", "60px", "60px"],
209
+ max_height=1000,
210
+ show_fullscreen_button=True,
211
+ show_row_numbers=True,
212
+ show_search='filter',
213
+ # pinned_columns= # TODO
214
  )
215
 
216
  # Main logic