| import gradio as gr |
| import json |
| from main import ChemEagle |
| from rdkit import Chem |
| from rdkit.Chem import rdChemReactions, Draw, AllChem |
| from rdkit.Chem.Draw import rdMolDraw2D |
| import cairosvg |
| import re |
| import os |
|
|
| example_diagram = "examples/exp.png" |
| rdkit_image = "examples/rdkit.png" |
|
|
|
|
| def parse_reactions(output_json): |
| if isinstance(output_json, str): |
| reactions_data = json.loads(output_json) |
| else: |
| reactions_data = output_json |
| reactions_list = reactions_data.get("reactions", []) |
| detailed_output = [] |
| smiles_output = [] |
|
|
| for reaction in reactions_list: |
| reaction_id = reaction.get("reaction_id", "Unknown ID") |
| reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] |
| conds = reaction.get("conditions") |
| if conds is None: |
| conds = reaction.get("condition", []) |
| conditions = [ |
| f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" |
| for c in conds |
| ] |
| conditions_1 = [ |
| f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" |
| for c in conds |
| ] |
| products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] |
| products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] |
| products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] |
| additional = reaction.get("additional_info", []) |
| additional_str = [str(x) for x in additional if x] |
|
|
| tail = conditions_1 + additional_str |
| tail_str = ", ".join(tail) |
| full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}" |
| full_reaction = f"<span style='color:black'>{full_reaction}</span>" |
|
|
| reaction_output = f"<b>Reaction: </b> {reaction_id}<br>" |
| reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>" |
| reaction_output += f" Conditions: {', '.join(conditions)}<br>" |
| reaction_output += f" Products: {', '.join(products)}<br>" |
| reaction_output += f" additional_info: {', '.join(additional_str)}<br>" |
| reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br><br>" |
| detailed_output.append(reaction_output) |
|
|
| reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" |
| smiles_output.append(reaction_smiles) |
|
|
| return detailed_output, smiles_output |
|
|
|
|
| def parse_mol(output_json): |
| """ |
| 解析单分子/多分子的 ChemEagle 输出,返回与 parse_reactions 相同的 detailed_output, smiles_output。 |
| """ |
| if isinstance(output_json, str): |
| mols_data = json.loads(output_json) |
| else: |
| mols_data = output_json |
| molecules_list = mols_data.get("molecules", []) |
| detailed_output = [] |
| smiles_output = [] |
|
|
| for i, mol in enumerate(molecules_list): |
| smiles = mol.get("smiles", "Unknown") |
| label = mol.get("label", f"Mol {i+1}") |
| bbox = mol.get("bbox", []) |
| |
| mol_output = f"<b>Molecule:</b> {label}<br>" \ |
| f" SMILES: <span style='color:blue'>{smiles}</span><br>" \ |
| f" bbox: {bbox}<br><br>" |
| detailed_output.append(mol_output) |
| smiles_output.append(smiles) |
| return detailed_output, smiles_output |
|
|
|
|
|
|
| def process_chem_image(image): |
| image_path = "temp_image.png" |
| image.save(image_path) |
|
|
| chemeagle_result = ChemEagle(image_path) |
| if "molecules" in chemeagle_result: |
| detailed, smiles = parse_mol(chemeagle_result) |
| else: |
| detailed, smiles = parse_reactions(chemeagle_result) |
|
|
| json_path = "output.json" |
| with open(json_path, 'w') as jf: |
| json.dump(chemeagle_result, jf, indent=2) |
|
|
| return "\n\n".join(detailed), smiles, example_diagram, json_path |
|
|
|
|
|
|
| def process_chem_image_api(image, api_key, endpoint): |
| |
| os.environ["API_KEY"] = api_key |
| os.environ["AZURE_ENDPOINT"] = endpoint or "" |
|
|
| image_path = "temp_image.png" |
| image.save(image_path) |
|
|
| |
| chemeagle_result = ChemEagle(image_path) |
| if "molecules" in chemeagle_result: |
| detailed, smiles = parse_mol(chemeagle_result) |
| else: |
| detailed, smiles = parse_reactions(chemeagle_result) |
|
|
| json_path = "output.json" |
| with open(json_path, 'w') as jf: |
| json.dump(chemeagle_result, jf, indent=2) |
|
|
| return "\n\n".join(detailed), smiles, example_diagram, json_path |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """ |
| <center><h1>ChemEAGLE: A Multi-Agent System Enables Versatile Information Extraction from the Chemical Literature</h1></center> |
| Upload a chemical graphic to extract machine-readable chemical data. |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="Upload a chemical graphic") |
| api_key_input = gr.Textbox(label="Azure API Key", type="password", placeholder="Enter your Azure API Key") |
| endpoint_input = gr.Textbox(label="Azure Endpoint", placeholder="e.g. https://xxx.openai.azure.com/") |
| |
| with gr.Row(): |
| clear_btn = gr.Button("Clear") |
| run_btn = gr.Button("Run", elem_id="submit-btn") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Parsed Reactions") |
| reaction_output = gr.HTML(label="Detailed Reaction Output") |
| gr.Markdown("### Schematic Diagram") |
| schematic_diagram = gr.Image(value=example_diagram, label="示意图") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Machine-readable Output") |
| smiles_output = gr.Textbox( |
| label="Reaction SMILES", |
| show_copy_button=True, |
| interactive=False, |
| visible=False |
| ) |
|
|
| @gr.render(inputs=smiles_output) |
| def show_split(inputs): |
| if not inputs or (isinstance(inputs, str) and inputs.strip() == ""): |
| return gr.Textbox(label="SMILES of Reaction or Molecule i"), gr.Image(value=rdkit_image, label="RDKit Image", height=100) |
| smiles_list = inputs.split(",") |
| smiles_list = [re.sub(r"^\s*\[?'?|']?\s*$", "", item) for item in smiles_list] |
| components = [] |
| for i, smiles in enumerate(smiles_list): |
| smiles_clean = smiles.replace('"', '').replace("'", "") |
| |
| components.append(gr.Textbox(value=smiles_clean, label=f"SMILES of Item {i}", show_copy_button=True, interactive=False)) |
| try: |
| |
| rxn = rdChemReactions.ReactionFromSmarts(smiles_clean, useSmiles=True) |
| is_rxn = rxn is not None and rxn.GetNumProductTemplates() > 0 |
| except Exception: |
| is_rxn = False |
| |
| if is_rxn: |
| try: |
| new_rxn = AllChem.ChemicalReaction() |
| for mol in rxn.GetReactants(): |
| mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) |
| new_rxn.AddReactantTemplate(mol) |
| for mol in rxn.GetProducts(): |
| mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) |
| new_rxn.AddProductTemplate(mol) |
| cleaned_rxn = new_rxn |
| |
| for react in cleaned_rxn.GetReactants(): |
| for atom in react.GetAtoms(): atom.SetAtomMapNum(0) |
| for prod in cleaned_rxn.GetProducts(): |
| for atom in prod.GetAtoms(): atom.SetAtomMapNum(0) |
| |
| react0 = cleaned_rxn.GetReactantTemplate(0) |
| react1 = cleaned_rxn.GetReactantTemplate(1) if cleaned_rxn.GetNumReactantTemplates() > 1 else None |
| if react0.GetNumBonds() > 0: |
| bond_len = Draw.MeanBondLength(react0) |
| elif react1 and react1.GetNumBonds() > 0: |
| bond_len = Draw.MeanBondLength(react1) |
| else: |
| bond_len = 1.0 |
| |
| drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1) |
| dopts = drawer.drawOptions() |
| dopts.padding = 0.1 |
| dopts.includeRadicals = True |
| Draw.SetACS1996Mode(dopts, bond_len * 0.55) |
| dopts.bondLineWidth = 1.5 |
| drawer.DrawReaction(cleaned_rxn) |
| drawer.FinishDrawing() |
| svg = drawer.GetDrawingText() |
| svg_file = f"reaction_{i}.svg" |
| with open(svg_file, "w") as f: f.write(svg) |
| png_file = f"reaction_{i}.png" |
| cairosvg.svg2png(url=svg_file, write_to=png_file) |
| components.append(gr.Image(value=png_file, label=f"RDKit Image of Reaction {i}")) |
| except Exception as e: |
| print(f"Failed to draw reaction {i} for SMILES '{smiles_clean}': {e}") |
| else: |
| |
| try: |
| mol = Chem.MolFromSmiles(smiles_clean) |
| if mol: |
| img = Draw.MolToImage(mol, size=(350, 150)) |
| img_file = f"mol_{i}.png" |
| img.save(img_file) |
| components.append(gr.Image(value=img_file, label=f"RDKit Image of Molecule {i}")) |
| else: |
| components.append(gr.Image(value=rdkit_image, label="Invalid Molecule")) |
| except Exception as e: |
| print(f"Failed to draw molecule {i} for SMILES '{smiles_clean}': {e}") |
| components.append(gr.Image(value=rdkit_image, label="Invalid Molecule")) |
| return components |
|
|
| download_json = gr.File(label="Download JSON File") |
|
|
| gr.Examples( |
| examples=[ |
| ["examples/reaction0.png"], |
| ["examples/reaction1.jpg"], |
| ["examples/reaction2.png"], |
| ["examples/reaction3.png"], |
| ["examples/reaction4.png"], |
| ["examples/reaction5.png"], |
| ["examples/template1.png"], |
| ["examples/molecules1.png"], |
|
|
| ], |
| inputs=[image_input], |
| outputs=[reaction_output, smiles_output, schematic_diagram, download_json], |
| cache_examples=False, |
| examples_per_page=10, |
| ) |
|
|
| clear_btn.click( |
| lambda: (None, None, None, None), |
| inputs=[], |
| outputs=[image_input, reaction_output, smiles_output, download_json] |
| ) |
|
|
|
|
| |
| run_btn.click( |
| process_chem_image_api, |
| inputs=[image_input, api_key_input, endpoint_input], |
| outputs=[reaction_output, smiles_output, schematic_diagram, download_json] |
| ) |
|
|
| |
|
|
| demo.css = """ |
| #submit-btn { |
| background-color: #FF914D; |
| color: white; |
| font-weight: bold; |
| } |
| """ |
|
|
| demo.launch() |
|
|