Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.model.config import ModelConfig | |
| from src.model.backbone import Backbone | |
| from src.model.plastic import PlasticLayer | |
| from src.model.branches import BranchRouter | |
| from src.model.verifier import Verifier | |
| class ABPTModel(nn.Module): | |
| def __init__(self, cfg: ModelConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.backbone = Backbone(cfg) | |
| if cfg.use_plastic: | |
| self.plastic = PlasticLayer(cfg) | |
| if cfg.use_branches: | |
| self.branch_router = BranchRouter(cfg) | |
| else: | |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| if cfg.use_verifier and cfg.use_branches: | |
| self.verifier = Verifier(cfg) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| ) -> dict: | |
| result = {} | |
| backbone_out = self.backbone(input_ids) | |
| hidden = backbone_out["hidden"] | |
| result["backbone_hidden"] = hidden | |
| result["layer_outputs"] = backbone_out["layer_outputs"] | |
| if self.cfg.use_plastic: | |
| hidden = self.plastic(hidden) | |
| if self.cfg.use_branches: | |
| branch_out = self.branch_router(hidden) | |
| result["diversity_loss"] = branch_out["diversity_loss"] | |
| result["branch_logits"] = branch_out["branch_logits"] | |
| if self.cfg.use_verifier: | |
| verifier_out = self.verifier(branch_out["branch_logits"]) | |
| logits = verifier_out["logits"] | |
| result["confidence"] = verifier_out["confidence"] | |
| result["branch_weights"] = verifier_out["branch_weights"] | |
| else: | |
| logits = branch_out["logits"] | |
| else: | |
| logits = self.lm_head(hidden) | |
| result["hidden"] = hidden | |
| result["logits"] = logits | |
| if targets is not None: | |
| B, T, V = logits.shape | |
| ce_loss = F.cross_entropy(logits.view(B * T, V), targets.view(B * T)) | |
| total_loss = ce_loss | |
| if self.cfg.use_branches: | |
| total_loss = total_loss + self.cfg.diversity_weight * result["diversity_loss"] | |
| result["loss"] = total_loss | |
| result["ce_loss"] = ce_loss | |
| return result | |
| def param_count(self) -> int: | |
| return sum(p.numel() for p in self.parameters()) | |
| def param_count_str(self) -> str: | |
| n = self.param_count() | |
| if n >= 1_000_000: | |
| return f"{n / 1_000_000:.1f}M" | |
| return f"{n / 1_000:.1f}K" | |