File size: 4,726 Bytes
924e4e0
 
 
 
c34f7f5
 
 
 
924e4e0
c34f7f5
924e4e0
c34f7f5
924e4e0
c34f7f5
 
924e4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c34f7f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
---
license: apache-2.0
library_name: transformers
tags:
- query-auto-completion
- search
- pytorch
- custom-model
datasets:
- rexoscare/autocomplete-search-dataset
language:
- en
pipeline_tag: text-classification
base_model:
- google/byt5-small
---

# Query Auto-Completion Model

A CNN+Transformer model for ranking query auto-completion suggestions.

## Model Description

This model scores how well a candidate completion matches a given query prefix. It uses:
- **Prefix Encoder**: Multi-scale CNN + Transformer for extracting search intention from partial queries
- **Candidate Encoder**: Transformer for encoding candidate completions
- **Match Predictor**: MLP that scores the compatibility between prefix and candidate

The model uses pretrained ByT5 byte-level embeddings for robust character-level understanding.

## Model Architecture

- Embedding Dimension: 256
- CNN Filters: 64 (filter sizes: [3, 4, 5])
- Transformer Heads: 4
- Transformer Layers: 2
- Base Embeddings: google/byt5-small

## Usage

### Installation

```bash
pip install transformers torch
```

### Loading the Model

```python
from transformers import AutoTokenizer, AutoConfig, AutoModel

# Load model (trust_remote_code required for custom architecture)
config = AutoConfig.from_pretrained("lv12/sin-qac-model", trust_remote_code=True)
model = AutoModel.from_pretrained("lv12/sin-qac-model", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
```

### Scoring Candidates

```python
import torch

def score_completion(model, tokenizer, prefix: str, candidate: str, max_length: int = 20):
    """Score how well a candidate matches a prefix."""
    model.eval()

    prefix_encoding = tokenizer(
        prefix,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    candidate_encoding = tokenizer(
        candidate,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    with torch.no_grad():
        score = model(
            prefix_ids=prefix_encoding["input_ids"],
            candidate_ids=candidate_encoding["input_ids"]
        )

    return score.squeeze().item()

# Example usage
prefix = "how to"
candidates = ["how to cook pasta", "how to learn python", "weather today"]

scores = []
for candidate in candidates:
    score = score_completion(model, tokenizer, prefix, candidate)
    scores.append((candidate, score))

# Sort by score (higher is better match)
scores.sort(key=lambda x: x[1], reverse=True)
for candidate, score in scores:
    print(f"{score:.4f} - {candidate}")
```

### Batch Scoring

```python
def batch_score(model, tokenizer, prefix: str, candidates: list, max_length: int = 20):
    """Score multiple candidates efficiently."""
    model.eval()

    prefix_encoding = tokenizer(
        prefix,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    prefix_ids = prefix_encoding["input_ids"]

    candidate_encodings = tokenizer(
        candidates,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    scores = []
    with torch.no_grad():
        for i in range(len(candidates)):
            score = model(
                prefix_ids=prefix_ids,
                candidate_ids=candidate_encodings["input_ids"][i:i+1]
            )
            scores.append(score.squeeze().item())

    return list(zip(candidates, scores))

# Example
results = batch_score(model, tokenizer, "best resta", [
    "best restaurants near me",
    "best restaurant in new york",
    "best resume templates",
    "weather forecast"
])
for candidate, score in sorted(results, key=lambda x: -x[1]):
    print(f"{score:.4f} - {candidate}")
```

## Training Details

- **Dataset**: [rexoscare/autocomplete-search-dataset](https://huggingface.co/datasets/rexoscare/autocomplete-search-dataset)
- **Checkpoint**: `final.ckpt`
- **Training Framework**: PyTorch Lightning
- **Validation Loss**: 0.0459

## Evaluation

The model outputs scores between 0 and 1:
- **> 0.7**: Strong match (candidate is a likely completion)
- **0.4 - 0.7**: Moderate match
- **< 0.4**: Weak match (candidate unlikely to be what user wants)

## Limitations

- Optimized for English queries
- Best performance on short prefixes (< 20 characters)
- Trained on search autocomplete data; may not generalize to other domains

## Citation

If you use this model, please cite:

```bibtex
@misc{query-completion-model,
  title={Query Auto-Completion Model},
  year={2024},
  publisher={HuggingFace},
  url={https://huggingface.co/lv12/sin-qac-model}
}
```