alidenewade commited on
Commit
ee612c3
·
verified ·
1 Parent(s): bdc69d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -103
app.py CHANGED
@@ -1,17 +1,13 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
- from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer, BitsAndBytesConfig
5
  from rdkit import Chem
6
  from rdkit.Chem import Draw, rdFMCS
7
  from rdkit.Chem.Draw import MolToImage
8
  # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
9
  # from PIL import Image
10
  import pandas as pd
11
- from bertviz import head_view # For potential future use or if other parts rely on it
12
- from bertviz import neuron_view as neuron_view_function # Specific import for neuron_view function
13
- # IPython.core.display.HTML is generally for notebooks. Gradio's gr.HTML handles HTML strings directly.
14
- # from IPython.core.display import HTML
15
  import io
16
  import base64
17
  import logging
@@ -58,14 +54,13 @@ def load_optimized_models():
58
 
59
  logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
60
 
61
- # Model names
62
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
63
 
64
- # Load tokenizers (these don't need quantization)
65
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
66
- attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
67
 
68
- # Load models with quantization if available
69
  model_kwargs = {
70
  "torch_dtype": torch_dtype,
71
  }
@@ -85,35 +80,21 @@ def load_optimized_models():
85
  model_name,
86
  **model_kwargs
87
  )
88
-
89
- # RoBERTa model for attention
90
- attention_model_kwargs = model_kwargs.copy()
91
- attention_model_kwargs["output_attentions"] = True
92
-
93
- attention_model = RobertaModel.from_pretrained(
94
- model_name,
95
- **attention_model_kwargs
96
- )
97
-
98
- # Set models to evaluation mode for inference
99
- fill_mask_model.eval()
100
- attention_model.eval()
101
 
102
  # Create optimized pipeline
103
  # Let pipeline infer device from model if possible, or set based on model's device
104
  pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
105
 
106
-
107
  fill_mask_pipeline = pipeline(
108
  'fill-mask',
109
  model=fill_mask_model,
110
  tokenizer=fill_mask_tokenizer,
111
  device=pipeline_device, # Use model's device
112
- # torch_dtype=torch_dtype # Pipeline might infer this or it might conflict
113
  )
114
 
115
  logger.info("Models loaded successfully with optimizations")
116
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
117
 
118
  except Exception as e:
119
  logger.error(f"Error loading optimized models: {e}")
@@ -129,17 +110,13 @@ def load_standard_models(model_name):
129
  device_idx = 0 if torch.cuda.is_available() else -1
130
  fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
131
 
132
- attention_model = RobertaModel.from_pretrained(model_name, output_attentions=True)
133
- attention_tokenizer = RobertaTokenizer.from_pretrained(model_name)
134
-
135
  if torch.cuda.is_available():
136
  fill_mask_model.to("cuda")
137
- attention_model.to("cuda")
138
 
139
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer
140
 
141
  # Load models with optimizations
142
- fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline, attention_model, attention_tokenizer = load_optimized_models()
143
 
144
  # --- Memory Management Utilities ---
145
  def clear_gpu_cache():
@@ -249,57 +226,6 @@ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlig
249
  return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
250
 
251
 
252
- def visualize_attention_bertviz(sentence_a, sentence_b):
253
- """
254
- Generates and displays BertViz neuron-by-neuron attention view as HTML.
255
- Optimized with memory management and mixed precision.
256
- """
257
- if not sentence_a or not sentence_b:
258
- return "<p style='color:red;'>Please provide two SMILES strings.</p>"
259
- try:
260
- inputs = attention_tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
261
- input_ids = inputs['input_ids']
262
-
263
- # Move to appropriate device if using GPU
264
- if torch.cuda.is_available() and hasattr(attention_model, 'device'):
265
- input_ids = input_ids.to(attention_model.device)
266
-
267
- # Ensure model is in eval mode and use no_grad for inference
268
- attention_model.eval()
269
- with torch.no_grad():
270
- # Use autocast for mixed precision if on CUDA
271
- if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): # Check for amp
272
- with torch.cuda.amp.autocast(dtype=torch.float16 if get_torch_dtype() == torch.float16 else None):
273
- attention_outputs = attention_model(input_ids)
274
- else:
275
- attention_outputs = attention_model(input_ids)
276
-
277
- attention = attention_outputs[-1] # Last item in the tuple is attentions
278
- input_id_list = input_ids[0].tolist()
279
- tokens = attention_tokenizer.convert_ids_to_tokens(input_id_list)
280
-
281
- # Using the specifically imported neuron_view_function
282
- html_object = neuron_view_function(attention, tokens)
283
-
284
- # Extract HTML string from the IPython.core.display.HTML object
285
- html_string = html_object.data # .data should provide the HTML string
286
-
287
- # Add D3 and jQuery CDN links to the HTML string for better rendering in Gradio
288
- html_with_deps = f"""
289
- <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
290
- <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/5.16.0/d3.min.js"></script>
291
- {html_string}
292
- """
293
-
294
- # Clear cache after attention computation
295
- clear_gpu_cache()
296
-
297
- return html_with_deps
298
- except Exception as e:
299
- clear_gpu_cache() # Clear cache on error
300
- logger.error(f"Error in visualize_attention_bertviz: {e}", exc_info=True)
301
- return f"<p style='color:red;'>Error generating attention visualization: {str(e)}</p>"
302
-
303
  def display_molecule_image(smiles_string):
