Anton Bushuiev commited on
Commit
e0dc24a
·
1 Parent(s): c765e79

Optimize image generation

Browse files
Files changed (1) hide show
  1. app.py +44 -10
app.py CHANGED
@@ -36,9 +36,9 @@ from dreams.definitions import *
36
  # CONSTANTS AND CONFIGURATION
37
  # =============================================================================
38
 
39
- # Default image sizes for different components
40
- SMILES_IMG_SIZE = 200
41
- SPECTRUM_IMG_SIZE = 1500
42
 
43
  # Library and data paths
44
  LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5')
@@ -48,6 +48,15 @@ EXAMPLE_PATH = Path('./data')
48
  # Similarity threshold for filtering results
49
  SIMILARITY_THRESHOLD = 0.75
50
 
 
 
 
 
 
 
 
 
 
51
  # =============================================================================
52
  # UTILITY FUNCTIONS FOR IMAGE CONVERSION
53
  # =============================================================================
@@ -83,7 +92,7 @@ def _convert_pil_to_base64(img, format='PNG'):
83
  str: Base64 encoded image string
84
  """
85
  buffered = io.BytesIO()
86
- img.save(buffered, format=format)
87
  img_str = base64.b64encode(buffered.getvalue())
88
  return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}"
89
 
@@ -114,6 +123,7 @@ def _crop_transparent_edges(img):
114
  def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
115
  """
116
  Convert SMILES string to HTML image for display in Gradio dataframe
 
117
 
118
  Args:
119
  smiles: SMILES string representation of molecule
@@ -122,18 +132,25 @@ def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
122
  Returns:
123
  str: HTML img tag with base64 encoded image
124
  """
 
 
 
 
 
125
  try:
126
  # Parse SMILES to RDKit molecule
127
  mol = Chem.MolFromSmiles(smiles)
128
  if mol is None:
129
- return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
 
 
130
 
131
  # Create PNG drawing with Cairo backend for better control
132
  d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
133
  opts = d2d.drawOptions()
134
  opts.clearBackground = False
135
  opts.padding = 0.05 # Minimal padding
136
- opts.bondLineWidth = 2.0 # Make bonds more visible
137
 
138
  # Draw the molecule
139
  d2d.DrawMolecule(mol)
@@ -147,15 +164,22 @@ def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
147
  img = _crop_transparent_edges(img)
148
  img_str = _convert_pil_to_base64(img)
149
 
150
- return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
 
 
 
 
151
 
152
  except Exception as e:
153
- return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
 
 
154
 
155
 
156
  def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
157
  """
158
  Convert spectrum plot to HTML image for display in Gradio dataframe
 
159
 
160
  Args:
161
  spec1: First spectrum data
@@ -170,11 +194,11 @@ def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
170
  matplotlib.use('Agg')
171
 
172
  # Create the spectrum plot using DreaMS utility function
173
- su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(2, 1))
174
 
175
  # Save figure to buffer with transparent background
176
  buffered = BytesIO()
177
- plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100, transparent=True)
178
  buffered.seek(0)
179
 
180
  # Convert to PIL Image, crop edges, and convert to base64
@@ -226,6 +250,9 @@ def setup():
226
  print("Setting up DreaMS application...")
227
  print("=" * 60)
228
 
 
 
 
229
  try:
230
  # Download spectral library
231
  library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
