carminezacc commited on
Commit
23758fa
·
verified ·
1 Parent(s): acb9229

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +131 -100
  2. configuration_eruku.py +52 -0
  3. modeling_eruku.py +418 -0
README.md CHANGED
@@ -1,127 +1,159 @@
1
  ---
2
- license: mit
3
  tags:
4
  - handwriting-generation
 
5
  - text-to-image
6
  - autoregressive
 
 
7
  - pytorch
8
- - eruku
9
- - text-image-generation
10
- pipeline_tag: image-to-text
 
11
  ---
12
 
13
  # Eruku - Autoregressive Styled Text Image Generation
14
 
15
- **Eruku** is a state-of-the-art autoregressive model for styled text image generation, particularly excelling at handwritten text generation (HTG).
 
 
 
 
16
 
17
- 📄 **Paper**: ["Autoregressive Styled Text Image Generation, but Make it Reliable"](https://arxiv.org/abs/2510.23240)
18
- 🎮 **Demo**: [HuggingFace Space](https://huggingface.co/spaces/blowing-up-groundhogs/eruku-demo)
19
 
20
- ## Model Description
21
 
22
- Eruku addresses key limitations of previous handwriting generation methods while maintaining their strengths:
 
 
 
 
 
23
 
24
- - **No Style Text Required**: Unlike previous methods, Eruku doesn't require transcriptions of style images
25
- - 🎯 **Reliable Generation**: Proper stop mechanism prevents repetition loops and visual artifacts
26
- - 🔤 **Special Token Alignment**: Introduces special textual tokens (SOG/EOG) for better alignment between text and visual representations
27
- - ⚡ **Classifier-Free Guidance**: Implements CFG for improved control over style adherence and text fidelity
28
- - 📏 **Arbitrary Length**: Can generate text images of any length without architectural constraints
29
 
30
- ## Architecture
31
-
32
- The model combines:
33
-
34
- - **T5 Transformer**: Autoregressive text encoder for understanding and generation control
35
- - **VAE (Variational Autoencoder)**: Efficient image tokenizer (from `blowing-up-groundhogs/emuru_vae`)
36
- - **OrigamiNet OCR**: For auxiliary OCR loss during training
37
-
38
- ## Model Files
39
-
40
- - `000073688.pth` - Main trained model weights (8.0 GB)
41
- - `origami.pth` - OCR model checkpoint (OrigamiNet, 41 MB)
42
 
43
- ## Usage
44
 
45
  ```python
 
 
46
  import torch
47
- from huggingface_hub import hf_hub_download
48
- from eruku_continuous_inf import Emuru
49
 
50
- # Download checkpoints
51
- model_checkpoint = hf_hub_download(
52
- repo_id="blowing-up-groundhogs/eruku",
53
- filename="000073688.pth"
 
54
  )
 
 
55
 
56
- ocr_checkpoint = hf_hub_download(
57
- repo_id="blowing-up-groundhogs/eruku",
58
- filename="origami.pth"
59
- )
60
 
61
- # Initialize model
62
- model = Emuru(
63
- t5_checkpoint='google-t5/t5-base',
64
- vae_checkpoint='blowing-up-groundhogs/emuru_vae',
65
- ocr_checkpoint=ocr_checkpoint,
66
- slices_per_query=1,
67
- channels=1
68
  )
69
 
70
- # Load trained weights
71
- checkpoint = torch.load(model_checkpoint, map_location='cpu')
72
- model.load_state_dict(checkpoint, strict=False)
73
- model.eval()
74
 
75
- # Generate handwriting
76
- style_text = "" # Optional
77
- gen_text = "Hello World!"
78
 
79
- # Prepare inputs
80
- inputs = model.get_model_inputs(
81
- style_img=[torch.ones(1, 1, 64)], # Minimal style image
82
- gen_img=None,
83
- style_len=64,
84
- gen_len=None,
85
- max_img_len=128 * 8
86
- )
87
 
88
- # Generate
89
- output_img, _ = model.generate(
90
- decoder_inputs_embeds_vae=inputs['decoder_inputs_embeds'],
91
- style_text=[style_text],
92
- gen_text=[gen_text],
93
- cfg_scale=1.5,
94
- max_new_tokens=128
95
- )
96
- ```
97
 
98
- ## Performance Highlights
99
 
100
- From the paper, Eruku demonstrates:
101
 
102
- - **Superior Text Adherence**: Lower Character Error Rate (CER) compared to previous methods
103
- - **Better Generalization**: Excellent performance on both handwritten and typewritten styles
104
- - **Style Consistency**: High-fidelity style replication while maintaining readability
105
- - **Efficient Training**: Simpler training process without requiring auxiliary networks
 
 
 
106
 
107
- ## Training
108
 
109
- The model was trained in two stages:
 
 
 
110
 
111
- 1. **Stage 1**: Pre-training on large-scale synthetic and real handwriting datasets
112
- 2. **Stage 2**: Fine-tuning with longer text sequences and dropout strategies
113
 
114
- Training details are available in the paper.
 
 
 
 
115
 
116
- ## Limitations
117
 
118
- - Best performance on English text (model trained primarily on English datasets)
119
- - Very long texts (>200 tokens) may require chunking
120
- - Style transfer quality depends on the style reference provided
 
121
 
122
- ## Citation
123
 
124
- If you use Eruku in your research, please cite:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  ```bibtex
127
  @InProceedings{pippi2025zeroshot,
@@ -137,26 +169,25 @@ If you use Eruku in your research, please cite:
137
  author = {Carmine Zaccagnino and Fabio Quattrini and Vittorio Pippi and Silvia Cascianelli and Alessio Tonioni and Rita Cucchiara},
138
  title = {Autoregressive Styled Text Image Generation, but Make it Reliable},
139
  booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
140
- month=3,
141
- year = 2026
142
  }
143
  ```
144
 
145
- ## Authors
 
 
 
 
 
146
 
147
- - Carmine Zaccagnino (University of Modena and Reggio Emilia)
148
- - Fabio Quattrini (University of Modena and Reggio Emilia)
149
- - Vittorio Pippi (University of Modena and Reggio Emilia)
150
- - Silvia Cascianelli (University of Modena and Reggio Emilia)
151
- - Alessio Tonioni (Google)
152
- - Rita Cucchiara (University of Modena and Reggio Emilia)
153
 
154
- ## License
155
 
156
- MIT License
157
 
158
- ## Related
 
 
159
 
160
- - **VAE Model**: [blowing-up-groundhogs/emuru_vae](https://huggingface.co/blowing-up-groundhogs/emuru_vae)
161
- - **Demo Space**: [HuggingFace Space](https://huggingface.co/spaces/blowing-up-groundhogs/eruku-demo)
162
- - **Paper**: [arXiv:2510.23240](https://arxiv.org/abs/2510.23240)
 
1
  ---
2
+ license: apache-2.0
3
  tags:
4
  - handwriting-generation
5
+ - styled-text-generation
6
  - text-to-image
7
  - autoregressive
8
+ - vision
9
+ - transformers
10
  - pytorch
11
+ language:
12
+ - en
13
+ pipeline_tag: image-to-image
14
+ library_name: transformers
15
  ---
16
 
17
  # Eruku - Autoregressive Styled Text Image Generation
18
 
19
+ <p align="center">
20
+ <img src="https://img.shields.io/badge/CVPR-2025-blue" alt="CVPR 2025">
21
+ <img src="https://img.shields.io/badge/WACV-2026-green" alt="WACV 2026">
22
+ <img src="https://img.shields.io/badge/License-Apache%202.0-yellow" alt="License">
23
+ </p>
24
 
25
+ **Eruku** is a state-of-the-art autoregressive model for styled handwritten and typewritten text image generation. Given a style reference image and text to generate, it produces high-quality text images that faithfully replicate the input style.
 
26
 
27
+ ## 🌟 Key Features
28
 
29
+ - **Zero-shot style transfer**: No training required for new styles
30
+ - **No transcription required**: Works with just a style image (transcription optional but helps)
31
+ - **Reliable generation**: Proper EOG (End of Generation) mechanism prevents artifacts
32
+ - **Arbitrary length**: Generate text of any length
33
+ - **High fidelity**: Excellent style consistency and text readability
34
+ - **Classifier-Free Guidance**: Fine control over generation quality
35
 
36
+ ## 📦 Installation
 
 
 
 
37
 
38
+ ```bash
39
+ pip install torch torchvision transformers diffusers einops pillow
40
+ ```
 
 
 
 
 
 
 
 
 
41
 
42
+ ## 🚀 Quick Start
43
 
44
  ```python
45
+ from transformers import AutoModel
46
+ from PIL import Image
47
  import torch
 
 
48
 
49
+ # Load model
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ model = AutoModel.from_pretrained(
52
+ "blowing-up-groundhogs/eruku",
53
+ trust_remote_code=True
54
  )
55
+ model.to(device)
56
+ model.eval()
57
 
58
+ # Load a style image (handwritten/typewritten text sample)
59
+ style_image = Image.open("style_sample.png")
 
 
60
 
61
+ # Generate text in that style
62
+ result = model.generate_handwriting(
63
+ style_image=style_image,
64
+ gen_text="Hello, World!",
65
+ style_text="", # Optional: transcription of style image
66
+ cfg_scale=1.25, # Classifier-free guidance scale
 
67
  )
68
 
69
+ # Save the result
70
+ result.save("generated.png")
71
+ ```
 
72
 
73
+ ## 📖 Detailed Usage
 
 
74
 
75
+ ### Input Format
 
 
 
 
 
 
 
76
 
77
+ The model takes three inputs:
78
+
79
+ 1. **Style Image** (`style_image`): A PIL Image containing handwritten or typewritten text that serves as the style reference. The model will replicate this style.
80
+
81
+ 2. **Generation Text** (`gen_text`): The text you want to render in the extracted style.
 
 
 
 
82
 
83
+ 3. **Style Text** (`style_text`, optional): The transcription of the text in the style image. Providing this helps the model better understand the style, but it's not required.
84
 
85
+ ### Parameters
86
 
87
+ | Parameter | Type | Default | Description |
88
+ |-----------|------|---------|-------------|
89
+ | `style_image` | PIL.Image | Required | Reference style image |
90
+ | `gen_text` | str | Required | Text to generate |
91
+ | `style_text` | str | `""` | Optional transcription of style image |
92
+ | `cfg_scale` | float | `1.25` | Classifier-free guidance scale |
93
+ | `max_new_tokens` | int | `512` | Maximum generation tokens |
94
 
95
+ ### CFG Scale Guide
96
 
97
+ - `1.0`: No guidance (faster but may drift from prompt)
98
+ - `1.25`: Recommended default - good balance
99
+ - `1.5-2.0`: Stronger adherence to prompt
100
+ - `>2.0`: May cause artifacts
101
 
102
+ ## 🖼️ Example Results
 
103
 
104
+ The model excels at:
105
+ - Handwritten text in various styles (cursive, print, mixed)
106
+ - Typewritten text with different fonts
107
+ - Multi-language text (trained primarily on English)
108
+ - Long text sequences
109
 
110
+ ## 📊 Model Architecture
111
 
112
+ Eruku combines:
113
+ - **T5-Large encoder-decoder** for text understanding and autoregressive generation
114
+ - **VAE (Variational Autoencoder)** for image encoding and decoding
115
+ - **Custom embeddings** for style transfer and special tokens (SOS, SOG, EOG)
116
 
117
+ The model generates images autoregressively, predicting one latent slice at a time until it produces an EOG (End of Generation) token.
118
 
119
+ ## 🔧 Advanced Usage
120
+
121
+ ### Lower-level API
122
+
123
+ For more control, you can use the lower-level methods:
124
+
125
+ ```python
126
+ import torch
127
+ from torchvision import transforms as T
128
+
129
+ # Prepare style image manually
130
+ style_img = Image.open("style.png").convert('RGB')
131
+ width, height = style_img.size
132
+ new_width = int(64 * width / height)
133
+ style_img = style_img.resize((new_width, 64), Image.LANCZOS)
134
+ style_tensor = T.ToTensor()(style_img).to(device)
135
+
136
+ # Get model inputs
137
+ inputs = model.get_model_inputs(
138
+ style_img=[style_tensor],
139
+ style_len=style_tensor.shape[-1],
140
+ max_img_len=1024*1024
141
+ )
142
+
143
+ # Generate with full control
144
+ with torch.inference_mode():
145
+ output_img, special_sequence = model.generate(
146
+ decoder_inputs_embeds_vae=inputs['decoder_inputs_embeds'],
147
+ style_text=["Style text here"],
148
+ gen_text=["Text to generate"],
149
+ cfg_scale=1.25,
150
+ max_new_tokens=512
151
+ )
152
+ ```
153
+
154
+ ## 📚 Citation
155
+
156
+ If you use Eruku in your research, please cite both papers:
157
 
158
  ```bibtex
159
  @InProceedings{pippi2025zeroshot,
 
169
  author = {Carmine Zaccagnino and Fabio Quattrini and Vittorio Pippi and Silvia Cascianelli and Alessio Tonioni and Rita Cucchiara},
170
  title = {Autoregressive Styled Text Image Generation, but Make it Reliable},
171
  booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
172
+ month = {March},
173
+ year = {2026}
174
  }
175
  ```
176
 
177
+ ## 🔗 Links
178
+
179
+ - 📄 **Paper**: [arXiv:2510.23240](https://arxiv.org/abs/2510.23240)
180
+ - 🌐 **Project Website**: [eruku.carminezacc.com](https://eruku.carminezacc.com)
181
+ - 🤗 **Demo**: [Hugging Face Space](https://huggingface.co/spaces/carminezacc/eruku)
182
+ - 🎨 **VAE Model**: [blowing-up-groundhogs/emuru_vae](https://huggingface.co/blowing-up-groundhogs/emuru_vae)
183
 
184
+ ## 📜 License
 
 
 
 
 
185
 
186
+ This model is released under the Apache 2.0 License.
187
 
188
+ ## 🙏 Acknowledgments
189
 
190
+ - T5: google-t5/t5-large
191
+ - VAE: blowing-up-groundhogs/emuru_vae
192
+ - Training datasets: IAM, CVL, RIMES, FontSquare
193
 
 
 
 
configuration_eruku.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eruku Configuration
3
+
4
+ Configuration class for the Eruku Styled Handwritten Text Recognition model.
5
+ """
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class ErukuConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for Eruku model.
13
+
14
+ Args:
15
+ t5_name_or_path (`str`, *optional*, defaults to `"google-t5/t5-large"`):
16
+ The name or path of the T5 model to use as the backbone.
17
+ vae_name_or_path (`str`, *optional*, defaults to `"blowing-up-groundhogs/emuru_vae"`):
18
+ The name or path of the VAE model for image encoding/decoding.
19
+ tokenizer_name_or_path (`str`, *optional*, defaults to `"google/byt5-small"`):
20
+ The name or path of the tokenizer (character-level).
21
+ slices_per_query (`int`, *optional*, defaults to 1):
22
+ Number of VAE latent slices per query token.
23
+ channels (`int`, *optional*, defaults to 1):
24
+ Number of channels in the VAE latent space.
25
+ vae_latent_dim (`int`, *optional*, defaults to 8):
26
+ Dimension of the VAE latent space.
27
+ cfg_scale (`float`, *optional*, defaults to 1.25):
28
+ Default classifier-free guidance scale for generation.
29
+ """
30
+
31
+ model_type = "eruku"
32
+
33
+ def __init__(
34
+ self,
35
+ t5_name_or_path: str = "google-t5/t5-large",
36
+ vae_name_or_path: str = "blowing-up-groundhogs/emuru_vae",
37
+ tokenizer_name_or_path: str = "google/byt5-small",
38
+ slices_per_query: int = 1,
39
+ channels: int = 1,
40
+ vae_latent_dim: int = 8,
41
+ cfg_scale: float = 1.25,
42
+ **kwargs
43
+ ):
44
+ super().__init__(**kwargs)
45
+ self.t5_name_or_path = t5_name_or_path
46
+ self.vae_name_or_path = vae_name_or_path
47
+ self.tokenizer_name_or_path = tokenizer_name_or_path
48
+ self.slices_per_query = slices_per_query
49
+ self.channels = channels
50
+ self.vae_latent_dim = vae_latent_dim
51
+ self.cfg_scale = cfg_scale
52
+
modeling_eruku.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eruku Model - Styled Handwritten Text Recognition
3
+
4
+ This module implements the Eruku model for autoregressive styled text image generation.
5
+ Based on the papers:
6
+ - "Zero-Shot Styled Text Image Generation, but Make It Autoregressive" (CVPR 2025)
7
+ - "Autoregressive Styled Text Image Generation, but Make it Reliable" (WACV 2026)
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Tuple, List, Union
13
+ from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
14
+ from diffusers import AutoencoderKL
15
+ from einops import rearrange, repeat
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ from torchvision.transforms import Normalize
18
+ from PIL import Image
19
+ import numpy as np
20
+
21
+ from .configuration_eruku import ErukuConfig
22
+
23
+
24
+ # Number of special tokens: SOG, EOG, IMG
25
+ SPECIAL_TOKEN_COUNT = 3
26
+
27
+
28
+ def pad_images(images: List[torch.Tensor], padding_value: float = 1.0) -> torch.Tensor:
29
+ """Pad a list of images to the same width."""
30
+ images = [rearrange(img, 'c h w -> w c h') for img in images]
31
+ padded = rearrange(pad_sequence(images, padding_value=padding_value), 'w b c h -> b c h w')
32
+ return padded.contiguous()
33
+
34
+
35
+ class ErukuPreTrainedModel(PreTrainedModel):
36
+ """
37
+ Base class for Eruku models.
38
+ """
39
+ config_class = ErukuConfig
40
+ base_model_prefix = "eruku"
41
+ supports_gradient_checkpointing = True
42
+
43
+ def _init_weights(self, module):
44
+ """Initialize weights - handled by sub-components."""
45
+ pass
46
+
47
+
48
+ class ErukuForConditionalGeneration(ErukuPreTrainedModel):
49
+ """
50
+ Eruku model for conditional styled text image generation.
51
+
52
+ The model takes a style image (handwritten/typewritten text sample),
53
+ optional style text (transcription of the style image), and generation
54
+ text (text to render), and produces an image of the generation text
55
+ in the style of the reference image.
56
+
57
+ Example usage:
58
+ ```python
59
+ from transformers import AutoModel
60
+ from PIL import Image
61
+ import torch
62
+
63
+ # Load model
64
+ model = AutoModel.from_pretrained("blowing-up-groundhogs/eruku", trust_remote_code=True)
65
+ model.eval()
66
+
67
+ # Generate handwriting
68
+ result = model.generate_handwriting(
69
+ style_image=Image.open("style.png"),
70
+ style_text="Hello", # optional - text in style image
71
+ gen_text="World", # text to generate
72
+ )
73
+ result.save("output.png")
74
+ ```
75
+ """
76
+
77
+ def __init__(self, config: ErukuConfig):
78
+ super().__init__(config)
79
+ self.config = config
80
+
81
+ # Character-level tokenizer
82
+ self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
83
+ self.tokenizer.add_tokens(["<sog>"])
84
+
85
+ # T5 backbone
86
+ t5_config = T5Config.from_pretrained(config.t5_name_or_path)
87
+ t5_config.vocab_size = len(self.tokenizer)
88
+ self.T5 = T5ForConditionalGeneration(t5_config)
89
+ self.T5.lm_head = nn.Identity()
90
+
91
+ # Image normalization
92
+ self.normalize = Normalize(0.5, 0.5)
93
+
94
+ # Special token embeddings
95
+ self.sos = nn.Embedding(1, t5_config.d_model)
96
+ self.sog = nn.Embedding(1, t5_config.d_model)
97
+ self.eog = nn.Embedding(1, t5_config.d_model)
98
+
99
+ # VAE for image encoding/decoding
100
+ self.vae = AutoencoderKL.from_pretrained(config.vae_name_or_path)
101
+ self._freeze_module(self.vae)
102
+
103
+ # Projection layers
104
+ vae_dim = config.vae_latent_dim * config.channels * config.slices_per_query
105
+ self.query_emb = nn.Linear(vae_dim, t5_config.d_model)
106
+ self.t5_to_vae = nn.Linear(t5_config.d_model, vae_dim)
107
+ self.t5_to_special = nn.Linear(t5_config.d_model, SPECIAL_TOKEN_COUNT)
108
+
109
+ # Unconditional embedding for CFG
110
+ self.uncond_embedding = nn.Embedding(1, t5_config.d_model)
111
+
112
+ # CFG configuration
113
+ self.drop_text = False
114
+ self.drop_img = False
115
+
116
+ # Einops rearrangements
117
+ self.z_rearrange = lambda x: rearrange(x, 'b w (q c h) -> b c h (w q)',
118
+ c=config.channels, q=config.slices_per_query)
119
+
120
+ self.post_init()
121
+
122
+ def _freeze_module(self, module: nn.Module):
123
+ """Freeze all parameters in a module."""
124
+ module.eval()
125
+ for param in module.parameters():
126
+ param.requires_grad = False
127
+
128
+ def _img_encode(self, img: torch.Tensor) -> torch.Tensor:
129
+ """Encode image to VAE latent space."""
130
+ img = self.normalize(img)
131
+ img = img.contiguous()
132
+ return self.vae.encode(img.float()).latent_dist.sample()
133
+
134
+ @torch.no_grad()
135
+ def get_model_inputs(
136
+ self,
137
+ style_img: List[torch.Tensor],
138
+ style_len: Union[int, List[int]],
139
+ max_img_len: int = 1024 * 1024
140
+ ) -> dict:
141
+ """
142
+ Prepare model inputs from style images.
143
+
144
+ Args:
145
+ style_img: List of style image tensors [C, H, W]
146
+ style_len: Width(s) of style images
147
+ max_img_len: Maximum image length in pixels
148
+
149
+ Returns:
150
+ Dictionary with decoder_inputs_embeds
151
+ """
152
+ bs = len(style_img)
153
+ decoder_inputs_embeds_list = []
154
+
155
+ # Pad images to same width
156
+ style_img_padded = pad_images([el.to(self.T5.device) for el in style_img])
157
+ style_img_embeds = self._img_encode(style_img_padded)
158
+
159
+ for el in range(bs):
160
+ if isinstance(style_len, int):
161
+ sl = style_len
162
+ else:
163
+ sl = int(style_len[el])
164
+
165
+ # Ensure width is within bounds
166
+ sl = max(64, min(sl, style_img_embeds.shape[-1] * 8))
167
+
168
+ # Style image embeddings + SOG marker
169
+ sample_embeds = torch.cat([
170
+ style_img_embeds[el, :, :, :sl // 8],
171
+ torch.ones(1, 8, 1).to(self.T5.device), # SOG placeholder
172
+ ], dim=-1)
173
+
174
+ sample_embeds = rearrange(sample_embeds, 'c h w -> w (h c)', h=8, c=1)
175
+ decoder_inputs_embeds_list.append(sample_embeds)
176
+
177
+ decoder_inputs_embeds = pad_sequence(
178
+ decoder_inputs_embeds_list,
179
+ padding_value=1,
180
+ batch_first=True
181
+ )[:, :max_img_len // 8]
182
+
183
+ return {'decoder_inputs_embeds': decoder_inputs_embeds}
184
+
185
+ @torch.inference_mode()
186
+ def generate(
187
+ self,
188
+ decoder_inputs_embeds_vae: torch.Tensor,
189
+ style_text: List[str],
190
+ gen_text: List[str],
191
+ cfg_scale: float = 1.25,
192
+ max_new_tokens: int = 512
193
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
194
+ """
195
+ Generate styled text image autoregressively.
196
+
197
+ Args:
198
+ decoder_inputs_embeds_vae: VAE embeddings of style image
199
+ style_text: List of style text strings (can be empty)
200
+ gen_text: List of generation text strings
201
+ cfg_scale: Classifier-free guidance scale (1.25 recommended)
202
+ max_new_tokens: Maximum tokens to generate
203
+
204
+ Returns:
205
+ Tuple of (generated_image, special_sequence)
206
+ """
207
+ # Encode text
208
+ encoded_text = self.tokenizer(
209
+ [f"{style}<sog>{gen}" for style, gen in zip(style_text, gen_text)],
210
+ padding=True,
211
+ return_tensors="pt"
212
+ )
213
+ text_input_ids = encoded_text['input_ids'].to(self.T5.device)
214
+ text_mask = encoded_text['attention_mask'].to(self.T5.device)
215
+
216
+ # Initialize generation
217
+ sog = repeat(self.sog.weight, '1 d -> b 1 d', b=1)
218
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=1)
219
+
220
+ z_sequence = [decoder_inputs_embeds_vae]
221
+ special_sequence = torch.ones(decoder_inputs_embeds_vae.size(1)) * 3
222
+
223
+ # Build initial decoder inputs
224
+ decoder_inputs_embeds = self.query_emb(torch.cat(z_sequence, dim=1))
225
+ if len(style_text[0]) != 0:
226
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
227
+ else:
228
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds, sog], dim=1)
229
+ vae_latent = self.t5_to_vae(sog)
230
+ special_sequence = torch.cat([special_sequence, torch.zeros(1)])
231
+ z_sequence.append(vae_latent)
232
+
233
+ # Autoregressive generation
234
+ for i in range(max_new_tokens):
235
+ if cfg_scale != 1.0:
236
+ # Classifier-free guidance
237
+ conditional_text_embeds = self.T5.shared(text_input_ids)
238
+ if self.drop_text:
239
+ unconditional_text_embeds = self.uncond_embedding.weight.expand_as(conditional_text_embeds)
240
+ else:
241
+ unconditional_text_embeds = conditional_text_embeds
242
+
243
+ if self.drop_img:
244
+ unconditional_decoder_embeds = self.uncond_embedding.weight.expand_as(decoder_inputs_embeds)
245
+ else:
246
+ unconditional_decoder_embeds = decoder_inputs_embeds
247
+
248
+ output_uncond = self.T5(
249
+ inputs_embeds=unconditional_text_embeds,
250
+ attention_mask=text_mask,
251
+ decoder_inputs_embeds=unconditional_decoder_embeds
252
+ ).logits[:, -1:]
253
+
254
+ output_cond = self.T5(
255
+ input_ids=text_input_ids,
256
+ attention_mask=text_mask,
257
+ decoder_inputs_embeds=decoder_inputs_embeds
258
+ ).logits[:, -1:]
259
+
260
+ output = output_uncond + (output_cond - output_uncond) * cfg_scale
261
+ else:
262
+ output = self.T5(
263
+ input_ids=text_input_ids,
264
+ attention_mask=text_mask,
265
+ decoder_inputs_embeds=decoder_inputs_embeds
266
+ ).logits[:, -1:]
267
+
268
+ # Predict special token
269
+ special_prediction = self.t5_to_special(output)
270
+ predicted_special = torch.argmax(special_prediction, dim=-1).item()
271
+
272
+ if predicted_special == 0: # SOG
273
+ decoder_inputs_embeds = torch.cat([decoder_inputs_embeds, sog], dim=1)
274
+ vae_latent = self.t5_to_vae(output)
275
+ special_sequence = torch.cat([special_sequence, torch.zeros(1)])
276
+ elif predicted_special == 1: # EOG - stop generation
277
+ special_sequence = torch.cat([special_sequence, torch.ones(1)])
278
+ vae_latent = self.t5_to_vae(output)
279
+ z_sequence.append(vae_latent)
280
+ break
281
+ else: # IMG token
282
+ vae_latent = self.t5_to_vae(output)
283
+ decoder_inputs_embeds = torch.cat([
284
+ decoder_inputs_embeds,
285
+ self.query_emb(vae_latent)
286
+ ], dim=1)
287
+ special_sequence = torch.cat([special_sequence, torch.ones(1) * 2])
288
+
289
+ z_sequence.append(vae_latent)
290
+
291
+ # Decode to image
292
+ z_sequence = [el.to(self.vae.device) for el in z_sequence]
293
+ z_sequence = torch.cat(z_sequence, dim=1)
294
+ z_sequence = self.z_rearrange(z_sequence)
295
+ img = torch.clamp(self.vae.decode(z_sequence).sample, -1, 1)
296
+
297
+ return img, special_sequence.to(self.T5.device)
298
+
299
+ def generate_handwriting(
300
+ self,
301
+ style_image: Image.Image,
302
+ gen_text: str,
303
+ style_text: str = "",
304
+ cfg_scale: float = 1.25,
305
+ max_new_tokens: int = 512,
306
+ device: Optional[str] = None
307
+ ) -> Image.Image:
308
+ """
309
+ High-level API for generating handwriting.
310
+
311
+ This is the recommended entry point for inference.
312
+
313
+ Args:
314
+ style_image: PIL Image containing handwriting style reference
315
+ gen_text: Text to generate in the style
316
+ style_text: Optional transcription of text in style_image
317
+ cfg_scale: Classifier-free guidance scale (default: 1.25)
318
+ max_new_tokens: Maximum generation length
319
+ device: Device to use (auto-detected if None)
320
+
321
+ Returns:
322
+ PIL Image of generated handwriting
323
+ """
324
+ import torchvision.transforms as T
325
+
326
+ if device is None:
327
+ device = next(self.parameters()).device
328
+
329
+ # Preprocess style image
330
+ style_img = style_image.convert('RGB')
331
+
332
+ # Resize to height 64 maintaining aspect ratio
333
+ width, height = style_img.size
334
+ new_width = int(64 * width / height)
335
+ style_img = style_img.resize((new_width, 64), Image.LANCZOS)
336
+
337
+ # Convert to tensor
338
+ style_tensor = T.ToTensor()(style_img).to(device)
339
+ style_len = style_tensor.shape[-1]
340
+
341
+ # Get model inputs
342
+ inputs = self.get_model_inputs(
343
+ style_img=[style_tensor],
344
+ style_len=style_len,
345
+ max_img_len=1024 * 1024
346
+ )
347
+
348
+ # Generate
349
+ output_img, _ = self.generate(
350
+ decoder_inputs_embeds_vae=inputs['decoder_inputs_embeds'],
351
+ style_text=[style_text],
352
+ gen_text=[gen_text],
353
+ cfg_scale=cfg_scale,
354
+ max_new_tokens=max_new_tokens
355
+ )
356
+
357
+ # Crop out the style image part (keep only generated portion)
358
+ style_width_latent = style_len // 8 + 1 # +1 for SOG token
359
+ output_img = output_img[:, :, :, style_width_latent * 8:]
360
+
361
+ # Trim whitespace
362
+ output_img = self._trim_white(output_img)
363
+
364
+ # Convert to PIL
365
+ output_img = (torch.clamp(output_img, -1, 1) + 1) * 127.5
366
+ output_img = output_img.byte().squeeze().cpu().numpy()
367
+
368
+ if len(output_img.shape) == 2:
369
+ return Image.fromarray(output_img, mode='L')
370
+ elif output_img.shape[0] == 3:
371
+ output_img = np.transpose(output_img, (1, 2, 0))
372
+ return Image.fromarray(output_img, mode='RGB')
373
+ else:
374
+ return Image.fromarray(output_img[0], mode='L')
375
+
376
+ @staticmethod
377
+ def _trim_white(img: torch.Tensor, threshold: float = 0.9, padding: int = 8) -> torch.Tensor:
378
+ """Trim white margins from generated image."""
379
+ start_idx, end_idx = 0, img.size(-1)
380
+ vertical_min = img[0, 0].min(-2).values.tolist()
381
+
382
+ # Skip initial non-white columns
383
+ for v in vertical_min:
384
+ if v >= threshold:
385
+ break
386
+ start_idx += 1
387
+
388
+ # Skip initial white columns
389
+ for v in vertical_min:
390
+ if v < threshold:
391
+ break
392
+ start_idx += 1
393
+
394
+ # Skip trailing white columns
395
+ for v in vertical_min[::-1]:
396
+ if v < threshold:
397
+ break
398
+ end_idx -= 1
399
+
400
+ start_idx = max(start_idx - padding, 0)
401
+ end_idx = min(end_idx + padding, img.size(-1))
402
+
403
+ if start_idx >= end_idx:
404
+ return img
405
+
406
+ return img[..., start_idx:end_idx]
407
+
408
+ def forward(self, **kwargs):
409
+ """Forward pass - mainly for training compatibility."""
410
+ raise NotImplementedError(
411
+ "Direct forward() is not supported. Use generate_handwriting() for inference."
412
+ )
413
+
414
+
415
+ # Register for AutoModel
416
+ ErukuConfig.register_for_auto_class()
417
+ ErukuForConditionalGeneration.register_for_auto_class("AutoModel")
418
+