| | """
|
| | Vortex model implementation for HuggingFace.
|
| | Integrates with transformers library.
|
| | """
|
| |
|
| | from typing import Optional, Tuple, List, Dict, Any
|
| | import torch
|
| | import torch.nn as nn
|
| | from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig
|
| | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| |
|
| | from configuration_vortex import VortexConfig
|
| | from models.vortex_model import VortexModel
|
| |
|
| |
|
| | class VortexPreTrainedModel(PreTrainedModel):
|
| | """
|
| | Base class for Vortex models.
|
| | Handles loading/saving in HF format.
|
| | """
|
| | config_class = VortexConfig
|
| | base_model_prefix = "vortex"
|
| | supports_gradient_checkpointing = True
|
| | _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
| |
|
| | def _init_weights(self, module):
|
| | """Initialize weights."""
|
| | if isinstance(module, nn.Linear):
|
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| | if module.bias is not None:
|
| | module.bias.data.zero_()
|
| | elif isinstance(module, nn.Embedding):
|
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| | if module.padding_idx is not None:
|
| | module.weight.data[module.padding_idx].zero_()
|
| | elif isinstance(module, nn.LayerNorm):
|
| | module.bias.data.zero_()
|
| | module.weight.data.fill_(1.0)
|
| |
|
| | def get_input_embeddings(self):
|
| | return self.vortex.embed_tokens
|
| |
|
| | def set_input_embeddings(self, value):
|
| | self.vortex.embed_tokens = value
|
| |
|
| | def get_output_embeddings(self):
|
| | return self.vortex.lm_head
|
| |
|
| | def set_output_embeddings(self, new_embeddings):
|
| | self.vortex.lm_head = new_embeddings
|
| |
|
| |
|
| | class VortexForCausalLM(VortexPreTrainedModel):
|
| | """
|
| | Vortex model for causal language modeling.
|
| | """
|
| | _tied_weights_keys = ["vortex.lm_head.weight"]
|
| |
|
| | def __init__(self, config: VortexConfig):
|
| | super().__init__(config)
|
| | self.config = config
|
| |
|
| |
|
| | self.vortex = VortexModel(config.to_dict())
|
| |
|
| |
|
| | self.apply(self._init_weights)
|
| |
|
| |
|
| | if self.config.tie_word_embeddings:
|
| | self.tie_weights()
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids: torch.LongTensor = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | position_ids: Optional[torch.LongTensor] = None,
|
| | past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | labels: Optional[torch.LongTensor] = None,
|
| | use_cache: Optional[bool] = None,
|
| | output_attentions: Optional[bool] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | return_dict: Optional[bool] = None,
|
| | domain_ids: Optional[torch.LongTensor] = None,
|
| | domain_tags: Optional[torch.Tensor] = None,
|
| | text: Optional[List[str]] = None,
|
| | ) -> CausalLMOutputWithCrossAttentions:
|
| | """
|
| | Forward pass.
|
| |
|
| | Args:
|
| | input_ids: Token IDs (batch, seq_len)
|
| | attention_mask: Attention mask (batch, seq_len)
|
| | labels: Labels for LM loss (batch, seq_len)
|
| | domain_ids: Domain IDs (batch,)
|
| | domain_tags: Domain tag masks (batch, seq_len, num_domains)
|
| | text: Original text strings (for science modules)
|
| | """
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| |
|
| | outputs = self.vortex(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | domain_ids=domain_ids,
|
| | domain_tags=domain_tags,
|
| | text=text,
|
| | return_dict=True,
|
| | )
|
| |
|
| | logits = outputs["logits"]
|
| | last_hidden_state = outputs["last_hidden_state"]
|
| |
|
| | loss = None
|
| | if labels is not None:
|
| |
|
| | shift_logits = logits[..., :-1, :].contiguous()
|
| | shift_labels = labels[..., 1:].contiguous()
|
| | loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| | loss = loss_fct(
|
| | shift_logits.view(-1, shift_logits.size(-1)),
|
| | shift_labels.view(-1),
|
| | )
|
| |
|
| | if not return_dict:
|
| | output = (logits,) + (last_hidden_state,)
|
| | return (loss,) + output if loss is not None else output
|
| |
|
| | return CausalLMOutputWithCrossAttentions(
|
| | loss=loss,
|
| | logits=logits,
|
| | hidden_states=last_hidden_state,
|
| | attentions=None,
|
| | )
|
| |
|
| | def prepare_inputs_for_generation(
|
| | self,
|
| | input_ids,
|
| | past_key_values=None,
|
| | attention_mask=None,
|
| | **kwargs,
|
| | ):
|
| | """Prepare inputs for text generation."""
|
| |
|
| | if past_key_values:
|
| | input_ids = input_ids[:, -1:]
|
| |
|
| | return {
|
| | "input_ids": input_ids,
|
| | "attention_mask": attention_mask,
|
| | "past_key_values": past_key_values,
|
| | "use_cache": kwargs.get("use_cache", True),
|
| | }
|
| |
|
| | def generate(
|
| | self,
|
| | input_ids: Optional[torch.LongTensor] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | **kwargs,
|
| | ):
|
| | """Generate text."""
|
| | from transformers import GenerationConfig
|
| |
|
| | generation_config = kwargs.pop("generation_config", None)
|
| | if generation_config is None:
|
| | generation_config = GenerationConfig.from_model_config(self.config)
|
| |
|
| | return super().generate(
|
| | input_ids=input_ids,
|
| | inputs_embeds=inputs_embeds,
|
| | generation_config=generation_config,
|
| | **kwargs,
|
| | )
|
| |
|
| |
|
| |
|
| | from transformers import AutoConfig, AutoModelForCausalLM
|
| |
|
| | AutoConfig.register("vortex", VortexConfig)
|
| | AutoModelForCausalLM.register(VortexConfig, VortexForCausalLM)
|
| |
|
| |
|
| | def test_hf_integration():
|
| | """Test HuggingFace integration."""
|
| | from transformers import AutoConfig, AutoModelForCausalLM
|
| |
|
| |
|
| | config = VortexConfig(
|
| | d_model=512,
|
| | num_layers=2,
|
| | num_heads=8,
|
| | vocab_size=1000,
|
| | )
|
| |
|
| |
|
| | model = VortexForCausalLM(config)
|
| | print(f"Model parameters: {model.get_num_parameters():,}")
|
| |
|
| |
|
| | batch_size = 2
|
| | seq_len = 32
|
| | input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
| | labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
| |
|
| | outputs = model(input_ids=input_ids, labels=labels)
|
| | print(f"Loss: {outputs.loss.item():.4f}")
|
| | print(f"Logits shape: {outputs.logits.shape}")
|
| |
|
| |
|
| | model.save_pretrained("./test_vortex_model")
|
| | config.save_pretrained("./test_vortex_model")
|
| |
|
| | loaded_config = AutoConfig.from_pretrained("./test_vortex_model")
|
| | loaded_model = AutoModelForCausalLM.from_pretrained("./test_vortex_model")
|
| | print(f"Loaded model type: {type(loaded_model)}")
|
| |
|
| | print("HF integration test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_hf_integration()
|
| |
|