| """Hugging Face model wrapper for HydrAMP.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.utils import ModelOutput |
|
|
| from .config import HydrAMPConfig |
| from .hydramp import HydrAMPDecoder, HydrAMPEncoder |
|
|
|
|
| @dataclass |
| class HydrAMPOutput(ModelOutput): |
| """HydrAMP forward outputs.""" |
|
|
| logits: torch.Tensor | None = None |
| mean: torch.Tensor | None = None |
| log_std: torch.Tensor | None = None |
|
|
|
|
| class HydrAMPModel(PreTrainedModel): |
| """HydrAMP model with HF `AutoModel` compatibility.""" |
|
|
| config_class = HydrAMPConfig |
| base_model_prefix = "hydramp" |
|
|
| def __init__(self, config: HydrAMPConfig) -> None: |
| super().__init__(config) |
| if len(config.default_condition) != config.condition_dim: |
| raise ValueError( |
| f"default_condition must contain {config.condition_dim} values, got {len(config.default_condition)}." |
| ) |
|
|
| self.encoder = HydrAMPEncoder( |
| vocab_size=config.vocab_size, |
| embedding_dim=config.embedding_dim, |
| latent_dim=config.latent_dim, |
| sequence_length=config.sequence_length, |
| gru_hidden_size=config.encoder_hidden_size, |
| ) |
| self.decoder = HydrAMPDecoder( |
| sequence_length=config.sequence_length, |
| latent_dim=config.latent_dim, |
| condition_dim=config.condition_dim, |
| hidden_size=config.decoder_hidden_size, |
| vocab_size=config.vocab_size, |
| ) |
| self.register_buffer( |
| "default_condition", |
| torch.tensor(config.default_condition, dtype=torch.float32), |
| persistent=False, |
| ) |
| self.post_init() |
|
|
| def forward_latent_positions( |
| self, |
| z: torch.Tensor, |
| num_steps: int | None = None, |
| condition: torch.Tensor | None = None, |
| *, |
| return_logits: bool = True, |
| ) -> CausalLMOutputWithPast: |
| """Decode latent vectors to sequence distributions (GRUVAE-style API). |
| |
| Output length is fixed to ``config.sequence_length``. If ``num_steps`` is |
| passed, it must equal that value. |
| """ |
| fixed = self.config.sequence_length |
| if num_steps is None: |
| num_steps = fixed |
| elif num_steps != fixed: |
| msg = f"HydrAMP decoder length is fixed at {fixed}; got num_steps={num_steps}." |
| raise ValueError(msg) |
|
|
| if condition is None: |
| condition = self.default_condition.unsqueeze(0).expand(z.shape[0], -1) |
| condition = condition.to(device=z.device, dtype=z.dtype) |
| decoder_input = torch.cat([z, condition], dim=-1) |
| out = self.decoder( |
| decoder_input, |
| return_logits=return_logits, |
| gumbel_temperature=self.config.temperature, |
| ) |
| return CausalLMOutputWithPast(logits=out, past_key_values=None) |
|
|
| def decode_to_token_ids( |
| self, |
| z: torch.Tensor, |
| num_steps: int | None = None, |
| condition: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Greedy token IDs from latent ``z`` (argmax over vocabulary per position).""" |
| logits = self.forward_latent_positions( |
| z, num_steps=num_steps, condition=condition, return_logits=True |
| ).logits |
| assert logits is not None |
| return logits.argmax(dim=-1) |
|
|
| def forward(self, input_ids: torch.Tensor, **_: object) -> HydrAMPOutput: |
| """Run encode + deterministic decode for reconstruction.""" |
| mean, log_std = self.encoder.encode(input_ids) |
| logits = self.forward_latent_positions(mean, return_logits=True).logits |
| return HydrAMPOutput(logits=logits, mean=mean, log_std=log_std) |
|
|