FastCDM / fastcdm /core.py
BinyangQiu
Impl Dockerfile
c5b5c64
Raw
History Blame Contribute Delete
13.6 kB
from fastcdm.render.render_worker import RenderWorker
from fastcdm.matcher import update_inliers, HungarianMatcher, SimpleAffineTransform
from fastcdm.clean import (
clean,
PATTERN_STRIP_START_BRACKET,
PATTERN_STRIP_END_BRACKET,
)
from fastcdm.tokenize import tokenize
from fastcdm.colorize import process_for_katex, generate_high_contrast_colors
from fastcdm.box import get_bboxes_from_array
import cv2
import numpy as np
from typing import List, Tuple
from pathlib import Path
from skimage.measure import ransac
import traceback
import subprocess
import shutil
import os
root_dir = Path(__file__).parent
TEMPLATE_FILE = root_dir / "render" / "templates" / "formula.html"
def preprocess(s: str):
# --- Step 1: Clean & Tokenization ---
clean_s = clean(s)
success_tokenization, tokenized_s = tokenize(clean_s)
if not success_tokenization:
print("Tokenization failed")
return 0.0
# --- Step 2: Prepare Colorized Latex for KaTeX ---
katex_template, token_list = process_for_katex(tokenized_s)
# Generate colors
num_colors = len(token_list) + 10
colors_rgb = generate_high_contrast_colors(num_colors)
final_latex = katex_template
color_map = [] # List of (token, rgb_color)
for c_idx, token in enumerate(token_list):
r, g, b = colors_rgb[c_idx % len(colors_rgb)]
final_latex = final_latex.replace(
f"__COLOR__{c_idx}__", f"#{r:02x}{g:02x}{b:02x}"
)
color_map.append((token, (r, g, b)))
# 移除首尾括号
final_latex = PATTERN_STRIP_START_BRACKET.sub("", final_latex)
final_latex = PATTERN_STRIP_END_BRACKET.sub("", final_latex)
return final_latex, color_map
def calculate_metrics(gt_len, pred_len, match_num):
"""计算F1-score, Recall, Precision。"""
recall = match_num / gt_len if gt_len > 0 else 0
precision = match_num / pred_len if pred_len > 0 else 0
f1_score = (
2 * (precision * recall) / (precision + recall) if recall + precision > 0 else 0
)
return f1_score, recall, precision
def postprocess(
img_gt: np.ndarray,
img_pred: np.ndarray,
gt_color_map: List[Tuple[str, Tuple[int, int, int]]],
pred_color_map: List[Tuple[str, Tuple[int, int, int]]],
visualize: bool,
):
# Normalize Image Sizes (Max Height/Width)
h_gt, w_gt = img_gt.shape[:2]
h_pred, w_pred = img_pred.shape[:2]
max_h = max(h_gt, h_pred)
max_w = max(w_gt, w_pred)
# Create canvas
final_gt_img = np.full((max_h, max_w, 3), 255, dtype=np.uint8)
final_gt_img[0:h_gt, 0:w_gt] = img_gt
final_pred_img = np.full((max_h, max_w, 3), 255, dtype=np.uint8)
final_pred_img[0:h_pred, 0:w_pred] = img_pred
vis_img = None
gt_colors = [c[1] for c in gt_color_map]
pred_colors = [c[1] for c in pred_color_map]
# bboxes format: [xmin, ymin, xmax, ymax]
gt_bboxes_list = get_bboxes_from_array(final_gt_img, gt_colors)
pred_bboxes_list = get_bboxes_from_array(final_pred_img, pred_colors)
# --- Step 3: Match Tokens ---
# Convert to list of dicts as expected by HungarianMatcher
# item = {'bbox': [xmin, ymin, xmax, ymax], 'token': token_str}
gt_data = []
for i, bbox in enumerate(gt_bboxes_list):
if bbox:
gt_data.append({"bbox": bbox, "token": gt_color_map[i][0]})
pred_data = []
for i, bbox in enumerate(pred_bboxes_list):
if bbox:
pred_data.append({"bbox": bbox, "token": pred_color_map[i][0]})
matcher = HungarianMatcher()
size_tuple = (max_w, max_h)
matched_idxes = matcher(gt_data, pred_data, size_tuple, size_tuple)
# RANSAC Verification
src, dst = [], []
for idx1, idx2 in matched_idxes:
# Center points
x1_c = (gt_data[idx1]["bbox"][0] + gt_data[idx1]["bbox"][2]) / 2
y1_c = (gt_data[idx1]["bbox"][1] + gt_data[idx1]["bbox"][3]) / 2
x2_c = (pred_data[idx2]["bbox"][0] + pred_data[idx2]["bbox"][2]) / 2
y2_c = (pred_data[idx2]["bbox"][1] + pred_data[idx2]["bbox"][3]) / 2
src.append([y1_c, x1_c])
dst.append([y2_c, x2_c])
src, dst = np.array(src), np.array(dst)
min_samples = 3
if src.shape[0] <= min_samples:
inliers = np.ones(len(matched_idxes), dtype=bool)
else:
inliers = np.zeros(len(matched_idxes), dtype=bool)
for _ in range(5):
if np.sum(~inliers) <= min_samples:
break
# SimpleAffineTransform expects (N, 2)
# RANSAC fits model to data
try:
model, inliers_1 = ransac(
(src[~inliers], dst[~inliers]),
SimpleAffineTransform,
min_samples=min_samples,
residual_threshold=20,
max_trials=50,
)
if inliers_1 is not None and inliers_1.any():
inliers = update_inliers(inliers, inliers_1)
else:
break
except Exception:
# Ransac might fail if data is degenerate
break
# Double check token cost for inliers
for idx, (a, b) in enumerate(matched_idxes):
# matcher.cost['token'] is (gt_len, pred_len)
# If token cost is 1 (completely different), reject even if spatially aligned?
# visual_matcher.py logic: if inliers[idx] and matcher.cost['token'][a, b] == 1: inliers[idx] = False
if inliers[idx] and matcher.cost["token"][a, b] == 1:
inliers[idx] = False
match_num = np.sum(inliers)
num_gt = len(gt_bboxes_list)
num_pred = len(pred_bboxes_list)
f1, recall, precision = calculate_metrics(num_gt, num_pred, match_num)
if visualize:
vis_img = np.full((max_h * 2 + 10, max_w, 3), 255, dtype=np.uint8)
vis_img[0:max_h, 0:max_w] = final_gt_img
vis_img[max_h + 10 : max_h + 10 + max_h, 0:max_w] = final_pred_img
# Draw matches that are inliers
for idx, (gt_idx, pred_idx) in enumerate(matched_idxes):
if inliers[idx]:
gt_box = gt_data[gt_idx]["bbox"]
pred_box = pred_data[pred_idx]["bbox"]
# Draw boxes
cv2.rectangle(
vis_img,
(gt_box[0], gt_box[1]),
(gt_box[2], gt_box[3]),
(0, 255, 0),
1,
)
y_offset = max_h + 10
cv2.rectangle(
vis_img,
(pred_box[0], pred_box[1] + y_offset),
(pred_box[2], pred_box[3] + y_offset),
(0, 0, 255),
1,
)
# Draw line
pt1 = (
int((gt_box[0] + gt_box[2]) / 2),
int((gt_box[1] + gt_box[3]) / 2),
)
pt2 = (
int((pred_box[0] + pred_box[2]) / 2),
int((pred_box[1] + pred_box[3]) / 2) + y_offset,
)
cv2.line(vis_img, pt1, pt2, (255, 0, 0), 1)
return (f1, recall, precision, vis_img) if visualize else (f1, recall, precision)
class FastCDM:
def __init__(self, chromedriver: str = None) -> None:
self.chromedriver = chromedriver
self.check_environment()
self.render_worker = None
self.init_render_worker()
def check_environment(self) -> None:
"""
检查运行环境是否准备就绪。
1. Node.js 环境是否安装。
2. ChromeDriver 是否可用。
"""
# 1. 检查 Node.js
node_path = shutil.which("node")
if not node_path:
raise RuntimeError(
"Node.js is not found. Please install Node.js for formula normalization. "
"Visit https://nodejs.org/ to install it."
)
try:
subprocess.check_output(["node", "--version"], text=True)
except Exception as e:
raise RuntimeError(f"Node.js is found but failed to execute: {e}")
# 2. 检查 ChromeDriver
if self.chromedriver:
if not os.path.exists(self.chromedriver):
raise FileNotFoundError(f"Specified ChromeDriver not found at: {self.chromedriver}")
else:
chromedriver_path = shutil.which("chromedriver")
if not chromedriver_path:
raise RuntimeError(
"ChromeDriver not found in PATH and no path was specified. "
"Please install ChromeDriver or provide the path via the 'chromedriver' parameter. "
"You can also use 'scripts/auto_install_chromedriver.py' to download it."
)
def init_render_worker(self) -> None:
try:
self.render_worker = RenderWorker(
template_file="file://" + str(TEMPLATE_FILE.resolve()),
timeout=30,
driver_path=self.chromedriver,
)
except Exception as e:
print("Failed to init RenderWorker:")
print("=" * 30)
print(traceback.format_exc())
return None
def close(self):
if self.render_worker:
self.render_worker.close()
self.render_worker = None
def __del__(self):
self.close()
def render(self, latex_list: list) -> list:
"""
渲染 LaTeX 表达式列表。
参数:
latex_list (list): LaTeX 表达式列表。
返回:
list: 渲染后的图像列表。
"""
try:
latex_strings = [
f"$${s}$$" if not s.startswith("$$") else s for s in latex_list
]
imgs = self.render_worker.render(latex_strings)
except Exception as e:
print("Rendering failed:")
print("=" * 30)
print(traceback.format_exc(e))
return []
assert len(imgs) == len(
latex_strings
), "Number of rendered images must match number of input strings"
return imgs
def compute(self, gt: str, pred: str, visualize: bool = False) -> tuple:
"""
计算给定的 GT 和预测 LaTeX 表达式的 CDM 分数。
参数:
gt (str): ground truth LaTeX 表达式。
pred (str): 预测 LaTeX 表达式。
返回:
tuple: 包含 F1 分数、召回率和准确率的元组。
"""
gt_latex, gt_color_map = preprocess(gt)
pred_latex, pred_color_map = preprocess(pred)
imgs = self.render([gt_latex, pred_latex])
gt_img, pred_img = imgs[0], imgs[1]
result = postprocess(gt_img, pred_img, gt_color_map, pred_color_map, visualize)
return result
def batch_compute(self, gt_list: list, pred_list: list) -> list:
"""
TODO
批量计算给定的 GT 和预测 LaTeX 表达式的 CDM 分数。
参数:
gt_list (list): ground truth LaTeX 表达式列表。
pred_list (list): 预测 LaTeX 表达式列表。
返回:
list: 包含每个表达式的 F1 分数、召回率和准确率的元组列表。
"""
raise NotImplementedError("batch_compute is not implemented yet.")
if __name__ == "__main__":
# gt = r"A_{M123}=u\,A^{M}"
# pred = r"A_M123 = \hat{u} A^M"
# gt = r"r = \frac { \alpha } { \beta } \vert \sin \beta \left( \sigma _ { 1 } \pm \sigma _ { 2 } \right) \vert"
# pred = r"r={\frac{\alpha}{\beta}}|\sin\beta\left(\sigma_{2}+\sigma_{1}\right)|"
# gt = r"\frac{1}{2}"
# pred = r"\frac{1}{2}"
# gt = r"\tilde{\theta}_k(t)=\frac{\hat{\theta}_k(t+1)-\hat{\theta}_k(t)}{T_s}"
# pred = r"\tilde{\theta}_k(t)=\frac{\hat{\theta}_k(t+1)-\hat{\theta}_k(t)}{T_s}"
gt = r"\begin{bmatrix}(\mathbf{I}-\mathbf{A}^{\mathsf{DD}})&-\mathbf{A}^{\mathsf{DP }}&-\mathbf{A}^{\mathsf{DN}}\\ 0&\mathbf{I}&0\\ -\mathbf{A}^{\mathsf{ND}}&-\mathbf{A}^{\mathsf{NP}}&(\mathbf{I}-\mathbf{A}^{ \mathsf{NN}})\end{bmatrix}^{-1}=\begin{bmatrix}\mathbf{B}^{\mathsf{DD}}& \mathbf{B}^{\mathsf{DP}}&\mathbf{B}^{\mathsf{DN}}\\ \mathbf{B}^{\mathsf{PD}}&\mathbf{B}^{\mathsf{PP}}&\mathbf{B}^{\mathsf{PN}}\\ \mathbf{B}^{\mathsf{ND}}&\mathbf{B}^{\mathsf{NP}}&\mathbf{B}^{\mathsf{NN}} \end{bmatrix}"
pred = r"\left[ \begin{array} { c c c } { ( I - A ^ { \mathrm { D D } } ) } & { - A ^ { \mathrm { D P } } } & { - A ^ { \mathrm { D N } } } \\ { 0 } & { \mathbf { I } } & { 0 } \\ { - A ^ { \mathrm { N D } } } & { - A ^ { \mathrm { N P } } } & { ( I - A ^ { \mathrm { N N } } ) } \end{array} \right] ^ { - 1 } = \left[ \begin{array} { c c c } { \mathbf { B } ^ { \mathrm { D D } } } & { \mathbf { B } ^ { \mathrm { D P } } } & { \mathbf { B } ^ { \mathrm { D N } } } \\ { \mathbf { B } ^ { \mathrm { P D } } } & { \mathbf { B } ^ { \mathrm { P P } } } & { \mathbf { B } ^ { \mathrm { P N } } } \\ { \mathbf { B } ^ { \mathrm { N D } } } & { \mathbf { B } ^ { \mathrm { N P } } } & { \mathbf { B } ^ { \mathrm { N N } } } \end{array} \right]"
fastcdm = FastCDM(chromedriver="driver/chromedriver")
res = fastcdm.compute(gt, pred, visualize=True)
f1, recall, precision, vis_img = res
print(f"CDM Score (F1): {f1:.4f}")
print(f"Recall: {recall:.4f}")
print(f"Precision: {precision:.4f}")
try:
out_dir = Path(__file__).parent.parent / "vis"
out_dir.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(out_dir / "match_vis.png"), vis_img)
except Exception:
pass