@@ -396,6 +423,9 @@ def _predict_core(lib_pth, in_pth, progress):
396
  """
397
  in_pth = Path(in_pth)
398
 
 
 
 
399
  # Load library data
400
  progress(0, desc="Loading library data...")
401
  msdata_lib = MSData.load(lib_pth)
@@ -431,6 +461,10 @@ def _predict_core(lib_pth, in_pth, progress):
431
  for n, j in enumerate(topk):
432
  row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs)
433
  df.append(row_data)
 
 
 
 
434
 
435
  df = pd.DataFrame(df)
436
 
 
36
  # CONSTANTS AND CONFIGURATION
37
  # =============================================================================
38
 
39
+ # Optimized image sizes for better performance
40
+ SMILES_IMG_SIZE = 120 # Reduced from 200 for faster rendering
41
+ SPECTRUM_IMG_SIZE = 800 # Reduced from 1500 for faster generation
42
 
43
  # Library and data paths
44
  LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5')
 
48
  # Similarity threshold for filtering results
49
  SIMILARITY_THRESHOLD = 0.75
50
 
51
+ # Cache for SMILES images to avoid regeneration
52
+ _smiles_cache = {}
53
+
54
+ def clear_smiles_cache():
55
+ """Clear the SMILES image cache to free memory"""
56
+ global _smiles_cache
57
+ _smiles_cache.clear()
58
+ print("SMILES image cache cleared")
59
+
60
  # =============================================================================
61
  # UTILITY FUNCTIONS FOR IMAGE CONVERSION
62
  # =============================================================================
 
92
  str: Base64 encoded image string
93
  """
94
  buffered = io.BytesIO()
95
+ img.save(buffered, format=format, optimize=True) # Added optimize=True
96
  img_str = base64.b64encode(buffered.getvalue())
97
  return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}"
98
 
 
123
  def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
124
  """
125
  Convert SMILES string to HTML image for display in Gradio dataframe
126
+ Uses caching to avoid regenerating the same molecule images
127
 
128
  Args:
129
  smiles: SMILES string representation of molecule
 
132
  Returns:
133
  str: HTML img tag with base64 encoded image
134
  """
135
+ # Check cache first
136
+ cache_key = f"{smiles}_{img_size}"
137
+ if cache_key in _smiles_cache:
138
+ return _smiles_cache[cache_key]
139
+
140
  try:
141
  # Parse SMILES to RDKit molecule
142
  mol = Chem.MolFromSmiles(smiles)
143
  if mol is None:
144
+ result = f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
145
+ _smiles_cache[cache_key] = result
146
+ return result
147
 
148
  # Create PNG drawing with Cairo backend for better control
149
  d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
150
  opts = d2d.drawOptions()
151
  opts.clearBackground = False
152
  opts.padding = 0.05 # Minimal padding
153
+ opts.bondLineWidth = 1.5 # Reduced from 2.0 for smaller images
154
 
155
  # Draw the molecule
156
  d2d.DrawMolecule(mol)
 
164
  img = _crop_transparent_edges(img)
165
  img_str = _convert_pil_to_base64(img)
166
 
167
+ result = f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
168
+
169
+ # Cache the result
170
+ _smiles_cache[cache_key] = result
171
+ return result
172
 
173
  except Exception as e:
174
+ result = f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
175
+ _smiles_cache[cache_key] = result
176
+ return result
177
 
178
 
179
  def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
180
  """
181
  Convert spectrum plot to HTML image for display in Gradio dataframe
182
+ Optimized version based on working code
183
 
184
  Args:
185
  spec1: First spectrum data
 
194
  matplotlib.use('Agg')
195
 
196
  # Create the spectrum plot using DreaMS utility function
197
+ su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(1.6, 0.8)) # Reduced size for performance
198
 
199
  # Save figure to buffer with transparent background
200
  buffered = BytesIO()
201
+ plt.savefig(buffered, format='png', bbox_inches='tight', dpi=80, transparent=True)
202
  buffered.seek(0)
203
 
204
  # Convert to PIL Image, crop edges, and convert to base64
 
250
  print("Setting up DreaMS application...")
251
  print("=" * 60)
252
 
253
+ # Clear any existing cache
254
+ clear_smiles_cache()
255
+
256
  try:
257
  # Download spectral library
258
  library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
 
423
  """
424
  in_pth = Path(in_pth)
425
 
426
+ # Clear cache at start to prevent memory buildup
427
+ clear_smiles_cache()
428
+
429
  # Load library data
430
  progress(0, desc="Loading library data...")
431
  msdata_lib = MSData.load(lib_pth)
 
461
  for n, j in enumerate(topk):
462
  row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs)
463
  df.append(row_data)
464
+
465
+ # Clear cache every 100 spectra to prevent memory buildup
466
+ if (i + 1) % 100 == 0:
467
+ clear_smiles_cache()
468
 
469
  df = pd.DataFrame(df)
470