|
|
--- |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- query-auto-completion |
|
|
- search |
|
|
- pytorch |
|
|
- custom-model |
|
|
datasets: |
|
|
- rexoscare/autocomplete-search-dataset |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: text-classification |
|
|
base_model: |
|
|
- google/byt5-small |
|
|
--- |
|
|
|
|
|
# Query Auto-Completion Model |
|
|
|
|
|
A CNN+Transformer model for ranking query auto-completion suggestions. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model scores how well a candidate completion matches a given query prefix. It uses: |
|
|
- **Prefix Encoder**: Multi-scale CNN + Transformer for extracting search intention from partial queries |
|
|
- **Candidate Encoder**: Transformer for encoding candidate completions |
|
|
- **Match Predictor**: MLP that scores the compatibility between prefix and candidate |
|
|
|
|
|
The model uses pretrained ByT5 byte-level embeddings for robust character-level understanding. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
- Embedding Dimension: 256 |
|
|
- CNN Filters: 64 (filter sizes: [3, 4, 5]) |
|
|
- Transformer Heads: 4 |
|
|
- Transformer Layers: 2 |
|
|
- Base Embeddings: google/byt5-small |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install transformers torch |
|
|
``` |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoConfig, AutoModel |
|
|
|
|
|
# Load model (trust_remote_code required for custom architecture) |
|
|
config = AutoConfig.from_pretrained("lv12/sin-qac-model", trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained("lv12/sin-qac-model", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") |
|
|
``` |
|
|
|
|
|
### Scoring Candidates |
|
|
|
|
|
```python |
|
|
import torch |
|
|
|
|
|
def score_completion(model, tokenizer, prefix: str, candidate: str, max_length: int = 20): |
|
|
"""Score how well a candidate matches a prefix.""" |
|
|
model.eval() |
|
|
|
|
|
prefix_encoding = tokenizer( |
|
|
prefix, |
|
|
max_length=max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
candidate_encoding = tokenizer( |
|
|
candidate, |
|
|
max_length=max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
score = model( |
|
|
prefix_ids=prefix_encoding["input_ids"], |
|
|
candidate_ids=candidate_encoding["input_ids"] |
|
|
) |
|
|
|
|
|
return score.squeeze().item() |
|
|
|
|
|
# Example usage |
|
|
prefix = "how to" |
|
|
candidates = ["how to cook pasta", "how to learn python", "weather today"] |
|
|
|
|
|
scores = [] |
|
|
for candidate in candidates: |
|
|
score = score_completion(model, tokenizer, prefix, candidate) |
|
|
scores.append((candidate, score)) |
|
|
|
|
|
# Sort by score (higher is better match) |
|
|
scores.sort(key=lambda x: x[1], reverse=True) |
|
|
for candidate, score in scores: |
|
|
print(f"{score:.4f} - {candidate}") |
|
|
``` |
|
|
|
|
|
### Batch Scoring |
|
|
|
|
|
```python |
|
|
def batch_score(model, tokenizer, prefix: str, candidates: list, max_length: int = 20): |
|
|
"""Score multiple candidates efficiently.""" |
|
|
model.eval() |
|
|
|
|
|
prefix_encoding = tokenizer( |
|
|
prefix, |
|
|
max_length=max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
prefix_ids = prefix_encoding["input_ids"] |
|
|
|
|
|
candidate_encodings = tokenizer( |
|
|
candidates, |
|
|
max_length=max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
scores = [] |
|
|
with torch.no_grad(): |
|
|
for i in range(len(candidates)): |
|
|
score = model( |
|
|
prefix_ids=prefix_ids, |
|
|
candidate_ids=candidate_encodings["input_ids"][i:i+1] |
|
|
) |
|
|
scores.append(score.squeeze().item()) |
|
|
|
|
|
return list(zip(candidates, scores)) |
|
|
|
|
|
# Example |
|
|
results = batch_score(model, tokenizer, "best resta", [ |
|
|
"best restaurants near me", |
|
|
"best restaurant in new york", |
|
|
"best resume templates", |
|
|
"weather forecast" |
|
|
]) |
|
|
for candidate, score in sorted(results, key=lambda x: -x[1]): |
|
|
print(f"{score:.4f} - {candidate}") |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Dataset**: [rexoscare/autocomplete-search-dataset](https://huggingface.co/datasets/rexoscare/autocomplete-search-dataset) |
|
|
- **Checkpoint**: `final.ckpt` |
|
|
- **Training Framework**: PyTorch Lightning |
|
|
- **Validation Loss**: 0.0459 |
|
|
|
|
|
## Evaluation |
|
|
|
|
|
The model outputs scores between 0 and 1: |
|
|
- **> 0.7**: Strong match (candidate is a likely completion) |
|
|
- **0.4 - 0.7**: Moderate match |
|
|
- **< 0.4**: Weak match (candidate unlikely to be what user wants) |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Optimized for English queries |
|
|
- Best performance on short prefixes (< 20 characters) |
|
|
- Trained on search autocomplete data; may not generalize to other domains |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{query-completion-model, |
|
|
title={Query Auto-Completion Model}, |
|
|
year={2024}, |
|
|
publisher={HuggingFace}, |
|
|
url={https://huggingface.co/lv12/sin-qac-model} |
|
|
} |
|
|
``` |