Spaces:
Runtime error
Runtime error
| import json | |
| import torch | |
| from huggingnft.lightweight_gan.train import timestamped_filename | |
| from streamlit_option_menu import option_menu | |
| from huggingface_hub import hf_hub_download, file_download | |
| from PIL import Image | |
| from huggingface_hub.hf_api import HfApi | |
| import streamlit as st | |
| from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer | |
| from accelerate import Accelerator | |
| from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet | |
| from torchvision import transforms as T | |
| from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip | |
| from torchvision.utils import make_grid | |
| import requests | |
| hfapi = HfApi() | |
| model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")] | |
| # streamlit-option-menu | |
| # st.set_page_config(page_title="Streamlit App Gallery", page_icon="", layout="wide") | |
| # sysmenu = ''' | |
| # <style> | |
| # #MainMenu {visibility:hidden;} | |
| # footer {visibility:hidden;} | |
| # ''' | |
| # st.markdown(sysmenu,unsafe_allow_html=True) | |
| # # Add a logo (optional) in the sidebar | |
| # logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png') | |
| # profile = Image.open(r'C:\Users\13525\Desktop\medium_profile.png') | |
| ABOUT_TEXT = "🤗 Hugging NFT - Generate NFT by OpenSea collection name." | |
| CONTACT_TEXT = "Here is some contact info" | |
| GENERATE_IMAGE_TEXT = "Text about generation" | |
| INTERPOLATION_TEXT = "Text about Interpolation" | |
| COLLECTION2COLLECTION_TEXT = "Text about Collection2Collection" | |
| STOPWORDS = ["-old"] | |
| COLLECTION2COLLECTION_KEYS = ["__2__"] | |
| def load_lightweight_model(model_name): | |
| file_path = file_download.hf_hub_download( | |
| repo_id=model_name, | |
| filename="config.json" | |
| ) | |
| config = json.loads(open(file_path).read()) | |
| organization_name, name = model_name.split("/") | |
| model = Trainer(**config, organization_name=organization_name, name=name) | |
| model.load(use_cpu=True) | |
| model.accelerator = Accelerator() | |
| return model | |
| def clean_models(model_names, stopwords): | |
| cleaned_model_names = [] | |
| for model_name in model_names: | |
| clear = True | |
| for stopword in stopwords: | |
| if stopword in model_name: | |
| clear = False | |
| break | |
| if clear: | |
| cleaned_model_names.append(model_name) | |
| return cleaned_model_names | |
| def get_concat_h(im1, im2): | |
| dst = Image.new('RGB', (im1.width + im2.width, im1.height)) | |
| dst.paste(im1, (0, 0)) | |
| dst.paste(im2, (im1.width, 0)) | |
| return dst | |
| model_names = clean_models(model_names, STOPWORDS) | |
| with st.sidebar: | |
| choose = option_menu("Hugging NFT", | |
| ["About", "Generate image", "Interpolation", "Collection2Collection", "Contact"], | |
| icons=['house', 'camera fill', 'bi bi-youtube', 'book', 'person lines fill'], | |
| menu_icon="app-indicator", default_index=0, | |
| styles={ | |
| # "container": {"padding": "5!important", "background-color": "#fafafa", }, | |
| "container": {"border-radius": ".0rem"}, | |
| # "icon": {"color": "orange", "font-size": "25px"}, | |
| # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px", | |
| # "--hover-color": "#eee"}, | |
| # "nav-link-selected": {"background-color": "#02ab21"}, | |
| } | |
| ) | |
| st.sidebar.markdown( | |
| """ | |
| <style> | |
| .aligncenter { | |
| text-align: center; | |
| } | |
| </style> | |
| <p style='text-align: center'> | |
| <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">Project Repository</a> | |
| </p> | |
| <p class="aligncenter"> | |
| <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank"> | |
| <img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/> | |
| </a> | |
| </p> | |
| <p class="aligncenter"> | |
| <a href="https://twitter.com/alekseykorshuk" target="_blank"> | |
| <img src="https://img.shields.io/twitter/follow/alekseykorshuk?style=social"/> | |
| </a> | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| if choose == "About": | |
| README = requests.get("https://raw.githubusercontent.com/AlekseyKorshuk/huggingnft/main/README.md").text | |
| # st.title(choose) | |
| st.markdown(README) | |
| if choose == "Contact": | |
| st.title(choose) | |
| st.markdown(CONTACT_TEXT) | |
| if choose == "Generate image": | |
| st.title(choose) | |
| st.markdown(GENERATE_IMAGE_TEXT) | |
| model_name = st.selectbox( | |
| 'Choose model:', | |
| clean_models(model_names, COLLECTION2COLLECTION_KEYS) | |
| ) | |
| generation_type = st.selectbox( | |
| 'Select generation type:', | |
| ["default", "ema"] | |
| ) | |
| nrows = st.number_input("Number of rows:", | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| value=8, | |
| ) | |
| generate_image_button = st.button("Generate") | |
| if generate_image_button: | |
| with st.spinner(text=f"Downloading selected model..."): | |
| model = load_lightweight_model(f"huggingnft/{model_name}") | |
| with st.spinner(text=f"Generating..."): | |
| st.image( | |
| model.generate_app( | |
| num=timestamped_filename(), | |
| nrow=nrows, | |
| checkpoint=-1, | |
| types=generation_type | |
| )[0] | |
| ) | |
| if choose == "Interpolation": | |
| st.title(choose) | |
| st.markdown(INTERPOLATION_TEXT) | |
| model_name = st.selectbox( | |
| 'Choose model:', | |
| clean_models(model_names, COLLECTION2COLLECTION_KEYS) | |
| ) | |
| nrows = st.number_input("Number of rows:", | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| value=1, | |
| ) | |
| num_steps = st.number_input("Number of steps:", | |
| min_value=1, | |
| max_value=1000, | |
| step=1, | |
| value=100, | |
| ) | |
| generate_image_button = st.button("Generate") | |
| if generate_image_button: | |
| with st.spinner(text=f"Downloading selected model..."): | |
| model = load_lightweight_model(f"huggingnft/{model_name}") | |
| my_bar = st.progress(0) | |
| result = model.generate_interpolation( | |
| num=timestamped_filename(), | |
| num_image_tiles=nrows, | |
| num_steps=num_steps, | |
| save_frames=False, | |
| progress_bar=my_bar | |
| ) | |
| my_bar.empty() | |
| with st.spinner(text=f"Uploading result..."): | |
| st.image(result) | |
| if choose == "Collection2Collection": | |
| st.title(choose) | |
| st.markdown(COLLECTION2COLLECTION_TEXT) | |
| model_name = st.selectbox( | |
| 'Choose model:', | |
| set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS)) | |
| ) | |
| nrows = st.number_input("Number of images to generate:", | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| value=1, | |
| ) | |
| generate_image_button = st.button("Generate") | |
| if generate_image_button: | |
| n_channels = 3 | |
| image_size = 256 | |
| input_shape = (image_size, image_size) | |
| transform = Compose([ | |
| T.ToPILImage(), | |
| T.Resize(input_shape), | |
| ToTensor(), | |
| Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ]) | |
| with st.spinner(text=f"Downloading selected model..."): | |
| translator = GeneratorResNet.from_pretrained(f'huggingnft/{model_name}', | |
| input_shape=(n_channels, image_size, image_size), | |
| num_residual_blocks=9) | |
| z = torch.randn(nrows, 100, 1, 1) | |
| with st.spinner(text=f"Downloading selected model..."): | |
| model = load_lightweight_model(f"huggingnft/{model_name.split('__2__')[0]}") | |
| with st.spinner(text=f"Generating input images..."): | |
| punks = model.generate_app( | |
| num=timestamped_filename(), | |
| nrow=4, | |
| checkpoint=-1, | |
| types="default" | |
| )[1] | |
| pipe_transform = T.Resize((256, 256)) | |
| input = pipe_transform(punks) | |
| with st.spinner(text=f"Generating output images..."): | |
| output = translator(input) | |
| out_img = make_grid(output, | |
| nrow=4, normalize=True) | |
| # out_img = make_grid(punks, | |
| # nrow=8, normalize=True) | |
| out_transform = Compose([ | |
| T.ToPILImage() | |
| ]) | |
| results = [] | |
| for out_punk, out_ape in zip(input, output): | |
| results.append( | |
| get_concat_h(out_transform(make_grid(out_punk, nrow=1, normalize=True)), out_transform(make_grid(out_ape, nrow=1, normalize=True))) | |
| ) | |
| for result in results: | |
| st.image(result) | |