File size: 14,736 Bytes
9e554f6
 
 
 
 
 
bf59780
 
 
 
9e554f6
 
bf59780
 
9e554f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b112c43
 
9e554f6
 
 
b112c43
9e554f6
 
b112c43
9e554f6
 
 
 
 
 
bf59780
 
9e554f6
 
 
 
 
 
 
 
 
4336dcc
33451d1
9e554f6
 
 
4336dcc
9e554f6
 
 
 
4336dcc
9e554f6
 
 
 
 
 
 
 
 
 
 
 
 
 
33451d1
9e554f6
 
 
 
 
 
4336dcc
9e554f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33451d1
9e554f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33451d1
b85941b
9e554f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
"""
DiffusionPen: Hindi Handwriting Generation Demo
Inference-focused Gradio application with CANINE text encoding
"""

import gradio as gr
import torch
import numpy as np
from PIL import Image
from unet import UNetModel
from transformers import CanineTokenizer, CanineModel
from pathlib import Path


class DiffusionPenDemo:
    """
    Hindi Handwriting Generation Demo using DiffusionPen UNet
    
    Features:
    - CANINE text encoder for character-level Hindi encoding
    - 339 different writer styles
    - Configurable diffusion steps and guidance
    - GPU/CPU automatic detection
    - Checkpoint loading support
    """
    
    def __init__(self, checkpoint_path=None, device=None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.checkpoint_path = checkpoint_path
        self.model = None
        self.text_encoder = None
        self.tokenizer = None
        self.checkpoint_loaded = False
        self.load_models()
    
    def load_models(self):
        """Load UNet model and CANINE text encoder"""
        try:
            print(f"\n{'='*60}")
            print(f"🔧 DiffusionPen Initialization")
            print(f"{'='*60}")
            print(f"📱 Device: {self.device.upper()}")
            
            # Load CANINE text encoder
            print("\n📝 Loading CANINE text encoder...")
            self.tokenizer = CanineTokenizer.from_pretrained('google/canine-s')
            self.text_encoder = CanineModel.from_pretrained('google/canine-s').to(self.device)
            self.text_encoder.eval()
            print("   ✓ CANINE loaded (768-dim embeddings)")
            
            # Initialize UNet model
            print("\n🧠 Initializing UNet model...")
            
            class Args:
                interpolation = False
                mix_rate = 0.5
            
            self.model = UNetModel(
                image_size=64,
                in_channels=1,
                model_channels=128,
                out_channels=1,
                num_res_blocks=2,
                attention_resolutions=[16, 8],
                dropout=0.1,
                channel_mult=(1, 2, 4),
                dims=2,
                num_classes=339,  # Hindi writer styles
                use_checkpoint=True,
                num_heads=8,
                num_head_channels=-1,
                use_scale_shift_norm=True,
                resblock_updown=False,
                use_spatial_transformer=True,
                transformer_depth=1,
                context_dim=768,
                text_encoder=self.text_encoder,
                args=Args()
            ).to(self.device)
            self.model.eval()
            
            # Count parameters
            total_params = sum(p.numel() for p in self.model.parameters())
            print(f"   ✓ UNet initialized ({total_params/1e6:.1f}M parameters)")
            
            # Load checkpoint if available
            if self.checkpoint_path and Path(self.checkpoint_path).exists():
                self._load_checkpoint()
            else:
                print(f"\n⚠️  No checkpoint found at: {self.checkpoint_path}")
                print("   Using random initialization")
            
            print(f"\n{'='*60}")
            print(f"✅ Ready for inference!")
            print(f"{'='*60}\n")
            
        except Exception as e:
            print(f"\n❌ Error during initialization: {str(e)}")
            raise
    
    def _load_checkpoint(self):
        """Load model checkpoint"""
        try:
            print(f"\n📂 Loading checkpoint: {self.checkpoint_path}")
            checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
            
            # Handle different checkpoint formats
            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                    print(f"   Format: Standard (model_state_dict)")
                elif 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                    print(f"   Format: Alternative (state_dict)")
                else:
                    state_dict = checkpoint
                    print(f"   Format: Raw state dict")
            else:
                state_dict = checkpoint
                print(f"   Format: Direct tensor state")
            
            # Load state dict with strict=False to handle minor mismatches
            missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
            
            if missing_keys:
                print(f"   ⚠️  Missing keys: {len(missing_keys)}")
            if unexpected_keys:
                print(f"   ⚠️  Unexpected keys: {len(unexpected_keys)}")
            
            self.checkpoint_loaded = True
            print(f"   ✓ Checkpoint loaded successfully")
            
        except Exception as e:
            print(f"   ❌ Failed to load checkpoint: {str(e)}")
            self.checkpoint_loaded = False
    
    def encode_text(self, text):
        """Encode Hindi text using CANINE"""
        try:
            # CANINE handles character-level encoding natively
            inputs = self.tokenizer(
                text,
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=512
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            return inputs
        except Exception as e:
            print(f"❌ Text encoding error: {e}")
            return None
    
    @torch.no_grad()
    def generate(self, text, writer_id=0, num_steps=50, guidance_scale=7.5):
        """
        Generate Hindi handwriting from text
        
        Args:
            text: Hindi text in Devanagari script
            writer_id: Writer style ID (0-338)
            num_steps: Number of diffusion steps (10-100)
            guidance_scale: Text guidance strength (1.0-15.0)
        
        Returns:
            Tuple[PIL.Image, str]: Generated image and status message
        """
        if self.model is None:
            return None, "❌ Model not initialized"
        
        try:
            # Input validation
            if not text.strip():
                return None, "⚠️  Please enter Hindi text"
            
            writer_id = max(0, min(int(writer_id), 338))
            num_steps = max(10, min(int(num_steps), 100))
            guidance_scale = max(1.0, min(float(guidance_scale), 15.0))
            
            print(f"\n🎨 Generating handwriting...")
            print(f"   Text: '{text}'")
            print(f"   Writer: {writer_id}/338")
            print(f"   Steps: {num_steps}")
            print(f"   Guidance: {guidance_scale}")
            
            # Encode text with CANINE
            context = self.encode_text(text)
            if context is None:
                return None, "❌ Text encoding failed"
            
            batch_size = 1
            
            # Initialize from noise
            x = torch.randn(batch_size, 1, 64, 64, device=self.device)
            
            # Reverse diffusion process
            for step in range(num_steps - 1, -1, -1):
                # Prepare timestep and writer conditioning
                t = torch.full((batch_size,), step, dtype=torch.long, device=self.device)
                y = torch.tensor([writer_id], dtype=torch.long, device=self.device)
                
                # Model prediction
                with torch.no_grad():
                    noise_pred = self.model(
                        x,
                        timesteps=t,
                        context=context,
                        y=y
                    )
                
                # Denoising step with adaptive scaling
                alpha_t = 1.0 - (step / num_steps)
                scale = guidance_scale * alpha_t
                x = x - 0.01 * scale * noise_pred
                
                # Progress indicator
                if (num_steps - step) % max(1, num_steps // 5) == 0:
                    progress = ((num_steps - step) / num_steps) * 100
                    print(f"   Progress: {progress:.0f}%")
            
            # Post-processing
            x = torch.clamp(x, -1, 1)
            x = (x + 1) / 2  # Normalize to [0, 1]
            x = x.squeeze(0).squeeze(0).cpu().numpy()
            
            # Convert to PIL Image
            img_array = (x * 255).astype(np.uint8)
            img = Image.fromarray(img_array, mode='L')
            
            status = f"✅ Generated with writer {writer_id}, {num_steps} steps"
            print(f"   {status}\n")
            return img, status
        
        except Exception as e:
            error_msg = f"❌ Generation error: {str(e)}"
            print(f"   {error_msg}")
            return None, error_msg


# ==============================================================================
# CONFIGURATION
# ==============================================================================

# Path to your trained checkpoint (edit this!)
CHECKPOINT_PATH = "./checkpoints/model.pt"

# Initialize demo
print("\n🚀 Initializing DiffusionPen...")
demo_instance = DiffusionPenDemo(
    checkpoint_path=CHECKPOINT_PATH,
    device=None  # Auto-detect GPU/CPU
)


def gradio_generate(text, writer_id, num_steps, guidance_scale):
    """Gradio callback for generation"""
    img, message = demo_instance.generate(
        text=text,
        writer_id=writer_id,
        num_steps=num_steps,
        guidance_scale=guidance_scale
    )
    return img, message


# ==============================================================================
# GRADIO INTERFACE
# ==============================================================================

theme = gr.themes.Soft(
    primary_hue="indigo",
    secondary_hue="amber",
)

with gr.Blocks(title="DiffusionPen - Hindi Handwriting Generation", theme=theme) as demo:
    
    # Header
    gr.Markdown("""
    # 🎨 DiffusionPen: Hindi Handwriting Generation
    
    Generate authentic Hindi handwriting using diffusion models with CANINE text encoding.
    """)
    
    # Main content
    with gr.Row():
        # Input panel
        with gr.Column(scale=1, min_width=300):
            gr.Markdown("### ✏️ Input Settings")
            
            text_input = gr.Textbox(
                label="Hindi Text (Devanagari)",
                placeholder="नमस्ते",
                lines=2,
                info="Enter text in Devanagari script"
            )
            
            writer_id = gr.Slider(
                label="Writer ID",
                minimum=0,
                maximum=338,
                value=0,
                step=1,
                info="0-338: Different writing styles"
            )
            
            num_steps = gr.Slider(
                label="Diffusion Steps",
                minimum=10,
                maximum=100,
                value=50,
                step=10,
                info="10=fast, 100=quality"
            )
            
            guidance_scale = gr.Slider(
                label="Guidance Scale",
                minimum=1.0,
                maximum=15.0,
                value=7.5,
                step=0.5,
                info="1=ignore text, 15=strict"
            )
            
            generate_btn = gr.Button(
                "✨ Generate Handwriting",
                variant="primary",
                size="lg"
            )
        
        # Output panel
        with gr.Column(scale=1, min_width=300):
            gr.Markdown("### 📊 Output")
            
            output_image = gr.Image(
                label="Generated Handwriting",
                type='pil',
                interactive=False,
                show_download_button=True
            )
            
            status_text = gr.Textbox(
                label="Status",
                interactive=False,
                info="Generation progress and results"
            )
    
    # Examples
    gr.Markdown("### 📚 Examples to Try")
    gr.Examples(
        examples=[
            ["नमस्ते", 0, 50, 7.5],
            ["हिंदी", 50, 50, 7.5],
            ["आईआईआीटी", 100, 50, 7.5],
            ["लिपि", 150, 50, 7.5],
            ["भाषा", 200, 50, 7.5],
            ["नई लिखावट", 250, 60, 7.5],
        ],
        inputs=[text_input, writer_id, num_steps, guidance_scale],
        outputs=[output_image, status_text],
        fn=gradio_generate,
        cache_examples=False,
        run_on_click=False
    )
    
    # Information
    gr.Markdown("""
    ---
    
    ### 📖 About This Demo
    
    **Model Architecture:**
    - **Base**: UNet with 128 channels, 3 levels
    - **Attention**: Spatial transformers at resolutions 16×8
    - **Text Encoding**: CANINE (768-dim, character-level)
    - **Writer Styles**: 339 different writing styles
    - **Input/Output**: 64×64 grayscale images
    
    **CANINE Text Encoder:**
    - ✓ Character-level (no subword tokenization)
    - ✓ Native Devanagari support
    - ✓ Pre-trained on 104 languages
    - ✓ 768-dimensional contextual embeddings
    
    **Performance:**
    - CPU: ~2 minutes per image
    - GPU: ~20 seconds per image
    - Memory: 6-8 GB
    
    ### 💡 Tips
    1. Keep text short (5-10 characters) for faster generation
    2. Try different Writer IDs for style variation
    3. Increase steps from 50→100 for better quality
    4. Guidance scale 5-10 works best for most cases
    5. Use CPU to generate demos, GPU for production
    
    ### 🔗 Resources
    - [CANINE Paper](https://arxiv.org/abs/2103.06367)
    - [Diffusion Models Course](https://huggingface.co/course)
    - [UNet Architecture](https://en.wikipedia.org/wiki/U-Net)
    """)
    
    # Connect button
    generate_btn.click(
        fn=gradio_generate,
        inputs=[text_input, writer_id, num_steps, guidance_scale],
        outputs=[output_image, status_text],
        api_name="generate"
    )


if __name__ == "__main__":
    print(f"\n{'='*60}")
    print("🚀 Starting DiffusionPen Gradio Demo")
    print(f"{'='*60}")
    print(f"Device: {demo_instance.device}")
    print(f"Checkpoint: {'✓ Loaded' if demo_instance.checkpoint_loaded else '✗ Not found'}")
    print(f"Models: {'✓ Ready' if demo_instance.model is not None else '✗ Error'}")
    print(f"{'='*60}\n")
    
    demo.launch(
        share=False,
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )