| import torch |
| import torch.nn as nn |
| from PIL import Image, ImageDraw, ImageFont |
| import numpy as np |
| import os |
| import subprocess |
|
|
| |
| |
| |
| class TextEncoder(nn.Module): |
| def __init__(self, vocab_size=20000, dim=256): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, dim) |
| self.transformer = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer(d_model=dim, nhead=4), |
| num_layers=4 |
| ) |
| def forward(self, x): |
| return self.transformer(self.embed(x)) |
|
|
| class ImageEncoder(nn.Module): |
| def __init__(self, dim=256): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(3, dim//2, 3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(dim//2, dim, 3, padding=1), |
| nn.ReLU() |
| ) |
| def forward(self, x): |
| return self.conv(x) |
|
|
| class VideoDiffusion(nn.Module): |
| def __init__(self, dim=256): |
| super().__init__() |
| self.conv3d_1 = nn.Conv3d(dim, dim, 3, padding=1) |
| self.conv3d_2 = nn.Conv3d(dim, dim, 3, padding=1) |
| def forward(self, x): |
| x = torch.relu(self.conv3d_1(x)) |
| x = torch.relu(self.conv3d_2(x)) |
| return x |
|
|
| class VideoDecoder(nn.Module): |
| def __init__(self, dim=256): |
| super().__init__() |
| self.deconv = nn.Sequential( |
| nn.ConvTranspose3d(dim, dim//2, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose3d(dim//2, 3, 3, stride=1, padding=1), |
| nn.Sigmoid() |
| ) |
| def forward(self, x): |
| return self.deconv(x) |
|
|
| |
| |
| |
| video_length = 15 |
| fps = 10 |
| frame_dir = "frames" |
| output_file = "hajime_output.avi" |
| os.makedirs(frame_dir, exist_ok=True) |
|
|
| |
| text_enc = TextEncoder() |
| img_enc = ImageEncoder() |
| diff = VideoDiffusion() |
| dec = VideoDecoder() |
|
|
| |
| latent_video = torch.rand(1, 256, video_length, 64, 64) |
| latent_video = diff(latent_video) |
| video_frames = dec(latent_video) |
|
|
| |
| frame_paths = [] |
| font = ImageFont.load_default() |
| for i, frame in enumerate(video_frames[0].permute(1,2,3,0)): |
| img = Image.fromarray((frame.numpy()*255).astype(np.uint8)) |
| draw = ImageDraw.Draw(img) |
| draw.text((5,5), "HAJIME WATERMARK", font=font, fill=(255,0,0)) |
| path = f"{frame_dir}/frame_{i:03d}.png" |
| img.save(path) |
| frame_paths.append(path) |
|
|
| |
| end_img = Image.new('RGB', (64,64), color=(0,0,0)) |
| draw = ImageDraw.Draw(end_img) |
| draw.text((5,25), "HAJIME!!", font=font, fill=(255,255,0)) |
| end_path = f"{frame_dir}/frame_end.png" |
| end_img.save(end_path) |
| frame_paths.append(end_path) |
|
|
| print("Hajime video generation pipeline complete!") |