"""Visualization helpers: tokenization-grid overlay, foveation outline, prompt I/O.""" import os import numpy as np import pandas as pd def load_prompt_dataset(prompt_dataset_path: str): """Load prompts from a CSV with a `prompt` column.""" df = pd.read_csv(prompt_dataset_path) return [row["prompt"] for _, row in df.iterrows()] def create_tokenization_mask_vis(foveation_mask, height: int, width: int, lr_factor: int = 2): """Render the HR/LR token grid implied by a token-grid foveation mask. HR region (mask=1) is white with 16x16 token outlines; LR region (mask=0) is light gray with (16*lr_factor)x(16*lr_factor) outlines. """ mask_np = foveation_mask.detach().cpu().numpy() if hasattr(foveation_mask, "detach") else np.asarray(foveation_mask) upsampled = np.repeat(np.repeat(mask_np, 16, axis=0), 16, axis=1)[:height, :width] vis = np.full((height, width, 3), 255, dtype=np.uint8) if lr_factor > 1: vis[upsampled < 0.5] = [225, 225, 225] yy = np.arange(height) xx = np.arange(width) border_px = 4 hr_grid = np.zeros((height, width), dtype=bool) hr_grid[yy % 16 < border_px, :] = True hr_grid[:, xx % 16 < border_px] = True lr_grid = np.zeros((height, width), dtype=bool) lr_grid[yy % (16 * lr_factor) < border_px, :] = True lr_grid[:, xx % (16 * lr_factor) < border_px] = True hr_region = upsampled > 0.5 outline = (hr_grid & hr_region) | (lr_grid & ~hr_region) vis[outline] = [0, 0, 0] vis[yy < border_px, :] = [0, 0, 0] vis[yy > height - border_px, :] = [0, 0, 0] vis[:, xx < border_px] = [0, 0, 0] vis[:, xx > width - border_px] = [0, 0, 0] return vis def draw_foveation_outline( image_np, mask_token_grid, height: int, width: int, outline_width_frac: float = 0.01, color=(255, 0, 0), ): """Draw a white outline (with thin black backing) around the foveation region, in-place.""" try: import cv2 except ImportError: return image_np mask_np = mask_token_grid.detach().cpu().numpy() if hasattr(mask_token_grid, "detach") else np.asarray(mask_token_grid) mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255 mask_high = cv2.resize(mask_uint8, (width, height), interpolation=cv2.INTER_NEAREST) contours, _ = cv2.findContours(mask_high, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) border_width = max(1, int(outline_width_frac * width)) black_stroke = max(1, int(0.10 * border_width)) white_stroke = max(1, int(0.80 * border_width)) cv2.drawContours(image_np, contours, -1, (255, 255, 255), black_stroke + white_stroke) cv2.drawContours(image_np, contours, -1, (255, 255, 255), white_stroke) return image_np