Anton Bushuiev commited on
Commit ·
e0dc24a
1
Parent(s): c765e79
Optimize image generation
Browse files
app.py
CHANGED
|
@@ -36,9 +36,9 @@ from dreams.definitions import *
|
|
| 36 |
# CONSTANTS AND CONFIGURATION
|
| 37 |
# =============================================================================
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
SMILES_IMG_SIZE = 200
|
| 41 |
-
SPECTRUM_IMG_SIZE = 1500
|
| 42 |
|
| 43 |
# Library and data paths
|
| 44 |
LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5')
|
|
@@ -48,6 +48,15 @@ EXAMPLE_PATH = Path('./data')
|
|
| 48 |
# Similarity threshold for filtering results
|
| 49 |
SIMILARITY_THRESHOLD = 0.75
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# =============================================================================
|
| 52 |
# UTILITY FUNCTIONS FOR IMAGE CONVERSION
|
| 53 |
# =============================================================================
|
|
@@ -83,7 +92,7 @@ def _convert_pil_to_base64(img, format='PNG'):
|
|
| 83 |
str: Base64 encoded image string
|
| 84 |
"""
|
| 85 |
buffered = io.BytesIO()
|
| 86 |
-
img.save(buffered, format=format)
|
| 87 |
img_str = base64.b64encode(buffered.getvalue())
|
| 88 |
return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}"
|
| 89 |
|
|
@@ -114,6 +123,7 @@ def _crop_transparent_edges(img):
|
|
| 114 |
def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
|
| 115 |
"""
|
| 116 |
Convert SMILES string to HTML image for display in Gradio dataframe
|
|
|
|
| 117 |
|
| 118 |
Args:
|
| 119 |
smiles: SMILES string representation of molecule
|
|
@@ -122,18 +132,25 @@ def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
|
|
| 122 |
Returns:
|
| 123 |
str: HTML img tag with base64 encoded image
|
| 124 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
try:
|
| 126 |
# Parse SMILES to RDKit molecule
|
| 127 |
mol = Chem.MolFromSmiles(smiles)
|
| 128 |
if mol is None:
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
|
| 131 |
# Create PNG drawing with Cairo backend for better control
|
| 132 |
d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
|
| 133 |
opts = d2d.drawOptions()
|
| 134 |
opts.clearBackground = False
|
| 135 |
opts.padding = 0.05 # Minimal padding
|
| 136 |
-
opts.bondLineWidth =
|
| 137 |
|
| 138 |
# Draw the molecule
|
| 139 |
d2d.DrawMolecule(mol)
|
|
@@ -147,15 +164,22 @@ def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
|
|
| 147 |
img = _crop_transparent_edges(img)
|
| 148 |
img_str = _convert_pil_to_base64(img)
|
| 149 |
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
except Exception as e:
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
|
| 157 |
"""
|
| 158 |
Convert spectrum plot to HTML image for display in Gradio dataframe
|
|
|
|
| 159 |
|
| 160 |
Args:
|
| 161 |
spec1: First spectrum data
|
|
@@ -170,11 +194,11 @@ def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
|
|
| 170 |
matplotlib.use('Agg')
|
| 171 |
|
| 172 |
# Create the spectrum plot using DreaMS utility function
|
| 173 |
-
su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(
|
| 174 |
|
| 175 |
# Save figure to buffer with transparent background
|
| 176 |
buffered = BytesIO()
|
| 177 |
-
plt.savefig(buffered, format='png', bbox_inches='tight', dpi=
|
| 178 |
buffered.seek(0)
|
| 179 |
|
| 180 |
# Convert to PIL Image, crop edges, and convert to base64
|
|
@@ -226,6 +250,9 @@ def setup():
|
|
| 226 |
print("Setting up DreaMS application...")
|
| 227 |
print("=" * 60)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
| 229 |
try:
|
| 230 |
# Download spectral library
|
| 231 |
library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
|
|
@@ -396,6 +423,9 @@ def _predict_core(lib_pth, in_pth, progress):
|
|
| 396 |
"""
|
| 397 |
in_pth = Path(in_pth)
|
| 398 |
|
|
|
|
|
|
|
|
|
|
| 399 |
# Load library data
|
| 400 |
progress(0, desc="Loading library data...")
|
| 401 |
msdata_lib = MSData.load(lib_pth)
|
|
@@ -431,6 +461,10 @@ def _predict_core(lib_pth, in_pth, progress):
|
|
| 431 |
for n, j in enumerate(topk):
|
| 432 |
row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs)
|
| 433 |
df.append(row_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
df = pd.DataFrame(df)
|
| 436 |
|
|
|
|
| 36 |
# CONSTANTS AND CONFIGURATION
|
| 37 |
# =============================================================================
|
| 38 |
|
| 39 |
+
# Optimized image sizes for better performance
|
| 40 |
+
SMILES_IMG_SIZE = 120 # Reduced from 200 for faster rendering
|
| 41 |
+
SPECTRUM_IMG_SIZE = 800 # Reduced from 1500 for faster generation
|
| 42 |
|
| 43 |
# Library and data paths
|
| 44 |
LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5')
|
|
|
|
| 48 |
# Similarity threshold for filtering results
|
| 49 |
SIMILARITY_THRESHOLD = 0.75
|
| 50 |
|
| 51 |
+
# Cache for SMILES images to avoid regeneration
|
| 52 |
+
_smiles_cache = {}
|
| 53 |
+
|
| 54 |
+
def clear_smiles_cache():
|
| 55 |
+
"""Clear the SMILES image cache to free memory"""
|
| 56 |
+
global _smiles_cache
|
| 57 |
+
_smiles_cache.clear()
|
| 58 |
+
print("SMILES image cache cleared")
|
| 59 |
+
|
| 60 |
# =============================================================================
|
| 61 |
# UTILITY FUNCTIONS FOR IMAGE CONVERSION
|
| 62 |
# =============================================================================
|
|
|
|
| 92 |
str: Base64 encoded image string
|
| 93 |
"""
|
| 94 |
buffered = io.BytesIO()
|
| 95 |
+
img.save(buffered, format=format, optimize=True) # Added optimize=True
|
| 96 |
img_str = base64.b64encode(buffered.getvalue())
|
| 97 |
return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}"
|
| 98 |
|
|
|
|
| 123 |
def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
|
| 124 |
"""
|
| 125 |
Convert SMILES string to HTML image for display in Gradio dataframe
|
| 126 |
+
Uses caching to avoid regenerating the same molecule images
|
| 127 |
|
| 128 |
Args:
|
| 129 |
smiles: SMILES string representation of molecule
|
|
|
|
| 132 |
Returns:
|
| 133 |
str: HTML img tag with base64 encoded image
|
| 134 |
"""
|
| 135 |
+
# Check cache first
|
| 136 |
+
cache_key = f"{smiles}_{img_size}"
|
| 137 |
+
if cache_key in _smiles_cache:
|
| 138 |
+
return _smiles_cache[cache_key]
|
| 139 |
+
|
| 140 |
try:
|
| 141 |
# Parse SMILES to RDKit molecule
|
| 142 |
mol = Chem.MolFromSmiles(smiles)
|
| 143 |
if mol is None:
|
| 144 |
+
result = f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
|
| 145 |
+
_smiles_cache[cache_key] = result
|
| 146 |
+
return result
|
| 147 |
|
| 148 |
# Create PNG drawing with Cairo backend for better control
|
| 149 |
d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
|
| 150 |
opts = d2d.drawOptions()
|
| 151 |
opts.clearBackground = False
|
| 152 |
opts.padding = 0.05 # Minimal padding
|
| 153 |
+
opts.bondLineWidth = 1.5 # Reduced from 2.0 for smaller images
|
| 154 |
|
| 155 |
# Draw the molecule
|
| 156 |
d2d.DrawMolecule(mol)
|
|
|
|
| 164 |
img = _crop_transparent_edges(img)
|
| 165 |
img_str = _convert_pil_to_base64(img)
|
| 166 |
|
| 167 |
+
result = f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
|
| 168 |
+
|
| 169 |
+
# Cache the result
|
| 170 |
+
_smiles_cache[cache_key] = result
|
| 171 |
+
return result
|
| 172 |
|
| 173 |
except Exception as e:
|
| 174 |
+
result = f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
|
| 175 |
+
_smiles_cache[cache_key] = result
|
| 176 |
+
return result
|
| 177 |
|
| 178 |
|
| 179 |
def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
|
| 180 |
"""
|
| 181 |
Convert spectrum plot to HTML image for display in Gradio dataframe
|
| 182 |
+
Optimized version based on working code
|
| 183 |
|
| 184 |
Args:
|
| 185 |
spec1: First spectrum data
|
|
|
|
| 194 |
matplotlib.use('Agg')
|
| 195 |
|
| 196 |
# Create the spectrum plot using DreaMS utility function
|
| 197 |
+
su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(1.6, 0.8)) # Reduced size for performance
|
| 198 |
|
| 199 |
# Save figure to buffer with transparent background
|
| 200 |
buffered = BytesIO()
|
| 201 |
+
plt.savefig(buffered, format='png', bbox_inches='tight', dpi=80, transparent=True)
|
| 202 |
buffered.seek(0)
|
| 203 |
|
| 204 |
# Convert to PIL Image, crop edges, and convert to base64
|
|
|
|
| 250 |
print("Setting up DreaMS application...")
|
| 251 |
print("=" * 60)
|
| 252 |
|
| 253 |
+
# Clear any existing cache
|
| 254 |
+
clear_smiles_cache()
|
| 255 |
+
|
| 256 |
try:
|
| 257 |
# Download spectral library
|
| 258 |
library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
|
|
|
|
| 423 |
"""
|
| 424 |
in_pth = Path(in_pth)
|
| 425 |
|
| 426 |
+
# Clear cache at start to prevent memory buildup
|
| 427 |
+
clear_smiles_cache()
|
| 428 |
+
|
| 429 |
# Load library data
|
| 430 |
progress(0, desc="Loading library data...")
|
| 431 |
msdata_lib = MSData.load(lib_pth)
|
|
|
|
| 461 |
for n, j in enumerate(topk):
|
| 462 |
row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs)
|
| 463 |
df.append(row_data)
|
| 464 |
+
|
| 465 |
+
# Clear cache every 100 spectra to prevent memory buildup
|
| 466 |
+
if (i + 1) % 100 == 0:
|
| 467 |
+
clear_smiles_cache()
|
| 468 |
|
| 469 |
df = pd.DataFrame(df)
|
| 470 |
|