Spaces:
Sleeping
Sleeping
| """ | |
| MNIST GAN Digit Generator Application | |
| A production-ready Gradio interface for generating handwritten digits using a trained GAN. | |
| Author: Vikranth Reddimasu | |
| License: MIT | |
| """ | |
| import logging | |
| import random | |
| import sys | |
| from pathlib import Path | |
| from typing import Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import matplotlib | |
| matplotlib.use('Agg') # Non-interactive backend for server deployment | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import io | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| NOISE_DIM = 100 | |
| HIDDEN_DIM = 256 | |
| OUTPUT_DIM = 784 | |
| IMAGE_SIZE = 28 | |
| MODEL_PATH = "generator_model.pth" | |
| DEFAULT_SEED = 42 | |
| MIN_IMAGES = 1 | |
| MAX_IMAGES = 16 | |
| MIN_TEMPERATURE = 0.5 | |
| MAX_TEMPERATURE = 2.0 | |
| HERO_TEXT = """ | |
| <h1 style="margin-bottom:0.3rem;">MNIST Digit Generator</h1> | |
| <p style="font-size:1.05rem;margin:0;color:var(--body-text-color);"> | |
| Explore a trained GAN that synthesizes handwriting on demand. Set your parameters, tap generate, and download a fresh grid of digits. | |
| </p> | |
| """ | |
| METRIC_CARDS = [ | |
| ("Noise Dim", f"{NOISE_DIM}", "Latent space size"), | |
| ("Hidden Units", f"{HIDDEN_DIM}", "Per-layer width"), | |
| ("Epochs", "200", "Training duration"), | |
| ("Temp Range", f"{MIN_TEMPERATURE} – {MAX_TEMPERATURE}", "Diversity control"), | |
| ("Generator Loss", "0.981", "Final epoch"), | |
| ("Params", "1.49M", "Trainable weights"), | |
| ] | |
| STYLE_PRESETS = { | |
| "Precise": { | |
| "description": "Clean, focused digits", | |
| "samples": 6, | |
| "seed": 123, | |
| "temperature": 0.8, | |
| }, | |
| "Balanced": { | |
| "description": "Default training vibe", | |
| "samples": 9, | |
| "seed": 512, | |
| "temperature": 1.0, | |
| }, | |
| "Playful": { | |
| "description": "Add variety & contrast", | |
| "samples": 12, | |
| "seed": 777, | |
| "temperature": 1.3, | |
| }, | |
| } | |
| DEFAULT_STYLE = "Balanced" | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| max-width: 1100px !important; | |
| margin: auto; | |
| background: radial-gradient(circle at top, #111827, #0b1120 45%, #05070f 90%); | |
| color: #F8FAFC; | |
| } | |
| #hero { | |
| padding: 0.5rem 0 1rem 0; | |
| } | |
| #hero h1 { | |
| font-size: 2.2rem; | |
| font-weight: 700; | |
| color: #F8FAFC; | |
| } | |
| #hero p { | |
| color: #CBD5F5; | |
| } | |
| .surface { | |
| background: rgba(15, 23, 42, 0.9); | |
| border: 1px solid rgba(148, 163, 184, 0.2); | |
| border-radius: 14px; | |
| padding: 1.2rem; | |
| box-shadow: 0 18px 40px rgba(2, 6, 23, 0.6); | |
| } | |
| .metric-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); | |
| gap: 0.75rem; | |
| margin-bottom: 0.5rem; | |
| } | |
| .metric-card { | |
| border-radius: 12px; | |
| padding: 0.8rem 1rem; | |
| background: linear-gradient(135deg, rgba(59,130,246,0.2), rgba(147,197,253,0.08)); | |
| border: 1px solid rgba(59, 130, 246, 0.25); | |
| } | |
| .metric-label { | |
| margin: 0; | |
| font-size: 0.8rem; | |
| letter-spacing: 0.06em; | |
| text-transform: uppercase; | |
| color: #2563EB; | |
| } | |
| .metric-value { | |
| margin: 0.15rem 0 0 0; | |
| font-size: 1.2rem; | |
| font-weight: 600; | |
| color: #F8FAFC; | |
| } | |
| .metric-hint { | |
| margin: 0.1rem 0 0 0; | |
| font-size: 0.8rem; | |
| color: rgba(248, 250, 252, 0.75); | |
| } | |
| .pill-row > button { | |
| flex: 1; | |
| } | |
| .tip-text { | |
| font-size: 0.92rem; | |
| color: var(--body-text-color-subdued); | |
| color: rgba(226, 232, 240, 0.9); | |
| } | |
| .quick-presets .table { | |
| background: rgba(15, 23, 42, 0.65); | |
| } | |
| .gr-box { | |
| background: rgba(15, 23, 42, 0.8) !important; | |
| } | |
| .style-select .wrap { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 0.35rem; | |
| } | |
| .style-select select { | |
| background: rgba(30, 64, 175, 0.2); | |
| border: 1px solid rgba(79, 70, 229, 0.6); | |
| border-radius: 999px; | |
| color: #E0E7FF; | |
| } | |
| """ | |
| class Generator(nn.Module): | |
| """ | |
| Fully Connected Generator for MNIST digit generation. | |
| Architecture: | |
| - Input: Random noise vector (100-dim) | |
| - Fully connected layers with LeakyReLU and BatchNorm | |
| - Output: 784-dim vector (28×28 image) with Tanh activation | |
| Args: | |
| noise_dim: Dimension of input noise vector (default: 100) | |
| hidden_dim: Dimension of hidden layers (default: 256) | |
| output_dim: Dimension of output vector (default: 784) | |
| """ | |
| def __init__( | |
| self, | |
| noise_dim: int = NOISE_DIM, | |
| hidden_dim: int = HIDDEN_DIM, | |
| output_dim: int = OUTPUT_DIM | |
| ): | |
| super(Generator, self).__init__() | |
| self.model = nn.Sequential( | |
| # First layer | |
| nn.Linear(noise_dim, hidden_dim), | |
| nn.LeakyReLU(0.2), | |
| nn.BatchNorm1d(hidden_dim), | |
| # Second layer | |
| nn.Linear(hidden_dim, hidden_dim * 2), | |
| nn.LeakyReLU(0.2), | |
| nn.BatchNorm1d(hidden_dim * 2), | |
| # Third layer | |
| nn.Linear(hidden_dim * 2, hidden_dim * 4), | |
| nn.LeakyReLU(0.2), | |
| nn.BatchNorm1d(hidden_dim * 4), | |
| # Output layer | |
| nn.Linear(hidden_dim * 4, output_dim), | |
| nn.Tanh() # Output in range [-1, 1] to match normalized images | |
| ) | |
| logger.info(f"Generator initialized with {self._count_parameters():,} parameters") | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass through generator.""" | |
| return self.model(x) | |
| def _count_parameters(self) -> int: | |
| """Count total trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| class ModelManager: | |
| """Manages model loading and inference.""" | |
| def __init__(self, model_path: str = MODEL_PATH): | |
| self.model_path = Path(model_path) | |
| # Auto-detect best available device (CUDA > MPS > CPU) | |
| # Use try-except for MPS check as it may not be available on all platforms | |
| try: | |
| if torch.cuda.is_available(): | |
| self.device = torch.device('cuda') | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| self.device = torch.device('mps') | |
| else: | |
| self.device = torch.device('cpu') | |
| except Exception: | |
| # Fallback to CPU if device detection fails | |
| self.device = torch.device('cpu') | |
| self.generator = None | |
| self.model_info = None | |
| logger.info(f"Using device: {self.device}") | |
| self._load_model() | |
| def _load_model(self) -> None: | |
| """Load the trained generator model.""" | |
| try: | |
| # Resolve absolute path for better debugging | |
| abs_path = self.model_path.resolve() | |
| logger.info(f"Loading model from {abs_path}") | |
| logger.info(f"Current working directory: {Path.cwd()}") | |
| logger.info(f"Model file exists: {self.model_path.exists()}") | |
| # Initialize generator | |
| self.generator = Generator( | |
| noise_dim=NOISE_DIM, | |
| hidden_dim=HIDDEN_DIM, | |
| output_dim=OUTPUT_DIM | |
| ).to(self.device) | |
| logger.info("Generator initialized successfully") | |
| # Load checkpoint if available | |
| if self.model_path.exists(): | |
| logger.info(f"Loading checkpoint from {abs_path}") | |
| checkpoint = torch.load( | |
| str(self.model_path), # Convert Path to string for torch.load | |
| map_location=self.device, | |
| weights_only=True | |
| ) | |
| logger.info("Checkpoint loaded, extracting state dict...") | |
| self.generator.load_state_dict(checkpoint['generator_state_dict']) | |
| self.generator.eval() | |
| logger.info("Model state dict loaded successfully") | |
| epochs = checkpoint.get('epoch', 'N/A') | |
| g_loss = checkpoint.get('generator_loss', 'N/A') | |
| self.model_info = ( | |
| f"Model trained for {epochs} epochs | " | |
| f"Generator Loss: {g_loss:.4f}" if isinstance(g_loss, float) | |
| else f"Model trained for {epochs} epochs" | |
| ) | |
| logger.info(f"Model loaded successfully: {self.model_info}") | |
| else: | |
| logger.warning(f"Model file not found at {abs_path}. Using untrained model.") | |
| logger.warning(f"Files in current directory: {list(Path('.').glob('*.pth'))}") | |
| self.model_info = "Warning: Model weights not found. Using untrained model." | |
| self.generator.eval() | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}", exc_info=True) | |
| self.model_info = f"Error loading model: {str(e)}" | |
| # Don't raise - let the app start and show error in UI | |
| logger.error("Model loading failed, but continuing to start app...") | |
| def generate( | |
| self, | |
| num_images: int, | |
| seed: int, | |
| temperature: float | |
| ) -> torch.Tensor: | |
| """ | |
| Generate digit images. | |
| Args: | |
| num_images: Number of images to generate | |
| seed: Random seed for reproducibility | |
| temperature: Temperature for sampling diversity | |
| Returns: | |
| Generated images as tensor (N, 1, 28, 28) | |
| """ | |
| # Set seeds | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| # Generate noise | |
| noise = torch.randn(num_images, NOISE_DIM).to(self.device) * temperature | |
| # Generate images | |
| generated = self.generator(noise) | |
| # Reshape to image format (N, 1, 28, 28) | |
| generated = generated.view(-1, 1, IMAGE_SIZE, IMAGE_SIZE) | |
| return generated.cpu() | |
| def validate_inputs( | |
| num_images: int, | |
| seed: int, | |
| temperature: float | |
| ) -> Tuple[int, int, float]: | |
| """ | |
| Validate and clamp input parameters. | |
| Args: | |
| num_images: Requested number of images | |
| seed: Random seed | |
| temperature: Sampling temperature | |
| Returns: | |
| Validated (num_images, seed, temperature) | |
| """ | |
| num_images = max(MIN_IMAGES, min(MAX_IMAGES, int(num_images))) | |
| seed = max(0, int(seed)) | |
| temperature = max(MIN_TEMPERATURE, min(MAX_TEMPERATURE, float(temperature))) | |
| return num_images, seed, temperature | |
| def create_image_grid( | |
| images: np.ndarray, | |
| num_images: int | |
| ) -> Image.Image: | |
| """ | |
| Create a grid visualization of generated images. | |
| Args: | |
| images: Array of images (N, 1, 28, 28) | |
| num_images: Number of images to display | |
| Returns: | |
| PIL Image containing the grid | |
| """ | |
| # Denormalize from [-1, 1] to [0, 1] | |
| images = (images + 1) / 2 | |
| images = np.clip(images, 0, 1) | |
| # Determine grid dimensions | |
| n_cols = min(4, num_images) | |
| n_rows = (num_images + n_cols - 1) // n_cols | |
| # Create figure | |
| fig_width = n_cols * 2 | |
| fig_height = n_rows * 2 | |
| fig, axes = plt.subplots( | |
| n_rows, n_cols, | |
| figsize=(fig_width, fig_height), | |
| facecolor='white' | |
| ) | |
| # Handle single image case | |
| if num_images == 1: | |
| axes = np.array([axes]) | |
| axes = axes.flatten() | |
| # Plot images | |
| for idx in range(num_images): | |
| axes[idx].imshow(images[idx].squeeze(), cmap='gray', vmin=0, vmax=1) | |
| axes[idx].axis('off') | |
| axes[idx].set_title(f'Sample {idx+1}', fontsize=10, pad=5) | |
| # Hide unused subplots | |
| for idx in range(num_images, len(axes)): | |
| axes[idx].axis('off') | |
| plt.tight_layout(pad=0.5) | |
| # Convert to PIL Image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| image = Image.open(buf) | |
| plt.close(fig) # Important: close figure to free memory | |
| return image | |
| def generate_digits( | |
| num_images: int, | |
| seed: int, | |
| temperature: float | |
| ) -> Image.Image: | |
| """ | |
| Main generation function for Gradio interface. | |
| Args: | |
| num_images: Number of digits to generate (1-16) | |
| seed: Random seed for reproducibility | |
| temperature: Diversity control (0.5-2.0) | |
| Returns: | |
| PIL Image containing generated digits | |
| """ | |
| try: | |
| if model_manager is None: | |
| raise RuntimeError("Model not loaded. Please check the logs for details.") | |
| # Validate inputs | |
| num_images, seed, temperature = validate_inputs(num_images, seed, temperature) | |
| logger.info( | |
| f"Generating {num_images} images with seed={seed}, " | |
| f"temperature={temperature:.2f}" | |
| ) | |
| # Generate images | |
| images = model_manager.generate(num_images, seed, temperature) | |
| # Create visualization | |
| image_grid = create_image_grid(images.numpy(), num_images) | |
| logger.info("Generation completed successfully") | |
| return image_grid | |
| except Exception as e: | |
| logger.error(f"Error during generation: {str(e)}", exc_info=True) | |
| # Return error image | |
| fig, ax = plt.subplots(figsize=(6, 2)) | |
| ax.text( | |
| 0.5, 0.5, | |
| f"Error: {str(e)}\nPlease check logs.", | |
| ha='center', va='center', | |
| fontsize=12, color='red' | |
| ) | |
| ax.axis('off') | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| error_img = Image.open(buf) | |
| plt.close(fig) | |
| return error_img | |
| def _random_seed() -> int: | |
| """Return a random seed for the shuffle button.""" | |
| return random.randint(0, 10000) | |
| def _apply_style(name: str) -> Tuple[int, int, float, str]: | |
| """Return slider updates for the selected style preset.""" | |
| preset = STYLE_PRESETS.get(name, STYLE_PRESETS[DEFAULT_STYLE]) | |
| note = f"Preset • {preset['description']}" | |
| return preset["samples"], preset["seed"], preset["temperature"], note | |
| def _sync_seed(link_enabled: bool, seed_value: int): | |
| """Mirror left seed to right when link is enabled.""" | |
| if link_enabled: | |
| return gr.update(value=seed_value) | |
| return gr.update() | |
| def _toggle_seed_input(link_enabled: bool, seed_value: int): | |
| """Toggle right-seed interactivity when link checkbox changes.""" | |
| return gr.update( | |
| value=seed_value if link_enabled else seed_value, | |
| interactive=not link_enabled | |
| ) | |
| def generate_comparison( | |
| num_images: int, | |
| seed_left: int, | |
| temperature_left: float, | |
| seed_right: int, | |
| temperature_right: float | |
| ) -> Tuple[Image.Image, Image.Image]: | |
| """Generate two grids for side-by-side comparison.""" | |
| left = generate_digits(num_images, seed_left, temperature_left) | |
| right = generate_digits(num_images, seed_right, temperature_right) | |
| return left, right | |
| def create_interface() -> gr.Blocks: | |
| """Create and configure the Gradio interface.""" | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="cyan", | |
| neutral_hue="slate" | |
| ) | |
| with gr.Blocks( | |
| title="MNIST GAN Digit Generator", | |
| theme=theme, | |
| css=CUSTOM_CSS | |
| ) as interface: | |
| gr.Markdown(HERO_TEXT, elem_id="hero") | |
| with gr.Row(elem_classes="metric-grid"): | |
| for label, value, hint in METRIC_CARDS: | |
| gr.HTML( | |
| f""" | |
| <div class='metric-card'> | |
| <div class='metric-label'>{label}</div> | |
| <div class='metric-value'>{value}</div> | |
| <div class='metric-hint'>{hint}</div> | |
| </div> | |
| """ | |
| ) | |
| model_status = model_manager.model_info if model_manager else "Model unavailable" | |
| with gr.Tabs(): | |
| with gr.Tab("Playground"): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, elem_classes="surface"): | |
| gr.Markdown("### Controls") | |
| num_images = gr.Slider( | |
| minimum=MIN_IMAGES, | |
| maximum=MAX_IMAGES, | |
| value=STYLE_PRESETS[DEFAULT_STYLE]["samples"], | |
| step=1, | |
| label="Samples", | |
| info=f"Generate {MIN_IMAGES}-{MAX_IMAGES} digits" | |
| ) | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=10000, | |
| value=STYLE_PRESETS[DEFAULT_STYLE]["seed"], | |
| step=1, | |
| label="Seed", | |
| info="Same seed → same grid" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=MIN_TEMPERATURE, | |
| maximum=MAX_TEMPERATURE, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = wilder digits" | |
| ) | |
| gr.Markdown("#### Style presets") | |
| with gr.Row(): | |
| style_selector = gr.Dropdown( | |
| label="Choose a vibe", | |
| choices=list(STYLE_PRESETS.keys()), | |
| value=DEFAULT_STYLE, | |
| elem_classes="style-select", | |
| scale=1 | |
| ) | |
| preset_note = gr.Markdown( | |
| f"Preset • {STYLE_PRESETS[DEFAULT_STYLE]['description']}" | |
| ) | |
| style_selector.change( | |
| fn=_apply_style, | |
| inputs=style_selector, | |
| outputs=[num_images, seed, temperature, preset_note], | |
| queue=False | |
| ) | |
| with gr.Row(elem_classes="pill-row"): | |
| random_btn = gr.Button("Shuffle Seed", variant="secondary") | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2, elem_classes="surface"): | |
| gr.Markdown("### Output") | |
| output_image = gr.Image( | |
| label="", | |
| type="pil", | |
| height=430, | |
| show_download_button=True, | |
| interactive=False | |
| ) | |
| gr.Markdown( | |
| f"**Model status:** {model_status}", | |
| elem_classes="tip-text" | |
| ) | |
| gr.Markdown( | |
| "Pro tip: Lower temperatures focus on clean digits, while higher values explore creative shapes.", | |
| elem_classes="tip-text" | |
| ) | |
| with gr.Group(elem_classes="surface quick-presets"): | |
| gr.Markdown("### Quick Presets") | |
| gr.Examples( | |
| label="Pick a vibe", | |
| examples=[ | |
| [4, 21, 0.9], | |
| [9, 512, 1.0], | |
| [12, 777, 1.3], | |
| [16, 999, 0.7], | |
| ], | |
| inputs=[num_images, seed, temperature], | |
| outputs=output_image, | |
| fn=generate_digits, | |
| cache_examples=True | |
| ) | |
| with gr.Tab("Compare"): | |
| gr.Markdown("Explore two parameter sets side-by-side to understand how seeds and temperature shape samples.") | |
| num_images_cmp = gr.Slider( | |
| minimum=MIN_IMAGES, | |
| maximum=12, | |
| value=6, | |
| step=1, | |
| label="Samples per grid" | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(elem_classes="surface"): | |
| gr.Markdown("#### Left grid") | |
| seed_left = gr.Slider(0, 10000, 42, step=1, label="Seed") | |
| temp_left = gr.Slider(MIN_TEMPERATURE, MAX_TEMPERATURE, 0.9, step=0.1, label="Temperature") | |
| with gr.Column(elem_classes="surface"): | |
| gr.Markdown("#### Right grid") | |
| seed_lock = gr.Checkbox(value=True, label="Mirror seed from left") | |
| seed_right = gr.Slider( | |
| 0, 10000, 42, step=1, label="Seed", interactive=False | |
| ) | |
| temp_right = gr.Slider(MIN_TEMPERATURE, MAX_TEMPERATURE, 1.3, step=0.1, label="Temperature") | |
| seed_left.change( | |
| fn=_sync_seed, | |
| inputs=[seed_lock, seed_left], | |
| outputs=seed_right, | |
| queue=False | |
| ) | |
| seed_lock.change( | |
| fn=_toggle_seed_input, | |
| inputs=[seed_lock, seed_left], | |
| outputs=seed_right, | |
| queue=False | |
| ) | |
| compare_btn = gr.Button("Generate Comparison", variant="primary") | |
| with gr.Row(): | |
| compare_left = gr.Image(type="pil", height=360, label="Left output", show_download_button=True) | |
| compare_right = gr.Image(type="pil", height=360, label="Right output", show_download_button=True) | |
| with gr.Tab("Insights"): | |
| gr.Markdown( | |
| "Training insights highlight how the GAN converged. These snapshots come directly from the notebook used to build the weights." | |
| ) | |
| loss_path = Path("losses.png") | |
| if loss_path.exists(): | |
| gr.Image( | |
| value=str(loss_path), | |
| type="filepath", | |
| height=280, | |
| label="Training loss curves", | |
| show_download_button=True | |
| ) | |
| else: | |
| gr.Markdown("⚠️ Loss plot not found in repository.") | |
| gr.Markdown("**Key checkpoints**") | |
| gr.Dataframe( | |
| headers=["Metric", "Value", "Notes"], | |
| value=[ | |
| ["Generator Loss", "0.981", "Converged after steady decline"], | |
| ["Discriminator Loss", "1.213", "Balanced with generator"], | |
| ["Best Epoch", "200", "Used for deployment weights"], | |
| ], | |
| interactive=False, | |
| wrap=True | |
| ) | |
| gr.Markdown( | |
| "Want to dig deeper? Clone the repo and open `GAN_MNIST_Assignment.ipynb` to replay the full experiment." | |
| ) | |
| with gr.Accordion("What powers this demo?", open=False): | |
| gr.Markdown( | |
| """ | |
| - **Generator:** Fully-connected GAN with 1.49M parameters (100-dim noise → 28×28 image). | |
| - **Training:** MNIST, 200 epochs, Adam (lr=2e-4), final generator loss ≈ 0.97. | |
| - **Deployment:** PyTorch + Gradio running on CPU in Hugging Face Spaces. | |
| """ | |
| ) | |
| random_btn.click( | |
| fn=_random_seed, | |
| outputs=seed, | |
| queue=False | |
| ) | |
| generate_btn.click( | |
| fn=generate_digits, | |
| inputs=[num_images, seed, temperature], | |
| outputs=output_image | |
| ) | |
| compare_btn.click( | |
| fn=generate_comparison, | |
| inputs=[num_images_cmp, seed_left, temp_left, seed_right, temp_right], | |
| outputs=[compare_left, compare_right] | |
| ) | |
| interface.load( | |
| fn=lambda: generate_digits(9, DEFAULT_SEED, 1.0), | |
| inputs=None, | |
| outputs=output_image | |
| ) | |
| return interface | |
| # Initialize model manager with error handling | |
| logger.info("Initializing application...") | |
| try: | |
| model_manager = ModelManager() | |
| except Exception as e: | |
| logger.error(f"Failed to initialize model manager: {str(e)}", exc_info=True) | |
| # Create a dummy model manager that will show error in UI | |
| model_manager = None | |
| # Create interface | |
| demo = create_interface() | |
| if __name__ == "__main__": | |
| logger.info("Starting Gradio application...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |