| |
|
|
| from __future__ import annotations |
| from typing import Dict, List, Optional, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
| class MLPHead(nn.Module): |
| def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 512): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, hidden_dim), |
| nn.GELU(), |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.GELU(), |
| nn.Linear(hidden_dim // 2, out_dim) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| |
| |
| |
|
|
| class GeneralistComfortDT(nn.Module): |
| def __init__(self, config: dict): |
| super().__init__() |
| self.config = config |
|
|
| d_model = config["D_MODEL"] |
| vocab_size = config["VOCAB_SIZE"] |
| max_zones = config["MAX_ZONES"] |
| context_dim = config.get("CONTEXT_DIM", 10) |
| rtg_dim = config.get("RTG_DIM", 2) |
| self.feat_embed = nn.Embedding(vocab_size, d_model) |
| self.zone_embed = nn.Embedding(max_zones, d_model) |
| self.val_proj = nn.Linear(1, d_model) |
| self.val_gamma = nn.Embedding(vocab_size, d_model) |
| self.val_beta = nn.Embedding(vocab_size, d_model) |
| self.ctx_proj = nn.Linear(context_dim, d_model) |
| self.rtg_embed = nn.Linear(rtg_dim, d_model) |
| self.pos_embed = nn.Parameter(torch.zeros(1, config["CONTEXT_LEN"], d_model)) |
|
|
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=config["N_HEADS"], |
| dim_feedforward=4 * d_model, |
| dropout=config["DROPOUT"], |
| batch_first=True, |
| activation="gelu", |
| norm_first=True, |
| ) |
| self.backbone = nn.TransformerEncoder(enc_layer, num_layers=config["N_LAYERS"]) |
| self.ln_out = nn.LayerNorm(d_model) |
| self.action_head = MLPHead(d_model, config["NUM_ACTION_BINS"]) |
| self.state_head = nn.Linear(d_model, 1) |
| self.state_head_4h = nn.Linear(d_model, 1) |
| self.return_head = MLPHead(d_model, rtg_dim, hidden_dim=256) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| nn.init.normal_(self.pos_embed, std=0.02) |
| nn.init.ones_(self.val_gamma.weight) |
| nn.init.zeros_(self.val_beta.weight) |
|
|
| @staticmethod |
| def _build_time_causal_mask(T: int, K: int, device: torch.device) -> torch.Tensor: |
| L = T * K |
| ti = torch.arange(L, device=device) // K |
| return (ti[None, :] > ti[:, None]) |
|
|
|
|
|
|
| def forward( |
| self, |
| feature_ids: torch.Tensor, |
| feature_vals: torch.Tensor, |
| zone_ids: torch.Tensor, |
| attn_mask: torch.Tensor, |
| rtg: Optional[torch.Tensor] = None, |
| context: Optional[torch.Tensor] = None, |
| rtg_dropout_prob: float = 0.0 |
| ) -> Dict[str, torch.Tensor]: |
| |
| B, T, K = feature_ids.shape |
| d_model = self.config["D_MODEL"] |
| flat_fids = feature_ids.reshape(B, -1) |
| flat_vals = feature_vals.reshape(B, -1, 1) |
| flat_zids = zone_ids.reshape(B, -1) |
| val_emb = self.val_proj(flat_vals) |
| val_emb = self.val_gamma(flat_fids) * val_emb + self.val_beta(flat_fids) |
|
|
| x_base = ( |
| self.feat_embed(flat_fids) |
| + self.zone_embed(flat_zids) |
| + val_emb |
| ) |
| pos = self.pos_embed[:, :T, :].unsqueeze(2).expand(-1, -1, K, -1).reshape(1, -1, d_model) |
| x_base = x_base + pos |
|
|
| if context is not None: |
| ctx_emb = self.ctx_proj(context).unsqueeze(1) |
| x_base = x_base + ctx_emb |
| rtg_emb = torch.zeros_like(x_base) |
| if rtg is not None: |
| flat_rtg = rtg.unsqueeze(2).expand(-1, -1, K, -1).reshape(B, -1, 2) |
| if self.training: |
| flat_rtg = flat_rtg + torch.randn_like(flat_rtg) * 0.005 |
| |
| rtg_emb = self.rtg_embed(flat_rtg) |
| |
| if self.training: |
| rtg_emb = F.dropout(rtg_emb, p=0.1) |
| if rtg_dropout_prob > 0.0: |
| mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device)) |
| rtg_emb = rtg_emb * mask |
| x = x_base + rtg_emb |
|
|
|
|
| flat_mask = attn_mask.reshape(B, -1) |
| key_padding_mask = (flat_mask == 0) |
| attn_mask_2d = self._build_time_causal_mask(T, K, device=x.device) |
| x_latent = self.backbone(x, mask=attn_mask_2d, src_key_padding_mask=key_padding_mask) |
| x_latent = self.ln_out(x_latent) |
| action_logits = self.action_head(x_latent).reshape(B, T, K, -1) |
| x_phys = x_latent - rtg_emb |
| state_preds = self.state_head(x_phys).reshape(B, T, K) |
| state_preds_4h = self.state_head_4h(x_phys).reshape(B, T, K) |
| return_preds_raw = self.return_head(x_phys).reshape(B, T, K, -1) |
| return_preds = return_preds_raw.mean(dim=2) |
|
|
|
|
| if self.training and rtg_dropout_prob > 0.0: |
| mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device)) |
| rtg_emb = rtg_emb * mask |
|
|
| return { |
| "action_logits": action_logits, |
| "state_preds": state_preds, |
| "state_preds_4h": state_preds_4h, |
| "return_preds": return_preds, |
| "building_latent": x_latent.mean(dim=1) |
| } |