Controller / training /embeddings.py
Gen-HVAC's picture
Upload 4 files
1641a08 verified
raw
history blame
6.08 kB
#embeddings.py
from __future__ import annotations
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# 1.MLP HEAD
# ============================================================
class MLPHead(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, out_dim)
)
def forward(self, x):
return self.net(x)
# ============================================================
# 2. DECISION TRANSFORMER
# ============================================================
class GeneralistComfortDT(nn.Module):
def __init__(self, config: dict):
super().__init__()
self.config = config
d_model = config["D_MODEL"]
vocab_size = config["VOCAB_SIZE"]
max_zones = config["MAX_ZONES"]
context_dim = config.get("CONTEXT_DIM", 10)
rtg_dim = config.get("RTG_DIM", 2)
self.feat_embed = nn.Embedding(vocab_size, d_model)
self.zone_embed = nn.Embedding(max_zones, d_model)
self.val_proj = nn.Linear(1, d_model)
self.val_gamma = nn.Embedding(vocab_size, d_model)
self.val_beta = nn.Embedding(vocab_size, d_model)
self.ctx_proj = nn.Linear(context_dim, d_model)
self.rtg_embed = nn.Linear(rtg_dim, d_model)
self.pos_embed = nn.Parameter(torch.zeros(1, config["CONTEXT_LEN"], d_model))
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=config["N_HEADS"],
dim_feedforward=4 * d_model,
dropout=config["DROPOUT"],
batch_first=True,
activation="gelu",
norm_first=True,
)
self.backbone = nn.TransformerEncoder(enc_layer, num_layers=config["N_LAYERS"])
self.ln_out = nn.LayerNorm(d_model)
self.action_head = MLPHead(d_model, config["NUM_ACTION_BINS"])
self.state_head = nn.Linear(d_model, 1)
self.state_head_4h = nn.Linear(d_model, 1)
self.return_head = MLPHead(d_model, rtg_dim, hidden_dim=256)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
nn.init.normal_(self.pos_embed, std=0.02)
nn.init.ones_(self.val_gamma.weight)
nn.init.zeros_(self.val_beta.weight)
@staticmethod
def _build_time_causal_mask(T: int, K: int, device: torch.device) -> torch.Tensor:
L = T * K
ti = torch.arange(L, device=device) // K
return (ti[None, :] > ti[:, None])
def forward(
self,
feature_ids: torch.Tensor,
feature_vals: torch.Tensor,
zone_ids: torch.Tensor,
attn_mask: torch.Tensor,
rtg: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
rtg_dropout_prob: float = 0.0
) -> Dict[str, torch.Tensor]:
B, T, K = feature_ids.shape
d_model = self.config["D_MODEL"]
flat_fids = feature_ids.reshape(B, -1)
flat_vals = feature_vals.reshape(B, -1, 1)
flat_zids = zone_ids.reshape(B, -1)
val_emb = self.val_proj(flat_vals)
val_emb = self.val_gamma(flat_fids) * val_emb + self.val_beta(flat_fids)
x_base = (
self.feat_embed(flat_fids)
+ self.zone_embed(flat_zids)
+ val_emb
)
pos = self.pos_embed[:, :T, :].unsqueeze(2).expand(-1, -1, K, -1).reshape(1, -1, d_model)
x_base = x_base + pos
if context is not None:
ctx_emb = self.ctx_proj(context).unsqueeze(1)
x_base = x_base + ctx_emb
rtg_emb = torch.zeros_like(x_base)
if rtg is not None:
flat_rtg = rtg.unsqueeze(2).expand(-1, -1, K, -1).reshape(B, -1, 2)
if self.training:
flat_rtg = flat_rtg + torch.randn_like(flat_rtg) * 0.005 # Noise
rtg_emb = self.rtg_embed(flat_rtg)
if self.training:
rtg_emb = F.dropout(rtg_emb, p=0.1)
if rtg_dropout_prob > 0.0:
mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device))
rtg_emb = rtg_emb * mask
x = x_base + rtg_emb
flat_mask = attn_mask.reshape(B, -1)
key_padding_mask = (flat_mask == 0)
attn_mask_2d = self._build_time_causal_mask(T, K, device=x.device)
x_latent = self.backbone(x, mask=attn_mask_2d, src_key_padding_mask=key_padding_mask)
x_latent = self.ln_out(x_latent)
action_logits = self.action_head(x_latent).reshape(B, T, K, -1)
x_phys = x_latent - rtg_emb
state_preds = self.state_head(x_phys).reshape(B, T, K)
state_preds_4h = self.state_head_4h(x_phys).reshape(B, T, K)
return_preds_raw = self.return_head(x_phys).reshape(B, T, K, -1)
return_preds = return_preds_raw.mean(dim=2)
if self.training and rtg_dropout_prob > 0.0:
mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device))
rtg_emb = rtg_emb * mask
return {
"action_logits": action_logits,
"state_preds": state_preds,
"state_preds_4h": state_preds_4h,
"return_preds": return_preds,
"building_latent": x_latent.mean(dim=1)
}