RxnIM / app.py
L3ul's picture
Add error handling to process_chem_image for debugging
1f4d8b3 verified
import os
import gradio as gr
import json
import torch
from rxn.reaction import Reaction
import re
from getReaction import generate_combined_image
PROMPT_DIR = "prompts/"
ckpt_path = "./rxn/model/model.ckpt"
model = Reaction(ckpt_path, device=torch.device('cpu'))
example_diagram = "examples/exp.png"
PROMPT_NAMES = {
"2_RxnOCR.txt": "Reaction Image Parsing Workflow",
}
def list_prompt_files_with_names():
prompt_files = {}
for f in os.listdir(PROMPT_DIR):
if f.endswith(".txt"):
friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}")
prompt_files[friendly_name] = f
return prompt_files
def format_local_predictions(predictions):
"""
Convert local model predictions to the same structured format
that the original RXNIM/GPT-4o pipeline would produce.
"""
reactions_data = {"reactions": []}
smiles_list = []
for i, reaction in enumerate(predictions):
rxn_entry = {
"reaction_id": str(i + 1),
"reactants": [],
"conditions": [],
"products": [],
}
# Reactants
reactant_smiles = []
for r in reaction.get("reactants", []):
smiles = r.get("smiles", "")
if smiles:
rxn_entry["reactants"].append({"smiles": smiles, "label": "None"})
reactant_smiles.append(smiles)
# Conditions (text from OCR)
for c in reaction.get("conditions", []):
text_parts = c.get("text", [])
if isinstance(text_parts, list):
text = " ".join(text_parts)
else:
text = str(text_parts)
if text.strip():
rxn_entry["conditions"].append({"role": "reagent", "text": text.strip()})
# Products
product_smiles = []
for p in reaction.get("products", []):
smiles = p.get("smiles", "")
if smiles:
rxn_entry["products"].append({"smiles": smiles, "label": "None"})
product_smiles.append(smiles)
reactions_data["reactions"].append(rxn_entry)
if reactant_smiles and product_smiles:
rxn_smiles = f"{'.'.join(reactant_smiles)}>>{'.'.join(product_smiles)}"
smiles_list.append(rxn_smiles)
return reactions_data, smiles_list
def format_reactions_html(reactions_data):
"""Format reactions data as HTML for display."""
detailed_output = []
for reaction in reactions_data.get("reactions", []):
reaction_id = reaction.get("reaction_id", "?")
reactants = [r.get("smiles", "?") for r in reaction.get("reactants", [])]
conditions = [
f"<span style='color:red'>{c.get('text', c.get('smiles', '?'))} [{c.get('role', '?')}]</span>"
for c in reaction.get("conditions", [])
]
products = [
f"<span style='color:orange'>{p.get('smiles', '?')}</span>"
for p in reaction.get("products", [])
]
html = f"<b>Reaction: </b> {reaction_id}<br>"
html += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
html += f" Conditions: {', '.join(conditions) if conditions else 'N/A'}<br>"
html += f" Products: {', '.join(products)}<br><br>"
detailed_output.append(html)
return detailed_output
def process_chem_image(image, selected_task):
import traceback
try:
image_path = "temp_image.png"
if image is None:
raise ValueError("No image provided")
image.save(image_path)
# Run local model (MolScribe + OCR) — no external API needed
predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
# Format predictions into structured data
reactions_data, smiles_list = format_local_predictions(predictions)
# Generate HTML output
detailed_reactions = format_reactions_html(reactions_data)
# Generate visualization
try:
combined_image_path = generate_combined_image(predictions, image_path)
except Exception:
combined_image_path = None
# Save JSON
json_file_path = "output.json"
with open(json_file_path, "w") as f:
json.dump(reactions_data, f, indent=4)
return (
"\n\n".join(detailed_reactions),
"\n".join(smiles_list),
combined_image_path,
example_diagram,
json_file_path,
)
except Exception as e:
tb = traceback.format_exc()
error_html = f"<b>Error:</b> {str(e)}<br><pre>{tb}</pre>"
return (error_html, "", None, example_diagram, None)
prompts_with_names = list_prompt_files_with_names()
examples = [
["examples/reaction1.png", "Reaction Image Parsing Workflow"],
["examples/reaction2.png", "Reaction Image Parsing Workflow"],
["examples/reaction3.png", "Reaction Image Parsing Workflow"],
["examples/reaction4.png", "Reaction Image Parsing Workflow"],
]
with gr.Blocks() as demo:
gr.Markdown("""
<center><h1>RxnIM — Chemical Reaction Image Mining</h1></center>
Upload a reaction image and extract SMILES, conditions, and structural data.
This fork uses the local MolScribe + OCR model (no external API required).
""")
with gr.Row(equal_height=False):
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Reaction Image")
task_radio = gr.Radio(
choices=list(prompts_with_names.keys()),
label="Select a predefined task",
)
with gr.Row():
clear_button = gr.Button("Clear")
process_button = gr.Button("Run", elem_id="submit-btn")
gr.Markdown("### Reaction Image Parsing Output")
reaction_output = gr.HTML(label="Reaction outputs")
with gr.Column(scale=1):
gr.Markdown("### Reaction Extraction Output")
visualization_output = gr.Image(label="Visualization Output")
schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram")
with gr.Column(scale=1):
gr.Markdown("### Machine-readable Data Output")
smiles_output = gr.Textbox(
label="Reaction SMILES",
show_copy_button=True,
interactive=False,
)
download_json = gr.File(label="Download JSON File")
gr.Examples(
examples=examples,
inputs=[image_input, task_radio],
outputs=[reaction_output, smiles_output, visualization_output],
)
clear_button.click(
lambda: (None, None, None, None, None),
inputs=[],
outputs=[image_input, task_radio, reaction_output, smiles_output, visualization_output],
)
process_button.click(
process_chem_image,
inputs=[image_input, task_radio],
outputs=[reaction_output, smiles_output, visualization_output, schematic_diagram, download_json],
)
demo.css = """
#submit-btn {
background-color: #FF914D;
color: white;
font-weight: bold;
}
"""
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)