nutata-v1.0-finetuned / modeling_nutata.py
amewebstudio's picture
Update modeling_nutata.py
7da1525 verified
#!/usr/bin/env python3
"""
================================================================================
╔═══════════════════════════════════════════════════════════════════════════════╗
β•‘ β•‘
β•‘ β–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘
β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β•β•β•šβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β•β• β•‘
β•‘ β–ˆβ–ˆβ•”β–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•šβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘
β•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β• β–ˆβ–ˆβ•”β–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ•β•β•β•β–ˆβ–ˆβ•‘ β•‘
β•‘ β–ˆβ–ˆβ•‘ β•šβ–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•—β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘ β•‘
β•‘ β•šβ•β• β•šβ•β•β•β•β•šβ•β•β•β•β•β•β•β•šβ•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β•β• β•‘
β•‘ β•‘
β•‘ β–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘
β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β•β•β•β•β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β•β•β•β•β–ˆβ–ˆβ•‘
β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β–ˆβ–ˆβ–ˆβ–ˆβ•”β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘
β•‘ β•šβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β• β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β• β–ˆβ–ˆβ•‘
β•‘ β•šβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘ β•šβ•β• β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ•‘
β•‘ β•šβ•β•β•β• β•šβ•β•β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β•β• β•šβ•β•β•β•β•β• β•šβ•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β•β•β•šβ•β•β•
β•‘ β•‘
β•‘ NUTATA v1.1 COMPLETE β•‘
β•‘ "Learning to Understand and Generate Video with Cognition" β•‘
β•‘ β•‘
β•‘ Full NEXUS Cognitive Architecture for Video: β•‘
β•‘ β€’ Causal 3D VAE (Spatial + Temporal compression) β•‘
β•‘ β€’ EARCP with Temporal Attention β•‘
β•‘ β€’ LPOL Memory (Video-specific domains) β•‘
β•‘ β€’ GQA for efficient long-sequence processing β•‘
β•‘ β€’ Neurogenesis for adaptive capacity β•‘
β•‘ β€’ Temporal Coherence Module β•‘
β•‘ β€’ Frame Prediction & Generation β•‘
β•‘ β€’ Flow Prediction (NEW in v1.1) β•‘
β•‘ β€’ Hierarchical Memory (NEW in v1.1) β•‘
β•‘ β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
================================================================================
Author: Mike Amega (Logo) - Ame Web Studio
License: Proprietary (Ame Web Studio)
Version: 1.1 - Complete with Architecture Fixes
CRITICAL ARCHITECTURE FIXES (for checkpoint compatibility):
- EncoderStage/DecoderStage as nn.ModuleList (creates encoders.0.0, encoders.0.1, encoders.0.2)
- temporal_coherence.smooth as depthwise Conv1d (groups=d_model)
- LPOL memory_attn as nn.ModuleDict with q_proj, k_proj, v_proj, o_proj
- VideoExpert with fc1/fc2 Linear layers (not SwiGLU)
- attn_norm as nn.LayerNorm with elementwise_affine=False
- All projection layers with bias=False where needed
Changes from v1.0:
- Fixed RoPE dimension mismatch in VideoGQA
- Fixed LPOL memory attention mechanism
- Added FlashAttention-style efficient attention
- Added Flow Prediction module
- Improved temporal coherence
- Better gradient flow with residual scaling
- Added video quality metrics (PSNR, SSIM)
================================================================================
"""
import os
import sys
import math
import json
import random
import logging
import shutil
import tempfile
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass, field
from collections import deque
from enum import Enum
import numpy as np
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from tqdm import tqdm
# HuggingFace Hub
try:
from huggingface_hub import HfApi, hf_hub_download, snapshot_download, create_repo
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print(
"Warning: huggingface_hub not installed. Install with: pip install huggingface_hub"
)
import warnings
warnings.filterwarnings("ignore")
# Logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ==============================================================================
# SECTION 1: CONFIGURATION
# ==============================================================================
@dataclass
class VideoConfig:
"""Video-specific configuration - v2.0 with high-resolution support"""
height: int = 256 # Upgraded from 64
width: int = 256 # Upgraded from 64
channels: int = 3
n_frames: int = 16
fps: int = 8
temporal_downsample: int = 4
spatial_downsample: int = 8
@classmethod
def from_dict(cls, d: dict) -> "VideoConfig":
valid_keys = {
"height",
"width",
"channels",
"n_frames",
"fps",
"temporal_downsample",
"spatial_downsample",
}
return cls(**{k: v for k, v in d.items() if k in valid_keys})
@dataclass
class NutataModelConfig:
"""
NUTATA v2.0 Configuration - HIGH RESOLUTION + CONDITIONING
Major improvements:
- Progressive resolution training (64 β†’ 512)
- Perceptual loss (VGG + LPIPS)
- Multi-level conditioning (text/action/scene)
- Hierarchical temporal memory
- Optical flow consistency
- Anti-blur sharpening
"""
# === Model Identity ===
model_type: str = "nutata-videomodel"
version: str = "2.0"
codename: str = "VideoSim-Cognitive-HD"
architecture_type: str = "cognitive-video-conditioned"
# === Core Dimensions ===
d_model: int = 512
d_ff: int = 2048
n_layers: int = 8
n_heads: int = 8
dropout: float = 0.1
# === Video Latent Space ===
latent_dim: int = 256
temporal_latent_dim: int = 512
latent_channels: int = 4
# === Sequence ===
max_frames: int = 64
context_frames: int = 16
prediction_frames: int = 8
# === 3D VAE (upgraded for higher resolution) ===
encoder_channels: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
decoder_channels: List[int] = field(default_factory=lambda: [512, 256, 128, 64])
kl_weight: float = 0.0001
use_skip_connections: bool = True # NEW: encoder-decoder skip connections
# === Progressive Resolution Training (NEW) ===
progressive_resolution: List[int] = field(
default_factory=lambda: [64, 128, 256, 384, 512]
)
current_resolution_idx: int = 0 # Start at 64, increase during training
resolution_warmup_epochs: int = 5 # Epochs per resolution step
# === LPOL Memory (Video domains) ===
use_lpol: bool = True
memory_size: int = 256
memory_slots_per_domain: int = 32
memory_k: int = 8
domain_types: List[str] = field(
default_factory=lambda: [
"motion",
"appearance",
"temporal",
"spatial",
"object",
"scene",
"action",
"causality",
"physics",
]
)
# === GQA (Grouped Query Attention) ===
use_gqa: bool = True
gqa_num_heads: int = 8
gqa_num_kv_groups: int = 2
# === EARCP ===
expert_types: List[str] = field(
default_factory=lambda: [
"Motion",
"Appearance",
"Temporal",
"Spatial",
"Prediction",
"Generation",
]
)
max_experts: int = 12
growth_threshold_coherence: float = 0.3
growth_patience: int = 10
# === Neurogenesis ===
neurogenesis_enabled: bool = True
min_neurons: int = 32
max_neurons: int = 256
neuron_birth_threshold: float = 0.8
neuron_death_threshold: float = 0.05
# === Energy System ===
energy_enabled: bool = True
energy_cost_encode: float = 0.01
energy_cost_decode: float = 0.02
energy_cost_predict: float = 0.03
energy_regeneration: float = 0.05
# === Dream Phase ===
dream_enabled: bool = True
dream_cycle_length: int = 100
dream_duration: int = 20
# === Temporal Coherence ===
temporal_coherence_weight: float = 0.1
flow_prediction: bool = True
# === Perceptual Loss (NEW - v2.0) ===
use_perceptual_loss: bool = True
perceptual_loss_weight: float = 0.1
perceptual_adaptive: bool = True # Adaptive weight based on epoch
perceptual_max_weight: float = 0.2
perceptual_warmup_epochs: int = 10
lpips_weight: float = 0.05
vgg_layers: List[str] = field(
default_factory=lambda: ["relu1_2", "relu2_2", "relu3_4", "relu4_4"]
)
# === Optical Flow Consistency (NEW - v2.0) ===
use_optical_flow: bool = True
optical_flow_weight: float = 0.05
# === Conditioning System (NEW - v2.0) ===
use_conditioning: bool = True
text_condition_dim: int = 768 # CLIP embedding dimension
num_action_classes: int = 400 # Kinetics-400
num_scene_classes: int = 365 # Places365
condition_injection_levels: List[str] = field(
default_factory=lambda: ["latent", "decoder", "temporal", "memory"]
) # Multi-level injection as per user request
# === Hierarchical Memory (NEW - v2.0) ===
use_hierarchical_memory: bool = True
memory_scales: List[int] = field(default_factory=lambda: [1, 4, 16])
frame_memory_slots: int = 64
clip_memory_slots: int = 32
scene_memory_slots: int = 16
# === Multi-Scale Temporal (NEW - v2.0) ===
use_multiscale_temporal: bool = True
temporal_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 15])
# === Anti-Blur Sharpening (NEW - v2.0) ===
use_sharpening: bool = True
sharpening_weight: float = 0.1
# === Training ===
batch_size: int = 2 # Reduced for higher resolution
learning_rate: float = 1e-4
epochs: int = 50
gradient_accumulation: int = 8 # Increased for smaller batch
warmup_steps: int = 500
use_gradient_checkpointing: bool = True # Save memory
# === HuggingFace ===
push_to_hub: bool = True
hub_model_id: str = "amewebstudio/nutata-videomodel-v2.0"
# === Video Config ===
video: VideoConfig = field(default_factory=VideoConfig)
def to_dict(self) -> Dict:
d = {}
for key, value in self.__dict__.items():
if key == "video":
d[key] = {k: v for k, v in value.__dict__.items()}
elif isinstance(value, (list, dict, str, int, float, bool, type(None))):
d[key] = value
return d
@classmethod
def from_dict(cls, d: Dict) -> "NutataModelConfig":
d = d.copy()
if "video" in d and isinstance(d["video"], dict):
d["video"] = VideoConfig.from_dict(d["video"])
# Remove internal keys
for k in ["_dynamic_state", "_class_name", "_version", "_architecture"]:
d.pop(k, None)
known = set(cls.__dataclass_fields__.keys())
return cls(**{k: v for k, v in d.items() if k in known})
# ==============================================================================
# SECTION 2: BASIC BUILDING BLOCKS
# ==============================================================================
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * rms * self.weight
class LayerNormCustom(nn.Module):
"""Layer Normalization with optional bias"""
def __init__(self, dim: int, eps: float = 1e-6, bias: bool = True):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(-1, keepdim=True)
var = x.var(-1, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
x = x * self.weight
if self.bias is not None:
x = x + self.bias
return x
class SwiGLU(nn.Module):
"""SwiGLU activation"""
def __init__(
self, in_features: int, hidden_features: int = None, out_features: int = None
):
super().__init__()
hidden_features = hidden_features or in_features * 4
out_features = out_features or in_features
self.w1 = nn.Linear(in_features, hidden_features, bias=False)
self.w2 = nn.Linear(in_features, hidden_features, bias=False)
self.w3 = nn.Linear(hidden_features, out_features, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) - Fixed version"""
def __init__(self, dim: int, max_seq_len: int = 1024, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Precompute cos/sin cache
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(
seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos(), persistent=False)
self.register_buffer("sin_cache", emb.sin(), persistent=False)
def forward(
self, seq_len: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.cos_cache.shape[0]:
self._build_cache(seq_len)
return (
self.cos_cache[:seq_len].to(device),
self.sin_cache[:seq_len].to(device),
)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary positional embedding to Q and K
Args:
q: [B, H, T, D] query tensor
k: [B, H, T, D] key tensor
cos: [T, D] cosine embeddings
sin: [T, D] sine embeddings
Returns:
q_embed, k_embed with rotary position encoding applied
"""
# Reshape cos/sin for broadcasting: [T, D] -> [1, 1, T, D]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# ==============================================================================
# SECTION 3: GROUPED QUERY ATTENTION (GQA) - CHECKPOINT COMPATIBLE
# ==============================================================================
class VideoGQA(nn.Module):
"""
Grouped Query Attention for Video sequences - CHECKPOINT COMPATIBLE
Key features:
- Correct RoPE application with proper dimension handling
- Efficient memory usage with proper KV expansion
- bias=False on all projections to match checkpoint
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.d_model = config.d_model
self.num_heads = config.gqa_num_heads
self.num_kv_groups = config.gqa_num_kv_groups
self.head_dim = config.d_model // self.num_heads
self.heads_per_group = self.num_heads // self.num_kv_groups
self.scale = self.head_dim**-0.5
# Projections - NO bias to match checkpoint
self.q_proj = nn.Linear(
config.d_model, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
config.d_model, self.num_kv_groups * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
config.d_model, self.num_kv_groups * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, config.d_model, bias=False
)
# RoPE
self.rope = RotaryPositionalEmbedding(
self.head_dim, max_seq_len=config.max_frames * 2
)
self.dropout = nn.Dropout(config.dropout)
# Residual scaling for better gradient flow
self.residual_scale = nn.Parameter(torch.ones(1) * 0.1)
def forward(
self, x: torch.Tensor, causal: bool = True, use_rope: bool = True
) -> torch.Tensor:
"""
Args:
x: [B, T, D] input tensor
causal: Whether to apply causal masking
use_rope: Whether to apply rotary position embeddings
Returns:
output: [B, T, D]
"""
B, T, D = x.shape
# Project to Q, K, V
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim)
v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim)
# Transpose to [B, H, T, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Apply RoPE
if use_rope:
cos, sin = self.rope(T, x.device)
# Apply to Q (all heads)
q = (q * cos.unsqueeze(0).unsqueeze(0)) + (
rotate_half(q) * sin.unsqueeze(0).unsqueeze(0)
)
# Apply to K (kv groups)
k = (k * cos.unsqueeze(0).unsqueeze(0)) + (
rotate_half(k) * sin.unsqueeze(0).unsqueeze(0)
)
# Expand K, V for grouped attention
# [B, num_kv_groups, T, D] -> [B, num_heads, T, D]
k = k.repeat_interleave(self.heads_per_group, dim=1)
v = v.repeat_interleave(self.heads_per_group, dim=1)
# Compute attention scores
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Apply causal mask if needed
if causal:
causal_mask = torch.triu(
torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1
)
attn_weights = attn_weights.masked_fill(causal_mask, float("-inf"))
# Softmax and dropout
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, v)
# Reshape back
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
# Output projection
output = self.o_proj(attn_output)
return output
class MultiHeadAttention(nn.Module):
"""Standard Multi-Head Attention for memory operations"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
B, T_q, _ = q.shape
T_k = k.shape[1]
# Project
q = self.q_proj(q).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(k).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(v).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
# Attention
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T_q, -1)
return self.o_proj(out)
# ==============================================================================
# SECTION 3.5: NEW V2.0 MODULES - PERCEPTUAL LOSS, CONDITIONING, HIERARCHICAL MEMORY
# ==============================================================================
class PerceptualLossModule(nn.Module):
"""
Multi-scale perceptual loss using VGG19 features.
Computes feature-level differences for sharper, more perceptually accurate reconstructions.
Features:
- Uses pretrained VGG19 (frozen)
- Extracts features at multiple layers
- Adaptive weighting based on training epoch
"""
VGG_MEAN = [0.485, 0.456, 0.406]
VGG_STD = [0.229, 0.224, 0.225]
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
self.layers = config.vgg_layers
# VGG19 features (will be loaded on first use to avoid download issues)
self.vgg = None
self.layer_indices = {
"relu1_2": 4,
"relu2_2": 9,
"relu3_4": 18,
"relu4_4": 27,
"relu5_4": 36,
}
# Layer weights (higher layers = more semantic, lower = more texture)
self.layer_weights = {
"relu1_2": 1.0,
"relu2_2": 1.0,
"relu3_4": 1.0,
"relu4_4": 1.0,
"relu5_4": 0.5,
}
# Epoch tracking for adaptive weighting
self.current_epoch = 0
def _load_vgg(self, device):
"""Lazy load VGG19 to avoid issues in offline environments"""
if self.vgg is not None:
return
try:
from torchvision import models
vgg19 = models.vgg19(weights="IMAGENET1K_V1").features
vgg19.eval()
for param in vgg19.parameters():
param.requires_grad = False
self.vgg = vgg19.to(device)
logger.info("πŸ“¦ VGG19 loaded for perceptual loss")
except Exception as e:
logger.warning(f"⚠️ Could not load VGG19: {e}. Perceptual loss disabled.")
self.vgg = None
def _normalize(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize for VGG (ImageNet stats)"""
mean = torch.tensor(self.VGG_MEAN, device=x.device).view(1, 3, 1, 1)
std = torch.tensor(self.VGG_STD, device=x.device).view(1, 3, 1, 1)
return (x - mean) / std
def get_adaptive_weight(self, epoch: int = None) -> float:
"""Compute adaptive weight: min(max_weight, epoch / warmup * max_weight)"""
if not self.config.perceptual_adaptive:
return self.config.perceptual_loss_weight
epoch = epoch if epoch is not None else self.current_epoch
warmup = self.config.perceptual_warmup_epochs
max_w = self.config.perceptual_max_weight
return min(max_w, (epoch / warmup) * max_w) if warmup > 0 else max_w
def forward(
self, pred: torch.Tensor, target: torch.Tensor, epoch: int = None
) -> torch.Tensor:
"""
Compute perceptual loss between predicted and target videos.
Args:
pred: [B, C, T, H, W] predicted video
target: [B, C, T, H, W] target video
epoch: current training epoch (for adaptive weighting)
Returns:
Perceptual loss (scalar)
"""
self._load_vgg(pred.device)
if self.vgg is None:
return torch.tensor(0.0, device=pred.device)
B, C, T, H, W = pred.shape
# Flatten time into batch for frame-wise processing
pred_frames = pred.transpose(1, 2).reshape(B * T, C, H, W)
target_frames = target.transpose(1, 2).reshape(B * T, C, H, W)
# Resize to VGG input size if needed
if H < 64 or W < 64:
pred_frames = F.interpolate(
pred_frames, size=(224, 224), mode="bilinear", align_corners=False
)
target_frames = F.interpolate(
target_frames, size=(224, 224), mode="bilinear", align_corners=False
)
# Normalize
pred_norm = self._normalize(pred_frames)
target_norm = self._normalize(target_frames)
# Extract features and compute loss
loss = 0.0
pred_feat = pred_norm
target_feat = target_norm
for i, layer in enumerate(self.vgg):
pred_feat = layer(pred_feat)
target_feat = layer(target_feat)
# Check if this is a layer we want
for layer_name, layer_idx in self.layer_indices.items():
if i == layer_idx and layer_name in self.layers:
weight = self.layer_weights.get(layer_name, 1.0)
loss = loss + weight * F.l1_loss(pred_feat, target_feat)
return loss * self.get_adaptive_weight(epoch)
class ConditionalEncoder(nn.Module):
"""
Multi-modal conditioning system for controllable video generation.
Supports:
- Text β†’ video (CLIP-compatible embeddings)
- Action β†’ movement (action class embeddings)
- Scene β†’ style (scene descriptor embeddings)
Injects conditions at multiple levels as per user request:
- Latent space
- Decoder
- Temporal processor
- Memory
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
d = config.d_model
# === Text Conditioning (CLIP-compatible) ===
self.text_proj = nn.Sequential(
nn.Linear(config.text_condition_dim, d), nn.SiLU(), nn.Linear(d, d)
)
# === Action Conditioning ===
self.action_embed = nn.Embedding(config.num_action_classes, d)
self.action_proj = nn.Linear(d, d)
# === Scene Conditioning ===
self.scene_embed = nn.Embedding(config.num_scene_classes, d)
self.scene_proj = nn.Linear(d, d)
# === Cross-attention for condition injection ===
self.cross_attn = MultiHeadAttention(d, 8, dropout=config.dropout)
# === Condition fusion ===
self.condition_fusion = nn.Sequential(
nn.Linear(d * 3, d), nn.SiLU(), nn.Dropout(config.dropout), nn.Linear(d, d)
)
# === Gating for residual ===
self.gate = nn.Sequential(nn.Linear(d * 2, d), nn.Sigmoid())
# === Level-specific projections ===
self.level_projs = nn.ModuleDict(
{
"latent": nn.Linear(d, d),
"decoder": nn.Linear(d, d),
"temporal": nn.Linear(d, d),
"memory": nn.Linear(d, d),
}
)
def encode_conditions(
self,
text_emb: torch.Tensor = None,
action_id: torch.Tensor = None,
scene_id: torch.Tensor = None,
batch_size: int = 1,
device: torch.device = None,
) -> torch.Tensor:
"""
Encode all conditions into a unified conditioning vector.
Args:
text_emb: [B, text_dim] or [B, T, text_dim] text embeddings (e.g., from CLIP)
action_id: [B] action class indices
scene_id: [B] scene class indices
batch_size: batch size for zero-initialization if no conditions
device: device for tensors
Returns:
condition: [B, D] unified condition vector
"""
device = device or (
text_emb.device
if text_emb is not None
else action_id.device
if action_id is not None
else scene_id.device
if scene_id is not None
else "cpu"
)
# Initialize with zeros
text_cond = torch.zeros(batch_size, self.config.d_model, device=device)
action_cond = torch.zeros(batch_size, self.config.d_model, device=device)
scene_cond = torch.zeros(batch_size, self.config.d_model, device=device)
# Encode text
if text_emb is not None:
if text_emb.dim() == 3: # [B, T, D] β†’ mean pool
text_emb = text_emb.mean(dim=1)
text_cond = self.text_proj(text_emb)
# Encode action
if action_id is not None:
action_emb = self.action_embed(action_id) # [B, D]
action_cond = self.action_proj(action_emb)
# Encode scene
if scene_id is not None:
scene_emb = self.scene_embed(scene_id) # [B, D]
scene_cond = self.scene_proj(scene_emb)
# Fuse conditions
combined = torch.cat([text_cond, action_cond, scene_cond], dim=-1)
condition = self.condition_fusion(combined) # [B, D]
return condition
def inject(
self, z: torch.Tensor, condition: torch.Tensor, level: str = "latent"
) -> torch.Tensor:
"""
Inject condition into latent representation at specified level.
Args:
z: [B, T, D] latent sequence
condition: [B, D] condition vector
level: one of 'latent', 'decoder', 'temporal', 'memory'
Returns:
z_conditioned: [B, T, D] conditioned latent
"""
B, T, D = z.shape
# Get level-specific projection (ModuleDict doesn't have .get())
proj = (
self.level_projs[level]
if level in self.level_projs
else self.level_projs["latent"]
)
cond_proj = proj(condition) # [B, D]
# Expand condition to sequence length
cond_expanded = cond_proj.unsqueeze(1).expand(-1, T, -1) # [B, T, D]
# Cross-attention: z attends to condition
z_attended = self.cross_attn(z, cond_expanded, cond_expanded)
# Gated residual
concat = torch.cat([z, z_attended], dim=-1)
gate = self.gate(concat)
z_conditioned = z + gate * z_attended
return z_conditioned
def forward(
self,
z: torch.Tensor,
text_emb: torch.Tensor = None,
action_id: torch.Tensor = None,
scene_id: torch.Tensor = None,
level: str = "latent",
) -> torch.Tensor:
"""
Full forward: encode conditions and inject into z.
"""
B = z.shape[0]
condition = self.encode_conditions(text_emb, action_id, scene_id, B, z.device)
return self.inject(z, condition, level)
class HierarchicalTemporalMemory(nn.Module):
"""
3-level hierarchical memory for long-term temporal coherence.
Levels:
- Frame level: Fine details, short-term (1 frame)
- Clip level: Medium-term patterns (4-8 frames)
- Scene level: Long-term context (16+ frames)
Each level maintains a memory bank that is updated and queried
to provide multi-scale temporal context.
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
d = config.d_model
# Memory scales: [1, 4, 16] frames
self.scales = config.memory_scales
# Memory banks (learnable)
self.frame_memory = nn.Parameter(
torch.randn(config.frame_memory_slots, d) * 0.02
)
self.clip_memory = nn.Parameter(torch.randn(config.clip_memory_slots, d) * 0.02)
self.scene_memory = nn.Parameter(
torch.randn(config.scene_memory_slots, d) * 0.02
)
# Temporal pooling for different scales
self.clip_pool = nn.AvgPool1d(kernel_size=4, stride=4, padding=0)
self.scene_pool = nn.AvgPool1d(kernel_size=16, stride=16, padding=0)
# Hierarchical attention
self.frame_attn = MultiHeadAttention(d, 8, dropout=config.dropout)
self.clip_attn = MultiHeadAttention(d, 8, dropout=config.dropout)
self.scene_attn = MultiHeadAttention(d, 8, dropout=config.dropout)
# Fusion across scales
self.scale_fusion = nn.Sequential(
nn.Linear(d * 3, d), nn.SiLU(), nn.Linear(d, d)
)
# Gating
self.gate = nn.Sequential(nn.Linear(d * 2, d), nn.Sigmoid())
def forward(self, z_seq: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
"""
Process sequence through hierarchical memory.
Args:
z_seq: [B, T, D] latent sequence
Returns:
z_enhanced: [B, T, D] memory-enhanced sequence
info: Dict with memory statistics
"""
B, T, D = z_seq.shape
# === Frame-level memory (full resolution) ===
frame_mem = self.frame_memory.unsqueeze(0).expand(B, -1, -1) # [B, slots, D]
frame_ctx = self.frame_attn(z_seq, frame_mem, frame_mem) # [B, T, D]
# === Clip-level memory (pooled) ===
# Pool temporally
z_t = z_seq.transpose(1, 2) # [B, D, T]
T_clip = max(1, T // 4)
if T >= 4:
z_clip = F.adaptive_avg_pool1d(z_t, T_clip).transpose(
1, 2
) # [B, T_clip, D]
else:
z_clip = z_seq.mean(dim=1, keepdim=True) # [B, 1, D]
clip_mem = self.clip_memory.unsqueeze(0).expand(B, -1, -1)
clip_ctx_pooled = self.clip_attn(z_clip, clip_mem, clip_mem) # [B, T_clip, D]
# Upsample back to original resolution
clip_ctx = F.interpolate(
clip_ctx_pooled.transpose(1, 2), size=T, mode="linear", align_corners=False
).transpose(1, 2) # [B, T, D]
# === Scene-level memory (heavily pooled) ===
T_scene = max(1, T // 16)
if T >= 16:
z_scene = F.adaptive_avg_pool1d(z_t, T_scene).transpose(
1, 2
) # [B, T_scene, D]
else:
z_scene = z_seq.mean(dim=1, keepdim=True) # [B, 1, D]
scene_mem = self.scene_memory.unsqueeze(0).expand(B, -1, -1)
scene_ctx_pooled = self.scene_attn(
z_scene, scene_mem, scene_mem
) # [B, T_scene, D]
# Upsample back
scene_ctx = F.interpolate(
scene_ctx_pooled.transpose(1, 2), size=T, mode="linear", align_corners=False
).transpose(1, 2) # [B, T, D]
# === Fuse all scales ===
multi_scale = torch.cat([frame_ctx, clip_ctx, scene_ctx], dim=-1) # [B, T, 3D]
fused = self.scale_fusion(multi_scale) # [B, T, D]
# === Gated residual ===
concat = torch.cat([z_seq, fused], dim=-1)
gate = self.gate(concat)
z_enhanced = z_seq + gate * fused
info = {
"frame_ctx_norm": frame_ctx.norm().item(),
"clip_ctx_norm": clip_ctx.norm().item(),
"scene_ctx_norm": scene_ctx.norm().item(),
}
return z_enhanced, info
class MultiScaleTemporalEncoder(nn.Module):
"""
Extract features at multiple temporal scales.
Captures both fast motion (3 frames) and slow motion (15 frames) patterns.
"""
def __init__(self, channels: int, config: NutataModelConfig):
super().__init__()
self.config = config
kernel_sizes = config.temporal_kernel_sizes # [3, 7, 15]
# Fast path (captures rapid motion)
self.fast_conv = CausalConv3d(channels, channels, (kernel_sizes[0], 1, 1))
# Medium path
self.medium_conv = CausalConv3d(channels, channels, (kernel_sizes[1], 1, 1))
# Slow path (captures gradual changes)
self.slow_conv = CausalConv3d(channels, channels, (kernel_sizes[2], 1, 1))
# Fusion
self.fusion = nn.Conv3d(channels * 3, channels, kernel_size=1)
# Scale parameter
self.scale = nn.Parameter(torch.tensor(0.1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, C, T, H, W] video features
Returns:
x_enhanced: [B, C, T, H, W] multi-scale enhanced features
"""
fast = self.fast_conv(x)
medium = self.medium_conv(x)
slow = self.slow_conv(x)
# Handle temporal dimension mismatches due to causal padding
T_min = min(fast.shape[2], medium.shape[2], slow.shape[2])
fast = fast[:, :, :T_min]
medium = medium[:, :, :T_min]
slow = slow[:, :, :T_min]
combined = torch.cat([fast, medium, slow], dim=1) # [B, 3C, T, H, W]
fused = self.fusion(combined) # [B, C, T, H, W]
# Pad back to original size if needed
if fused.shape[2] < x.shape[2]:
pad = x.shape[2] - fused.shape[2]
fused = F.pad(fused, (0, 0, 0, 0, pad, 0), mode="replicate")
return x + self.scale * fused
class SharpnessEnhancer(nn.Module):
"""
Learned sharpening module to counteract blur.
Uses learned high-pass filtering to enhance edges and details.
"""
def __init__(self, channels: int):
super().__init__()
# High-frequency detail extractor
self.detail_conv = nn.Sequential(
nn.Conv3d(
channels, channels, (1, 3, 3), padding=(0, 1, 1), groups=channels
),
nn.SiLU(),
nn.Conv3d(channels, channels, (1, 3, 3), padding=(0, 1, 1)),
)
# Learned Laplacian-like kernel for edge enhancement (depthwise conv)
self.edge_conv = nn.Conv3d(
channels,
channels,
(1, 3, 3),
padding=(0, 1, 1),
bias=False,
groups=channels,
)
# Initialize with Laplacian-like weights
with torch.no_grad():
laplacian = torch.tensor(
[[[0, -1, 0], [-1, 4, -1], [0, -1, 0]]], dtype=torch.float32
)
# Broadcast to all channels
self.edge_conv.weight.data = (
laplacian.unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1, 1, 1) * 0.1
)
# Learnable blend factor
self.detail_scale = nn.Parameter(torch.tensor(0.1))
self.edge_scale = nn.Parameter(torch.tensor(0.05))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply sharpening to video.
Args:
x: [B, C, T, H, W] video tensor
Returns:
x_sharp: [B, C, T, H, W] sharpened video
"""
# Extract high-frequency details
details = self.detail_conv(x)
# Edge enhancement
edges = self.edge_conv(x)
# Blend with original
x_sharp = x + self.detail_scale * details + self.edge_scale * edges
return x_sharp
# ==============================================================================
# SECTION 4: 3D CONVOLUTION BLOCKS
# ==============================================================================
class CausalConv3d(nn.Module):
"""Causal 3D convolution - doesn't look into future frames"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int, int] = (3, 3, 3),
stride: Tuple[int, int, int] = (1, 1, 1),
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
self.kernel_size = kernel_size
self.stride = stride
self.temporal_pad = kernel_size[0] - 1
self.conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=(0, kernel_size[1] // 2, kernel_size[2] // 2),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Causal padding: only pad past frames
if self.temporal_pad > 0:
x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0), mode="replicate")
return self.conv(x)
class ResBlock3D(nn.Module):
"""3D Residual block with spatial and temporal processing"""
def __init__(self, channels: int, use_temporal: bool = True):
super().__init__()
self.use_temporal = use_temporal
# Spatial processing
self.spatial = nn.Sequential(
nn.GroupNorm(min(32, channels), channels),
nn.SiLU(),
nn.Conv3d(channels, channels, (1, 3, 3), padding=(0, 1, 1)),
nn.GroupNorm(min(32, channels), channels),
nn.SiLU(),
nn.Conv3d(channels, channels, (1, 3, 3), padding=(0, 1, 1)),
)
# Temporal processing (causal)
if use_temporal:
self.temporal = nn.Sequential(
nn.GroupNorm(min(32, channels), channels),
nn.SiLU(),
CausalConv3d(channels, channels, (3, 1, 1)),
)
# Learnable residual scale
self.scale = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.spatial(x)
if self.use_temporal:
h = h + self.temporal(x)
return x + self.scale * h
class DownsampleBlock3D(nn.Module):
"""Downsample in space and optionally time"""
def __init__(
self, in_channels: int, out_channels: int, temporal_downsample: bool = True
):
super().__init__()
stride = (2, 2, 2) if temporal_downsample else (1, 2, 2)
self.conv = nn.Conv3d(in_channels, out_channels, 3, stride=stride, padding=1)
self.norm = nn.GroupNorm(min(32, out_channels), out_channels)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.norm(self.conv(x)))
class UpsampleBlock3D(nn.Module):
"""Upsample in space and optionally time"""
def __init__(
self, in_channels: int, out_channels: int, temporal_upsample: bool = True
):
super().__init__()
self.temporal_upsample = temporal_upsample
self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1)
self.norm = nn.GroupNorm(min(32, out_channels), out_channels)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
scale = (2, 2, 2) if self.temporal_upsample else (1, 2, 2)
x = F.interpolate(x, scale_factor=scale, mode="trilinear", align_corners=False)
return self.act(self.norm(self.conv(x)))
class SpatialAttention2D(nn.Module):
"""2D spatial attention within each frame"""
def __init__(self, channels: int, reduction: int = 8):
super().__init__()
reduced = max(channels // reduction, 8)
self.query = nn.Conv2d(channels, reduced, 1)
self.key = nn.Conv2d(channels, reduced, 1)
self.value = nn.Conv2d(channels, channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
q = self.query(x).view(B, -1, H * W).permute(0, 2, 1)
k = self.key(x).view(B, -1, H * W)
v = self.value(x).view(B, -1, H * W)
attn = F.softmax(torch.bmm(q, k) / math.sqrt(q.size(-1)), dim=-1)
out = torch.bmm(v, attn.permute(0, 2, 1)).view(B, C, H, W)
return self.gamma * out + x
class SpatioTemporalBlock(nn.Module):
"""
Spatio-temporal processing block - CHECKPOINT COMPATIBLE
Creates structure: spatial.0 (GroupNorm), spatial.1 (SiLU), spatial.2 (Conv3d),
temporal.0, temporal.1, temporal.2
"""
def __init__(self, channels: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(1))
# Spatial path: GroupNorm -> SiLU -> Conv3d
self.spatial = nn.Sequential(
nn.GroupNorm(min(8, channels), channels),
nn.SiLU(),
nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
)
# Temporal path: GroupNorm -> SiLU -> CausalConv3d
self.temporal = nn.Sequential(
nn.GroupNorm(min(8, channels), channels),
nn.SiLU(),
CausalConv3d(channels, channels, kernel_size=(3, 1, 1)),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.spatial(x)
h = self.temporal(h)
return x + self.scale * h
# ==============================================================================
# SECTION 5: ENCODER/DECODER STAGES - CHECKPOINT COMPATIBLE (ModuleList)
# ==============================================================================
class EncoderStage(nn.ModuleList):
"""
Encoder stage as nn.ModuleList - CRITICAL FOR CHECKPOINT COMPATIBILITY
This creates the exact key structure:
- encoders.0.0 -> SpatioTemporalBlock
- encoders.0.1 -> SpatioTemporalBlock
- encoders.0.2 -> Conv3d (downsample)
"""
def __init__(self, in_channels: int, out_channels: int, temporal_down: bool = True):
stride = (2, 2, 2) if temporal_down else (1, 2, 2)
super().__init__(
[
SpatioTemporalBlock(in_channels),
SpatioTemporalBlock(in_channels),
nn.Conv3d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self[0](x) # First ST block
x = self[1](x) # Second ST block
x = self[2](x) # Downsample conv
return x
class DecoderStage(nn.ModuleList):
"""
Decoder stage as nn.ModuleList - CRITICAL FOR CHECKPOINT COMPATIBILITY
This creates the exact key structure:
- decoders.0.0 -> SpatioTemporalBlock
- decoders.0.1 -> SpatioTemporalBlock
- decoders.0.2 -> ConvTranspose3d (upsample)
"""
def __init__(self, in_channels: int, out_channels: int, temporal_up: bool = True):
stride = (2, 2, 2) if temporal_up else (1, 2, 2)
output_padding = (1, 1, 1) if temporal_up else (0, 1, 1)
super().__init__(
[
SpatioTemporalBlock(in_channels),
SpatioTemporalBlock(in_channels),
nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
output_padding=output_padding,
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self[0](x) # First ST block
x = self[1](x) # Second ST block
x = self[2](x) # Upsample conv
return x
# ==============================================================================
# SECTION 6: VIDEO VAE ENCODER - CHECKPOINT COMPATIBLE
# ==============================================================================
class VideoVAEEncoder(nn.Module):
"""
3D VAE Encoder for video sequences - CHECKPOINT COMPATIBLE
Creates exact structure:
- input_conv
- encoders (ModuleList of EncoderStage)
- final_blocks (ModuleList of SpatioTemporalBlock)
- to_mu, to_logvar
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
channels = config.encoder_channels # [64, 128, 256, 512]
# v2.0 Multi-Scale Temporal Encoder (Applied before entry)
self.temporal_encoder = (
MultiScaleTemporalEncoder(config.video.channels, config)
if hasattr(config, "use_multiscale_temporal")
and config.use_multiscale_temporal
else None
)
# Initial projection
self.input_conv = nn.Conv3d(
config.video.channels, channels[0], kernel_size=(1, 3, 3), padding=(0, 1, 1)
)
# Encoder stages as ModuleList of EncoderStage (which are also ModuleLists)
self.encoders = nn.ModuleList(
[
EncoderStage(channels[0], channels[1], temporal_down=True),
EncoderStage(channels[1], channels[2], temporal_down=True),
EncoderStage(channels[2], channels[3], temporal_down=False),
]
)
# Final processing blocks as ModuleList
self.final_blocks = nn.ModuleList(
[SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])]
)
# To latent distribution
self.to_mu = nn.Conv3d(channels[-1], config.latent_channels, 1)
self.to_logvar = nn.Conv3d(channels[-1], config.latent_channels, 1)
# Project to d_model for transformer
self.to_d_model = nn.Linear(config.latent_channels, config.d_model)
# Adaptive projection for flattened spatial dims
# Calculate expected flattened size: C_latent * H_latent * W_latent
# After 3 encoder stages with downsample: H/8, W/8, T/4
H_latent = config.video.height // 8
W_latent = config.video.width // 8
flat_size = config.latent_channels * H_latent * W_latent
self.adaptive_proj = nn.Linear(flat_size, config.d_model)
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: [B, C, T, H, W] or [B, T, C, H, W]
Returns:
z: [B, T', D] sampled latent sequence
mu: [B, T', D] mean
logvar: [B, T', D] log variance
z_spatial: [B, C_latent, T', H', W'] spatial latent
"""
# Handle different input formats
if x.dim() == 5:
# Check if axis 1 is channels (standard) or axis 2 (sequence first)
if (
x.shape[1] != self.config.video.channels
and x.shape[2] == self.config.video.channels
):
# [B, T, C, H, W] -> [B, C, T, H, W]
x = x.permute(0, 2, 1, 3, 4)
# v2.0 Multi-Scale Temporal Encoding (Pre-process)
if hasattr(self, "temporal_encoder") and self.temporal_encoder is not None:
x = self.temporal_encoder(x)
B, C, T, H, W = x.shape
# Encode
h = self.input_conv(x)
for encoder_stage in self.encoders:
h = encoder_stage(h)
for block in self.final_blocks:
h = block(h)
# To latent distribution
mu_spatial = self.to_mu(h)
logvar_spatial = self.to_logvar(h).clamp(-10, 10)
# Reparameterization
std = torch.exp(0.5 * logvar_spatial)
eps = torch.randn_like(std)
z_spatial = mu_spatial + eps * std
# Flatten spatial dims and project to d_model
B, C_lat, T_lat, H_lat, W_lat = z_spatial.shape
# Pool spatial dims: [B, C, T, H, W] -> [B, T, C*H*W] -> [B, T, D]
z_flat = z_spatial.permute(0, 2, 1, 3, 4) # [B, T, C, H, W]
z_flat = z_flat.reshape(B, T_lat, -1) # [B, T, C*H*W]
# Use pre-initialized adaptive projection
# Handle size mismatch gracefully (can happen with different video sizes)
if z_flat.shape[-1] != self.adaptive_proj.in_features:
# Create compatible projection on the fly (for inference with different sizes)
proj = nn.Linear(z_flat.shape[-1], self.config.d_model).to(z_flat.device)
z = proj(z_flat)
mu_flat = mu_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1)
mu = proj(mu_flat)
logvar_flat = logvar_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1)
logvar = proj(logvar_flat)
else:
z = self.adaptive_proj(z_flat)
mu_flat = mu_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1)
mu = self.adaptive_proj(mu_flat)
logvar_flat = logvar_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1)
logvar = self.adaptive_proj(logvar_flat)
return z, mu, logvar, z_spatial
# ==============================================================================
# SECTION 7: VIDEO VAE DECODER - CHECKPOINT COMPATIBLE
# ==============================================================================
class VideoVAEDecoder(nn.Module):
"""
3D VAE Decoder for video generation - CHECKPOINT COMPATIBLE
Creates exact structure:
- from_latent
- init_blocks (ModuleList)
- decoders (ModuleList of DecoderStage)
- final_blocks (ModuleList)
- to_rgb (Sequential)
- temporal_refine (Sequential)
- refine_scale
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
channels = config.decoder_channels # [512, 256, 128, 64]
# From latent
self.from_latent = nn.Conv3d(config.latent_channels, channels[0], 1)
# Initial processing as ModuleList
self.init_blocks = nn.ModuleList(
[SpatioTemporalBlock(channels[0]), SpatioTemporalBlock(channels[0])]
)
# Decoder stages as ModuleList of DecoderStage (which are also ModuleLists)
self.decoders = nn.ModuleList(
[
DecoderStage(channels[0], channels[1], temporal_up=True),
DecoderStage(channels[1], channels[2], temporal_up=True),
DecoderStage(channels[2], channels[3], temporal_up=False),
]
)
# Final refinement as ModuleList
self.final_blocks = nn.ModuleList(
[SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])]
)
# To RGB - Sequential with specific structure
self.to_rgb = nn.Sequential(
nn.Conv3d(channels[-1], channels[-1] // 2, (1, 3, 3), padding=(0, 1, 1)),
nn.SiLU(),
nn.Conv3d(
channels[-1] // 2, config.video.channels, (1, 3, 3), padding=(0, 1, 1)
),
nn.Sigmoid(),
)
# Temporal refinement - Sequential
self.temporal_refine = nn.Sequential(
CausalConv3d(config.video.channels, 32, (3, 3, 3)),
nn.SiLU(),
nn.Conv3d(32, config.video.channels, 1),
nn.Tanh(),
)
self.refine_scale = nn.Parameter(torch.tensor(0.05))
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Args:
z: [B, C_latent, T', H', W'] spatial latent
Returns:
video: [B, C, T, H, W] reconstructed video
"""
h = self.from_latent(z)
for block in self.init_blocks:
h = block(h)
for decoder_stage in self.decoders:
h = decoder_stage(h)
for block in self.final_blocks:
h = block(h)
video = self.to_rgb(h)
# Temporal refinement
refine = self.temporal_refine(video) * self.refine_scale
video = torch.clamp(video + refine, 0, 1)
return video
# ==============================================================================
# SECTION 8: LPOL MEMORY FOR VIDEO - CHECKPOINT COMPATIBLE
# ==============================================================================
class VideoLPOL(nn.Module):
"""
LPOL Memory for Video - CHECKPOINT COMPATIBLE
Creates exact structure:
- memories (ParameterDict with domain names)
- memory_attn (ModuleDict with q_proj, k_proj, v_proj, o_proj)
- domain_clf (Sequential)
- fusion (Sequential)
- gate (Sequential)
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
self.n_domains = len(config.domain_types)
self.slots_per_domain = config.memory_slots_per_domain
# Domain-specific memory banks as ParameterDict
self.memories = nn.ParameterDict(
{
d: nn.Parameter(
torch.randn(self.slots_per_domain, config.d_model) * 0.02
)
for d in config.domain_types
}
)
# Memory attention as ModuleDict - CRITICAL for checkpoint compatibility
# Creates: memory_attn.q_proj, memory_attn.k_proj, memory_attn.v_proj, memory_attn.o_proj
self.memory_attn = nn.ModuleDict(
{
"q_proj": nn.Linear(config.d_model, config.d_model, bias=False),
"k_proj": nn.Linear(config.d_model, config.d_model, bias=False),
"v_proj": nn.Linear(config.d_model, config.d_model, bias=False),
"o_proj": nn.Linear(config.d_model, config.d_model, bias=False),
}
)
# Domain classifier
self.domain_clf = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, self.n_domains),
)
# Output fusion
self.fusion = nn.Sequential(
nn.Linear(config.d_model * 2, config.d_model),
nn.GELU(),
nn.Linear(config.d_model, config.d_model),
)
# Gating mechanism
self.gate = nn.Sequential(
nn.Linear(config.d_model * 2, config.d_model), nn.Sigmoid()
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
"""
Args:
x: [B, T, D] temporal sequence
Returns:
output: [B, T, D] memory-augmented features
info: Dict with domain activations
"""
B, T, D = x.shape
# Classify domains based on temporal pooling
x_pooled = x.mean(dim=1) # [B, D]
domain_logits = self.domain_clf(x_pooled)
domain_probs = F.softmax(domain_logits, dim=-1) # [B, n_domains]
# Collect memories weighted by domain probability
all_memories = []
for i, domain_name in enumerate(self.config.domain_types):
mem = self.memories[domain_name] # [slots, D]
weight = domain_probs[:, i : i + 1] # [B, 1]
weighted_mem = mem.unsqueeze(0) * weight.unsqueeze(-1) # [B, slots, D]
all_memories.append(weighted_mem)
# Stack all domain memories: [B, n_domains * slots, D]
memory_bank = torch.cat(all_memories, dim=1)
# Cross-attention using ModuleDict projections
q = self.memory_attn["q_proj"](x) # [B, T, D]
k = self.memory_attn["k_proj"](memory_bank) # [B, M, D]
v = self.memory_attn["v_proj"](memory_bank) # [B, M, D]
# Compute attention
n_heads = 8
head_dim = D // n_heads
q = q.view(B, T, n_heads, head_dim).transpose(1, 2)
k = k.view(B, -1, n_heads, head_dim).transpose(1, 2)
v = v.view(B, -1, n_heads, head_dim).transpose(1, 2)
scale = head_dim**-0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
attn_weights = F.softmax(attn_weights, dim=-1)
retrieved = torch.matmul(attn_weights, v)
retrieved = retrieved.transpose(1, 2).contiguous().view(B, T, D)
retrieved = self.memory_attn["o_proj"](retrieved)
# Gated fusion
concat = torch.cat([x, retrieved], dim=-1)
gate = self.gate(concat)
fused = self.fusion(concat)
output = x + gate * fused
return output, {
"domain_probs": domain_probs,
"top_domain": domain_probs.argmax(dim=-1),
"domain_names": list(self.memories.keys()),
}
# ==============================================================================
# SECTION 9: TEMPORAL TRANSFORMER WITH EARCP - CHECKPOINT COMPATIBLE
# ==============================================================================
class VideoExpert(nn.Module):
"""
Expert network for video processing - CHECKPOINT COMPATIBLE
Creates exact structure:
- confidence (Sequential)
- gate (Linear)
- fc1 (Linear)
- fc2 (Linear)
- dropout
"""
def __init__(self, config: NutataModelConfig, expert_type: str):
super().__init__()
self.expert_type = expert_type
# Confidence estimation
self.confidence = nn.Sequential(
nn.Linear(config.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()
)
# Gate
self.gate = nn.Linear(config.d_model, config.d_model)
# FFN with fc1/fc2 (NOT SwiGLU) - CRITICAL for checkpoint
self.fc1 = nn.Linear(config.d_model, config.d_ff)
self.fc2 = nn.Linear(config.d_ff, config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: [B, T, D]
Returns:
output: [B, T, D]
confidence: [B, 1]
"""
# Confidence from pooled representation
conf = self.confidence(x.mean(dim=1)) # [B, 1]
# Gated FFN
gate_val = torch.sigmoid(self.gate(x)) # [B, T, D]
h = self.dropout(F.gelu(self.fc1(x))) # [B, T, d_ff]
h = self.fc2(h) # [B, T, D]
out = h * gate_val
return out, conf
class VideoEARCPLayer(nn.Module):
"""
EARCP Layer for video with temporal attention and expert routing - CHECKPOINT COMPATIBLE
Creates exact structure:
- attn_norm (LayerNorm with elementwise_affine=False)
- attn_scale (Parameter)
- temporal_attn (VideoGQA)
- experts (ModuleList of VideoExpert)
- router (Linear)
- low_coh_count (buffer)
"""
def __init__(
self, config: NutataModelConfig, layer_idx: int, n_experts: int = None
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if n_experts is None:
n_experts = len(config.expert_types)
# Attention norm - elementwise_affine=False for checkpoint compatibility
self.attn_norm = nn.LayerNorm(config.d_model, elementwise_affine=False)
# Attention scale parameter
self.attn_scale = nn.Parameter(torch.ones(1))
# Temporal attention
self.temporal_attn = VideoGQA(config)
# Experts as ModuleList
self.experts = nn.ModuleList(
[
VideoExpert(
config,
config.expert_types[i]
if i < len(config.expert_types)
else f"Hybrid_{i}",
)
for i in range(n_experts)
]
)
# Router
self.router = nn.Linear(config.d_model, n_experts)
# Growth tracking
self.register_buffer("low_coh_count", torch.tensor(0))
# Store coherence for diagnostics
self.coherence_score = 0.5
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float, bool]:
"""
Args:
x: [B, T, D]
Returns:
output: [B, T, D]
coherence: float
grew: bool (whether new expert was added)
"""
# Temporal attention with residual
h = self.attn_norm(x)
attn_out = self.temporal_attn(h, causal=True)
x = x + self.attn_scale * attn_out
# Expert routing based on temporal mean
router_input = x.mean(dim=1) # [B, D]
router_logits = self.router(router_input) # [B, n_experts]
weights = F.softmax(router_logits, dim=-1) # [B, n_experts]
# Apply experts
expert_outputs = []
confs = []
for expert in self.experts:
out, conf = expert(x)
expert_outputs.append(out)
confs.append(conf)
# Stack and weight
expert_outputs = torch.stack(expert_outputs, dim=1) # [B, n_experts, T, D]
weighted_out = torch.einsum(
"be,betd->btd", weights, expert_outputs
) # [B, T, D]
x = x + weighted_out
# Compute coherence
confs_tensor = torch.stack(confs, dim=1) # [B, n_experts, 1]
expert_conf = confs_tensor.mean().item()
# Routing focus (inverse entropy)
entropy = -(weights * weights.log().clamp(min=-100)).sum(dim=-1).mean().item()
max_entropy = math.log(len(self.experts)) if len(self.experts) > 1 else 1.0
routing_focus = 1 - (entropy / max_entropy)
coherence = 0.5 * expert_conf + 0.5 * routing_focus
self.coherence_score = coherence
# Dynamic growth
grew = False
if coherence < self.config.growth_threshold_coherence:
self.low_coh_count += 1
if (
self.low_coh_count >= self.config.growth_patience
and len(self.experts) < self.config.max_experts
):
new_expert = VideoExpert(self.config, f"Hybrid_{len(self.experts)}").to(
x.device
)
self.experts.append(new_expert)
# Expand router
old_router = self.router
self.router = nn.Linear(self.config.d_model, len(self.experts)).to(
x.device
)
with torch.no_grad():
self.router.weight[: old_router.out_features] = old_router.weight
self.router.bias[: old_router.out_features] = old_router.bias
self.low_coh_count.zero_()
grew = True
else:
self.low_coh_count.zero_()
return x, coherence, grew
def get_expert_count(self) -> int:
return len(self.experts)
# ==============================================================================
# SECTION 10: NEUROGENESIS FOR VIDEO - CHECKPOINT COMPATIBLE
# ==============================================================================
class VideoNeurogenesis(nn.Module):
"""
Neurogenesis layer - adapts capacity based on temporal complexity
CHECKPOINT COMPATIBLE
Creates exact structure:
- weights (Parameter)
- bias (Parameter)
- temporal_gate (Linear)
- n_neurons, usage, lifetime, births, deaths (buffers)
"""
def __init__(self, input_dim: int, n_neurons: int, config: NutataModelConfig):
super().__init__()
self.config = config
self.input_dim = input_dim
self.weights = nn.Parameter(torch.randn(n_neurons, input_dim) * 0.02)
self.bias = nn.Parameter(torch.zeros(n_neurons))
self.temporal_gate = nn.Linear(input_dim, n_neurons)
self.register_buffer("n_neurons", torch.tensor(n_neurons))
self.register_buffer("usage", torch.ones(n_neurons))
self.register_buffer("lifetime", torch.zeros(n_neurons))
self.register_buffer("births", torch.tensor(0))
self.register_buffer("deaths", torch.tensor(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, D]
Returns:
output: [B, T, n_neurons]
"""
n = self.n_neurons.item()
# Temporal gating
gate = torch.sigmoid(self.temporal_gate(x))
gate = gate[..., :n]
# Neuron activation
out = torch.tanh(F.linear(x, self.weights[:n], self.bias[:n]))
out = out * gate
# Track usage
with torch.no_grad():
act = out.abs().mean(dim=(0, 1))
if act.size(0) == n:
self.usage[:n] = 0.99 * self.usage[:n] + 0.01 * act
self.lifetime[:n] += 1
return out
def maybe_grow(self, coherence: float) -> int:
"""Attempt to grow neurons based on coherence"""
if not self.config.neurogenesis_enabled:
return 0
n = self.n_neurons.item()
if n >= self.config.max_neurons:
return 0
if coherence < self.config.neuron_birth_threshold:
return 0
device = self.weights.device
with torch.no_grad():
new_w = torch.randn(1, self.input_dim, device=device) * 0.02
new_b = torch.zeros(1, device=device)
self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0))
self.bias = nn.Parameter(torch.cat([self.bias.data, new_b]))
# Expand temporal gate
old_gate = self.temporal_gate
self.temporal_gate = nn.Linear(self.input_dim, n + 1).to(device)
self.temporal_gate.weight.data[:n] = old_gate.weight.data
self.temporal_gate.weight.data[n:] = (
torch.randn(1, self.input_dim, device=device) * 0.02
)
self.temporal_gate.bias.data[:n] = old_gate.bias.data
self.temporal_gate.bias.data[n:] = 0
self.usage = torch.cat([self.usage, torch.ones(1, device=device)])
self.lifetime = torch.cat([self.lifetime, torch.zeros(1, device=device)])
self.n_neurons += 1
self.births += 1
return 1
def maybe_prune(self) -> int:
"""Prune low-usage neurons"""
if not self.config.neurogenesis_enabled:
return 0
n = self.n_neurons.item()
if n <= self.config.min_neurons:
return 0
# Find neurons to prune
threshold = self.config.neuron_death_threshold
prune_mask = self.usage[:n] < threshold
if not prune_mask.any():
return 0
keep_mask = ~prune_mask
n_keep = keep_mask.sum().item()
if n_keep < self.config.min_neurons:
return 0
device = self.weights.device
with torch.no_grad():
self.weights = nn.Parameter(self.weights.data[keep_mask])
self.bias = nn.Parameter(self.bias.data[keep_mask])
old_gate = self.temporal_gate
self.temporal_gate = nn.Linear(self.input_dim, n_keep).to(device)
self.temporal_gate.weight.data = old_gate.weight.data[keep_mask]
self.temporal_gate.bias.data = old_gate.bias.data[keep_mask]
self.usage = self.usage[keep_mask]
self.lifetime = self.lifetime[keep_mask]
pruned = n - n_keep
self.n_neurons.fill_(n_keep)
self.deaths += pruned
return pruned
def resize(self, target_neurons: int):
"""
Resize the neurogenesis layer to a specific number of neurons.
Used when loading pretrained models with different neuron counts.
"""
current = self.n_neurons.item()
if target_neurons == current:
return
device = self.weights.device
if target_neurons > current:
# Grow neurons
extra = target_neurons - current
new_w = torch.randn(extra, self.input_dim, device=device) * 0.02
new_b = torch.zeros(extra, device=device)
self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0))
self.bias = nn.Parameter(torch.cat([self.bias.data, new_b]))
# Expand temporal gate
old_gate = self.temporal_gate
self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device)
with torch.no_grad():
self.temporal_gate.weight[:current] = old_gate.weight
self.temporal_gate.bias[:current] = old_gate.bias
self.usage = torch.cat([self.usage, torch.ones(extra, device=device)])
self.lifetime = torch.cat(
[self.lifetime, torch.zeros(extra, device=device)]
)
else:
# Shrink neurons (keep most active ones)
keep_indices = torch.argsort(self.usage, descending=True)[:target_neurons]
self.weights = nn.Parameter(self.weights.data[keep_indices])
self.bias = nn.Parameter(self.bias.data[keep_indices])
# Shrink temporal gate
old_gate = self.temporal_gate
self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device)
with torch.no_grad():
self.temporal_gate.weight[:] = old_gate.weight[keep_indices]
self.temporal_gate.bias[:] = old_gate.bias[keep_indices]
self.usage = self.usage[keep_indices]
self.lifetime = self.lifetime[keep_indices]
self.n_neurons.fill_(target_neurons)
def get_stats(self) -> Dict:
return {
"total_neurons": self.n_neurons.item(),
"total_births": self.births.item(),
"total_deaths": self.deaths.item(),
"avg_usage": self.usage[: self.n_neurons.item()].mean().item(),
"max_lifetime": self.lifetime[: self.n_neurons.item()].max().item(),
}
# ==============================================================================
# SECTION 11: ENERGY SYSTEM - CHECKPOINT COMPATIBLE
# ==============================================================================
class VideoEnergySystem(nn.Module):
"""
Energy system for cognitive load tracking - CHECKPOINT COMPATIBLE
Creates exact structure:
- energy (buffer)
- consumed (buffer)
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
self.register_buffer("energy", torch.tensor(1.0))
self.register_buffer("consumed", torch.tensor(0.0))
self.costs = {
"encode": config.energy_cost_encode,
"decode": config.energy_cost_decode,
"predict": config.energy_cost_predict,
"process": 0.01,
"memory": 0.005,
"attention": 0.008,
}
def consume(self, operation: str, amount: float = None) -> bool:
"""Consume energy for an operation"""
cost = amount if amount else self.costs.get(operation, 0.01)
if self.energy.item() >= cost:
self.energy -= cost
self.consumed += cost
return True
return False
def regenerate(self):
"""Regenerate energy"""
regen = min(self.config.energy_regeneration, 1.0 - self.energy.item())
self.energy += regen
def reset(self):
"""Reset energy to full"""
self.energy.fill_(1.0)
def get_stats(self) -> Dict:
return {
"energy": self.energy.item(),
"consumed": self.consumed.item(),
"efficiency": 1.0
- self.consumed.item()
/ max(1.0, self.consumed.item() + self.energy.item()),
}
# ==============================================================================
# SECTION 12: TEMPORAL COHERENCE MODULE - CHECKPOINT COMPATIBLE
# ==============================================================================
class TemporalCoherenceModule(nn.Module):
"""
Ensures temporal coherence across frames - CHECKPOINT COMPATIBLE
Creates exact structure:
- diff_predictor (Sequential)
- smooth (Conv1d with groups=d_model) - DEPTHWISE CONV
- alpha (Parameter)
- coherence_history, history_idx (buffers)
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
d = config.d_model
# Frame difference predictor
self.diff_predictor = nn.Sequential(
nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d)
)
# Temporal smoothing - DEPTHWISE Conv1d (groups=d_model) - CRITICAL!
self.smooth = nn.Conv1d(d, d, kernel_size=3, padding=1, groups=d)
# Learnable blending parameter
self.alpha = nn.Parameter(torch.tensor(0.2))
# History tracking
self.register_buffer("coherence_history", torch.zeros(100))
self.register_buffer("history_idx", torch.tensor(0))
def forward(self, z_seq: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""
Args:
z_seq: [B, T, D] latent sequence
Returns:
output: [B, T, D] temporally coherent sequence
coherence: float score
"""
B, T, D = z_seq.shape
# Compute frame differences
if T > 1:
diffs = z_seq[:, 1:] - z_seq[:, :-1]
pairs = torch.cat([z_seq[:, :-1], z_seq[:, 1:]], dim=-1)
pred_diffs = self.diff_predictor(pairs)
coherence = 1 - F.mse_loss(pred_diffs, diffs).item()
coherence = max(0, min(1, coherence))
else:
coherence = 1.0
# Temporal smoothing with depthwise conv
z_t = z_seq.transpose(1, 2) # [B, D, T]
smoothed = self.smooth(z_t).transpose(1, 2) # [B, T, D]
# Adaptive blending
alpha = torch.sigmoid(self.alpha)
output = (1 - alpha) * z_seq + alpha * smoothed
# Update history
idx = self.history_idx.item() % 100
self.coherence_history[idx] = coherence
self.history_idx += 1
return output, coherence
def get_average_coherence(self) -> float:
valid = min(self.history_idx.item(), 100)
if valid == 0:
return 0.0
return self.coherence_history[:valid].mean().item()
# ==============================================================================
# SECTION 13: FLOW PREDICTION MODULE - CHECKPOINT COMPATIBLE
# ==============================================================================
class OpticalFlowConsistencyModule(nn.Module):
"""
Real Optical Flow Consistency - v2.0
Computes dense optical flow between frames and enforces temporal consistency.
"""
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
channels = config.video.channels
# Simple FlowNet (Learnable Flow)
self.flow_net = nn.Sequential(
nn.Conv2d(channels * 2, 32, kernel_size=7, padding=3),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 2, kernel_size=3, padding=1),
nn.Tanh(), # Flow in range [-1, 1] relative to image size roughly?
# Ideally unbounded, but we learn a shift.
# Let's simple learn a delta field.
)
def warp(self, x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
"""
Warp image x using flow.
x: [B, C, H, W]
flow: [B, 2, H, W]
"""
B, C, H, W = x.shape
# Create meshgrid
base_grid = torch.meshgrid(
torch.arange(H, device=x.device),
torch.arange(W, device=x.device),
indexing="ij",
)
base_grid = torch.stack(base_grid[::-1], dim=0).float() # [2, H, W] (x, y)
base_grid = base_grid.unsqueeze(0).expand(B, -1, -1, -1) # [B, 2, H, W]
# Add flow
# Flow predicted is typically pixels.
# grid_sample expects [-1, 1].
# Let's assume flow_net outputs normalized shift [-1, 1].
final_grid = (
base_grid + flow * 20.0
) # Scale flow reasonably or let net learn it?
# Normalize to [-1, 1]
final_grid[:, 0] = 2.0 * final_grid[:, 0] / max(W - 1, 1) - 1.0
final_grid[:, 1] = 2.0 * final_grid[:, 1] / max(H - 1, 1) - 1.0
final_grid = final_grid.permute(0, 2, 3, 1) # [B, H, W, 2]
return F.grid_sample(
x, final_grid, mode="bilinear", padding_mode="border", align_corners=True
)
def forward(self, video: torch.Tensor) -> Dict:
"""
Args:
video: [B, C, T, H, W]
Returns:
Dict with flow_loss
"""
if video.dim() == 3: # [B, T, D] -> Latent? v2.0 supports only video
return {}
B, C, T, H, W = video.shape
if T < 2:
return {"flow_loss": torch.tensor(0.0, device=video.device)}
# Pairwise flow
# Reshape to [B*(T-1), C*2, H, W]
frames_t = video[:, :, :-1].transpose(1, 2).reshape(-1, C, H, W)
frames_t1 = video[:, :, 1:].transpose(1, 2).reshape(-1, C, H, W)
flow_input = torch.cat([frames_t, frames_t1], dim=1)
# Resize for speed if needed? doing full res for now
flow = self.flow_net(flow_input) # [B*(T-1), 2, H, W]
# Warp t to t+1
warped_t = self.warp(frames_t, flow)
# Photometric loss
# Use L1 + robust loss
diff = (warped_t - frames_t1).abs()
loss = diff.mean()
return {
"flow_loss": loss,
"flow": flow,
"warped": warped_t.reshape(B, T - 1, C, H, W).transpose(1, 2),
}
# Alias for compatibility (NutataModel uses this name)
FlowPredictionModule = OpticalFlowConsistencyModule
# ==============================================================================
# SECTION 14: VIDEO QUALITY METRICS
# ==============================================================================
def compute_psnr(pred: torch.Tensor, target: torch.Tensor) -> float:
"""Compute Peak Signal-to-Noise Ratio"""
mse = F.mse_loss(pred, target)
if mse == 0:
return float("inf")
return 10 * math.log10(1.0 / mse.item())
def compute_ssim(pred: torch.Tensor, target: torch.Tensor) -> float:
"""Compute simplified SSIM for videos"""
# Flatten spatial dims
pred_flat = pred.flatten(2)
target_flat = target.flatten(2)
mu_pred = pred_flat.mean(dim=-1)
mu_target = target_flat.mean(dim=-1)
sigma_pred = pred_flat.var(dim=-1)
sigma_target = target_flat.var(dim=-1)
sigma_pred_target = (
(pred_flat - mu_pred.unsqueeze(-1)) * (target_flat - mu_target.unsqueeze(-1))
).mean(dim=-1)
C1 = 0.01**2
C2 = 0.03**2
ssim = ((2 * mu_pred * mu_target + C1) * (2 * sigma_pred_target + C2)) / (
(mu_pred**2 + mu_target**2 + C1) * (sigma_pred + sigma_target + C2)
)
return ssim.mean().item()
def compute_temporal_consistency(video: torch.Tensor) -> float:
"""Compute temporal consistency score"""
if video.shape[2] < 2: # T dimension
return 1.0
# Compute frame differences
diffs = (video[:, :, 1:] - video[:, :, :-1]).abs()
# Consistency is inverse of average difference
consistency = 1.0 - diffs.mean().item()
return max(0, min(1, consistency))
# ==============================================================================
# SECTION 14b: SPARSE COMPRESSION SYSTEM
# ==============================================================================
class SparseCompressor(nn.Module):
"""Learned top-k feature selection with straight-through estimator."""
def __init__(self, input_dim, output_dim, top_k=64, temporal_aware=False, use_ste=True):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.top_k = min(top_k, input_dim)
self.temporal_aware = temporal_aware
self.use_ste = use_ste
bottleneck = max(input_dim // 4, 32)
self.importance_net = nn.Sequential(
nn.Linear(input_dim, bottleneck), nn.SiLU(),
nn.Linear(bottleneck, input_dim), nn.Sigmoid(),
)
if temporal_aware:
self.temporal_context = nn.Sequential(
nn.Linear(input_dim, bottleneck), nn.SiLU(),
nn.Linear(bottleneck, input_dim), nn.Sigmoid(),
)
self.compress = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
self.residual_scale = nn.Parameter(torch.tensor(0.1))
self.register_buffer("selection_freq", torch.zeros(input_dim), persistent=False)
self.register_buffer("forward_count", torch.tensor(0, dtype=torch.long), persistent=False)
def forward(self, x):
is_sequence = x.dim() == 3
if is_sequence:
B, T, D = x.shape
x_flat = x.reshape(B * T, D)
else:
x_flat = x
B, T = x.shape[0], 1
importance = self.importance_net(x_flat)
if self.temporal_aware and is_sequence:
x_ctx = x.mean(dim=1, keepdim=True).expand_as(x).reshape(B * T, D)
temporal_mod = self.temporal_context(x_ctx)
importance = importance * temporal_mod
_, top_indices = importance.topk(self.top_k, dim=-1)
mask = torch.zeros_like(importance).scatter_(-1, top_indices, 1.0)
if self.use_ste and self.training:
sparse_x = x_flat * (mask - importance).detach() + x_flat * importance
else:
sparse_x = x_flat * mask
compressed = self.compress(sparse_x)
output = self.norm(compressed)
if self.training:
with torch.no_grad():
self.selection_freq += mask.sum(dim=0)
self.forward_count += 1
if is_sequence:
output = output.reshape(B, T, self.output_dim)
return output
def get_selection_stats(self):
if self.forward_count == 0:
return {"avg_freq": 0.0}
freq = self.selection_freq / self.forward_count
return {
"avg_selection_freq": freq.mean().item(),
"sparsity_ratio": (self.input_dim - self.top_k) / self.input_dim,
}
def sparsity_loss(self, x):
importance = self.importance_net(x.reshape(-1, self.input_dim))
entropy = -(
importance * importance.clamp(min=1e-8).log()
+ (1 - importance) * (1 - importance).clamp(min=1e-8).log()
)
return entropy.mean()
class PostVAESparseCompressor(nn.Module):
"""Post-VAE compression: forces selection of most informative latent dims."""
def __init__(self, d_model, top_k=128):
super().__init__()
self.compressor = SparseCompressor(d_model, d_model, top_k=top_k, temporal_aware=True, use_ste=True)
self.gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid())
def forward(self, z):
z_sparse = self.compressor(z)
gate_val = self.gate(torch.cat([z, z_sparse], dim=-1))
return z * (1 - gate_val) + z_sparse * gate_val
class InterEARCPSparseCompressor(nn.Module):
"""Between EARCP layers: adaptive compression, deeper = more aggressive."""
def __init__(self, d_model, top_k=96, layer_idx=0):
super().__init__()
self.layer_idx = layer_idx
adaptive_k = max(top_k - (layer_idx * 8), 48)
self.compressor = SparseCompressor(d_model, d_model, top_k=adaptive_k, temporal_aware=False, use_ste=True)
self.scale = nn.Parameter(torch.tensor(0.3))
def forward(self, z):
z_sparse = self.compressor(z)
return z + self.scale * (z_sparse - z)
class MemoryGateSparseCompressor(nn.Module):
"""Memory retrieval compression: extracts only truly relevant memory signals."""
def __init__(self, d_model, top_k=64):
super().__init__()
self.compressor = SparseCompressor(d_model, d_model, top_k=top_k, temporal_aware=True, use_ste=True)
def forward(self, retrieved):
return self.compressor(retrieved)
class PreDecodeSparseCompressor(nn.Module):
"""Pre-decoder compression: clean decisive latent code for decoder."""
def __init__(self, d_model, latent_channels, top_k=96):
super().__init__()
self.compressor = SparseCompressor(d_model, d_model, top_k=top_k, temporal_aware=False, use_ste=True)
self.gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid())
nn.init.constant_(self.gate[0].bias, 0.85)
def forward(self, z):
z_sparse = self.compressor(z)
gate_val = self.gate(torch.cat([z, z_sparse], dim=-1))
return z * gate_val + z_sparse * (1 - gate_val)
class SparseCompressionManager(nn.Module):
"""Manages all 4 SparseCompressor instances across the architecture."""
def __init__(self, config, earcp_compress_every=2):
super().__init__()
self.config = config
self.earcp_compress_every = earcp_compress_every
d = config.d_model
self.post_vae = PostVAESparseCompressor(d, top_k=d // 4)
n_compress_points = config.n_layers // earcp_compress_every
self.inter_earcp = nn.ModuleList([
InterEARCPSparseCompressor(d, top_k=d // 4, layer_idx=i)
for i in range(n_compress_points)
])
self.memory_gate = MemoryGateSparseCompressor(d, top_k=d // 8)
self.pre_decode = PreDecodeSparseCompressor(d, config.latent_channels, top_k=d // 4)
self.enable_post_vae = True
self.enable_inter_earcp = True
self.enable_memory = True
self.enable_pre_decode = True
def compress_post_vae(self, z):
return self.post_vae(z) if self.enable_post_vae else z
def compress_inter_earcp(self, z, layer_idx):
if not self.enable_inter_earcp:
return z
if layer_idx % self.earcp_compress_every != 0:
return z
compress_idx = layer_idx // self.earcp_compress_every
if compress_idx < len(self.inter_earcp):
return self.inter_earcp[compress_idx](z)
return z
def compress_memory(self, retrieved):
return self.memory_gate(retrieved) if self.enable_memory else retrieved
def compress_pre_decode(self, z):
return self.pre_decode(z) if self.enable_pre_decode else z
def total_sparsity_loss(self, z):
loss = self.post_vae.compressor.sparsity_loss(z)
loss = loss + self.memory_gate.compressor.sparsity_loss(z)
loss = loss + self.pre_decode.compressor.sparsity_loss(z)
return loss / 3.0
def count_params(self):
return sum(p.numel() for p in self.parameters())
# ==============================================================================
# SECTION 15: MAIN NEXUS-VIDEOMODEL CLASS
# ==============================================================================
class NutataModel(nn.Module):
"""
NUTATA v2.0 - HIGH RESOLUTION + CONDITIONING
Complete cognitive architecture for video understanding and generation.
v2.0 Features:
- Progressive resolution training (64 β†’ 512)
- Perceptual loss (VGG19 + adaptive weighting)
- Multi-level conditioning (text/action/scene β†’ latent/decoder/temporal/memory)
- Hierarchical temporal memory (frame/clip/scene)
- Anti-blur sharpening
- All components are checkpoint-compatible.
"""
def __init__(
self, config: NutataModelConfig = None, expert_counts: List[int] = None
):
super().__init__()
self.config = config or NutataModelConfig()
# Default expert counts per layer
if expert_counts is None:
expert_counts = [len(self.config.expert_types)] * self.config.n_layers
# === 3D VAE ===
self.encoder = VideoVAEEncoder(self.config)
self.decoder = VideoVAEDecoder(self.config)
# === LPOL Memory ===
self.lpol = VideoLPOL(self.config) if self.config.use_lpol else None
# === EARCP Layers ===
self.layers = nn.ModuleList(
[
VideoEARCPLayer(self.config, i, n_experts=expert_counts[i])
for i in range(self.config.n_layers)
]
)
# === Neurogenesis ===
self.neurogenesis = VideoNeurogenesis(self.config.d_model, 64, self.config)
# Initialize neuro_proj with max_neurons to handle dynamic growth during training
# This avoids optimizer issues when neurons grow
self.neuro_proj = nn.Linear(self.config.max_neurons, self.config.d_model)
# === Temporal Coherence ===
self.temporal_coherence = TemporalCoherenceModule(self.config)
# === Flow Prediction ===
self.flow_module = (
FlowPredictionModule(self.config) if self.config.flow_prediction else None
)
# === Energy System ===
self.energy = VideoEnergySystem(self.config)
# === Frame Prediction Head ===
self.frame_predictor = nn.Sequential(
nn.Linear(self.config.d_model, self.config.d_model),
nn.SiLU(),
nn.Dropout(self.config.dropout),
nn.Linear(self.config.d_model, self.config.latent_channels * 8 * 8),
)
# =====================================================================
# NEW V2.0 MODULES
# =====================================================================
# === Perceptual Loss (VGG19 + LPIPS) ===
self.perceptual_module = (
PerceptualLossModule(self.config)
if self.config.use_perceptual_loss
else None
)
# === Multi-Modal Conditioning (Text/Action/Scene) ===
self.conditioner = (
ConditionalEncoder(self.config) if self.config.use_conditioning else None
)
# === Hierarchical Temporal Memory (Frame/Clip/Scene) ===
self.hierarchical_memory = (
HierarchicalTemporalMemory(self.config)
if self.config.use_hierarchical_memory
else None
)
# === Anti-Blur Sharpening ===
self.sharpener = (
SharpnessEnhancer(self.config.video.channels)
if self.config.use_sharpening
else None
)
# === Sparse Compression (replaces gradient checkpointing) ===
self.sparse_manager = SparseCompressionManager(
self.config, earcp_compress_every=2
)
# === Epoch tracking for adaptive perceptual loss ===
self.current_epoch = 0
# Initialize weights
self._init_weights()
# Log initialization
logger.info(f"πŸ“¦ NUTATA v{self.config.version} initialized")
logger.info(f" Architecture: {self.config.codename}")
logger.info(
f" Input: {self.config.video.channels}x{self.config.video.n_frames}x{self.config.video.height}x{self.config.video.width}"
)
logger.info(f" Parameters: {self.count_params():,}")
logger.info(f" LPOL Domains: {len(self.config.domain_types)}")
logger.info(f" EARCP Layers: {self.config.n_layers}")
logger.info(f" Expert Types: {len(self.config.expert_types)}")
logger.info(f" Expert Counts: {[l.get_expert_count() for l in self.layers]}")
if self.config.use_gqa:
savings = int(
(1 - self.config.gqa_num_kv_groups / self.config.gqa_num_heads) * 100
)
logger.info(
f" GQA: {self.config.gqa_num_heads} heads, {self.config.gqa_num_kv_groups} KV groups ({savings}% savings)"
)
if self.config.flow_prediction:
logger.info(" Flow Prediction: Enabled")
# v2.0 feature logging
if self.config.use_perceptual_loss:
logger.info(
f" [v2.0] Perceptual Loss: Enabled (adaptive={self.config.perceptual_adaptive})"
)
if self.config.use_conditioning:
logger.info(
f" [v2.0] Conditioning: Enabled (text/action/scene at {self.config.condition_injection_levels})"
)
if self.config.use_hierarchical_memory:
logger.info(
f" [v2.0] Hierarchical Memory: Enabled (scales={self.config.memory_scales})"
)
if self.config.use_sharpening:
logger.info(
f" [v2.0] Sharpening: Enabled (weight={self.config.sharpening_weight})"
)
if hasattr(self, 'sparse_manager'):
logger.info(
f" [NUTATA] SparseCompressor: Enabled ({len(self.sparse_manager.inter_earcp)} bottleneck points, compress_every={self.sparse_manager.earcp_compress_every})"
)
logger.info(
f" [NUTATA] SparseCompressor params: {self.sparse_manager.count_params():,}"
)
def _init_weights(self):
"""Initialize weights with better defaults"""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv3d):
nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="relu"
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.ConvTranspose3d):
nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="relu"
)
if module.bias is not None:
nn.init.zeros_(module.bias)
def encode(
self, video: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Encode video to latent space
Args:
video: [B, C, T, H, W] or [B, T, C, H, W]
Returns:
z, mu, logvar, z_spatial
"""
self.energy.consume("encode")
return self.encoder(video)
def decode(self, z_spatial: torch.Tensor) -> torch.Tensor:
"""
Decode latent to video
Args:
z_spatial: [B, C_latent, T', H', W']
Returns:
video: [B, C, T, H, W]
"""
self.energy.consume("decode")
return self.decoder(z_spatial)
def process_temporal(self, z: torch.Tensor, conditions: dict = None) -> Dict:
"""
Process latent through cognitive systems - v2.0
Args:
z: [B, T', D] latent sequence
conditions: Encoded conditions (text/action/scene)
Returns:
Dict with processed latent and metadata
"""
self.energy.consume("process")
# === v2.0 Conditioning (Temporal Level) ===
if self.conditioner is not None and conditions is not None:
# Inject into temporal processing start
z = self.conditioner.inject(z, conditions, level="temporal")
# LPOL Memory
lpol_info = {}
if self.lpol is not None:
z, lpol_info = self.lpol(z)
# === v2.0 Hierarchical Memory ===
hierarchical_info = {}
if self.hierarchical_memory is not None:
# Inject condition before memory if available
if self.conditioner is not None and conditions is not None:
z_mem_in = self.conditioner.inject(z, conditions, level="memory")
else:
z_mem_in = z
z, hierarchical_info = self.hierarchical_memory(z_mem_in)
# EARCP Layers
coherences = []
total_growth = 0
for layer in self.layers:
z, coh, grew = layer(z)
coherences.append(coh)
if grew:
total_growth += 1
# Neurogenesis
neuro_out = self.neurogenesis(z)
# Pad neuro_out to max_neurons for neuro_proj (initialized with max_neurons)
current_neurons = neuro_out.shape[-1]
if current_neurons < self.config.max_neurons:
padding = torch.zeros(
*neuro_out.shape[:-1],
self.config.max_neurons - current_neurons,
device=neuro_out.device,
dtype=neuro_out.dtype,
)
neuro_out_padded = torch.cat([neuro_out, padding], dim=-1)
else:
neuro_out_padded = neuro_out[..., : self.config.max_neurons]
neuro_proj = self.neuro_proj(neuro_out_padded)
z = z + 0.1 * neuro_proj.mean(dim=-1, keepdim=True).expand_as(z)
avg_coherence = sum(coherences) / len(coherences) if coherences else 0.0
neuro_growth = self.neurogenesis.maybe_grow(avg_coherence)
# Temporal coherence
z, temp_coherence = self.temporal_coherence(z)
# Flow prediction (Moved to forward in v2.0 for optical flow consistency)
flow_info = {}
# if self.flow_module is not None:
# flow_info = self.flow_module(z)
# Energy regeneration
self.energy.regenerate()
# Consolidate memory info
memory_stats = {}
if lpol_info:
memory_stats.update(lpol_info)
if hierarchical_info:
memory_stats.update(hierarchical_info)
return {
"z": z,
"coherence": avg_coherence,
"temporal_coherence": temp_coherence,
"expert_growth": total_growth,
"neuro_growth": neuro_growth,
"lpol_info": lpol_info,
"flow_info": flow_info,
"hierarchical_info": hierarchical_info,
"energy": self.energy.get_stats(),
}
def predict_frames(self, z: torch.Tensor, n_frames: int = 1) -> torch.Tensor:
"""
Predict future frames in latent space
Args:
z: [B, T, D] current latent sequence
n_frames: number of frames to predict
Returns:
z_future: [B, n_frames, D] predicted latents
"""
self.energy.consume("predict")
predictions = []
current_z = z
for _ in range(n_frames):
last = current_z[:, -1:]
for layer in self.layers:
last, _, _ = layer(last)
predictions.append(last)
current_z = torch.cat([current_z, last], dim=1)
return torch.cat(predictions, dim=1)
def forward(
self,
video: torch.Tensor,
text_emb: torch.Tensor = None,
action_id: torch.Tensor = None,
scene_id: torch.Tensor = None,
epoch: int = None,
) -> Dict:
"""
Full forward pass - v2.0
Args:
video: [B, C, T, H, W] input video
text_emb: [B, D_text] optional text embeddings
action_id: [B] optional action class ids
scene_id: [B] optional scene class ids
epoch: current training epoch (for adaptive scheduling)
Returns:
Dict with loss, reconstruction, and metadata
"""
# Update current epoch
if epoch is not None:
self.current_epoch = epoch
B = video.shape[0]
# === v2.0 Conditioning Encoding ===
encoded_conditions = None
if self.conditioner is not None:
# Check if any conditions are provided
has_conditions = (
text_emb is not None or action_id is not None or scene_id is not None
)
if has_conditions:
encoded_conditions = self.conditioner.encode_conditions(
text_emb, action_id, scene_id, batch_size=B, device=video.device
)
# Encode
z, mu, logvar, z_spatial = self.encode(video)
# === v2.0 Conditioning (Latent Level) ===
if self.conditioner is not None and encoded_conditions is not None:
z = self.conditioner.inject(z, encoded_conditions, level="latent")
# Process through cognitive systems (with optional temporal conditioning)
proc_out = self.process_temporal(z, conditions=encoded_conditions)
z_processed = proc_out["z"]
# Decode
# Note: Ideally we'd inject conditions effectively into z_spatial or before decoder
recon = self.decode(z_spatial)
# === v2.0 Anti-Blur Sharpening ===
if self.sharpener is not None:
recon = self.sharpener(recon)
# === Losses ===
# Ensure video is in [B, C, T, H, W] format for comparison
if video.shape[2] == self.config.video.channels:
video_compare = video.permute(0, 2, 1, 3, 4)
else:
video_compare = video
# Handle size mismatch by interpolating reconstruction
if recon.shape != video_compare.shape:
recon = F.interpolate(
recon,
size=video_compare.shape[2:],
mode="trilinear",
align_corners=False,
)
recon_loss = F.mse_loss(recon, video_compare)
kl_loss = 0.5 * torch.mean(mu.pow(2) + logvar.exp() - logvar - 1)
coherence_loss = float(1.0 - proc_out["temporal_coherence"])
# Flow loss (Optical Flow Consistency - v2.0)
flow_loss_tensor = torch.tensor(0.0, device=recon.device)
if self.flow_module is not None:
# v2.0: Compute flow consistency on reconstructed video
flow_out = self.flow_module(recon)
if "flow_loss" in flow_out:
flow_loss_tensor = flow_out["flow_loss"]
# Add flow info to proc_out for logging
if "flow" in flow_out:
proc_out["flow_info"] = flow_out
# === v2.0 Perceptual Loss ===
perceptual_loss = torch.tensor(0.0, device=recon.device)
if self.perceptual_module is not None:
# Only compute if weight > 0
if (
self.config.perceptual_adaptive
and self.current_epoch < self.config.perceptual_warmup_epochs
):
# Scaled warmup
weight = self.perceptual_module.get_adaptive_weight(self.current_epoch)
else:
weight = self.config.perceptual_loss_weight
if weight > 0:
perceptual_loss = self.perceptual_module(
recon, video_compare, self.current_epoch
)
total_loss = (
recon_loss
+ self.config.kl_weight * kl_loss
+ self.config.temporal_coherence_weight * coherence_loss
+ 0.01 * flow_loss_tensor
+ perceptual_loss # Already weighted inside module if adaptive
)
# Convert simple tensors to float for logging
flow_loss = (
flow_loss_tensor.item()
if hasattr(flow_loss_tensor, "item")
else float(flow_loss_tensor)
)
p_loss_val = (
perceptual_loss.item()
if hasattr(perceptual_loss, "item")
else float(perceptual_loss)
)
return {
"loss": total_loss,
"recon_loss": recon_loss,
"kl_loss": kl_loss,
"coherence_loss": coherence_loss,
"flow_loss": flow_loss,
"perceptual_loss": p_loss_val,
"recon": recon,
"z": z,
"z_spatial": z_spatial,
"coherence": proc_out["coherence"],
"temporal_coherence": proc_out["temporal_coherence"],
"neurogenesis": proc_out["neuro_growth"],
"expert_growth": proc_out["expert_growth"],
"energy": proc_out["energy"],
}
def generate(
self,
n_frames: int = 16,
z_init: torch.Tensor = None,
temperature: float = 1.0,
batch_size: int = 1,
) -> torch.Tensor:
"""
Generate video from scratch or initial latent
Args:
n_frames: number of frames to generate
z_init: optional initial latent
temperature: sampling temperature
batch_size: batch size if z_init is None
Returns:
video: [B, C, T, H, W] generated video
"""
self.eval()
device = next(self.parameters()).device
if z_init is None:
B = batch_size
T = max(1, n_frames // 4)
H = self.config.video.height // 8
W = self.config.video.width // 8
z_init = (
torch.randn(B, self.config.latent_channels, T, H, W, device=device)
* temperature
)
with torch.no_grad():
video = self.decode(z_init)
# Ensure correct number of frames
if video.shape[2] != n_frames:
video = F.interpolate(
video,
size=(n_frames, video.shape[3], video.shape[4]),
mode="trilinear",
align_corners=False,
)
return video.clamp(0, 1)
def continue_video(
self, video: torch.Tensor, n_frames: int = 8, temperature: float = 0.8
) -> torch.Tensor:
"""
Continue a video by predicting future frames
Args:
video: [B, C, T, H, W] input video
n_frames: number of frames to predict
temperature: sampling temperature
Returns:
continued: [B, C, T+n_frames, H, W] continued video
"""
self.eval()
with torch.no_grad():
# Encode input
z, mu, logvar, z_spatial = self.encode(video)
# Predict future latents
z_future = self.predict_frames(z, n_frames=n_frames)
# Decode predictions to frames
B, T_fut, D = z_future.shape
spatial = self.frame_predictor(z_future)
spatial = spatial.view(B, T_fut, self.config.latent_channels, 8, 8)
spatial = spatial.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
# Scale up to full resolution
future_frames = self.decode(spatial)
# Concatenate with input
continued = torch.cat([video, future_frames], dim=2)
return continued.clamp(0, 1)
def count_params(self) -> int:
"""Count total parameters"""
return sum(p.numel() for p in self.parameters())
def count_trainable_params(self) -> int:
"""Count trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def diagnostics(self) -> Dict:
"""Get comprehensive diagnostics"""
total_experts = sum(layer.get_expert_count() for layer in self.layers)
return {
"model_version": self.config.version,
"total_params": self.count_params(),
"trainable_params": self.count_trainable_params(),
"total_experts": total_experts,
"expert_counts": [layer.get_expert_count() for layer in self.layers],
"neurogenesis": self.neurogenesis.get_stats(),
"energy": self.energy.get_stats(),
"temporal_coherence": self.temporal_coherence.get_average_coherence(),
"gqa_enabled": self.config.use_gqa,
"flow_prediction": self.config.flow_prediction,
"lpol_enabled": self.config.use_lpol,
"n_domains": len(self.config.domain_types) if self.config.use_lpol else 0,
}
# ===========================================================================
# HUGGINGFACE HUB INTEGRATION
# ===========================================================================
def save_pretrained(self, save_directory: str, save_config: bool = True):
"""
Save model weights and configuration to a directory.
Args:
save_directory: Directory to save the model
save_config: Whether to save configuration files
"""
os.makedirs(save_directory, exist_ok=True)
# 1. Save model weights
model_path = os.path.join(save_directory, "pytorch_model.bin")
torch.save(self.state_dict(), model_path)
logger.info(f"πŸ’Ύ Model weights saved to {model_path}")
# 2. Save configuration
if save_config:
config_path = os.path.join(save_directory, "config.json")
config_dict = self.config.to_dict()
config_dict["_class_name"] = "NutataModel"
config_dict["_version"] = self.config.version
config_dict["_architecture"] = "cognitive-video-world-model"
# Save dynamic state: expert counts per layer (CRITICAL for loading)
config_dict["_dynamic_state"] = {
"expert_counts": [layer.get_expert_count() for layer in self.layers],
"neuron_count": self.neurogenesis.n_neurons.item(),
"neuro_proj_in": self.neuro_proj.in_features,
"neuro_proj_out": self.neuro_proj.out_features,
"adaptive_proj_in": self.encoder.adaptive_proj.in_features
if self.encoder.adaptive_proj
else None,
"adaptive_proj_out": self.encoder.adaptive_proj.out_features
if self.encoder.adaptive_proj
else None,
}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=2)
logger.info(f"πŸ“„ Config saved to {config_path}")
# 3. Save model architecture info
arch_info = {
"model_type": "nutata-videomodel",
"version": self.config.version,
"total_params": self.count_params(),
"diagnostics": self.diagnostics(),
"architecture": {
"encoder": "VideoVAEEncoder (3D Causal VAE)",
"decoder": "VideoVAEDecoder (3D Upsampling)",
"memory": f"VideoLPOL ({len(self.config.domain_types)} domains)",
"attention": "VideoGQA (Grouped Query Attention with RoPE)",
"experts": "VideoEARCP (Dynamic Expert Routing)",
"neurogenesis": "VideoNeurogenesis (Adaptive Capacity)",
"temporal": "TemporalCoherenceModule + FlowPrediction",
},
"input_spec": {
"shape": f"[B, {self.config.video.channels}, {self.config.video.n_frames}, {self.config.video.height}, {self.config.video.width}]",
"dtype": "torch.float32",
"range": "[0, 1]",
},
"output_spec": {
"reconstruction": f"[B, {self.config.video.channels}, {self.config.video.n_frames}, {self.config.video.height}, {self.config.video.width}]",
"latent": f"[B, T', {self.config.d_model}]",
},
}
arch_path = os.path.join(save_directory, "architecture.json")
with open(arch_path, "w") as f:
json.dump(arch_info, f, indent=2)
logger.info(f"πŸ“„ Architecture info saved to {arch_path}")
# 4. Save training state (for resuming)
training_state = {
"neurogenesis_stats": self.neurogenesis.get_stats(),
"energy_stats": self.energy.get_stats(),
"expert_counts": [layer.get_expert_count() for layer in self.layers],
"temporal_coherence_avg": self.temporal_coherence.get_average_coherence(),
}
training_path = os.path.join(save_directory, "training_state.json")
with open(training_path, "w") as f:
json.dump(training_state, f, indent=2)
# 5. Save LICENSE file
license_content = self._generate_license()
license_path = os.path.join(save_directory, "LICENSE")
with open(license_path, "w") as f:
f.write(license_content)
# 6. Save modeling_nutata.py for HuggingFace Hub compatibility
modeling_code = self._generate_modeling_code()
modeling_path = os.path.join(save_directory, "modeling_nutata.py")
with open(modeling_path, "w") as f:
f.write(modeling_code)
logger.info(f"πŸ“„ modeling_nutata.py saved to {modeling_path}")
logger.info(f"βœ… Model saved to {save_directory}")
return save_directory
@classmethod
def from_pretrained(
cls, pretrained_path: str, device: str = None, **kwargs
) -> "NutataModel":
"""
Load a pretrained NutataModel from local path or HuggingFace Hub.
Args:
pretrained_path: Local directory or HuggingFace repo ID
device: Device to load model to
**kwargs: Additional config overrides
Returns:
Loaded NutataModel instance
"""
# Determine if local path or HF Hub
if os.path.isdir(pretrained_path):
load_dir = pretrained_path
elif HF_HUB_AVAILABLE:
logger.info(f"πŸ“₯ Downloading from HuggingFace Hub: {pretrained_path}")
load_dir = snapshot_download(repo_id=pretrained_path)
else:
raise ValueError(
f"Path {pretrained_path} not found and huggingface_hub not installed"
)
# Load config
config_path = os.path.join(load_dir, "config.json")
dynamic_state = None
if os.path.exists(config_path):
with open(config_path, "r") as f:
config_dict = json.load(f)
# Extract dynamic state before removing internal keys
dynamic_state = config_dict.pop("_dynamic_state", None)
# Remove internal keys
config_dict.pop("_class_name", None)
config_dict.pop("_version", None)
config_dict.pop("_architecture", None)
# Apply overrides
config_dict.update(kwargs)
config = NutataModelConfig.from_dict(config_dict)
else:
logger.warning("No config.json found, using default config")
config = NutataModelConfig(**kwargs)
# Determine expert counts
expert_counts = None
neuron_count = 64
neuro_proj_in = 64
neuro_proj_out = None
adaptive_proj_in = None
adaptive_proj_out = None
if dynamic_state is not None:
expert_counts = dynamic_state.get("expert_counts")
neuron_count = dynamic_state.get("neuron_count", 64)
neuro_proj_in = dynamic_state.get("neuro_proj_in", neuron_count)
neuro_proj_out = dynamic_state.get("neuro_proj_out")
adaptive_proj_in = dynamic_state.get("adaptive_proj_in")
adaptive_proj_out = dynamic_state.get("adaptive_proj_out")
logger.info(
f"πŸ”„ Restoring dynamic state: experts={expert_counts}, neurons={neuron_count}, neuro_proj={neuro_proj_in}"
)
# Create model with correct structure
model = cls(config, expert_counts=expert_counts)
# Resize neurogenesis if needed
if neuron_count != model.neurogenesis.n_neurons.item():
model.neurogenesis.resize(neuron_count)
# Note: neuro_proj is now initialized with max_neurons, so no resize needed
# For backward compatibility with old checkpoints that had smaller neuro_proj:
if neuro_proj_in and neuro_proj_in != model.neuro_proj.in_features:
# Old checkpoint had different size - create matching layer for loading
model.neuro_proj = nn.Linear(neuro_proj_in, config.d_model)
logger.info(
f"πŸ“ Adjusted neuro_proj for checkpoint: {neuro_proj_in} -> {config.d_model}"
)
# Create adaptive_proj if dimensions are known from config
if adaptive_proj_in is not None and adaptive_proj_out is not None:
model.encoder.adaptive_proj = nn.Linear(adaptive_proj_in, adaptive_proj_out)
logger.info(
f"πŸ“ Pre-created adaptive_proj from config: {adaptive_proj_in} -> {adaptive_proj_out}"
)
# Load weights
model_path = os.path.join(load_dir, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location="cpu")
# Handle adaptive_proj - it's created dynamically but we need to load it
adaptive_proj_keys = [k for k in state_dict.keys() if "adaptive_proj" in k]
if adaptive_proj_keys:
# Get the input dimension from the saved weights
weight_key = "encoder.adaptive_proj.weight"
if weight_key in state_dict:
in_features = state_dict[weight_key].shape[1]
out_features = state_dict[weight_key].shape[0]
# Create the adaptive_proj with correct dimensions
model.encoder.adaptive_proj = nn.Linear(in_features, out_features)
logger.info(
f"πŸ“ Created adaptive_proj: {in_features} -> {out_features}"
)
# ===== FIX: Reconcile neuro_proj from actual state_dict weights =====
# dynamic_state may have wrong neuro_proj_in (e.g. 64) while checkpoint
# actually trained with a larger size (e.g. 512). Always trust the weights.
neuro_proj_wkey = "neuro_proj.weight"
if neuro_proj_wkey in state_dict:
sd_in = state_dict[neuro_proj_wkey].shape[1]
sd_out = state_dict[neuro_proj_wkey].shape[0]
if (sd_in != model.neuro_proj.in_features or
sd_out != model.neuro_proj.out_features):
model.neuro_proj = nn.Linear(sd_in, sd_out)
logger.info(
f"πŸ“ Reconciled neuro_proj from weights: {sd_in} -> {sd_out}"
)
# ===== General shape mismatch handler =====
# strict=False only handles missing/unexpected keys, NOT size mismatches.
# Pre-filter any remaining shape conflicts to avoid RuntimeError.
model_state = model.state_dict()
keys_to_skip = []
for key in list(state_dict.keys()):
if key in model_state:
if state_dict[key].shape != model_state[key].shape:
logger.warning(
f" ⚠️ Shape mismatch for {key}: "
f"checkpoint={list(state_dict[key].shape)} vs "
f"model={list(model_state[key].shape)}, skipping"
)
keys_to_skip.append(key)
for key in keys_to_skip:
del state_dict[key]
# Load with strict=False to handle any remaining mismatches gracefully
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# Filter out expected "missing" keys that are handled dynamically
# Include shape-mismatched keys that were skipped
all_missing = list(missing) + keys_to_skip
if all_missing:
logger.warning(f" Missing/skipped keys: {len(all_missing)}")
if keys_to_skip:
logger.warning(f" ({len(keys_to_skip)} skipped due to shape mismatch)")
for k in all_missing[:5]:
logger.warning(f" - {k}")
if len(all_missing) > 5:
logger.warning(f" ... and {len(all_missing) - 5} more")
if unexpected:
# Filter out keys we intentionally handle separately
truly_unexpected = [k for k in unexpected if "adaptive_proj" not in k]
if truly_unexpected:
logger.warning(f" Unexpected keys: {len(truly_unexpected)}")
for k in truly_unexpected[:5]:
logger.warning(f" - {k}")
if len(truly_unexpected) > 5:
logger.warning(
f" ... and {len(truly_unexpected) - 5} more"
)
if not all_missing and not unexpected:
logger.info(f"βœ… Loaded weights perfectly (0 missing, 0 unexpected)")
else:
logger.info(f"βœ… Loaded weights from {model_path}")
else:
logger.warning("No pytorch_model.bin found, using random weights")
# Move to device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
logger.info(f"πŸ“¦ Model loaded on {device}")
return model
def push_to_hub(
self,
repo_id: str,
token: str = None,
private: bool = True,
commit_message: str = "Upload NUTATA",
save_directory: str = None,
) -> str:
"""
Push model to HuggingFace Hub.
Args:
repo_id: HuggingFace repo ID (e.g., "username/model-name")
token: HuggingFace API token
private: Whether to create private repo
commit_message: Commit message
save_directory: Temporary directory for saving
Returns:
URL of the uploaded model
"""
if not HF_HUB_AVAILABLE:
raise ImportError(
"huggingface_hub not installed. Install with: pip install huggingface_hub"
)
# Create save directory
if save_directory is None:
save_directory = tempfile.mkdtemp()
os.makedirs(save_directory, exist_ok=True)
# Save model locally first
self.save_pretrained(save_directory)
# Generate README
readme_content = self._generate_readme()
readme_path = os.path.join(save_directory, "README.md")
with open(readme_path, "w") as f:
f.write(readme_content)
logger.info(f"πŸ“„ README.md generated")
# Generate LICENSE
license_content = self._generate_license()
license_path = os.path.join(save_directory, "LICENSE")
with open(license_path, "w") as f:
f.write(license_content)
logger.info(f"πŸ“„ LICENSE generated")
# Generate modeling_nutata.py - CRITICAL for HuggingFace Hub loading
modeling_code = self._generate_modeling_code()
modeling_path = os.path.join(save_directory, "modeling_nutata.py")
with open(modeling_path, "w") as f:
f.write(modeling_code)
logger.info(f"πŸ“„ modeling_nutata.py generated ({len(modeling_code)} chars)")
# Create requirements.txt
requirements = """# NUTATA v1.1 Requirements
torch>=2.0.0
torchvision>=0.15.0
tqdm>=4.65.0
huggingface_hub>=0.19.0
numpy>=1.24.0
"""
req_path = os.path.join(save_directory, "requirements.txt")
with open(req_path, "w") as f:
f.write(requirements)
logger.info("πŸ“„ requirements.txt generated")
# Create .gitattributes for LFS
gitattributes = """*.bin filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
"""
gitattr_path = os.path.join(save_directory, ".gitattributes")
with open(gitattr_path, "w") as f:
f.write(gitattributes)
# List all files to upload
logger.info("\nπŸ“ Files to upload:")
for f in os.listdir(save_directory):
fpath = os.path.join(save_directory, f)
size = os.path.getsize(fpath) / (1024 * 1024) # MB
logger.info(f" {f} ({size:.2f} MB)")
# Create repo if needed
api = HfApi(token=token)
try:
create_repo(repo_id=repo_id, private=private, token=token, exist_ok=True)
logger.info(f"πŸ“ Repository created/verified: {repo_id}")
except Exception as e:
logger.warning(f"Could not create repo: {e}")
# Upload all files
logger.info("\n⬆️ Uploading to HuggingFace Hub...")
api.upload_folder(
folder_path=save_directory,
repo_id=repo_id,
commit_message=commit_message,
token=token,
)
url = f"https://huggingface.co/{repo_id}"
logger.info(f"πŸš€ Model pushed to {url}")
return url
def _generate_modeling_code(self) -> str:
"""
Generate modeling_nutata.py for HuggingFace Hub.
This file allows dynamic loading without needing the original codebase.
"""
# Get current expert counts and neuron count for embedding in code
expert_counts = [layer.get_expert_count() for layer in self.layers]
neuron_count = self.neurogenesis.n_neurons.item()
code = '''#!/usr/bin/env python3
"""
NUTATA v1.1 - HuggingFace Hub Modeling File
Auto-generated for dynamic loading and fine-tuning.
This file contains the complete model architecture for loading from HuggingFace Hub.
"""
import os
import math
import json
import logging
from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger(__name__)
# ==============================================================================
# CONFIGURATION
# ==============================================================================
@dataclass
class VideoConfig:
height: int = 64
width: int = 64
channels: int = 3
n_frames: int = 16
fps: int = 8
temporal_downsample: int = 4
spatial_downsample: int = 8
@classmethod
def from_dict(cls, d: dict) -> "VideoConfig":
valid_keys = {\'height\', \'width\', \'channels\', \'n_frames\', \'fps\', \'temporal_downsample\', \'spatial_downsample\'}
return cls(**{k: v for k, v in d.items() if k in valid_keys})
@dataclass
class NutataModelConfig:
model_type: str = "nutata-videomodel"
version: str = "1.1"
codename: str = "VideoSim-Cognitive"
architecture_type: str = "cognitive-video"
d_model: int = 512
d_ff: int = 2048
n_layers: int = 8
n_heads: int = 8
dropout: float = 0.1
latent_dim: int = 256
temporal_latent_dim: int = 512
latent_channels: int = 4
max_frames: int = 64
context_frames: int = 16
prediction_frames: int = 8
encoder_channels: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
decoder_channels: List[int] = field(default_factory=lambda: [512, 256, 128, 64])
kl_weight: float = 0.0001
use_lpol: bool = True
memory_size: int = 256
memory_slots_per_domain: int = 32
memory_k: int = 8
domain_types: List[str] = field(default_factory=lambda: [
\'motion\', \'appearance\', \'temporal\', \'spatial\', \'object\',
\'scene\', \'action\', \'causality\', \'physics\'
])
use_gqa: bool = True
gqa_num_heads: int = 8
gqa_num_kv_groups: int = 2
expert_types: List[str] = field(default_factory=lambda: [
\'Motion\', \'Appearance\', \'Temporal\', \'Spatial\', \'Prediction\', \'Generation\'
])
max_experts: int = 12
growth_threshold_coherence: float = 0.3
growth_patience: int = 10
neurogenesis_enabled: bool = True
min_neurons: int = 32
max_neurons: int = 256
neuron_birth_threshold: float = 0.8
neuron_death_threshold: float = 0.05
energy_enabled: bool = True
energy_cost_encode: float = 0.01
energy_cost_decode: float = 0.02
energy_cost_predict: float = 0.03
energy_regeneration: float = 0.05
dream_enabled: bool = True
dream_cycle_length: int = 100
dream_duration: int = 20
temporal_coherence_weight: float = 0.1
flow_prediction: bool = True
perceptual_loss_weight: float = 0.05
batch_size: int = 8
learning_rate: float = 1e-4
epochs: int = 50
gradient_accumulation: int = 4
warmup_steps: int = 500
push_to_hub: bool = True
hub_model_id: str = "amewebstudio/nutata-videomodel-v1.1"
video: VideoConfig = field(default_factory=VideoConfig)
def to_dict(self) -> Dict:
d = {}
for key, value in self.__dict__.items():
if key == \'video\':
d[key] = {k: v for k, v in value.__dict__.items()}
elif isinstance(value, (list, dict, str, int, float, bool, type(None))):
d[key] = value
return d
@classmethod
def from_dict(cls, d: Dict) -> "NutataModelConfig":
d = d.copy()
if \'video\' in d and isinstance(d[\'video\'], dict):
d[\'video\'] = VideoConfig.from_dict(d[\'video\'])
for k in [\'_dynamic_state\', \'_class_name\', \'_version\', \'_architecture\']:
d.pop(k, None)
known = set(cls.__dataclass_fields__.keys())
return cls(**{k: v for k, v in d.items() if k in known})
# ==============================================================================
# BUILDING BLOCKS
# ==============================================================================
class CausalConv3d(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
kernel_size: Tuple[int, int, int] = (3, 3, 3),
stride: Tuple[int, int, int] = (1, 1, 1)):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
self.kernel_size = kernel_size
self.stride = stride
self.temporal_pad = kernel_size[0] - 1
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=(0, kernel_size[1] // 2, kernel_size[2] // 2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.temporal_pad > 0:
x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0), mode=\'replicate\')
return self.conv(x)
class SpatioTemporalBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(1))
self.spatial = nn.Sequential(
nn.GroupNorm(min(8, channels), channels),
nn.SiLU(),
nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1))
)
self.temporal = nn.Sequential(
nn.GroupNorm(min(8, channels), channels),
nn.SiLU(),
CausalConv3d(channels, channels, kernel_size=(3, 1, 1))
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.spatial(x)
h = self.temporal(h)
return x + self.scale * h
class EncoderStage(nn.ModuleList):
def __init__(self, in_channels: int, out_channels: int, temporal_down: bool = True):
stride = (2, 2, 2) if temporal_down else (1, 2, 2)
super().__init__([
SpatioTemporalBlock(in_channels),
SpatioTemporalBlock(in_channels),
nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self[0](x)
x = self[1](x)
x = self[2](x)
return x
class DecoderStage(nn.ModuleList):
def __init__(self, in_channels: int, out_channels: int, temporal_up: bool = True):
stride = (2, 2, 2) if temporal_up else (1, 2, 2)
output_padding = (1, 1, 1) if temporal_up else (0, 1, 1)
super().__init__([
SpatioTemporalBlock(in_channels),
SpatioTemporalBlock(in_channels),
nn.ConvTranspose3d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, output_padding=output_padding)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self[0](x)
x = self[1](x)
x = self[2](x)
return x
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 1024, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer(\'inv_freq\', inv_freq, persistent=False)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum(\'i,j->ij\', t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer(\'cos_cache\', emb.cos(), persistent=False)
self.register_buffer(\'sin_cache\', emb.sin(), persistent=False)
def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.cos_cache.shape[0]:
self._build_cache(seq_len)
return self.cos_cache[:seq_len].to(device), self.sin_cache[:seq_len].to(device)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
# ==============================================================================
# VIDEO VAE
# ==============================================================================
class VideoVAEEncoder(nn.Module):
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
channels = config.encoder_channels
self.input_conv = nn.Conv3d(config.video.channels, channels[0], kernel_size=(1, 3, 3), padding=(0, 1, 1))
self.encoders = nn.ModuleList([
EncoderStage(channels[0], channels[1], temporal_down=True),
EncoderStage(channels[1], channels[2], temporal_down=True),
EncoderStage(channels[2], channels[3], temporal_down=False),
])
self.final_blocks = nn.ModuleList([SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])])
self.to_mu = nn.Conv3d(channels[-1], config.latent_channels, 1)
self.to_logvar = nn.Conv3d(channels[-1], config.latent_channels, 1)
self.to_d_model = nn.Linear(config.latent_channels, config.d_model)
self.adaptive_proj = None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if x.dim() == 5 and x.shape[2] == self.config.video.channels:
x = x.permute(0, 2, 1, 3, 4)
B, C, T, H, W = x.shape
h = self.input_conv(x)
for encoder_stage in self.encoders:
h = encoder_stage(h)
for block in self.final_blocks:
h = block(h)
mu_spatial = self.to_mu(h)
logvar_spatial = self.to_logvar(h).clamp(-10, 10)
std = torch.exp(0.5 * logvar_spatial)
eps = torch.randn_like(std)
z_spatial = mu_spatial + eps * std
B, C_lat, T_lat, H_lat, W_lat = z_spatial.shape
z_flat = z_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1)
if self.adaptive_proj is None or self.adaptive_proj.in_features != z_flat.shape[-1]:
self.adaptive_proj = nn.Linear(z_flat.shape[-1], self.config.d_model).to(z_flat.device)
z = self.adaptive_proj(z_flat)
mu_flat = mu_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1)
mu = self.adaptive_proj(mu_flat)
logvar_flat = logvar_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1)
logvar = self.adaptive_proj(logvar_flat)
return z, mu, logvar, z_spatial
class VideoVAEDecoder(nn.Module):
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
channels = config.decoder_channels
self.from_latent = nn.Conv3d(config.latent_channels, channels[0], 1)
self.init_blocks = nn.ModuleList([SpatioTemporalBlock(channels[0]), SpatioTemporalBlock(channels[0])])
self.decoders = nn.ModuleList([
DecoderStage(channels[0], channels[1], temporal_up=True),
DecoderStage(channels[1], channels[2], temporal_up=True),
DecoderStage(channels[2], channels[3], temporal_up=False),
])
self.final_blocks = nn.ModuleList([SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])])
self.to_rgb = nn.Sequential(
nn.Conv3d(channels[-1], channels[-1] // 2, (1, 3, 3), padding=(0, 1, 1)),
nn.SiLU(),
nn.Conv3d(channels[-1] // 2, config.video.channels, (1, 3, 3), padding=(0, 1, 1)),
nn.Sigmoid()
)
self.temporal_refine = nn.Sequential(
CausalConv3d(config.video.channels, 32, (3, 3, 3)),
nn.SiLU(),
nn.Conv3d(32, config.video.channels, 1),
nn.Tanh()
)
self.refine_scale = nn.Parameter(torch.tensor(0.05))
def forward(self, z: torch.Tensor) -> torch.Tensor:
h = self.from_latent(z)
for block in self.init_blocks:
h = block(h)
for decoder_stage in self.decoders:
h = decoder_stage(h)
for block in self.final_blocks:
h = block(h)
video = self.to_rgb(h)
refine = self.temporal_refine(video) * self.refine_scale
video = torch.clamp(video + refine, 0, 1)
return video
# ==============================================================================
# GQA ATTENTION
# ==============================================================================
class VideoGQA(nn.Module):
def __init__(self, config: NutataModelConfig):
super().__init__()
self.d_model = config.d_model
self.num_heads = config.gqa_num_heads
self.num_kv_groups = config.gqa_num_kv_groups
self.head_dim = config.d_model // self.num_heads
self.heads_per_group = self.num_heads // self.num_kv_groups
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(config.d_model, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.d_model, self.num_kv_groups * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.d_model, self.num_kv_groups * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.d_model, bias=False)
self.rope = RotaryPositionalEmbedding(self.head_dim, max_seq_len=config.max_frames * 2)
self.dropout = nn.Dropout(config.dropout)
self.residual_scale = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x: torch.Tensor, causal: bool = True, use_rope: bool = True) -> torch.Tensor:
B, T, D = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)
if use_rope:
cos, sin = self.rope(T, x.device)
q = (q * cos.unsqueeze(0).unsqueeze(0)) + (rotate_half(q) * sin.unsqueeze(0).unsqueeze(0))
k = (k * cos.unsqueeze(0).unsqueeze(0)) + (rotate_half(k) * sin.unsqueeze(0).unsqueeze(0))
k = k.repeat_interleave(self.heads_per_group, dim=1)
v = v.repeat_interleave(self.heads_per_group, dim=1)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if causal:
causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1)
attn_weights = attn_weights.masked_fill(causal_mask, float(\'-inf\'))
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(attn_output)
# ==============================================================================
# LPOL MEMORY
# ==============================================================================
class VideoLPOL(nn.Module):
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
self.n_domains = len(config.domain_types)
self.slots_per_domain = config.memory_slots_per_domain
self.memories = nn.ParameterDict({
d: nn.Parameter(torch.randn(self.slots_per_domain, config.d_model) * 0.02)
for d in config.domain_types
})
self.memory_attn = nn.ModuleDict({
\'q_proj\': nn.Linear(config.d_model, config.d_model, bias=False),
\'k_proj\': nn.Linear(config.d_model, config.d_model, bias=False),
\'v_proj\': nn.Linear(config.d_model, config.d_model, bias=False),
\'o_proj\': nn.Linear(config.d_model, config.d_model, bias=False),
})
self.domain_clf = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2), nn.GELU(),
nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, self.n_domains)
)
self.fusion = nn.Sequential(nn.Linear(config.d_model * 2, config.d_model), nn.GELU(), nn.Linear(config.d_model, config.d_model))
self.gate = nn.Sequential(nn.Linear(config.d_model * 2, config.d_model), nn.Sigmoid())
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
B, T, D = x.shape
x_pooled = x.mean(dim=1)
domain_logits = self.domain_clf(x_pooled)
domain_probs = F.softmax(domain_logits, dim=-1)
all_memories = []
for i, domain_name in enumerate(self.config.domain_types):
mem = self.memories[domain_name]
weight = domain_probs[:, i:i+1]
weighted_mem = mem.unsqueeze(0) * weight.unsqueeze(-1)
all_memories.append(weighted_mem)
memory_bank = torch.cat(all_memories, dim=1)
q = self.memory_attn[\'q_proj\'](x)
k = self.memory_attn[\'k_proj\'](memory_bank)
v = self.memory_attn[\'v_proj\'](memory_bank)
n_heads = 8
head_dim = D // n_heads
q = q.view(B, T, n_heads, head_dim).transpose(1, 2)
k = k.view(B, -1, n_heads, head_dim).transpose(1, 2)
v = v.view(B, -1, n_heads, head_dim).transpose(1, 2)
scale = head_dim ** -0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
attn_weights = F.softmax(attn_weights, dim=-1)
retrieved = torch.matmul(attn_weights, v)
retrieved = retrieved.transpose(1, 2).contiguous().view(B, T, D)
retrieved = self.memory_attn[\'o_proj\'](retrieved)
concat = torch.cat([x, retrieved], dim=-1)
gate = self.gate(concat)
fused = self.fusion(concat)
output = x + gate * fused
return output, {\'domain_probs\': domain_probs, \'top_domain\': domain_probs.argmax(dim=-1)}
# ==============================================================================
# EXPERTS AND EARCP
# ==============================================================================
class VideoExpert(nn.Module):
def __init__(self, config: NutataModelConfig, expert_type: str):
super().__init__()
self.expert_type = expert_type
self.confidence = nn.Sequential(nn.Linear(config.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
self.gate = nn.Linear(config.d_model, config.d_model)
self.fc1 = nn.Linear(config.d_model, config.d_ff)
self.fc2 = nn.Linear(config.d_ff, config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
conf = self.confidence(x.mean(dim=1))
gate_val = torch.sigmoid(self.gate(x))
h = self.dropout(F.gelu(self.fc1(x)))
h = self.fc2(h)
out = h * gate_val
return out, conf
class VideoEARCPLayer(nn.Module):
def __init__(self, config: NutataModelConfig, layer_idx: int, n_experts: int = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if n_experts is None:
n_experts = len(config.expert_types)
self.attn_norm = nn.LayerNorm(config.d_model, elementwise_affine=False)
self.attn_scale = nn.Parameter(torch.ones(1))
self.temporal_attn = VideoGQA(config)
self.experts = nn.ModuleList([
VideoExpert(config, config.expert_types[i] if i < len(config.expert_types) else f"Hybrid_{i}")
for i in range(n_experts)
])
self.router = nn.Linear(config.d_model, n_experts)
self.register_buffer(\'low_coh_count\', torch.tensor(0))
self.coherence_score = 0.5
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float, bool]:
h = self.attn_norm(x)
attn_out = self.temporal_attn(h, causal=True)
x = x + self.attn_scale * attn_out
router_input = x.mean(dim=1)
router_logits = self.router(router_input)
weights = F.softmax(router_logits, dim=-1)
expert_outputs = []
confs = []
for expert in self.experts:
out, conf = expert(x)
expert_outputs.append(out)
confs.append(conf)
expert_outputs = torch.stack(expert_outputs, dim=1)
weighted_out = torch.einsum(\'be,betd->btd\', weights, expert_outputs)
x = x + weighted_out
confs_tensor = torch.stack(confs, dim=1)
expert_conf = confs_tensor.mean().item()
entropy = -(weights * weights.log().clamp(min=-100)).sum(dim=-1).mean().item()
max_entropy = math.log(len(self.experts)) if len(self.experts) > 1 else 1.0
routing_focus = 1 - (entropy / max_entropy)
coherence = 0.5 * expert_conf + 0.5 * routing_focus
self.coherence_score = coherence
grew = False
if coherence < self.config.growth_threshold_coherence:
self.low_coh_count += 1
if self.low_coh_count >= self.config.growth_patience and len(self.experts) < self.config.max_experts:
new_expert = VideoExpert(self.config, f"Hybrid_{len(self.experts)}").to(x.device)
self.experts.append(new_expert)
old_router = self.router
self.router = nn.Linear(self.config.d_model, len(self.experts)).to(x.device)
with torch.no_grad():
self.router.weight[:old_router.out_features] = old_router.weight
self.router.bias[:old_router.out_features] = old_router.bias
self.low_coh_count.zero_()
grew = True
else:
self.low_coh_count.zero_()
return x, coherence, grew
def get_expert_count(self) -> int:
return len(self.experts)
# ==============================================================================
# NEUROGENESIS
# ==============================================================================
class VideoNeurogenesis(nn.Module):
def __init__(self, input_dim: int, n_neurons: int, config: NutataModelConfig):
super().__init__()
self.config = config
self.input_dim = input_dim
self.weights = nn.Parameter(torch.randn(n_neurons, input_dim) * 0.02)
self.bias = nn.Parameter(torch.zeros(n_neurons))
self.temporal_gate = nn.Linear(input_dim, n_neurons)
self.register_buffer(\'n_neurons\', torch.tensor(n_neurons))
self.register_buffer(\'usage\', torch.ones(n_neurons))
self.register_buffer(\'lifetime\', torch.zeros(n_neurons))
self.register_buffer(\'births\', torch.tensor(0))
self.register_buffer(\'deaths\', torch.tensor(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
n = self.n_neurons.item()
gate = torch.sigmoid(self.temporal_gate(x))
gate = gate[..., :n]
out = torch.tanh(F.linear(x, self.weights[:n], self.bias[:n]))
out = out * gate
with torch.no_grad():
act = out.abs().mean(dim=(0, 1))
if act.size(0) == n:
self.usage[:n] = 0.99 * self.usage[:n] + 0.01 * act
self.lifetime[:n] += 1
return out
def maybe_grow(self, coherence: float) -> int:
if not self.config.neurogenesis_enabled:
return 0
n = self.n_neurons.item()
if n >= self.config.max_neurons or coherence < self.config.neuron_birth_threshold:
return 0
device = self.weights.device
with torch.no_grad():
new_w = torch.randn(1, self.input_dim, device=device) * 0.02
new_b = torch.zeros(1, device=device)
self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0))
self.bias = nn.Parameter(torch.cat([self.bias.data, new_b]))
old_gate = self.temporal_gate
self.temporal_gate = nn.Linear(self.input_dim, n + 1).to(device)
self.temporal_gate.weight.data[:n] = old_gate.weight.data
self.temporal_gate.weight.data[n:] = torch.randn(1, self.input_dim, device=device) * 0.02
self.temporal_gate.bias.data[:n] = old_gate.bias.data
self.temporal_gate.bias.data[n:] = 0
self.usage = torch.cat([self.usage, torch.ones(1, device=device)])
self.lifetime = torch.cat([self.lifetime, torch.zeros(1, device=device)])
self.n_neurons += 1
self.births += 1
return 1
def resize(self, target_neurons: int):
current = self.n_neurons.item()
if target_neurons == current:
return
device = self.weights.device
if target_neurons > current:
extra = target_neurons - current
new_w = torch.randn(extra, self.input_dim, device=device) * 0.02
new_b = torch.zeros(extra, device=device)
self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0))
self.bias = nn.Parameter(torch.cat([self.bias.data, new_b]))
old_gate = self.temporal_gate
self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device)
with torch.no_grad():
self.temporal_gate.weight[:current] = old_gate.weight
self.temporal_gate.bias[:current] = old_gate.bias
self.usage = torch.cat([self.usage, torch.ones(extra, device=device)])
self.lifetime = torch.cat([self.lifetime, torch.zeros(extra, device=device)])
else:
keep_indices = torch.argsort(self.usage, descending=True)[:target_neurons]
self.weights = nn.Parameter(self.weights.data[keep_indices])
self.bias = nn.Parameter(self.bias.data[keep_indices])
old_gate = self.temporal_gate
self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device)
with torch.no_grad():
self.temporal_gate.weight[:] = old_gate.weight[keep_indices]
self.temporal_gate.bias[:] = old_gate.bias[keep_indices]
self.usage = self.usage[keep_indices]
self.lifetime = self.lifetime[keep_indices]
self.n_neurons.fill_(target_neurons)
def get_stats(self) -> Dict:
return {\'total_neurons\': self.n_neurons.item(), \'total_births\': self.births.item(),
\'total_deaths\': self.deaths.item(), \'avg_usage\': self.usage[:self.n_neurons.item()].mean().item(),
\'max_lifetime\': self.lifetime[:self.n_neurons.item()].max().item()}
# ==============================================================================
# TEMPORAL COHERENCE & FLOW
# ==============================================================================
class TemporalCoherenceModule(nn.Module):
def __init__(self, config: NutataModelConfig):
super().__init__()
d = config.d_model
self.diff_predictor = nn.Sequential(nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d))
self.smooth = nn.Conv1d(d, d, kernel_size=3, padding=1, groups=d)
self.alpha = nn.Parameter(torch.tensor(0.2))
self.register_buffer(\'coherence_history\', torch.zeros(100))
self.register_buffer(\'history_idx\', torch.tensor(0))
def forward(self, z_seq: torch.Tensor) -> Tuple[torch.Tensor, float]:
B, T, D = z_seq.shape
if T > 1:
diffs = z_seq[:, 1:] - z_seq[:, :-1]
pairs = torch.cat([z_seq[:, :-1], z_seq[:, 1:]], dim=-1)
pred_diffs = self.diff_predictor(pairs)
coherence = 1 - F.mse_loss(pred_diffs, diffs).item()
coherence = max(0, min(1, coherence))
else:
coherence = 1.0
z_t = z_seq.transpose(1, 2)
smoothed = self.smooth(z_t).transpose(1, 2)
alpha = torch.sigmoid(self.alpha)
output = (1 - alpha) * z_seq + alpha * smoothed
idx = self.history_idx.item() % 100
self.coherence_history[idx] = coherence
self.history_idx += 1
return output, coherence
def get_average_coherence(self) -> float:
valid = min(self.history_idx.item(), 100)
return self.coherence_history[:valid].mean().item() if valid > 0 else 0.0
class FlowPredictionModule(nn.Module):
def __init__(self, config: NutataModelConfig):
super().__init__()
d = config.d_model
self.flow_encoder = nn.Sequential(nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d // 2), nn.SiLU(), nn.Linear(d // 2, d))
self.warp_net = nn.Sequential(nn.Linear(d * 2, d), nn.Tanh())
self.motion_magnitude = nn.Sequential(nn.Linear(d, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
def forward(self, z_seq: torch.Tensor) -> Dict:
B, T, D = z_seq.shape
if T < 2:
return {\'flow\': None, \'warped\': z_seq, \'motion_magnitude\': torch.zeros(B, 1, device=z_seq.device),
\'flow_loss\': torch.tensor(0.0, device=z_seq.device)}
z_t = z_seq[:, :-1]
z_t1 = z_seq[:, 1:]
pairs = torch.cat([z_t, z_t1], dim=-1)
flow = self.flow_encoder(pairs)
warp_input = torch.cat([z_t, flow], dim=-1)
warped = self.warp_net(warp_input)
motion = self.motion_magnitude(flow.mean(dim=1))
flow_loss = F.mse_loss(warped, z_t1)
return {\'flow\': flow, \'warped\': warped, \'motion_magnitude\': motion, \'flow_loss\': flow_loss}
class VideoEnergySystem(nn.Module):
def __init__(self, config: NutataModelConfig):
super().__init__()
self.config = config
self.register_buffer(\'energy\', torch.tensor(1.0))
self.register_buffer(\'consumed\', torch.tensor(0.0))
self.costs = {\'encode\': config.energy_cost_encode, \'decode\': config.energy_cost_decode,
\'predict\': config.energy_cost_predict, \'process\': 0.01, \'memory\': 0.005, \'attention\': 0.008}
def consume(self, operation: str, amount: float = None) -> bool:
cost = amount if amount else self.costs.get(operation, 0.01)
if self.energy.item() >= cost:
self.energy -= cost
self.consumed += cost
return True
return False
def regenerate(self):
regen = min(self.config.energy_regeneration, 1.0 - self.energy.item())
self.energy += regen
def reset(self):
self.energy.fill_(1.0)
def get_stats(self) -> Dict:
return {\'energy\': self.energy.item(), \'consumed\': self.consumed.item()}
# ==============================================================================
# MAIN MODEL
# ==============================================================================
class NutataModel(nn.Module):
def __init__(self, config: NutataModelConfig = None, expert_counts: List[int] = None):
super().__init__()
self.config = config or NutataModelConfig()
if expert_counts is None:
expert_counts = [len(self.config.expert_types)] * self.config.n_layers
self.encoder = VideoVAEEncoder(self.config)
self.decoder = VideoVAEDecoder(self.config)
self.lpol = VideoLPOL(self.config) if self.config.use_lpol else None
self.layers = nn.ModuleList([
VideoEARCPLayer(self.config, i, n_experts=expert_counts[i])
for i in range(self.config.n_layers)
])
self.neurogenesis = VideoNeurogenesis(self.config.d_model, 64, self.config)
self.neuro_proj = nn.Linear(self.config.max_neurons, self.config.d_model)
self.temporal_coherence = TemporalCoherenceModule(self.config)
self.flow_module = FlowPredictionModule(self.config) if self.config.flow_prediction else None
self.energy = VideoEnergySystem(self.config)
self.frame_predictor = nn.Sequential(
nn.Linear(self.config.d_model, self.config.d_model), nn.SiLU(),
nn.Dropout(self.config.dropout), nn.Linear(self.config.d_model, self.config.latent_channels * 8 * 8)
)
def encode(self, video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
self.energy.consume(\'encode\')
return self.encoder(video)
def decode(self, z_spatial: torch.Tensor) -> torch.Tensor:
self.energy.consume(\'decode\')
return self.decoder(z_spatial)
def process_temporal(self, z: torch.Tensor) -> Dict:
self.energy.consume(\'process\')
lpol_info = {}
if self.lpol is not None:
z, lpol_info = self.lpol(z)
coherences = []
total_growth = 0
for layer in self.layers:
z, coh, grew = layer(z)
coherences.append(coh)
if grew:
total_growth += 1
neuro_out = self.neurogenesis(z)
current_neurons = neuro_out.shape[-1]
if current_neurons < self.config.max_neurons:
padding = torch.zeros(*neuro_out.shape[:-1], self.config.max_neurons - current_neurons, device=neuro_out.device, dtype=neuro_out.dtype)
neuro_out_padded = torch.cat([neuro_out, padding], dim=-1)
else:
neuro_out_padded = neuro_out[..., :self.config.max_neurons]
neuro_proj = self.neuro_proj(neuro_out_padded)
z = z + 0.1 * neuro_proj.mean(dim=-1, keepdim=True).expand_as(z)
avg_coherence = sum(coherences) / len(coherences) if coherences else 0.0
neuro_growth = self.neurogenesis.maybe_grow(avg_coherence)
z, temp_coherence = self.temporal_coherence(z)
flow_info = {}
if self.flow_module is not None:
flow_info = self.flow_module(z)
self.energy.regenerate()
return {\'z\': z, \'coherence\': avg_coherence, \'temporal_coherence\': temp_coherence,
\'expert_growth\': total_growth, \'neuro_growth\': neuro_growth, \'lpol_info\': lpol_info,
\'flow_info\': flow_info, \'energy\': self.energy.get_stats()}
def forward(self, video: torch.Tensor) -> Dict:
B = video.shape[0]
z, mu, logvar, z_spatial = self.encode(video)
proc_out = self.process_temporal(z)
z_processed = proc_out[\'z\']
recon = self.decode(z_spatial)
if video.shape[2] == self.config.video.channels:
video_compare = video.permute(0, 2, 1, 3, 4)
else:
video_compare = video
if recon.shape != video_compare.shape:
recon = F.interpolate(recon, size=video_compare.shape[2:], mode=\'trilinear\', align_corners=False)
recon_loss = F.mse_loss(recon, video_compare)
kl_loss = 0.5 * torch.mean(mu.pow(2) + logvar.exp() - logvar - 1)
coherence_loss = float(1.0 - proc_out[\'temporal_coherence\'])
flow_loss_tensor = torch.tensor(0.0, device=recon.device)
if \'flow_info\' in proc_out and proc_out[\'flow_info\'].get(\'flow_loss\') is not None:
flow_loss_tensor = proc_out[\'flow_info\'][\'flow_loss\']
total_loss = recon_loss + self.config.kl_weight * kl_loss + self.config.temporal_coherence_weight * coherence_loss + 0.01 * flow_loss_tensor
flow_loss = flow_loss_tensor.item() if hasattr(flow_loss_tensor, \'item\') else float(flow_loss_tensor)
return {\'loss\': total_loss, \'recon_loss\': recon_loss, \'kl_loss\': kl_loss, \'coherence_loss\': coherence_loss,
\'flow_loss\': flow_loss, \'recon\': recon, \'z\': z, \'z_spatial\': z_spatial,
\'coherence\': proc_out[\'coherence\'], \'temporal_coherence\': proc_out[\'temporal_coherence\'],
\'neurogenesis\': proc_out[\'neuro_growth\'], \'expert_growth\': proc_out[\'expert_growth\'], \'energy\': proc_out[\'energy\']}
def generate(self, n_frames: int = 16, z_init: torch.Tensor = None, temperature: float = 1.0, batch_size: int = 1) -> torch.Tensor:
self.eval()
device = next(self.parameters()).device
if z_init is None:
B = batch_size
T = max(1, n_frames // 4)
H = self.config.video.height // 8
W = self.config.video.width // 8
z_init = torch.randn(B, self.config.latent_channels, T, H, W, device=device) * temperature
with torch.no_grad():
video = self.decode(z_init)
if video.shape[2] != n_frames:
video = F.interpolate(video, size=(n_frames, video.shape[3], video.shape[4]), mode=\'trilinear\', align_corners=False)
return video.clamp(0, 1)
def count_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def diagnostics(self) -> Dict:
total_experts = sum(layer.get_expert_count() for layer in self.layers)
return {\'model_version\': self.config.version, \'total_params\': self.count_params(),
\'total_experts\': total_experts, \'expert_counts\': [layer.get_expert_count() for layer in self.layers],
\'neurogenesis\': self.neurogenesis.get_stats(), \'energy\': self.energy.get_stats()}
@classmethod
def from_pretrained(cls, pretrained_path: str, device: str = None, **kwargs) -> "NutataModel":
from huggingface_hub import snapshot_download
if os.path.isdir(pretrained_path):
load_dir = pretrained_path
else:
load_dir = snapshot_download(repo_id=pretrained_path)
config_path = os.path.join(load_dir, "config.json")
dynamic_state = None
if os.path.exists(config_path):
with open(config_path, \'r\') as f:
config_dict = json.load(f)
dynamic_state = config_dict.pop(\'_dynamic_state\', None)
config_dict.pop(\'_class_name\', None)
config_dict.pop(\'_version\', None)
config_dict.pop(\'_architecture\', None)
config_dict.update(kwargs)
config = NutataModelConfig.from_dict(config_dict)
else:
config = NutataModelConfig(**kwargs)
expert_counts = None
neuron_count = 64
neuro_proj_in = None # Will use max_neurons by default
adaptive_proj_in = None
adaptive_proj_out = None
if dynamic_state is not None:
expert_counts = dynamic_state.get(\'expert_counts\')
neuron_count = dynamic_state.get(\'neuron_count\', 64)
neuro_proj_in = dynamic_state.get(\'neuro_proj_in\', neuron_count)
adaptive_proj_in = dynamic_state.get(\'adaptive_proj_in\')
adaptive_proj_out = dynamic_state.get(\'adaptive_proj_out\')
model = cls(config, expert_counts=expert_counts)
if neuron_count != model.neurogenesis.n_neurons.item():
model.neurogenesis.resize(neuron_count)
# For backward compatibility with old checkpoints that had smaller neuro_proj
if neuro_proj_in and neuro_proj_in != model.neuro_proj.in_features:
model.neuro_proj = nn.Linear(neuro_proj_in, config.d_model)
if adaptive_proj_in and adaptive_proj_out:
model.encoder.adaptive_proj = nn.Linear(adaptive_proj_in, adaptive_proj_out)
model_path = os.path.join(load_dir, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=\'cpu\')
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
logger.warning(f"Missing keys: {len(missing)}")
if unexpected:
logger.warning(f"Unexpected keys: {len(unexpected)}")
if device is None:
device = \'cuda\' if torch.cuda.is_available() else \'cpu\'
return model.to(device)
'''
return code
def _generate_readme(self) -> str:
"""Generate README.md for HuggingFace Hub"""
diag = self.diagnostics()
readme = f'''---
license: other
license_name: ame-web-studio-proprietary
license_link: LICENSE
library_name: pytorch
tags:
- video
- video-generation
- video-understanding
- cognitive-architecture
- world-model
- nexus
- earcp
- lpol
- neurogenesis
- 3d-vae
- gqa
language:
- en
pipeline_tag: video-classification
---
# NUTATA v{self.config.version}
<p align="center">
<img src="https://img.shields.io/badge/NUTATA-blue?style=for-the-badge" alt="NUTATA"/>
<img src="https://img.shields.io/badge/Version-{self.config.version}-green?style=for-the-badge" alt="Version"/>
<img src="https://img.shields.io/badge/License-Proprietary-red?style=for-the-badge" alt="License"/>
</p>
## 🧠 Cognitive Video World Model
**NUTATA** is a revolutionary cognitive architecture for video understanding and generation, developed by **Mike Amega (Logo)** at **Ame Web Studio**.
This model combines multiple novel components into a unified framework for learning video representations with cognitive capabilities.
## πŸ—οΈ Architecture Overview
```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ NUTATA v{self.config.version} β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ β”‚
β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚ β”‚ 3D VAE β”‚ β”‚ Cognitive β”‚ β”‚ 3D VAE β”‚ β”‚
β”‚ β”‚ Encoder │───▢│ Processor │───▢│ Decoder β”‚ β”‚
β”‚ β”‚ (Causal) β”‚ β”‚ (EARCP+LPOL+GQA) β”‚ β”‚ β”‚ β”‚
β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚ β”‚
β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚ β”‚ Cognitive Components β”‚ β”‚
β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚
β”‚ β”‚ β”‚ LPOL β”‚ β”‚ EARCP β”‚ β”‚ GQA β”‚ β”‚ Neuro- β”‚ β”‚ Flow β”‚ β”‚ β”‚
β”‚ β”‚ β”‚ Memory β”‚ β”‚ Experts β”‚ β”‚ Attn β”‚ β”‚ genesis β”‚ β”‚ Pred. β”‚ β”‚ β”‚
β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚
β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚ β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
### Key Components
| Component | Description |
|-----------|-------------|
| **3D Causal VAE** | Spatiotemporal encoding with causal convolutions for video compression |
| **VideoLPOL** | 9-domain Long-term Procedural & Operational Learning memory system |
| **VideoEARCP** | Ensemble Auto-Regulated by Coherence & Performance with dynamic experts |
| **VideoGQA** | Grouped Query Attention with RoPE for efficient long-sequence processing |
| **Neurogenesis** | Dynamic neural capacity adaptation based on task complexity |
| **Flow Prediction** | Latent optical flow for temporal coherence and motion understanding |
| **Energy System** | Cognitive load management and resource allocation |
### LPOL Memory Domains
1. **Motion** - Movement patterns and trajectories
2. **Appearance** - Visual features and textures
3. **Temporal** - Time dynamics and sequences
4. **Spatial** - Spatial relationships and layouts
5. **Object** - Object persistence and tracking
6. **Scene** - Scene context and semantics
7. **Action** - Action recognition patterns
8. **Causality** - Cause-effect relationships
9. **Physics** - Physical dynamics understanding
## πŸ“Š Model Specifications
| Specification | Value |
|--------------|-------|
| **Parameters** | {diag["total_params"]:,} |
| **Architecture** | {self.config.codename} |
| **Input Shape** | `[B, {self.config.video.channels}, {self.config.video.n_frames}, {self.config.video.height}, {self.config.video.width}]` |
| **Latent Dim** | {self.config.d_model} |
| **EARCP Layers** | {self.config.n_layers} |
| **Total Experts** | {diag["total_experts"]} |
| **Expert Counts** | {diag["expert_counts"]} |
| **GQA Heads** | {self.config.gqa_num_heads} ({self.config.gqa_num_kv_groups} KV groups) |
| **Memory Domains** | {len(self.config.domain_types)} |
| **Neurons** | {diag["neurogenesis"]["total_neurons"]} |
## πŸš€ Quick Start
### Installation
```bash
pip install torch torchvision huggingface_hub tqdm
```
### Loading the Model
```python
from nutata_model import NutataModel
# Load from HuggingFace Hub
model = NutataModel.from_pretrained("{self.config.hub_model_id}")
# Or load locally
model = NutataModel.from_pretrained("./path/to/model")
```
### Video Reconstruction
```python
import torch
# Create sample video: [B, C, T, H, W]
video = torch.randn(1, 3, 16, 64, 64).cuda()
# Forward pass
model.eval()
with torch.no_grad():
outputs = model(video)
print(f"Reconstruction shape: {{outputs['recon'].shape}}")
print(f"Loss: {{outputs['loss'].item():.4f}}")
print(f"EARCP Coherence: {{outputs['coherence']:.3f}}")
print(f"Temporal Coherence: {{outputs['temporal_coherence']:.3f}}")
```
### Video Generation
```python
# Generate video from scratch
model.eval()
with torch.no_grad():
generated = model.generate(n_frames=16, temperature=0.8)
print(f"Generated shape: {{generated.shape}}")
# Generate from initial latent
z_init = torch.randn(1, {self.config.latent_channels}, 4, 8, 8).cuda()
with torch.no_grad():
generated = model.generate(n_frames=16, z_init=z_init)
```
### Future Frame Prediction
```python
# Encode context frames
video_context = torch.randn(1, 3, 8, 64, 64).cuda()
model.eval()
with torch.no_grad():
z, _, _, _ = model.encode(video_context)
# Predict future frames
z_future = model.predict_frames(z, n_frames=8)
print(f"Predicted latent shape: {{z_future.shape}}")
```
### Video Continuation
```python
# Continue an existing video
video_input = torch.randn(1, 3, 8, 64, 64).cuda()
model.eval()
with torch.no_grad():
continued = model.continue_video(video_input, n_frames=8)
print(f"Continued video shape: {{continued.shape}}") # [1, 3, 16, 64, 64]
```
### Fine-tuning
```python
from torch.optim import AdamW
# Load pretrained model
model = NutataModel.from_pretrained("{self.config.hub_model_id}")
model.train()
# Setup optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)
# Fine-tune on your dataset
for batch in dataloader:
video = batch['video'].cuda()
outputs = model(video)
loss = outputs['loss']
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save fine-tuned model
model.save_pretrained("./my_finetuned_model")
```
## πŸ“ Repository Structure
```
.
β”œβ”€β”€ README.md # This file
β”œβ”€β”€ config.json # Model configuration with dynamic state
β”œβ”€β”€ pytorch_model.bin # Model weights
β”œβ”€β”€ architecture.json # Architecture details
β”œβ”€β”€ training_state.json # Training state (for resuming)
β”œβ”€β”€ nutata_model.py # Model source code
β”œβ”€β”€ requirements.txt # Dependencies
└── LICENSE # Proprietary license
```
## πŸ”¬ Technical Details
### Grouped Query Attention (GQA)
GQA reduces memory usage by sharing Key-Value heads across query heads:
- **Query heads**: {self.config.gqa_num_heads}
- **KV groups**: {self.config.gqa_num_kv_groups}
- **Memory savings**: {int((1 - self.config.gqa_num_kv_groups / self.config.gqa_num_heads) * 100)}%
### EARCP (Ensemble Auto-Regulated by Coherence & Performance)
Dynamic expert routing with automatic growth:
- Initial experts: {len(self.config.expert_types)}
- Maximum experts: {self.config.max_experts}
- Growth threshold: coherence < {self.config.growth_threshold_coherence}
### Neurogenesis
Adaptive neural capacity:
- Initial neurons: 64
- Maximum neurons: {self.config.max_neurons}
- Birth threshold: coherence > {self.config.neuron_birth_threshold}
### Checkpoint Compatibility
This model uses a specific architecture structure for checkpoint compatibility:
- `EncoderStage`/`DecoderStage` as `nn.ModuleList` (creates `encoders.0.0`, `encoders.0.1`, etc.)
- `temporal_coherence.smooth` as depthwise `Conv1d` (groups=d_model)
- `lpol.memory_attn` as `nn.ModuleDict` with q/k/v/o_proj
- `VideoExpert` with `fc1`/`fc2` Linear layers
- `attn_norm` as `LayerNorm` with `elementwise_affine=False`
## πŸ“ˆ Training Metrics
The model was trained with the following target metrics:
- **PSNR**: > 25 dB
- **SSIM**: > 0.95
- **Temporal Coherence**: > 0.35
## πŸ”§ Model Diagnostics
```python
# Get comprehensive diagnostics
diag = model.diagnostics()
print(f"Total parameters: {{diag['total_params']:,}}")
print(f"Expert counts: {{diag['expert_counts']}}")
print(f"Neurogenesis stats: {{diag['neurogenesis']}}")
print(f"Energy stats: {{diag['energy']}}")
```
## ⚠️ License
This model is released under the **Ame Web Studio Proprietary License**.
**IMPORTANT**: This is NOT an open-source model. Usage is restricted to:
- Personal research and evaluation
- Non-commercial academic use
- Commercial use requires explicit licensing agreement
For licensing inquiries, contact: **contact@amewebstudio.com**
## πŸ“š Citation
If you use this model in your research, please cite:
```bibtex
@software{{nutata_model,
author = {{Amega, Mike (Logo)}},
title = {{NUTATA: Cognitive Video World Model}},
version = {{{self.config.version}}},
year = {{2026}},
publisher = {{Ame Web Studio}},
url = {{https://huggingface.co/{self.config.hub_model_id}}}
}}
```
## πŸ”— Links
- **Author**: Mike Amega (Logo)
- **Organization**: Ame Web Studio
- **Website**: [amewebstudio.com](https://amewebstudio.com)
- **Related Projects**: NEXUS-Core, EARCP Library, LPOL Architecture
---
<p align="center">
<b>Built with 🧠 by Ame Web Studio</b><br>
<i>"Learning to Understand and Generate Video with Cognition"</i>
</p>
'''
return readme
def _generate_license(self) -> str:
"""Generate proprietary LICENSE file"""
return """AME WEB STUDIO PROPRIETARY LICENSE
Version 1.0, January 2026
Copyright (c) 2026 Mike Amega (Logo) - Ame Web Studio
All Rights Reserved.
================================================================================
NUTATA License
================================================================================
This software and associated documentation files (the "Software") are the
proprietary property of Ame Web Studio and Mike Amega ("Logo").
GRANT OF LICENSE
----------------
Subject to the terms of this license, you are granted a limited, non-exclusive,
non-transferable license to:
1. PERSONAL USE: Use the Software for personal, non-commercial research and
evaluation purposes.
2. ACADEMIC USE: Use the Software for non-commercial academic research,
provided that:
- Proper attribution is given in any publications
- The Software is not redistributed
- No commercial benefit is derived
RESTRICTIONS
------------
You may NOT:
1. Use the Software for any commercial purpose without a separate commercial
license agreement with Ame Web Studio.
2. Modify, adapt, translate, reverse engineer, decompile, disassemble, or
create derivative works based on the Software for commercial purposes.
3. Redistribute, sublicense, rent, lease, or lend the Software to any third
party.
4. Remove or alter any proprietary notices, labels, or marks on the Software.
5. Use the Software to train competing AI models for commercial deployment.
6. Claim ownership or authorship of the Software or its components.
COMMERCIAL LICENSING
--------------------
For commercial use, including but not limited to:
- Integration into commercial products or services
- Use in production environments
- Offering services based on the Software
Please contact: contact@amewebstudio.com
INTELLECTUAL PROPERTY
---------------------
The Software embodies proprietary algorithms and architectures including but
not limited to:
- NEXUS Cognitive Architecture
- EARCP (Ensemble Auto-Regulated by Coherence & Performance)
- LPOL (Long-term Procedural & Operational Learning)
- VideoGQA (Video Grouped Query Attention)
- Neurogenesis mechanisms
- Temporal Coherence algorithms
- Flow Prediction systems
These components are trade secrets and proprietary technology of Ame Web Studio.
WARRANTY DISCLAIMER
-------------------
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
LIMITATION OF LIABILITY
-----------------------
IN NO EVENT SHALL AME WEB STUDIO, MIKE AMEGA, OR ANY CONTRIBUTORS BE LIABLE
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
TERMINATION
-----------
This license is effective until terminated. Your rights under this license
will terminate automatically without notice if you fail to comply with any
of its terms.
GOVERNING LAW
-------------
This license shall be governed by and construed in accordance with the laws
of Canada, without regard to its conflict of law provisions.
================================================================================
For licensing inquiries:
- Email: contact@amewebstudio.com
- Website: https://amewebstudio.com
================================================================================
"""
# ==============================================================================
# SECTION 16: DATASETS
# ==============================================================================
class SyntheticVideoDataset(Dataset):
"""Generate synthetic videos with moving shapes for testing"""
def __init__(
self,
n_videos: int = 1000,
n_frames: int = 16,
height: int = 64,
width: int = 64,
):
self.n_videos = n_videos
self.n_frames = n_frames
self.height = height
self.width = width
def __len__(self) -> int:
return self.n_videos
def __getitem__(self, idx: int) -> Dict:
np.random.seed(idx)
video = np.zeros((self.n_frames, self.height, self.width, 3), dtype=np.float32)
# Background gradient
bg = np.linspace(0, 0.3, self.height).reshape(-1, 1, 1)
video += bg
n_objects = np.random.randint(2, 5)
for obj in range(n_objects):
color = np.random.rand(3) * 0.7 + 0.3
x = np.random.rand() * (self.width - 15) + 5
y = np.random.rand() * (self.height - 15) + 5
vx = np.random.randn() * 2
vy = np.random.randn() * 2
size = np.random.randint(5, 12)
shape = np.random.choice(["circle", "square", "triangle"])
for t in range(self.n_frames):
x += vx
y += vy
# Bounce
if x < size or x > self.width - size:
vx = -vx * 0.9
x = np.clip(x, size, self.width - size)
if y < size or y > self.height - size:
vy = -vy * 0.9
y = np.clip(y, size, self.height - size)
ix, iy = int(x), int(y)
if shape == "circle":
for dx in range(-size, size + 1):
for dy in range(-size, size + 1):
if dx * dx + dy * dy <= size * size:
px, py = ix + dx, iy + dy
if 0 <= px < self.width and 0 <= py < self.height:
video[t, py, px] = color
elif shape == "square":
x1, x2 = max(0, ix - size // 2), min(self.width, ix + size // 2)
y1, y2 = max(0, iy - size // 2), min(self.height, iy + size // 2)
video[t, y1:y2, x1:x2] = color
else: # triangle
for dy in range(size):
width_at_y = (size - dy) // 2
for dx in range(-width_at_y, width_at_y + 1):
px, py = ix + dx, iy + dy - size // 2
if 0 <= px < self.width and 0 <= py < self.height:
video[t, py, px] = color
video = np.clip(video, 0, 1)
video = torch.tensor(video).permute(0, 3, 1, 2) # [T, C, H, W]
return {"video": video}
class VideoFolderDataset(Dataset):
"""Load videos from a folder (placeholder for real datasets)"""
def __init__(
self,
folder_path: str,
n_frames: int = 16,
height: int = 64,
width: int = 64,
transform=None,
):
self.folder_path = folder_path
self.n_frames = n_frames
self.height = height
self.width = width
self.transform = transform
# Find all video files
self.video_files = []
if os.path.exists(folder_path):
for ext in ["*.mp4", "*.avi", "*.mov", "*.mkv"]:
import glob
self.video_files.extend(glob.glob(os.path.join(folder_path, ext)))
if not self.video_files:
logger.warning(f"No video files found in {folder_path}")
def __len__(self) -> int:
return max(1, len(self.video_files))
def __getitem__(self, idx: int) -> Dict:
# Placeholder - would need proper video loading
video = torch.randn(self.n_frames, 3, self.height, self.width)
video = (video - video.min()) / (video.max() - video.min())
return {"video": video}
# ==============================================================================
# SECTION 17: TRAINING
# ==============================================================================
def train_video_model(
model: NutataModel,
dataset: Dataset,
config: NutataModelConfig,
save_path: str = "./nutata_model_v1_1.pt",
push_to_hub: bool = False,
hub_token: str = None,
):
"""
Train NUTATA v1.1
Args:
model: NutataModel instance
dataset: Training dataset
config: Model configuration
save_path: Path to save checkpoint
push_to_hub: Whether to push to HuggingFace Hub after training
hub_token: HuggingFace API token (required if push_to_hub=True)
"""
logger.info("\n" + "=" * 70)
logger.info("πŸš€ NUTATA v1.1 TRAINING")
logger.info("=" * 70)
dataloader = DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=2,
drop_last=True,
pin_memory=True,
)
logger.info(f"πŸ“Š Dataset: {len(dataset)} videos")
logger.info(f"πŸ“Š Batches per epoch: {len(dataloader)}")
logger.info(f"πŸ“Š Batch size: {config.batch_size}")
logger.info(f"πŸ“Š Epochs: {config.epochs}")
optimizer = AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=0.01,
betas=(0.9, 0.95),
)
# Cosine annealing with warmup
total_steps = config.epochs * len(dataloader)
def lr_lambda(step):
if step < config.warmup_steps:
return step / config.warmup_steps
progress = (step - config.warmup_steps) / (total_steps - config.warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = GradScaler()
model.train()
best_loss = float("inf")
global_step = 0
for epoch in range(config.epochs):
epoch_metrics = {
"loss": 0,
"recon": 0,
"kl": 0,
"coherence": 0,
"psnr": 0,
"ssim": 0,
}
total_coherence = 0
total_neuro = 0
total_experts = 0
pbar = tqdm(
enumerate(dataloader),
total=len(dataloader),
desc=f"Epoch {epoch + 1}/{config.epochs}",
ncols=120,
)
for batch_idx, batch in pbar:
video = batch["video"].to(device)
# Rearrange to [B, C, T, H, W]
if video.dim() == 5 and video.shape[1] != 3:
video = video.permute(0, 2, 1, 3, 4)
optimizer.zero_grad()
with autocast():
outputs = model(video)
loss = outputs["loss"]
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
global_step += 1
# Metrics
with torch.no_grad():
video_comp = (
video if video.shape[1] == 3 else video.permute(0, 2, 1, 3, 4)
)
psnr = compute_psnr(outputs["recon"], video_comp)
ssim = compute_ssim(outputs["recon"], video_comp)
epoch_metrics["loss"] += outputs["loss"].item()
epoch_metrics["recon"] += outputs["recon_loss"].item()
epoch_metrics["kl"] += outputs["kl_loss"].item()
coh_loss = outputs["coherence_loss"]
epoch_metrics["coherence"] += (
float(coh_loss) if not hasattr(coh_loss, "item") else coh_loss
)
epoch_metrics["psnr"] += psnr
epoch_metrics["ssim"] += ssim
total_coherence += outputs["coherence"]
total_neuro += outputs["neurogenesis"]
total_experts += outputs["expert_growth"]
if batch_idx % 5 == 0:
n = batch_idx + 1
pbar.set_postfix(
{
"loss": f"{epoch_metrics['loss'] / n:.4f}",
"psnr": f"{epoch_metrics['psnr'] / n:.1f}",
"ssim": f"{epoch_metrics['ssim'] / n:.3f}",
"coh": f"{total_coherence / n:.3f}",
"lr": f"{scheduler.get_last_lr()[0]:.6f}",
}
)
n_batches = len(dataloader)
for k in epoch_metrics:
epoch_metrics[k] /= n_batches
logger.info(
f"Epoch {epoch + 1}/{config.epochs} | "
f"Loss: {epoch_metrics['loss']:.4f} | "
f"PSNR: {epoch_metrics['psnr']:.2f} dB | "
f"SSIM: {epoch_metrics['ssim']:.4f} | "
f"Coherence: {total_coherence / n_batches:.3f} | "
f"Neuro: +{total_neuro} | "
f"Experts: +{total_experts}"
)
if epoch_metrics["loss"] < best_loss:
best_loss = epoch_metrics["loss"]
# Save checkpoint
checkpoint = {
"model": model.state_dict(),
"config": config.to_dict(),
"epoch": epoch,
"loss": best_loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"metrics": epoch_metrics,
}
torch.save(checkpoint, save_path)
logger.info(f" πŸ’Ύ Best model saved (loss: {best_loss:.4f})")
# Also save to directory for Hub
save_dir = save_path.replace(".pt", "_best")
model.save_pretrained(save_dir)
logger.info("\nπŸŽ‰ Training complete!")
logger.info(f"πŸ“Š Best loss: {best_loss:.4f}")
logger.info(f"πŸ“Š Final PSNR: {epoch_metrics['psnr']:.2f} dB")
logger.info(f"πŸ“Š Final SSIM: {epoch_metrics['ssim']:.4f}")
# Save final model
save_dir = save_path.replace(".pt", "_final")
model.save_pretrained(save_dir)
# Push to HuggingFace Hub if requested
if push_to_hub and config.push_to_hub:
if hub_token is None:
logger.warning(
"⚠️ No HuggingFace token provided. Set HF_TOKEN environment variable or pass hub_token argument."
)
hub_token = os.environ.get("HF_TOKEN")
if hub_token:
try:
url = model.push_to_hub(
repo_id=config.hub_model_id,
token=hub_token,
private=True,
commit_message=f"Upload NUTATA v{config.version} - Loss: {best_loss:.4f}, PSNR: {epoch_metrics['psnr']:.2f}",
save_directory=save_dir,
)
logger.info(f"πŸš€ Model pushed to HuggingFace Hub: {url}")
except Exception as e:
logger.error(f"❌ Failed to push to HuggingFace Hub: {e}")
import traceback
traceback.print_exc()
else:
logger.warning("⚠️ Skipping HuggingFace Hub push - no token available")
return model
# ==============================================================================
# SECTION 18: UTILITY FUNCTIONS
# ==============================================================================
def load_nutata_model(
pretrained_path: str = "amewebstudio/nutata-videomodel-v1.1", device: str = None
) -> NutataModel:
"""
Convenience function to load NUTATA.
Args:
pretrained_path: HuggingFace repo ID or local path
device: Device to load to (auto-detected if None)
Returns:
Loaded NutataModel
Example:
>>> model = load_nutata_model()
>>> video = torch.randn(1, 3, 16, 64, 64).cuda()
>>> outputs = model(video)
"""
return NutataModel.from_pretrained(pretrained_path, device=device)
def create_nutata_model(
d_model: int = 512,
n_layers: int = 4,
n_frames: int = 16,
height: int = 64,
width: int = 64,
device: str = None,
**kwargs,
) -> NutataModel:
"""
Create a new NUTATA with custom configuration.
Args:
d_model: Model dimension
n_layers: Number of EARCP layers
n_frames: Number of video frames
height: Frame height
width: Frame width
device: Device to create on
**kwargs: Additional config parameters
Returns:
New NutataModel instance
Example:
>>> model = create_nutata_model(d_model=768, n_layers=6)
"""
video_config = VideoConfig(height=height, width=width, n_frames=n_frames)
config = NutataModelConfig(
d_model=d_model, n_layers=n_layers, video=video_config, **kwargs
)
model = NutataModel(config)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
return model.to(device)
def benchmark_model(model: NutataModel, n_iterations: int = 10) -> Dict:
"""
Benchmark model performance
Args:
model: NutataModel instance
n_iterations: Number of iterations for benchmarking
Returns:
Dict with timing results
"""
import time
model.eval()
device = next(model.parameters()).device
B = 1
T = model.config.video.n_frames
H = model.config.video.height
W = model.config.video.width
C = model.config.video.channels
video = torch.randn(B, C, T, H, W, device=device)
# Warmup
with torch.no_grad():
for _ in range(3):
_ = model(video)
if device.type == "cuda":
torch.cuda.synchronize()
# Benchmark forward pass
forward_times = []
for _ in range(n_iterations):
start = time.time()
with torch.no_grad():
_ = model(video)
if device.type == "cuda":
torch.cuda.synchronize()
forward_times.append(time.time() - start)
# Benchmark generation
gen_times = []
for _ in range(n_iterations):
start = time.time()
with torch.no_grad():
_ = model.generate(n_frames=T)
if device.type == "cuda":
torch.cuda.synchronize()
gen_times.append(time.time() - start)
return {
"forward_mean": np.mean(forward_times),
"forward_std": np.std(forward_times),
"forward_fps": 1.0 / np.mean(forward_times),
"generation_mean": np.mean(gen_times),
"generation_std": np.std(gen_times),
"generation_fps": 1.0 / np.mean(gen_times),
"device": str(device),
}
# ==============================================================================
# SECTION 19: MAIN
# ==============================================================================
def main():
print("\n" + "=" * 70)
print("""
╔═══════════════════════════════════════════════════════════════════╗
β•‘ β•‘
β•‘ NUTATA v1.1 COMPLETE β•‘
β•‘ Cognitive Video World Model β•‘
β•‘ β•‘
β•‘ β€’ 3D VAE (Causal Spatiotemporal) β•‘
β•‘ β€’ EARCP with Temporal Attention β•‘
β•‘ β€’ LPOL Memory (9 Video Domains) β•‘
β•‘ β€’ GQA for Long Sequences β•‘
β•‘ β€’ Neurogenesis β•‘
β•‘ β€’ Temporal Coherence Module β•‘
β•‘ β€’ Flow Prediction β•‘
β•‘ β€’ HuggingFace Hub Integration β•‘
β•‘ β•‘
β•‘ Author: Mike Amega - Ame Web Studio β•‘
β•‘ Version: 1.1 (Complete with Architecture Fixes) β•‘
β•‘ β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
""")
print("=" * 70 + "\n")
# =========================================================================
# HUGGINGFACE HUB CONFIGURATION
# =========================================================================
HF_TOKEN = os.environ.get("HF_TOKEN", "")
HF_REPO_ID = "amewebstudio/nutata-v1.0-finetuned"
# Set token in environment for huggingface_hub library
if HF_TOKEN:
os.environ["HF_TOKEN"] = HF_TOKEN
logger.info(f"πŸ”‘ HuggingFace Token: Configured")
logger.info(f"πŸ“¦ Target Repository: {HF_REPO_ID}")
# Create config
config = NutataModelConfig(
epochs=20,
batch_size=4,
n_layers=4,
d_model=512,
learning_rate=1e-4,
warmup_steps=100,
push_to_hub=True,
hub_model_id=HF_REPO_ID,
)
# Device info
logger.info(f"\nπŸ”§ Device: {device}")
if torch.cuda.is_available():
logger.info(f"πŸ”§ GPU: {torch.cuda.get_device_name(0)}")
logger.info(
f"πŸ”§ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
)
# Create model
logger.info("\nπŸ“¦ Creating model...")
model = NutataModel(config)
model = model.to(device)
# Diagnostics
logger.info("\nπŸ“Š Model Diagnostics:")
diag = model.diagnostics()
for k, v in diag.items():
logger.info(f" {k}: {v}")
# Create synthetic dataset
logger.info("\nπŸ“š Creating synthetic video dataset...")
dataset = SyntheticVideoDataset(
n_videos=500,
n_frames=config.video.n_frames,
height=config.video.height,
width=config.video.width,
)
logger.info(f" Videos: {len(dataset)}")
logger.info(
f" Shape: [{config.video.n_frames}, {config.video.channels}, {config.video.height}, {config.video.width}]"
)
# Quick forward pass test
logger.info("\nπŸ§ͺ Testing forward pass...")
test_batch = dataset[0]["video"].unsqueeze(0).to(device)
if test_batch.shape[1] != 3:
test_batch = test_batch.permute(0, 2, 1, 3, 4)
with torch.no_grad():
test_out = model(test_batch)
logger.info(f" βœ… Forward pass successful!")
logger.info(f" Input shape: {test_batch.shape}")
logger.info(f" Recon shape: {test_out['recon'].shape}")
logger.info(f" Loss: {test_out['loss'].item():.4f}")
logger.info(f" Coherence: {test_out['coherence']:.3f}")
# Train
logger.info("\n" + "=" * 70)
model = train_video_model(
model,
dataset,
config,
save_path="./nutata_model_v1_1.pt",
push_to_hub=True,
hub_token=HF_TOKEN,
)
# Final diagnostics
logger.info("\nπŸ“Š Final Diagnostics:")
diag = model.diagnostics()
for k, v in diag.items():
logger.info(f" {k}: {v}")
# Test generation
logger.info("\n🎬 Testing video generation...")
model.eval()
with torch.no_grad():
generated = model.generate(n_frames=16, temperature=0.8)
logger.info(f" Generated shape: {generated.shape}")
logger.info(
f" Value range: [{generated.min().item():.3f}, {generated.max().item():.3f}]"
)
# Test video continuation
logger.info("\n🎬 Testing video continuation...")
with torch.no_grad():
continued = model.continue_video(test_batch, n_frames=8)
logger.info(f" Input shape: {test_batch.shape}")
logger.info(f" Continued shape: {continued.shape}")
# Benchmark
logger.info("\n⏱️ Benchmarking model...")
bench = benchmark_model(model, n_iterations=5)
for k, v in bench.items():
if isinstance(v, float):
logger.info(f" {k}: {v:.4f}")
else:
logger.info(f" {k}: {v}")
# Demonstrate save/load
logger.info("\nπŸ’Ύ Testing save/load functionality...")
save_dir = "./nutata_model_test_save"
model.save_pretrained(save_dir)
# Test loading
logger.info("\nπŸ“₯ Testing model loading...")
loaded_model = NutataModel.from_pretrained(save_dir, device=str(device))
logger.info(f" βœ… Model loaded successfully!")
logger.info(f" Loaded params: {loaded_model.count_params():,}")
# Verify loaded model works
with torch.no_grad():
loaded_out = loaded_model(test_batch)
logger.info(f" Loaded model loss: {loaded_out['loss'].item():.4f}")
logger.info("\n" + "=" * 70)
logger.info("βœ… ALL TESTS PASSED!")
logger.info("=" * 70)
return model, config
# ==============================================================================
# ENTRY POINT
# ==============================================================================
if __name__ == "__main__":
model, config = main()