| """HuggingFace-compatible model classes for SwipeTransformer.""" |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPooling, |
| ModelOutput, |
| SequenceClassifierOutput, |
| ) |
|
|
| from .configuration_swipe import SwipeCrossEncoderConfig, SwipeTransformerConfig |
|
|
|
|
| @dataclass |
| class SwipeTransformerOutput(ModelOutput): |
| """ |
| Output type for SwipeTransformerModel. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (character prediction). |
| char_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`): |
| Prediction scores of the character prediction head. |
| path_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 3)`, *optional*): |
| Prediction scores of the path prediction head (if enabled). |
| length_logits (`torch.FloatTensor` of shape `(batch_size, max_length+1)`, *optional*): |
| Prediction scores of the length prediction head (if enabled). |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): |
| SEP token embeddings for similarity/embedding tasks. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| char_logits: torch.FloatTensor = None |
| path_logits: Optional[torch.FloatTensor] = None |
| length_logits: Optional[torch.FloatTensor] = None |
| last_hidden_state: torch.FloatTensor = None |
| pooler_output: Optional[torch.FloatTensor] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| class SwipeTransformerPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface |
| for downloading and loading pretrained models. |
| """ |
|
|
| config_class = SwipeTransformerConfig |
| base_model_prefix = "swipe_transformer" |
| supports_gradient_checkpointing = False |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
|
|
| class SwipeTransformerModel(SwipeTransformerPreTrainedModel): |
| """ |
| HuggingFace-compatible SwipeTransformerModel. |
| |
| This model reuses the existing components from src/swipealot/models/ |
| and wraps them in a HuggingFace-compatible interface. |
| |
| Args: |
| config (SwipeTransformerConfig): Model configuration |
| """ |
|
|
| def __init__(self, config: SwipeTransformerConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| from .embeddings import MixedEmbedding |
| from .heads import CharacterPredictionHead, LengthPredictionHead, PathPredictionHead |
|
|
| |
| self.embeddings = MixedEmbedding( |
| vocab_size=config.vocab_size, |
| max_path_len=config.max_path_len, |
| max_char_len=config.max_char_len, |
| d_model=config.d_model, |
| dropout=config.dropout, |
| ) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.d_model, |
| nhead=config.n_heads, |
| dim_feedforward=config.d_ff, |
| dropout=config.dropout, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer, |
| num_layers=config.n_layers, |
| enable_nested_tensor=False, |
| ) |
|
|
| |
| self.char_head = CharacterPredictionHead( |
| d_model=config.d_model, |
| vocab_size=config.vocab_size, |
| ) |
|
|
| if config.predict_path: |
| self.path_head = PathPredictionHead(d_model=config.d_model) |
| else: |
| self.path_head = None |
|
|
| |
| |
| self.length_head = LengthPredictionHead( |
| d_model=config.d_model, |
| max_length=config.max_char_len, |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| path_coords: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| return_dict: bool | None = None, |
| output_hidden_states: bool | None = None, |
| ): |
| """ |
| Forward pass of the model. |
| |
| Args: |
| path_coords (torch.Tensor): Path coordinates [batch, path_len, 3] |
| input_ids (torch.Tensor): Character token IDs [batch, char_len] |
| attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len] |
| labels (torch.Tensor, optional): Labels for loss calculation [batch, char_len] |
| return_dict (bool, optional): Whether to return ModelOutput object |
| output_hidden_states (bool, optional): Whether to output hidden states |
| |
| Returns: |
| SwipeTransformerOutput or tuple: Model outputs with: |
| - loss: Optional loss value |
| - char_logits: Character prediction logits [batch, seq_len, vocab_size] |
| - path_logits: Path prediction logits [batch, seq_len, 3] (if predict_path=True) |
| - length_logits: Length prediction logits [batch, max_length] |
| - last_hidden_state: Hidden states [batch, seq_len, d_model] |
| - pooler_output: SEP token embeddings [batch, d_model] for similarity/embedding tasks |
| - hidden_states: Tuple of hidden states (if output_hidden_states=True) |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| batch_size = path_coords.shape[0] |
| device = path_coords.device |
|
|
| |
| cls_token = torch.full( |
| (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device |
| ) |
| sep_token = torch.full( |
| (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device |
| ) |
|
|
| |
| embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token) |
|
|
| |
| if attention_mask is not None: |
| |
| |
| src_key_padding_mask = attention_mask == 0 |
| else: |
| src_key_padding_mask = None |
|
|
| |
| hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask) |
|
|
| |
| char_logits = self.char_head(hidden_states) |
|
|
| |
| path_logits = None |
| if self.path_head is not None: |
| path_logits = self.path_head(hidden_states) |
|
|
| |
| cls_hidden = hidden_states[:, 0, :] |
| length_logits = self.length_head(cls_hidden) |
|
|
| |
| |
| path_len = path_coords.shape[1] |
| sep_position = 1 + path_len |
| pooler_output = hidden_states[:, sep_position, :] |
|
|
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
| |
| |
| char_start = 1 + path_len + 1 |
| char_hidden = hidden_states[:, char_start : char_start + labels.shape[1], :] |
| char_pred = self.char_head(char_hidden) |
| loss = loss_fct(char_pred.reshape(-1, self.config.vocab_size), labels.reshape(-1)) |
|
|
| if not return_dict: |
| output = (hidden_states, char_logits, length_logits, pooler_output) |
| if path_logits is not None: |
| output = output + (path_logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SwipeTransformerOutput( |
| loss=loss, |
| char_logits=char_logits, |
| path_logits=path_logits, |
| length_logits=length_logits, |
| last_hidden_state=hidden_states, |
| pooler_output=pooler_output, |
| hidden_states=(hidden_states,) if output_hidden_states else None, |
| ) |
|
|
|
|
| class SwipeCrossEncoderForSequenceClassification(SwipeTransformerPreTrainedModel): |
| """ |
| HuggingFace-compatible cross-encoder for sequence classification. |
| |
| This model is designed for similarity scoring between swipe paths and words. |
| It extracts the SEP token embedding and passes it through a classification head. |
| |
| Args: |
| config (SwipeCrossEncoderConfig): Model configuration |
| """ |
|
|
| config_class = SwipeCrossEncoderConfig |
| base_model_prefix = "swipe_cross_encoder" |
|
|
| def __init__(self, config: SwipeCrossEncoderConfig): |
| super().__init__(config) |
| self.config = config |
| self.num_labels = config.num_labels |
|
|
| |
| from .embeddings import MixedEmbedding |
| from .heads import ClassificationHead |
|
|
| |
| self.embeddings = MixedEmbedding( |
| vocab_size=config.vocab_size, |
| max_path_len=config.max_path_len, |
| max_char_len=config.max_char_len, |
| d_model=config.d_model, |
| dropout=config.dropout, |
| ) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.d_model, |
| nhead=config.n_heads, |
| dim_feedforward=config.d_ff, |
| dropout=config.dropout, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer, |
| num_layers=config.n_layers, |
| enable_nested_tensor=False, |
| ) |
|
|
| |
| self.classifier = ClassificationHead( |
| d_model=config.d_model, |
| num_labels=config.num_labels, |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| path_coords: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| return_dict: bool | None = None, |
| ): |
| """ |
| Forward pass for cross-encoder. |
| |
| Args: |
| path_coords (torch.Tensor): Path coordinates [batch, path_len, 3] |
| input_ids (torch.Tensor): Character token IDs [batch, char_len] |
| attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len] |
| labels (torch.Tensor, optional): Labels for loss calculation [batch, num_labels] |
| return_dict (bool, optional): Whether to return ModelOutput object |
| |
| Returns: |
| SequenceClassifierOutput or tuple: Model outputs with logits and optional loss |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| batch_size = path_coords.shape[0] |
| device = path_coords.device |
|
|
| |
| cls_token = torch.full( |
| (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device |
| ) |
| sep_token = torch.full( |
| (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device |
| ) |
|
|
| |
| embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token) |
|
|
| |
| if attention_mask is not None: |
| src_key_padding_mask = attention_mask == 0 |
| else: |
| src_key_padding_mask = None |
|
|
| |
| hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask) |
|
|
| |
| |
| path_len = path_coords.shape[1] |
| sep_position = 1 + path_len |
| sep_embedding = hidden_states[:, sep_position, :] |
|
|
| |
| logits = self.classifier(sep_embedding) |
|
|
| |
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| else: |
| self.config.problem_type = "single_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = nn.MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + (hidden_states,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=(hidden_states,), |
| ) |
|
|
|
|
| class SwipeModel(SwipeTransformerPreTrainedModel): |
| """ |
| Base Swipe model for extracting embeddings. |
| |
| .. deprecated:: |
| This class is deprecated. Use SwipeTransformerModel instead, which now |
| includes pooler_output for embeddings alongside prediction heads. |
| SwipeTransformerModel provides both predictions AND embeddings in a single model. |
| |
| This model returns the SEP token embedding, which can be used for: |
| - Vector databases |
| - Semantic search |
| - Similarity computation |
| |
| The SEP token embedding represents the joint encoding of the path and text. |
| |
| Usage (Deprecated - use SwipeTransformerModel instead): |
| ```python |
| from transformers import AutoModel |
| |
| model = AutoModel.from_pretrained( |
| "your-username/swipe-model", |
| trust_remote_code=True |
| ) |
| |
| # Get embeddings |
| outputs = model(path_coords=paths, input_ids=tokens) |
| embeddings = outputs.pooler_output # SEP token embeddings |
| ``` |
| |
| Args: |
| config (SwipeTransformerConfig or SwipeCrossEncoderConfig): Model configuration |
| """ |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| from .embeddings import MixedEmbedding |
|
|
| |
| self.embeddings = MixedEmbedding( |
| vocab_size=config.vocab_size, |
| max_path_len=config.max_path_len, |
| max_char_len=config.max_char_len, |
| d_model=config.d_model, |
| dropout=config.dropout, |
| ) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.d_model, |
| nhead=config.n_heads, |
| dim_feedforward=config.d_ff, |
| dropout=config.dropout, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer, |
| num_layers=config.n_layers, |
| enable_nested_tensor=False, |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| path_coords: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| return_dict: bool | None = None, |
| output_hidden_states: bool | None = None, |
| ): |
| """ |
| Forward pass that returns embeddings. |
| |
| Args: |
| path_coords (torch.Tensor): Path coordinates [batch, path_len, 3] |
| input_ids (torch.Tensor): Character token IDs [batch, char_len] |
| attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len] |
| return_dict (bool, optional): Whether to return ModelOutput object |
| output_hidden_states (bool, optional): Whether to output all hidden states |
| |
| Returns: |
| BaseModelOutputWithPooling with: |
| - last_hidden_state: Full sequence hidden states [batch, seq_len, d_model] |
| - pooler_output: SEP token embeddings [batch, d_model] |
| - hidden_states: Tuple of hidden states (if output_hidden_states=True) |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| batch_size = path_coords.shape[0] |
| device = path_coords.device |
|
|
| |
| cls_token = torch.full( |
| (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device |
| ) |
| sep_token = torch.full( |
| (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device |
| ) |
|
|
| |
| embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token) |
|
|
| |
| if attention_mask is not None: |
| src_key_padding_mask = attention_mask == 0 |
| else: |
| src_key_padding_mask = None |
|
|
| |
| hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask) |
|
|
| |
| |
| path_len = path_coords.shape[1] |
| sep_position = 1 + path_len |
| pooler_output = hidden_states[:, sep_position, :] |
|
|
| if not return_dict: |
| return (hidden_states, pooler_output) |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=hidden_states, |
| pooler_output=pooler_output, |
| hidden_states=(hidden_states,) if output_hidden_states else None, |
| ) |
|
|