Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import warnings | |
| import subprocess | |
| import argparse | |
| import json | |
| import pandas as pd | |
| import gradio as gr | |
| import datamol as dm | |
| from rdkit import RDLogger | |
| from typing import Dict, Any, Optional | |
| from transformers import GenerationConfig | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore", message="DEPRECATION WARNING: please use MorganGenerator") | |
| RDLogger.DisableLog('rdApp.*') | |
| from boring_utils.utils import cprint, tprint, get_device | |
| from boring_utils.helpers import DEBUG | |
| # ============================== | |
| # Config | |
| # ============================== | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model_path', type=str, default="checkpoint/fraglm_llama_240710/checkpoint-500000", help='Path to the model') | |
| parser.add_argument('--tokenizer_path', type=str, default="tokenizer/fraglm_2406_bpe_8k.json", help='Path to the model') | |
| args = parser.parse_args() | |
| HF_SPACE = os.getenv('HF_SPACE', False) | |
| SHARE_SPACE = HF_SPACE | |
| REQUIRE_EMAIL = os.getenv('REQUIRE_EMAIL', 'True').lower() == 'true' | |
| HF_MODEL = "YDS-Pharmatech/FragLlama-base" | |
| HF_TOKENIZER_PATH = "/data/fraglm/tokenizer/fraglm_2406_bpe_8k.json" | |
| LOCAL_MODEL = args.model_path | |
| LOCAL_TOKENIZER_PATH = args.tokenizer_path | |
| device = get_device() | |
| # ============================== | |
| # Load Model | |
| # ============================== | |
| def install_and_import(package): | |
| import importlib | |
| package_path = f"/data/{package}" | |
| # Always try to update the repository first if it exists | |
| if os.path.exists(os.path.join(package_path, '.git')): | |
| print(f"Updating {package} repository...") | |
| try: | |
| subprocess.check_call(['git', '-C', package_path, 'pull']) | |
| print(f"Successfully updated {package}") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Warning: Failed to update {package}: {e}") | |
| try: | |
| # Try to import after potential update | |
| return importlib.import_module(package) | |
| except ImportError: | |
| print(f"{package} not found, attempting to install...") | |
| # Install the package | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "-e", package_path]) | |
| print(f"{package} installed successfully") | |
| return importlib.import_module(package) | |
| if HF_SPACE: | |
| # TODO: move the tmp csv to the docker temp folder | |
| os.makedirs("/data/tmp", exist_ok=True) | |
| sys.path.append("/data/fraglm") | |
| # os.chdir("/data/fraglm") | |
| fraglm = install_and_import("fraglm") | |
| else: | |
| from fraglm.constants import PROJECT_HOME_DIR; os.chdir(PROJECT_HOME_DIR) | |
| from fraglm.inference import FragLMDesign | |
| from fraglm.utils import * | |
| from fraglm.trainer.model import FragLMLlamaModel | |
| from fraglm.inference.post_processing import PostProcessMode, PostProcessConfig | |
| from fraglm.ui_tools import * | |
| if DEBUG: | |
| import importlib.util | |
| spec = importlib.util.find_spec("fraglm.inference") | |
| print(f"fraglm.inference spec: {spec}") | |
| # print(f"Installed packages: {subprocess.check_output([sys.executable, '-m', 'pip', 'list']).decode()}") | |
| if HF_SPACE: | |
| model = FragLMLlamaModel.from_pretrained(HF_MODEL, token=os.getenv('HF_TOKEN')).to(device) | |
| designer = FragLMDesign(model=model, tokenizer=HF_TOKENIZER_PATH) | |
| else: | |
| model = FragLMLlamaModel.from_pretrained(LOCAL_MODEL).to(device) | |
| designer = FragLMDesign(model=model, tokenizer=LOCAL_TOKENIZER_PATH) | |
| DEFAULT_GEN_CONFIG = GenerationConfig.from_model_config(model.config).to_dict() | |
| def parse_generation_config(config_str: str, default_config: Dict[str, Any] = DEFAULT_GEN_CONFIG) -> GenerationConfig: | |
| """ | |
| Parse the generation config string and create a GenerationConfig object. | |
| Allows partial overwrite of the default config. | |
| """ | |
| try: | |
| # Make a copy of default_config to avoid modifying it | |
| config_dict = default_config.copy() | |
| if config_str: | |
| # Update with user provided config | |
| config_dict.update(json.loads(config_str)) | |
| return GenerationConfig(**config_dict) | |
| except json.JSONDecodeError: | |
| # If parsing fails, return the default config | |
| return GenerationConfig(**default_config) | |
| # ============================== | |
| # Inference Code | |
| # ============================== | |
| def create_designer(gen_config_str): | |
| gen_config = parse_generation_config(gen_config_str) | |
| if HF_SPACE: | |
| model = FragLMLlamaModel.from_pretrained(HF_MODEL, token=os.getenv('HF_TOKEN')).to(device) | |
| if gen_config: | |
| designer = FragLMDesign(model=model, tokenizer=HF_TOKENIZER_PATH, generation_config=gen_config) | |
| else: | |
| designer = FragLMDesign(model=model, tokenizer=HF_TOKENIZER_PATH) | |
| else: | |
| model = FragLMLlamaModel.from_pretrained(LOCAL_MODEL).to(device) | |
| if gen_config: | |
| designer = FragLMDesign(model=model, tokenizer=LOCAL_TOKENIZER_PATH, generation_config=gen_config) | |
| else: | |
| designer = FragLMDesign(model=model, tokenizer=LOCAL_TOKENIZER_PATH) | |
| return designer | |
| def scaffold_hopping(scaffold1, scaffold2, n_samples_per_trial, extra_params_dict: dict, gen_config_str: Optional[str] = None): | |
| """Scaffold hopping function using scaffold morphing""" | |
| tprint(f"UI Scaffold Hopping Debug Info") | |
| cprint(f"Input scaffold1: {scaffold1}") | |
| cprint(f"Input scaffold2: {scaffold2}") | |
| cprint(f"Samples requested: {n_samples_per_trial}") | |
| cprint(f"Extra params: {json.dumps(extra_params_dict, indent=2)}") | |
| cprint(f"Generation config: {gen_config_str}") | |
| scaffold1 = Chem.MolToSmiles(Chem.MolFromSmiles(scaffold1), isomericSmiles=False) | |
| scaffold2 = Chem.MolToSmiles(Chem.MolFromSmiles(scaffold2), isomericSmiles=False) | |
| side_chains = f"{scaffold1}.{scaffold2}" | |
| if gen_config_str: | |
| global designer | |
| designer = create_designer(gen_config_str) | |
| # Handle post processing configuration | |
| post_process_mode = extra_params_dict.pop("post_process_mode", "SELECT_LONGEST") | |
| if post_process_mode == "AGGRESSIVE_CONNECT": | |
| post_process_config = PostProcessConfig( | |
| mode=PostProcessMode.AGGRESSIVE_CONNECT, | |
| scaffold=extra_params_dict.pop("post_process_scaffold", None), | |
| num_attempts=extra_params_dict.pop("post_process_num_attempts", 5) | |
| ) | |
| else: | |
| post_process_config = PostProcessMode.SELECT_LONGEST | |
| kwargs = { | |
| 'side_chains': side_chains, | |
| 'n_samples_per_trial': n_samples_per_trial, | |
| 'sanitize': True, | |
| 'post_process_mode': post_process_config, | |
| **extra_params_dict | |
| } | |
| generated_smiles = execute_function(designer, 'scaffold_hopping', **kwargs) | |
| if not generated_smiles: | |
| return None, "Generation failed - no valid molecules produced", gr.Button(interactive=True), gr.Textbox(value=""), None | |
| success_rate = len(generated_smiles) / n_samples_per_trial | |
| success_message = f"Success Rate: {success_rate:.1%} ({len(generated_smiles)}/{n_samples_per_trial})" | |
| try: | |
| generated_mols = [dm.to_mol(x) for x in generated_smiles] | |
| img = dm.viz.lasso_highlight_image( | |
| generated_mols, | |
| dm.from_smarts(scaffold1), | |
| mol_size=(350, 200), | |
| color_list=["#ff80b5"], | |
| scale_padding=0.1, | |
| use_svg=False, | |
| n_cols=4 | |
| ) | |
| except Exception as e: | |
| print(f"Visualization error: {e}") | |
| img = dm.to_image( | |
| generated_smiles, | |
| mol_size=(350, 200), | |
| use_svg=False, | |
| ) | |
| df = pd.DataFrame({'SMILES': generated_smiles}) | |
| timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S') | |
| csv_path = f'generated_scaffold_smiles_{timestamp}_{scaffold1[:20]}_{scaffold2[:20]}.csv' if not HF_SPACE else f'generated_scaffold_smiles_{timestamp}_{scaffold1[:20]}_{scaffold2[:20]}.csv' | |
| df.to_csv(csv_path, index=False) | |
| return img, success_message, gr.Button(interactive=True), gr.Textbox(value=""), csv_path | |
| def fragment_growth(motif, n_samples_per_trial, extra_params_dict: dict, gen_config_str: Optional[str] = None): | |
| """Fragment growth function""" | |
| tprint(f"UI Fragment Growth Debug Info") | |
| cprint(f"Input motif: {motif}") | |
| cprint(f"Samples requested: {n_samples_per_trial}") | |
| cprint(f"Extra params: {json.dumps(extra_params_dict, indent=2)}") | |
| cprint(f"Generation config: {gen_config_str}") | |
| motif = Chem.MolToSmiles(Chem.MolFromSmiles(motif), isomericSmiles=False) | |
| if gen_config_str: | |
| global designer | |
| designer = create_designer(gen_config_str) | |
| # Handle post processing configuration | |
| post_process_mode = extra_params_dict.pop("post_process_mode", "SELECT_LONGEST") | |
| if post_process_mode == "AGGRESSIVE_CONNECT": | |
| post_process_config = PostProcessConfig( | |
| mode=PostProcessMode.AGGRESSIVE_CONNECT, | |
| scaffold=extra_params_dict.pop("post_process_scaffold", None), | |
| num_attempts=extra_params_dict.pop("post_process_num_attempts", 5) | |
| ) | |
| else: | |
| post_process_config = PostProcessMode.SELECT_LONGEST | |
| kwargs = { | |
| 'motif': motif, | |
| 'n_samples_per_trial': n_samples_per_trial, | |
| 'sanitize': True, | |
| 'post_process_mode': post_process_config, | |
| **extra_params_dict | |
| } | |
| generated_smiles = execute_function(designer, 'fragment_growth', **kwargs) | |
| if DEBUG: | |
| tprint(f"UI Results Debug Info") | |
| cprint(f"Generated SMILES: {generated_smiles}") | |
| cprint(f"Type: {type(generated_smiles)}") | |
| cprint(f"Length: {len(generated_smiles) if generated_smiles else 0}") | |
| if not generated_smiles or not isinstance(generated_smiles, (list, tuple)) or len(generated_smiles) == 0: | |
| tprint(f"UI Generation failed - empty or invalid result", sep="*") | |
| return None, "Generation failed - no valid molecules produced", gr.Button(interactive=True), gr.Textbox(value=""), None | |
| valid_smiles = [s for s in generated_smiles if s and Chem.MolFromSmiles(s)] | |
| if not valid_smiles: | |
| tprint(f"UI Generation failed - no valid molecules after filtering", sep="*") | |
| return None, "Generation failed - no valid molecules produced", gr.Button(interactive=True), gr.Textbox(value=""), None | |
| success_rate = len(valid_smiles) / n_samples_per_trial | |
| success_message = f"Success Rate: {success_rate:.1%} ({len(valid_smiles)}/{n_samples_per_trial})" | |
| try: | |
| generated_mols = [dm.to_mol(x) for x in valid_smiles] | |
| img = dm.viz.lasso_highlight_image( | |
| generated_mols, | |
| dm.from_smarts(motif), | |
| mol_size=(350, 200), | |
| color_list=["#ff80b5"], | |
| scale_padding=0.1, | |
| use_svg=False, | |
| n_cols=4 | |
| ) | |
| except Exception as e: | |
| print(f"Visualization error: {e}") | |
| img = dm.to_image( | |
| valid_smiles, | |
| mol_size=(350, 200), | |
| use_svg=False, | |
| ) | |
| df = pd.DataFrame({'SMILES': valid_smiles}) | |
| timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S') | |
| csv_path = f'generated_motif_smiles_{timestamp}_{motif[:20]}.csv' if not HF_SPACE else f'generated_motif_smiles_{timestamp}_{motif[:20]}.csv' | |
| df.to_csv(csv_path, index=False) | |
| return img, success_message, gr.Button(interactive=True), gr.Textbox(value=""), csv_path | |
| def linker_design(linker1, linker2, n_samples_per_trial, extra_params_dict: dict, gen_config_str: Optional[str] = None): | |
| """Linker design function""" | |
| tprint(f"UI Linker Design Debug Info") | |
| cprint(f"Input linker1: {linker1}") | |
| cprint(f"Input linker2: {linker2}") | |
| cprint(f"Samples requested: {n_samples_per_trial}") | |
| cprint(f"Extra params: {json.dumps(extra_params_dict, indent=2)}") | |
| cprint(f"Generation config: {gen_config_str}") | |
| linker1 = Chem.MolToSmiles(Chem.MolFromSmiles(linker1), isomericSmiles=False) | |
| linker2 = Chem.MolToSmiles(Chem.MolFromSmiles(linker2), isomericSmiles=False) | |
| if gen_config_str: | |
| global designer | |
| designer = create_designer(gen_config_str) | |
| kwargs = { | |
| 'n_samples_per_trial': n_samples_per_trial, | |
| 'sanitize': True, | |
| 'random_seed': 100, | |
| 'post_process_mode': PostProcessMode.SELECT_LONGEST, | |
| **extra_params_dict | |
| } | |
| # Pass linkers as positional args | |
| generated_smiles = execute_function( | |
| designer, | |
| 'linker_design', | |
| groups=[linker1, linker2], # Pass linkers as positional args | |
| **kwargs | |
| ) | |
| if not generated_smiles: | |
| return None, "Generation failed - no valid molecules produced", gr.Button(interactive=True), gr.Textbox(value=""), None | |
| success_rate = len(generated_smiles) / n_samples_per_trial | |
| success_message = f"Success Rate: {success_rate:.1%} ({len(generated_smiles)}/{n_samples_per_trial})" | |
| try: | |
| generated_mols = [dm.to_mol(x) for x in generated_smiles] | |
| img = dm.viz.lasso_highlight_image( | |
| generated_mols, | |
| [dm.from_smarts(linker1), dm.from_smarts(linker2)], | |
| mol_size=(350, 200), | |
| color_list=["#ff80b5"], | |
| scale_padding=0.1, | |
| use_svg=False, | |
| n_cols=4 | |
| ) | |
| except Exception as e: | |
| print(f"Visualization error: {e}") | |
| img = dm.to_image( | |
| generated_smiles, | |
| mol_size=(350, 200), | |
| use_svg=False, | |
| ) | |
| df = pd.DataFrame({'SMILES': generated_smiles}) | |
| timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S') | |
| csv_path = f'generated_linker_smiles_{timestamp}_{linker1[:20]}_{linker2[:20]}.csv' if not HF_SPACE else f'generated_linker_smiles_{timestamp}_{linker1[:20]}_{linker2[:20]}.csv' | |
| df.to_csv(csv_path, index=False) | |
| return img, success_message, gr.Button(interactive=True), gr.Textbox(value=""), csv_path | |
| # TODO: change verify email to submit? | |
| def verify_email(email): | |
| if "@" in email and "." in email: | |
| return True, EMAIL_VERIFIED_MESSAGE | |
| return False, "Invalid email format" | |
| # ============================== | |
| # UI | |
| # ============================== | |
| with gr.Blocks(theme=gr.themes.Citrus()) as demo: | |
| gr.Markdown("# FragLlama Demo") | |
| gr.HTML(VIDEO_MESSAGE) | |
| with gr.Row(visible=REQUIRE_EMAIL): | |
| email_input = gr.Textbox( | |
| label="", | |
| placeholder="Enter your email to unlock generation", | |
| type="email", | |
| submit_btn="Send result to my Email", | |
| value="" if REQUIRE_EMAIL else "disabled@example.com" | |
| ) | |
| # Global generation config | |
| gen_config_input = gr.Textbox( | |
| label="Generation Config (JSON format)", | |
| placeholder='{"max_length": 200}', | |
| value='{}', | |
| visible=False | |
| ) | |
| # Common parameter creation function | |
| def create_common_params(show_aggressive_gen=False): | |
| # Number of molecules to generate in one run | |
| n_samples_per_trial = gr.Slider(1, 100, 20, step=1, label="Number of generated molecules") | |
| with gr.Accordion("Advanced Options", open=False): | |
| # Minimum number of atoms in generated molecules | |
| min_length = gr.Number( | |
| value=10, | |
| label="Min Length", | |
| info="Minimum number of atoms in generated molecules", | |
| maximum=50 | |
| ) | |
| # Maximum number of atoms in generated molecules | |
| max_length = gr.Number( | |
| value=80, | |
| label="Max Length", | |
| info="Maximum number of atoms in generated molecules", | |
| maximum=120 | |
| ) | |
| # Whether to keep input fragments intact without further fragmentation | |
| do_not_fragment = gr.Checkbox( | |
| label="Keep Input Fragments Intact", | |
| value=False, | |
| info="If checked, input fragments will be kept intact without further breaking down", | |
| visible=False | |
| ) | |
| # Experimental option for generating longer molecules | |
| aggressive_gen = gr.Checkbox( | |
| label="(Experimental) Long Molecule Generation", | |
| value=False, | |
| info="Enable aggressive connection mode for generating longer molecules", | |
| # visible=show_aggressive_gen, | |
| visible=False | |
| ) | |
| # Additional parameters in JSON format | |
| extra_params = gr.Textbox( | |
| label="Extra Parameters (JSON format)", | |
| placeholder='{"sanitize": "False", "other_param": value}', | |
| info="Additional parameters in JSON format for advanced control" | |
| ) | |
| # Hidden JSON field for storing combined parameters | |
| extra_dict = gr.JSON( | |
| value={}, # Empty initially, will be updated via JavaScript | |
| visible=False # Hide this from UI | |
| ) | |
| return n_samples_per_trial, do_not_fragment, min_length, max_length, extra_params, extra_dict, aggressive_gen | |
| def visualize_input(smiles): | |
| if not smiles: | |
| return None | |
| try: | |
| mol = dm.to_mol(smiles) | |
| if mol is None: | |
| return None | |
| img = dm.to_image(mol, mol_size=(350, 200), use_svg=False) | |
| return img | |
| except: | |
| return None | |
| # Update extra_dict whenever advanced parameters change | |
| def update_extra_dict(do_not_fragment, min_length, max_length, extra_params, aggressive_gen=False, scaffold=None): | |
| extra_dict = { | |
| "do_not_fragment_further": do_not_fragment, | |
| "min_length": min_length, | |
| "max_length": max_length, | |
| } | |
| # Add post_process_mode based on aggressive_merge | |
| if aggressive_gen: | |
| extra_dict["post_process_mode"] = "AGGRESSIVE_CONNECT" | |
| if scaffold: | |
| extra_dict["post_process_scaffold"] = scaffold | |
| extra_dict["post_process_num_attempts"] = 5 | |
| else: | |
| extra_dict["post_process_mode"] = "SELECT_LONGEST" | |
| # Update with any additional parameters from extra_params | |
| try: | |
| if extra_params: | |
| extra_dict.update(json.loads(extra_params)) | |
| except json.JSONDecodeError: | |
| pass | |
| return extra_dict | |
| # Scaffold Hopping tab | |
| with gr.Tab("Scaffold Hopping"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| scaffold1_input = gr.Textbox(label="Scaffold 1") | |
| scaffold1_input.placeholder = PLACEHOLDER_SCAFFOLD1 | |
| with gr.Column(): | |
| scaffold1_preview = gr.Image(label="Input Preview", type="pil") | |
| scaffold1_input.change( | |
| fn=visualize_input, | |
| inputs=[scaffold1_input], | |
| outputs=[scaffold1_preview] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| scaffold2_input = gr.Textbox(label="Scaffold 2") | |
| scaffold2_input.placeholder = PLACEHOLDER_SCAFFOLD2 | |
| with gr.Column(): | |
| scaffold2_preview = gr.Image(label="Input Preview", type="pil") | |
| scaffold2_input.change( | |
| fn=visualize_input, | |
| inputs=[scaffold2_input], | |
| outputs=[scaffold2_preview] | |
| ) | |
| (n_samples_per_trial, | |
| do_not_fragment, min_length, max_length, | |
| extra_params, extra_dict, aggressive_gen) = create_common_params(show_aggressive_gen=True) | |
| scaffold_button = gr.Button("Generate", interactive=False) | |
| scaffold_output = gr.Image(type="pil", label="Examples of Generated Molecules") | |
| scaffold_success = gr.Textbox(label="Generation Statistics") | |
| scaffold_send = gr.Button("Send Results", interactive=False) | |
| scaffold_send_status = gr.Textbox(label="Send Status", value="") | |
| scaffold_csv_path = gr.Textbox(visible=False) | |
| # Connect the update function | |
| for param in [do_not_fragment, min_length, max_length, extra_params]: | |
| param.change( | |
| fn=update_extra_dict, | |
| inputs=[do_not_fragment, min_length, max_length, extra_params], | |
| outputs=[extra_dict] | |
| ) | |
| scaffold_button.click( | |
| scaffold_hopping, | |
| inputs=[ | |
| scaffold1_input, | |
| scaffold2_input, | |
| n_samples_per_trial, | |
| extra_dict, | |
| gen_config_input, | |
| ], | |
| outputs=[ | |
| scaffold_output, | |
| scaffold_success, | |
| scaffold_send, | |
| scaffold_send_status, | |
| scaffold_csv_path | |
| ] | |
| ) | |
| scaffold_send.click( | |
| fn=send_result, | |
| inputs=[ | |
| email_input, | |
| scaffold_csv_path, | |
| gr.Textbox(value="Scaffold Hopping", visible=False) | |
| ], | |
| outputs=[scaffold_send_status] | |
| ) | |
| # Fragment Growth tab | |
| with gr.Tab("Fragment Growth"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| motif_input = gr.Textbox(label="Fragment") | |
| motif_input.placeholder = PLACEHOLDER_MOTIF | |
| with gr.Column(): | |
| motif_preview = gr.Image(label="Input Preview", type="pil") | |
| motif_input.change( | |
| fn=visualize_input, | |
| inputs=[motif_input], | |
| outputs=[motif_preview] | |
| ) | |
| (n_samples_per_trial, | |
| do_not_fragment, min_length, max_length, | |
| extra_params, extra_dict, aggressive_gen) = create_common_params() | |
| motif_button = gr.Button("Generate", interactive=False) | |
| motif_output = gr.Image(type="pil", label="Examples of Generated Molecules") | |
| motif_success = gr.Textbox(label="Generation Statistics") | |
| motif_send = gr.Button("Send Results", interactive=False) | |
| motif_send_status = gr.Textbox(label="Send Status", value="") | |
| motif_csv_path = gr.Textbox(visible=False) | |
| # Connect the update function | |
| for param in [do_not_fragment, min_length, max_length, extra_params, aggressive_gen]: | |
| param.change( | |
| fn=update_extra_dict, | |
| inputs=[do_not_fragment, min_length, max_length, extra_params, aggressive_gen, motif_input], | |
| outputs=[extra_dict] | |
| ) | |
| motif_button.click( | |
| fragment_growth, | |
| inputs=[ | |
| motif_input, | |
| n_samples_per_trial, | |
| extra_dict, | |
| gen_config_input, | |
| ], | |
| outputs=[ | |
| motif_output, | |
| motif_success, | |
| motif_send, | |
| motif_send_status, | |
| motif_csv_path | |
| ] | |
| ) | |
| motif_send.click( | |
| fn=send_result, | |
| inputs=[ | |
| email_input, | |
| motif_csv_path, | |
| gr.Textbox(value="Fragment Growth", visible=False) | |
| ], | |
| outputs=[motif_send_status] | |
| ) | |
| # Linker Design tab | |
| with gr.Tab("Linker Design"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| linker1_input = gr.Textbox(label="Linker 1") | |
| linker1_input.placeholder = PLACEHOLDER_LINKER1 | |
| with gr.Column(): | |
| linker1_preview = gr.Image(label="Input Preview", type="pil") | |
| linker1_input.change( | |
| fn=visualize_input, | |
| inputs=[linker1_input], | |
| outputs=[linker1_preview] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| linker2_input = gr.Textbox(label="Linker 2") | |
| linker2_input.placeholder = PLACEHOLDER_LINKER2 | |
| with gr.Column(): | |
| linker2_preview = gr.Image(label="Input Preview", type="pil") | |
| linker2_input.change( | |
| fn=visualize_input, | |
| inputs=[linker2_input], | |
| outputs=[linker2_preview] | |
| ) | |
| (n_samples_per_trial, | |
| do_not_fragment, min_length, max_length, | |
| extra_params, extra_dict, aggressive_gen) = create_common_params() | |
| linker_button = gr.Button("Generate", interactive=False) | |
| linker_output = gr.Image(type="pil", label="Examples of Generated Molecules") | |
| linker_success = gr.Textbox(label="Generation Statistics") | |
| linker_send = gr.Button("Send Results", interactive=False) | |
| linker_send_status = gr.Textbox(label="Send Status", value="") | |
| linker_csv_path = gr.Textbox(visible=False) | |
| # Connect the update function | |
| for param in [do_not_fragment, min_length, max_length, extra_params]: | |
| param.change( | |
| fn=update_extra_dict, | |
| inputs=[do_not_fragment, min_length, max_length, extra_params], | |
| outputs=[extra_dict] | |
| ) | |
| linker_button.click( | |
| linker_design, | |
| inputs=[ | |
| linker1_input, | |
| linker2_input, | |
| n_samples_per_trial, | |
| extra_dict, | |
| gen_config_input, | |
| ], | |
| outputs=[ | |
| linker_output, | |
| linker_success, | |
| linker_send, | |
| linker_send_status, | |
| linker_csv_path | |
| ] | |
| ) | |
| linker_send.click( | |
| fn=send_result, | |
| inputs=[ | |
| email_input, | |
| linker_csv_path, | |
| gr.Textbox(value="Linker Design", visible=False) | |
| ], | |
| outputs=[linker_send_status] | |
| ) | |
| with gr.Tab("Advanced Global Settings"): | |
| gr.Markdown(""" | |
| # Generation Config Settings | |
| - Default config will be used if not specified | |
| - You can partially override specific parameters | |
| - Example: {"max_length": 200} will only override max_length | |
| - Reference: https://huggingface.co/docs/transformers/main/en/main_classes/text_generation | |
| ## Available Parameters | |
| - max_length: Maximum length of generated sequence | |
| - min_length: Minimum length of generated sequence | |
| - temperature: Higher values produce more diverse outputs | |
| - top_p: Nucleus sampling threshold | |
| - top_k: Top-k sampling threshold | |
| - ... | |
| """) | |
| # gen_config_input.render() | |
| # Create a new textbox and store the reference | |
| config_editor = gr.Textbox( | |
| label="Generation Config (JSON format)", | |
| placeholder='{"max_length": 200}', | |
| value='{}', | |
| interactive=True, | |
| ) | |
| # Use the reference in change event | |
| config_editor.change( | |
| lambda x: x, | |
| inputs=[config_editor], | |
| outputs=[gen_config_input] | |
| ) | |
| with gr.Tab("Contact Us"): | |
| gr.Markdown(ABOUT_MESSAGE) | |
| def update_button_states(email): | |
| if not REQUIRE_EMAIL: | |
| is_valid = True | |
| message = "Email verification disabled" | |
| return [ | |
| gr.Button(interactive=True), # scaffold_button | |
| gr.Button(interactive=True), # motif_button | |
| gr.Button(interactive=True), # linker_button | |
| gr.Button(interactive=False), # scaffold_send - force disable | |
| gr.Button(interactive=False), # motif_send - force disable | |
| gr.Button(interactive=False) # linker_send - force disable | |
| ] | |
| else: | |
| is_valid, message = verify_email(email) | |
| gr.Info(message) | |
| return [ | |
| gr.Button(interactive=is_valid), # scaffold_button | |
| gr.Button(interactive=is_valid), # motif_button | |
| gr.Button(interactive=is_valid), # linker_button | |
| gr.Button(interactive=is_valid), # scaffold_send | |
| gr.Button(interactive=is_valid), # motif_send | |
| gr.Button(interactive=is_valid) # linker_send | |
| ] | |
| if not REQUIRE_EMAIL: | |
| demo.load( | |
| fn=lambda: update_button_states("disabled@example.com"), | |
| outputs=[scaffold_button, motif_button, linker_button, scaffold_send, motif_send, linker_send] | |
| ) | |
| email_input.submit( | |
| fn=update_button_states, | |
| inputs=[email_input], | |
| outputs=[scaffold_button, motif_button, linker_button, scaffold_send, motif_send, linker_send] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=SHARE_SPACE) |