Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from model import Generator | |
| import torchvision.utils as vutils | |
| import os | |
| from math import log2 | |
| # Function to generate images | |
| def generate_images(): | |
| Z_DIM = 256 | |
| IN_CHANNELS = 256 | |
| # Load pretrained generator weights | |
| checkpoint = torch.load("generator.pth", map_location=torch.device('cpu')) | |
| # Filter out optimizer-related keys | |
| state_dict = checkpoint['state_dict'] | |
| # Load the filtered state dictionary into the model | |
| generator = Generator(Z_DIM, IN_CHANNELS, img_channels=3) | |
| generator.load_state_dict(state_dict) | |
| generator.eval() | |
| # Set output directory | |
| output_dir = "generated_images" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Generate images | |
| img_sizes = [256] | |
| images = [] | |
| for img_size in img_sizes: | |
| num_steps = int(log2(img_size / 4)) | |
| x = torch.randn((6, Z_DIM, 1, 1)) # Generate a batch of 6 images | |
| with torch.no_grad(): | |
| z = generator(x, alpha=0.5, steps=num_steps) | |
| # Normalize the generated images to the range [-1, 1] | |
| z = (z + 1) / 2 | |
| assert z.shape == (6, 3, img_size, img_size) | |
| # Append generated images to the list | |
| for i in range(6): | |
| images.append(z[i].detach()) | |
| return images | |
| # Main function to create Streamlit web app | |
| def main(): | |
| st.title('Image Generation with pro-gan 🤖') | |
| st.write("Click the buttons below to generate images.") | |
| st.write("Trained on CelebHQ dataset.") | |
| # Prompt message about image size | |
| st.write("Note: Due to limited resources, the model has been trained to generate 256x256 size images. They are still awesome!") | |
| # Generate images on button click | |
| if st.button('Generate Images'): | |
| images = generate_images() | |
| # Display the generated images | |
| for i, image in enumerate(images): | |
| st.image(image.permute(1, 2, 0).cpu().numpy(), caption=f'Generated Image {i+1}', use_column_width=True) | |
| if __name__ == '__main__': | |
| main() | |