| """ |
| Sentinel Tiny Video Space β Frame interpolation with temporal Sentinel conv |
| """ |
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| from diffusers import UNet2DModel |
| import json |
|
|
| |
| class SentinelAct(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.inv_e = 1.0 / np.e |
| def forward(self, x): |
| return x * (1.0 / torch.cosh(self.inv_e * x)) |
|
|
| class SentinelVideoModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.spatial = UNet2DModel( |
| sample_size=32, in_channels=3, out_channels=3, layers_per_block=1, |
| block_out_channels=(32, 64), down_block_types=("DownBlock2D",)*2, |
| up_block_types=("UpBlock2D",)*2, time_embedding_type="positional", |
| ) |
| self.temporal = nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0)) |
| self.inv_e = 1.0 / np.e |
| def forward(self, x, t): |
| B, C, F, H, W = x.shape |
| out = [] |
| for f in range(F): |
| of = self.spatial(x[:, :, f, :, :], t).sample |
| out.append(of) |
| stack = torch.stack(out, dim=2) |
| temp = self.temporal(stack) |
| return temp * (1.0 / torch.cosh(self.inv_e * temp)) |
|
|
| |
| model_status = "β³ Loading model..." |
| video_model = None |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| model_path = hf_hub_download(repo_id="5dimension/sentinel-tiny-video", filename="model.pt") |
| video_model = SentinelVideoModel() |
| video_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True)) |
| video_model.eval() |
| model_status = "β
Model loaded β 30 trainable params on frozen 2D UNet" |
| except Exception as e: |
| video_model = SentinelVideoModel() |
| video_model.eval() |
| model_status = f"β οΈ Using fresh weights: {str(e)[:100]}" |
|
|
| |
| def interpolate_video(frame1, frame2, steps=8): |
| if video_model is None: |
| return [Image.new('RGB', (256, 256), color='gray') for _ in range(steps)] |
| |
| |
| def pil_to_tensor(img): |
| arr = np.array(img.resize((32, 32)).convert('RGB')) / 255.0 |
| return torch.from_numpy(arr).permute(2, 0, 1).float() |
| |
| f1 = pil_to_tensor(frame1) |
| f2 = pil_to_tensor(frame2) |
| |
| |
| interpolated = [frame1.resize((256, 256), Image.BILINEAR)] |
| |
| with torch.no_grad(): |
| for i in range(1, steps): |
| alpha = i / steps |
| |
| mid_blend = (1 - alpha) * f1 + alpha * f2 |
| |
| |
| triplet = torch.stack([f1, mid_blend, f2], dim=0).unsqueeze(0) |
| triplet = triplet.permute(0, 2, 1, 3, 4) |
| |
| t = torch.tensor([500], device="cpu").long() |
| out = video_model(triplet, t) |
| |
| |
| mid = out[0, :, 1, :, :].permute(1, 2, 0).cpu().numpy() |
| mid = np.clip(mid, 0, 1) |
| mid = (mid * 255).astype(np.uint8) |
| |
| img = Image.fromarray(mid).resize((256, 256), Image.BILINEAR) |
| interpolated.append(img) |
| |
| interpolated.append(frame2.resize((256, 256), Image.BILINEAR)) |
| |
| return interpolated |
|
|
| |
| with gr.Blocks(title="π¬ Sentinel Tiny Video", css=""" |
| .gradio-container { max-width: 900px; margin: 0 auto; } |
| .title { text-align: center; font-size: 2em; font-weight: bold; color: #6b4c9a; } |
| .subtitle { text-align: center; color: #888; margin-bottom: 1em; } |
| """) as demo: |
| gr.Markdown(""" |
| <div class="title">π¬ Sentinel Tiny Video</div> |
| <div class="subtitle">30-parameter frame interpolation with Sentinel temporal convolution</div> |
| """) |
| |
| gr.Markdown(f"**Status**: {model_status}") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| frame1 = gr.Image(label="Frame 1", type="pil", value=None) |
| frame2 = gr.Image(label="Frame 2", type="pil", value=None) |
| with gr.Column(scale=1): |
| steps = gr.Slider(3, 16, value=8, step=1, label="Interpolation Steps") |
| generate_btn = gr.Button("π¬ Interpolate", variant="primary") |
| |
| gallery = gr.Gallery(label="Interpolated Frames", columns=4, height=200) |
| |
| with gr.Row(): |
| gr.Markdown(""" |
| ### About |
| - **Architecture**: Frozen 2D UNet + trainable 3D temporal conv |
| - **Trainable Params**: 30 (yes, thirty) |
| - **Activation**: Sentinel sech on temporal features |
| - **Dataset**: CIFAR-10 frame triplets |
| - **Input**: Two 32Γ32 keyframes β smooth transition |
| - **Full model**: [sentinel-tiny-video](https://huggingface.co/5dimension/sentinel-tiny-video) |
| """) |
| |
| generate_btn.click(interpolate_video, [frame1, frame2, steps], gallery) |
|
|
| demo.launch() |
|
|