Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw, rdFMCS | |
| from rdkit.Chem.Draw import MolToImage | |
| # PIL is imported as Image by rdkit.Chem.Draw.MolToImage, but explicit import is good practice if used directly. | |
| # from PIL import Image | |
| import pandas as pd | |
| import io | |
| import base64 | |
| import logging | |
| # Set up logging to monitor quantization effects | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Quantization Configuration --- | |
| def get_quantization_config(): | |
| """ | |
| Configure 8-bit quantization for model optimization. | |
| Falls back gracefully if bitsandbytes is not available. | |
| """ | |
| try: | |
| # 8-bit quantization configuration - good balance of speed and quality | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_compute_dtype=torch.float16, | |
| bnb_8bit_use_double_quant=True, # Nested quantization for better compression | |
| ) | |
| logger.info("8-bit quantization configuration loaded successfully") | |
| return quantization_config | |
| except ImportError: | |
| logger.warning("bitsandbytes not available, falling back to standard loading") | |
| return None | |
| except Exception as e: | |
| logger.warning(f"Quantization setup failed: {e}, using standard loading") | |
| return None | |
| def get_torch_dtype(): | |
| """Get appropriate torch dtype based on available hardware.""" | |
| if torch.cuda.is_available(): | |
| return torch.float16 # Use half precision on GPU | |
| else: | |
| return torch.float32 # Keep full precision on CPU | |
| # --- Optimized Model Loading --- | |
| def load_optimized_models(): | |
| """Load models with quantization and other optimizations.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = get_torch_dtype() | |
| quantization_config = get_quantization_config() | |
| logger.info(f"Loading models on device: {device} with dtype: {torch_dtype}") | |
| # Model name | |
| model_name = "seyonec/PubChem10M_SMILES_BPE_450k" | |
| # Load tokenizer (doesn't need quantization) | |
| fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Load model with quantization if available | |
| model_kwargs = { | |
| "torch_dtype": torch_dtype, | |
| } | |
| if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU | |
| model_kwargs["quantization_config"] = quantization_config | |
| # device_map="auto" is often used with bitsandbytes for automatic distribution | |
| model_kwargs["device_map"] = "auto" | |
| elif torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" # For non-quantized GPU loading | |
| else: | |
| model_kwargs["device_map"] = None # For CPU | |
| try: | |
| # Masked LM Model | |
| fill_mask_model = AutoModelForMaskedLM.from_pretrained( | |
| model_name, | |
| **model_kwargs | |
| ) | |
| fill_mask_model.eval() # Set model to evaluation mode for inference | |
| # Create optimized pipeline | |
| # Let pipeline infer device from model if possible, or set based on model's device | |
| pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1 | |
| fill_mask_pipeline = pipeline( | |
| 'fill-mask', | |
| model=fill_mask_model, | |
| tokenizer=fill_mask_tokenizer, | |
| device=pipeline_device, # Use model's device | |
| ) | |
| logger.info("Models loaded successfully with optimizations") | |
| return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
| except Exception as e: | |
| logger.error(f"Error loading optimized models: {e}") | |
| # Fallback to standard loading | |
| logger.info("Falling back to standard model loading...") | |
| return load_standard_models(model_name) | |
| def load_standard_models(model_name): | |
| """Fallback standard model loading without quantization.""" | |
| fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| # Determine device for standard loading | |
| device_idx = 0 if torch.cuda.is_available() else -1 | |
| fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx) | |
| if torch.cuda.is_available(): | |
| fill_mask_model.to("cuda") | |
| return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline | |
| # Load models with optimizations | |
| fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models() | |
| # --- Memory Management Utilities --- | |
| def clear_gpu_cache(): | |
| """Clear CUDA cache to free up memory.""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # --- Helper Functions from Notebook (adapted) --- | |
| def get_mol(smiles): | |
| """Converts SMILES to RDKit Mol object and Kekulizes it.""" | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| try: | |
| Chem.Kekulize(mol) | |
| except: # Kekulization can fail for some structures | |
| pass | |
| return mol | |
| def find_matches_one(mol, submol_smarts): | |
| """Finds all matching atoms for a SMARTS pattern in a molecule.""" | |
| if not mol or not submol_smarts: | |
| return [] | |
| submol = Chem.MolFromSmarts(submol_smarts) | |
| if not submol: | |
| return [] | |
| matches = mol.GetSubstructMatches(submol) | |
| return matches | |
| def get_image_with_highlight(mol, atomset=None, size=(300, 300)): | |
| """Draws molecule with optional atom highlighting.""" | |
| if mol is None: | |
| return None | |
| highlight_color = (0, 1, 0, 0.5) # Green with some transparency | |
| # Ensure atomset contains integers if not None or empty | |
| valid_atomset = [] | |
| if atomset: | |
| try: | |
| valid_atomset = [int(a) for a in atomset] | |
| except ValueError: | |
| logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.") | |
| valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers | |
| img = MolToImage(mol, size=size, fitImage=True, | |
| highlightAtoms=valid_atomset if valid_atomset else [], | |
| highlightAtomColors={i: highlight_color for i in valid_atomset} if valid_atomset else {}) | |
| return img | |
| # --- Optimized Gradio Interface Functions --- | |
| def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlight="CC=CC"): | |
| """ | |
| Predicts masked tokens in a SMILES string, shows scores, and visualizes molecules. | |
| Optimized with memory management. Returns 7 items for Gradio outputs. | |
| """ | |
| if fill_mask_tokenizer.mask_token not in smiles_mask: | |
| # Return 7 items for the 7 output components | |
| return pd.DataFrame(), None, None, None, None, None, "Error: Input SMILES must contain a mask token (e.g., <mask>)." | |
| try: | |
| # Use torch.no_grad() for inference to save memory | |
| with torch.no_grad(): | |
| predictions = fill_mask_pipeline(smiles_mask, top_k=10) # Get more to filter for valid ones | |
| except Exception as e: | |
| clear_gpu_cache() # Clear cache on error | |
| # Return 7 items | |
| return pd.DataFrame(), None, None, None, None, None, f"Error during prediction: {str(e)}" | |
| results_data = [] | |
| image_list = [] | |
| valid_predictions_count = 0 | |
| for pred in predictions: | |
| if valid_predictions_count >= 5: | |
| break | |
| predicted_smiles = pred['sequence'] | |
| score = pred['score'] | |
| mol = get_mol(predicted_smiles) | |
| if mol: | |
| results_data.append({"Predicted SMILES": predicted_smiles, "Score": f"{score:.4f}"}) | |
| atom_matches_indices = [] | |
| if substructure_smarts_highlight: | |
| matches = find_matches_one(mol, substructure_smarts_highlight) | |
| if matches: | |
| atom_matches_indices = list(matches[0]) # Highlight first match | |
| img = get_image_with_highlight(mol, atomset=atom_matches_indices) | |
| image_list.append(img) | |
| valid_predictions_count += 1 | |
| # Pad image_list if fewer than 5 valid predictions | |
| while len(image_list) < 5: | |
| image_list.append(None) | |
| df_results = pd.DataFrame(results_data) | |
| # Clear cache after inference | |
| clear_gpu_cache() | |
| status_message = "Prediction successful." if valid_predictions_count > 0 else "No valid molecules found for top predictions." | |
| # Unpack image_list into individual image outputs + df_results + status_message | |
| return df_results, image_list[0], image_list[1], image_list[2], image_list[3], image_list[4], status_message | |
| def display_molecule_image(smiles_string): | |
| """ | |
| Displays a 2D image of a molecule from its SMILES string. | |
| """ | |
| if not smiles_string: | |
| return None, "Please enter a SMILES string." | |
| mol = get_mol(smiles_string) | |
| if mol is None: | |
| return None, "Invalid SMILES string." | |
| img = MolToImage(mol, size=(400, 400), fitImage=True) | |
| return img, "Molecule displayed." | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# ChemBERTa SMILES Utilities Dashboard") | |
| with gr.Tab("Masked SMILES Prediction"): | |
| gr.Markdown("Enter a SMILES string with a `<mask>` token (e.g., `C1=CC=CC<mask>C1`) to predict possible completions.") | |
| with gr.Row(): | |
| smiles_input_masked = gr.Textbox(label="SMILES String with Mask", value="C1=CC=CC<mask>C1") | |
| substructure_input = gr.Textbox(label="Substructure to Highlight (SMARTS)", value="C=C") | |
| predict_button_masked = gr.Button("Predict and Visualize") | |
| status_masked = gr.Textbox(label="Status", interactive=False) | |
| predictions_table = gr.DataFrame(label="Top Predictions & Scores") | |
| gr.Markdown("### Predicted Molecule Visualizations (Top 5 Valid)") | |
| with gr.Row(): | |
| img_out_1 = gr.Image(label="Prediction 1", type="pil", interactive=False) | |
| img_out_2 = gr.Image(label="Prediction 2", type="pil", interactive=False) | |
| img_out_3 = gr.Image(label="Prediction 3", type="pil", interactive=False) | |
| img_out_4 = gr.Image(label="Prediction 4", type="pil", interactive=False) | |
| img_out_5 = gr.Image(label="Prediction 5", type="pil", interactive=False) | |
| # Automatically populate on load for the default example | |
| demo.load( | |
| lambda: predict_and_visualize_masked_smiles("C1=CC=CC<mask>C1", "C=C"), | |
| inputs=None, | |
| outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked] | |
| ) | |
| predict_button_masked.click( | |
| predict_and_visualize_masked_smiles, | |
| inputs=[smiles_input_masked, substructure_input], | |
| outputs=[predictions_table, img_out_1, img_out_2, img_out_3, img_out_4, img_out_5, status_masked] | |
| ) | |
| with gr.Tab("Molecule Viewer"): | |
| gr.Markdown("Enter a SMILES string to display its 2D structure.") | |
| smiles_input_viewer = gr.Textbox(label="SMILES String", value="C1=CC=CC=C1") | |
| view_button_molecule = gr.Button("View Molecule") | |
| status_viewer = gr.Textbox(label="Status", interactive=False) | |
| molecule_image_output = gr.Image(label="Molecule Structure", type="pil", interactive=False) | |
| # Automatically populate on load for the default example | |
| demo.load( | |
| lambda: display_molecule_image("C1=CC=CC=C1"), | |
| inputs=None, | |
| outputs=[molecule_image_output, status_viewer] | |
| ) | |
| view_button_molecule.click( | |
| display_molecule_image, | |
| inputs=[smiles_input_viewer], | |
| outputs=[molecule_image_output, status_viewer] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |