Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,17 +1,13 @@
|
|
| 1 |
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
-
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline,
|
| 5 |
from rdkit import Chem
|
| 6 |
from rdkit.Chem import Draw, rdFMCS
|
| 7 |
from rdkit.Chem.Draw import MolToImage
|
| 8 |
# PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
|
| 9 |
# from PIL import Image
|
| 10 |
import pandas as pd
|
| 11 |
-
from bertviz import head_view # For potential future use or if other parts rely on it
|
| 12 |
-
from bertviz import neuron_view as neuron_view_function # Specific import for neuron_view function
|
| 13 |
-
# IPython.core.display.HTML is generally for notebooks. Gradio's gr.HTML handles HTML strings directly.
|
| 14 |
-
# from IPython.core.display import HTML
|
| 15 |
import io
|
| 16 |
import base64
|
| 17 |
import logging
|
|
@@ -58,14 +54,13 @@ def load_optimized_models():
|
|
| 58 |
|
| 59 |
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
|
| 60 |
|
| 61 |
-
# Model
|
| 62 |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
|
| 63 |
|
| 64 |
-
# Load
|
| 65 |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 66 |
-
attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
| 67 |
|
| 68 |
-
# Load
|
| 69 |
model_kwargs = {
|
| 70 |
"torch_dtype": torch_dtype,
|
| 71 |
}
|
|
@@ -85,35 +80,21 @@ def load_optimized_models():
|
|
| 85 |
model_name,
|
| 86 |
**model_kwargs
|
| 87 |
)
|
| 88 |
-
|
| 89 |
-
# RoBERTa model for attention
|
| 90 |
-
attention_model_kwargs = model_kwargs.copy()
|
| 91 |
-
attention_model_kwargs["output_attentions"] = True
|
| 92 |
-
|
| 93 |
-
attention_model = RobertaModel.from_pretrained(
|
| 94 |
-
model_name,
|
| 95 |
-
**attention_model_kwargs
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
# Set models to evaluation mode for inference
|
| 99 |
-
fill_mask_model.eval()
|
| 100 |
-
attention_model.eval()
|
| 101 |
|
| 102 |
# Create optimized pipeline
|
| 103 |
# Let pipeline infer device from model if possible, or set based on model's device
|
| 104 |
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
|
| 105 |
|
| 106 |
-
|
| 107 |
fill_mask_pipeline = pipeline(
|
| 108 |
'fill-mask',
|
| 109 |
model=fill_mask_model,
|
| 110 |
tokenizer=fill_mask_tokenizer,
|
| 111 |
device=pipeline_device, # Use model's device
|
| 112 |
-
# torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
|
| 113 |
)
|
| 114 |
|
| 115 |
logger.info("Models loaded successfully with optimizations")
|
| 116 |
-
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
|
| 117 |
|
| 118 |
except Exception as e:
|
| 119 |
logger.error(f"Error loading optimized models: {e}")
|
|
@@ -129,17 +110,13 @@ def load_standard_models(model_name):
|
|
| 129 |
device_idx = 0 if torch.cuda.is_available() else -1
|
| 130 |
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
|
| 131 |
|
| 132 |
-
attention_model = RobertaModel.from_pretrained(model_name, output_attentions=True)
|
| 133 |
-
attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
| 134 |
-
|
| 135 |
if torch.cuda.is_available():
|
| 136 |
fill_mask_model.to("cuda")
|
| 137 |
-
attention_model.to("cuda")
|
| 138 |
|
| 139 |
-
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
|
| 140 |
|
| 141 |
# Load models with optimizations
|
| 142 |
-
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
|
| 143 |
|
| 144 |
# --- Memory Management Utilities ---
|
| 145 |
def clear_gpu_cache():
|
|
@@ -249,57 +226,6 @@ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlig
|
|
| 249 |
return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
|
| 250 |
|
| 251 |
|
| 252 |
-
def visualize_attention_bertviz(sentence_a, sentence_b):
|
| 253 |
-
"""
|
| 254 |
-
Generates and displays BertViz neuron-by-neuron attention view as HTML.
|
| 255 |
-
Optimized with memory management and mixed precision.
|
| 256 |
-
"""
|
| 257 |
-
if not sentence_a or not sentence_b:
|
| 258 |
-
return "<p style='color:red;'>Please provide two SMILES strings.</p>"
|
| 259 |
-
try:
|
| 260 |
-
inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
|
| 261 |
-
input_ids = inputs['input_ids']
|
| 262 |
-
|
| 263 |
-
# Move to appropriate device if using GPU
|
| 264 |
-
if torch.cuda.is_available() and hasattr(attention_model, 'device'):
|
| 265 |
-
input_ids = input_ids.to(attention_model.device)
|
| 266 |
-
|
| 267 |
-
# Ensure model is in eval mode and use no_grad for inference
|
| 268 |
-
attention_model.eval()
|
| 269 |
-
with torch.no_grad():
|
| 270 |
-
# Use autocast for mixed precision if on CUDA
|
| 271 |
-
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): # Check for amp
|
| 272 |
-
with torch.cuda.amp.autocast(dtype=torch.float16 if get_torch_dtype() == torch.float16 else None):
|
| 273 |
-
attention_outputs = attention_model(input_ids)
|
| 274 |
-
else:
|
| 275 |
-
attention_outputs = attention_model(input_ids)
|
| 276 |
-
|
| 277 |
-
attention = attention_outputs[-1] # Last item in the tuple is attentions
|
| 278 |
-
input_id_list = input_ids[0].tolist()
|
| 279 |
-
tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list)
|
| 280 |
-
|
| 281 |
-
# Using the specifically imported neuron_view_function
|
| 282 |
-
html_object = neuron_view_function(attention, tokens)
|
| 283 |
-
|
| 284 |
-
# Extract HTML string from the IPython.core.display.HTML object
|
| 285 |
-
html_string = html_object.data # .data should provide the HTML string
|
| 286 |
-
|
| 287 |
-
# Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio
|
| 288 |
-
html_with_deps = f"""
|
| 289 |
-
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
|
| 290 |
-
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/5.16.0/d3.min.js"></script>
|
| 291 |
-
{html_string}
|
| 292 |
-
"""
|
| 293 |
-
|
| 294 |
-
# Clear cache after attention computation
|
| 295 |
-
clear_gpu_cache()
|
| 296 |
-
|
| 297 |
-
return html_with_deps
|
| 298 |
-
except Exception as e:
|
| 299 |
-
clear_gpu_cache() # Clear cache on error
|
| 300 |
-
logger.error(f"Error in visualize_attention_bertviz: {e}", exc_info=True)
|
| 301 |
-
return f"<p style='color:red;'>Error generating attention visualization: {str(e)}</p>"
|
| 302 |
-
|
| 303 |
def display_molecule_image(smiles_string):
|
| 304 |
"""
|
| 305 |
Displays a 2D image of a molecule from its SMILES string.
|
|
@@ -346,26 +272,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
| 346 |
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
|
| 347 |
)
|
| 348 |
|
| 349 |
-
with gr.Tab("Attention Visualization"):
|
| 350 |
-
gr.Markdown("Enter two SMILES strings to visualize **neuron-by-neuron attention** between them using BertViz. This may take a moment to render.")
|
| 351 |
-
with gr.Row():
|
| 352 |
-
smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC")
|
| 353 |
-
smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC")
|
| 354 |
-
visualize_button_attn = gr.Button("Visualize Attention")
|
| 355 |
-
attention_html_output = gr.HTML(label="Attention Neuron View") # Changed label for clarity
|
| 356 |
-
|
| 357 |
-
# Automatically populate on load for the default example
|
| 358 |
-
demo.load(
|
| 359 |
-
lambda: visualize_attention_bertviz("CCCCC[C@@H](Br)CC", "CCCCC[C@H](Br)CC"),
|
| 360 |
-
inputs=None,
|
| 361 |
-
outputs=[attention_html_output]
|
| 362 |
-
)
|
| 363 |
-
visualize_button_attn.click(
|
| 364 |
-
visualize_attention_bertviz,
|
| 365 |
-
inputs=[smiles_a_input_attn, smiles_b_input_attn],
|
| 366 |
-
outputs=[attention_html_output]
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
with gr.Tab("Molecule Viewer"):
|
| 370 |
gr.Markdown("Enter a SMILES string to display its 2D structure.")
|
| 371 |
smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
|
|
@@ -386,4 +292,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
| 386 |
)
|
| 387 |
|
| 388 |
if __name__ == "__main__":
|
| 389 |
-
demo.launch()
|
|
|
|
| 1 |
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
|
| 5 |
from rdkit import Chem
|
| 6 |
from rdkit.Chem import Draw, rdFMCS
|
| 7 |
from rdkit.Chem.Draw import MolToImage
|
| 8 |
# PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
|
| 9 |
# from PIL import Image
|
| 10 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
import io
|
| 12 |
import base64
|
| 13 |
import logging
|
|
|
|
| 54 |
|
| 55 |
logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
|
| 56 |
|
| 57 |
+
# Model name
|
| 58 |
model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
|
| 59 |
|
| 60 |
+
# Load tokenizer (doesn't need quantization)
|
| 61 |
fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
| 62 |
|
| 63 |
+
# Load model with quantization if available
|
| 64 |
model_kwargs = {
|
| 65 |
"torch_dtype": torch_dtype,
|
| 66 |
}
|
|
|
|
| 80 |
model_name,
|
| 81 |
**model_kwargs
|
| 82 |
)
|
| 83 |
+
fill_mask_model.eval() # Set model to evaluation mode for inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# Create optimized pipeline
|
| 86 |
# Let pipeline infer device from model if possible, or set based on model's device
|
| 87 |
pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
|
| 88 |
|
|
|
|
| 89 |
fill_mask_pipeline = pipeline(
|
| 90 |
'fill-mask',
|
| 91 |
model=fill_mask_model,
|
| 92 |
tokenizer=fill_mask_tokenizer,
|
| 93 |
device=pipeline_device, # Use model's device
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
logger.info("Models loaded successfully with optimizations")
|
| 97 |
+
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
|
| 98 |
|
| 99 |
except Exception as e:
|
| 100 |
logger.error(f"Error loading optimized models: {e}")
|
|
|
|
| 110 |
device_idx = 0 if torch.cuda.is_available() else -1
|
| 111 |
fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
|
| 112 |
|
|
|
|
|
|
|
|
|
|
| 113 |
if torch.cuda.is_available():
|
| 114 |
fill_mask_model.to("cuda")
|
|
|
|
| 115 |
|
| 116 |
+
return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
|
| 117 |
|
| 118 |
# Load models with optimizations
|
| 119 |
+
fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
|
| 120 |
|
| 121 |
# --- Memory Management Utilities ---
|
| 122 |
def clear_gpu_cache():
|
|
|
|
| 226 |
return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
|
| 227 |
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
def display_molecule_image(smiles_string):
|
| 230 |
"""
|
| 231 |
Displays a 2D image of a molecule from its SMILES string.
|
|
|
|
| 272 |
outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
|
| 273 |
)
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
with gr.Tab("Molecule Viewer"):
|
| 276 |
gr.Markdown("Enter a SMILES string to display its 2D structure.")
|
| 277 |
smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
|
|
|
|
| 292 |
)
|
| 293 |
|
| 294 |
if __name__ == "__main__":
|
| 295 |
+
demo.launch()
|