304
  """
305
  Displays a 2D image of a molecule from its SMILES string.
@@ -346,26 +272,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
346
  outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
347
  )
348
 
349
- with gr.Tab("Attention Visualization"):
350
- gr.Markdown("Enter two SMILES strings to visualize **neuron-by-neuron attention** between them using BertViz. This may take a moment to render.")
351
- with gr.Row():
352
- smiles_a_input_attn = gr.Textbox(label="SMILES String A", value="CCCCC[C@@H](Br)CC")
353
- smiles_b_input_attn = gr.Textbox(label="SMILES String B", value="CCCCC[C@H](Br)CC")
354
- visualize_button_attn = gr.Button("Visualize Attention")
355
- attention_html_output = gr.HTML(label="Attention Neuron View") # Changed label for clarity
356
-
357
- # Automatically populate on load for the default example
358
- demo.load(
359
- lambda: visualize_attention_bertviz("CCCCC[C@@H](Br)CC", "CCCCC[C@H](Br)CC"),
360
- inputs=None,
361
- outputs=[attention_html_output]
362
- )
363
- visualize_button_attn.click(
364
- visualize_attention_bertviz,
365
- inputs=[smiles_a_input_attn, smiles_b_input_attn],
366
- outputs=[attention_html_output]
367
- )
368
-
369
  with gr.Tab("Molecule Viewer"):
370
  gr.Markdown("Enter a SMILES string to display its 2D structure.")
371
  smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
@@ -386,4 +292,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
386
  )
387
 
388
  if __name__ == "__main__":
389
- demo.launch()
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
  from rdkit import Chem
6
  from rdkit.Chem import Draw, rdFMCS
7
  from rdkit.Chem.Draw import MolToImage
8
  # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly.
9
  # from PIL import Image
10
  import pandas as pd
 
 
 
 
11
  import io
12
  import base64
13
  import logging
 
54
 
55
  logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}")
56
 
57
+ # Model name
58
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
59
 
60
+ # Load tokenizer (doesn't need quantization)
61
  fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
 
62
 
63
+ # Load model with quantization if available
64
  model_kwargs = {
65
  "torch_dtype": torch_dtype,
66
  }
 
80
  model_name,
81
  **model_kwargs
82
  )
83
+ fill_mask_model.eval() # Set model to evaluation mode for inference
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # Create optimized pipeline
86
  # Let pipeline infer device from model if possible, or set based on model's device
87
  pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
88
 
 
89
  fill_mask_pipeline = pipeline(
90
  'fill-mask',
91
  model=fill_mask_model,
92
  tokenizer=fill_mask_tokenizer,
93
  device=pipeline_device, # Use model's device
 
94
  )
95
 
96
  logger.info("Models loaded successfully with optimizations")
97
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
98
 
99
  except Exception as e:
100
  logger.error(f"Error loading optimized models: {e}")
 
110
  device_idx = 0 if torch.cuda.is_available() else -1
111
  fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
112
 
 
 
 
113
  if torch.cuda.is_available():
114
  fill_mask_model.to("cuda")
 
115
 
116
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
117
 
118
  # Load models with optimizations
119
+ fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
120
 
121
  # --- Memory Management Utilities ---
122
  def clear_gpu_cache():
 
226
  return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message
227
 
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def display_molecule_image(smiles_string):
230
  """
231
  Displays a 2D image of a molecule from its SMILES string.
 
272
  outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked]
273
  )
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  with gr.Tab("Molecule Viewer"):
276
  gr.Markdown("Enter a SMILES string to display its 2D structure.")
277
  smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1")
 
292
  )
293
 
294
  if __name__ == "__main__":
295
+ demo.launch()