| import sys |
| sys.path.append('./rxn/') |
| import torch |
| from rxn.reaction import Reaction |
| import json |
| from matplotlib import pyplot as plt |
| import numpy as np |
|
|
| ckpt_path = "./rxn/model/model.ckpt" |
| model = Reaction(ckpt_path, device=torch.device('cpu')) |
| device = torch.device('cpu') |
|
|
| def get_reaction(image_path: str) -> list: |
| image_file = image_path |
| return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True)) |
|
|
|
|
|
|
| def generate_combined_image(predictions, image_file): |
| """ |
| 将预测的图像整合到一个对称的布局中输出。 |
| """ |
| output = model.draw_predictions(predictions, image_file=image_file) |
| n_images = len(output) |
| |
| |
| |
| |
| |
| |
| n_cols = 1 |
| n_rows = (n_images + n_cols - 1) // n_cols |
|
|
| |
| processed_images = [] |
| for img in output: |
| if len(img.shape) == 2: |
| img = np.stack([img] * 3, axis=-1) |
| elif img.shape[2] > 3: |
| img = img[:, :, :3] |
| if img.dtype == np.float32 or img.dtype == np.float64: |
| img = (img * 255).astype(np.uint8) |
| processed_images.append(img) |
| output = processed_images |
|
|
| |
| if n_images < n_rows * n_cols: |
| blank_image = np.ones_like(output[0]) * 255 |
| while len(output) < n_rows * n_cols: |
| output.append(blank_image) |
|
|
| |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 12 * n_rows)) |
|
|
| |
| if isinstance(axes, np.ndarray): |
| axes = axes.flatten() |
| else: |
| axes = [axes] |
|
|
| |
| for idx, img in enumerate(output): |
| ax = axes[idx] |
| ax.imshow(img) |
| ax.axis('off') |
| if idx < n_images: |
| ax.set_title(f"### Reaction {idx + 1} ###",fontsize=42) |
|
|
| |
| for idx in range(n_images, len(axes)): |
| fig.delaxes(axes[idx]) |
|
|
| |
| combined_image_path = "combined_output.png" |
| plt.tight_layout() |
| plt.savefig(combined_image_path) |
| plt.close(fig) |
| return combined_image_path |
|
|