| """ | |
| HuggingFace-compatible model definition for Query Auto-Completion. | |
| """ | |
| import torch | |
| from typing import Union, Tuple | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from .model import QueryCompletionModel as BaseQueryCompletionModel | |
| class QueryCompletionConfig(PretrainedConfig): | |
| """Configuration for Query Auto-Completion model.""" | |
| model_type = "query-completion" | |
| def __init__( | |
| self, | |
| vocab_size: int = 384, | |
| embed_dim: int = 256, | |
| num_filters: int = 64, | |
| filter_sizes: list = None, | |
| num_heads: int = 4, | |
| num_transformer_layers: int = 2, | |
| use_pretrained_embeddings: bool = True, | |
| pretrained_model_name: str = "google/byt5-small", | |
| dropout: float = 0.1, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.embed_dim = embed_dim | |
| self.num_filters = num_filters | |
| self.filter_sizes = filter_sizes or [3, 4, 5] | |
| self.num_heads = num_heads | |
| self.num_transformer_layers = num_transformer_layers | |
| self.use_pretrained_embeddings = use_pretrained_embeddings | |
| self.pretrained_model_name = pretrained_model_name | |
| self.dropout = dropout | |
| class QueryCompletionModelForHub(PreTrainedModel): | |
| """HuggingFace wrapper around the Query Auto-Completion model.""" | |
| config_class = QueryCompletionConfig | |
| base_model_prefix = "query_completion" | |
| supports_gradient_checkpointing = False | |
| def __init__(self, config: QueryCompletionConfig): | |
| super().__init__(config) | |
| self.model = BaseQueryCompletionModel( | |
| vocab_size=config.vocab_size, | |
| embed_dim=config.embed_dim, | |
| num_filters=config.num_filters, | |
| num_heads=config.num_heads, | |
| num_transformer_layers=config.num_transformer_layers, | |
| use_pretrained_embeddings=config.use_pretrained_embeddings, | |
| pretrained_model_name=config.pretrained_model_name, | |
| ) | |
| self.post_init() | |
| def forward( | |
| self, | |
| prefix_ids: torch.Tensor, | |
| candidate_ids: torch.Tensor, | |
| return_dict: bool = None, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: | |
| return self.model(prefix_ids, candidate_ids) | |