import os import gradio as gr import json import torch from rxn.reaction import Reaction import re from getReaction import generate_combined_image PROMPT_DIR = "prompts/" ckpt_path = "./rxn/model/model.ckpt" model = Reaction(ckpt_path, device=torch.device('cpu')) example_diagram = "examples/exp.png" PROMPT_NAMES = { "2_RxnOCR.txt": "Reaction Image Parsing Workflow", } def list_prompt_files_with_names(): prompt_files = {} for f in os.listdir(PROMPT_DIR): if f.endswith(".txt"): friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}") prompt_files[friendly_name] = f return prompt_files def format_local_predictions(predictions): """ Convert local model predictions to the same structured format that the original RXNIM/GPT-4o pipeline would produce. """ reactions_data = {"reactions": []} smiles_list = [] for i, reaction in enumerate(predictions): rxn_entry = { "reaction_id": str(i + 1), "reactants": [], "conditions": [], "products": [], } # Reactants reactant_smiles = [] for r in reaction.get("reactants", []): smiles = r.get("smiles", "") if smiles: rxn_entry["reactants"].append({"smiles": smiles, "label": "None"}) reactant_smiles.append(smiles) # Conditions (text from OCR) for c in reaction.get("conditions", []): text_parts = c.get("text", []) if isinstance(text_parts, list): text = " ".join(text_parts) else: text = str(text_parts) if text.strip(): rxn_entry["conditions"].append({"role": "reagent", "text": text.strip()}) # Products product_smiles = [] for p in reaction.get("products", []): smiles = p.get("smiles", "") if smiles: rxn_entry["products"].append({"smiles": smiles, "label": "None"}) product_smiles.append(smiles) reactions_data["reactions"].append(rxn_entry) if reactant_smiles and product_smiles: rxn_smiles = f"{'.'.join(reactant_smiles)}>>{'.'.join(product_smiles)}" smiles_list.append(rxn_smiles) return reactions_data, smiles_list def format_reactions_html(reactions_data): """Format reactions data as HTML for display.""" detailed_output = [] for reaction in reactions_data.get("reactions", []): reaction_id = reaction.get("reaction_id", "?") reactants = [r.get("smiles", "?") for r in reaction.get("reactants", [])] conditions = [ f"{c.get('text', c.get('smiles', '?'))} [{c.get('role', '?')}]" for c in reaction.get("conditions", []) ] products = [ f"{p.get('smiles', '?')}" for p in reaction.get("products", []) ] html = f"Reaction: {reaction_id}
" html += f" Reactants: {', '.join(reactants)}
" html += f" Conditions: {', '.join(conditions) if conditions else 'N/A'}
" html += f" Products: {', '.join(products)}

" detailed_output.append(html) return detailed_output def process_chem_image(image, selected_task): import traceback try: image_path = "temp_image.png" if image is None: raise ValueError("No image provided") image.save(image_path) # Run local model (MolScribe + OCR) — no external API needed predictions = model.predict_image_file(image_path, molscribe=True, ocr=True) # Format predictions into structured data reactions_data, smiles_list = format_local_predictions(predictions) # Generate HTML output detailed_reactions = format_reactions_html(reactions_data) # Generate visualization try: combined_image_path = generate_combined_image(predictions, image_path) except Exception: combined_image_path = None # Save JSON json_file_path = "output.json" with open(json_file_path, "w") as f: json.dump(reactions_data, f, indent=4) return ( "\n\n".join(detailed_reactions), "\n".join(smiles_list), combined_image_path, example_diagram, json_file_path, ) except Exception as e: tb = traceback.format_exc() error_html = f"Error: {str(e)}
{tb}
" return (error_html, "", None, example_diagram, None) prompts_with_names = list_prompt_files_with_names() examples = [ ["examples/reaction1.png", "Reaction Image Parsing Workflow"], ["examples/reaction2.png", "Reaction Image Parsing Workflow"], ["examples/reaction3.png", "Reaction Image Parsing Workflow"], ["examples/reaction4.png", "Reaction Image Parsing Workflow"], ] with gr.Blocks() as demo: gr.Markdown("""

RxnIM — Chemical Reaction Image Mining

Upload a reaction image and extract SMILES, conditions, and structural data. This fork uses the local MolScribe + OCR model (no external API required). """) with gr.Row(equal_height=False): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Reaction Image") task_radio = gr.Radio( choices=list(prompts_with_names.keys()), label="Select a predefined task", ) with gr.Row(): clear_button = gr.Button("Clear") process_button = gr.Button("Run", elem_id="submit-btn") gr.Markdown("### Reaction Image Parsing Output") reaction_output = gr.HTML(label="Reaction outputs") with gr.Column(scale=1): gr.Markdown("### Reaction Extraction Output") visualization_output = gr.Image(label="Visualization Output") schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram") with gr.Column(scale=1): gr.Markdown("### Machine-readable Data Output") smiles_output = gr.Textbox( label="Reaction SMILES", show_copy_button=True, interactive=False, ) download_json = gr.File(label="Download JSON File") gr.Examples( examples=examples, inputs=[image_input, task_radio], outputs=[reaction_output, smiles_output, visualization_output], ) clear_button.click( lambda: (None, None, None, None, None), inputs=[], outputs=[image_input, task_radio, reaction_output, smiles_output, visualization_output], ) process_button.click( process_chem_image, inputs=[image_input, task_radio], outputs=[reaction_output, smiles_output, visualization_output, schematic_diagram, download_json], ) demo.css = """ #submit-btn { background-color: #FF914D; color: white; font-weight: bold; } """ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)