File size: 7,365 Bytes
5c43f61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | """
Vortex model implementation for HuggingFace.
Integrates with transformers library.
"""
from typing import Optional, Tuple, List, Dict, Any
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from configuration_vortex import VortexConfig
from models.vortex_model import VortexModel
class VortexPreTrainedModel(PreTrainedModel):
"""
Base class for Vortex models.
Handles loading/saving in HF format.
"""
config_class = VortexConfig
base_model_prefix = "vortex"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_input_embeddings(self):
return self.vortex.embed_tokens
def set_input_embeddings(self, value):
self.vortex.embed_tokens = value
def get_output_embeddings(self):
return self.vortex.lm_head
def set_output_embeddings(self, new_embeddings):
self.vortex.lm_head = new_embeddings
class VortexForCausalLM(VortexPreTrainedModel):
"""
Vortex model for causal language modeling.
"""
_tied_weights_keys = ["vortex.lm_head.weight"]
def __init__(self, config: VortexConfig):
super().__init__(config)
self.config = config
# Build core model
self.vortex = VortexModel(config.to_dict())
# Initialize weights
self.apply(self._init_weights)
# Tie weights if configured
if self.config.tie_word_embeddings:
self.tie_weights()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
domain_ids: Optional[torch.LongTensor] = None,
domain_tags: Optional[torch.Tensor] = None,
text: Optional[List[str]] = None,
) -> CausalLMOutputWithCrossAttentions:
"""
Forward pass.
Args:
input_ids: Token IDs (batch, seq_len)
attention_mask: Attention mask (batch, seq_len)
labels: Labels for LM loss (batch, seq_len)
domain_ids: Domain IDs (batch,)
domain_tags: Domain tag masks (batch, seq_len, num_domains)
text: Original text strings (for science modules)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Pass through Vortex model
outputs = self.vortex(
input_ids=input_ids,
attention_mask=attention_mask,
domain_ids=domain_ids,
domain_tags=domain_tags,
text=text,
return_dict=True,
)
logits = outputs["logits"]
last_hidden_state = outputs["last_hidden_state"]
loss = None
if labels is not None:
# Compute cross-entropy loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
if not return_dict:
output = (logits,) + (last_hidden_state,)
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
hidden_states=last_hidden_state,
attentions=None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
**kwargs,
):
"""Prepare inputs for text generation."""
# Omit tokens that are already past
if past_key_values:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True),
}
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
):
"""Generate text."""
from transformers import GenerationConfig
generation_config = kwargs.pop("generation_config", None)
if generation_config is None:
generation_config = GenerationConfig.from_model_config(self.config)
return super().generate(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
generation_config=generation_config,
**kwargs,
)
# Register model for AutoModel
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("vortex", VortexConfig)
AutoModelForCausalLM.register(VortexConfig, VortexForCausalLM)
def test_hf_integration():
"""Test HuggingFace integration."""
from transformers import AutoConfig, AutoModelForCausalLM
# Create config
config = VortexConfig(
d_model=512,
num_layers=2,
num_heads=8,
vocab_size=1000,
)
# Create model
model = VortexForCausalLM(config)
print(f"Model parameters: {model.get_num_parameters():,}")
# Test forward
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
outputs = model(input_ids=input_ids, labels=labels)
print(f"Loss: {outputs.loss.item():.4f}")
print(f"Logits shape: {outputs.logits.shape}")
# Test save/load
model.save_pretrained("./test_vortex_model")
config.save_pretrained("./test_vortex_model")
loaded_config = AutoConfig.from_pretrained("./test_vortex_model")
loaded_model = AutoModelForCausalLM.from_pretrained("./test_vortex_model")
print(f"Loaded model type: {type(loaded_model)}")
print("HF integration test passed!")
if __name__ == "__main__":
test_hf_integration()
|