Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import T5ForConditionalGeneration, AutoTokenizer | |
| import torch | |
| # Load the model and tokenizer | |
| model_name = "smitathkr1/ord-forward-t5" | |
| print(f"Loading model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = T5ForConditionalGeneration.from_pretrained(model_name) | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| def predict_reaction(reactants_smiles, max_length=150, num_beams=5, temperature=1.0): | |
| """ | |
| Predict the product of a chemical reaction given reactants in SMILES format. | |
| Args: | |
| reactants_smiles: SMILES string of reactants (can be separated by '.') | |
| max_length: Maximum length of generated sequence | |
| num_beams: Number of beams for beam search | |
| temperature: Sampling temperature | |
| Returns: | |
| Predicted product SMILES | |
| """ | |
| try: | |
| # Prepare input | |
| input_text = reactants_smiles.strip() | |
| # Tokenize input | |
| inputs = tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| padding=True | |
| ).to(device) | |
| # Generate prediction | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| temperature=temperature, | |
| early_stopping=True, | |
| do_sample=False | |
| ) | |
| # Decode output | |
| predicted_product = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return predicted_product | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Example inputs from ORD dataset - [reactants, max_length, num_beams, temperature] | |
| examples = [ | |
| ["CC(C)N1CCN(C)CC1.Brc1ccccc1", 150, 5, 1.0], # Buchwald-Hartwig amination example | |
| ["CCN1CCNCC1.Ic1ccccc1", 150, 5, 1.0], # Another coupling reaction | |
| ["CC(=O)N1CCNCC1.Clc1ccccc1", 150, 5, 1.0], # Chloro coupling | |
| ] | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_reaction, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Reactants (SMILES)", | |
| placeholder="Enter reactants in SMILES format (e.g., CC(C)N1CCN(C)CC1.Brc1ccccc1)", | |
| lines=2 | |
| ), | |
| gr.Slider( | |
| minimum=50, | |
| maximum=300, | |
| value=150, | |
| step=10, | |
| label="Max Length" | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Num Beams" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| ], | |
| outputs=gr.Textbox( | |
| label="Predicted Product (SMILES)", | |
| lines=2 | |
| ), | |
| examples=examples, | |
| title="🧪 ORD Forward Reaction Prediction - T5 Model", | |
| description=""" | |
| ## Forward Reaction Prediction using T5 | |
| This model predicts chemical reaction products from reactants using a T5 model trained on 2.3M reactions from the Open Reaction Database (ORD). | |
| **Model:** `smitathkr1/ord-forward-t5` (5 epochs completed) | |
| **Dataset:** [smitathkr1/ord-reactions](https://huggingface.co/datasets/smitathkr1/ord-reactions) | |
| ### How to use: | |
| 1. Enter reactants in SMILES format (separate multiple reactants with '.') | |
| 2. Adjust generation parameters if needed | |
| 3. Click Submit to get the predicted product | |
| ### Example reactions: | |
| - Buchwald-Hartwig amination reactions | |
| - Various coupling reactions from the ORD database | |
| """, | |
| article=""" | |
| ### About the Model | |
| This T5 model was trained on 2.3 million reactions from the Open Reaction Database. | |
| The training has completed 5 epochs so far. | |
| ### Citation | |
| If you use this model, please cite the Open Reaction Database: | |
| - [Open Reaction Database](https://open-reaction-database.org/) | |
| ### Notes | |
| - Input should be valid SMILES strings | |
| - The model predicts forward reactions (reactants → products) | |
| - Adjust beam search parameters for different prediction strategies | |
| """, | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch() | |