""" HuggingFace-compatible model definition for Query Auto-Completion. """ import torch from typing import Union, Tuple from transformers import PretrainedConfig, PreTrainedModel from .model import QueryCompletionModel as BaseQueryCompletionModel class QueryCompletionConfig(PretrainedConfig): """Configuration for Query Auto-Completion model.""" model_type = "query-completion" def __init__( self, vocab_size: int = 384, embed_dim: int = 256, num_filters: int = 64, filter_sizes: list = None, num_heads: int = 4, num_transformer_layers: int = 2, use_pretrained_embeddings: bool = True, pretrained_model_name: str = "google/byt5-small", dropout: float = 0.1, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.embed_dim = embed_dim self.num_filters = num_filters self.filter_sizes = filter_sizes or [3, 4, 5] self.num_heads = num_heads self.num_transformer_layers = num_transformer_layers self.use_pretrained_embeddings = use_pretrained_embeddings self.pretrained_model_name = pretrained_model_name self.dropout = dropout class QueryCompletionModelForHub(PreTrainedModel): """HuggingFace wrapper around the Query Auto-Completion model.""" config_class = QueryCompletionConfig base_model_prefix = "query_completion" supports_gradient_checkpointing = False def __init__(self, config: QueryCompletionConfig): super().__init__(config) self.model = BaseQueryCompletionModel( vocab_size=config.vocab_size, embed_dim=config.embed_dim, num_filters=config.num_filters, num_heads=config.num_heads, num_transformer_layers=config.num_transformer_layers, use_pretrained_embeddings=config.use_pretrained_embeddings, pretrained_model_name=config.pretrained_model_name, ) self.post_init() def forward( self, prefix_ids: torch.Tensor, candidate_ids: torch.Tensor, return_dict: bool = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: return self.model(prefix_ids, candidate_ids)