Vortex-13b-V1 / modeling_vortex.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
Vortex model implementation for HuggingFace.
Integrates with transformers library.
"""
from typing import Optional, Tuple, List, Dict, Any
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from configuration_vortex import VortexConfig
from models.vortex_model import VortexModel
class VortexPreTrainedModel(PreTrainedModel):
"""
Base class for Vortex models.
Handles loading/saving in HF format.
"""
config_class = VortexConfig
base_model_prefix = "vortex"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_input_embeddings(self):
return self.vortex.embed_tokens
def set_input_embeddings(self, value):
self.vortex.embed_tokens = value
def get_output_embeddings(self):
return self.vortex.lm_head
def set_output_embeddings(self, new_embeddings):
self.vortex.lm_head = new_embeddings
class VortexForCausalLM(VortexPreTrainedModel):
"""
Vortex model for causal language modeling.
"""
_tied_weights_keys = ["vortex.lm_head.weight"]
def __init__(self, config: VortexConfig):
super().__init__(config)
self.config = config
# Build core model
self.vortex = VortexModel(config.to_dict())
# Initialize weights
self.apply(self._init_weights)
# Tie weights if configured
if self.config.tie_word_embeddings:
self.tie_weights()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
domain_ids: Optional[torch.LongTensor] = None,
domain_tags: Optional[torch.Tensor] = None,
text: Optional[List[str]] = None,
) -> CausalLMOutputWithCrossAttentions:
"""
Forward pass.
Args:
input_ids: Token IDs (batch, seq_len)
attention_mask: Attention mask (batch, seq_len)
labels: Labels for LM loss (batch, seq_len)
domain_ids: Domain IDs (batch,)
domain_tags: Domain tag masks (batch, seq_len, num_domains)
text: Original text strings (for science modules)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Pass through Vortex model
outputs = self.vortex(
input_ids=input_ids,
attention_mask=attention_mask,
domain_ids=domain_ids,
domain_tags=domain_tags,
text=text,
return_dict=True,
)
logits = outputs["logits"]
last_hidden_state = outputs["last_hidden_state"]
loss = None
if labels is not None:
# Compute cross-entropy loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
if not return_dict:
output = (logits,) + (last_hidden_state,)
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
hidden_states=last_hidden_state,
attentions=None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
**kwargs,
):
"""Prepare inputs for text generation."""
# Omit tokens that are already past
if past_key_values:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True),
}
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
):
"""Generate text."""
from transformers import GenerationConfig
generation_config = kwargs.pop("generation_config", None)
if generation_config is None:
generation_config = GenerationConfig.from_model_config(self.config)
return super().generate(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
generation_config=generation_config,
**kwargs,
)
# Register model for AutoModel
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("vortex", VortexConfig)
AutoModelForCausalLM.register(VortexConfig, VortexForCausalLM)
def test_hf_integration():
"""Test HuggingFace integration."""
from transformers import AutoConfig, AutoModelForCausalLM
# Create config
config = VortexConfig(
d_model=512,
num_layers=2,
num_heads=8,
vocab_size=1000,
)
# Create model
model = VortexForCausalLM(config)
print(f"Model parameters: {model.get_num_parameters():,}")
# Test forward
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
outputs = model(input_ids=input_ids, labels=labels)
print(f"Loss: {outputs.loss.item():.4f}")
print(f"Logits shape: {outputs.logits.shape}")
# Test save/load
model.save_pretrained("./test_vortex_model")
config.save_pretrained("./test_vortex_model")
loaded_config = AutoConfig.from_pretrained("./test_vortex_model")
loaded_model = AutoModelForCausalLM.from_pretrained("./test_vortex_model")
print(f"Loaded model type: {type(loaded_model)}")
print("HF integration test passed!")
if __name__ == "__main__":
test_hf_integration()