| | import streamlit as st |
| |
|
| | import numpy as np |
| | import torch |
| | from huggingface_hub import hf_hub_download |
| | import json |
| |
|
| | CONFIG_NAME = "config.json" |
| | revision = None |
| | cache_dir = None |
| | force_download = False |
| | proxies = None |
| | resume_download = False |
| | local_files_only = False |
| | token = None |
| |
|
| |
|
| | from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN |
| |
|
| | def load_model(model_name="ceyda/butterfly_cropped_uniq1K_512"): |
| |
|
| | """ |
| | Loads a pre-trained LightweightGAN model from Hugging Face Model Hub. |
| | |
| | Args: |
| | model_name (str): The name of the pre-trained model to load. Defaults to "ceyda/butterfly_cropped_uniq1K_512". |
| | model_version (str): The version of the pre-trained model to load. Defaults to None. |
| | |
| | Returns: |
| | LightweightGAN: The loaded pre-trained model. |
| | """ |
| | |
| | config_file = hf_hub_download( |
| | repo_id=str(model_name), |
| | filename=CONFIG_NAME, |
| | revision=revision, |
| | cache_dir=cache_dir, |
| | force_download=force_download, |
| | proxies=proxies, |
| | resume_download=resume_download, |
| | token=token, |
| | local_files_only=local_files_only, |
| | ) |
| | with open(config_file, "r", encoding="utf-8") as f: |
| | config = json.load(f) |
| |
|
| | |
| | gan = LightweightGAN(latent_dim=256, image_size=512) |
| |
|
| | gan = gan._from_pretrained( |
| | model_id=str(model_name), |
| | revision=revision, |
| | cache_dir=cache_dir, |
| | force_download=force_download, |
| | proxies=proxies, |
| | resume_download=resume_download, |
| | local_files_only=local_files_only, |
| | token=token, |
| | use_auth_token=False, |
| | config=config, |
| | ) |
| |
|
| | gan.eval() |
| | return gan |
| |
|
| | def generation(gan, batch_size=1): |
| | with torch.no_grad(): |
| | ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255 |
| | ims = ims.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| | return ims |