File size: 4,726 Bytes
924e4e0 c34f7f5 924e4e0 c34f7f5 924e4e0 c34f7f5 924e4e0 c34f7f5 924e4e0 c34f7f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | ---
license: apache-2.0
library_name: transformers
tags:
- query-auto-completion
- search
- pytorch
- custom-model
datasets:
- rexoscare/autocomplete-search-dataset
language:
- en
pipeline_tag: text-classification
base_model:
- google/byt5-small
---
# Query Auto-Completion Model
A CNN+Transformer model for ranking query auto-completion suggestions.
## Model Description
This model scores how well a candidate completion matches a given query prefix. It uses:
- **Prefix Encoder**: Multi-scale CNN + Transformer for extracting search intention from partial queries
- **Candidate Encoder**: Transformer for encoding candidate completions
- **Match Predictor**: MLP that scores the compatibility between prefix and candidate
The model uses pretrained ByT5 byte-level embeddings for robust character-level understanding.
## Model Architecture
- Embedding Dimension: 256
- CNN Filters: 64 (filter sizes: [3, 4, 5])
- Transformer Heads: 4
- Transformer Layers: 2
- Base Embeddings: google/byt5-small
## Usage
### Installation
```bash
pip install transformers torch
```
### Loading the Model
```python
from transformers import AutoTokenizer, AutoConfig, AutoModel
# Load model (trust_remote_code required for custom architecture)
config = AutoConfig.from_pretrained("lv12/sin-qac-model", trust_remote_code=True)
model = AutoModel.from_pretrained("lv12/sin-qac-model", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
```
### Scoring Candidates
```python
import torch
def score_completion(model, tokenizer, prefix: str, candidate: str, max_length: int = 20):
"""Score how well a candidate matches a prefix."""
model.eval()
prefix_encoding = tokenizer(
prefix,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
candidate_encoding = tokenizer(
candidate,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
score = model(
prefix_ids=prefix_encoding["input_ids"],
candidate_ids=candidate_encoding["input_ids"]
)
return score.squeeze().item()
# Example usage
prefix = "how to"
candidates = ["how to cook pasta", "how to learn python", "weather today"]
scores = []
for candidate in candidates:
score = score_completion(model, tokenizer, prefix, candidate)
scores.append((candidate, score))
# Sort by score (higher is better match)
scores.sort(key=lambda x: x[1], reverse=True)
for candidate, score in scores:
print(f"{score:.4f} - {candidate}")
```
### Batch Scoring
```python
def batch_score(model, tokenizer, prefix: str, candidates: list, max_length: int = 20):
"""Score multiple candidates efficiently."""
model.eval()
prefix_encoding = tokenizer(
prefix,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
prefix_ids = prefix_encoding["input_ids"]
candidate_encodings = tokenizer(
candidates,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
scores = []
with torch.no_grad():
for i in range(len(candidates)):
score = model(
prefix_ids=prefix_ids,
candidate_ids=candidate_encodings["input_ids"][i:i+1]
)
scores.append(score.squeeze().item())
return list(zip(candidates, scores))
# Example
results = batch_score(model, tokenizer, "best resta", [
"best restaurants near me",
"best restaurant in new york",
"best resume templates",
"weather forecast"
])
for candidate, score in sorted(results, key=lambda x: -x[1]):
print(f"{score:.4f} - {candidate}")
```
## Training Details
- **Dataset**: [rexoscare/autocomplete-search-dataset](https://huggingface.co/datasets/rexoscare/autocomplete-search-dataset)
- **Checkpoint**: `final.ckpt`
- **Training Framework**: PyTorch Lightning
- **Validation Loss**: 0.0459
## Evaluation
The model outputs scores between 0 and 1:
- **> 0.7**: Strong match (candidate is a likely completion)
- **0.4 - 0.7**: Moderate match
- **< 0.4**: Weak match (candidate unlikely to be what user wants)
## Limitations
- Optimized for English queries
- Best performance on short prefixes (< 20 characters)
- Trained on search autocomplete data; may not generalize to other domains
## Citation
If you use this model, please cite:
```bibtex
@misc{query-completion-model,
title={Query Auto-Completion Model},
year={2024},
publisher={HuggingFace},
url={https://huggingface.co/lv12/sin-qac-model}
}
``` |