Hajime / hajimu.py
Ai128474's picture
Create hajimu.py
f7c0512 verified
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import os
import subprocess
# ---------------------------
# Model Skeleton (CPU)
# ---------------------------
class TextEncoder(nn.Module):
def __init__(self, vocab_size=20000, dim=256):
super().__init__()
self.embed = nn.Embedding(vocab_size, dim)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim, nhead=4),
num_layers=4
)
def forward(self, x):
return self.transformer(self.embed(x))
class ImageEncoder(nn.Module):
def __init__(self, dim=256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, dim//2, 3, padding=1),
nn.ReLU(),
nn.Conv2d(dim//2, dim, 3, padding=1),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class VideoDiffusion(nn.Module):
def __init__(self, dim=256):
super().__init__()
self.conv3d_1 = nn.Conv3d(dim, dim, 3, padding=1)
self.conv3d_2 = nn.Conv3d(dim, dim, 3, padding=1)
def forward(self, x):
x = torch.relu(self.conv3d_1(x))
x = torch.relu(self.conv3d_2(x))
return x
class VideoDecoder(nn.Module):
def __init__(self, dim=256):
super().__init__()
self.deconv = nn.Sequential(
nn.ConvTranspose3d(dim, dim//2, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose3d(dim//2, 3, 3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, x):
return self.deconv(x)
# ---------------------------
# Hajime Pipeline
# ---------------------------
video_length = 15
fps = 10
frame_dir = "frames"
output_file = "hajime_output.avi"
os.makedirs(frame_dir, exist_ok=True)
# Initialize models
text_enc = TextEncoder()
img_enc = ImageEncoder()
diff = VideoDiffusion()
dec = VideoDecoder()
# Dummy inputs (random noise)
latent_video = torch.rand(1, 256, video_length, 64, 64)
latent_video = diff(latent_video)
video_frames = dec(latent_video) # [1,3,frames,H,W]
# Generate frames and add watermark + ending frame
frame_paths = []
font = ImageFont.load_default()
for i, frame in enumerate(video_frames[0].permute(1,2,3,0)): # frames x H x W x C
img = Image.fromarray((frame.numpy()*255).astype(np.uint8))
draw = ImageDraw.Draw(img)
draw.text((5,5), "HAJIME WATERMARK", font=font, fill=(255,0,0))
path = f"{frame_dir}/frame_{i:03d}.png"
img.save(path)
frame_paths.append(path)
# Ending frame with guy screaming "HAJIME!!"
end_img = Image.new('RGB', (64,64), color=(0,0,0))
draw = ImageDraw.Draw(end_img)
draw.text((5,25), "HAJIME!!", font=font, fill=(255,255,0))
end_path = f"{frame_dir}/frame_end.png"
end_img.save(end_path)
frame_paths.append(end_path)
print("Hajime video generation pipeline complete!")