#embeddings.py from __future__ import annotations from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # 1.MLP HEAD # ============================================================ class MLPHead(nn.Module): def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 512): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, out_dim) ) def forward(self, x): return self.net(x) # ============================================================ # 2. DECISION TRANSFORMER # ============================================================ class GeneralistComfortDT(nn.Module): def __init__(self, config: dict): super().__init__() self.config = config d_model = config["D_MODEL"] vocab_size = config["VOCAB_SIZE"] max_zones = config["MAX_ZONES"] context_dim = config.get("CONTEXT_DIM", 10) rtg_dim = config.get("RTG_DIM", 2) self.feat_embed = nn.Embedding(vocab_size, d_model) self.zone_embed = nn.Embedding(max_zones, d_model) self.val_proj = nn.Linear(1, d_model) self.val_gamma = nn.Embedding(vocab_size, d_model) self.val_beta = nn.Embedding(vocab_size, d_model) self.ctx_proj = nn.Linear(context_dim, d_model) self.rtg_embed = nn.Linear(rtg_dim, d_model) self.pos_embed = nn.Parameter(torch.zeros(1, config["CONTEXT_LEN"], d_model)) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=config["N_HEADS"], dim_feedforward=4 * d_model, dropout=config["DROPOUT"], batch_first=True, activation="gelu", norm_first=True, ) self.backbone = nn.TransformerEncoder(enc_layer, num_layers=config["N_LAYERS"]) self.ln_out = nn.LayerNorm(d_model) self.action_head = MLPHead(d_model, config["NUM_ACTION_BINS"]) self.state_head = nn.Linear(d_model, 1) self.state_head_4h = nn.Linear(d_model, 1) self.return_head = MLPHead(d_model, rtg_dim, hidden_dim=256) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) nn.init.normal_(self.pos_embed, std=0.02) nn.init.ones_(self.val_gamma.weight) nn.init.zeros_(self.val_beta.weight) @staticmethod def _build_time_causal_mask(T: int, K: int, device: torch.device) -> torch.Tensor: L = T * K ti = torch.arange(L, device=device) // K return (ti[None, :] > ti[:, None]) def forward( self, feature_ids: torch.Tensor, feature_vals: torch.Tensor, zone_ids: torch.Tensor, attn_mask: torch.Tensor, rtg: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, rtg_dropout_prob: float = 0.0 ) -> Dict[str, torch.Tensor]: B, T, K = feature_ids.shape d_model = self.config["D_MODEL"] flat_fids = feature_ids.reshape(B, -1) flat_vals = feature_vals.reshape(B, -1, 1) flat_zids = zone_ids.reshape(B, -1) val_emb = self.val_proj(flat_vals) val_emb = self.val_gamma(flat_fids) * val_emb + self.val_beta(flat_fids) x_base = ( self.feat_embed(flat_fids) + self.zone_embed(flat_zids) + val_emb ) pos = self.pos_embed[:, :T, :].unsqueeze(2).expand(-1, -1, K, -1).reshape(1, -1, d_model) x_base = x_base + pos if context is not None: ctx_emb = self.ctx_proj(context).unsqueeze(1) x_base = x_base + ctx_emb rtg_emb = torch.zeros_like(x_base) if rtg is not None: flat_rtg = rtg.unsqueeze(2).expand(-1, -1, K, -1).reshape(B, -1, 2) if self.training: flat_rtg = flat_rtg + torch.randn_like(flat_rtg) * 0.005 # Noise rtg_emb = self.rtg_embed(flat_rtg) if self.training: rtg_emb = F.dropout(rtg_emb, p=0.1) if rtg_dropout_prob > 0.0: mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device)) rtg_emb = rtg_emb * mask x = x_base + rtg_emb flat_mask = attn_mask.reshape(B, -1) key_padding_mask = (flat_mask == 0) attn_mask_2d = self._build_time_causal_mask(T, K, device=x.device) x_latent = self.backbone(x, mask=attn_mask_2d, src_key_padding_mask=key_padding_mask) x_latent = self.ln_out(x_latent) action_logits = self.action_head(x_latent).reshape(B, T, K, -1) x_phys = x_latent - rtg_emb state_preds = self.state_head(x_phys).reshape(B, T, K) state_preds_4h = self.state_head_4h(x_phys).reshape(B, T, K) return_preds_raw = self.return_head(x_phys).reshape(B, T, K, -1) return_preds = return_preds_raw.mean(dim=2) if self.training and rtg_dropout_prob > 0.0: mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device)) rtg_emb = rtg_emb * mask return { "action_logits": action_logits, "state_preds": state_preds, "state_preds_4h": state_preds_4h, "return_preds": return_preds, "building_latent": x_latent.mean(dim=1) }