""" Vortex model implementation for HuggingFace. Integrates with transformers library. """ from typing import Optional, Tuple, List, Dict, Any import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from configuration_vortex import VortexConfig from models.vortex_model import VortexModel class VortexPreTrainedModel(PreTrainedModel): """ Base class for Vortex models. Handles loading/saving in HF format. """ config_class = VortexConfig base_model_prefix = "vortex" supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def _init_weights(self, module): """Initialize weights.""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def get_input_embeddings(self): return self.vortex.embed_tokens def set_input_embeddings(self, value): self.vortex.embed_tokens = value def get_output_embeddings(self): return self.vortex.lm_head def set_output_embeddings(self, new_embeddings): self.vortex.lm_head = new_embeddings class VortexForCausalLM(VortexPreTrainedModel): """ Vortex model for causal language modeling. """ _tied_weights_keys = ["vortex.lm_head.weight"] def __init__(self, config: VortexConfig): super().__init__(config) self.config = config # Build core model self.vortex = VortexModel(config.to_dict()) # Initialize weights self.apply(self._init_weights) # Tie weights if configured if self.config.tie_word_embeddings: self.tie_weights() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, domain_ids: Optional[torch.LongTensor] = None, domain_tags: Optional[torch.Tensor] = None, text: Optional[List[str]] = None, ) -> CausalLMOutputWithCrossAttentions: """ Forward pass. Args: input_ids: Token IDs (batch, seq_len) attention_mask: Attention mask (batch, seq_len) labels: Labels for LM loss (batch, seq_len) domain_ids: Domain IDs (batch,) domain_tags: Domain tag masks (batch, seq_len, num_domains) text: Original text strings (for science modules) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Pass through Vortex model outputs = self.vortex( input_ids=input_ids, attention_mask=attention_mask, domain_ids=domain_ids, domain_tags=domain_tags, text=text, return_dict=True, ) logits = outputs["logits"] last_hidden_state = outputs["last_hidden_state"] loss = None if labels is not None: # Compute cross-entropy loss shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) if not return_dict: output = (logits,) + (last_hidden_state,) return (loss,) + output if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, hidden_states=last_hidden_state, attentions=None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, **kwargs, ): """Prepare inputs for text generation.""" # Omit tokens that are already past if past_key_values: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), } def generate( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs, ): """Generate text.""" from transformers import GenerationConfig generation_config = kwargs.pop("generation_config", None) if generation_config is None: generation_config = GenerationConfig.from_model_config(self.config) return super().generate( input_ids=input_ids, inputs_embeds=inputs_embeds, generation_config=generation_config, **kwargs, ) # Register model for AutoModel from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("vortex", VortexConfig) AutoModelForCausalLM.register(VortexConfig, VortexForCausalLM) def test_hf_integration(): """Test HuggingFace integration.""" from transformers import AutoConfig, AutoModelForCausalLM # Create config config = VortexConfig( d_model=512, num_layers=2, num_heads=8, vocab_size=1000, ) # Create model model = VortexForCausalLM(config) print(f"Model parameters: {model.get_num_parameters():,}") # Test forward batch_size = 2 seq_len = 32 input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) labels = torch.randint(0, config.vocab_size, (batch_size, seq_len)) outputs = model(input_ids=input_ids, labels=labels) print(f"Loss: {outputs.loss.item():.4f}") print(f"Logits shape: {outputs.logits.shape}") # Test save/load model.save_pretrained("./test_vortex_model") config.save_pretrained("./test_vortex_model") loaded_config = AutoConfig.from_pretrained("./test_vortex_model") loaded_model = AutoModelForCausalLM.from_pretrained("./test_vortex_model") print(f"Loaded model type: {type(loaded_model)}") print("HF integration test passed!") if __name__ == "__main__": test_hf_integration()