VReason-Demo / render.py
EvidenceAIResearch's picture
Upload RadGenome demo Space
a3feb0e verified
Raw
History Blame Contribute Delete
10.7 kB
import cv2
import imagesize
import itertools
import templates
import numpy as np
from pathlib import Path
from utils import mask_to_svg
from jinja2 import Environment, FileSystemLoader
# def get_center_point(mask):
# M = cv2.moments(mask)
# if M["m00"] == 0:
# return None # empty mask
# cx = M["m10"] / M["m00"]
# cy = M["m01"] / M["m00"]
# return {"x": cx, "y": cy}
def get_area(mask):
return int(np.sum(mask > 0))
from scipy.ndimage import binary_erosion
def mask_border(mask: np.ndarray) -> np.ndarray:
"""Return border pixels of a binary mask."""
eroded = binary_erosion(mask, structure=np.ones((3, 3)))
return mask & ~eroded
def get_left_point(mask):
border = mask_border(mask)
ys, xs = np.where(border > 0)
if xs.size == 0:
return None
min_x = xs.min()
max_x = xs.max()
offset = int((max_x - min_x) // 20) # dynamic offset by mask width
# select left-ish border pixels
candidate_mask = xs <= min_x + offset
xs = xs[candidate_mask]
ys = ys[candidate_mask]
if xs.size == 0:
return None
# desired vertical middle
target_y = (ys.min() + ys.max()) / 2
# find the candidate whose y is closest to target
idx = np.argmin(np.abs(ys - target_y))
chosen_x = xs[idx]
chosen_y = ys[idx]
return {"x": int(chosen_x), "y": int(chosen_y)}
def get_right_point(mask):
border = mask_border(mask)
ys, xs = np.where(border > 0)
if xs.size == 0:
return None
min_x = xs.min()
max_x = xs.max()
offset = int((max_x - min_x) // 20) # dynamic offset by mask width
# select left-ish border pixels
candidate_mask = xs >= max_x - offset
xs = xs[candidate_mask]
ys = ys[candidate_mask]
if xs.size == 0:
return None
# desired vertical middle
target_y = (ys.min() + ys.max()) / 2
# find the candidate whose y is closest to target
idx = np.argmin(np.abs(ys - target_y))
chosen_x = xs[idx]
chosen_y = ys[idx]
return {"x": int(chosen_x), "y": int(chosen_y)}
def prepare_data(data):
"""
- Split left-sorted masks into two left and right whose total content lengths are approximately balanced.
Returns:
left_data, right_data
"""
# Build formated_data and sort by "left"
formated_data = []
for roi in data["explain"]:
heading = roi["roi"].title()
content = roi["reason"]
mask = roi["mask"]
# protect
mask = mask if mask.any() else np.ones_like(mask, dtype=mask.dtype)
formated_data.append({
"heading": heading,
"content": content,
"mask": mask,
"area": get_area(mask),
# "center": get_center_point(mask),
"left": get_left_point(mask),
"right": get_right_point(mask),
})
# formated_data.sort(key=lambda x: x["center"]["x"])
formated_data.sort(key=lambda x: x["area"], reverse=True)
# --- split logic ---
lengths = [len(b["content"]) + 60 for b in formated_data]
total_chars = sum(lengths)
half_chars = total_chars / 2
forward_cumsum = [0] + list(itertools.accumulate(lengths))
backward_cumsum = list(itertools.accumulate(reversed(lengths)))[::-1] + [0]
cut_idx = None
min_diff = float("inf")
for i, (left_sum, right_sum) in enumerate(zip(forward_cumsum, backward_cumsum)):
diff = abs(right_sum - left_sum)
if diff < min_diff:
min_diff = diff
cut_idx = i
left_data = formated_data[:cut_idx]
right_data = formated_data[cut_idx:]
left_data.sort(key=lambda x: x["left"]["y"])
right_data.sort(key=lambda x: x["right"]["y"])
return left_data, right_data
def hover_on_other(i, n, target_type, types=("svg", "tbox", "connector"), subtype=None):
"""
i: index of the element being hovered
n: total number of elements
target_type: type of element to affect (e.g., 'connector', 'svg', 'tbox')
subtype: optional CSS subtype (e.g., ':before', '> span')
"""
selectors = []
for t in types: # hover source types
for j in range(n):
if j != i: # skip self
a = f"#{t}{i}"
b = f"#{target_type}{j}{'' if subtype is None else subtype}"
selectors.append(templates.diagram_html.hover_a_set_b.render(a=a, b=b))
return ", ".join(selectors)
def hover_on_self(i, target_type, types=("svg", "tbox", "connector"), subtype=None):
"""
Logic:
For each type in types (except target_type), generate:
body:has(#<type><i>:hover) #<target_type><i>{subtype}
"""
selectors = []
for t in types:
if t != target_type: # only other types with same index
a = f"#{t}{i}"
b = f"#{target_type}{i}{'' if subtype is None else subtype}"
selectors.append(templates.diagram_html.hover_a_set_b.render(a=a, b=b))
return ", ".join(selectors)
import base64
def to_b64(path):
with open(path, "rb") as f:
return "data:image/png;base64," + base64.b64encode(f.read()).decode()
def render_diagram_html(data):
img_width, img_height = imagesize.get(data["image_path"])
env = Environment(loader=FileSystemLoader("templates"))
html = env.get_template("template.html")
css = env.get_template("template.css")
js = env.get_template("script.js")
# Setup
svg_html_ls = []
tbox_left_html_ls = []
tbox_right_html_ls = []
svg_css_ls = []
tbox_css_ls = []
connector_css_ls = []
# Scale and transform masks to match css style and split them into left and right groups
left_data, right_data = prepare_data(data)
import re
def print_abbr(s):
s = re.sub(r'\bd="[^"]*"', 'd="..."', s)
print(s)
n = len(left_data) + len(right_data)
i = 0
for data_point in left_data:
color = templates.COLORS[i%len(templates.COLORS)]
mask = data_point["mask"]
left = f'{data_point["left"]["x"] / img_width :.4f}'
top = f'{data_point["left"]["y"] / img_height :.4f}'
extra_data = f'data-left="{left}" data-top="{top}" data-color="{color}"'
svg = mask_to_svg(mask).format(index=i, extra_data=extra_data, sub_class="", prefix="", hidden_style="", sub_svgs="")
data_point.update({
"i": i,
"side": "left",
"color": color,
"svg_hover_on_self": hover_on_self(i, "svg"),
"svg_hover_on_self_area": hover_on_self(i, "svg", subtype=" .level0"),
"svg_hover_on_self_stroke": ", ".join([hover_on_self(i, "svg", subtype=" .outer"), hover_on_self(i, "svg", subtype=" .inner")]),
"svg_hover_on_self_bg": hover_on_self(i, "svg", subtype=" .bg"),
"svg_hover_on_other": hover_on_other(i, n, "svg"),
"tbox_hover_on_self": hover_on_self(i, "tbox", subtype="::before"),
"connector_hover_on_self": hover_on_self(i, "connector"),
"connector_hover_on_other": hover_on_other(i, n, "connector"),
"connector_hover_on_self_line": hover_on_self(i, "connector", subtype=" line"),
"connector_hover_on_other_line": hover_on_other(i, n, "connector", subtype=" line"),
})
svg_html_ls.append(svg.strip("\n"))
svg_css_ls.append(templates.diagram_html.svg_css.render(**data_point).strip("\n"))
tbox_left_html_ls.append(templates.diagram_html.tbox_html.render(**data_point).strip("\n"))
tbox_css_ls.append(templates.diagram_html.tbox_css.render(**data_point).strip("\n"))
connector_css_ls.append(templates.diagram_html.connector_css.render(**data_point).strip("\n"))
i += 1
for data_point in right_data:
color = templates.COLORS[i%len(templates.COLORS)]
mask = data_point["mask"]
left = f'{data_point["right"]["x"] / img_width :.4f}'
top = f'{data_point["right"]["y"] / img_height :.4f}'
extra_data = f'data-left="{left}" data-top="{top}" data-color="{color}"'
svg = mask_to_svg(mask).format(index=i, extra_data=extra_data, sub_class="", prefix="", hidden_style="", sub_svgs="")
data_point.update({
"i": i,
"side": "right",
"color": color,
"svg_hover_on_self": hover_on_self(i, "svg"),
"svg_hover_on_self_area": hover_on_self(i, "svg", subtype=" .level0"),
"svg_hover_on_self_stroke": ", ".join([hover_on_self(i, "svg", subtype=" .outer"), hover_on_self(i, "svg", subtype=" .inner")]),
"svg_hover_on_self_bg": hover_on_self(i, "svg", subtype=" .bg"),
"svg_hover_on_other": hover_on_other(i, n, "svg"),
"tbox_hover_on_self": hover_on_self(i, "tbox", subtype="::before"),
"connector_hover_on_self": hover_on_self(i, "connector"),
"connector_hover_on_other": hover_on_other(i, n, "connector"),
"connector_hover_on_self_line": hover_on_self(i, "connector", subtype=" line"),
"connector_hover_on_other_line": hover_on_other(i, n, "connector", subtype=" line"),
})
svg_html_ls.append(svg.strip("\n"))
svg_css_ls.append(templates.diagram_html.svg_css.render(**data_point).strip("\n"))
tbox_right_html_ls.append(templates.diagram_html.tbox_html.render(**data_point).strip("\n"))
tbox_css_ls.append(templates.diagram_html.tbox_css.render(**data_point).strip("\n"))
connector_css_ls.append(templates.diagram_html.connector_css.render(**data_point).strip("\n"))
i += 1
css_content = css.render(
svgs = "\n".join(svg_css_ls),
tboxs = "\n".join(tbox_css_ls),
connectors = "\n".join(connector_css_ls),
)
js_content = js.render()
html_content = html.render(
stylesheet_content = css_content,
javascript_content = js_content,
image_path = to_b64(data["image_path"]),
svgs = "\n".join(svg_html_ls),
tboxs_left = "\n".join(tbox_left_html_ls),
tboxs_right = "\n".join(tbox_right_html_ls),
)
return html_content
import lorem
def random_mask(width, height, min_size=10):
# empty mask
mask = np.zeros((height, width), dtype=np.uint8)
# random number of circles
n_circles = np.random.randint(1, 15)
# maximum radius based on sqrt(width*height)
max_radius = int((width * height) ** 0.5 * 0.3)
for _ in range(n_circles):
# random center inside the image
cx = np.random.randint(0, width - 1)
cy = np.random.randint(0, height - 1)
# random radius within bounds
radius = np.random.randint(min_size, max_radius)
# draw circle on mask
cv2.circle(mask, (cx, cy), radius, 255, -1) # -1 fills the circle
return mask
def random_data(image_path):
np.random.seed(4)
n_boxes = 4 #np.random.randint(1, 10)
data = {
"image_id": 0,
"image_path": image_path,
"explain": [{
"roi": lorem.sentence()[:np.random.randint(3, 30)],
"mask": random_mask(*imagesize.get(image_path)),
"reason": " ".join([lorem.sentence() for i in range(np.random.randint(1, 3))])
} for i in range(n_boxes)]
}
print(f"Randomly generted {n_boxes} boxes.")
return data
# image_path = "tmp/img2.jpg"
# data = random_data(image_path)
# render_diagram_html(data)