Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from streamlit_ketcher import st_ketcher | |
| import torch | |
| # --- Page Configuration --- | |
| st.set_page_config( | |
| page_title="Chemical Reaction Predictor", | |
| page_icon="🧪", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # --- Model Loading --- | |
| # Use st.cache_resource to load the model only once | |
| def load_model(): | |
| """ | |
| Loads the T5 model and tokenizer from Hugging Face. | |
| Uses AutoModel for better compatibility. | |
| """ | |
| model_name = "sagawa/ReactionT5v2-forward-USPTO_MIT" | |
| try: | |
| # Use Auto* classes for robustness | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| return model, tokenizer | |
| except Exception as e: | |
| # Provide more detailed error information | |
| st.error("An error occurred while loading the model.") | |
| st.error(f"Error Type: {type(e).__name__}") | |
| st.error(f"Error Details: {e}") | |
| # Add a hint about potential memory issues on Hugging Face Spaces | |
| st.info("Hint: Free tiers on Hugging Face Spaces have limited memory (RAM). " | |
| "If the app fails to load the model, it might be due to an Out-of-Memory error. " | |
| "Consider upgrading your Space for more resources.") | |
| return None, None | |
| # --- Core Functions --- | |
| def predict_product(reactants, reagents, model, tokenizer, num_predictions): | |
| """Predicts the reaction product using the T5 model.""" | |
| # Format the input string as required by the model | |
| # Handle the case where reagents might be empty | |
| if reagents and reagents.strip(): | |
| input_text = f"reactants>{reactants}.reagents>{reagents}>products>" | |
| else: | |
| input_text = f"reactants>{reactants}>products>" | |
| input_ids = tokenizer.encode(input_text, return_tensors='pt') | |
| # Generate predictions using beam search | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=512, | |
| num_beams=num_predictions * 2, # Generate more beams for better diversity | |
| num_return_sequences=num_predictions, | |
| early_stopping=True, | |
| ) | |
| # Decode predictions | |
| predictions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| return predictions | |
| def display_molecule(smiles_string, legend): | |
| """Generates and displays a molecule image from a SMILES string.""" | |
| if not smiles_string: | |
| st.warning("Received an empty SMILES string.") | |
| return | |
| mol = Chem.MolFromSmiles(smiles_string) | |
| if mol: | |
| try: | |
| img = Draw.MolToImage(mol, size=(300, 300), legend=legend) | |
| st.image(img, use_column_width='auto') | |
| except Exception as e: | |
| st.warning(f"Could not generate image for SMILES: {smiles_string}. Error: {e}") | |
| else: | |
| st.warning(f"Invalid SMILES string provided: {smiles_string}") | |
| # --- Initialize Session State --- | |
| # This ensures that the state is preserved across reruns | |
| if 'reactants' not in st.session_state: | |
| st.session_state.reactants = "CCO.O=C(O)C" # Start with a default example | |
| if 'reagents' not in st.session_state: | |
| st.session_state.reagents = "" | |
| # --- Sidebar UI --- | |
| with st.sidebar: | |
| st.title("🧪 Reaction Predictor") | |
| st.markdown("---") | |
| st.header("Controls and Information") | |
| # Example Reactions | |
| example_reactions = { | |
| "Esterification": ("CCO.O=C(O)C", ""), | |
| "Amide Formation": ("CCN.O=C(Cl)C", ""), | |
| "Suzuki Coupling": ("[B-](C1=CC=CC=C1)(F)(F)F.[K+].CC1=CC=C(Br)C=C1", "c1ccc(B(O)O)cc1"), | |
| "Clear Inputs": ("", "") | |
| } | |
| def load_example(): | |
| # Callback to load selected example into session state | |
| example_key = st.session_state.example_select | |
| reactants, reagents = example_reactions[example_key] | |
| st.session_state.reactants = reactants | |
| st.session_state.reagents = reagents | |
| st.selectbox( | |
| "Load an Example Reaction", | |
| options=list(example_reactions.keys()), | |
| key="example_select", | |
| on_change=load_example | |
| ) | |
| st.markdown("---") | |
| st.subheader("Prediction Parameters") | |
| num_predictions = st.slider("Number of Predictions to Generate", 1, 5, 1, help="How many potential products should the model suggest?") | |
| st.markdown("---") | |
| st.subheader("About") | |
| st.info( | |
| "This app uses the sagawa/ReactionT5v2-forward-USPTO_MIT model to predict chemical reaction products." | |
| ) | |
| st.markdown("[View Model on Hugging Face](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT)") | |
| # --- Main Application UI --- | |
| st.title("Chemical Reaction Predictor") | |
| st.markdown("A tool to predict chemical reactions using a state-of-the-art Transformer model.") | |
| # --- Model Loading and Main Logic --- | |
| with st.spinner("Loading the prediction model... This may take a moment on first startup."): | |
| model, tokenizer = load_model() | |
| # Only proceed if the model loaded successfully | |
| if model and tokenizer: | |
| st.success("Model loaded successfully!") | |
| # Input Section | |
| st.header("1. Provide Reactants and Reagents") | |
| input_tab1, input_tab2 = st.tabs(["✍️ Chemical Drawing Tool", "⌨️ SMILES Text Input"]) | |
| with input_tab1: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Reactants") | |
| # This component's value is now directly tied to the session state | |
| reactant_smiles_drawing = st_ketcher(st.session_state.reactants, key="ketcher_reactants") | |
| if reactant_smiles_drawing != st.session_state.reactants: | |
| st.session_state.reactants = reactant_smiles_drawing | |
| st.rerun() # Use the modern rerun command | |
| with col2: | |
| st.subheader("Reagents (Optional)") | |
| reagent_smiles_drawing = st_ketcher(st.session_state.reagents, key="ketcher_reagents") | |
| if reagent_smiles_drawing != st.session_state.reagents: | |
| st.session_state.reagents = reagent_smiles_drawing | |
| st.rerun() | |
| with input_tab2: | |
| st.subheader("Enter SMILES Strings") | |
| # Text inputs now also directly update the session state on change | |
| st.text_input("Reactants SMILES", key="reactant_text", value=st.session_state.reactants, on_change=lambda: setattr(st.session_state, 'reactants', st.session_state.reactant_text)) | |
| st.text_input("Reagents SMILES", key="reagent_text", value=st.session_state.reagents, on_change=lambda: setattr(st.session_state, 'reagents', st.session_state.reagent_text)) | |
| # Display the current state clearly | |
| st.info(f"**Current Reactants:** `{st.session_state.reactants}`") | |
| st.info(f"**Current Reagents:** `{st.session_state.reagents or 'None'}`") | |
| # Prediction Button | |
| st.header("2. Generate Prediction") | |
| if st.button("Predict Product", type="primary", use_container_width=True): | |
| if not st.session_state.reactants or not st.session_state.reactants.strip(): | |
| st.error("Error: Reactants field cannot be empty. Please provide a molecule.") | |
| else: | |
| with st.spinner("Running prediction..."): | |
| predictions = predict_product( | |
| st.session_state.reactants, | |
| st.session_state.reagents, | |
| model, | |
| tokenizer, | |
| num_predictions | |
| ) | |
| st.header("3. Predicted Products") | |
| if not predictions: | |
| st.warning("The model did not return any predictions.") | |
| else: | |
| for i, product_smiles in enumerate(predictions): | |
| st.subheader(f"Top Prediction #{i + 1}") | |
| st.code(product_smiles, language="smiles") | |
| display_molecule(product_smiles, f"Predicted Product #{i + 1}") | |
| elif not model or not tokenizer: | |
| st.error("Application could not start because the model failed to load. Please check the error messages above.") |