import os
import gradio as gr
import json
import torch
from rxn.reaction import Reaction
import re
from getReaction import generate_combined_image
PROMPT_DIR = "prompts/"
ckpt_path = "./rxn/model/model.ckpt"
model = Reaction(ckpt_path, device=torch.device('cpu'))
example_diagram = "examples/exp.png"
PROMPT_NAMES = {
"2_RxnOCR.txt": "Reaction Image Parsing Workflow",
}
def list_prompt_files_with_names():
prompt_files = {}
for f in os.listdir(PROMPT_DIR):
if f.endswith(".txt"):
friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}")
prompt_files[friendly_name] = f
return prompt_files
def format_local_predictions(predictions):
"""
Convert local model predictions to the same structured format
that the original RXNIM/GPT-4o pipeline would produce.
"""
reactions_data = {"reactions": []}
smiles_list = []
for i, reaction in enumerate(predictions):
rxn_entry = {
"reaction_id": str(i + 1),
"reactants": [],
"conditions": [],
"products": [],
}
# Reactants
reactant_smiles = []
for r in reaction.get("reactants", []):
smiles = r.get("smiles", "")
if smiles:
rxn_entry["reactants"].append({"smiles": smiles, "label": "None"})
reactant_smiles.append(smiles)
# Conditions (text from OCR)
for c in reaction.get("conditions", []):
text_parts = c.get("text", [])
if isinstance(text_parts, list):
text = " ".join(text_parts)
else:
text = str(text_parts)
if text.strip():
rxn_entry["conditions"].append({"role": "reagent", "text": text.strip()})
# Products
product_smiles = []
for p in reaction.get("products", []):
smiles = p.get("smiles", "")
if smiles:
rxn_entry["products"].append({"smiles": smiles, "label": "None"})
product_smiles.append(smiles)
reactions_data["reactions"].append(rxn_entry)
if reactant_smiles and product_smiles:
rxn_smiles = f"{'.'.join(reactant_smiles)}>>{'.'.join(product_smiles)}"
smiles_list.append(rxn_smiles)
return reactions_data, smiles_list
def format_reactions_html(reactions_data):
"""Format reactions data as HTML for display."""
detailed_output = []
for reaction in reactions_data.get("reactions", []):
reaction_id = reaction.get("reaction_id", "?")
reactants = [r.get("smiles", "?") for r in reaction.get("reactants", [])]
conditions = [
f"{c.get('text', c.get('smiles', '?'))} [{c.get('role', '?')}]"
for c in reaction.get("conditions", [])
]
products = [
f"{p.get('smiles', '?')}"
for p in reaction.get("products", [])
]
html = f"Reaction: {reaction_id}
"
html += f" Reactants: {', '.join(reactants)}
"
html += f" Conditions: {', '.join(conditions) if conditions else 'N/A'}
"
html += f" Products: {', '.join(products)}
"
detailed_output.append(html)
return detailed_output
def process_chem_image(image, selected_task):
import traceback
try:
image_path = "temp_image.png"
if image is None:
raise ValueError("No image provided")
image.save(image_path)
# Run local model (MolScribe + OCR) — no external API needed
predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
# Format predictions into structured data
reactions_data, smiles_list = format_local_predictions(predictions)
# Generate HTML output
detailed_reactions = format_reactions_html(reactions_data)
# Generate visualization
try:
combined_image_path = generate_combined_image(predictions, image_path)
except Exception:
combined_image_path = None
# Save JSON
json_file_path = "output.json"
with open(json_file_path, "w") as f:
json.dump(reactions_data, f, indent=4)
return (
"\n\n".join(detailed_reactions),
"\n".join(smiles_list),
combined_image_path,
example_diagram,
json_file_path,
)
except Exception as e:
tb = traceback.format_exc()
error_html = f"Error: {str(e)}
{tb}"
return (error_html, "", None, example_diagram, None)
prompts_with_names = list_prompt_files_with_names()
examples = [
["examples/reaction1.png", "Reaction Image Parsing Workflow"],
["examples/reaction2.png", "Reaction Image Parsing Workflow"],
["examples/reaction3.png", "Reaction Image Parsing Workflow"],
["examples/reaction4.png", "Reaction Image Parsing Workflow"],
]
with gr.Blocks() as demo:
gr.Markdown("""