| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| import os |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| nz = 100 |
| ngf = 64 |
| num_classes = 10 |
|
|
| |
| class Generator(nn.Module): |
| def __init__(self): |
| super(Generator, self).__init__() |
| |
| self.label_embedding = nn.Embedding(num_classes, num_classes) |
| |
| self.main = nn.Sequential( |
| nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False), |
| nn.BatchNorm2d(ngf * 8), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(ngf * 4), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(ngf * 2), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False), |
| nn.Tanh() |
| ) |
| |
| self.resize = nn.AdaptiveAvgPool2d((28, 28)) |
| |
| def forward(self, noise, labels): |
| label_embedding = self.label_embedding(labels) |
| label_embedding = label_embedding.view(label_embedding.size(0), num_classes, 1, 1) |
| input_tensor = torch.cat([noise, label_embedding], dim=1) |
| output = self.main(input_tensor) |
| output = self.resize(output) |
| return output |
|
|
| |
| def load_model(): |
| generator = Generator().to(device) |
| |
| try: |
| checkpoint = torch.load('mnist_gan_model.pth', map_location=device) |
| generator.load_state_dict(checkpoint['generator_state_dict']) |
| generator.eval() |
| print("β
Model loaded successfully!") |
| return generator |
| except Exception as e: |
| print(f"β Error loading model: {e}") |
| return None |
|
|
| |
| generator = load_model() |
|
|
| |
| def generate_digit_images(digit): |
| """Generate 5 images of the specified digit""" |
| |
| if generator is None: |
| return [Image.new('L', (112, 112), 128)] * 5 |
| |
| digit = int(digit) |
| num_images = 5 |
| |
| with torch.no_grad(): |
| noise = torch.randn(num_images, nz, 1, 1).to(device) |
| labels = torch.full((num_images,), digit, dtype=torch.long).to(device) |
| |
| generated_images = generator(noise, labels) |
| |
| images = generated_images.cpu().numpy() |
| images = (images + 1) / 2.0 |
| images = np.squeeze(images) |
| |
| pil_images = [] |
| for img in images: |
| img_uint8 = (img * 255).astype(np.uint8) |
| pil_img = Image.fromarray(img_uint8, mode='L') |
| pil_img = pil_img.resize((112, 112), Image.NEAREST) |
| pil_images.append(pil_img) |
| |
| return pil_images |
|
|
| |
| def create_app(): |
| with gr.Blocks(title="Handwritten Digit Generator", theme=gr.themes.Soft()) as app: |
| |
| gr.Markdown("# π’ Handwritten Digit Generator") |
| gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| digit_input = gr.Dropdown( |
| choices=list(range(10)), |
| value=8, |
| label="Choose a digit (0-9)" |
| ) |
| generate_btn = gr.Button("π¨ Generate Images", variant="primary") |
| |
| with gr.Column(scale=2): |
| gr.Markdown("### Generated Images") |
| image_gallery = gr.Gallery( |
| label="Generated Digit Images", |
| show_label=False, |
| columns=5, |
| rows=1, |
| height=200 |
| ) |
| |
| generate_btn.click( |
| fn=generate_digit_images, |
| inputs=[digit_input], |
| outputs=[image_gallery] |
| ) |
| |
| |
| app.load( |
| fn=lambda: generate_digit_images(8), |
| outputs=[image_gallery] |
| ) |
| |
| gr.Markdown("---") |
| gr.Markdown("**π€ Model**: Conditional GAN | **β‘ Framework**: PyTorch + Gradio") |
| |
| return app |
|
|
| |
| if __name__ == "__main__": |
| app = create_app() |
| app.launch() |