Spaces:
Runtime error
Runtime error
| """ | |
| DiffusionPen: Hindi Handwriting Generation Demo | |
| Inference-focused Gradio application with CANINE text encoding | |
| """ | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from unet import UNetModel | |
| from transformers import CanineTokenizer, CanineModel | |
| from pathlib import Path | |
| class DiffusionPenDemo: | |
| """ | |
| Hindi Handwriting Generation Demo using DiffusionPen UNet | |
| Features: | |
| - CANINE text encoder for character-level Hindi encoding | |
| - 339 different writer styles | |
| - Configurable diffusion steps and guidance | |
| - GPU/CPU automatic detection | |
| - Checkpoint loading support | |
| """ | |
| def __init__(self, checkpoint_path=None, device=None): | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.checkpoint_path = checkpoint_path | |
| self.model = None | |
| self.text_encoder = None | |
| self.tokenizer = None | |
| self.checkpoint_loaded = False | |
| self.load_models() | |
| def load_models(self): | |
| """Load UNet model and CANINE text encoder""" | |
| try: | |
| print(f"\n{'='*60}") | |
| print(f"🔧 DiffusionPen Initialization") | |
| print(f"{'='*60}") | |
| print(f"📱 Device: {self.device.upper()}") | |
| # Load CANINE text encoder | |
| print("\n📝 Loading CANINE text encoder...") | |
| self.tokenizer = CanineTokenizer.from_pretrained('google/canine-s') | |
| self.text_encoder = CanineModel.from_pretrained('google/canine-s').to(self.device) | |
| self.text_encoder.eval() | |
| print(" ✓ CANINE loaded (768-dim embeddings)") | |
| # Initialize UNet model | |
| print("\n🧠 Initializing UNet model...") | |
| class Args: | |
| interpolation = False | |
| mix_rate = 0.5 | |
| self.model = UNetModel( | |
| image_size=64, | |
| in_channels=1, | |
| model_channels=128, | |
| out_channels=1, | |
| num_res_blocks=2, | |
| attention_resolutions=[16, 8], | |
| dropout=0.1, | |
| channel_mult=(1, 2, 4), | |
| dims=2, | |
| num_classes=339, # Hindi writer styles | |
| use_checkpoint=True, | |
| num_heads=8, | |
| num_head_channels=-1, | |
| use_scale_shift_norm=True, | |
| resblock_updown=False, | |
| use_spatial_transformer=True, | |
| transformer_depth=1, | |
| context_dim=768, | |
| text_encoder=self.text_encoder, | |
| args=Args() | |
| ).to(self.device) | |
| self.model.eval() | |
| # Count parameters | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| print(f" ✓ UNet initialized ({total_params/1e6:.1f}M parameters)") | |
| # Load checkpoint if available | |
| if self.checkpoint_path and Path(self.checkpoint_path).exists(): | |
| self._load_checkpoint() | |
| else: | |
| print(f"\n⚠️ No checkpoint found at: {self.checkpoint_path}") | |
| print(" Using random initialization") | |
| print(f"\n{'='*60}") | |
| print(f"✅ Ready for inference!") | |
| print(f"{'='*60}\n") | |
| except Exception as e: | |
| print(f"\n❌ Error during initialization: {str(e)}") | |
| raise | |
| def _load_checkpoint(self): | |
| """Load model checkpoint""" | |
| try: | |
| print(f"\n📂 Loading checkpoint: {self.checkpoint_path}") | |
| checkpoint = torch.load(self.checkpoint_path, map_location=self.device) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict): | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| print(f" Format: Standard (model_state_dict)") | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| print(f" Format: Alternative (state_dict)") | |
| else: | |
| state_dict = checkpoint | |
| print(f" Format: Raw state dict") | |
| else: | |
| state_dict = checkpoint | |
| print(f" Format: Direct tensor state") | |
| # Load state dict with strict=False to handle minor mismatches | |
| missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False) | |
| if missing_keys: | |
| print(f" ⚠️ Missing keys: {len(missing_keys)}") | |
| if unexpected_keys: | |
| print(f" ⚠️ Unexpected keys: {len(unexpected_keys)}") | |
| self.checkpoint_loaded = True | |
| print(f" ✓ Checkpoint loaded successfully") | |
| except Exception as e: | |
| print(f" ❌ Failed to load checkpoint: {str(e)}") | |
| self.checkpoint_loaded = False | |
| def encode_text(self, text): | |
| """Encode Hindi text using CANINE""" | |
| try: | |
| # CANINE handles character-level encoding natively | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors='pt', | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| return inputs | |
| except Exception as e: | |
| print(f"❌ Text encoding error: {e}") | |
| return None | |
| def generate(self, text, writer_id=0, num_steps=50, guidance_scale=7.5): | |
| """ | |
| Generate Hindi handwriting from text | |
| Args: | |
| text: Hindi text in Devanagari script | |
| writer_id: Writer style ID (0-338) | |
| num_steps: Number of diffusion steps (10-100) | |
| guidance_scale: Text guidance strength (1.0-15.0) | |
| Returns: | |
| Tuple[PIL.Image, str]: Generated image and status message | |
| """ | |
| if self.model is None: | |
| return None, "❌ Model not initialized" | |
| try: | |
| # Input validation | |
| if not text.strip(): | |
| return None, "⚠️ Please enter Hindi text" | |
| writer_id = max(0, min(int(writer_id), 338)) | |
| num_steps = max(10, min(int(num_steps), 100)) | |
| guidance_scale = max(1.0, min(float(guidance_scale), 15.0)) | |
| print(f"\n🎨 Generating handwriting...") | |
| print(f" Text: '{text}'") | |
| print(f" Writer: {writer_id}/338") | |
| print(f" Steps: {num_steps}") | |
| print(f" Guidance: {guidance_scale}") | |
| # Encode text with CANINE | |
| context = self.encode_text(text) | |
| if context is None: | |
| return None, "❌ Text encoding failed" | |
| batch_size = 1 | |
| # Initialize from noise | |
| x = torch.randn(batch_size, 1, 64, 64, device=self.device) | |
| # Reverse diffusion process | |
| for step in range(num_steps - 1, -1, -1): | |
| # Prepare timestep and writer conditioning | |
| t = torch.full((batch_size,), step, dtype=torch.long, device=self.device) | |
| y = torch.tensor([writer_id], dtype=torch.long, device=self.device) | |
| # Model prediction | |
| with torch.no_grad(): | |
| noise_pred = self.model( | |
| x, | |
| timesteps=t, | |
| context=context, | |
| y=y | |
| ) | |
| # Denoising step with adaptive scaling | |
| alpha_t = 1.0 - (step / num_steps) | |
| scale = guidance_scale * alpha_t | |
| x = x - 0.01 * scale * noise_pred | |
| # Progress indicator | |
| if (num_steps - step) % max(1, num_steps // 5) == 0: | |
| progress = ((num_steps - step) / num_steps) * 100 | |
| print(f" Progress: {progress:.0f}%") | |
| # Post-processing | |
| x = torch.clamp(x, -1, 1) | |
| x = (x + 1) / 2 # Normalize to [0, 1] | |
| x = x.squeeze(0).squeeze(0).cpu().numpy() | |
| # Convert to PIL Image | |
| img_array = (x * 255).astype(np.uint8) | |
| img = Image.fromarray(img_array, mode='L') | |
| status = f"✅ Generated with writer {writer_id}, {num_steps} steps" | |
| print(f" {status}\n") | |
| return img, status | |
| except Exception as e: | |
| error_msg = f"❌ Generation error: {str(e)}" | |
| print(f" {error_msg}") | |
| return None, error_msg | |
| # ============================================================================== | |
| # CONFIGURATION | |
| # ============================================================================== | |
| # Path to your trained checkpoint (edit this!) | |
| CHECKPOINT_PATH = "./checkpoints/model.pt" | |
| # Initialize demo | |
| print("\n🚀 Initializing DiffusionPen...") | |
| demo_instance = DiffusionPenDemo( | |
| checkpoint_path=CHECKPOINT_PATH, | |
| device=None # Auto-detect GPU/CPU | |
| ) | |
| def gradio_generate(text, writer_id, num_steps, guidance_scale): | |
| """Gradio callback for generation""" | |
| img, message = demo_instance.generate( | |
| text=text, | |
| writer_id=writer_id, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale | |
| ) | |
| return img, message | |
| # ============================================================================== | |
| # GRADIO INTERFACE | |
| # ============================================================================== | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="amber", | |
| ) | |
| with gr.Blocks(title="DiffusionPen - Hindi Handwriting Generation", theme=theme) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| # 🎨 DiffusionPen: Hindi Handwriting Generation | |
| Generate authentic Hindi handwriting using diffusion models with CANINE text encoding. | |
| """) | |
| # Main content | |
| with gr.Row(): | |
| # Input panel | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### ✏️ Input Settings") | |
| text_input = gr.Textbox( | |
| label="Hindi Text (Devanagari)", | |
| placeholder="नमस्ते", | |
| lines=2, | |
| info="Enter text in Devanagari script" | |
| ) | |
| writer_id = gr.Slider( | |
| label="Writer ID", | |
| minimum=0, | |
| maximum=338, | |
| value=0, | |
| step=1, | |
| info="0-338: Different writing styles" | |
| ) | |
| num_steps = gr.Slider( | |
| label="Diffusion Steps", | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=10, | |
| info="10=fast, 100=quality" | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=15.0, | |
| value=7.5, | |
| step=0.5, | |
| info="1=ignore text, 15=strict" | |
| ) | |
| generate_btn = gr.Button( | |
| "✨ Generate Handwriting", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Output panel | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### 📊 Output") | |
| output_image = gr.Image( | |
| label="Generated Handwriting", | |
| type='pil', | |
| interactive=False, | |
| show_download_button=True | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| info="Generation progress and results" | |
| ) | |
| # Examples | |
| gr.Markdown("### 📚 Examples to Try") | |
| gr.Examples( | |
| examples=[ | |
| ["नमस्ते", 0, 50, 7.5], | |
| ["हिंदी", 50, 50, 7.5], | |
| ["आईआईआीटी", 100, 50, 7.5], | |
| ["लिपि", 150, 50, 7.5], | |
| ["भाषा", 200, 50, 7.5], | |
| ["नई लिखावट", 250, 60, 7.5], | |
| ], | |
| inputs=[text_input, writer_id, num_steps, guidance_scale], | |
| outputs=[output_image, status_text], | |
| fn=gradio_generate, | |
| cache_examples=False, | |
| run_on_click=False | |
| ) | |
| # Information | |
| gr.Markdown(""" | |
| --- | |
| ### 📖 About This Demo | |
| **Model Architecture:** | |
| - **Base**: UNet with 128 channels, 3 levels | |
| - **Attention**: Spatial transformers at resolutions 16×8 | |
| - **Text Encoding**: CANINE (768-dim, character-level) | |
| - **Writer Styles**: 339 different writing styles | |
| - **Input/Output**: 64×64 grayscale images | |
| **CANINE Text Encoder:** | |
| - ✓ Character-level (no subword tokenization) | |
| - ✓ Native Devanagari support | |
| - ✓ Pre-trained on 104 languages | |
| - ✓ 768-dimensional contextual embeddings | |
| **Performance:** | |
| - CPU: ~2 minutes per image | |
| - GPU: ~20 seconds per image | |
| - Memory: 6-8 GB | |
| ### 💡 Tips | |
| 1. Keep text short (5-10 characters) for faster generation | |
| 2. Try different Writer IDs for style variation | |
| 3. Increase steps from 50→100 for better quality | |
| 4. Guidance scale 5-10 works best for most cases | |
| 5. Use CPU to generate demos, GPU for production | |
| ### 🔗 Resources | |
| - [CANINE Paper](https://arxiv.org/abs/2103.06367) | |
| - [Diffusion Models Course](https://huggingface.co/course) | |
| - [UNet Architecture](https://en.wikipedia.org/wiki/U-Net) | |
| """) | |
| # Connect button | |
| generate_btn.click( | |
| fn=gradio_generate, | |
| inputs=[text_input, writer_id, num_steps, guidance_scale], | |
| outputs=[output_image, status_text], | |
| api_name="generate" | |
| ) | |
| if __name__ == "__main__": | |
| print(f"\n{'='*60}") | |
| print("🚀 Starting DiffusionPen Gradio Demo") | |
| print(f"{'='*60}") | |
| print(f"Device: {demo_instance.device}") | |
| print(f"Checkpoint: {'✓ Loaded' if demo_instance.checkpoint_loaded else '✗ Not found'}") | |
| print(f"Models: {'✓ Ready' if demo_instance.model is not None else '✗ Error'}") | |
| print(f"{'='*60}\n") | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |