Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from types import SimpleNamespace | |
| import torch | |
| import torch.nn.functional as F | |
| import wandb | |
| import streamlit as st | |
| from utilities import ContextUnet, setup_ddpm | |
| from constants import WANDB_API_KEY | |
| def load_model(): | |
| """ | |
| This function loads the model from the model registry. | |
| """ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # login to wandb | |
| # wandb.login(key=WANDB_API_KEY) | |
| api = wandb.Api(api_key=WANDB_API_KEY) | |
| artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v1", type="model") | |
| model_path = Path(artifact.download()) | |
| # recover model info from the registry | |
| producer_run = artifact.logged_by() | |
| # load the weights dictionary | |
| model_weights = torch.load(model_path / "context_model.pth", map_location="cpu") | |
| # create the model | |
| model = ContextUnet( | |
| in_channels=3, | |
| n_feat=producer_run.config["n_feat"], | |
| n_cfeat=producer_run.config["n_cfeat"], | |
| height=producer_run.config["height"], | |
| ) | |
| # load the weights into the model | |
| model.load_state_dict(model_weights) | |
| # set the model to eval mode | |
| model.eval() | |
| return model.to(DEVICE) | |
| def show_image(img): | |
| """ | |
| This function shows the image in the streamlit app. | |
| """ | |
| img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2 | |
| st.image(img, clamp=True) | |
| return img | |
| def generate_sprites(feature_vector): | |
| """ | |
| This function generates sprites from a given feature vector. | |
| """ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| config = SimpleNamespace( | |
| # hyperparameters | |
| num_samples=30, | |
| # ddpm sampler hyperparameters | |
| timesteps=500, | |
| beta1=1e-4, | |
| beta2=0.02, | |
| # network hyperparameters | |
| height=16, | |
| ) | |
| nn_model = load_model() | |
| _, sample_ddpm_context = setup_ddpm( | |
| config.beta1, config.beta2, config.timesteps, DEVICE | |
| ) | |
| noises = torch.randn(config.num_samples, 3, config.height, config.height).to(DEVICE) | |
| feature_vector = torch.tensor([feature_vector]).to(DEVICE).float() | |
| ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector) | |
| # upscale the 16*16 images to 256*256 | |
| ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear") | |
| # show the images | |
| img = show_image(ddpm_samples[0]) | |
| return img | |