Spaces:
Runtime error
Runtime error
| """Visualization helpers: tokenization-grid overlay, foveation outline, prompt I/O.""" | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| def load_prompt_dataset(prompt_dataset_path: str): | |
| """Load prompts from a CSV with a `prompt` column.""" | |
| df = pd.read_csv(prompt_dataset_path) | |
| return [row["prompt"] for _, row in df.iterrows()] | |
| def create_tokenization_mask_vis(foveation_mask, height: int, width: int, lr_factor: int = 2): | |
| """Render the HR/LR token grid implied by a token-grid foveation mask. | |
| HR region (mask=1) is white with 16x16 token outlines; LR region (mask=0) | |
| is light gray with (16*lr_factor)x(16*lr_factor) outlines. | |
| """ | |
| mask_np = foveation_mask.detach().cpu().numpy() if hasattr(foveation_mask, "detach") else np.asarray(foveation_mask) | |
| upsampled = np.repeat(np.repeat(mask_np, 16, axis=0), 16, axis=1)[:height, :width] | |
| vis = np.full((height, width, 3), 255, dtype=np.uint8) | |
| if lr_factor > 1: | |
| vis[upsampled < 0.5] = [225, 225, 225] | |
| yy = np.arange(height) | |
| xx = np.arange(width) | |
| border_px = 4 | |
| hr_grid = np.zeros((height, width), dtype=bool) | |
| hr_grid[yy % 16 < border_px, :] = True | |
| hr_grid[:, xx % 16 < border_px] = True | |
| lr_grid = np.zeros((height, width), dtype=bool) | |
| lr_grid[yy % (16 * lr_factor) < border_px, :] = True | |
| lr_grid[:, xx % (16 * lr_factor) < border_px] = True | |
| hr_region = upsampled > 0.5 | |
| outline = (hr_grid & hr_region) | (lr_grid & ~hr_region) | |
| vis[outline] = [0, 0, 0] | |
| vis[yy < border_px, :] = [0, 0, 0] | |
| vis[yy > height - border_px, :] = [0, 0, 0] | |
| vis[:, xx < border_px] = [0, 0, 0] | |
| vis[:, xx > width - border_px] = [0, 0, 0] | |
| return vis | |
| def draw_foveation_outline( | |
| image_np, | |
| mask_token_grid, | |
| height: int, | |
| width: int, | |
| outline_width_frac: float = 0.01, | |
| color=(255, 0, 0), | |
| ): | |
| """Draw a white outline (with thin black backing) around the foveation region, in-place.""" | |
| try: | |
| import cv2 | |
| except ImportError: | |
| return image_np | |
| mask_np = mask_token_grid.detach().cpu().numpy() if hasattr(mask_token_grid, "detach") else np.asarray(mask_token_grid) | |
| mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255 | |
| mask_high = cv2.resize(mask_uint8, (width, height), interpolation=cv2.INTER_NEAREST) | |
| contours, _ = cv2.findContours(mask_high, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| border_width = max(1, int(outline_width_frac * width)) | |
| black_stroke = max(1, int(0.10 * border_width)) | |
| white_stroke = max(1, int(0.80 * border_width)) | |
| cv2.drawContours(image_np, contours, -1, (255, 255, 255), black_stroke + white_stroke) | |
| cv2.drawContours(image_np, contours, -1, (255, 255, 255), white_stroke) | |
| return image_np | |