| import os |
| import gradio as gr |
| import json |
| import torch |
| from rxn.reaction import Reaction |
| import re |
| from getReaction import generate_combined_image |
|
|
|
|
| PROMPT_DIR = "prompts/" |
| ckpt_path = "./rxn/model/model.ckpt" |
| model = Reaction(ckpt_path, device=torch.device('cpu')) |
|
|
| example_diagram = "examples/exp.png" |
|
|
| PROMPT_NAMES = { |
| "2_RxnOCR.txt": "Reaction Image Parsing Workflow", |
| } |
|
|
|
|
| def list_prompt_files_with_names(): |
| prompt_files = {} |
| for f in os.listdir(PROMPT_DIR): |
| if f.endswith(".txt"): |
| friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}") |
| prompt_files[friendly_name] = f |
| return prompt_files |
|
|
|
|
| def format_local_predictions(predictions): |
| """ |
| Convert local model predictions to the same structured format |
| that the original RXNIM/GPT-4o pipeline would produce. |
| """ |
| reactions_data = {"reactions": []} |
| smiles_list = [] |
|
|
| for i, reaction in enumerate(predictions): |
| rxn_entry = { |
| "reaction_id": str(i + 1), |
| "reactants": [], |
| "conditions": [], |
| "products": [], |
| } |
|
|
| |
| reactant_smiles = [] |
| for r in reaction.get("reactants", []): |
| smiles = r.get("smiles", "") |
| if smiles: |
| rxn_entry["reactants"].append({"smiles": smiles, "label": "None"}) |
| reactant_smiles.append(smiles) |
|
|
| |
| for c in reaction.get("conditions", []): |
| text_parts = c.get("text", []) |
| if isinstance(text_parts, list): |
| text = " ".join(text_parts) |
| else: |
| text = str(text_parts) |
| if text.strip(): |
| rxn_entry["conditions"].append({"role": "reagent", "text": text.strip()}) |
|
|
| |
| product_smiles = [] |
| for p in reaction.get("products", []): |
| smiles = p.get("smiles", "") |
| if smiles: |
| rxn_entry["products"].append({"smiles": smiles, "label": "None"}) |
| product_smiles.append(smiles) |
|
|
| reactions_data["reactions"].append(rxn_entry) |
|
|
| if reactant_smiles and product_smiles: |
| rxn_smiles = f"{'.'.join(reactant_smiles)}>>{'.'.join(product_smiles)}" |
| smiles_list.append(rxn_smiles) |
|
|
| return reactions_data, smiles_list |
|
|
|
|
| def format_reactions_html(reactions_data): |
| """Format reactions data as HTML for display.""" |
| detailed_output = [] |
| for reaction in reactions_data.get("reactions", []): |
| reaction_id = reaction.get("reaction_id", "?") |
| reactants = [r.get("smiles", "?") for r in reaction.get("reactants", [])] |
| conditions = [ |
| f"<span style='color:red'>{c.get('text', c.get('smiles', '?'))} [{c.get('role', '?')}]</span>" |
| for c in reaction.get("conditions", []) |
| ] |
| products = [ |
| f"<span style='color:orange'>{p.get('smiles', '?')}</span>" |
| for p in reaction.get("products", []) |
| ] |
|
|
| html = f"<b>Reaction: </b> {reaction_id}<br>" |
| html += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>" |
| html += f" Conditions: {', '.join(conditions) if conditions else 'N/A'}<br>" |
| html += f" Products: {', '.join(products)}<br><br>" |
| detailed_output.append(html) |
|
|
| return detailed_output |
|
|
|
|
| def process_chem_image(image, selected_task): |
| import traceback |
| try: |
| image_path = "temp_image.png" |
| if image is None: |
| raise ValueError("No image provided") |
| image.save(image_path) |
|
|
| |
| predictions = model.predict_image_file(image_path, molscribe=True, ocr=True) |
|
|
| |
| reactions_data, smiles_list = format_local_predictions(predictions) |
|
|
| |
| detailed_reactions = format_reactions_html(reactions_data) |
|
|
| |
| try: |
| combined_image_path = generate_combined_image(predictions, image_path) |
| except Exception: |
| combined_image_path = None |
|
|
| |
| json_file_path = "output.json" |
| with open(json_file_path, "w") as f: |
| json.dump(reactions_data, f, indent=4) |
|
|
| return ( |
| "\n\n".join(detailed_reactions), |
| "\n".join(smiles_list), |
| combined_image_path, |
| example_diagram, |
| json_file_path, |
| ) |
| except Exception as e: |
| tb = traceback.format_exc() |
| error_html = f"<b>Error:</b> {str(e)}<br><pre>{tb}</pre>" |
| return (error_html, "", None, example_diagram, None) |
|
|
|
|
| prompts_with_names = list_prompt_files_with_names() |
|
|
| examples = [ |
| ["examples/reaction1.png", "Reaction Image Parsing Workflow"], |
| ["examples/reaction2.png", "Reaction Image Parsing Workflow"], |
| ["examples/reaction3.png", "Reaction Image Parsing Workflow"], |
| ["examples/reaction4.png", "Reaction Image Parsing Workflow"], |
| ] |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown(""" |
| <center><h1>RxnIM — Chemical Reaction Image Mining</h1></center> |
| |
| Upload a reaction image and extract SMILES, conditions, and structural data. |
| This fork uses the local MolScribe + OCR model (no external API required). |
| """) |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="Upload Reaction Image") |
| task_radio = gr.Radio( |
| choices=list(prompts_with_names.keys()), |
| label="Select a predefined task", |
| ) |
| with gr.Row(): |
| clear_button = gr.Button("Clear") |
| process_button = gr.Button("Run", elem_id="submit-btn") |
|
|
| gr.Markdown("### Reaction Image Parsing Output") |
| reaction_output = gr.HTML(label="Reaction outputs") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Reaction Extraction Output") |
| visualization_output = gr.Image(label="Visualization Output") |
| schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Machine-readable Data Output") |
| smiles_output = gr.Textbox( |
| label="Reaction SMILES", |
| show_copy_button=True, |
| interactive=False, |
| ) |
| download_json = gr.File(label="Download JSON File") |
|
|
| gr.Examples( |
| examples=examples, |
| inputs=[image_input, task_radio], |
| outputs=[reaction_output, smiles_output, visualization_output], |
| ) |
|
|
| clear_button.click( |
| lambda: (None, None, None, None, None), |
| inputs=[], |
| outputs=[image_input, task_radio, reaction_output, smiles_output, visualization_output], |
| ) |
|
|
| process_button.click( |
| process_chem_image, |
| inputs=[image_input, task_radio], |
| outputs=[reaction_output, smiles_output, visualization_output, schematic_diagram, download_json], |
| ) |
|
|
| demo.css = """ |
| #submit-btn { |
| background-color: #FF914D; |
| color: white; |
| font-weight: bold; |
| } |
| """ |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|