Spaces:
Build error
Build error
| import gradio as gr | |
| import os, glob | |
| from functools import partial | |
| import glob | |
| import torch | |
| from torch import nn | |
| from PIL import Image | |
| import numpy as np | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| class RuleCA(nn.Module): | |
| def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device): | |
| super().__init__() | |
| # The hard-coded filters: | |
| self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]), | |
| torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]), | |
| torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T, | |
| torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device) | |
| self.chn = 4 | |
| self.rule_channels = rule_channels | |
| self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device) | |
| self.relu = nn.ReLU() | |
| self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device) | |
| if zero_w2: | |
| self.w2.weight.data.zero_() | |
| self.device = device | |
| def perchannel_conv(self, x, filters): | |
| '''filters: [filter_n, h, w]''' | |
| b, ch, h, w = x.shape | |
| y = x.reshape(b*ch, 1, h, w) | |
| y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular') | |
| y = torch.nn.functional.conv2d(y, filters[:,None]) | |
| return y.reshape(b, -1, h, w) | |
| def forward(self, x, rule=0, update_rate=0.5): | |
| b, ch, xsz, ysz = x.shape | |
| rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device) | |
| rule_grid[:,rule] = 1 | |
| y = self.perchannel_conv(x, self.filters) # Apply the filters | |
| y = torch.cat([y, rule_grid], dim=1) | |
| y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain' | |
| b, c, h, w = y.shape | |
| update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor() | |
| return x+y*update_mask | |
| def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5): | |
| y = self.perchannel_conv(x, self.filters) # Apply the filters | |
| y = torch.cat([y, rule_grid], dim=1) | |
| y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain' | |
| b, c, h, w = y.shape | |
| update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor() | |
| return x+y*update_mask | |
| def to_rgb(self, x): | |
| # TODO: rename this to_rgb & explain | |
| return x[...,:3,:,:]+0.5 | |
| def seed(self, n, sz=128): | |
| """Initializes n 'grids', size sz. In this case all 0s.""" | |
| return torch.zeros(n, self.chn, sz, sz).to(self.device) | |
| def to_frames(video_file): | |
| os.system('rm -r guide_frames;mkdir guide_frames') | |
| os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg") | |
| def update(preset, enhance, scale2x, video_file): | |
| # Load presets | |
| ca = RuleCA(hidden_n=32, rule_channels=3) | |
| ca_fn = '' | |
| if preset == 'Glowing Crystals': | |
| ca_fn = 'glowing_crystals.pt' | |
| elif preset == 'Rainbow Diamonds': | |
| ca_fn = 'rainbow_diamonds.pt' | |
| elif preset == 'Dark Diamonds': | |
| ca_fn = 'dark_diamonds.pt' | |
| elif preset == 'Dragon Scales': | |
| ca = RuleCA(hidden_n=16, rule_channels=3) | |
| ca_fn = 'dragon_scales.pt' | |
| ca.load_state_dict(torch.load(ca_fn, map_location=device)) | |
| # Get video frames | |
| to_frames(video_file) | |
| size=(426, 240) | |
| vid_size = Image.open(f'guide_frames/0001.jpg').size | |
| if vid_size[0]>vid_size[1]: # Change < to > if larger side should be capped at 256px | |
| size = (256, int(256*(vid_size[1]/vid_size[0]))) | |
| else: | |
| size = (int(256*(vid_size[0]/vid_size[1])), 256) | |
| if scale2x: | |
| size = (size[0]*2, size[1]*2) | |
| # Starting grid | |
| x = torch.zeros(1, 4, size[1], size[0]).to(ca.device) | |
| os.system("rm -r steps;mkdir steps") | |
| for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1): | |
| # load frame | |
| im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size) | |
| # make rule grid | |
| rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device) | |
| if enhance: | |
| rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect | |
| # Apply the updates | |
| with torch.no_grad(): | |
| x = ca.forward_w_rule_grid(x, rule_grid.float()) | |
| if i%2==0: | |
| img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0) | |
| img = Image.fromarray(np.array(img*255).astype(np.uint8)) | |
| img.save(f'steps/{i//2:05}.jpeg') | |
| # Write output video from saved frames | |
| os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4") | |
| return 'video.mp4' | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("Choose a preset below, upload a video and then click **Run** to see the output. Read [this report](https://wandb.ai/johnowhitaker/nca/reports/Fun-with-Neural-Cellular-Automata--VmlldzoyMDQ5Mjg0) for background on this project, or check out my [AI art course](https://github.com/johnowhitaker/aiaiart) for an in-depth lesson on Neural Cellular Automata like this.") | |
| with gr.Row(): | |
| preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset') | |
| with gr.Column(): | |
| enhance = gr.Checkbox(label='Rescale inputs (more extreme results)') | |
| scale2x = gr.Checkbox(label='Larger output (slower)') | |
| with gr.Row(): | |
| inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)") | |
| out = gr.Video(label="Output") | |
| btn = gr.Button("Run") | |
| btn.click(fn=update, inputs=[preset, enhance, scale2x, inp], outputs=out) | |
| with gr.Row(): | |
| gr.Markdown("") | |
| demo.launch(enable_queue=True) |