| |
| """ |
| ================================================================================ |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β β |
| β ββββ ββββββββββββββ ββββββ βββββββββββ β |
| β βββββ ββββββββββββββββββββββ βββββββββββ β |
| β ββββββ βββββββββ ββββββ βββ βββββββββββ β |
| β ββββββββββββββββ ββββββ βββ βββββββββββ β |
| β βββ ββββββββββββββββββ ββββββββββββββββββββ β |
| β βββ ββββββββββββββββ βββ βββββββ ββββββββ β |
| β β |
| β βββ βββββββββββββ ββββββββ βββββββ ββββ ββββ βββββββ βββββββ βββββββββββ |
| β βββ ββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββ |
| β βββ βββββββββ βββββββββ βββ βββββββββββββββββ ββββββ βββββββββ βββ |
| β ββββ ββββββββββ βββββββββ βββ βββββββββββββββββ ββββββ βββββββββ βββ |
| β βββββββ βββββββββββββββββββββββββββββββ βββ ββββββββββββββββββββββββββββββββ |
| β βββββ ββββββββββ ββββββββ βββββββ βββ βββ βββββββ βββββββ ββββββββββββ |
| β β |
| β NUTATA v1.1 COMPLETE β |
| β "Learning to Understand and Generate Video with Cognition" β |
| β β |
| β Full NEXUS Cognitive Architecture for Video: β |
| β β’ Causal 3D VAE (Spatial + Temporal compression) β |
| β β’ EARCP with Temporal Attention β |
| β β’ LPOL Memory (Video-specific domains) β |
| β β’ GQA for efficient long-sequence processing β |
| β β’ Neurogenesis for adaptive capacity β |
| β β’ Temporal Coherence Module β |
| β β’ Frame Prediction & Generation β |
| β β’ Flow Prediction (NEW in v1.1) β |
| β β’ Hierarchical Memory (NEW in v1.1) β |
| β β |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| ================================================================================ |
| |
| Author: Mike Amega (Logo) - Ame Web Studio |
| License: Proprietary (Ame Web Studio) |
| Version: 1.1 - Complete with Architecture Fixes |
| |
| CRITICAL ARCHITECTURE FIXES (for checkpoint compatibility): |
| - EncoderStage/DecoderStage as nn.ModuleList (creates encoders.0.0, encoders.0.1, encoders.0.2) |
| - temporal_coherence.smooth as depthwise Conv1d (groups=d_model) |
| - LPOL memory_attn as nn.ModuleDict with q_proj, k_proj, v_proj, o_proj |
| - VideoExpert with fc1/fc2 Linear layers (not SwiGLU) |
| - attn_norm as nn.LayerNorm with elementwise_affine=False |
| - All projection layers with bias=False where needed |
| |
| Changes from v1.0: |
| - Fixed RoPE dimension mismatch in VideoGQA |
| - Fixed LPOL memory attention mechanism |
| - Added FlashAttention-style efficient attention |
| - Added Flow Prediction module |
| - Improved temporal coherence |
| - Better gradient flow with residual scaling |
| - Added video quality metrics (PSNR, SSIM) |
| |
| ================================================================================ |
| """ |
|
|
| import os |
| import sys |
| import math |
| import json |
| import random |
| import logging |
| import shutil |
| import tempfile |
| from datetime import datetime |
| from typing import Dict, List, Optional, Tuple, Any, Union |
| from dataclasses import dataclass, field |
| from collections import deque |
| from enum import Enum |
|
|
| import numpy as np |
|
|
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.cuda.amp import GradScaler, autocast |
| from torch.optim import AdamW |
| from tqdm import tqdm |
|
|
| |
| try: |
| from huggingface_hub import HfApi, hf_hub_download, snapshot_download, create_repo |
|
|
| HF_HUB_AVAILABLE = True |
| except ImportError: |
| HF_HUB_AVAILABLE = False |
| print( |
| "Warning: huggingface_hub not installed. Install with: pip install huggingface_hub" |
| ) |
|
|
| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s" |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class VideoConfig: |
| """Video-specific configuration - v2.0 with high-resolution support""" |
|
|
| height: int = 256 |
| width: int = 256 |
| channels: int = 3 |
| n_frames: int = 16 |
| fps: int = 8 |
| temporal_downsample: int = 4 |
| spatial_downsample: int = 8 |
|
|
| @classmethod |
| def from_dict(cls, d: dict) -> "VideoConfig": |
| valid_keys = { |
| "height", |
| "width", |
| "channels", |
| "n_frames", |
| "fps", |
| "temporal_downsample", |
| "spatial_downsample", |
| } |
| return cls(**{k: v for k, v in d.items() if k in valid_keys}) |
|
|
|
|
| @dataclass |
| class NutataModelConfig: |
| """ |
| NUTATA v2.0 Configuration - HIGH RESOLUTION + CONDITIONING |
| |
| Major improvements: |
| - Progressive resolution training (64 β 512) |
| - Perceptual loss (VGG + LPIPS) |
| - Multi-level conditioning (text/action/scene) |
| - Hierarchical temporal memory |
| - Optical flow consistency |
| - Anti-blur sharpening |
| """ |
|
|
| |
| model_type: str = "nutata-videomodel" |
| version: str = "2.0" |
| codename: str = "VideoSim-Cognitive-HD" |
| architecture_type: str = "cognitive-video-conditioned" |
|
|
| |
| d_model: int = 512 |
| d_ff: int = 2048 |
| n_layers: int = 8 |
| n_heads: int = 8 |
| dropout: float = 0.1 |
|
|
| |
| latent_dim: int = 256 |
| temporal_latent_dim: int = 512 |
| latent_channels: int = 4 |
|
|
| |
| max_frames: int = 64 |
| context_frames: int = 16 |
| prediction_frames: int = 8 |
|
|
| |
| encoder_channels: List[int] = field(default_factory=lambda: [64, 128, 256, 512]) |
| decoder_channels: List[int] = field(default_factory=lambda: [512, 256, 128, 64]) |
| kl_weight: float = 0.0001 |
| use_skip_connections: bool = True |
|
|
| |
| progressive_resolution: List[int] = field( |
| default_factory=lambda: [64, 128, 256, 384, 512] |
| ) |
| current_resolution_idx: int = 0 |
| resolution_warmup_epochs: int = 5 |
|
|
| |
| use_lpol: bool = True |
| memory_size: int = 256 |
| memory_slots_per_domain: int = 32 |
| memory_k: int = 8 |
| domain_types: List[str] = field( |
| default_factory=lambda: [ |
| "motion", |
| "appearance", |
| "temporal", |
| "spatial", |
| "object", |
| "scene", |
| "action", |
| "causality", |
| "physics", |
| ] |
| ) |
|
|
| |
| use_gqa: bool = True |
| gqa_num_heads: int = 8 |
| gqa_num_kv_groups: int = 2 |
|
|
| |
| expert_types: List[str] = field( |
| default_factory=lambda: [ |
| "Motion", |
| "Appearance", |
| "Temporal", |
| "Spatial", |
| "Prediction", |
| "Generation", |
| ] |
| ) |
| max_experts: int = 12 |
| growth_threshold_coherence: float = 0.3 |
| growth_patience: int = 10 |
|
|
| |
| neurogenesis_enabled: bool = True |
| min_neurons: int = 32 |
| max_neurons: int = 256 |
| neuron_birth_threshold: float = 0.8 |
| neuron_death_threshold: float = 0.05 |
|
|
| |
| energy_enabled: bool = True |
| energy_cost_encode: float = 0.01 |
| energy_cost_decode: float = 0.02 |
| energy_cost_predict: float = 0.03 |
| energy_regeneration: float = 0.05 |
|
|
| |
| dream_enabled: bool = True |
| dream_cycle_length: int = 100 |
| dream_duration: int = 20 |
|
|
| |
| temporal_coherence_weight: float = 0.1 |
| flow_prediction: bool = True |
|
|
| |
| use_perceptual_loss: bool = True |
| perceptual_loss_weight: float = 0.1 |
| perceptual_adaptive: bool = True |
| perceptual_max_weight: float = 0.2 |
| perceptual_warmup_epochs: int = 10 |
| lpips_weight: float = 0.05 |
| vgg_layers: List[str] = field( |
| default_factory=lambda: ["relu1_2", "relu2_2", "relu3_4", "relu4_4"] |
| ) |
|
|
| |
| use_optical_flow: bool = True |
| optical_flow_weight: float = 0.05 |
|
|
| |
| use_conditioning: bool = True |
| text_condition_dim: int = 768 |
| num_action_classes: int = 400 |
| num_scene_classes: int = 365 |
| condition_injection_levels: List[str] = field( |
| default_factory=lambda: ["latent", "decoder", "temporal", "memory"] |
| ) |
|
|
| |
| use_hierarchical_memory: bool = True |
| memory_scales: List[int] = field(default_factory=lambda: [1, 4, 16]) |
| frame_memory_slots: int = 64 |
| clip_memory_slots: int = 32 |
| scene_memory_slots: int = 16 |
|
|
| |
| use_multiscale_temporal: bool = True |
| temporal_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 15]) |
|
|
| |
| use_sharpening: bool = True |
| sharpening_weight: float = 0.1 |
|
|
| |
| batch_size: int = 2 |
| learning_rate: float = 1e-4 |
| epochs: int = 50 |
| gradient_accumulation: int = 8 |
| warmup_steps: int = 500 |
| use_gradient_checkpointing: bool = True |
|
|
| |
| push_to_hub: bool = True |
| hub_model_id: str = "amewebstudio/nutata-videomodel-v2.0" |
|
|
| |
| video: VideoConfig = field(default_factory=VideoConfig) |
|
|
| def to_dict(self) -> Dict: |
| d = {} |
| for key, value in self.__dict__.items(): |
| if key == "video": |
| d[key] = {k: v for k, v in value.__dict__.items()} |
| elif isinstance(value, (list, dict, str, int, float, bool, type(None))): |
| d[key] = value |
| return d |
|
|
| @classmethod |
| def from_dict(cls, d: Dict) -> "NutataModelConfig": |
| d = d.copy() |
| if "video" in d and isinstance(d["video"], dict): |
| d["video"] = VideoConfig.from_dict(d["video"]) |
| |
| for k in ["_dynamic_state", "_class_name", "_version", "_architecture"]: |
| d.pop(k, None) |
| known = set(cls.__dataclass_fields__.keys()) |
| return cls(**{k: v for k, v in d.items() if k in known}) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return x * rms * self.weight |
|
|
|
|
| class LayerNormCustom(nn.Module): |
| """Layer Normalization with optional bias""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6, bias: bool = True): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.bias = nn.Parameter(torch.zeros(dim)) if bias else None |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| mean = x.mean(-1, keepdim=True) |
| var = x.var(-1, keepdim=True, unbiased=False) |
| x = (x - mean) / torch.sqrt(var + self.eps) |
| x = x * self.weight |
| if self.bias is not None: |
| x = x + self.bias |
| return x |
|
|
|
|
| class SwiGLU(nn.Module): |
| """SwiGLU activation""" |
|
|
| def __init__( |
| self, in_features: int, hidden_features: int = None, out_features: int = None |
| ): |
| super().__init__() |
| hidden_features = hidden_features or in_features * 4 |
| out_features = out_features or in_features |
|
|
| self.w1 = nn.Linear(in_features, hidden_features, bias=False) |
| self.w2 = nn.Linear(in_features, hidden_features, bias=False) |
| self.w3 = nn.Linear(hidden_features, out_features, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
| class RotaryPositionalEmbedding(nn.Module): |
| """Rotary Position Embedding (RoPE) - Fixed version""" |
|
|
| def __init__(self, dim: int, max_seq_len: int = 1024, base: int = 10000): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.base = base |
|
|
| |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| |
| self._build_cache(max_seq_len) |
|
|
| def _build_cache(self, seq_len: int): |
| t = torch.arange( |
| seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype |
| ) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
|
|
| self.register_buffer("cos_cache", emb.cos(), persistent=False) |
| self.register_buffer("sin_cache", emb.sin(), persistent=False) |
|
|
| def forward( |
| self, seq_len: int, device: torch.device |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if seq_len > self.cos_cache.shape[0]: |
| self._build_cache(seq_len) |
|
|
| return ( |
| self.cos_cache[:seq_len].to(device), |
| self.sin_cache[:seq_len].to(device), |
| ) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate half the hidden dims of the input""" |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb( |
| q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply rotary positional embedding to Q and K |
| |
| Args: |
| q: [B, H, T, D] query tensor |
| k: [B, H, T, D] key tensor |
| cos: [T, D] cosine embeddings |
| sin: [T, D] sine embeddings |
| |
| Returns: |
| q_embed, k_embed with rotary position encoding applied |
| """ |
| |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
| return q_embed, k_embed |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VideoGQA(nn.Module): |
| """ |
| Grouped Query Attention for Video sequences - CHECKPOINT COMPATIBLE |
| |
| Key features: |
| - Correct RoPE application with proper dimension handling |
| - Efficient memory usage with proper KV expansion |
| - bias=False on all projections to match checkpoint |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.d_model = config.d_model |
| self.num_heads = config.gqa_num_heads |
| self.num_kv_groups = config.gqa_num_kv_groups |
| self.head_dim = config.d_model // self.num_heads |
| self.heads_per_group = self.num_heads // self.num_kv_groups |
| self.scale = self.head_dim**-0.5 |
|
|
| |
| self.q_proj = nn.Linear( |
| config.d_model, self.num_heads * self.head_dim, bias=False |
| ) |
| self.k_proj = nn.Linear( |
| config.d_model, self.num_kv_groups * self.head_dim, bias=False |
| ) |
| self.v_proj = nn.Linear( |
| config.d_model, self.num_kv_groups * self.head_dim, bias=False |
| ) |
| self.o_proj = nn.Linear( |
| self.num_heads * self.head_dim, config.d_model, bias=False |
| ) |
|
|
| |
| self.rope = RotaryPositionalEmbedding( |
| self.head_dim, max_seq_len=config.max_frames * 2 |
| ) |
|
|
| self.dropout = nn.Dropout(config.dropout) |
|
|
| |
| self.residual_scale = nn.Parameter(torch.ones(1) * 0.1) |
|
|
| def forward( |
| self, x: torch.Tensor, causal: bool = True, use_rope: bool = True |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, T, D] input tensor |
| causal: Whether to apply causal masking |
| use_rope: Whether to apply rotary position embeddings |
| |
| Returns: |
| output: [B, T, D] |
| """ |
| B, T, D = x.shape |
|
|
| |
| q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim) |
| k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim) |
| v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim) |
|
|
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| if use_rope: |
| cos, sin = self.rope(T, x.device) |
| |
| q = (q * cos.unsqueeze(0).unsqueeze(0)) + ( |
| rotate_half(q) * sin.unsqueeze(0).unsqueeze(0) |
| ) |
| |
| k = (k * cos.unsqueeze(0).unsqueeze(0)) + ( |
| rotate_half(k) * sin.unsqueeze(0).unsqueeze(0) |
| ) |
|
|
| |
| |
| k = k.repeat_interleave(self.heads_per_group, dim=1) |
| v = v.repeat_interleave(self.heads_per_group, dim=1) |
|
|
| |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
|
| |
| if causal: |
| causal_mask = torch.triu( |
| torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1 |
| ) |
| attn_weights = attn_weights.masked_fill(causal_mask, float("-inf")) |
|
|
| |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) |
| attn_weights = self.dropout(attn_weights) |
|
|
| |
| attn_output = torch.matmul(attn_weights, v) |
|
|
| |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) |
|
|
| |
| output = self.o_proj(attn_output) |
|
|
| return output |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """Standard Multi-Head Attention for memory operations""" |
|
|
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.scale = self.head_dim**-0.5 |
|
|
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) |
| self.o_proj = nn.Linear(d_model, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| mask: torch.Tensor = None, |
| ) -> torch.Tensor: |
| B, T_q, _ = q.shape |
| T_k = k.shape[1] |
|
|
| |
| q = self.q_proj(q).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(k).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(v).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
|
| if mask is not None: |
| attn = attn.masked_fill(mask, float("-inf")) |
|
|
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
|
|
| out = torch.matmul(attn, v) |
| out = out.transpose(1, 2).contiguous().view(B, T_q, -1) |
|
|
| return self.o_proj(out) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class PerceptualLossModule(nn.Module): |
| """ |
| Multi-scale perceptual loss using VGG19 features. |
| Computes feature-level differences for sharper, more perceptually accurate reconstructions. |
| |
| Features: |
| - Uses pretrained VGG19 (frozen) |
| - Extracts features at multiple layers |
| - Adaptive weighting based on training epoch |
| """ |
|
|
| VGG_MEAN = [0.485, 0.456, 0.406] |
| VGG_STD = [0.229, 0.224, 0.225] |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| self.layers = config.vgg_layers |
|
|
| |
| self.vgg = None |
| self.layer_indices = { |
| "relu1_2": 4, |
| "relu2_2": 9, |
| "relu3_4": 18, |
| "relu4_4": 27, |
| "relu5_4": 36, |
| } |
|
|
| |
| self.layer_weights = { |
| "relu1_2": 1.0, |
| "relu2_2": 1.0, |
| "relu3_4": 1.0, |
| "relu4_4": 1.0, |
| "relu5_4": 0.5, |
| } |
|
|
| |
| self.current_epoch = 0 |
|
|
| def _load_vgg(self, device): |
| """Lazy load VGG19 to avoid issues in offline environments""" |
| if self.vgg is not None: |
| return |
|
|
| try: |
| from torchvision import models |
|
|
| vgg19 = models.vgg19(weights="IMAGENET1K_V1").features |
| vgg19.eval() |
| for param in vgg19.parameters(): |
| param.requires_grad = False |
| self.vgg = vgg19.to(device) |
| logger.info("π¦ VGG19 loaded for perceptual loss") |
| except Exception as e: |
| logger.warning(f"β οΈ Could not load VGG19: {e}. Perceptual loss disabled.") |
| self.vgg = None |
|
|
| def _normalize(self, x: torch.Tensor) -> torch.Tensor: |
| """Normalize for VGG (ImageNet stats)""" |
| mean = torch.tensor(self.VGG_MEAN, device=x.device).view(1, 3, 1, 1) |
| std = torch.tensor(self.VGG_STD, device=x.device).view(1, 3, 1, 1) |
| return (x - mean) / std |
|
|
| def get_adaptive_weight(self, epoch: int = None) -> float: |
| """Compute adaptive weight: min(max_weight, epoch / warmup * max_weight)""" |
| if not self.config.perceptual_adaptive: |
| return self.config.perceptual_loss_weight |
|
|
| epoch = epoch if epoch is not None else self.current_epoch |
| warmup = self.config.perceptual_warmup_epochs |
| max_w = self.config.perceptual_max_weight |
| return min(max_w, (epoch / warmup) * max_w) if warmup > 0 else max_w |
|
|
| def forward( |
| self, pred: torch.Tensor, target: torch.Tensor, epoch: int = None |
| ) -> torch.Tensor: |
| """ |
| Compute perceptual loss between predicted and target videos. |
| |
| Args: |
| pred: [B, C, T, H, W] predicted video |
| target: [B, C, T, H, W] target video |
| epoch: current training epoch (for adaptive weighting) |
| |
| Returns: |
| Perceptual loss (scalar) |
| """ |
| self._load_vgg(pred.device) |
|
|
| if self.vgg is None: |
| return torch.tensor(0.0, device=pred.device) |
|
|
| B, C, T, H, W = pred.shape |
|
|
| |
| pred_frames = pred.transpose(1, 2).reshape(B * T, C, H, W) |
| target_frames = target.transpose(1, 2).reshape(B * T, C, H, W) |
|
|
| |
| if H < 64 or W < 64: |
| pred_frames = F.interpolate( |
| pred_frames, size=(224, 224), mode="bilinear", align_corners=False |
| ) |
| target_frames = F.interpolate( |
| target_frames, size=(224, 224), mode="bilinear", align_corners=False |
| ) |
|
|
| |
| pred_norm = self._normalize(pred_frames) |
| target_norm = self._normalize(target_frames) |
|
|
| |
| loss = 0.0 |
| pred_feat = pred_norm |
| target_feat = target_norm |
|
|
| for i, layer in enumerate(self.vgg): |
| pred_feat = layer(pred_feat) |
| target_feat = layer(target_feat) |
|
|
| |
| for layer_name, layer_idx in self.layer_indices.items(): |
| if i == layer_idx and layer_name in self.layers: |
| weight = self.layer_weights.get(layer_name, 1.0) |
| loss = loss + weight * F.l1_loss(pred_feat, target_feat) |
|
|
| return loss * self.get_adaptive_weight(epoch) |
|
|
|
|
| class ConditionalEncoder(nn.Module): |
| """ |
| Multi-modal conditioning system for controllable video generation. |
| |
| Supports: |
| - Text β video (CLIP-compatible embeddings) |
| - Action β movement (action class embeddings) |
| - Scene β style (scene descriptor embeddings) |
| |
| Injects conditions at multiple levels as per user request: |
| - Latent space |
| - Decoder |
| - Temporal processor |
| - Memory |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
|
|
| |
| self.text_proj = nn.Sequential( |
| nn.Linear(config.text_condition_dim, d), nn.SiLU(), nn.Linear(d, d) |
| ) |
|
|
| |
| self.action_embed = nn.Embedding(config.num_action_classes, d) |
| self.action_proj = nn.Linear(d, d) |
|
|
| |
| self.scene_embed = nn.Embedding(config.num_scene_classes, d) |
| self.scene_proj = nn.Linear(d, d) |
|
|
| |
| self.cross_attn = MultiHeadAttention(d, 8, dropout=config.dropout) |
|
|
| |
| self.condition_fusion = nn.Sequential( |
| nn.Linear(d * 3, d), nn.SiLU(), nn.Dropout(config.dropout), nn.Linear(d, d) |
| ) |
|
|
| |
| self.gate = nn.Sequential(nn.Linear(d * 2, d), nn.Sigmoid()) |
|
|
| |
| self.level_projs = nn.ModuleDict( |
| { |
| "latent": nn.Linear(d, d), |
| "decoder": nn.Linear(d, d), |
| "temporal": nn.Linear(d, d), |
| "memory": nn.Linear(d, d), |
| } |
| ) |
|
|
| def encode_conditions( |
| self, |
| text_emb: torch.Tensor = None, |
| action_id: torch.Tensor = None, |
| scene_id: torch.Tensor = None, |
| batch_size: int = 1, |
| device: torch.device = None, |
| ) -> torch.Tensor: |
| """ |
| Encode all conditions into a unified conditioning vector. |
| |
| Args: |
| text_emb: [B, text_dim] or [B, T, text_dim] text embeddings (e.g., from CLIP) |
| action_id: [B] action class indices |
| scene_id: [B] scene class indices |
| batch_size: batch size for zero-initialization if no conditions |
| device: device for tensors |
| |
| Returns: |
| condition: [B, D] unified condition vector |
| """ |
| device = device or ( |
| text_emb.device |
| if text_emb is not None |
| else action_id.device |
| if action_id is not None |
| else scene_id.device |
| if scene_id is not None |
| else "cpu" |
| ) |
|
|
| |
| text_cond = torch.zeros(batch_size, self.config.d_model, device=device) |
| action_cond = torch.zeros(batch_size, self.config.d_model, device=device) |
| scene_cond = torch.zeros(batch_size, self.config.d_model, device=device) |
|
|
| |
| if text_emb is not None: |
| if text_emb.dim() == 3: |
| text_emb = text_emb.mean(dim=1) |
| text_cond = self.text_proj(text_emb) |
|
|
| |
| if action_id is not None: |
| action_emb = self.action_embed(action_id) |
| action_cond = self.action_proj(action_emb) |
|
|
| |
| if scene_id is not None: |
| scene_emb = self.scene_embed(scene_id) |
| scene_cond = self.scene_proj(scene_emb) |
|
|
| |
| combined = torch.cat([text_cond, action_cond, scene_cond], dim=-1) |
| condition = self.condition_fusion(combined) |
|
|
| return condition |
|
|
| def inject( |
| self, z: torch.Tensor, condition: torch.Tensor, level: str = "latent" |
| ) -> torch.Tensor: |
| """ |
| Inject condition into latent representation at specified level. |
| |
| Args: |
| z: [B, T, D] latent sequence |
| condition: [B, D] condition vector |
| level: one of 'latent', 'decoder', 'temporal', 'memory' |
| |
| Returns: |
| z_conditioned: [B, T, D] conditioned latent |
| """ |
| B, T, D = z.shape |
|
|
| |
| proj = ( |
| self.level_projs[level] |
| if level in self.level_projs |
| else self.level_projs["latent"] |
| ) |
| cond_proj = proj(condition) |
|
|
| |
| cond_expanded = cond_proj.unsqueeze(1).expand(-1, T, -1) |
|
|
| |
| z_attended = self.cross_attn(z, cond_expanded, cond_expanded) |
|
|
| |
| concat = torch.cat([z, z_attended], dim=-1) |
| gate = self.gate(concat) |
|
|
| z_conditioned = z + gate * z_attended |
|
|
| return z_conditioned |
|
|
| def forward( |
| self, |
| z: torch.Tensor, |
| text_emb: torch.Tensor = None, |
| action_id: torch.Tensor = None, |
| scene_id: torch.Tensor = None, |
| level: str = "latent", |
| ) -> torch.Tensor: |
| """ |
| Full forward: encode conditions and inject into z. |
| """ |
| B = z.shape[0] |
| condition = self.encode_conditions(text_emb, action_id, scene_id, B, z.device) |
| return self.inject(z, condition, level) |
|
|
|
|
| class HierarchicalTemporalMemory(nn.Module): |
| """ |
| 3-level hierarchical memory for long-term temporal coherence. |
| |
| Levels: |
| - Frame level: Fine details, short-term (1 frame) |
| - Clip level: Medium-term patterns (4-8 frames) |
| - Scene level: Long-term context (16+ frames) |
| |
| Each level maintains a memory bank that is updated and queried |
| to provide multi-scale temporal context. |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
|
|
| |
| self.scales = config.memory_scales |
|
|
| |
| self.frame_memory = nn.Parameter( |
| torch.randn(config.frame_memory_slots, d) * 0.02 |
| ) |
| self.clip_memory = nn.Parameter(torch.randn(config.clip_memory_slots, d) * 0.02) |
| self.scene_memory = nn.Parameter( |
| torch.randn(config.scene_memory_slots, d) * 0.02 |
| ) |
|
|
| |
| self.clip_pool = nn.AvgPool1d(kernel_size=4, stride=4, padding=0) |
| self.scene_pool = nn.AvgPool1d(kernel_size=16, stride=16, padding=0) |
|
|
| |
| self.frame_attn = MultiHeadAttention(d, 8, dropout=config.dropout) |
| self.clip_attn = MultiHeadAttention(d, 8, dropout=config.dropout) |
| self.scene_attn = MultiHeadAttention(d, 8, dropout=config.dropout) |
|
|
| |
| self.scale_fusion = nn.Sequential( |
| nn.Linear(d * 3, d), nn.SiLU(), nn.Linear(d, d) |
| ) |
|
|
| |
| self.gate = nn.Sequential(nn.Linear(d * 2, d), nn.Sigmoid()) |
|
|
| def forward(self, z_seq: torch.Tensor) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Process sequence through hierarchical memory. |
| |
| Args: |
| z_seq: [B, T, D] latent sequence |
| |
| Returns: |
| z_enhanced: [B, T, D] memory-enhanced sequence |
| info: Dict with memory statistics |
| """ |
| B, T, D = z_seq.shape |
|
|
| |
| frame_mem = self.frame_memory.unsqueeze(0).expand(B, -1, -1) |
| frame_ctx = self.frame_attn(z_seq, frame_mem, frame_mem) |
|
|
| |
| |
| z_t = z_seq.transpose(1, 2) |
| T_clip = max(1, T // 4) |
| if T >= 4: |
| z_clip = F.adaptive_avg_pool1d(z_t, T_clip).transpose( |
| 1, 2 |
| ) |
| else: |
| z_clip = z_seq.mean(dim=1, keepdim=True) |
|
|
| clip_mem = self.clip_memory.unsqueeze(0).expand(B, -1, -1) |
| clip_ctx_pooled = self.clip_attn(z_clip, clip_mem, clip_mem) |
|
|
| |
| clip_ctx = F.interpolate( |
| clip_ctx_pooled.transpose(1, 2), size=T, mode="linear", align_corners=False |
| ).transpose(1, 2) |
|
|
| |
| T_scene = max(1, T // 16) |
| if T >= 16: |
| z_scene = F.adaptive_avg_pool1d(z_t, T_scene).transpose( |
| 1, 2 |
| ) |
| else: |
| z_scene = z_seq.mean(dim=1, keepdim=True) |
|
|
| scene_mem = self.scene_memory.unsqueeze(0).expand(B, -1, -1) |
| scene_ctx_pooled = self.scene_attn( |
| z_scene, scene_mem, scene_mem |
| ) |
|
|
| |
| scene_ctx = F.interpolate( |
| scene_ctx_pooled.transpose(1, 2), size=T, mode="linear", align_corners=False |
| ).transpose(1, 2) |
|
|
| |
| multi_scale = torch.cat([frame_ctx, clip_ctx, scene_ctx], dim=-1) |
| fused = self.scale_fusion(multi_scale) |
|
|
| |
| concat = torch.cat([z_seq, fused], dim=-1) |
| gate = self.gate(concat) |
| z_enhanced = z_seq + gate * fused |
|
|
| info = { |
| "frame_ctx_norm": frame_ctx.norm().item(), |
| "clip_ctx_norm": clip_ctx.norm().item(), |
| "scene_ctx_norm": scene_ctx.norm().item(), |
| } |
|
|
| return z_enhanced, info |
|
|
|
|
| class MultiScaleTemporalEncoder(nn.Module): |
| """ |
| Extract features at multiple temporal scales. |
| Captures both fast motion (3 frames) and slow motion (15 frames) patterns. |
| """ |
|
|
| def __init__(self, channels: int, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| kernel_sizes = config.temporal_kernel_sizes |
|
|
| |
| self.fast_conv = CausalConv3d(channels, channels, (kernel_sizes[0], 1, 1)) |
|
|
| |
| self.medium_conv = CausalConv3d(channels, channels, (kernel_sizes[1], 1, 1)) |
|
|
| |
| self.slow_conv = CausalConv3d(channels, channels, (kernel_sizes[2], 1, 1)) |
|
|
| |
| self.fusion = nn.Conv3d(channels * 3, channels, kernel_size=1) |
|
|
| |
| self.scale = nn.Parameter(torch.tensor(0.1)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, C, T, H, W] video features |
| |
| Returns: |
| x_enhanced: [B, C, T, H, W] multi-scale enhanced features |
| """ |
| fast = self.fast_conv(x) |
| medium = self.medium_conv(x) |
| slow = self.slow_conv(x) |
|
|
| |
| T_min = min(fast.shape[2], medium.shape[2], slow.shape[2]) |
| fast = fast[:, :, :T_min] |
| medium = medium[:, :, :T_min] |
| slow = slow[:, :, :T_min] |
|
|
| combined = torch.cat([fast, medium, slow], dim=1) |
| fused = self.fusion(combined) |
|
|
| |
| if fused.shape[2] < x.shape[2]: |
| pad = x.shape[2] - fused.shape[2] |
| fused = F.pad(fused, (0, 0, 0, 0, pad, 0), mode="replicate") |
|
|
| return x + self.scale * fused |
|
|
|
|
| class SharpnessEnhancer(nn.Module): |
| """ |
| Learned sharpening module to counteract blur. |
| Uses learned high-pass filtering to enhance edges and details. |
| """ |
|
|
| def __init__(self, channels: int): |
| super().__init__() |
|
|
| |
| self.detail_conv = nn.Sequential( |
| nn.Conv3d( |
| channels, channels, (1, 3, 3), padding=(0, 1, 1), groups=channels |
| ), |
| nn.SiLU(), |
| nn.Conv3d(channels, channels, (1, 3, 3), padding=(0, 1, 1)), |
| ) |
|
|
| |
| self.edge_conv = nn.Conv3d( |
| channels, |
| channels, |
| (1, 3, 3), |
| padding=(0, 1, 1), |
| bias=False, |
| groups=channels, |
| ) |
|
|
| |
| with torch.no_grad(): |
| laplacian = torch.tensor( |
| [[[0, -1, 0], [-1, 4, -1], [0, -1, 0]]], dtype=torch.float32 |
| ) |
| |
| self.edge_conv.weight.data = ( |
| laplacian.unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1, 1, 1) * 0.1 |
| ) |
|
|
| |
| self.detail_scale = nn.Parameter(torch.tensor(0.1)) |
| self.edge_scale = nn.Parameter(torch.tensor(0.05)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply sharpening to video. |
| |
| Args: |
| x: [B, C, T, H, W] video tensor |
| |
| Returns: |
| x_sharp: [B, C, T, H, W] sharpened video |
| """ |
| |
| details = self.detail_conv(x) |
|
|
| |
| edges = self.edge_conv(x) |
|
|
| |
| x_sharp = x + self.detail_scale * details + self.edge_scale * edges |
|
|
| return x_sharp |
|
|
|
|
| |
| |
| |
|
|
|
|
| class CausalConv3d(nn.Module): |
| """Causal 3D convolution - doesn't look into future frames""" |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: Tuple[int, int, int] = (3, 3, 3), |
| stride: Tuple[int, int, int] = (1, 1, 1), |
| ): |
| super().__init__() |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size, kernel_size) |
| if isinstance(stride, int): |
| stride = (stride, stride, stride) |
|
|
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.temporal_pad = kernel_size[0] - 1 |
|
|
| self.conv = nn.Conv3d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| padding=(0, kernel_size[1] // 2, kernel_size[2] // 2), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| if self.temporal_pad > 0: |
| x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0), mode="replicate") |
| return self.conv(x) |
|
|
|
|
| class ResBlock3D(nn.Module): |
| """3D Residual block with spatial and temporal processing""" |
|
|
| def __init__(self, channels: int, use_temporal: bool = True): |
| super().__init__() |
| self.use_temporal = use_temporal |
|
|
| |
| self.spatial = nn.Sequential( |
| nn.GroupNorm(min(32, channels), channels), |
| nn.SiLU(), |
| nn.Conv3d(channels, channels, (1, 3, 3), padding=(0, 1, 1)), |
| nn.GroupNorm(min(32, channels), channels), |
| nn.SiLU(), |
| nn.Conv3d(channels, channels, (1, 3, 3), padding=(0, 1, 1)), |
| ) |
|
|
| |
| if use_temporal: |
| self.temporal = nn.Sequential( |
| nn.GroupNorm(min(32, channels), channels), |
| nn.SiLU(), |
| CausalConv3d(channels, channels, (3, 1, 1)), |
| ) |
|
|
| |
| self.scale = nn.Parameter(torch.ones(1) * 0.1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.spatial(x) |
| if self.use_temporal: |
| h = h + self.temporal(x) |
| return x + self.scale * h |
|
|
|
|
| class DownsampleBlock3D(nn.Module): |
| """Downsample in space and optionally time""" |
|
|
| def __init__( |
| self, in_channels: int, out_channels: int, temporal_downsample: bool = True |
| ): |
| super().__init__() |
| stride = (2, 2, 2) if temporal_downsample else (1, 2, 2) |
|
|
| self.conv = nn.Conv3d(in_channels, out_channels, 3, stride=stride, padding=1) |
| self.norm = nn.GroupNorm(min(32, out_channels), out_channels) |
| self.act = nn.SiLU() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.act(self.norm(self.conv(x))) |
|
|
|
|
| class UpsampleBlock3D(nn.Module): |
| """Upsample in space and optionally time""" |
|
|
| def __init__( |
| self, in_channels: int, out_channels: int, temporal_upsample: bool = True |
| ): |
| super().__init__() |
| self.temporal_upsample = temporal_upsample |
|
|
| self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1) |
| self.norm = nn.GroupNorm(min(32, out_channels), out_channels) |
| self.act = nn.SiLU() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| scale = (2, 2, 2) if self.temporal_upsample else (1, 2, 2) |
| x = F.interpolate(x, scale_factor=scale, mode="trilinear", align_corners=False) |
| return self.act(self.norm(self.conv(x))) |
|
|
|
|
| class SpatialAttention2D(nn.Module): |
| """2D spatial attention within each frame""" |
|
|
| def __init__(self, channels: int, reduction: int = 8): |
| super().__init__() |
| reduced = max(channels // reduction, 8) |
| self.query = nn.Conv2d(channels, reduced, 1) |
| self.key = nn.Conv2d(channels, reduced, 1) |
| self.value = nn.Conv2d(channels, channels, 1) |
| self.gamma = nn.Parameter(torch.zeros(1)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, C, H, W = x.shape |
|
|
| q = self.query(x).view(B, -1, H * W).permute(0, 2, 1) |
| k = self.key(x).view(B, -1, H * W) |
| v = self.value(x).view(B, -1, H * W) |
|
|
| attn = F.softmax(torch.bmm(q, k) / math.sqrt(q.size(-1)), dim=-1) |
| out = torch.bmm(v, attn.permute(0, 2, 1)).view(B, C, H, W) |
|
|
| return self.gamma * out + x |
|
|
|
|
| class SpatioTemporalBlock(nn.Module): |
| """ |
| Spatio-temporal processing block - CHECKPOINT COMPATIBLE |
| Creates structure: spatial.0 (GroupNorm), spatial.1 (SiLU), spatial.2 (Conv3d), |
| temporal.0, temporal.1, temporal.2 |
| """ |
|
|
| def __init__(self, channels: int): |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(1)) |
|
|
| |
| self.spatial = nn.Sequential( |
| nn.GroupNorm(min(8, channels), channels), |
| nn.SiLU(), |
| nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)), |
| ) |
|
|
| |
| self.temporal = nn.Sequential( |
| nn.GroupNorm(min(8, channels), channels), |
| nn.SiLU(), |
| CausalConv3d(channels, channels, kernel_size=(3, 1, 1)), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.spatial(x) |
| h = self.temporal(h) |
| return x + self.scale * h |
|
|
|
|
| |
| |
| |
|
|
|
|
| class EncoderStage(nn.ModuleList): |
| """ |
| Encoder stage as nn.ModuleList - CRITICAL FOR CHECKPOINT COMPATIBILITY |
| |
| This creates the exact key structure: |
| - encoders.0.0 -> SpatioTemporalBlock |
| - encoders.0.1 -> SpatioTemporalBlock |
| - encoders.0.2 -> Conv3d (downsample) |
| """ |
|
|
| def __init__(self, in_channels: int, out_channels: int, temporal_down: bool = True): |
| stride = (2, 2, 2) if temporal_down else (1, 2, 2) |
| super().__init__( |
| [ |
| SpatioTemporalBlock(in_channels), |
| SpatioTemporalBlock(in_channels), |
| nn.Conv3d( |
| in_channels, out_channels, kernel_size=3, stride=stride, padding=1 |
| ), |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self[0](x) |
| x = self[1](x) |
| x = self[2](x) |
| return x |
|
|
|
|
| class DecoderStage(nn.ModuleList): |
| """ |
| Decoder stage as nn.ModuleList - CRITICAL FOR CHECKPOINT COMPATIBILITY |
| |
| This creates the exact key structure: |
| - decoders.0.0 -> SpatioTemporalBlock |
| - decoders.0.1 -> SpatioTemporalBlock |
| - decoders.0.2 -> ConvTranspose3d (upsample) |
| """ |
|
|
| def __init__(self, in_channels: int, out_channels: int, temporal_up: bool = True): |
| stride = (2, 2, 2) if temporal_up else (1, 2, 2) |
| output_padding = (1, 1, 1) if temporal_up else (0, 1, 1) |
| super().__init__( |
| [ |
| SpatioTemporalBlock(in_channels), |
| SpatioTemporalBlock(in_channels), |
| nn.ConvTranspose3d( |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| stride=stride, |
| padding=1, |
| output_padding=output_padding, |
| ), |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self[0](x) |
| x = self[1](x) |
| x = self[2](x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VideoVAEEncoder(nn.Module): |
| """ |
| 3D VAE Encoder for video sequences - CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - input_conv |
| - encoders (ModuleList of EncoderStage) |
| - final_blocks (ModuleList of SpatioTemporalBlock) |
| - to_mu, to_logvar |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| channels = config.encoder_channels |
|
|
| |
| self.temporal_encoder = ( |
| MultiScaleTemporalEncoder(config.video.channels, config) |
| if hasattr(config, "use_multiscale_temporal") |
| and config.use_multiscale_temporal |
| else None |
| ) |
|
|
| |
| self.input_conv = nn.Conv3d( |
| config.video.channels, channels[0], kernel_size=(1, 3, 3), padding=(0, 1, 1) |
| ) |
|
|
| |
| self.encoders = nn.ModuleList( |
| [ |
| EncoderStage(channels[0], channels[1], temporal_down=True), |
| EncoderStage(channels[1], channels[2], temporal_down=True), |
| EncoderStage(channels[2], channels[3], temporal_down=False), |
| ] |
| ) |
|
|
| |
| self.final_blocks = nn.ModuleList( |
| [SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])] |
| ) |
|
|
| |
| self.to_mu = nn.Conv3d(channels[-1], config.latent_channels, 1) |
| self.to_logvar = nn.Conv3d(channels[-1], config.latent_channels, 1) |
|
|
| |
| self.to_d_model = nn.Linear(config.latent_channels, config.d_model) |
|
|
| |
| |
| |
| H_latent = config.video.height // 8 |
| W_latent = config.video.width // 8 |
| flat_size = config.latent_channels * H_latent * W_latent |
| self.adaptive_proj = nn.Linear(flat_size, config.d_model) |
|
|
| def forward( |
| self, x: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x: [B, C, T, H, W] or [B, T, C, H, W] |
| |
| Returns: |
| z: [B, T', D] sampled latent sequence |
| mu: [B, T', D] mean |
| logvar: [B, T', D] log variance |
| z_spatial: [B, C_latent, T', H', W'] spatial latent |
| """ |
| |
| if x.dim() == 5: |
| |
| if ( |
| x.shape[1] != self.config.video.channels |
| and x.shape[2] == self.config.video.channels |
| ): |
| |
| x = x.permute(0, 2, 1, 3, 4) |
|
|
| |
| if hasattr(self, "temporal_encoder") and self.temporal_encoder is not None: |
| x = self.temporal_encoder(x) |
|
|
| B, C, T, H, W = x.shape |
|
|
| |
| h = self.input_conv(x) |
|
|
| for encoder_stage in self.encoders: |
| h = encoder_stage(h) |
|
|
| for block in self.final_blocks: |
| h = block(h) |
|
|
| |
| mu_spatial = self.to_mu(h) |
| logvar_spatial = self.to_logvar(h).clamp(-10, 10) |
|
|
| |
| std = torch.exp(0.5 * logvar_spatial) |
| eps = torch.randn_like(std) |
| z_spatial = mu_spatial + eps * std |
|
|
| |
| B, C_lat, T_lat, H_lat, W_lat = z_spatial.shape |
|
|
| |
| z_flat = z_spatial.permute(0, 2, 1, 3, 4) |
| z_flat = z_flat.reshape(B, T_lat, -1) |
|
|
| |
| |
| if z_flat.shape[-1] != self.adaptive_proj.in_features: |
| |
| proj = nn.Linear(z_flat.shape[-1], self.config.d_model).to(z_flat.device) |
| z = proj(z_flat) |
| mu_flat = mu_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) |
| mu = proj(mu_flat) |
| logvar_flat = logvar_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) |
| logvar = proj(logvar_flat) |
| else: |
| z = self.adaptive_proj(z_flat) |
| mu_flat = mu_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) |
| mu = self.adaptive_proj(mu_flat) |
| logvar_flat = logvar_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) |
| logvar = self.adaptive_proj(logvar_flat) |
|
|
| return z, mu, logvar, z_spatial |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VideoVAEDecoder(nn.Module): |
| """ |
| 3D VAE Decoder for video generation - CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - from_latent |
| - init_blocks (ModuleList) |
| - decoders (ModuleList of DecoderStage) |
| - final_blocks (ModuleList) |
| - to_rgb (Sequential) |
| - temporal_refine (Sequential) |
| - refine_scale |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| channels = config.decoder_channels |
|
|
| |
| self.from_latent = nn.Conv3d(config.latent_channels, channels[0], 1) |
|
|
| |
| self.init_blocks = nn.ModuleList( |
| [SpatioTemporalBlock(channels[0]), SpatioTemporalBlock(channels[0])] |
| ) |
|
|
| |
| self.decoders = nn.ModuleList( |
| [ |
| DecoderStage(channels[0], channels[1], temporal_up=True), |
| DecoderStage(channels[1], channels[2], temporal_up=True), |
| DecoderStage(channels[2], channels[3], temporal_up=False), |
| ] |
| ) |
|
|
| |
| self.final_blocks = nn.ModuleList( |
| [SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])] |
| ) |
|
|
| |
| self.to_rgb = nn.Sequential( |
| nn.Conv3d(channels[-1], channels[-1] // 2, (1, 3, 3), padding=(0, 1, 1)), |
| nn.SiLU(), |
| nn.Conv3d( |
| channels[-1] // 2, config.video.channels, (1, 3, 3), padding=(0, 1, 1) |
| ), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| self.temporal_refine = nn.Sequential( |
| CausalConv3d(config.video.channels, 32, (3, 3, 3)), |
| nn.SiLU(), |
| nn.Conv3d(32, config.video.channels, 1), |
| nn.Tanh(), |
| ) |
| self.refine_scale = nn.Parameter(torch.tensor(0.05)) |
|
|
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| z: [B, C_latent, T', H', W'] spatial latent |
| |
| Returns: |
| video: [B, C, T, H, W] reconstructed video |
| """ |
| h = self.from_latent(z) |
|
|
| for block in self.init_blocks: |
| h = block(h) |
|
|
| for decoder_stage in self.decoders: |
| h = decoder_stage(h) |
|
|
| for block in self.final_blocks: |
| h = block(h) |
|
|
| video = self.to_rgb(h) |
|
|
| |
| refine = self.temporal_refine(video) * self.refine_scale |
| video = torch.clamp(video + refine, 0, 1) |
|
|
| return video |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VideoLPOL(nn.Module): |
| """ |
| LPOL Memory for Video - CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - memories (ParameterDict with domain names) |
| - memory_attn (ModuleDict with q_proj, k_proj, v_proj, o_proj) |
| - domain_clf (Sequential) |
| - fusion (Sequential) |
| - gate (Sequential) |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| self.n_domains = len(config.domain_types) |
| self.slots_per_domain = config.memory_slots_per_domain |
|
|
| |
| self.memories = nn.ParameterDict( |
| { |
| d: nn.Parameter( |
| torch.randn(self.slots_per_domain, config.d_model) * 0.02 |
| ) |
| for d in config.domain_types |
| } |
| ) |
|
|
| |
| |
| self.memory_attn = nn.ModuleDict( |
| { |
| "q_proj": nn.Linear(config.d_model, config.d_model, bias=False), |
| "k_proj": nn.Linear(config.d_model, config.d_model, bias=False), |
| "v_proj": nn.Linear(config.d_model, config.d_model, bias=False), |
| "o_proj": nn.Linear(config.d_model, config.d_model, bias=False), |
| } |
| ) |
|
|
| |
| self.domain_clf = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 2), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.d_model // 2, self.n_domains), |
| ) |
|
|
| |
| self.fusion = nn.Sequential( |
| nn.Linear(config.d_model * 2, config.d_model), |
| nn.GELU(), |
| nn.Linear(config.d_model, config.d_model), |
| ) |
|
|
| |
| self.gate = nn.Sequential( |
| nn.Linear(config.d_model * 2, config.d_model), nn.Sigmoid() |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Args: |
| x: [B, T, D] temporal sequence |
| |
| Returns: |
| output: [B, T, D] memory-augmented features |
| info: Dict with domain activations |
| """ |
| B, T, D = x.shape |
|
|
| |
| x_pooled = x.mean(dim=1) |
| domain_logits = self.domain_clf(x_pooled) |
| domain_probs = F.softmax(domain_logits, dim=-1) |
|
|
| |
| all_memories = [] |
| for i, domain_name in enumerate(self.config.domain_types): |
| mem = self.memories[domain_name] |
| weight = domain_probs[:, i : i + 1] |
| weighted_mem = mem.unsqueeze(0) * weight.unsqueeze(-1) |
| all_memories.append(weighted_mem) |
|
|
| |
| memory_bank = torch.cat(all_memories, dim=1) |
|
|
| |
| q = self.memory_attn["q_proj"](x) |
| k = self.memory_attn["k_proj"](memory_bank) |
| v = self.memory_attn["v_proj"](memory_bank) |
|
|
| |
| n_heads = 8 |
| head_dim = D // n_heads |
|
|
| q = q.view(B, T, n_heads, head_dim).transpose(1, 2) |
| k = k.view(B, -1, n_heads, head_dim).transpose(1, 2) |
| v = v.view(B, -1, n_heads, head_dim).transpose(1, 2) |
|
|
| scale = head_dim**-0.5 |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale |
| attn_weights = F.softmax(attn_weights, dim=-1) |
|
|
| retrieved = torch.matmul(attn_weights, v) |
| retrieved = retrieved.transpose(1, 2).contiguous().view(B, T, D) |
| retrieved = self.memory_attn["o_proj"](retrieved) |
|
|
| |
| concat = torch.cat([x, retrieved], dim=-1) |
| gate = self.gate(concat) |
| fused = self.fusion(concat) |
|
|
| output = x + gate * fused |
|
|
| return output, { |
| "domain_probs": domain_probs, |
| "top_domain": domain_probs.argmax(dim=-1), |
| "domain_names": list(self.memories.keys()), |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VideoExpert(nn.Module): |
| """ |
| Expert network for video processing - CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - confidence (Sequential) |
| - gate (Linear) |
| - fc1 (Linear) |
| - fc2 (Linear) |
| - dropout |
| """ |
|
|
| def __init__(self, config: NutataModelConfig, expert_type: str): |
| super().__init__() |
| self.expert_type = expert_type |
|
|
| |
| self.confidence = nn.Sequential( |
| nn.Linear(config.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid() |
| ) |
|
|
| |
| self.gate = nn.Linear(config.d_model, config.d_model) |
|
|
| |
| self.fc1 = nn.Linear(config.d_model, config.d_ff) |
| self.fc2 = nn.Linear(config.d_ff, config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x: [B, T, D] |
| |
| Returns: |
| output: [B, T, D] |
| confidence: [B, 1] |
| """ |
| |
| conf = self.confidence(x.mean(dim=1)) |
|
|
| |
| gate_val = torch.sigmoid(self.gate(x)) |
| h = self.dropout(F.gelu(self.fc1(x))) |
| h = self.fc2(h) |
|
|
| out = h * gate_val |
|
|
| return out, conf |
|
|
|
|
| class VideoEARCPLayer(nn.Module): |
| """ |
| EARCP Layer for video with temporal attention and expert routing - CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - attn_norm (LayerNorm with elementwise_affine=False) |
| - attn_scale (Parameter) |
| - temporal_attn (VideoGQA) |
| - experts (ModuleList of VideoExpert) |
| - router (Linear) |
| - low_coh_count (buffer) |
| """ |
|
|
| def __init__( |
| self, config: NutataModelConfig, layer_idx: int, n_experts: int = None |
| ): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
|
|
| if n_experts is None: |
| n_experts = len(config.expert_types) |
|
|
| |
| self.attn_norm = nn.LayerNorm(config.d_model, elementwise_affine=False) |
|
|
| |
| self.attn_scale = nn.Parameter(torch.ones(1)) |
|
|
| |
| self.temporal_attn = VideoGQA(config) |
|
|
| |
| self.experts = nn.ModuleList( |
| [ |
| VideoExpert( |
| config, |
| config.expert_types[i] |
| if i < len(config.expert_types) |
| else f"Hybrid_{i}", |
| ) |
| for i in range(n_experts) |
| ] |
| ) |
|
|
| |
| self.router = nn.Linear(config.d_model, n_experts) |
|
|
| |
| self.register_buffer("low_coh_count", torch.tensor(0)) |
|
|
| |
| self.coherence_score = 0.5 |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float, bool]: |
| """ |
| Args: |
| x: [B, T, D] |
| |
| Returns: |
| output: [B, T, D] |
| coherence: float |
| grew: bool (whether new expert was added) |
| """ |
| |
| h = self.attn_norm(x) |
| attn_out = self.temporal_attn(h, causal=True) |
| x = x + self.attn_scale * attn_out |
|
|
| |
| router_input = x.mean(dim=1) |
| router_logits = self.router(router_input) |
| weights = F.softmax(router_logits, dim=-1) |
|
|
| |
| expert_outputs = [] |
| confs = [] |
|
|
| for expert in self.experts: |
| out, conf = expert(x) |
| expert_outputs.append(out) |
| confs.append(conf) |
|
|
| |
| expert_outputs = torch.stack(expert_outputs, dim=1) |
| weighted_out = torch.einsum( |
| "be,betd->btd", weights, expert_outputs |
| ) |
|
|
| x = x + weighted_out |
|
|
| |
| confs_tensor = torch.stack(confs, dim=1) |
| expert_conf = confs_tensor.mean().item() |
|
|
| |
| entropy = -(weights * weights.log().clamp(min=-100)).sum(dim=-1).mean().item() |
| max_entropy = math.log(len(self.experts)) if len(self.experts) > 1 else 1.0 |
| routing_focus = 1 - (entropy / max_entropy) |
|
|
| coherence = 0.5 * expert_conf + 0.5 * routing_focus |
| self.coherence_score = coherence |
|
|
| |
| grew = False |
| if coherence < self.config.growth_threshold_coherence: |
| self.low_coh_count += 1 |
| if ( |
| self.low_coh_count >= self.config.growth_patience |
| and len(self.experts) < self.config.max_experts |
| ): |
| new_expert = VideoExpert(self.config, f"Hybrid_{len(self.experts)}").to( |
| x.device |
| ) |
| self.experts.append(new_expert) |
|
|
| |
| old_router = self.router |
| self.router = nn.Linear(self.config.d_model, len(self.experts)).to( |
| x.device |
| ) |
| with torch.no_grad(): |
| self.router.weight[: old_router.out_features] = old_router.weight |
| self.router.bias[: old_router.out_features] = old_router.bias |
|
|
| self.low_coh_count.zero_() |
| grew = True |
| else: |
| self.low_coh_count.zero_() |
|
|
| return x, coherence, grew |
|
|
| def get_expert_count(self) -> int: |
| return len(self.experts) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VideoNeurogenesis(nn.Module): |
| """ |
| Neurogenesis layer - adapts capacity based on temporal complexity |
| CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - weights (Parameter) |
| - bias (Parameter) |
| - temporal_gate (Linear) |
| - n_neurons, usage, lifetime, births, deaths (buffers) |
| """ |
|
|
| def __init__(self, input_dim: int, n_neurons: int, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| self.input_dim = input_dim |
|
|
| self.weights = nn.Parameter(torch.randn(n_neurons, input_dim) * 0.02) |
| self.bias = nn.Parameter(torch.zeros(n_neurons)) |
| self.temporal_gate = nn.Linear(input_dim, n_neurons) |
|
|
| self.register_buffer("n_neurons", torch.tensor(n_neurons)) |
| self.register_buffer("usage", torch.ones(n_neurons)) |
| self.register_buffer("lifetime", torch.zeros(n_neurons)) |
| self.register_buffer("births", torch.tensor(0)) |
| self.register_buffer("deaths", torch.tensor(0)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, T, D] |
| |
| Returns: |
| output: [B, T, n_neurons] |
| """ |
| n = self.n_neurons.item() |
|
|
| |
| gate = torch.sigmoid(self.temporal_gate(x)) |
| gate = gate[..., :n] |
|
|
| |
| out = torch.tanh(F.linear(x, self.weights[:n], self.bias[:n])) |
| out = out * gate |
|
|
| |
| with torch.no_grad(): |
| act = out.abs().mean(dim=(0, 1)) |
| if act.size(0) == n: |
| self.usage[:n] = 0.99 * self.usage[:n] + 0.01 * act |
| self.lifetime[:n] += 1 |
|
|
| return out |
|
|
| def maybe_grow(self, coherence: float) -> int: |
| """Attempt to grow neurons based on coherence""" |
| if not self.config.neurogenesis_enabled: |
| return 0 |
|
|
| n = self.n_neurons.item() |
| if n >= self.config.max_neurons: |
| return 0 |
| if coherence < self.config.neuron_birth_threshold: |
| return 0 |
|
|
| device = self.weights.device |
|
|
| with torch.no_grad(): |
| new_w = torch.randn(1, self.input_dim, device=device) * 0.02 |
| new_b = torch.zeros(1, device=device) |
|
|
| self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0)) |
| self.bias = nn.Parameter(torch.cat([self.bias.data, new_b])) |
|
|
| |
| old_gate = self.temporal_gate |
| self.temporal_gate = nn.Linear(self.input_dim, n + 1).to(device) |
| self.temporal_gate.weight.data[:n] = old_gate.weight.data |
| self.temporal_gate.weight.data[n:] = ( |
| torch.randn(1, self.input_dim, device=device) * 0.02 |
| ) |
| self.temporal_gate.bias.data[:n] = old_gate.bias.data |
| self.temporal_gate.bias.data[n:] = 0 |
|
|
| self.usage = torch.cat([self.usage, torch.ones(1, device=device)]) |
| self.lifetime = torch.cat([self.lifetime, torch.zeros(1, device=device)]) |
| self.n_neurons += 1 |
| self.births += 1 |
|
|
| return 1 |
|
|
| def maybe_prune(self) -> int: |
| """Prune low-usage neurons""" |
| if not self.config.neurogenesis_enabled: |
| return 0 |
|
|
| n = self.n_neurons.item() |
| if n <= self.config.min_neurons: |
| return 0 |
|
|
| |
| threshold = self.config.neuron_death_threshold |
| prune_mask = self.usage[:n] < threshold |
|
|
| if not prune_mask.any(): |
| return 0 |
|
|
| keep_mask = ~prune_mask |
| n_keep = keep_mask.sum().item() |
|
|
| if n_keep < self.config.min_neurons: |
| return 0 |
|
|
| device = self.weights.device |
|
|
| with torch.no_grad(): |
| self.weights = nn.Parameter(self.weights.data[keep_mask]) |
| self.bias = nn.Parameter(self.bias.data[keep_mask]) |
|
|
| old_gate = self.temporal_gate |
| self.temporal_gate = nn.Linear(self.input_dim, n_keep).to(device) |
| self.temporal_gate.weight.data = old_gate.weight.data[keep_mask] |
| self.temporal_gate.bias.data = old_gate.bias.data[keep_mask] |
|
|
| self.usage = self.usage[keep_mask] |
| self.lifetime = self.lifetime[keep_mask] |
|
|
| pruned = n - n_keep |
| self.n_neurons.fill_(n_keep) |
| self.deaths += pruned |
|
|
| return pruned |
|
|
| def resize(self, target_neurons: int): |
| """ |
| Resize the neurogenesis layer to a specific number of neurons. |
| Used when loading pretrained models with different neuron counts. |
| """ |
| current = self.n_neurons.item() |
|
|
| if target_neurons == current: |
| return |
|
|
| device = self.weights.device |
|
|
| if target_neurons > current: |
| |
| extra = target_neurons - current |
|
|
| new_w = torch.randn(extra, self.input_dim, device=device) * 0.02 |
| new_b = torch.zeros(extra, device=device) |
|
|
| self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0)) |
| self.bias = nn.Parameter(torch.cat([self.bias.data, new_b])) |
|
|
| |
| old_gate = self.temporal_gate |
| self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device) |
| with torch.no_grad(): |
| self.temporal_gate.weight[:current] = old_gate.weight |
| self.temporal_gate.bias[:current] = old_gate.bias |
|
|
| self.usage = torch.cat([self.usage, torch.ones(extra, device=device)]) |
| self.lifetime = torch.cat( |
| [self.lifetime, torch.zeros(extra, device=device)] |
| ) |
|
|
| else: |
| |
| keep_indices = torch.argsort(self.usage, descending=True)[:target_neurons] |
|
|
| self.weights = nn.Parameter(self.weights.data[keep_indices]) |
| self.bias = nn.Parameter(self.bias.data[keep_indices]) |
|
|
| |
| old_gate = self.temporal_gate |
| self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device) |
| with torch.no_grad(): |
| self.temporal_gate.weight[:] = old_gate.weight[keep_indices] |
| self.temporal_gate.bias[:] = old_gate.bias[keep_indices] |
|
|
| self.usage = self.usage[keep_indices] |
| self.lifetime = self.lifetime[keep_indices] |
|
|
| self.n_neurons.fill_(target_neurons) |
|
|
| def get_stats(self) -> Dict: |
| return { |
| "total_neurons": self.n_neurons.item(), |
| "total_births": self.births.item(), |
| "total_deaths": self.deaths.item(), |
| "avg_usage": self.usage[: self.n_neurons.item()].mean().item(), |
| "max_lifetime": self.lifetime[: self.n_neurons.item()].max().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VideoEnergySystem(nn.Module): |
| """ |
| Energy system for cognitive load tracking - CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - energy (buffer) |
| - consumed (buffer) |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| self.register_buffer("energy", torch.tensor(1.0)) |
| self.register_buffer("consumed", torch.tensor(0.0)) |
|
|
| self.costs = { |
| "encode": config.energy_cost_encode, |
| "decode": config.energy_cost_decode, |
| "predict": config.energy_cost_predict, |
| "process": 0.01, |
| "memory": 0.005, |
| "attention": 0.008, |
| } |
|
|
| def consume(self, operation: str, amount: float = None) -> bool: |
| """Consume energy for an operation""" |
| cost = amount if amount else self.costs.get(operation, 0.01) |
|
|
| if self.energy.item() >= cost: |
| self.energy -= cost |
| self.consumed += cost |
| return True |
| return False |
|
|
| def regenerate(self): |
| """Regenerate energy""" |
| regen = min(self.config.energy_regeneration, 1.0 - self.energy.item()) |
| self.energy += regen |
|
|
| def reset(self): |
| """Reset energy to full""" |
| self.energy.fill_(1.0) |
|
|
| def get_stats(self) -> Dict: |
| return { |
| "energy": self.energy.item(), |
| "consumed": self.consumed.item(), |
| "efficiency": 1.0 |
| - self.consumed.item() |
| / max(1.0, self.consumed.item() + self.energy.item()), |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TemporalCoherenceModule(nn.Module): |
| """ |
| Ensures temporal coherence across frames - CHECKPOINT COMPATIBLE |
| |
| Creates exact structure: |
| - diff_predictor (Sequential) |
| - smooth (Conv1d with groups=d_model) - DEPTHWISE CONV |
| - alpha (Parameter) |
| - coherence_history, history_idx (buffers) |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
|
|
| |
| self.diff_predictor = nn.Sequential( |
| nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d) |
| ) |
|
|
| |
| self.smooth = nn.Conv1d(d, d, kernel_size=3, padding=1, groups=d) |
|
|
| |
| self.alpha = nn.Parameter(torch.tensor(0.2)) |
|
|
| |
| self.register_buffer("coherence_history", torch.zeros(100)) |
| self.register_buffer("history_idx", torch.tensor(0)) |
|
|
| def forward(self, z_seq: torch.Tensor) -> Tuple[torch.Tensor, float]: |
| """ |
| Args: |
| z_seq: [B, T, D] latent sequence |
| |
| Returns: |
| output: [B, T, D] temporally coherent sequence |
| coherence: float score |
| """ |
| B, T, D = z_seq.shape |
|
|
| |
| if T > 1: |
| diffs = z_seq[:, 1:] - z_seq[:, :-1] |
| pairs = torch.cat([z_seq[:, :-1], z_seq[:, 1:]], dim=-1) |
| pred_diffs = self.diff_predictor(pairs) |
|
|
| coherence = 1 - F.mse_loss(pred_diffs, diffs).item() |
| coherence = max(0, min(1, coherence)) |
| else: |
| coherence = 1.0 |
|
|
| |
| z_t = z_seq.transpose(1, 2) |
| smoothed = self.smooth(z_t).transpose(1, 2) |
|
|
| |
| alpha = torch.sigmoid(self.alpha) |
| output = (1 - alpha) * z_seq + alpha * smoothed |
|
|
| |
| idx = self.history_idx.item() % 100 |
| self.coherence_history[idx] = coherence |
| self.history_idx += 1 |
|
|
| return output, coherence |
|
|
| def get_average_coherence(self) -> float: |
| valid = min(self.history_idx.item(), 100) |
| if valid == 0: |
| return 0.0 |
| return self.coherence_history[:valid].mean().item() |
|
|
|
|
| |
| |
| |
|
|
|
|
| class OpticalFlowConsistencyModule(nn.Module): |
| """ |
| Real Optical Flow Consistency - v2.0 |
| Computes dense optical flow between frames and enforces temporal consistency. |
| """ |
|
|
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| channels = config.video.channels |
|
|
| |
| self.flow_net = nn.Sequential( |
| nn.Conv2d(channels * 2, 32, kernel_size=7, padding=3), |
| nn.ReLU(), |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 32, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(32, 2, kernel_size=3, padding=1), |
| nn.Tanh(), |
| |
| |
| ) |
|
|
| def warp(self, x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: |
| """ |
| Warp image x using flow. |
| x: [B, C, H, W] |
| flow: [B, 2, H, W] |
| """ |
| B, C, H, W = x.shape |
| |
| base_grid = torch.meshgrid( |
| torch.arange(H, device=x.device), |
| torch.arange(W, device=x.device), |
| indexing="ij", |
| ) |
| base_grid = torch.stack(base_grid[::-1], dim=0).float() |
| base_grid = base_grid.unsqueeze(0).expand(B, -1, -1, -1) |
|
|
| |
| |
| |
| |
|
|
| final_grid = ( |
| base_grid + flow * 20.0 |
| ) |
| |
| final_grid[:, 0] = 2.0 * final_grid[:, 0] / max(W - 1, 1) - 1.0 |
| final_grid[:, 1] = 2.0 * final_grid[:, 1] / max(H - 1, 1) - 1.0 |
|
|
| final_grid = final_grid.permute(0, 2, 3, 1) |
|
|
| return F.grid_sample( |
| x, final_grid, mode="bilinear", padding_mode="border", align_corners=True |
| ) |
|
|
| def forward(self, video: torch.Tensor) -> Dict: |
| """ |
| Args: |
| video: [B, C, T, H, W] |
| |
| Returns: |
| Dict with flow_loss |
| """ |
| if video.dim() == 3: |
| return {} |
|
|
| B, C, T, H, W = video.shape |
| if T < 2: |
| return {"flow_loss": torch.tensor(0.0, device=video.device)} |
|
|
| |
| |
| frames_t = video[:, :, :-1].transpose(1, 2).reshape(-1, C, H, W) |
| frames_t1 = video[:, :, 1:].transpose(1, 2).reshape(-1, C, H, W) |
|
|
| flow_input = torch.cat([frames_t, frames_t1], dim=1) |
|
|
| |
|
|
| flow = self.flow_net(flow_input) |
|
|
| |
| warped_t = self.warp(frames_t, flow) |
|
|
| |
| |
| diff = (warped_t - frames_t1).abs() |
| loss = diff.mean() |
|
|
| return { |
| "flow_loss": loss, |
| "flow": flow, |
| "warped": warped_t.reshape(B, T - 1, C, H, W).transpose(1, 2), |
| } |
|
|
|
|
| |
| FlowPredictionModule = OpticalFlowConsistencyModule |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_psnr(pred: torch.Tensor, target: torch.Tensor) -> float: |
| """Compute Peak Signal-to-Noise Ratio""" |
| mse = F.mse_loss(pred, target) |
| if mse == 0: |
| return float("inf") |
| return 10 * math.log10(1.0 / mse.item()) |
|
|
|
|
| def compute_ssim(pred: torch.Tensor, target: torch.Tensor) -> float: |
| """Compute simplified SSIM for videos""" |
| |
| pred_flat = pred.flatten(2) |
| target_flat = target.flatten(2) |
|
|
| mu_pred = pred_flat.mean(dim=-1) |
| mu_target = target_flat.mean(dim=-1) |
|
|
| sigma_pred = pred_flat.var(dim=-1) |
| sigma_target = target_flat.var(dim=-1) |
|
|
| sigma_pred_target = ( |
| (pred_flat - mu_pred.unsqueeze(-1)) * (target_flat - mu_target.unsqueeze(-1)) |
| ).mean(dim=-1) |
|
|
| C1 = 0.01**2 |
| C2 = 0.03**2 |
|
|
| ssim = ((2 * mu_pred * mu_target + C1) * (2 * sigma_pred_target + C2)) / ( |
| (mu_pred**2 + mu_target**2 + C1) * (sigma_pred + sigma_target + C2) |
| ) |
|
|
| return ssim.mean().item() |
|
|
|
|
| def compute_temporal_consistency(video: torch.Tensor) -> float: |
| """Compute temporal consistency score""" |
| if video.shape[2] < 2: |
| return 1.0 |
|
|
| |
| diffs = (video[:, :, 1:] - video[:, :, :-1]).abs() |
|
|
| |
| consistency = 1.0 - diffs.mean().item() |
| return max(0, min(1, consistency)) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SparseCompressor(nn.Module): |
| """Learned top-k feature selection with straight-through estimator.""" |
|
|
| def __init__(self, input_dim, output_dim, top_k=64, temporal_aware=False, use_ste=True): |
| super().__init__() |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.top_k = min(top_k, input_dim) |
| self.temporal_aware = temporal_aware |
| self.use_ste = use_ste |
|
|
| bottleneck = max(input_dim // 4, 32) |
| self.importance_net = nn.Sequential( |
| nn.Linear(input_dim, bottleneck), nn.SiLU(), |
| nn.Linear(bottleneck, input_dim), nn.Sigmoid(), |
| ) |
| if temporal_aware: |
| self.temporal_context = nn.Sequential( |
| nn.Linear(input_dim, bottleneck), nn.SiLU(), |
| nn.Linear(bottleneck, input_dim), nn.Sigmoid(), |
| ) |
| self.compress = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| self.residual_scale = nn.Parameter(torch.tensor(0.1)) |
| self.register_buffer("selection_freq", torch.zeros(input_dim), persistent=False) |
| self.register_buffer("forward_count", torch.tensor(0, dtype=torch.long), persistent=False) |
|
|
| def forward(self, x): |
| is_sequence = x.dim() == 3 |
| if is_sequence: |
| B, T, D = x.shape |
| x_flat = x.reshape(B * T, D) |
| else: |
| x_flat = x |
| B, T = x.shape[0], 1 |
|
|
| importance = self.importance_net(x_flat) |
| if self.temporal_aware and is_sequence: |
| x_ctx = x.mean(dim=1, keepdim=True).expand_as(x).reshape(B * T, D) |
| temporal_mod = self.temporal_context(x_ctx) |
| importance = importance * temporal_mod |
|
|
| _, top_indices = importance.topk(self.top_k, dim=-1) |
| mask = torch.zeros_like(importance).scatter_(-1, top_indices, 1.0) |
|
|
| if self.use_ste and self.training: |
| sparse_x = x_flat * (mask - importance).detach() + x_flat * importance |
| else: |
| sparse_x = x_flat * mask |
|
|
| compressed = self.compress(sparse_x) |
| output = self.norm(compressed) |
|
|
| if self.training: |
| with torch.no_grad(): |
| self.selection_freq += mask.sum(dim=0) |
| self.forward_count += 1 |
|
|
| if is_sequence: |
| output = output.reshape(B, T, self.output_dim) |
| return output |
|
|
| def get_selection_stats(self): |
| if self.forward_count == 0: |
| return {"avg_freq": 0.0} |
| freq = self.selection_freq / self.forward_count |
| return { |
| "avg_selection_freq": freq.mean().item(), |
| "sparsity_ratio": (self.input_dim - self.top_k) / self.input_dim, |
| } |
|
|
| def sparsity_loss(self, x): |
| importance = self.importance_net(x.reshape(-1, self.input_dim)) |
| entropy = -( |
| importance * importance.clamp(min=1e-8).log() |
| + (1 - importance) * (1 - importance).clamp(min=1e-8).log() |
| ) |
| return entropy.mean() |
|
|
|
|
| class PostVAESparseCompressor(nn.Module): |
| """Post-VAE compression: forces selection of most informative latent dims.""" |
|
|
| def __init__(self, d_model, top_k=128): |
| super().__init__() |
| self.compressor = SparseCompressor(d_model, d_model, top_k=top_k, temporal_aware=True, use_ste=True) |
| self.gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid()) |
|
|
| def forward(self, z): |
| z_sparse = self.compressor(z) |
| gate_val = self.gate(torch.cat([z, z_sparse], dim=-1)) |
| return z * (1 - gate_val) + z_sparse * gate_val |
|
|
|
|
| class InterEARCPSparseCompressor(nn.Module): |
| """Between EARCP layers: adaptive compression, deeper = more aggressive.""" |
|
|
| def __init__(self, d_model, top_k=96, layer_idx=0): |
| super().__init__() |
| self.layer_idx = layer_idx |
| adaptive_k = max(top_k - (layer_idx * 8), 48) |
| self.compressor = SparseCompressor(d_model, d_model, top_k=adaptive_k, temporal_aware=False, use_ste=True) |
| self.scale = nn.Parameter(torch.tensor(0.3)) |
|
|
| def forward(self, z): |
| z_sparse = self.compressor(z) |
| return z + self.scale * (z_sparse - z) |
|
|
|
|
| class MemoryGateSparseCompressor(nn.Module): |
| """Memory retrieval compression: extracts only truly relevant memory signals.""" |
|
|
| def __init__(self, d_model, top_k=64): |
| super().__init__() |
| self.compressor = SparseCompressor(d_model, d_model, top_k=top_k, temporal_aware=True, use_ste=True) |
|
|
| def forward(self, retrieved): |
| return self.compressor(retrieved) |
|
|
|
|
| class PreDecodeSparseCompressor(nn.Module): |
| """Pre-decoder compression: clean decisive latent code for decoder.""" |
|
|
| def __init__(self, d_model, latent_channels, top_k=96): |
| super().__init__() |
| self.compressor = SparseCompressor(d_model, d_model, top_k=top_k, temporal_aware=False, use_ste=True) |
| self.gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid()) |
| nn.init.constant_(self.gate[0].bias, 0.85) |
|
|
| def forward(self, z): |
| z_sparse = self.compressor(z) |
| gate_val = self.gate(torch.cat([z, z_sparse], dim=-1)) |
| return z * gate_val + z_sparse * (1 - gate_val) |
|
|
|
|
| class SparseCompressionManager(nn.Module): |
| """Manages all 4 SparseCompressor instances across the architecture.""" |
|
|
| def __init__(self, config, earcp_compress_every=2): |
| super().__init__() |
| self.config = config |
| self.earcp_compress_every = earcp_compress_every |
| d = config.d_model |
|
|
| self.post_vae = PostVAESparseCompressor(d, top_k=d // 4) |
| n_compress_points = config.n_layers // earcp_compress_every |
| self.inter_earcp = nn.ModuleList([ |
| InterEARCPSparseCompressor(d, top_k=d // 4, layer_idx=i) |
| for i in range(n_compress_points) |
| ]) |
| self.memory_gate = MemoryGateSparseCompressor(d, top_k=d // 8) |
| self.pre_decode = PreDecodeSparseCompressor(d, config.latent_channels, top_k=d // 4) |
|
|
| self.enable_post_vae = True |
| self.enable_inter_earcp = True |
| self.enable_memory = True |
| self.enable_pre_decode = True |
|
|
| def compress_post_vae(self, z): |
| return self.post_vae(z) if self.enable_post_vae else z |
|
|
| def compress_inter_earcp(self, z, layer_idx): |
| if not self.enable_inter_earcp: |
| return z |
| if layer_idx % self.earcp_compress_every != 0: |
| return z |
| compress_idx = layer_idx // self.earcp_compress_every |
| if compress_idx < len(self.inter_earcp): |
| return self.inter_earcp[compress_idx](z) |
| return z |
|
|
| def compress_memory(self, retrieved): |
| return self.memory_gate(retrieved) if self.enable_memory else retrieved |
|
|
| def compress_pre_decode(self, z): |
| return self.pre_decode(z) if self.enable_pre_decode else z |
|
|
| def total_sparsity_loss(self, z): |
| loss = self.post_vae.compressor.sparsity_loss(z) |
| loss = loss + self.memory_gate.compressor.sparsity_loss(z) |
| loss = loss + self.pre_decode.compressor.sparsity_loss(z) |
| return loss / 3.0 |
|
|
| def count_params(self): |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class NutataModel(nn.Module): |
| """ |
| NUTATA v2.0 - HIGH RESOLUTION + CONDITIONING |
| |
| Complete cognitive architecture for video understanding and generation. |
| |
| v2.0 Features: |
| - Progressive resolution training (64 β 512) |
| - Perceptual loss (VGG19 + adaptive weighting) |
| - Multi-level conditioning (text/action/scene β latent/decoder/temporal/memory) |
| - Hierarchical temporal memory (frame/clip/scene) |
| - Anti-blur sharpening |
| - All components are checkpoint-compatible. |
| """ |
|
|
| def __init__( |
| self, config: NutataModelConfig = None, expert_counts: List[int] = None |
| ): |
| super().__init__() |
| self.config = config or NutataModelConfig() |
|
|
| |
| if expert_counts is None: |
| expert_counts = [len(self.config.expert_types)] * self.config.n_layers |
|
|
| |
| self.encoder = VideoVAEEncoder(self.config) |
| self.decoder = VideoVAEDecoder(self.config) |
|
|
| |
| self.lpol = VideoLPOL(self.config) if self.config.use_lpol else None |
|
|
| |
| self.layers = nn.ModuleList( |
| [ |
| VideoEARCPLayer(self.config, i, n_experts=expert_counts[i]) |
| for i in range(self.config.n_layers) |
| ] |
| ) |
|
|
| |
| self.neurogenesis = VideoNeurogenesis(self.config.d_model, 64, self.config) |
| |
| |
| self.neuro_proj = nn.Linear(self.config.max_neurons, self.config.d_model) |
|
|
| |
| self.temporal_coherence = TemporalCoherenceModule(self.config) |
|
|
| |
| self.flow_module = ( |
| FlowPredictionModule(self.config) if self.config.flow_prediction else None |
| ) |
|
|
| |
| self.energy = VideoEnergySystem(self.config) |
|
|
| |
| self.frame_predictor = nn.Sequential( |
| nn.Linear(self.config.d_model, self.config.d_model), |
| nn.SiLU(), |
| nn.Dropout(self.config.dropout), |
| nn.Linear(self.config.d_model, self.config.latent_channels * 8 * 8), |
| ) |
|
|
| |
| |
| |
|
|
| |
| self.perceptual_module = ( |
| PerceptualLossModule(self.config) |
| if self.config.use_perceptual_loss |
| else None |
| ) |
|
|
| |
| self.conditioner = ( |
| ConditionalEncoder(self.config) if self.config.use_conditioning else None |
| ) |
|
|
| |
| self.hierarchical_memory = ( |
| HierarchicalTemporalMemory(self.config) |
| if self.config.use_hierarchical_memory |
| else None |
| ) |
|
|
| |
| self.sharpener = ( |
| SharpnessEnhancer(self.config.video.channels) |
| if self.config.use_sharpening |
| else None |
| ) |
|
|
| |
| self.sparse_manager = SparseCompressionManager( |
| self.config, earcp_compress_every=2 |
| ) |
|
|
| |
| self.current_epoch = 0 |
|
|
| |
| self._init_weights() |
|
|
| |
| logger.info(f"π¦ NUTATA v{self.config.version} initialized") |
| logger.info(f" Architecture: {self.config.codename}") |
| logger.info( |
| f" Input: {self.config.video.channels}x{self.config.video.n_frames}x{self.config.video.height}x{self.config.video.width}" |
| ) |
| logger.info(f" Parameters: {self.count_params():,}") |
| logger.info(f" LPOL Domains: {len(self.config.domain_types)}") |
| logger.info(f" EARCP Layers: {self.config.n_layers}") |
| logger.info(f" Expert Types: {len(self.config.expert_types)}") |
| logger.info(f" Expert Counts: {[l.get_expert_count() for l in self.layers]}") |
| if self.config.use_gqa: |
| savings = int( |
| (1 - self.config.gqa_num_kv_groups / self.config.gqa_num_heads) * 100 |
| ) |
| logger.info( |
| f" GQA: {self.config.gqa_num_heads} heads, {self.config.gqa_num_kv_groups} KV groups ({savings}% savings)" |
| ) |
| if self.config.flow_prediction: |
| logger.info(" Flow Prediction: Enabled") |
|
|
| |
| if self.config.use_perceptual_loss: |
| logger.info( |
| f" [v2.0] Perceptual Loss: Enabled (adaptive={self.config.perceptual_adaptive})" |
| ) |
| if self.config.use_conditioning: |
| logger.info( |
| f" [v2.0] Conditioning: Enabled (text/action/scene at {self.config.condition_injection_levels})" |
| ) |
| if self.config.use_hierarchical_memory: |
| logger.info( |
| f" [v2.0] Hierarchical Memory: Enabled (scales={self.config.memory_scales})" |
| ) |
| if self.config.use_sharpening: |
| logger.info( |
| f" [v2.0] Sharpening: Enabled (weight={self.config.sharpening_weight})" |
| ) |
| if hasattr(self, 'sparse_manager'): |
| logger.info( |
| f" [NUTATA] SparseCompressor: Enabled ({len(self.sparse_manager.inter_earcp)} bottleneck points, compress_every={self.sparse_manager.earcp_compress_every})" |
| ) |
| logger.info( |
| f" [NUTATA] SparseCompressor params: {self.sparse_manager.count_params():,}" |
| ) |
|
|
| def _init_weights(self): |
| """Initialize weights with better defaults""" |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Conv3d): |
| nn.init.kaiming_normal_( |
| module.weight, mode="fan_out", nonlinearity="relu" |
| ) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.ConvTranspose3d): |
| nn.init.kaiming_normal_( |
| module.weight, mode="fan_out", nonlinearity="relu" |
| ) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def encode( |
| self, video: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Encode video to latent space |
| |
| Args: |
| video: [B, C, T, H, W] or [B, T, C, H, W] |
| |
| Returns: |
| z, mu, logvar, z_spatial |
| """ |
| self.energy.consume("encode") |
| return self.encoder(video) |
|
|
| def decode(self, z_spatial: torch.Tensor) -> torch.Tensor: |
| """ |
| Decode latent to video |
| |
| Args: |
| z_spatial: [B, C_latent, T', H', W'] |
| |
| Returns: |
| video: [B, C, T, H, W] |
| """ |
| self.energy.consume("decode") |
| return self.decoder(z_spatial) |
|
|
| def process_temporal(self, z: torch.Tensor, conditions: dict = None) -> Dict: |
| """ |
| Process latent through cognitive systems - v2.0 |
| |
| Args: |
| z: [B, T', D] latent sequence |
| conditions: Encoded conditions (text/action/scene) |
| |
| Returns: |
| Dict with processed latent and metadata |
| """ |
| self.energy.consume("process") |
|
|
| |
| if self.conditioner is not None and conditions is not None: |
| |
| z = self.conditioner.inject(z, conditions, level="temporal") |
|
|
| |
| lpol_info = {} |
| if self.lpol is not None: |
| z, lpol_info = self.lpol(z) |
|
|
| |
| hierarchical_info = {} |
| if self.hierarchical_memory is not None: |
| |
| if self.conditioner is not None and conditions is not None: |
| z_mem_in = self.conditioner.inject(z, conditions, level="memory") |
| else: |
| z_mem_in = z |
|
|
| z, hierarchical_info = self.hierarchical_memory(z_mem_in) |
|
|
| |
| coherences = [] |
| total_growth = 0 |
|
|
| for layer in self.layers: |
| z, coh, grew = layer(z) |
| coherences.append(coh) |
| if grew: |
| total_growth += 1 |
|
|
| |
| neuro_out = self.neurogenesis(z) |
|
|
| |
| current_neurons = neuro_out.shape[-1] |
| if current_neurons < self.config.max_neurons: |
| padding = torch.zeros( |
| *neuro_out.shape[:-1], |
| self.config.max_neurons - current_neurons, |
| device=neuro_out.device, |
| dtype=neuro_out.dtype, |
| ) |
| neuro_out_padded = torch.cat([neuro_out, padding], dim=-1) |
| else: |
| neuro_out_padded = neuro_out[..., : self.config.max_neurons] |
|
|
| neuro_proj = self.neuro_proj(neuro_out_padded) |
| z = z + 0.1 * neuro_proj.mean(dim=-1, keepdim=True).expand_as(z) |
|
|
| avg_coherence = sum(coherences) / len(coherences) if coherences else 0.0 |
| neuro_growth = self.neurogenesis.maybe_grow(avg_coherence) |
|
|
| |
| z, temp_coherence = self.temporal_coherence(z) |
|
|
| |
| flow_info = {} |
| |
| |
|
|
| |
| self.energy.regenerate() |
|
|
| |
| memory_stats = {} |
| if lpol_info: |
| memory_stats.update(lpol_info) |
| if hierarchical_info: |
| memory_stats.update(hierarchical_info) |
|
|
| return { |
| "z": z, |
| "coherence": avg_coherence, |
| "temporal_coherence": temp_coherence, |
| "expert_growth": total_growth, |
| "neuro_growth": neuro_growth, |
| "lpol_info": lpol_info, |
| "flow_info": flow_info, |
| "hierarchical_info": hierarchical_info, |
| "energy": self.energy.get_stats(), |
| } |
|
|
| def predict_frames(self, z: torch.Tensor, n_frames: int = 1) -> torch.Tensor: |
| """ |
| Predict future frames in latent space |
| |
| Args: |
| z: [B, T, D] current latent sequence |
| n_frames: number of frames to predict |
| |
| Returns: |
| z_future: [B, n_frames, D] predicted latents |
| """ |
| self.energy.consume("predict") |
|
|
| predictions = [] |
| current_z = z |
|
|
| for _ in range(n_frames): |
| last = current_z[:, -1:] |
|
|
| for layer in self.layers: |
| last, _, _ = layer(last) |
|
|
| predictions.append(last) |
| current_z = torch.cat([current_z, last], dim=1) |
|
|
| return torch.cat(predictions, dim=1) |
|
|
| def forward( |
| self, |
| video: torch.Tensor, |
| text_emb: torch.Tensor = None, |
| action_id: torch.Tensor = None, |
| scene_id: torch.Tensor = None, |
| epoch: int = None, |
| ) -> Dict: |
| """ |
| Full forward pass - v2.0 |
| |
| Args: |
| video: [B, C, T, H, W] input video |
| text_emb: [B, D_text] optional text embeddings |
| action_id: [B] optional action class ids |
| scene_id: [B] optional scene class ids |
| epoch: current training epoch (for adaptive scheduling) |
| |
| Returns: |
| Dict with loss, reconstruction, and metadata |
| """ |
| |
| if epoch is not None: |
| self.current_epoch = epoch |
|
|
| B = video.shape[0] |
|
|
| |
| encoded_conditions = None |
| if self.conditioner is not None: |
| |
| has_conditions = ( |
| text_emb is not None or action_id is not None or scene_id is not None |
| ) |
| if has_conditions: |
| encoded_conditions = self.conditioner.encode_conditions( |
| text_emb, action_id, scene_id, batch_size=B, device=video.device |
| ) |
|
|
| |
| z, mu, logvar, z_spatial = self.encode(video) |
|
|
| |
| if self.conditioner is not None and encoded_conditions is not None: |
| z = self.conditioner.inject(z, encoded_conditions, level="latent") |
|
|
| |
| proc_out = self.process_temporal(z, conditions=encoded_conditions) |
| z_processed = proc_out["z"] |
|
|
| |
| |
| recon = self.decode(z_spatial) |
|
|
| |
| if self.sharpener is not None: |
| recon = self.sharpener(recon) |
|
|
| |
| |
| if video.shape[2] == self.config.video.channels: |
| video_compare = video.permute(0, 2, 1, 3, 4) |
| else: |
| video_compare = video |
|
|
| |
| if recon.shape != video_compare.shape: |
| recon = F.interpolate( |
| recon, |
| size=video_compare.shape[2:], |
| mode="trilinear", |
| align_corners=False, |
| ) |
|
|
| recon_loss = F.mse_loss(recon, video_compare) |
| kl_loss = 0.5 * torch.mean(mu.pow(2) + logvar.exp() - logvar - 1) |
| coherence_loss = float(1.0 - proc_out["temporal_coherence"]) |
|
|
| |
| flow_loss_tensor = torch.tensor(0.0, device=recon.device) |
| if self.flow_module is not None: |
| |
| flow_out = self.flow_module(recon) |
| if "flow_loss" in flow_out: |
| flow_loss_tensor = flow_out["flow_loss"] |
| |
| if "flow" in flow_out: |
| proc_out["flow_info"] = flow_out |
|
|
| |
| perceptual_loss = torch.tensor(0.0, device=recon.device) |
| if self.perceptual_module is not None: |
| |
| if ( |
| self.config.perceptual_adaptive |
| and self.current_epoch < self.config.perceptual_warmup_epochs |
| ): |
| |
| weight = self.perceptual_module.get_adaptive_weight(self.current_epoch) |
| else: |
| weight = self.config.perceptual_loss_weight |
|
|
| if weight > 0: |
| perceptual_loss = self.perceptual_module( |
| recon, video_compare, self.current_epoch |
| ) |
|
|
| total_loss = ( |
| recon_loss |
| + self.config.kl_weight * kl_loss |
| + self.config.temporal_coherence_weight * coherence_loss |
| + 0.01 * flow_loss_tensor |
| + perceptual_loss |
| ) |
|
|
| |
| flow_loss = ( |
| flow_loss_tensor.item() |
| if hasattr(flow_loss_tensor, "item") |
| else float(flow_loss_tensor) |
| ) |
|
|
| p_loss_val = ( |
| perceptual_loss.item() |
| if hasattr(perceptual_loss, "item") |
| else float(perceptual_loss) |
| ) |
|
|
| return { |
| "loss": total_loss, |
| "recon_loss": recon_loss, |
| "kl_loss": kl_loss, |
| "coherence_loss": coherence_loss, |
| "flow_loss": flow_loss, |
| "perceptual_loss": p_loss_val, |
| "recon": recon, |
| "z": z, |
| "z_spatial": z_spatial, |
| "coherence": proc_out["coherence"], |
| "temporal_coherence": proc_out["temporal_coherence"], |
| "neurogenesis": proc_out["neuro_growth"], |
| "expert_growth": proc_out["expert_growth"], |
| "energy": proc_out["energy"], |
| } |
|
|
| def generate( |
| self, |
| n_frames: int = 16, |
| z_init: torch.Tensor = None, |
| temperature: float = 1.0, |
| batch_size: int = 1, |
| ) -> torch.Tensor: |
| """ |
| Generate video from scratch or initial latent |
| |
| Args: |
| n_frames: number of frames to generate |
| z_init: optional initial latent |
| temperature: sampling temperature |
| batch_size: batch size if z_init is None |
| |
| Returns: |
| video: [B, C, T, H, W] generated video |
| """ |
| self.eval() |
| device = next(self.parameters()).device |
|
|
| if z_init is None: |
| B = batch_size |
| T = max(1, n_frames // 4) |
| H = self.config.video.height // 8 |
| W = self.config.video.width // 8 |
| z_init = ( |
| torch.randn(B, self.config.latent_channels, T, H, W, device=device) |
| * temperature |
| ) |
|
|
| with torch.no_grad(): |
| video = self.decode(z_init) |
|
|
| |
| if video.shape[2] != n_frames: |
| video = F.interpolate( |
| video, |
| size=(n_frames, video.shape[3], video.shape[4]), |
| mode="trilinear", |
| align_corners=False, |
| ) |
|
|
| return video.clamp(0, 1) |
|
|
| def continue_video( |
| self, video: torch.Tensor, n_frames: int = 8, temperature: float = 0.8 |
| ) -> torch.Tensor: |
| """ |
| Continue a video by predicting future frames |
| |
| Args: |
| video: [B, C, T, H, W] input video |
| n_frames: number of frames to predict |
| temperature: sampling temperature |
| |
| Returns: |
| continued: [B, C, T+n_frames, H, W] continued video |
| """ |
| self.eval() |
|
|
| with torch.no_grad(): |
| |
| z, mu, logvar, z_spatial = self.encode(video) |
|
|
| |
| z_future = self.predict_frames(z, n_frames=n_frames) |
|
|
| |
| B, T_fut, D = z_future.shape |
| spatial = self.frame_predictor(z_future) |
| spatial = spatial.view(B, T_fut, self.config.latent_channels, 8, 8) |
| spatial = spatial.permute(0, 2, 1, 3, 4) |
|
|
| |
| future_frames = self.decode(spatial) |
|
|
| |
| continued = torch.cat([video, future_frames], dim=2) |
|
|
| return continued.clamp(0, 1) |
|
|
| def count_params(self) -> int: |
| """Count total parameters""" |
| return sum(p.numel() for p in self.parameters()) |
|
|
| def count_trainable_params(self) -> int: |
| """Count trainable parameters""" |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| def diagnostics(self) -> Dict: |
| """Get comprehensive diagnostics""" |
| total_experts = sum(layer.get_expert_count() for layer in self.layers) |
|
|
| return { |
| "model_version": self.config.version, |
| "total_params": self.count_params(), |
| "trainable_params": self.count_trainable_params(), |
| "total_experts": total_experts, |
| "expert_counts": [layer.get_expert_count() for layer in self.layers], |
| "neurogenesis": self.neurogenesis.get_stats(), |
| "energy": self.energy.get_stats(), |
| "temporal_coherence": self.temporal_coherence.get_average_coherence(), |
| "gqa_enabled": self.config.use_gqa, |
| "flow_prediction": self.config.flow_prediction, |
| "lpol_enabled": self.config.use_lpol, |
| "n_domains": len(self.config.domain_types) if self.config.use_lpol else 0, |
| } |
|
|
| |
| |
| |
|
|
| def save_pretrained(self, save_directory: str, save_config: bool = True): |
| """ |
| Save model weights and configuration to a directory. |
| |
| Args: |
| save_directory: Directory to save the model |
| save_config: Whether to save configuration files |
| """ |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| |
| model_path = os.path.join(save_directory, "pytorch_model.bin") |
| torch.save(self.state_dict(), model_path) |
| logger.info(f"πΎ Model weights saved to {model_path}") |
|
|
| |
| if save_config: |
| config_path = os.path.join(save_directory, "config.json") |
| config_dict = self.config.to_dict() |
| config_dict["_class_name"] = "NutataModel" |
| config_dict["_version"] = self.config.version |
| config_dict["_architecture"] = "cognitive-video-world-model" |
|
|
| |
| config_dict["_dynamic_state"] = { |
| "expert_counts": [layer.get_expert_count() for layer in self.layers], |
| "neuron_count": self.neurogenesis.n_neurons.item(), |
| "neuro_proj_in": self.neuro_proj.in_features, |
| "neuro_proj_out": self.neuro_proj.out_features, |
| "adaptive_proj_in": self.encoder.adaptive_proj.in_features |
| if self.encoder.adaptive_proj |
| else None, |
| "adaptive_proj_out": self.encoder.adaptive_proj.out_features |
| if self.encoder.adaptive_proj |
| else None, |
| } |
|
|
| with open(config_path, "w") as f: |
| json.dump(config_dict, f, indent=2) |
| logger.info(f"π Config saved to {config_path}") |
|
|
| |
| arch_info = { |
| "model_type": "nutata-videomodel", |
| "version": self.config.version, |
| "total_params": self.count_params(), |
| "diagnostics": self.diagnostics(), |
| "architecture": { |
| "encoder": "VideoVAEEncoder (3D Causal VAE)", |
| "decoder": "VideoVAEDecoder (3D Upsampling)", |
| "memory": f"VideoLPOL ({len(self.config.domain_types)} domains)", |
| "attention": "VideoGQA (Grouped Query Attention with RoPE)", |
| "experts": "VideoEARCP (Dynamic Expert Routing)", |
| "neurogenesis": "VideoNeurogenesis (Adaptive Capacity)", |
| "temporal": "TemporalCoherenceModule + FlowPrediction", |
| }, |
| "input_spec": { |
| "shape": f"[B, {self.config.video.channels}, {self.config.video.n_frames}, {self.config.video.height}, {self.config.video.width}]", |
| "dtype": "torch.float32", |
| "range": "[0, 1]", |
| }, |
| "output_spec": { |
| "reconstruction": f"[B, {self.config.video.channels}, {self.config.video.n_frames}, {self.config.video.height}, {self.config.video.width}]", |
| "latent": f"[B, T', {self.config.d_model}]", |
| }, |
| } |
|
|
| arch_path = os.path.join(save_directory, "architecture.json") |
| with open(arch_path, "w") as f: |
| json.dump(arch_info, f, indent=2) |
| logger.info(f"π Architecture info saved to {arch_path}") |
|
|
| |
| training_state = { |
| "neurogenesis_stats": self.neurogenesis.get_stats(), |
| "energy_stats": self.energy.get_stats(), |
| "expert_counts": [layer.get_expert_count() for layer in self.layers], |
| "temporal_coherence_avg": self.temporal_coherence.get_average_coherence(), |
| } |
|
|
| training_path = os.path.join(save_directory, "training_state.json") |
| with open(training_path, "w") as f: |
| json.dump(training_state, f, indent=2) |
|
|
| |
| license_content = self._generate_license() |
| license_path = os.path.join(save_directory, "LICENSE") |
| with open(license_path, "w") as f: |
| f.write(license_content) |
|
|
| |
| modeling_code = self._generate_modeling_code() |
| modeling_path = os.path.join(save_directory, "modeling_nutata.py") |
| with open(modeling_path, "w") as f: |
| f.write(modeling_code) |
| logger.info(f"π modeling_nutata.py saved to {modeling_path}") |
|
|
| logger.info(f"β
Model saved to {save_directory}") |
|
|
| return save_directory |
|
|
| @classmethod |
| def from_pretrained( |
| cls, pretrained_path: str, device: str = None, **kwargs |
| ) -> "NutataModel": |
| """ |
| Load a pretrained NutataModel from local path or HuggingFace Hub. |
| |
| Args: |
| pretrained_path: Local directory or HuggingFace repo ID |
| device: Device to load model to |
| **kwargs: Additional config overrides |
| |
| Returns: |
| Loaded NutataModel instance |
| """ |
| |
| if os.path.isdir(pretrained_path): |
| load_dir = pretrained_path |
| elif HF_HUB_AVAILABLE: |
| logger.info(f"π₯ Downloading from HuggingFace Hub: {pretrained_path}") |
| load_dir = snapshot_download(repo_id=pretrained_path) |
| else: |
| raise ValueError( |
| f"Path {pretrained_path} not found and huggingface_hub not installed" |
| ) |
|
|
| |
| config_path = os.path.join(load_dir, "config.json") |
| dynamic_state = None |
|
|
| if os.path.exists(config_path): |
| with open(config_path, "r") as f: |
| config_dict = json.load(f) |
|
|
| |
| dynamic_state = config_dict.pop("_dynamic_state", None) |
|
|
| |
| config_dict.pop("_class_name", None) |
| config_dict.pop("_version", None) |
| config_dict.pop("_architecture", None) |
|
|
| |
| config_dict.update(kwargs) |
|
|
| config = NutataModelConfig.from_dict(config_dict) |
| else: |
| logger.warning("No config.json found, using default config") |
| config = NutataModelConfig(**kwargs) |
|
|
| |
| expert_counts = None |
| neuron_count = 64 |
| neuro_proj_in = 64 |
| neuro_proj_out = None |
| adaptive_proj_in = None |
| adaptive_proj_out = None |
|
|
| if dynamic_state is not None: |
| expert_counts = dynamic_state.get("expert_counts") |
| neuron_count = dynamic_state.get("neuron_count", 64) |
| neuro_proj_in = dynamic_state.get("neuro_proj_in", neuron_count) |
| neuro_proj_out = dynamic_state.get("neuro_proj_out") |
| adaptive_proj_in = dynamic_state.get("adaptive_proj_in") |
| adaptive_proj_out = dynamic_state.get("adaptive_proj_out") |
| logger.info( |
| f"π Restoring dynamic state: experts={expert_counts}, neurons={neuron_count}, neuro_proj={neuro_proj_in}" |
| ) |
|
|
| |
| model = cls(config, expert_counts=expert_counts) |
|
|
| |
| if neuron_count != model.neurogenesis.n_neurons.item(): |
| model.neurogenesis.resize(neuron_count) |
|
|
| |
| |
| if neuro_proj_in and neuro_proj_in != model.neuro_proj.in_features: |
| |
| model.neuro_proj = nn.Linear(neuro_proj_in, config.d_model) |
| logger.info( |
| f"π Adjusted neuro_proj for checkpoint: {neuro_proj_in} -> {config.d_model}" |
| ) |
|
|
| |
| if adaptive_proj_in is not None and adaptive_proj_out is not None: |
| model.encoder.adaptive_proj = nn.Linear(adaptive_proj_in, adaptive_proj_out) |
| logger.info( |
| f"π Pre-created adaptive_proj from config: {adaptive_proj_in} -> {adaptive_proj_out}" |
| ) |
|
|
| |
| model_path = os.path.join(load_dir, "pytorch_model.bin") |
| if os.path.exists(model_path): |
| state_dict = torch.load(model_path, map_location="cpu") |
|
|
| |
| adaptive_proj_keys = [k for k in state_dict.keys() if "adaptive_proj" in k] |
| if adaptive_proj_keys: |
| |
| weight_key = "encoder.adaptive_proj.weight" |
| if weight_key in state_dict: |
| in_features = state_dict[weight_key].shape[1] |
| out_features = state_dict[weight_key].shape[0] |
| |
| model.encoder.adaptive_proj = nn.Linear(in_features, out_features) |
| logger.info( |
| f"π Created adaptive_proj: {in_features} -> {out_features}" |
| ) |
|
|
| |
| |
| |
| neuro_proj_wkey = "neuro_proj.weight" |
| if neuro_proj_wkey in state_dict: |
| sd_in = state_dict[neuro_proj_wkey].shape[1] |
| sd_out = state_dict[neuro_proj_wkey].shape[0] |
| if (sd_in != model.neuro_proj.in_features or |
| sd_out != model.neuro_proj.out_features): |
| model.neuro_proj = nn.Linear(sd_in, sd_out) |
| logger.info( |
| f"π Reconciled neuro_proj from weights: {sd_in} -> {sd_out}" |
| ) |
|
|
| |
| |
| |
| model_state = model.state_dict() |
| keys_to_skip = [] |
| for key in list(state_dict.keys()): |
| if key in model_state: |
| if state_dict[key].shape != model_state[key].shape: |
| logger.warning( |
| f" β οΈ Shape mismatch for {key}: " |
| f"checkpoint={list(state_dict[key].shape)} vs " |
| f"model={list(model_state[key].shape)}, skipping" |
| ) |
| keys_to_skip.append(key) |
| for key in keys_to_skip: |
| del state_dict[key] |
|
|
| |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
| |
| |
| all_missing = list(missing) + keys_to_skip |
| if all_missing: |
| logger.warning(f" Missing/skipped keys: {len(all_missing)}") |
| if keys_to_skip: |
| logger.warning(f" ({len(keys_to_skip)} skipped due to shape mismatch)") |
| for k in all_missing[:5]: |
| logger.warning(f" - {k}") |
| if len(all_missing) > 5: |
| logger.warning(f" ... and {len(all_missing) - 5} more") |
| if unexpected: |
| |
| truly_unexpected = [k for k in unexpected if "adaptive_proj" not in k] |
| if truly_unexpected: |
| logger.warning(f" Unexpected keys: {len(truly_unexpected)}") |
| for k in truly_unexpected[:5]: |
| logger.warning(f" - {k}") |
| if len(truly_unexpected) > 5: |
| logger.warning( |
| f" ... and {len(truly_unexpected) - 5} more" |
| ) |
|
|
| if not all_missing and not unexpected: |
| logger.info(f"β
Loaded weights perfectly (0 missing, 0 unexpected)") |
| else: |
| logger.info(f"β
Loaded weights from {model_path}") |
| else: |
| logger.warning("No pytorch_model.bin found, using random weights") |
|
|
| |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
|
|
| logger.info(f"π¦ Model loaded on {device}") |
|
|
| return model |
|
|
| def push_to_hub( |
| self, |
| repo_id: str, |
| token: str = None, |
| private: bool = True, |
| commit_message: str = "Upload NUTATA", |
| save_directory: str = None, |
| ) -> str: |
| """ |
| Push model to HuggingFace Hub. |
| |
| Args: |
| repo_id: HuggingFace repo ID (e.g., "username/model-name") |
| token: HuggingFace API token |
| private: Whether to create private repo |
| commit_message: Commit message |
| save_directory: Temporary directory for saving |
| |
| Returns: |
| URL of the uploaded model |
| """ |
| if not HF_HUB_AVAILABLE: |
| raise ImportError( |
| "huggingface_hub not installed. Install with: pip install huggingface_hub" |
| ) |
|
|
| |
| if save_directory is None: |
| save_directory = tempfile.mkdtemp() |
|
|
| os.makedirs(save_directory, exist_ok=True) |
|
|
| |
| self.save_pretrained(save_directory) |
|
|
| |
| readme_content = self._generate_readme() |
| readme_path = os.path.join(save_directory, "README.md") |
| with open(readme_path, "w") as f: |
| f.write(readme_content) |
| logger.info(f"π README.md generated") |
|
|
| |
| license_content = self._generate_license() |
| license_path = os.path.join(save_directory, "LICENSE") |
| with open(license_path, "w") as f: |
| f.write(license_content) |
| logger.info(f"π LICENSE generated") |
|
|
| |
| modeling_code = self._generate_modeling_code() |
| modeling_path = os.path.join(save_directory, "modeling_nutata.py") |
| with open(modeling_path, "w") as f: |
| f.write(modeling_code) |
| logger.info(f"π modeling_nutata.py generated ({len(modeling_code)} chars)") |
|
|
| |
| requirements = """# NUTATA v1.1 Requirements |
| torch>=2.0.0 |
| torchvision>=0.15.0 |
| tqdm>=4.65.0 |
| huggingface_hub>=0.19.0 |
| numpy>=1.24.0 |
| """ |
| req_path = os.path.join(save_directory, "requirements.txt") |
| with open(req_path, "w") as f: |
| f.write(requirements) |
| logger.info("π requirements.txt generated") |
|
|
| |
| gitattributes = """*.bin filter=lfs diff=lfs merge=lfs -text |
| *.pt filter=lfs diff=lfs merge=lfs -text |
| *.pth filter=lfs diff=lfs merge=lfs -text |
| *.ckpt filter=lfs diff=lfs merge=lfs -text |
| """ |
| gitattr_path = os.path.join(save_directory, ".gitattributes") |
| with open(gitattr_path, "w") as f: |
| f.write(gitattributes) |
|
|
| |
| logger.info("\nπ Files to upload:") |
| for f in os.listdir(save_directory): |
| fpath = os.path.join(save_directory, f) |
| size = os.path.getsize(fpath) / (1024 * 1024) |
| logger.info(f" {f} ({size:.2f} MB)") |
|
|
| |
| api = HfApi(token=token) |
|
|
| try: |
| create_repo(repo_id=repo_id, private=private, token=token, exist_ok=True) |
| logger.info(f"π Repository created/verified: {repo_id}") |
| except Exception as e: |
| logger.warning(f"Could not create repo: {e}") |
|
|
| |
| logger.info("\nβ¬οΈ Uploading to HuggingFace Hub...") |
| api.upload_folder( |
| folder_path=save_directory, |
| repo_id=repo_id, |
| commit_message=commit_message, |
| token=token, |
| ) |
|
|
| url = f"https://huggingface.co/{repo_id}" |
| logger.info(f"π Model pushed to {url}") |
|
|
| return url |
|
|
| def _generate_modeling_code(self) -> str: |
| """ |
| Generate modeling_nutata.py for HuggingFace Hub. |
| This file allows dynamic loading without needing the original codebase. |
| """ |
| |
| expert_counts = [layer.get_expert_count() for layer in self.layers] |
| neuron_count = self.neurogenesis.n_neurons.item() |
|
|
| code = '''#!/usr/bin/env python3 |
| """ |
| NUTATA v1.1 - HuggingFace Hub Modeling File |
| Auto-generated for dynamic loading and fine-tuning. |
| |
| This file contains the complete model architecture for loading from HuggingFace Hub. |
| """ |
| |
| import os |
| import math |
| import json |
| import logging |
| from typing import Dict, List, Optional, Tuple, Any, Union |
| from dataclasses import dataclass, field |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| logger = logging.getLogger(__name__) |
| |
| # ============================================================================== |
| # CONFIGURATION |
| # ============================================================================== |
| |
| @dataclass |
| class VideoConfig: |
| height: int = 64 |
| width: int = 64 |
| channels: int = 3 |
| n_frames: int = 16 |
| fps: int = 8 |
| temporal_downsample: int = 4 |
| spatial_downsample: int = 8 |
| |
| @classmethod |
| def from_dict(cls, d: dict) -> "VideoConfig": |
| valid_keys = {\'height\', \'width\', \'channels\', \'n_frames\', \'fps\', \'temporal_downsample\', \'spatial_downsample\'} |
| return cls(**{k: v for k, v in d.items() if k in valid_keys}) |
| |
| |
| @dataclass |
| class NutataModelConfig: |
| model_type: str = "nutata-videomodel" |
| version: str = "1.1" |
| codename: str = "VideoSim-Cognitive" |
| architecture_type: str = "cognitive-video" |
| d_model: int = 512 |
| d_ff: int = 2048 |
| n_layers: int = 8 |
| n_heads: int = 8 |
| dropout: float = 0.1 |
| latent_dim: int = 256 |
| temporal_latent_dim: int = 512 |
| latent_channels: int = 4 |
| max_frames: int = 64 |
| context_frames: int = 16 |
| prediction_frames: int = 8 |
| encoder_channels: List[int] = field(default_factory=lambda: [64, 128, 256, 512]) |
| decoder_channels: List[int] = field(default_factory=lambda: [512, 256, 128, 64]) |
| kl_weight: float = 0.0001 |
| use_lpol: bool = True |
| memory_size: int = 256 |
| memory_slots_per_domain: int = 32 |
| memory_k: int = 8 |
| domain_types: List[str] = field(default_factory=lambda: [ |
| \'motion\', \'appearance\', \'temporal\', \'spatial\', \'object\', |
| \'scene\', \'action\', \'causality\', \'physics\' |
| ]) |
| use_gqa: bool = True |
| gqa_num_heads: int = 8 |
| gqa_num_kv_groups: int = 2 |
| expert_types: List[str] = field(default_factory=lambda: [ |
| \'Motion\', \'Appearance\', \'Temporal\', \'Spatial\', \'Prediction\', \'Generation\' |
| ]) |
| max_experts: int = 12 |
| growth_threshold_coherence: float = 0.3 |
| growth_patience: int = 10 |
| neurogenesis_enabled: bool = True |
| min_neurons: int = 32 |
| max_neurons: int = 256 |
| neuron_birth_threshold: float = 0.8 |
| neuron_death_threshold: float = 0.05 |
| energy_enabled: bool = True |
| energy_cost_encode: float = 0.01 |
| energy_cost_decode: float = 0.02 |
| energy_cost_predict: float = 0.03 |
| energy_regeneration: float = 0.05 |
| dream_enabled: bool = True |
| dream_cycle_length: int = 100 |
| dream_duration: int = 20 |
| temporal_coherence_weight: float = 0.1 |
| flow_prediction: bool = True |
| perceptual_loss_weight: float = 0.05 |
| batch_size: int = 8 |
| learning_rate: float = 1e-4 |
| epochs: int = 50 |
| gradient_accumulation: int = 4 |
| warmup_steps: int = 500 |
| push_to_hub: bool = True |
| hub_model_id: str = "amewebstudio/nutata-videomodel-v1.1" |
| video: VideoConfig = field(default_factory=VideoConfig) |
| |
| def to_dict(self) -> Dict: |
| d = {} |
| for key, value in self.__dict__.items(): |
| if key == \'video\': |
| d[key] = {k: v for k, v in value.__dict__.items()} |
| elif isinstance(value, (list, dict, str, int, float, bool, type(None))): |
| d[key] = value |
| return d |
| |
| @classmethod |
| def from_dict(cls, d: Dict) -> "NutataModelConfig": |
| d = d.copy() |
| if \'video\' in d and isinstance(d[\'video\'], dict): |
| d[\'video\'] = VideoConfig.from_dict(d[\'video\']) |
| for k in [\'_dynamic_state\', \'_class_name\', \'_version\', \'_architecture\']: |
| d.pop(k, None) |
| known = set(cls.__dataclass_fields__.keys()) |
| return cls(**{k: v for k, v in d.items() if k in known}) |
| |
| |
| # ============================================================================== |
| # BUILDING BLOCKS |
| # ============================================================================== |
| |
| class CausalConv3d(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int, |
| kernel_size: Tuple[int, int, int] = (3, 3, 3), |
| stride: Tuple[int, int, int] = (1, 1, 1)): |
| super().__init__() |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size, kernel_size) |
| if isinstance(stride, int): |
| stride = (stride, stride, stride) |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.temporal_pad = kernel_size[0] - 1 |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, |
| stride=stride, padding=(0, kernel_size[1] // 2, kernel_size[2] // 2)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.temporal_pad > 0: |
| x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0), mode=\'replicate\') |
| return self.conv(x) |
| |
| |
| class SpatioTemporalBlock(nn.Module): |
| def __init__(self, channels: int): |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(1)) |
| self.spatial = nn.Sequential( |
| nn.GroupNorm(min(8, channels), channels), |
| nn.SiLU(), |
| nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)) |
| ) |
| self.temporal = nn.Sequential( |
| nn.GroupNorm(min(8, channels), channels), |
| nn.SiLU(), |
| CausalConv3d(channels, channels, kernel_size=(3, 1, 1)) |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.spatial(x) |
| h = self.temporal(h) |
| return x + self.scale * h |
| |
| |
| class EncoderStage(nn.ModuleList): |
| def __init__(self, in_channels: int, out_channels: int, temporal_down: bool = True): |
| stride = (2, 2, 2) if temporal_down else (1, 2, 2) |
| super().__init__([ |
| SpatioTemporalBlock(in_channels), |
| SpatioTemporalBlock(in_channels), |
| nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) |
| ]) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self[0](x) |
| x = self[1](x) |
| x = self[2](x) |
| return x |
| |
| |
| class DecoderStage(nn.ModuleList): |
| def __init__(self, in_channels: int, out_channels: int, temporal_up: bool = True): |
| stride = (2, 2, 2) if temporal_up else (1, 2, 2) |
| output_padding = (1, 1, 1) if temporal_up else (0, 1, 1) |
| super().__init__([ |
| SpatioTemporalBlock(in_channels), |
| SpatioTemporalBlock(in_channels), |
| nn.ConvTranspose3d(in_channels, out_channels, kernel_size=3, |
| stride=stride, padding=1, output_padding=output_padding) |
| ]) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self[0](x) |
| x = self[1](x) |
| x = self[2](x) |
| return x |
| |
| |
| class RotaryPositionalEmbedding(nn.Module): |
| def __init__(self, dim: int, max_seq_len: int = 1024, base: int = 10000): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer(\'inv_freq\', inv_freq, persistent=False) |
| self._build_cache(max_seq_len) |
| |
| def _build_cache(self, seq_len: int): |
| t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum(\'i,j->ij\', t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| self.register_buffer(\'cos_cache\', emb.cos(), persistent=False) |
| self.register_buffer(\'sin_cache\', emb.sin(), persistent=False) |
| |
| def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: |
| if seq_len > self.cos_cache.shape[0]: |
| self._build_cache(seq_len) |
| return self.cos_cache[:seq_len].to(device), self.sin_cache[:seq_len].to(device) |
| |
| |
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat([-x2, x1], dim=-1) |
| |
| |
| # ============================================================================== |
| # VIDEO VAE |
| # ============================================================================== |
| |
| class VideoVAEEncoder(nn.Module): |
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| channels = config.encoder_channels |
| self.input_conv = nn.Conv3d(config.video.channels, channels[0], kernel_size=(1, 3, 3), padding=(0, 1, 1)) |
| self.encoders = nn.ModuleList([ |
| EncoderStage(channels[0], channels[1], temporal_down=True), |
| EncoderStage(channels[1], channels[2], temporal_down=True), |
| EncoderStage(channels[2], channels[3], temporal_down=False), |
| ]) |
| self.final_blocks = nn.ModuleList([SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])]) |
| self.to_mu = nn.Conv3d(channels[-1], config.latent_channels, 1) |
| self.to_logvar = nn.Conv3d(channels[-1], config.latent_channels, 1) |
| self.to_d_model = nn.Linear(config.latent_channels, config.d_model) |
| self.adaptive_proj = None |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| if x.dim() == 5 and x.shape[2] == self.config.video.channels: |
| x = x.permute(0, 2, 1, 3, 4) |
| B, C, T, H, W = x.shape |
| h = self.input_conv(x) |
| for encoder_stage in self.encoders: |
| h = encoder_stage(h) |
| for block in self.final_blocks: |
| h = block(h) |
| mu_spatial = self.to_mu(h) |
| logvar_spatial = self.to_logvar(h).clamp(-10, 10) |
| std = torch.exp(0.5 * logvar_spatial) |
| eps = torch.randn_like(std) |
| z_spatial = mu_spatial + eps * std |
| B, C_lat, T_lat, H_lat, W_lat = z_spatial.shape |
| z_flat = z_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) |
| if self.adaptive_proj is None or self.adaptive_proj.in_features != z_flat.shape[-1]: |
| self.adaptive_proj = nn.Linear(z_flat.shape[-1], self.config.d_model).to(z_flat.device) |
| z = self.adaptive_proj(z_flat) |
| mu_flat = mu_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) |
| mu = self.adaptive_proj(mu_flat) |
| logvar_flat = logvar_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) |
| logvar = self.adaptive_proj(logvar_flat) |
| return z, mu, logvar, z_spatial |
| |
| |
| class VideoVAEDecoder(nn.Module): |
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| channels = config.decoder_channels |
| self.from_latent = nn.Conv3d(config.latent_channels, channels[0], 1) |
| self.init_blocks = nn.ModuleList([SpatioTemporalBlock(channels[0]), SpatioTemporalBlock(channels[0])]) |
| self.decoders = nn.ModuleList([ |
| DecoderStage(channels[0], channels[1], temporal_up=True), |
| DecoderStage(channels[1], channels[2], temporal_up=True), |
| DecoderStage(channels[2], channels[3], temporal_up=False), |
| ]) |
| self.final_blocks = nn.ModuleList([SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])]) |
| self.to_rgb = nn.Sequential( |
| nn.Conv3d(channels[-1], channels[-1] // 2, (1, 3, 3), padding=(0, 1, 1)), |
| nn.SiLU(), |
| nn.Conv3d(channels[-1] // 2, config.video.channels, (1, 3, 3), padding=(0, 1, 1)), |
| nn.Sigmoid() |
| ) |
| self.temporal_refine = nn.Sequential( |
| CausalConv3d(config.video.channels, 32, (3, 3, 3)), |
| nn.SiLU(), |
| nn.Conv3d(32, config.video.channels, 1), |
| nn.Tanh() |
| ) |
| self.refine_scale = nn.Parameter(torch.tensor(0.05)) |
| |
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| h = self.from_latent(z) |
| for block in self.init_blocks: |
| h = block(h) |
| for decoder_stage in self.decoders: |
| h = decoder_stage(h) |
| for block in self.final_blocks: |
| h = block(h) |
| video = self.to_rgb(h) |
| refine = self.temporal_refine(video) * self.refine_scale |
| video = torch.clamp(video + refine, 0, 1) |
| return video |
| |
| |
| # ============================================================================== |
| # GQA ATTENTION |
| # ============================================================================== |
| |
| class VideoGQA(nn.Module): |
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.d_model = config.d_model |
| self.num_heads = config.gqa_num_heads |
| self.num_kv_groups = config.gqa_num_kv_groups |
| self.head_dim = config.d_model // self.num_heads |
| self.heads_per_group = self.num_heads // self.num_kv_groups |
| self.scale = self.head_dim ** -0.5 |
| self.q_proj = nn.Linear(config.d_model, self.num_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(config.d_model, self.num_kv_groups * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.d_model, self.num_kv_groups * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.d_model, bias=False) |
| self.rope = RotaryPositionalEmbedding(self.head_dim, max_seq_len=config.max_frames * 2) |
| self.dropout = nn.Dropout(config.dropout) |
| self.residual_scale = nn.Parameter(torch.ones(1) * 0.1) |
| |
| def forward(self, x: torch.Tensor, causal: bool = True, use_rope: bool = True) -> torch.Tensor: |
| B, T, D = x.shape |
| q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2) |
| if use_rope: |
| cos, sin = self.rope(T, x.device) |
| q = (q * cos.unsqueeze(0).unsqueeze(0)) + (rotate_half(q) * sin.unsqueeze(0).unsqueeze(0)) |
| k = (k * cos.unsqueeze(0).unsqueeze(0)) + (rotate_half(k) * sin.unsqueeze(0).unsqueeze(0)) |
| k = k.repeat_interleave(self.heads_per_group, dim=1) |
| v = v.repeat_interleave(self.heads_per_group, dim=1) |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| if causal: |
| causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1) |
| attn_weights = attn_weights.masked_fill(causal_mask, float(\'-inf\')) |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) |
| attn_weights = self.dropout(attn_weights) |
| attn_output = torch.matmul(attn_weights, v) |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) |
| return self.o_proj(attn_output) |
| |
| |
| # ============================================================================== |
| # LPOL MEMORY |
| # ============================================================================== |
| |
| class VideoLPOL(nn.Module): |
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| self.n_domains = len(config.domain_types) |
| self.slots_per_domain = config.memory_slots_per_domain |
| self.memories = nn.ParameterDict({ |
| d: nn.Parameter(torch.randn(self.slots_per_domain, config.d_model) * 0.02) |
| for d in config.domain_types |
| }) |
| self.memory_attn = nn.ModuleDict({ |
| \'q_proj\': nn.Linear(config.d_model, config.d_model, bias=False), |
| \'k_proj\': nn.Linear(config.d_model, config.d_model, bias=False), |
| \'v_proj\': nn.Linear(config.d_model, config.d_model, bias=False), |
| \'o_proj\': nn.Linear(config.d_model, config.d_model, bias=False), |
| }) |
| self.domain_clf = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), |
| nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, self.n_domains) |
| ) |
| self.fusion = nn.Sequential(nn.Linear(config.d_model * 2, config.d_model), nn.GELU(), nn.Linear(config.d_model, config.d_model)) |
| self.gate = nn.Sequential(nn.Linear(config.d_model * 2, config.d_model), nn.Sigmoid()) |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict]: |
| B, T, D = x.shape |
| x_pooled = x.mean(dim=1) |
| domain_logits = self.domain_clf(x_pooled) |
| domain_probs = F.softmax(domain_logits, dim=-1) |
| all_memories = [] |
| for i, domain_name in enumerate(self.config.domain_types): |
| mem = self.memories[domain_name] |
| weight = domain_probs[:, i:i+1] |
| weighted_mem = mem.unsqueeze(0) * weight.unsqueeze(-1) |
| all_memories.append(weighted_mem) |
| memory_bank = torch.cat(all_memories, dim=1) |
| q = self.memory_attn[\'q_proj\'](x) |
| k = self.memory_attn[\'k_proj\'](memory_bank) |
| v = self.memory_attn[\'v_proj\'](memory_bank) |
| n_heads = 8 |
| head_dim = D // n_heads |
| q = q.view(B, T, n_heads, head_dim).transpose(1, 2) |
| k = k.view(B, -1, n_heads, head_dim).transpose(1, 2) |
| v = v.view(B, -1, n_heads, head_dim).transpose(1, 2) |
| scale = head_dim ** -0.5 |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale |
| attn_weights = F.softmax(attn_weights, dim=-1) |
| retrieved = torch.matmul(attn_weights, v) |
| retrieved = retrieved.transpose(1, 2).contiguous().view(B, T, D) |
| retrieved = self.memory_attn[\'o_proj\'](retrieved) |
| concat = torch.cat([x, retrieved], dim=-1) |
| gate = self.gate(concat) |
| fused = self.fusion(concat) |
| output = x + gate * fused |
| return output, {\'domain_probs\': domain_probs, \'top_domain\': domain_probs.argmax(dim=-1)} |
| |
| |
| # ============================================================================== |
| # EXPERTS AND EARCP |
| # ============================================================================== |
| |
| class VideoExpert(nn.Module): |
| def __init__(self, config: NutataModelConfig, expert_type: str): |
| super().__init__() |
| self.expert_type = expert_type |
| self.confidence = nn.Sequential(nn.Linear(config.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()) |
| self.gate = nn.Linear(config.d_model, config.d_model) |
| self.fc1 = nn.Linear(config.d_model, config.d_ff) |
| self.fc2 = nn.Linear(config.d_ff, config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| conf = self.confidence(x.mean(dim=1)) |
| gate_val = torch.sigmoid(self.gate(x)) |
| h = self.dropout(F.gelu(self.fc1(x))) |
| h = self.fc2(h) |
| out = h * gate_val |
| return out, conf |
| |
| |
| class VideoEARCPLayer(nn.Module): |
| def __init__(self, config: NutataModelConfig, layer_idx: int, n_experts: int = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| if n_experts is None: |
| n_experts = len(config.expert_types) |
| self.attn_norm = nn.LayerNorm(config.d_model, elementwise_affine=False) |
| self.attn_scale = nn.Parameter(torch.ones(1)) |
| self.temporal_attn = VideoGQA(config) |
| self.experts = nn.ModuleList([ |
| VideoExpert(config, config.expert_types[i] if i < len(config.expert_types) else f"Hybrid_{i}") |
| for i in range(n_experts) |
| ]) |
| self.router = nn.Linear(config.d_model, n_experts) |
| self.register_buffer(\'low_coh_count\', torch.tensor(0)) |
| self.coherence_score = 0.5 |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float, bool]: |
| h = self.attn_norm(x) |
| attn_out = self.temporal_attn(h, causal=True) |
| x = x + self.attn_scale * attn_out |
| router_input = x.mean(dim=1) |
| router_logits = self.router(router_input) |
| weights = F.softmax(router_logits, dim=-1) |
| expert_outputs = [] |
| confs = [] |
| for expert in self.experts: |
| out, conf = expert(x) |
| expert_outputs.append(out) |
| confs.append(conf) |
| expert_outputs = torch.stack(expert_outputs, dim=1) |
| weighted_out = torch.einsum(\'be,betd->btd\', weights, expert_outputs) |
| x = x + weighted_out |
| confs_tensor = torch.stack(confs, dim=1) |
| expert_conf = confs_tensor.mean().item() |
| entropy = -(weights * weights.log().clamp(min=-100)).sum(dim=-1).mean().item() |
| max_entropy = math.log(len(self.experts)) if len(self.experts) > 1 else 1.0 |
| routing_focus = 1 - (entropy / max_entropy) |
| coherence = 0.5 * expert_conf + 0.5 * routing_focus |
| self.coherence_score = coherence |
| grew = False |
| if coherence < self.config.growth_threshold_coherence: |
| self.low_coh_count += 1 |
| if self.low_coh_count >= self.config.growth_patience and len(self.experts) < self.config.max_experts: |
| new_expert = VideoExpert(self.config, f"Hybrid_{len(self.experts)}").to(x.device) |
| self.experts.append(new_expert) |
| old_router = self.router |
| self.router = nn.Linear(self.config.d_model, len(self.experts)).to(x.device) |
| with torch.no_grad(): |
| self.router.weight[:old_router.out_features] = old_router.weight |
| self.router.bias[:old_router.out_features] = old_router.bias |
| self.low_coh_count.zero_() |
| grew = True |
| else: |
| self.low_coh_count.zero_() |
| return x, coherence, grew |
| |
| def get_expert_count(self) -> int: |
| return len(self.experts) |
| |
| |
| # ============================================================================== |
| # NEUROGENESIS |
| # ============================================================================== |
| |
| class VideoNeurogenesis(nn.Module): |
| def __init__(self, input_dim: int, n_neurons: int, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| self.input_dim = input_dim |
| self.weights = nn.Parameter(torch.randn(n_neurons, input_dim) * 0.02) |
| self.bias = nn.Parameter(torch.zeros(n_neurons)) |
| self.temporal_gate = nn.Linear(input_dim, n_neurons) |
| self.register_buffer(\'n_neurons\', torch.tensor(n_neurons)) |
| self.register_buffer(\'usage\', torch.ones(n_neurons)) |
| self.register_buffer(\'lifetime\', torch.zeros(n_neurons)) |
| self.register_buffer(\'births\', torch.tensor(0)) |
| self.register_buffer(\'deaths\', torch.tensor(0)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| n = self.n_neurons.item() |
| gate = torch.sigmoid(self.temporal_gate(x)) |
| gate = gate[..., :n] |
| out = torch.tanh(F.linear(x, self.weights[:n], self.bias[:n])) |
| out = out * gate |
| with torch.no_grad(): |
| act = out.abs().mean(dim=(0, 1)) |
| if act.size(0) == n: |
| self.usage[:n] = 0.99 * self.usage[:n] + 0.01 * act |
| self.lifetime[:n] += 1 |
| return out |
| |
| def maybe_grow(self, coherence: float) -> int: |
| if not self.config.neurogenesis_enabled: |
| return 0 |
| n = self.n_neurons.item() |
| if n >= self.config.max_neurons or coherence < self.config.neuron_birth_threshold: |
| return 0 |
| device = self.weights.device |
| with torch.no_grad(): |
| new_w = torch.randn(1, self.input_dim, device=device) * 0.02 |
| new_b = torch.zeros(1, device=device) |
| self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0)) |
| self.bias = nn.Parameter(torch.cat([self.bias.data, new_b])) |
| old_gate = self.temporal_gate |
| self.temporal_gate = nn.Linear(self.input_dim, n + 1).to(device) |
| self.temporal_gate.weight.data[:n] = old_gate.weight.data |
| self.temporal_gate.weight.data[n:] = torch.randn(1, self.input_dim, device=device) * 0.02 |
| self.temporal_gate.bias.data[:n] = old_gate.bias.data |
| self.temporal_gate.bias.data[n:] = 0 |
| self.usage = torch.cat([self.usage, torch.ones(1, device=device)]) |
| self.lifetime = torch.cat([self.lifetime, torch.zeros(1, device=device)]) |
| self.n_neurons += 1 |
| self.births += 1 |
| return 1 |
| |
| def resize(self, target_neurons: int): |
| current = self.n_neurons.item() |
| if target_neurons == current: |
| return |
| device = self.weights.device |
| if target_neurons > current: |
| extra = target_neurons - current |
| new_w = torch.randn(extra, self.input_dim, device=device) * 0.02 |
| new_b = torch.zeros(extra, device=device) |
| self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0)) |
| self.bias = nn.Parameter(torch.cat([self.bias.data, new_b])) |
| old_gate = self.temporal_gate |
| self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device) |
| with torch.no_grad(): |
| self.temporal_gate.weight[:current] = old_gate.weight |
| self.temporal_gate.bias[:current] = old_gate.bias |
| self.usage = torch.cat([self.usage, torch.ones(extra, device=device)]) |
| self.lifetime = torch.cat([self.lifetime, torch.zeros(extra, device=device)]) |
| else: |
| keep_indices = torch.argsort(self.usage, descending=True)[:target_neurons] |
| self.weights = nn.Parameter(self.weights.data[keep_indices]) |
| self.bias = nn.Parameter(self.bias.data[keep_indices]) |
| old_gate = self.temporal_gate |
| self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device) |
| with torch.no_grad(): |
| self.temporal_gate.weight[:] = old_gate.weight[keep_indices] |
| self.temporal_gate.bias[:] = old_gate.bias[keep_indices] |
| self.usage = self.usage[keep_indices] |
| self.lifetime = self.lifetime[keep_indices] |
| self.n_neurons.fill_(target_neurons) |
| |
| def get_stats(self) -> Dict: |
| return {\'total_neurons\': self.n_neurons.item(), \'total_births\': self.births.item(), |
| \'total_deaths\': self.deaths.item(), \'avg_usage\': self.usage[:self.n_neurons.item()].mean().item(), |
| \'max_lifetime\': self.lifetime[:self.n_neurons.item()].max().item()} |
| |
| |
| # ============================================================================== |
| # TEMPORAL COHERENCE & FLOW |
| # ============================================================================== |
| |
| class TemporalCoherenceModule(nn.Module): |
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| d = config.d_model |
| self.diff_predictor = nn.Sequential(nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d)) |
| self.smooth = nn.Conv1d(d, d, kernel_size=3, padding=1, groups=d) |
| self.alpha = nn.Parameter(torch.tensor(0.2)) |
| self.register_buffer(\'coherence_history\', torch.zeros(100)) |
| self.register_buffer(\'history_idx\', torch.tensor(0)) |
| |
| def forward(self, z_seq: torch.Tensor) -> Tuple[torch.Tensor, float]: |
| B, T, D = z_seq.shape |
| if T > 1: |
| diffs = z_seq[:, 1:] - z_seq[:, :-1] |
| pairs = torch.cat([z_seq[:, :-1], z_seq[:, 1:]], dim=-1) |
| pred_diffs = self.diff_predictor(pairs) |
| coherence = 1 - F.mse_loss(pred_diffs, diffs).item() |
| coherence = max(0, min(1, coherence)) |
| else: |
| coherence = 1.0 |
| z_t = z_seq.transpose(1, 2) |
| smoothed = self.smooth(z_t).transpose(1, 2) |
| alpha = torch.sigmoid(self.alpha) |
| output = (1 - alpha) * z_seq + alpha * smoothed |
| idx = self.history_idx.item() % 100 |
| self.coherence_history[idx] = coherence |
| self.history_idx += 1 |
| return output, coherence |
| |
| def get_average_coherence(self) -> float: |
| valid = min(self.history_idx.item(), 100) |
| return self.coherence_history[:valid].mean().item() if valid > 0 else 0.0 |
| |
| |
| class FlowPredictionModule(nn.Module): |
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| d = config.d_model |
| self.flow_encoder = nn.Sequential(nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d // 2), nn.SiLU(), nn.Linear(d // 2, d)) |
| self.warp_net = nn.Sequential(nn.Linear(d * 2, d), nn.Tanh()) |
| self.motion_magnitude = nn.Sequential(nn.Linear(d, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()) |
| |
| def forward(self, z_seq: torch.Tensor) -> Dict: |
| B, T, D = z_seq.shape |
| if T < 2: |
| return {\'flow\': None, \'warped\': z_seq, \'motion_magnitude\': torch.zeros(B, 1, device=z_seq.device), |
| \'flow_loss\': torch.tensor(0.0, device=z_seq.device)} |
| z_t = z_seq[:, :-1] |
| z_t1 = z_seq[:, 1:] |
| pairs = torch.cat([z_t, z_t1], dim=-1) |
| flow = self.flow_encoder(pairs) |
| warp_input = torch.cat([z_t, flow], dim=-1) |
| warped = self.warp_net(warp_input) |
| motion = self.motion_magnitude(flow.mean(dim=1)) |
| flow_loss = F.mse_loss(warped, z_t1) |
| return {\'flow\': flow, \'warped\': warped, \'motion_magnitude\': motion, \'flow_loss\': flow_loss} |
| |
| |
| class VideoEnergySystem(nn.Module): |
| def __init__(self, config: NutataModelConfig): |
| super().__init__() |
| self.config = config |
| self.register_buffer(\'energy\', torch.tensor(1.0)) |
| self.register_buffer(\'consumed\', torch.tensor(0.0)) |
| self.costs = {\'encode\': config.energy_cost_encode, \'decode\': config.energy_cost_decode, |
| \'predict\': config.energy_cost_predict, \'process\': 0.01, \'memory\': 0.005, \'attention\': 0.008} |
| |
| def consume(self, operation: str, amount: float = None) -> bool: |
| cost = amount if amount else self.costs.get(operation, 0.01) |
| if self.energy.item() >= cost: |
| self.energy -= cost |
| self.consumed += cost |
| return True |
| return False |
| |
| def regenerate(self): |
| regen = min(self.config.energy_regeneration, 1.0 - self.energy.item()) |
| self.energy += regen |
| |
| def reset(self): |
| self.energy.fill_(1.0) |
| |
| def get_stats(self) -> Dict: |
| return {\'energy\': self.energy.item(), \'consumed\': self.consumed.item()} |
| |
| |
| # ============================================================================== |
| # MAIN MODEL |
| # ============================================================================== |
| |
| class NutataModel(nn.Module): |
| def __init__(self, config: NutataModelConfig = None, expert_counts: List[int] = None): |
| super().__init__() |
| self.config = config or NutataModelConfig() |
| if expert_counts is None: |
| expert_counts = [len(self.config.expert_types)] * self.config.n_layers |
| self.encoder = VideoVAEEncoder(self.config) |
| self.decoder = VideoVAEDecoder(self.config) |
| self.lpol = VideoLPOL(self.config) if self.config.use_lpol else None |
| self.layers = nn.ModuleList([ |
| VideoEARCPLayer(self.config, i, n_experts=expert_counts[i]) |
| for i in range(self.config.n_layers) |
| ]) |
| self.neurogenesis = VideoNeurogenesis(self.config.d_model, 64, self.config) |
| self.neuro_proj = nn.Linear(self.config.max_neurons, self.config.d_model) |
| self.temporal_coherence = TemporalCoherenceModule(self.config) |
| self.flow_module = FlowPredictionModule(self.config) if self.config.flow_prediction else None |
| self.energy = VideoEnergySystem(self.config) |
| self.frame_predictor = nn.Sequential( |
| nn.Linear(self.config.d_model, self.config.d_model), nn.SiLU(), |
| nn.Dropout(self.config.dropout), nn.Linear(self.config.d_model, self.config.latent_channels * 8 * 8) |
| ) |
| |
| def encode(self, video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| self.energy.consume(\'encode\') |
| return self.encoder(video) |
| |
| def decode(self, z_spatial: torch.Tensor) -> torch.Tensor: |
| self.energy.consume(\'decode\') |
| return self.decoder(z_spatial) |
| |
| def process_temporal(self, z: torch.Tensor) -> Dict: |
| self.energy.consume(\'process\') |
| lpol_info = {} |
| if self.lpol is not None: |
| z, lpol_info = self.lpol(z) |
| coherences = [] |
| total_growth = 0 |
| for layer in self.layers: |
| z, coh, grew = layer(z) |
| coherences.append(coh) |
| if grew: |
| total_growth += 1 |
| neuro_out = self.neurogenesis(z) |
| current_neurons = neuro_out.shape[-1] |
| if current_neurons < self.config.max_neurons: |
| padding = torch.zeros(*neuro_out.shape[:-1], self.config.max_neurons - current_neurons, device=neuro_out.device, dtype=neuro_out.dtype) |
| neuro_out_padded = torch.cat([neuro_out, padding], dim=-1) |
| else: |
| neuro_out_padded = neuro_out[..., :self.config.max_neurons] |
| neuro_proj = self.neuro_proj(neuro_out_padded) |
| z = z + 0.1 * neuro_proj.mean(dim=-1, keepdim=True).expand_as(z) |
| avg_coherence = sum(coherences) / len(coherences) if coherences else 0.0 |
| neuro_growth = self.neurogenesis.maybe_grow(avg_coherence) |
| z, temp_coherence = self.temporal_coherence(z) |
| flow_info = {} |
| if self.flow_module is not None: |
| flow_info = self.flow_module(z) |
| self.energy.regenerate() |
| return {\'z\': z, \'coherence\': avg_coherence, \'temporal_coherence\': temp_coherence, |
| \'expert_growth\': total_growth, \'neuro_growth\': neuro_growth, \'lpol_info\': lpol_info, |
| \'flow_info\': flow_info, \'energy\': self.energy.get_stats()} |
| |
| def forward(self, video: torch.Tensor) -> Dict: |
| B = video.shape[0] |
| z, mu, logvar, z_spatial = self.encode(video) |
| proc_out = self.process_temporal(z) |
| z_processed = proc_out[\'z\'] |
| recon = self.decode(z_spatial) |
| if video.shape[2] == self.config.video.channels: |
| video_compare = video.permute(0, 2, 1, 3, 4) |
| else: |
| video_compare = video |
| if recon.shape != video_compare.shape: |
| recon = F.interpolate(recon, size=video_compare.shape[2:], mode=\'trilinear\', align_corners=False) |
| recon_loss = F.mse_loss(recon, video_compare) |
| kl_loss = 0.5 * torch.mean(mu.pow(2) + logvar.exp() - logvar - 1) |
| coherence_loss = float(1.0 - proc_out[\'temporal_coherence\']) |
| flow_loss_tensor = torch.tensor(0.0, device=recon.device) |
| if \'flow_info\' in proc_out and proc_out[\'flow_info\'].get(\'flow_loss\') is not None: |
| flow_loss_tensor = proc_out[\'flow_info\'][\'flow_loss\'] |
| total_loss = recon_loss + self.config.kl_weight * kl_loss + self.config.temporal_coherence_weight * coherence_loss + 0.01 * flow_loss_tensor |
| flow_loss = flow_loss_tensor.item() if hasattr(flow_loss_tensor, \'item\') else float(flow_loss_tensor) |
| return {\'loss\': total_loss, \'recon_loss\': recon_loss, \'kl_loss\': kl_loss, \'coherence_loss\': coherence_loss, |
| \'flow_loss\': flow_loss, \'recon\': recon, \'z\': z, \'z_spatial\': z_spatial, |
| \'coherence\': proc_out[\'coherence\'], \'temporal_coherence\': proc_out[\'temporal_coherence\'], |
| \'neurogenesis\': proc_out[\'neuro_growth\'], \'expert_growth\': proc_out[\'expert_growth\'], \'energy\': proc_out[\'energy\']} |
| |
| def generate(self, n_frames: int = 16, z_init: torch.Tensor = None, temperature: float = 1.0, batch_size: int = 1) -> torch.Tensor: |
| self.eval() |
| device = next(self.parameters()).device |
| if z_init is None: |
| B = batch_size |
| T = max(1, n_frames // 4) |
| H = self.config.video.height // 8 |
| W = self.config.video.width // 8 |
| z_init = torch.randn(B, self.config.latent_channels, T, H, W, device=device) * temperature |
| with torch.no_grad(): |
| video = self.decode(z_init) |
| if video.shape[2] != n_frames: |
| video = F.interpolate(video, size=(n_frames, video.shape[3], video.shape[4]), mode=\'trilinear\', align_corners=False) |
| return video.clamp(0, 1) |
| |
| def count_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
| |
| def diagnostics(self) -> Dict: |
| total_experts = sum(layer.get_expert_count() for layer in self.layers) |
| return {\'model_version\': self.config.version, \'total_params\': self.count_params(), |
| \'total_experts\': total_experts, \'expert_counts\': [layer.get_expert_count() for layer in self.layers], |
| \'neurogenesis\': self.neurogenesis.get_stats(), \'energy\': self.energy.get_stats()} |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_path: str, device: str = None, **kwargs) -> "NutataModel": |
| from huggingface_hub import snapshot_download |
| if os.path.isdir(pretrained_path): |
| load_dir = pretrained_path |
| else: |
| load_dir = snapshot_download(repo_id=pretrained_path) |
| config_path = os.path.join(load_dir, "config.json") |
| dynamic_state = None |
| if os.path.exists(config_path): |
| with open(config_path, \'r\') as f: |
| config_dict = json.load(f) |
| dynamic_state = config_dict.pop(\'_dynamic_state\', None) |
| config_dict.pop(\'_class_name\', None) |
| config_dict.pop(\'_version\', None) |
| config_dict.pop(\'_architecture\', None) |
| config_dict.update(kwargs) |
| config = NutataModelConfig.from_dict(config_dict) |
| else: |
| config = NutataModelConfig(**kwargs) |
| expert_counts = None |
| neuron_count = 64 |
| neuro_proj_in = None # Will use max_neurons by default |
| adaptive_proj_in = None |
| adaptive_proj_out = None |
| if dynamic_state is not None: |
| expert_counts = dynamic_state.get(\'expert_counts\') |
| neuron_count = dynamic_state.get(\'neuron_count\', 64) |
| neuro_proj_in = dynamic_state.get(\'neuro_proj_in\', neuron_count) |
| adaptive_proj_in = dynamic_state.get(\'adaptive_proj_in\') |
| adaptive_proj_out = dynamic_state.get(\'adaptive_proj_out\') |
| model = cls(config, expert_counts=expert_counts) |
| if neuron_count != model.neurogenesis.n_neurons.item(): |
| model.neurogenesis.resize(neuron_count) |
| # For backward compatibility with old checkpoints that had smaller neuro_proj |
| if neuro_proj_in and neuro_proj_in != model.neuro_proj.in_features: |
| model.neuro_proj = nn.Linear(neuro_proj_in, config.d_model) |
| if adaptive_proj_in and adaptive_proj_out: |
| model.encoder.adaptive_proj = nn.Linear(adaptive_proj_in, adaptive_proj_out) |
| model_path = os.path.join(load_dir, "pytorch_model.bin") |
| if os.path.exists(model_path): |
| state_dict = torch.load(model_path, map_location=\'cpu\') |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| logger.warning(f"Missing keys: {len(missing)}") |
| if unexpected: |
| logger.warning(f"Unexpected keys: {len(unexpected)}") |
| if device is None: |
| device = \'cuda\' if torch.cuda.is_available() else \'cpu\' |
| return model.to(device) |
| ''' |
| return code |
|
|
| def _generate_readme(self) -> str: |
| """Generate README.md for HuggingFace Hub""" |
| diag = self.diagnostics() |
|
|
| readme = f'''--- |
| license: other |
| license_name: ame-web-studio-proprietary |
| license_link: LICENSE |
| library_name: pytorch |
| tags: |
| - video |
| - video-generation |
| - video-understanding |
| - cognitive-architecture |
| - world-model |
| - nexus |
| - earcp |
| - lpol |
| - neurogenesis |
| - 3d-vae |
| - gqa |
| language: |
| - en |
| pipeline_tag: video-classification |
| --- |
| |
| # NUTATA v{self.config.version} |
| |
| <p align="center"> |
| <img src="https://img.shields.io/badge/NUTATA-blue?style=for-the-badge" alt="NUTATA"/> |
| <img src="https://img.shields.io/badge/Version-{self.config.version}-green?style=for-the-badge" alt="Version"/> |
| <img src="https://img.shields.io/badge/License-Proprietary-red?style=for-the-badge" alt="License"/> |
| </p> |
| |
| ## π§ Cognitive Video World Model |
| |
| **NUTATA** is a revolutionary cognitive architecture for video understanding and generation, developed by **Mike Amega (Logo)** at **Ame Web Studio**. |
| |
| This model combines multiple novel components into a unified framework for learning video representations with cognitive capabilities. |
| |
| ## ποΈ Architecture Overview |
| |
| ``` |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β NUTATA v{self.config.version} β |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| β β |
| β ββββββββββββββββ ββββββββββββββββββββ ββββββββββββββββββββ β |
| β β 3D VAE β β Cognitive β β 3D VAE β β |
| β β Encoder βββββΆβ Processor βββββΆβ Decoder β β |
| β β (Causal) β β (EARCP+LPOL+GQA) β β β β |
| β ββββββββββββββββ ββββββββββββββββββββ ββββββββββββββββββββ β |
| β β |
| β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β |
| β β Cognitive Components β β |
| β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ ββββββββββ β β |
| β β β LPOL β β EARCP β β GQA β β Neuro- β β Flow β β β |
| β β β Memory β β Experts β β Attn β β genesis β β Pred. β β β |
| β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ ββββββββββ β β |
| β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β |
| β β |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| ``` |
| |
| ### Key Components |
| |
| | Component | Description | |
| |-----------|-------------| |
| | **3D Causal VAE** | Spatiotemporal encoding with causal convolutions for video compression | |
| | **VideoLPOL** | 9-domain Long-term Procedural & Operational Learning memory system | |
| | **VideoEARCP** | Ensemble Auto-Regulated by Coherence & Performance with dynamic experts | |
| | **VideoGQA** | Grouped Query Attention with RoPE for efficient long-sequence processing | |
| | **Neurogenesis** | Dynamic neural capacity adaptation based on task complexity | |
| | **Flow Prediction** | Latent optical flow for temporal coherence and motion understanding | |
| | **Energy System** | Cognitive load management and resource allocation | |
| |
| ### LPOL Memory Domains |
| |
| 1. **Motion** - Movement patterns and trajectories |
| 2. **Appearance** - Visual features and textures |
| 3. **Temporal** - Time dynamics and sequences |
| 4. **Spatial** - Spatial relationships and layouts |
| 5. **Object** - Object persistence and tracking |
| 6. **Scene** - Scene context and semantics |
| 7. **Action** - Action recognition patterns |
| 8. **Causality** - Cause-effect relationships |
| 9. **Physics** - Physical dynamics understanding |
| |
| ## π Model Specifications |
| |
| | Specification | Value | |
| |--------------|-------| |
| | **Parameters** | {diag["total_params"]:,} | |
| | **Architecture** | {self.config.codename} | |
| | **Input Shape** | `[B, {self.config.video.channels}, {self.config.video.n_frames}, {self.config.video.height}, {self.config.video.width}]` | |
| | **Latent Dim** | {self.config.d_model} | |
| | **EARCP Layers** | {self.config.n_layers} | |
| | **Total Experts** | {diag["total_experts"]} | |
| | **Expert Counts** | {diag["expert_counts"]} | |
| | **GQA Heads** | {self.config.gqa_num_heads} ({self.config.gqa_num_kv_groups} KV groups) | |
| | **Memory Domains** | {len(self.config.domain_types)} | |
| | **Neurons** | {diag["neurogenesis"]["total_neurons"]} | |
| |
| ## π Quick Start |
| |
| ### Installation |
| |
| ```bash |
| pip install torch torchvision huggingface_hub tqdm |
| ``` |
| |
| ### Loading the Model |
| |
| ```python |
| from nutata_model import NutataModel |
| |
| # Load from HuggingFace Hub |
| model = NutataModel.from_pretrained("{self.config.hub_model_id}") |
| |
| # Or load locally |
| model = NutataModel.from_pretrained("./path/to/model") |
| ``` |
| |
| ### Video Reconstruction |
| |
| ```python |
| import torch |
| |
| # Create sample video: [B, C, T, H, W] |
| video = torch.randn(1, 3, 16, 64, 64).cuda() |
| |
| # Forward pass |
| model.eval() |
| with torch.no_grad(): |
| outputs = model(video) |
| |
| print(f"Reconstruction shape: {{outputs['recon'].shape}}") |
| print(f"Loss: {{outputs['loss'].item():.4f}}") |
| print(f"EARCP Coherence: {{outputs['coherence']:.3f}}") |
| print(f"Temporal Coherence: {{outputs['temporal_coherence']:.3f}}") |
| ``` |
| |
| ### Video Generation |
| |
| ```python |
| # Generate video from scratch |
| model.eval() |
| with torch.no_grad(): |
| generated = model.generate(n_frames=16, temperature=0.8) |
| print(f"Generated shape: {{generated.shape}}") |
| |
| # Generate from initial latent |
| z_init = torch.randn(1, {self.config.latent_channels}, 4, 8, 8).cuda() |
| with torch.no_grad(): |
| generated = model.generate(n_frames=16, z_init=z_init) |
| ``` |
| |
| ### Future Frame Prediction |
| |
| ```python |
| # Encode context frames |
| video_context = torch.randn(1, 3, 8, 64, 64).cuda() |
| model.eval() |
| |
| with torch.no_grad(): |
| z, _, _, _ = model.encode(video_context) |
| |
| # Predict future frames |
| z_future = model.predict_frames(z, n_frames=8) |
| print(f"Predicted latent shape: {{z_future.shape}}") |
| ``` |
| |
| ### Video Continuation |
| |
| ```python |
| # Continue an existing video |
| video_input = torch.randn(1, 3, 8, 64, 64).cuda() |
| |
| model.eval() |
| with torch.no_grad(): |
| continued = model.continue_video(video_input, n_frames=8) |
| print(f"Continued video shape: {{continued.shape}}") # [1, 3, 16, 64, 64] |
| ``` |
| |
| ### Fine-tuning |
| |
| ```python |
| from torch.optim import AdamW |
| |
| # Load pretrained model |
| model = NutataModel.from_pretrained("{self.config.hub_model_id}") |
| model.train() |
| |
| # Setup optimizer |
| optimizer = AdamW(model.parameters(), lr=1e-5) |
| |
| # Fine-tune on your dataset |
| for batch in dataloader: |
| video = batch['video'].cuda() |
| outputs = model(video) |
| loss = outputs['loss'] |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| # Save fine-tuned model |
| model.save_pretrained("./my_finetuned_model") |
| ``` |
| |
| ## π Repository Structure |
| |
| ``` |
| . |
| βββ README.md # This file |
| βββ config.json # Model configuration with dynamic state |
| βββ pytorch_model.bin # Model weights |
| βββ architecture.json # Architecture details |
| βββ training_state.json # Training state (for resuming) |
| βββ nutata_model.py # Model source code |
| βββ requirements.txt # Dependencies |
| βββ LICENSE # Proprietary license |
| ``` |
| |
| ## π¬ Technical Details |
| |
| ### Grouped Query Attention (GQA) |
| |
| GQA reduces memory usage by sharing Key-Value heads across query heads: |
| - **Query heads**: {self.config.gqa_num_heads} |
| - **KV groups**: {self.config.gqa_num_kv_groups} |
| - **Memory savings**: {int((1 - self.config.gqa_num_kv_groups / self.config.gqa_num_heads) * 100)}% |
| |
| ### EARCP (Ensemble Auto-Regulated by Coherence & Performance) |
| |
| Dynamic expert routing with automatic growth: |
| - Initial experts: {len(self.config.expert_types)} |
| - Maximum experts: {self.config.max_experts} |
| - Growth threshold: coherence < {self.config.growth_threshold_coherence} |
| |
| ### Neurogenesis |
| |
| Adaptive neural capacity: |
| - Initial neurons: 64 |
| - Maximum neurons: {self.config.max_neurons} |
| - Birth threshold: coherence > {self.config.neuron_birth_threshold} |
| |
| ### Checkpoint Compatibility |
| |
| This model uses a specific architecture structure for checkpoint compatibility: |
| - `EncoderStage`/`DecoderStage` as `nn.ModuleList` (creates `encoders.0.0`, `encoders.0.1`, etc.) |
| - `temporal_coherence.smooth` as depthwise `Conv1d` (groups=d_model) |
| - `lpol.memory_attn` as `nn.ModuleDict` with q/k/v/o_proj |
| - `VideoExpert` with `fc1`/`fc2` Linear layers |
| - `attn_norm` as `LayerNorm` with `elementwise_affine=False` |
| |
| ## π Training Metrics |
| |
| The model was trained with the following target metrics: |
| - **PSNR**: > 25 dB |
| - **SSIM**: > 0.95 |
| - **Temporal Coherence**: > 0.35 |
| |
| ## π§ Model Diagnostics |
| |
| ```python |
| # Get comprehensive diagnostics |
| diag = model.diagnostics() |
| print(f"Total parameters: {{diag['total_params']:,}}") |
| print(f"Expert counts: {{diag['expert_counts']}}") |
| print(f"Neurogenesis stats: {{diag['neurogenesis']}}") |
| print(f"Energy stats: {{diag['energy']}}") |
| ``` |
| |
| ## β οΈ License |
| |
| This model is released under the **Ame Web Studio Proprietary License**. |
| |
| **IMPORTANT**: This is NOT an open-source model. Usage is restricted to: |
| - Personal research and evaluation |
| - Non-commercial academic use |
| - Commercial use requires explicit licensing agreement |
| |
| For licensing inquiries, contact: **contact@amewebstudio.com** |
| |
| ## π Citation |
| |
| If you use this model in your research, please cite: |
| |
| ```bibtex |
| @software{{nutata_model, |
| author = {{Amega, Mike (Logo)}}, |
| title = {{NUTATA: Cognitive Video World Model}}, |
| version = {{{self.config.version}}}, |
| year = {{2026}}, |
| publisher = {{Ame Web Studio}}, |
| url = {{https://huggingface.co/{self.config.hub_model_id}}} |
| }} |
| ``` |
| |
| ## π Links |
| |
| - **Author**: Mike Amega (Logo) |
| - **Organization**: Ame Web Studio |
| - **Website**: [amewebstudio.com](https://amewebstudio.com) |
| - **Related Projects**: NEXUS-Core, EARCP Library, LPOL Architecture |
| |
| --- |
| |
| <p align="center"> |
| <b>Built with π§ by Ame Web Studio</b><br> |
| <i>"Learning to Understand and Generate Video with Cognition"</i> |
| </p> |
| ''' |
| return readme |
|
|
| def _generate_license(self) -> str: |
| """Generate proprietary LICENSE file""" |
| return """AME WEB STUDIO PROPRIETARY LICENSE |
| Version 1.0, January 2026 |
| |
| Copyright (c) 2026 Mike Amega (Logo) - Ame Web Studio |
| All Rights Reserved. |
| |
| ================================================================================ |
| NUTATA License |
| ================================================================================ |
| |
| This software and associated documentation files (the "Software") are the |
| proprietary property of Ame Web Studio and Mike Amega ("Logo"). |
| |
| GRANT OF LICENSE |
| ---------------- |
| |
| Subject to the terms of this license, you are granted a limited, non-exclusive, |
| non-transferable license to: |
| |
| 1. PERSONAL USE: Use the Software for personal, non-commercial research and |
| evaluation purposes. |
| |
| 2. ACADEMIC USE: Use the Software for non-commercial academic research, |
| provided that: |
| - Proper attribution is given in any publications |
| - The Software is not redistributed |
| - No commercial benefit is derived |
| |
| RESTRICTIONS |
| ------------ |
| |
| You may NOT: |
| |
| 1. Use the Software for any commercial purpose without a separate commercial |
| license agreement with Ame Web Studio. |
| |
| 2. Modify, adapt, translate, reverse engineer, decompile, disassemble, or |
| create derivative works based on the Software for commercial purposes. |
| |
| 3. Redistribute, sublicense, rent, lease, or lend the Software to any third |
| party. |
| |
| 4. Remove or alter any proprietary notices, labels, or marks on the Software. |
| |
| 5. Use the Software to train competing AI models for commercial deployment. |
| |
| 6. Claim ownership or authorship of the Software or its components. |
| |
| COMMERCIAL LICENSING |
| -------------------- |
| |
| For commercial use, including but not limited to: |
| - Integration into commercial products or services |
| - Use in production environments |
| - Offering services based on the Software |
| |
| Please contact: contact@amewebstudio.com |
| |
| INTELLECTUAL PROPERTY |
| --------------------- |
| |
| The Software embodies proprietary algorithms and architectures including but |
| not limited to: |
| - NEXUS Cognitive Architecture |
| - EARCP (Ensemble Auto-Regulated by Coherence & Performance) |
| - LPOL (Long-term Procedural & Operational Learning) |
| - VideoGQA (Video Grouped Query Attention) |
| - Neurogenesis mechanisms |
| - Temporal Coherence algorithms |
| - Flow Prediction systems |
| |
| These components are trade secrets and proprietary technology of Ame Web Studio. |
| |
| WARRANTY DISCLAIMER |
| ------------------- |
| |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
| |
| LIMITATION OF LIABILITY |
| ----------------------- |
| |
| IN NO EVENT SHALL AME WEB STUDIO, MIKE AMEGA, OR ANY CONTRIBUTORS BE LIABLE |
| FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE |
| OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| |
| TERMINATION |
| ----------- |
| |
| This license is effective until terminated. Your rights under this license |
| will terminate automatically without notice if you fail to comply with any |
| of its terms. |
| |
| GOVERNING LAW |
| ------------- |
| |
| This license shall be governed by and construed in accordance with the laws |
| of Canada, without regard to its conflict of law provisions. |
| |
| ================================================================================ |
| |
| For licensing inquiries: |
| - Email: contact@amewebstudio.com |
| - Website: https://amewebstudio.com |
| |
| ================================================================================ |
| """ |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SyntheticVideoDataset(Dataset): |
| """Generate synthetic videos with moving shapes for testing""" |
|
|
| def __init__( |
| self, |
| n_videos: int = 1000, |
| n_frames: int = 16, |
| height: int = 64, |
| width: int = 64, |
| ): |
| self.n_videos = n_videos |
| self.n_frames = n_frames |
| self.height = height |
| self.width = width |
|
|
| def __len__(self) -> int: |
| return self.n_videos |
|
|
| def __getitem__(self, idx: int) -> Dict: |
| np.random.seed(idx) |
|
|
| video = np.zeros((self.n_frames, self.height, self.width, 3), dtype=np.float32) |
|
|
| |
| bg = np.linspace(0, 0.3, self.height).reshape(-1, 1, 1) |
| video += bg |
|
|
| n_objects = np.random.randint(2, 5) |
|
|
| for obj in range(n_objects): |
| color = np.random.rand(3) * 0.7 + 0.3 |
| x = np.random.rand() * (self.width - 15) + 5 |
| y = np.random.rand() * (self.height - 15) + 5 |
| vx = np.random.randn() * 2 |
| vy = np.random.randn() * 2 |
| size = np.random.randint(5, 12) |
| shape = np.random.choice(["circle", "square", "triangle"]) |
|
|
| for t in range(self.n_frames): |
| x += vx |
| y += vy |
|
|
| |
| if x < size or x > self.width - size: |
| vx = -vx * 0.9 |
| x = np.clip(x, size, self.width - size) |
| if y < size or y > self.height - size: |
| vy = -vy * 0.9 |
| y = np.clip(y, size, self.height - size) |
|
|
| ix, iy = int(x), int(y) |
|
|
| if shape == "circle": |
| for dx in range(-size, size + 1): |
| for dy in range(-size, size + 1): |
| if dx * dx + dy * dy <= size * size: |
| px, py = ix + dx, iy + dy |
| if 0 <= px < self.width and 0 <= py < self.height: |
| video[t, py, px] = color |
| elif shape == "square": |
| x1, x2 = max(0, ix - size // 2), min(self.width, ix + size // 2) |
| y1, y2 = max(0, iy - size // 2), min(self.height, iy + size // 2) |
| video[t, y1:y2, x1:x2] = color |
| else: |
| for dy in range(size): |
| width_at_y = (size - dy) // 2 |
| for dx in range(-width_at_y, width_at_y + 1): |
| px, py = ix + dx, iy + dy - size // 2 |
| if 0 <= px < self.width and 0 <= py < self.height: |
| video[t, py, px] = color |
|
|
| video = np.clip(video, 0, 1) |
| video = torch.tensor(video).permute(0, 3, 1, 2) |
|
|
| return {"video": video} |
|
|
|
|
| class VideoFolderDataset(Dataset): |
| """Load videos from a folder (placeholder for real datasets)""" |
|
|
| def __init__( |
| self, |
| folder_path: str, |
| n_frames: int = 16, |
| height: int = 64, |
| width: int = 64, |
| transform=None, |
| ): |
| self.folder_path = folder_path |
| self.n_frames = n_frames |
| self.height = height |
| self.width = width |
| self.transform = transform |
|
|
| |
| self.video_files = [] |
| if os.path.exists(folder_path): |
| for ext in ["*.mp4", "*.avi", "*.mov", "*.mkv"]: |
| import glob |
|
|
| self.video_files.extend(glob.glob(os.path.join(folder_path, ext))) |
|
|
| if not self.video_files: |
| logger.warning(f"No video files found in {folder_path}") |
|
|
| def __len__(self) -> int: |
| return max(1, len(self.video_files)) |
|
|
| def __getitem__(self, idx: int) -> Dict: |
| |
| video = torch.randn(self.n_frames, 3, self.height, self.width) |
| video = (video - video.min()) / (video.max() - video.min()) |
| return {"video": video} |
|
|
|
|
| |
| |
| |
|
|
|
|
| def train_video_model( |
| model: NutataModel, |
| dataset: Dataset, |
| config: NutataModelConfig, |
| save_path: str = "./nutata_model_v1_1.pt", |
| push_to_hub: bool = False, |
| hub_token: str = None, |
| ): |
| """ |
| Train NUTATA v1.1 |
| |
| Args: |
| model: NutataModel instance |
| dataset: Training dataset |
| config: Model configuration |
| save_path: Path to save checkpoint |
| push_to_hub: Whether to push to HuggingFace Hub after training |
| hub_token: HuggingFace API token (required if push_to_hub=True) |
| """ |
| logger.info("\n" + "=" * 70) |
| logger.info("π NUTATA v1.1 TRAINING") |
| logger.info("=" * 70) |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=config.batch_size, |
| shuffle=True, |
| num_workers=2, |
| drop_last=True, |
| pin_memory=True, |
| ) |
|
|
| logger.info(f"π Dataset: {len(dataset)} videos") |
| logger.info(f"π Batches per epoch: {len(dataloader)}") |
| logger.info(f"π Batch size: {config.batch_size}") |
| logger.info(f"π Epochs: {config.epochs}") |
|
|
| optimizer = AdamW( |
| model.parameters(), |
| lr=config.learning_rate, |
| weight_decay=0.01, |
| betas=(0.9, 0.95), |
| ) |
|
|
| |
| total_steps = config.epochs * len(dataloader) |
|
|
| def lr_lambda(step): |
| if step < config.warmup_steps: |
| return step / config.warmup_steps |
| progress = (step - config.warmup_steps) / (total_steps - config.warmup_steps) |
| return 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| scaler = GradScaler() |
|
|
| model.train() |
| best_loss = float("inf") |
| global_step = 0 |
|
|
| for epoch in range(config.epochs): |
| epoch_metrics = { |
| "loss": 0, |
| "recon": 0, |
| "kl": 0, |
| "coherence": 0, |
| "psnr": 0, |
| "ssim": 0, |
| } |
| total_coherence = 0 |
| total_neuro = 0 |
| total_experts = 0 |
|
|
| pbar = tqdm( |
| enumerate(dataloader), |
| total=len(dataloader), |
| desc=f"Epoch {epoch + 1}/{config.epochs}", |
| ncols=120, |
| ) |
|
|
| for batch_idx, batch in pbar: |
| video = batch["video"].to(device) |
|
|
| |
| if video.dim() == 5 and video.shape[1] != 3: |
| video = video.permute(0, 2, 1, 3, 4) |
|
|
| optimizer.zero_grad() |
|
|
| with autocast(): |
| outputs = model(video) |
| loss = outputs["loss"] |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| global_step += 1 |
|
|
| |
| with torch.no_grad(): |
| video_comp = ( |
| video if video.shape[1] == 3 else video.permute(0, 2, 1, 3, 4) |
| ) |
| psnr = compute_psnr(outputs["recon"], video_comp) |
| ssim = compute_ssim(outputs["recon"], video_comp) |
|
|
| epoch_metrics["loss"] += outputs["loss"].item() |
| epoch_metrics["recon"] += outputs["recon_loss"].item() |
| epoch_metrics["kl"] += outputs["kl_loss"].item() |
| coh_loss = outputs["coherence_loss"] |
| epoch_metrics["coherence"] += ( |
| float(coh_loss) if not hasattr(coh_loss, "item") else coh_loss |
| ) |
| epoch_metrics["psnr"] += psnr |
| epoch_metrics["ssim"] += ssim |
| total_coherence += outputs["coherence"] |
| total_neuro += outputs["neurogenesis"] |
| total_experts += outputs["expert_growth"] |
|
|
| if batch_idx % 5 == 0: |
| n = batch_idx + 1 |
| pbar.set_postfix( |
| { |
| "loss": f"{epoch_metrics['loss'] / n:.4f}", |
| "psnr": f"{epoch_metrics['psnr'] / n:.1f}", |
| "ssim": f"{epoch_metrics['ssim'] / n:.3f}", |
| "coh": f"{total_coherence / n:.3f}", |
| "lr": f"{scheduler.get_last_lr()[0]:.6f}", |
| } |
| ) |
|
|
| n_batches = len(dataloader) |
| for k in epoch_metrics: |
| epoch_metrics[k] /= n_batches |
|
|
| logger.info( |
| f"Epoch {epoch + 1}/{config.epochs} | " |
| f"Loss: {epoch_metrics['loss']:.4f} | " |
| f"PSNR: {epoch_metrics['psnr']:.2f} dB | " |
| f"SSIM: {epoch_metrics['ssim']:.4f} | " |
| f"Coherence: {total_coherence / n_batches:.3f} | " |
| f"Neuro: +{total_neuro} | " |
| f"Experts: +{total_experts}" |
| ) |
|
|
| if epoch_metrics["loss"] < best_loss: |
| best_loss = epoch_metrics["loss"] |
|
|
| |
| checkpoint = { |
| "model": model.state_dict(), |
| "config": config.to_dict(), |
| "epoch": epoch, |
| "loss": best_loss, |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "metrics": epoch_metrics, |
| } |
| torch.save(checkpoint, save_path) |
| logger.info(f" πΎ Best model saved (loss: {best_loss:.4f})") |
|
|
| |
| save_dir = save_path.replace(".pt", "_best") |
| model.save_pretrained(save_dir) |
|
|
| logger.info("\nπ Training complete!") |
| logger.info(f"π Best loss: {best_loss:.4f}") |
| logger.info(f"π Final PSNR: {epoch_metrics['psnr']:.2f} dB") |
| logger.info(f"π Final SSIM: {epoch_metrics['ssim']:.4f}") |
|
|
| |
| save_dir = save_path.replace(".pt", "_final") |
| model.save_pretrained(save_dir) |
|
|
| |
| if push_to_hub and config.push_to_hub: |
| if hub_token is None: |
| logger.warning( |
| "β οΈ No HuggingFace token provided. Set HF_TOKEN environment variable or pass hub_token argument." |
| ) |
| hub_token = os.environ.get("HF_TOKEN") |
|
|
| if hub_token: |
| try: |
| url = model.push_to_hub( |
| repo_id=config.hub_model_id, |
| token=hub_token, |
| private=True, |
| commit_message=f"Upload NUTATA v{config.version} - Loss: {best_loss:.4f}, PSNR: {epoch_metrics['psnr']:.2f}", |
| save_directory=save_dir, |
| ) |
| logger.info(f"π Model pushed to HuggingFace Hub: {url}") |
| except Exception as e: |
| logger.error(f"β Failed to push to HuggingFace Hub: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
| else: |
| logger.warning("β οΈ Skipping HuggingFace Hub push - no token available") |
|
|
| return model |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_nutata_model( |
| pretrained_path: str = "amewebstudio/nutata-videomodel-v1.1", device: str = None |
| ) -> NutataModel: |
| """ |
| Convenience function to load NUTATA. |
| |
| Args: |
| pretrained_path: HuggingFace repo ID or local path |
| device: Device to load to (auto-detected if None) |
| |
| Returns: |
| Loaded NutataModel |
| |
| Example: |
| >>> model = load_nutata_model() |
| >>> video = torch.randn(1, 3, 16, 64, 64).cuda() |
| >>> outputs = model(video) |
| """ |
| return NutataModel.from_pretrained(pretrained_path, device=device) |
|
|
|
|
| def create_nutata_model( |
| d_model: int = 512, |
| n_layers: int = 4, |
| n_frames: int = 16, |
| height: int = 64, |
| width: int = 64, |
| device: str = None, |
| **kwargs, |
| ) -> NutataModel: |
| """ |
| Create a new NUTATA with custom configuration. |
| |
| Args: |
| d_model: Model dimension |
| n_layers: Number of EARCP layers |
| n_frames: Number of video frames |
| height: Frame height |
| width: Frame width |
| device: Device to create on |
| **kwargs: Additional config parameters |
| |
| Returns: |
| New NutataModel instance |
| |
| Example: |
| >>> model = create_nutata_model(d_model=768, n_layers=6) |
| """ |
| video_config = VideoConfig(height=height, width=width, n_frames=n_frames) |
|
|
| config = NutataModelConfig( |
| d_model=d_model, n_layers=n_layers, video=video_config, **kwargs |
| ) |
|
|
| model = NutataModel(config) |
|
|
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| return model.to(device) |
|
|
|
|
| def benchmark_model(model: NutataModel, n_iterations: int = 10) -> Dict: |
| """ |
| Benchmark model performance |
| |
| Args: |
| model: NutataModel instance |
| n_iterations: Number of iterations for benchmarking |
| |
| Returns: |
| Dict with timing results |
| """ |
| import time |
|
|
| model.eval() |
| device = next(model.parameters()).device |
|
|
| B = 1 |
| T = model.config.video.n_frames |
| H = model.config.video.height |
| W = model.config.video.width |
| C = model.config.video.channels |
|
|
| video = torch.randn(B, C, T, H, W, device=device) |
|
|
| |
| with torch.no_grad(): |
| for _ in range(3): |
| _ = model(video) |
|
|
| if device.type == "cuda": |
| torch.cuda.synchronize() |
|
|
| |
| forward_times = [] |
| for _ in range(n_iterations): |
| start = time.time() |
| with torch.no_grad(): |
| _ = model(video) |
| if device.type == "cuda": |
| torch.cuda.synchronize() |
| forward_times.append(time.time() - start) |
|
|
| |
| gen_times = [] |
| for _ in range(n_iterations): |
| start = time.time() |
| with torch.no_grad(): |
| _ = model.generate(n_frames=T) |
| if device.type == "cuda": |
| torch.cuda.synchronize() |
| gen_times.append(time.time() - start) |
|
|
| return { |
| "forward_mean": np.mean(forward_times), |
| "forward_std": np.std(forward_times), |
| "forward_fps": 1.0 / np.mean(forward_times), |
| "generation_mean": np.mean(gen_times), |
| "generation_std": np.std(gen_times), |
| "generation_fps": 1.0 / np.mean(gen_times), |
| "device": str(device), |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| print("\n" + "=" * 70) |
| print(""" |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β β |
| β NUTATA v1.1 COMPLETE β |
| β Cognitive Video World Model β |
| β β |
| β β’ 3D VAE (Causal Spatiotemporal) β |
| β β’ EARCP with Temporal Attention β |
| β β’ LPOL Memory (9 Video Domains) β |
| β β’ GQA for Long Sequences β |
| β β’ Neurogenesis β |
| β β’ Temporal Coherence Module β |
| β β’ Flow Prediction β |
| β β’ HuggingFace Hub Integration β |
| β β |
| β Author: Mike Amega - Ame Web Studio β |
| β Version: 1.1 (Complete with Architecture Fixes) β |
| β β |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| """) |
| print("=" * 70 + "\n") |
|
|
| |
| |
| |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| HF_REPO_ID = "amewebstudio/nutata-v1.0-finetuned" |
|
|
| |
| if HF_TOKEN: |
| os.environ["HF_TOKEN"] = HF_TOKEN |
|
|
| logger.info(f"π HuggingFace Token: Configured") |
| logger.info(f"π¦ Target Repository: {HF_REPO_ID}") |
|
|
| |
| config = NutataModelConfig( |
| epochs=20, |
| batch_size=4, |
| n_layers=4, |
| d_model=512, |
| learning_rate=1e-4, |
| warmup_steps=100, |
| push_to_hub=True, |
| hub_model_id=HF_REPO_ID, |
| ) |
|
|
| |
| logger.info(f"\nπ§ Device: {device}") |
| if torch.cuda.is_available(): |
| logger.info(f"π§ GPU: {torch.cuda.get_device_name(0)}") |
| logger.info( |
| f"π§ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB" |
| ) |
|
|
| |
| logger.info("\nπ¦ Creating model...") |
| model = NutataModel(config) |
| model = model.to(device) |
|
|
| |
| logger.info("\nπ Model Diagnostics:") |
| diag = model.diagnostics() |
| for k, v in diag.items(): |
| logger.info(f" {k}: {v}") |
|
|
| |
| logger.info("\nπ Creating synthetic video dataset...") |
| dataset = SyntheticVideoDataset( |
| n_videos=500, |
| n_frames=config.video.n_frames, |
| height=config.video.height, |
| width=config.video.width, |
| ) |
| logger.info(f" Videos: {len(dataset)}") |
| logger.info( |
| f" Shape: [{config.video.n_frames}, {config.video.channels}, {config.video.height}, {config.video.width}]" |
| ) |
|
|
| |
| logger.info("\nπ§ͺ Testing forward pass...") |
| test_batch = dataset[0]["video"].unsqueeze(0).to(device) |
| if test_batch.shape[1] != 3: |
| test_batch = test_batch.permute(0, 2, 1, 3, 4) |
|
|
| with torch.no_grad(): |
| test_out = model(test_batch) |
| logger.info(f" β
Forward pass successful!") |
| logger.info(f" Input shape: {test_batch.shape}") |
| logger.info(f" Recon shape: {test_out['recon'].shape}") |
| logger.info(f" Loss: {test_out['loss'].item():.4f}") |
| logger.info(f" Coherence: {test_out['coherence']:.3f}") |
|
|
| |
| logger.info("\n" + "=" * 70) |
| model = train_video_model( |
| model, |
| dataset, |
| config, |
| save_path="./nutata_model_v1_1.pt", |
| push_to_hub=True, |
| hub_token=HF_TOKEN, |
| ) |
|
|
| |
| logger.info("\nπ Final Diagnostics:") |
| diag = model.diagnostics() |
| for k, v in diag.items(): |
| logger.info(f" {k}: {v}") |
|
|
| |
| logger.info("\n㪠Testing video generation...") |
| model.eval() |
| with torch.no_grad(): |
| generated = model.generate(n_frames=16, temperature=0.8) |
| logger.info(f" Generated shape: {generated.shape}") |
| logger.info( |
| f" Value range: [{generated.min().item():.3f}, {generated.max().item():.3f}]" |
| ) |
|
|
| |
| logger.info("\n㪠Testing video continuation...") |
| with torch.no_grad(): |
| continued = model.continue_video(test_batch, n_frames=8) |
| logger.info(f" Input shape: {test_batch.shape}") |
| logger.info(f" Continued shape: {continued.shape}") |
|
|
| |
| logger.info("\nβ±οΈ Benchmarking model...") |
| bench = benchmark_model(model, n_iterations=5) |
| for k, v in bench.items(): |
| if isinstance(v, float): |
| logger.info(f" {k}: {v:.4f}") |
| else: |
| logger.info(f" {k}: {v}") |
|
|
| |
| logger.info("\nπΎ Testing save/load functionality...") |
| save_dir = "./nutata_model_test_save" |
| model.save_pretrained(save_dir) |
|
|
| |
| logger.info("\nπ₯ Testing model loading...") |
| loaded_model = NutataModel.from_pretrained(save_dir, device=str(device)) |
| logger.info(f" β
Model loaded successfully!") |
| logger.info(f" Loaded params: {loaded_model.count_params():,}") |
|
|
| |
| with torch.no_grad(): |
| loaded_out = loaded_model(test_batch) |
| logger.info(f" Loaded model loss: {loaded_out['loss'].item():.4f}") |
|
|
| logger.info("\n" + "=" * 70) |
| logger.info("β
ALL TESTS PASSED!") |
| logger.info("=" * 70) |
|
|
| return model, config |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| model, config = main() |