Docgenie-API / docgenie /generation /handwriting_diffusion /generate_handwriting_diffusion_raw.py
Ahadhassan-2003
deploy: update HF Space
dc4e6da
#!/usr/bin/env python3
"""
Diffusion-based handwriting token generator with intelligent word splitting and stitching.
This script:
- Reads handwriting bbox JSON files with format: "x1,y1,x2,y2,text,block_no,line_no,word_no"
- Intelligently splits long words internally based on --split-length parameter
- Splits numeric sequences within tokens into configurable chunk sizes (default: 2)
- Generates handwriting using HuggingFace diffusion model with text conditioning
- Stitches split word segments horizontally with baseline alignment
- Supports sentence-level reconstruction using line metadata
- Outputs transparent RGBA images with tight cropping
- Maintains consistent writer styles per document
- Supports batched generation for GPU efficiency
Usage example:
python scripts/generate_handwriting_diffusion_raw.py \
--input-dir docvqa-handwritten-sizes4/handwriting_bbox \
--output-dir docvqa-handwritten-sizes4/handwriting_raw_tokens \
--run-dir model/experiments/hf_conditional_latent \
--checkpoint latest.pt \
--steps 30 --split-length 7 --batch-size 8 --device cuda
With sentence stitching and custom baseline:
python scripts/generate_handwriting_diffusion_raw.py \
--input-dir docvqa-handwritten-sizes4/handwriting_bbox \
--output-dir docvqa-handwritten-sizes4/handwriting_raw_tokens \
--run-dir model/experiments/hf_conditional_latent \
--checkpoint latest.pt \
--steps 30 --split-length 7 --stitch-sentences \
--baseline-percentile 85.0 --device cuda
Install requirements:
pip install torch diffusers transformers Pillow PyYAML
Mapping file (raw_token_map.json) structure:
{
"backend": "diffusion-hf",
"split_length": 7,
"entries": [
{
"source_json": "example.json",
"hw_id": "hw0",
"author_id": "author1",
"words": [
{
"block_no": 22,
"line_no": 0,
"word_no": 0,
"image": "example/hw0_0.png",
"style_id": 123,
"width": 250,
"height": 64,
"segments": [
{"token": "genera", "bbox": [x1,y1,x2,y2]},
{"token": "tion", "bbox": [x1,y1,x2,y2]}
]
}
]
}
],
"file_author_styles": {"example.json": {"author1": {"style_id": 123}}}
}
"""
from __future__ import annotations
import argparse
import json
import math
import random
import sys
from copy import deepcopy
from datetime import datetime
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
from .tokenizer import CharTokenizer
from .text_encoder import TextEncoder
try:
import torch
import torch.nn as nn
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
)
from diffusers.training_utils import EMAModel
import numpy as np
from PIL import Image
import yaml
from rich.progress import Progress
except Exception as e:
print(
"[ERROR] Missing dependencies. Install: torch diffusers transformers Pillow PyYAML",
file=sys.stderr,
)
raise
BBox = Tuple[float, float, float, float]
@dataclass
class WordSegment:
"""Represents a segment of a word after splitting."""
token: str
bbox: BBox
original_index: (
int # Track which part of the word this is (0=first, 1=second, etc.)
)
space_before: bool = (
False # True if this segment had a space before it in the original word
)
@dataclass
class WordTask:
"""Represents a complete word (possibly split into segments)."""
source_json: str
hw_id: str
author_id: str
block_no: int
line_no: int
word_no: int
segments: List[WordSegment] # List of segments if word was split
original_bbox: BBox # Original bbox before splitting
include_in_sentence: bool = (
True # Whether this word should be considered for sentence stitching
)
sentence_exclusion_reason: Optional[str] = (
None # Reason for omitting from sentence stitching
)
# ---------------------------- util ----------------------------
def list_json_files(p: Path) -> List[Path]:
return sorted([x for x in p.glob("*.json") if x.is_file()])
def load_json(path: Path):
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def parse_bbox_record(rec: str) -> Tuple[BBox, str, int, int, int]:
"""Parse bbox record in format: x1,y1,x2,y2,text,block_no,line_no,word_no"""
parts = rec.split(",")
if len(parts) < 8:
raise ValueError(f"Invalid bbox record (expected at least 8 parts): {rec}")
x1, y1, x2, y2 = map(float, parts[:4])
block_no = int(parts[-3])
line_no = int(parts[-2])
word_no = int(parts[-1])
# Text is everything between coordinates and the last 3 indices
token = ",".join(parts[4:-3])
return (x1, y1, x2, y2), token, block_no, line_no, word_no
def split_word(word: str, split_length: int) -> List[str]:
"""
Split a word into segments where each segment is AT MOST split_length characters.
All segments will have equal or nearly equal length, with no segment exceeding split_length.
Args:
word: The word to split
split_length: Maximum length for each segment
Returns:
List of word segments (all <= split_length)
Examples:
split_word("generation", 4) -> ["gen", "era", "tio", "n"] (3, 3, 3, 1)
split_word("generation", 5) -> ["gener", "ation"] (5, 5)
split_word("extraordinary", 7) -> ["extraor", "dinary"] (7, 7)
split_word("extraordinary", 5) -> ["extra", "ordin", "ary"] (5, 5, 3)
split_word("hello", 10) -> ["hello"] (5)
Strategy:
- Calculate minimum number of segments needed (ceil(len/split_length))
- Distribute characters as evenly as possible
- Ensure no segment exceeds split_length
"""
if split_length <= 0:
return [word]
word_len = len(word)
if word_len <= split_length:
return [word]
# Calculate minimum number of segments needed
num_segments = (word_len + split_length - 1) // split_length # Ceiling division
# Calculate base length for each segment (will be <= split_length)
base_length = word_len // num_segments
remainder = word_len % num_segments
# Verify base_length doesn't exceed split_length
# (This should always be true given our calculation, but being safe)
assert base_length <= split_length, (
f"base_length {base_length} exceeds split_length {split_length}"
)
# Build segments: first 'remainder' segments get base_length+1, rest get base_length
segments = []
start = 0
for i in range(num_segments):
# First 'remainder' segments get one extra character
seg_length = base_length + (1 if i < remainder else 0)
segments.append(word[start : start + seg_length])
start += seg_length
# Verify all segments are <= split_length
for seg in segments:
assert len(seg) <= split_length, (
f"Segment '{seg}' (len={len(seg)}) exceeds split_length {split_length}"
)
return segments
def split_token_preserving_digit_chunks(
token: str, split_length_words: int, split_length_numeric: int
) -> List[str]:
"""
Split a token while keeping numeric sequences in configurable chunk sizes.
Args:
token: The token to split.
split_length_words: Maximum length for each non-numeric segment.
split_length_numeric: Maximum length for numeric sequences (<=0 disables special handling).
Returns:
List of token segments in the original order.
"""
if split_length_numeric <= 0:
return split_word(token, split_length_words)
segments: List[str] = []
idx = 0
token_len = len(token)
while idx < token_len:
if token[idx].isdigit():
start = idx
while idx < token_len and token[idx].isdigit():
idx += 1
digits = token[start:idx]
effective_chunk = max(1, split_length_numeric)
if split_length_words > 0:
effective_chunk = min(effective_chunk, split_length_words)
for chunk_start in range(0, len(digits), effective_chunk):
segments.append(digits[chunk_start : chunk_start + effective_chunk])
else:
start = idx
while idx < token_len and not token[idx].isdigit():
idx += 1
alpha = token[start:idx]
if alpha:
segments.extend(split_word(alpha, split_length_words))
return segments or [token]
def split_word_with_spaces(
word: str, split_length_words: int, split_length_numeric: int
) -> List[Tuple[str, bool]]:
"""
Split a word into segments, handling spaces first, then applying length-based splitting.
Args:
word: The word to split (may contain spaces)
split_length_words: Maximum length for each segment
split_length_numeric: Maximum length for numeric sequences within each token (<=0 disables special handling)
Returns:
List of tuples (segment_text, space_before) where space_before indicates if this
segment was separated by a space in the original word.
Examples:
split_word_with_spaces("hello world", 10) -> [("hello", False), ("world", True)]
split_word_with_spaces("very long phrase", 5) -> [("very", False), ("long", True), ("phras", True), ("e", False)]
split_word_with_spaces("hello", 3) -> [("hel", False), ("lo", False)]
Strategy:
1. Split at spaces first
2. Apply length-based splitting (with digit chunking) to each space-separated part
3. Mark segments that were separated by spaces with space_before=True
"""
if not word:
return []
# Split at spaces first
space_parts = word.split(" ")
result = []
for part_idx, part in enumerate(space_parts):
if not part: # Skip empty parts (from consecutive spaces)
continue
# Apply length-based splitting to this part
length_segments = split_token_preserving_digit_chunks(
part, split_length_words, split_length_numeric
)
for seg_idx, seg in enumerate(length_segments):
# First segment of non-first parts had a space before it
space_before = part_idx > 0 and seg_idx == 0
result.append((seg, space_before))
return result
def extract_tasks(
json_path: Path,
data: List[Dict[str, Any]],
split_length_words: int,
split_length_numeric: int,
) -> Tuple[List[WordTask], List[Dict[str, Any]]]:
"""
Extract word tasks from JSON data, splitting long words internally.
Args:
json_path: Path to the JSON file
data: Parsed JSON data
split_length_words: Maximum word length before splitting
split_length_numeric: Maximum length for numeric sequences within tokens (<=0 disables special handling)
Returns:
Tuple of (word tasks, extraction log entries)
"""
tasks: List[WordTask] = []
extraction_logs: List[Dict[str, Any]] = []
fallback_counters: Dict[str, int] = defaultdict(int)
zero_bbox: BBox = (0.0, 0.0, 0.0, 0.0)
for obj in data:
# Skip entries without valid data
if obj is None:
continue
hw_id = obj.get("id")
author_id = obj.get("author-id") or obj.get("author_id")
bboxes = obj.get("bboxes")
text_content = (obj.get("text") or "").strip()
# Skip entries with None or empty bboxes
if bboxes is None or not bboxes:
if not text_content:
extraction_logs.append(
{
"type": "extraction_skip",
"source_json": json_path.name,
"hw_id": hw_id,
"reason": "missing_bbox_no_text",
}
)
continue
fallback_words = [w for w in text_content.split() if w]
if not fallback_words:
extraction_logs.append(
{
"type": "extraction_skip",
"source_json": json_path.name,
"hw_id": hw_id,
"reason": "missing_bbox_no_tokens",
}
)
continue
for fallback_idx, raw_word in enumerate(fallback_words):
word_segments_with_flags = split_word_with_spaces(
raw_word, split_length_words, split_length_numeric
)
if not word_segments_with_flags:
continue
segments: List[WordSegment] = []
for seg_idx, (seg_text, space_before) in enumerate(
word_segments_with_flags
):
segments.append(
WordSegment(
token=seg_text,
bbox=zero_bbox,
original_index=seg_idx,
space_before=space_before,
)
)
fallback_counter = fallback_counters[hw_id]
fallback_counters[hw_id] += 1
tasks.append(
WordTask(
source_json=json_path.name,
hw_id=hw_id,
author_id=author_id,
block_no=-1,
line_no=-1,
word_no=100000 + fallback_counter,
segments=segments,
original_bbox=zero_bbox,
include_in_sentence=False,
sentence_exclusion_reason="missing_bbox",
)
)
extraction_logs.append(
{
"type": "extraction_notice",
"source_json": json_path.name,
"hw_id": hw_id,
"reason": "missing_bbox_generated",
"num_words": len(fallback_words),
}
)
continue
for idx, rec in enumerate(bboxes):
bbox, token, block_no, line_no, word_no = parse_bbox_record(rec)
# Split word with space-awareness (splits at spaces first, then by length)
word_segments_with_flags = split_word_with_spaces(
token, split_length_words, split_length_numeric
)
# Create WordSegment objects for each part
segments = []
for seg_idx, (seg_text, space_before) in enumerate(
word_segments_with_flags
):
segments.append(
WordSegment(
token=seg_text,
bbox=bbox, # Use same bbox for all segments (will adjust proportionally if needed)
original_index=seg_idx,
space_before=space_before,
)
)
tasks.append(
WordTask(
source_json=json_path.name,
hw_id=hw_id,
author_id=author_id,
block_no=block_no,
line_no=line_no,
word_no=word_no,
segments=segments,
original_bbox=bbox,
)
)
return tasks, extraction_logs
def style_id_for_file(json_name: str, author_id: str, seed: int, vocab: int) -> int:
"""Deterministically derive a style id for (json_name, author_id) combo."""
composite = f"{json_name}::{author_id}"
return (hash(composite) ^ seed) % vocab
def build_word_filename(task: WordTask) -> str:
"""Create a unique filename for a word using hw_id, block, line, and word numbers."""
block_part = f"b{task.block_no}" if task.block_no is not None else "bX"
line_part = f"l{task.line_no}" if task.line_no is not None else "lX"
word_part = f"w{task.word_no}"
return f"{task.hw_id}_{block_part}_{line_part}_{word_part}.png"
# ------------------------ generation -------------------------
def load_experiment(
run_dir: Path, checkpoint_name: str, device: torch.device
) -> Dict[str, Any]:
"""
Load model components from experiment directory.
Based on inference_hf.ipynb load_experiment function.
"""
run_dir = run_dir.expanduser().resolve()
if not run_dir.exists():
raise FileNotFoundError(f"Run directory {run_dir} does not exist.")
config_path = run_dir / "config.yaml"
if not config_path.exists():
raise FileNotFoundError(f"Expected config at {config_path}.")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
# Load tokenizer
vocab_path = Path(config["data"]["vocab_path"])
if not vocab_path.is_absolute():
vocab_path = run_dir / vocab_path
if not vocab_path.exists():
vocab_path = run_dir.parent / config["data"]["vocab_path"]
tokenizer = CharTokenizer.load(str(vocab_path))
# Load writer_id_map
writer_map_path = run_dir / "writer_id_map.json"
if not writer_map_path.exists():
raise FileNotFoundError(f"Expected writer mapping at {writer_map_path}.")
with open(writer_map_path, "r", encoding="utf-8") as f:
raw_writer_map = json.load(f)
writer_id_map = {str(k): int(v) for k, v in raw_writer_map.items()}
num_writers = len(writer_id_map)
# Load text encoder
text_cfg = config["model"]["text_encoder"]
text_encoder = TextEncoder(
vocab_size=len(tokenizer),
d_model=text_cfg["d_model"],
num_layers=text_cfg["num_layers"],
num_heads=text_cfg["num_heads"],
d_ff=text_cfg["d_ff"],
dropout=text_cfg["dropout"],
max_length=text_cfg["max_length"],
output_dim=text_cfg.get("output_dim", text_cfg["d_model"]),
).to(device)
text_encoder.eval()
# Load UNet
unet_cfg = deepcopy(config["model"]["unet"])
pretrained_path = unet_cfg.pop("pretrained_model_name_or_path", None)
# Ensure tuple types
for key in ("down_block_types", "up_block_types", "block_out_channels"):
if key in unet_cfg and isinstance(unet_cfg[key], list):
unet_cfg[key] = tuple(unet_cfg[key])
if "sample_size" in unet_cfg and isinstance(unet_cfg["sample_size"], list):
unet_cfg["sample_size"] = tuple(unet_cfg["sample_size"])
# Set num_class_embeds from writer_id_map
unet_cfg["num_class_embeds"] = num_writers
if pretrained_path:
unet = UNet2DConditionModel.from_pretrained(
pretrained_path, num_class_embeds=num_writers
).to(device)
else:
unet = UNet2DConditionModel(**unet_cfg).to(device)
unet.eval()
# Load scheduler - using DPM-Solver++ with order 3 for fast, high-quality sampling
scheduler_cfg = config["model"]["scheduler"]
noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=scheduler_cfg["num_train_timesteps"],
beta_start=scheduler_cfg["beta_start"],
beta_end=scheduler_cfg["beta_end"],
beta_schedule=scheduler_cfg["beta_schedule"],
prediction_type=scheduler_cfg.get("prediction_type", "epsilon"),
algorithm_type="dpmsolver++",
solver_order=3, # Higher order = better quality
use_karras_sigmas=scheduler_cfg.get("use_karras_sigmas", False),
)
# Add timestep_spacing if specified in config
if "timestep_spacing" in scheduler_cfg:
noise_scheduler.config.timestep_spacing = scheduler_cfg["timestep_spacing"]
# Load VAE if latent mode
mode = config["training"].get("mode", "latent")
vae = None
vae_scale_factor = 0.18215
if mode == "latent":
vae_config = config["model"].get("vae")
if vae_config is None:
raise KeyError("Latent mode requires 'model.vae' configuration.")
vae_model_name = vae_config["model_name"]
vae_cache_dir = run_dir / "cached_vae"
if vae_cache_dir.exists():
vae = AutoencoderKL.from_pretrained(vae_cache_dir).to(device)
else:
vae = AutoencoderKL.from_pretrained(vae_model_name).to(device)
vae_cache_dir.mkdir(parents=True, exist_ok=True)
vae.save_pretrained(vae_cache_dir)
vae.eval()
# Load checkpoint
checkpoint_path = run_dir / checkpoint_name
print(checkpoint_path)
if not checkpoint_path.exists():
checkpoint_path = Path(checkpoint_name)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint {checkpoint_name} not found.")
checkpoint = torch.load(checkpoint_path, map_location=device)
text_encoder.load_state_dict(checkpoint["text_encoder"])
unet.load_state_dict(checkpoint["unet"], strict=False)
# Load EMA if available
ema_model = None
if "ema" in checkpoint:
training_cfg = config.get("training", {})
use_warmup = training_cfg.get("ema_use_warmup", False)
ema_model = EMAModel(
unet.parameters(),
decay=training_cfg.get("ema_decay", 0.9999),
use_ema_warmup=use_warmup,
inv_gamma=training_cfg.get("ema_inv_gamma", 1.0),
power=training_cfg.get("ema_power", 1.0),
min_decay=training_cfg.get("ema_min_decay", 0.0),
device=device,
model_cls=UNet2DConditionModel,
model_config=unet.config,
)
ema_model.load_state_dict(checkpoint["ema"])
ema_model.to(device)
ema_model.copy_to(unet.parameters())
latent_shape = config["model"].get("latent_shape")
image_shape = config["model"].get("image_shape")
if mode == "latent":
sample_shape = tuple(latent_shape)
else:
sample_shape = tuple(image_shape)
return {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"unet": unet,
"noise_scheduler": noise_scheduler,
"vae": vae,
"vae_scale_factor": vae_scale_factor,
"writer_id_map": writer_id_map,
"device": device,
"config": config,
"sample_shape": sample_shape,
"mode": mode,
}
def diffusion_generate_batch(
tokens: List[str],
style_ids: List[int],
components: Dict[str, Any],
steps: int,
temperature: float = 1.0,
) -> List[Image.Image]:
"""
Generate batch of handwriting images using diffusion model.
Based on sample_diffusion from inference_hf.ipynb.
"""
if not tokens:
return []
device = components["device"]
tokenizer = components["tokenizer"]
text_encoder = components["text_encoder"]
unet = components["unet"]
noise_scheduler = components["noise_scheduler"]
sample_shape = components["sample_shape"]
mode = components["mode"]
vae = components.get("vae")
vae_scale_factor = components.get("vae_scale_factor", 0.18215)
# Encode text
encodings = tokenizer.encode_batch(tokens)
input_ids = torch.tensor(encodings["input_ids"], device=device, dtype=torch.long)
attention_mask = torch.tensor(
encodings["attention_mask"], device=device, dtype=torch.float32
)
# Convert writer style IDs to class indices
writer_indices = torch.tensor(style_ids, device=device, dtype=torch.long)
# Set timesteps
noise_scheduler.set_timesteps(steps, device=device)
timesteps = noise_scheduler.timesteps
# Initialize latents
batch_shape = (len(tokens),) + tuple(sample_shape)
latents = torch.randn(batch_shape, device=device) * temperature
# Generate text features
with torch.no_grad():
text_features = text_encoder(input_ids, attention_mask=attention_mask)
# Sampling loop
for timestep in timesteps:
t_batch = torch.full(
(len(tokens),), int(timestep), device=device, dtype=torch.long
)
model_output = unet(
latents,
t_batch,
encoder_hidden_states=text_features,
encoder_attention_mask=attention_mask,
class_labels=writer_indices,
)
noise_pred = (
model_output.sample if hasattr(model_output, "sample") else model_output
)
scheduler_step = noise_scheduler.step(noise_pred, int(timestep), latents)
latents = scheduler_step.prev_sample
# Decode if latent mode
if mode == "latent" and vae is not None:
latents = latents / vae_scale_factor
decoded = vae.decode(latents).sample
else:
decoded = latents
images = (decoded / 2 + 0.5).clamp(0.0, 1.0)
# Convert to PIL images with cropping and transparency
results: List[Image.Image] = []
imgs = images.cpu().numpy()
for i in range(len(tokens)):
arr = imgs[i]
if arr.shape[0] == 1:
arr = arr[0] # Remove channel dim if grayscale
else:
arr = arr.transpose(1, 2, 0) # CHW -> HWC
arr8 = (arr * 255).round().astype("uint8")
# Binarize
if arr8.ndim == 3:
arr8 = arr8.mean(axis=2).astype("uint8")
thresh = otsu_threshold(arr8)
bin_arr = (arr8 > thresh).astype("uint8") * 255
# Crop to content
cropped, crop_box = crop_to_content(bin_arr)
# Convert to RGBA
rgba = binary_to_rgba(cropped)
rgba.info["crop_box"] = crop_box
results.append(rgba)
return results
# ---------------------- binarization utils -------------------
def otsu_threshold(arr8):
hist = np.bincount(arr8.ravel(), minlength=256).astype(np.float64)
total = arr8.size
sum_total = (hist * np.arange(256)).sum()
weight_bg = 0.0
sum_bg = 0.0
max_between = -1.0
thresh = 0
for i in range(256):
weight_bg += hist[i]
if weight_bg == 0:
continue
weight_fg = total - weight_bg
if weight_fg == 0:
break
sum_bg += i * hist[i]
mean_bg = sum_bg / weight_bg
mean_fg = (sum_total - sum_bg) / weight_fg
between = weight_bg * weight_fg * (mean_bg - mean_fg) ** 2
if between > max_between:
max_between = between
thresh = i
return thresh
# ---------------------- cropping & alpha --------------------
def crop_to_content(bin_arr: np.ndarray, pad: int = 0):
"""Crop binary array (0=ink,255=bg) to tight bounding box. Returns (cropped_array, (x1,y1,x2,y2))."""
h, w = bin_arr.shape
ink_mask = bin_arr < 255
if not ink_mask.any():
# No ink; return 1x1 transparent placeholder
return bin_arr[:1, :1], (0, 0, 1, 1)
rows = np.where(ink_mask.any(axis=1))[0]
cols = np.where(ink_mask.any(axis=0))[0]
y1, y2 = rows[0], rows[-1]
x1, x2 = cols[0], cols[-1]
if pad:
x1 = max(0, x1 - pad)
y1 = max(0, y1 - pad)
x2 = min(w - 1, x2 + pad)
y2 = min(h - 1, y2 + pad)
cropped = bin_arr[y1 : y2 + 1, x1 : x2 + 1]
return cropped, (
int(x1),
int(y1),
int(x2) + 1,
int(y2) + 1,
) # x2,y2 exclusive for convenience
def binary_to_rgba(bin_arr: np.ndarray) -> Image.Image:
"""Convert binary (0 ink, 255 bg) to RGBA with transparent background."""
h, w = bin_arr.shape
# Ink black RGB (0,0,0), alpha 255 where ink, 0 where bg
alpha = (bin_arr == 0).astype("uint8") * 255
rgb = np.zeros((h, w, 3), dtype="uint8") # already black
rgba = np.dstack([rgb, alpha])
return Image.fromarray(rgba, mode="RGBA")
def pad_tokens_to_equal_length(tokens: List[str]) -> List[str]:
"""Pad tokens to equal length by appending spaces to shorter tokens."""
if not tokens:
return tokens
max_len = max(len(t) for t in tokens)
print([t.ljust(max_len) for t in tokens])
return [t.ljust(max_len) for t in tokens]
def calculate_baseline_info(
img: Image.Image, baseline_percentile: float = 85.0
) -> Dict[str, Any]:
"""
Calculate baseline information for an RGBA image.
Args:
img: RGBA PIL Image
baseline_percentile: Percentile for baseline detection (default: 85.0)
Returns:
Dictionary with baseline metrics:
- baseline_y: Absolute baseline position (pixels from top)
- baseline_ratio: Baseline as ratio of height (0.0-1.0)
- height_above: Pixels above baseline
- height_below: Pixels below baseline
- ascender_ratio: Ratio of height above baseline
- descender_ratio: Ratio of height below baseline
"""
arr = np.array(img)
height = img.height
if arr.shape[2] == 4: # RGBA
alpha = arr[:, :, 3]
else:
alpha = np.ones((height, img.width), dtype=np.uint8) * 255
ink_mask = alpha > 200
if not ink_mask.any():
# No ink, use bottom as baseline
baseline_y = height - 1
else:
# Find bottom-most ink pixels for each column
bottom_candidates = []
cols_with_ink = np.where(ink_mask.any(axis=0))[0]
for col_idx in cols_with_ink:
ink_rows = np.where(ink_mask[:, col_idx])[0]
if ink_rows.size > 0:
bottom_candidates.append(int(ink_rows[-1]))
if bottom_candidates:
baseline_y = int(np.percentile(bottom_candidates, baseline_percentile))
else:
baseline_y = height - 1
height_above = baseline_y
height_below = height - 1 - baseline_y
return {
"baseline_y": baseline_y,
"baseline_ratio": baseline_y / height if height > 0 else 0.0,
"height_above": height_above,
"height_below": height_below,
"ascender_ratio": height_above / height if height > 0 else 0.0,
"descender_ratio": height_below / height if height > 0 else 0.0,
}
def concatenate_images_horizontal(
images: List[Image.Image],
gap: int = 0,
baseline_align: bool = True,
baseline_percentile: float = 75.0,
) -> Image.Image:
"""
Horizontally concatenate a list of RGBA images with baseline alignment.
Args:
images: List of RGBA images to concatenate
gap: Spacing between images in pixels
baseline_align: If True, align by baseline; if False, center vertically
baseline_percentile: Percentile for baseline detection (default: 85.0)
Returns:
Concatenated RGBA image
"""
if not images:
raise ValueError("Cannot concatenate empty image list")
if len(images) == 1:
return images[0]
if baseline_align:
# Calculate baseline for each image
baselines = []
max_above_baseline = 0
max_below_baseline = 0
for img in images:
# Convert to grayscale array
arr = np.array(img)
if arr.shape[2] == 4: # RGBA
alpha = arr[:, :, 3]
else:
alpha = np.ones((arr.shape[0], arr.shape[1]), dtype=np.uint8) * 255
# Find ink pixels
ink_mask = alpha > 200
if not ink_mask.any():
# No ink, use bottom as baseline
baseline = img.height - 1
else:
# Find bottom-most ink pixels for each column (optimized: only iterate columns with ink)
bottom_candidates = []
cols_with_ink = np.where(ink_mask.any(axis=0))[0]
for col_idx in cols_with_ink:
ink_rows = np.where(ink_mask[:, col_idx])[0]
if ink_rows.size > 0:
bottom_candidates.append(int(ink_rows[-1]))
if bottom_candidates:
baseline = int(
np.percentile(bottom_candidates, baseline_percentile)
)
else:
baseline = img.height - 1
baselines.append(baseline)
# Calculate space above and below baseline
above = baseline
below = img.height - 1 - baseline
max_above_baseline = max(max_above_baseline, above)
max_below_baseline = max(max_below_baseline, below)
# Total height needed
canvas_height = max_above_baseline + 1 + max_below_baseline
total_width = sum(img.width for img in images) + gap * (len(images) - 1)
# Create canvas
result = Image.new("RGBA", (total_width, canvas_height), (0, 0, 0, 0))
# Paste images aligned by baseline
x_offset = 0
for img, baseline in zip(images, baselines):
# Calculate y position to align baselines
y_offset = max_above_baseline - baseline
result.paste(img, (x_offset, y_offset), img)
x_offset += img.width + gap
else:
# Simple vertical centering
max_height = max(img.height for img in images)
total_width = sum(img.width for img in images) + gap * (len(images) - 1)
result = Image.new("RGBA", (total_width, max_height), (0, 0, 0, 0))
x_offset = 0
for img in images:
y_offset = (max_height - img.height) // 2
result.paste(img, (x_offset, y_offset), img)
x_offset += img.width + gap
return result
def concatenate_segments_with_variable_gaps(
images: List[Image.Image],
segments: List[WordSegment],
segment_gap: int = 2,
word_gap: int = 20,
baseline_percentile: float = 75.0,
) -> Image.Image:
"""
Concatenate word segments with variable gaps based on whether they were separated by spaces.
Args:
images: List of RGBA segment images (same length as segments)
segments: List of WordSegment objects with space_before flags
segment_gap: Gap for length-split segments (no space in original)
word_gap: Gap for space-separated segments
baseline_percentile: Percentile for baseline detection
Returns:
Concatenated RGBA image with appropriate gaps
"""
if not images:
raise ValueError("Cannot concatenate empty image list")
if len(images) == 1:
return images[0]
if len(images) != len(segments):
raise ValueError(f"Mismatch: {len(images)} images but {len(segments)} segments")
# Calculate baseline for each image
baselines = []
max_above_baseline = 0
max_below_baseline = 0
for img in images:
arr = np.array(img)
if arr.shape[2] == 4: # RGBA
alpha = arr[:, :, 3]
else:
alpha = np.ones((arr.shape[0], arr.shape[1]), dtype=np.uint8) * 255
ink_mask = alpha > 200
if not ink_mask.any():
baseline = img.height - 1
else:
bottom_candidates = []
cols_with_ink = np.where(ink_mask.any(axis=0))[0]
for col_idx in cols_with_ink:
ink_rows = np.where(ink_mask[:, col_idx])[0]
if ink_rows.size > 0:
bottom_candidates.append(int(ink_rows[-1]))
if bottom_candidates:
baseline = int(np.percentile(bottom_candidates, baseline_percentile))
else:
baseline = img.height - 1
baselines.append(baseline)
above = baseline
below = img.height - 1 - baseline
max_above_baseline = max(max_above_baseline, above)
max_below_baseline = max(max_below_baseline, below)
# Calculate total width based on variable gaps
canvas_height = max_above_baseline + 1 + max_below_baseline
total_width = sum(img.width for img in images)
for i in range(1, len(images)):
# Use word_gap if this segment had a space before it, else segment_gap
gap = word_gap if segments[i].space_before else segment_gap
total_width += gap
# Create canvas and paste images
result = Image.new("RGBA", (total_width, canvas_height), (0, 0, 0, 0))
x_offset = 0
for i, (img, baseline, segment) in enumerate(zip(images, baselines, segments)):
y_offset = max_above_baseline - baseline
result.paste(img, (x_offset, y_offset), img)
x_offset += img.width
# Add appropriate gap before next image
if i < len(images) - 1:
gap = word_gap if segments[i + 1].space_before else segment_gap
x_offset += gap
return result
# -------------------------- main -----------------------------
def generate_handwriting(
input_dir: Path,
output_dir: Path,
run_dir: Path,
checkpoint: str = "latest.pt",
progress: Progress | None = None,
steps: int = 30,
split_length_words: int = 6,
split_length_numeric: int = 2,
temperature: float = 0.5,
seed: int = 42,
device: str = "cuda",
overwrite: bool = False,
mapping_file: Optional[Path] = None,
log_file: Optional[Path] = None,
batch_size: int = 32,
stitch_sentences: bool = True,
segment_gap: int = 2,
word_gap: int = 20,
baseline_percentile: float = 75.0,
allowed_writers: Optional[List[str]] = None,
) -> None:
"""Generate handwriting images and metadata using configured diffusion models."""
random.seed(seed)
torch.manual_seed(seed)
device_obj = torch.device(
device if torch.cuda.is_available() or device == "cpu" else "cpu"
)
input_dir = Path(input_dir)
output_dir = Path(output_dir)
run_dir = Path(run_dir)
mapping_file = Path(mapping_file) if mapping_file is not None else None
log_file = Path(log_file) if log_file is not None else None
# Load model components
print(f"Loading model from {run_dir}...")
components = load_experiment(run_dir, checkpoint, device_obj)
print(f"✓ Model loaded successfully")
print(f" Mode: {components['mode']}")
print(f" Sample shape: {components['sample_shape']}")
print(f" Writers: {len(components['writer_id_map'])}")
output_dir.mkdir(parents=True, exist_ok=True)
# Load JSON files
json_files = list_json_files(input_dir)
if not json_files:
print("[ERROR] No JSON files found.", file=sys.stderr)
sys.exit(1)
print(f"Found {len(json_files)} JSON files")
# Extract tasks with word splitting
tasks: List[WordTask] = []
extraction_logs: List[Dict[str, Any]] = []
for jf in json_files:
data = load_json(jf)
extracted_tasks, extracted_log_entries = extract_tasks(
jf, data, split_length_words, split_length_numeric
)
tasks.extend(extracted_tasks)
extraction_logs.extend(extracted_log_entries)
print(f"Extracted {len(tasks)} word tasks")
if split_length_words > 0:
total_segments = sum(len(t.segments) for t in tasks)
print(
f" Split into {total_segments} segments (split_length={split_length_words}, digit_chunk_length={split_length_numeric})"
)
# Per-file author style mapping
file_author_style_ids: Dict[str, Dict[str, int]] = {}
writer_id_map = components["writer_id_map"]
# Filter to allowed writers if specified
allowed_writer_ids = None
if allowed_writers is not None:
allowed_writer_ids = []
for w in allowed_writers:
try:
writer_id = int(w)
if 0 <= writer_id < len(writer_id_map):
allowed_writer_ids.append(writer_id)
else:
print(
f"[WARNING] Writer ID {writer_id} out of range (0-{len(writer_id_map) - 1}), ignoring"
)
except ValueError:
print(f"[WARNING] Invalid writer ID '{w}', must be integer, ignoring")
if not allowed_writer_ids:
print("[ERROR] No valid writer IDs provided in --allowed-writers")
sys.exit(1)
print(
f"Using {len(allowed_writer_ids)} allowed writer(s): {sorted(allowed_writer_ids)}"
)
# Set up RNG for random writer selection if needed
rng = random.Random(seed)
for t in tasks:
file_author_style_ids.setdefault(t.source_json, {})
if t.author_id not in file_author_style_ids[t.source_json]:
# Map author_id to writer index from the model's writer_id_map
if t.author_id in writer_id_map:
style_id = writer_id_map[t.author_id]
# If allowed_writers specified and this author's style not in list, randomly pick from allowed
if (
allowed_writer_ids is not None
and style_id not in allowed_writer_ids
):
style_id = rng.choice(allowed_writer_ids)
else:
# Author not in map: use allowed writers if specified, else fallback to hashing
if allowed_writer_ids is not None:
style_id = rng.choice(allowed_writer_ids)
else:
style_id = style_id_for_file(
t.source_json, t.author_id, seed, len(writer_id_map)
)
file_author_style_ids[t.source_json][t.author_id] = style_id
results: List[Dict[str, Any]] = []
generation_skip_log: List[Dict[str, Any]] = []
generation_error_log: List[Dict[str, Any]] = []
sentence_exclusion_log: List[Dict[str, Any]] = []
total_words = len(tasks)
effective_batch_size = max(1, batch_size)
progress = progress or Progress(transient=True)
generation_task_id = progress.add_task("Generating words", total=total_words)
for word_idx in range(0, total_words, effective_batch_size):
batch_tasks = tasks[word_idx : word_idx + effective_batch_size]
# Process each word task
for task in batch_tasks:
json_stem = Path(task.source_json).stem
doc_dir = output_dir / json_stem
doc_dir.mkdir(parents=True, exist_ok=True)
# Output filename includes block and line numbers to avoid collisions across lines
out_name = build_word_filename(task)
relative_image_path = f"{json_stem}/{out_name}"
out_path = doc_dir / out_name
if out_path.exists() and not overwrite:
# Load existing metadata
try:
existing_img = Image.open(out_path)
w, h = existing_img.size
baseline_info = calculate_baseline_info(
existing_img, baseline_percentile=baseline_percentile
)
results.append(
{
"image": relative_image_path,
"hw_id": task.hw_id,
"author_id": task.author_id,
"style_id": file_author_style_ids[task.source_json][
task.author_id
],
"source_json": task.source_json,
"block_no": task.block_no,
"line_no": task.line_no,
"word_no": task.word_no,
"segments": [
{
"token": seg.token,
"bbox": list(seg.bbox),
"space_before": seg.space_before,
}
for seg in task.segments
],
"skipped": True,
"skip_reason": "existing_output",
"include_in_sentence": task.include_in_sentence,
"sentence_exclusion_reason": task.sentence_exclusion_reason,
"width": w,
"height": h,
"baseline": baseline_info,
}
)
generation_skip_log.append(
{
"type": "existing_output",
"source_json": task.source_json,
"hw_id": task.hw_id,
"word_no": task.word_no,
"block_no": task.block_no,
"line_no": task.line_no,
"image": relative_image_path,
}
)
if not task.include_in_sentence:
sentence_exclusion_log.append(
{
"source_json": task.source_json,
"hw_id": task.hw_id,
"word_no": task.word_no,
"block_no": task.block_no,
"line_no": task.line_no,
"image": relative_image_path,
"reason": task.sentence_exclusion_reason
or "manual_exclusion",
}
)
except Exception as e:
print(f"[WARN] Could not load existing {out_path}: {e}")
continue
# Generate all segments for this word
try:
tokens_batch = [seg.token for seg in task.segments]
style_id = file_author_style_ids[task.source_json][task.author_id]
style_ids_batch = [style_id] * len(tokens_batch)
segment_images = diffusion_generate_batch(
tokens_batch,
style_ids_batch,
components,
steps,
temperature=temperature,
)
# Concatenate segments with variable gaps (word-gap for spaces, segment-gap for length splits)
if len(segment_images) > 1:
final_image = concatenate_segments_with_variable_gaps(
segment_images,
task.segments,
segment_gap=segment_gap,
word_gap=word_gap,
baseline_percentile=baseline_percentile,
)
else:
final_image = segment_images[0]
# Save
w, h = final_image.size
final_image.save(out_path)
# Calculate baseline information for alignment
baseline_info = calculate_baseline_info(
final_image, baseline_percentile=baseline_percentile
)
results.append(
{
"image": relative_image_path,
"hw_id": task.hw_id,
"author_id": task.author_id,
"style_id": style_id,
"source_json": task.source_json,
"block_no": task.block_no,
"line_no": task.line_no,
"word_no": task.word_no,
"segments": [
{
"token": seg.token,
"bbox": list(seg.bbox),
"space_before": seg.space_before,
}
for seg in task.segments
],
"skipped": False,
"skip_reason": None,
"include_in_sentence": task.include_in_sentence,
"sentence_exclusion_reason": task.sentence_exclusion_reason,
"width": w,
"height": h,
"baseline": baseline_info,
}
)
if not task.include_in_sentence:
sentence_exclusion_log.append(
{
"source_json": task.source_json,
"hw_id": task.hw_id,
"word_no": task.word_no,
"block_no": task.block_no,
"line_no": task.line_no,
"image": relative_image_path,
"reason": task.sentence_exclusion_reason
or "manual_exclusion",
}
)
except Exception as e:
print(
f"[ERROR] Generation failed for {task.hw_id} word {task.word_no}: {e}",
file=sys.stderr,
)
import traceback
traceback.print_exc()
generation_error_log.append(
{
"type": "generation_error",
"source_json": task.source_json,
"hw_id": task.hw_id,
"word_no": task.word_no,
"block_no": task.block_no,
"line_no": task.line_no,
"reason": str(e),
"traceback": traceback.format_exc(),
}
)
if progress and generation_task_id is not None:
progress.advance(generation_task_id, len(batch_tasks))
# Sentence-level stitching (if requested)
if stitch_sentences:
print("\nStitching words into sentences...")
sentences_dir = output_dir / "sentences"
sentences_dir.mkdir(exist_ok=True)
# Group results by (source_json, hw_id, block_no, line_no)
line_groups: Dict[Tuple[str, str, int, int], List[Dict[str, Any]]] = {}
for r in results:
if r["skipped"]:
continue
if not r.get("include_in_sentence", True):
continue
key = (r["source_json"], r["hw_id"], r["block_no"], r["line_no"])
line_groups.setdefault(key, []).append(r)
# Sort words within each line by word_no
for key in line_groups:
line_groups[key].sort(key=lambda x: x["word_no"])
sentence_results: List[Dict[str, Any]] = []
sentence_progress = progress
sentence_task_id = sentence_progress.add_task(
"Stitching sentences", total=len(line_groups)
)
for (source_json, hw_id, block_no, line_no), word_list in line_groups.items():
if not word_list:
continue
json_stem = Path(source_json).stem
sent_doc_dir = sentences_dir / json_stem
sent_doc_dir.mkdir(parents=True, exist_ok=True)
# Output filename: hw{id}_block{block}_line{line}.png
sent_name = f"{hw_id}_block{block_no}_line{line_no}.png"
sent_relative_path = f"sentences/{json_stem}/{sent_name}"
sent_path = sent_doc_dir / sent_name
if sent_path.exists() and not overwrite:
if sentence_progress and sentence_task_id is not None:
sentence_progress.advance(sentence_task_id, 1)
continue
try:
# Load all word images for this line
word_images = []
for word_data in word_list:
word_img_path = output_dir / word_data["image"]
if word_img_path.exists():
word_images.append(Image.open(word_img_path))
if not word_images:
continue
# Stitch words together with larger gap
sentence_image = concatenate_images_horizontal(
word_images,
gap=word_gap,
baseline_align=True,
baseline_percentile=baseline_percentile,
)
# Save sentence image
sentence_image.save(sent_path)
# Collect text for this line
line_text = " ".join(
[
"".join([seg["token"] for seg in w["segments"]])
for w in word_list
]
)
sentence_results.append(
{
"image": sent_relative_path,
"source_json": source_json,
"hw_id": hw_id,
"block_no": block_no,
"line_no": line_no,
"text": line_text,
"num_words": len(word_list),
"width": sentence_image.width,
"height": sentence_image.height,
}
)
except Exception as e:
print(
f"[ERROR] Failed to stitch sentence {hw_id} block{block_no} line{line_no}: {e}",
file=sys.stderr,
)
if sentence_progress and sentence_task_id is not None:
sentence_progress.advance(sentence_task_id, 1)
# Save sentence mapping
sentence_mapping_file = sentences_dir / "sentence_map.json"
with sentence_mapping_file.open("w", encoding="utf-8") as f:
json.dump(
{
"backend": "diffusion-hf-sentences",
"word_gap": word_gap,
"sentences": sentence_results,
},
f,
ensure_ascii=False,
indent=2,
)
print(f"✓ Generated {len(sentence_results)} sentence images")
print(f"✓ Sentence mapping saved: {sentence_mapping_file}")
# Build mapping structure
entries_map: Dict[Tuple[str, str], List[Dict[str, Any]]] = {}
for r in results:
key = (r["source_json"], r["hw_id"])
entries_map.setdefault(key, []).append(r)
# Export file author styles
file_author_styles_export = {
fname: {aid: {"style_id": sid} for aid, sid in inner.items()}
for fname, inner in sorted(file_author_style_ids.items())
}
consolidated = {
"backend": "diffusion-hf",
"split_length": split_length_words,
"digit_chunk_length": split_length_numeric,
"temperature": temperature,
"steps": steps,
"segment_gap": segment_gap,
"word_gap": word_gap if stitch_sentences else None,
"baseline_percentile": baseline_percentile,
"entries": [
{
"source_json": src,
"hw_id": hw,
"author_id": words[0]["author_id"] if words else None,
"words": [
{
"block_no": w["block_no"],
"line_no": w["line_no"],
"word_no": w["word_no"],
"image": w["image"],
"style_id": w["style_id"],
"width": w["width"],
"height": w["height"],
"baseline": w["baseline"],
"segments": w["segments"],
}
for w in sorted(
words, key=lambda x: (x["block_no"], x["line_no"], x["word_no"])
)
],
}
for (src, hw), words in sorted(entries_map.items())
],
"file_author_styles": file_author_styles_export,
}
mapping_path = mapping_file or (output_dir / "raw_token_map.json")
with mapping_path.open("w", encoding="utf-8") as f:
json.dump(consolidated, f, ensure_ascii=False, indent=2)
generated_count = sum(1 for r in results if not r["skipped"])
reused_count = sum(1 for r in results if r["skipped"])
log_file_path = log_file or (output_dir / "generation_log.json")
log_payload = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"summary": {
"total_tasks": len(tasks),
"extraction_skips": len(
[
entry
for entry in extraction_logs
if entry.get("type") == "extraction_skip"
]
),
"words_generated": generated_count,
"words_reused": reused_count,
"generation_errors": len(generation_error_log),
"sentence_exclusions": len(sentence_exclusion_log),
},
"details": {
"extraction": extraction_logs,
"generation_skips": generation_skip_log,
"generation_errors": generation_error_log,
"sentence_exclusions": sentence_exclusion_log,
},
}
with log_file_path.open("w", encoding="utf-8") as log_fp:
json.dump(log_payload, log_fp, ensure_ascii=False, indent=2)
print(f"\n✓ Generated {len(results)} word images")
print(f"✓ Mapping saved: {mapping_path}")
print(f"✓ Log saved: {log_file_path}")
print("[DONE] Freeing up memory..")
for k, v in components.items():
del v
del components
torch.cuda.empty_cache()
def main() -> None:
ap = argparse.ArgumentParser(
description="Diffusion-based handwriting token generator with intelligent word splitting."
)
ap.add_argument(
"--input-dir",
type=Path,
required=True,
help="Directory containing bbox JSON files",
)
ap.add_argument(
"--output-dir",
type=Path,
required=True,
help="Output directory for generated images",
)
ap.add_argument(
"--run-dir",
type=Path,
required=True,
help="Model experiment directory (e.g., model/experiments/hf_conditional_latent)",
)
ap.add_argument(
"--checkpoint", type=str, default="latest.pt", help="Checkpoint filename"
)
ap.add_argument("--steps", type=int, default=30, help="Number of diffusion steps")
ap.add_argument(
"--split-length-words",
type=int,
default=6,
help="Maximum word length before splitting (0 = no splitting)",
)
ap.add_argument(
"--temperature", type=float, default=0.5, help="Sampling temperature"
)
ap.add_argument("--seed", type=int, default=42, help="Random seed")
ap.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
ap.add_argument(
"--overwrite", action="store_true", help="Overwrite existing images"
)
ap.add_argument(
"--mapping-file", type=Path, default=None, help="Output mapping JSON path"
)
ap.add_argument(
"--log-file",
type=Path,
default=None,
help="Optional path for JSON log output (default: output_dir/generation_log.json)",
)
ap.add_argument(
"--batch-size", type=int, default=32, help="Batch size for generation"
)
ap.add_argument(
"--stitch-sentences",
default=True,
action="store_true",
help="Generate sentence-level stitched images in separate folder",
)
ap.add_argument(
"--segment-gap",
type=int,
default=2,
help="Gap between word segments (split parts) in pixels",
)
ap.add_argument(
"--word-gap",
type=int,
default=20,
help="Gap between words in sentence stitching in pixels",
)
ap.add_argument(
"--baseline-percentile",
type=float,
default=75.0,
help="Percentile for baseline detection (0-100, default: 85.0)",
)
ap.add_argument(
"--allowed-writers",
type=str,
nargs="+",
default=None,
help="List of allowed writer IDs to choose from (e.g., --allowed-writers 0 5 10 25)",
)
args = ap.parse_args()
generate_handwriting(**vars(args))
if __name__ == "__main__":
main()