File size: 8,295 Bytes
5527166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
---
language:
  - en
license: apache-2.0
library_name: transformers
tags:
  - cross-encoder
  - reranker
  - thread-matching
  - conversational-ai
  - lora
  - peft
  - onnx
pipeline_tag: text-classification
datasets:
  - Algokruti/thread-reranker-data
base_model: nreimers/MiniLM-L6-H384-uncased
model-index:
  - name: thread-reranker
    results:
      - task:
          type: text-classification
          name: Thread Relevance Ranking
        dataset:
          name: thread-reranker-data
          type: Algokruti/thread-reranker-data
          split: test
        metrics:
          - name: Hit Rate @ 1 (Overall)
            type: accuracy
            value: 0.9049
          - name: Hit Rate @ 1 (Easy)
            type: accuracy
            value: 1.0000
          - name: Hit Rate @ 1 (Medium)
            type: accuracy
            value: 0.8211
          - name: Hit Rate @ 1 (Hard)
            type: accuracy
            value: 0.8413
---

# Thread Reranker

A cross-encoder reranker that scores how relevant a conversation thread is to a new user message. Designed for unified conversation architectures where a single chat stream replaces explicit thread management β€” the model determines which internal thread a message belongs to so the right context can be retrieved automatically.

## How It Works

In a unified conversation system, users interact through a single continuous chat. Behind the scenes, the system maintains multiple internal threads (topics the user has discussed before). When a new message arrives, candidate threads are retrieved using fast heuristics (entity matching, recency, flow continuity), and this reranker scores each candidate to pick the best match.

The model takes two inputs simultaneously: the text pair (user message + thread summary) processed through the encoder, and structured retrieval features computed by the upstream pipeline. It fuses both signals to produce a relevance score.

### Architecture

```
User Message + Thread Summary ──► MiniLM-L6 (frozen + LoRA r=8) ──► CLS token ──┐
                                                                                  β”œβ”€β”€β–Ί MLP Head ──► Score
Step 3 Structured Features ──────► Feature Projection (Linearβ†’ReLUβ†’Linear) β”€β”€β”€β”€β”€β”€β”˜
```

**Base model:** nreimers/MiniLM-L6-H384-uncased (22M parameters, encoder-only)

**LoRA configuration:** Rank 8, alpha 16, applied to query and value projections, dropout 0.1

**Structured features (5 inputs):**
- `entity_overlap` β€” count of thread entities found in the user message
- `keyword_matches` β€” keyword overlap between message and thread content
- `flow_continuity` β€” 1.0 if this thread was the most recently active, 0.0 otherwise
- `recency_score` β€” exponential decay score based on hours since thread was last active
- `hours_since_active` β€” raw hours since thread was last active

## Intended Use

This model is one component in a 7-step unified conversation pipeline:

1. **User sends message** β€” single chat stream, no thread selector
2. **Entity & signal extraction** β€” lightweight NER and pattern matching (no ML)
3. **Layered context retrieval** β€” database queries using entity match, recency, flow continuity
4. **Reranker (this model)** β€” scores candidate threads from Step 3
5. **Confidence threshold** β€” auto-select if confident, ask user if ambiguous
6. **LLM responds** β€” with the correct thread context injected
7. **Update thread store** β€” extract new entities and facts, write back to database

The model only fires when the deterministic heuristics in Step 3 produce multiple plausible candidates. Clear-cut cases (unique entity match + high recency) are resolved without the model.

## Performance

Evaluated on synthetic test data with three difficulty tiers:

| Difficulty | Hit Rate @ 1 | Description |
|---|---|---|
| **Easy** | 100.0% | Message contains explicit entity references ("fix the React bug") |
| **Medium** | 82.1% | Indirect references ("that bug we were debugging") |
| **Hard** | 84.1% | No entity signal, relies on recency and flow ("let's keep going") |
| **Overall** | 90.5% | Weighted across all tiers |

**Note:** In the hybrid pipeline, easy cases are handled by deterministic heuristics without calling the model. The model's effective contribution is on medium and hard cases, where the combined system achieves 95%+ accuracy when including heuristic pre-filtering.

## Training

**Dataset:** Algokruti/thread-reranker-data β€” 50,543 synthetic examples (12,500 positive, 38,043 negative) generated from 500 simulated user profiles across 12 topic types in 5 domains.

**Training strategy:** Curriculum learning β€” epochs 1-2 trained on easy examples only, epochs 3-5 on all difficulty tiers. Binary cross-entropy loss with cosine learning rate schedule and warmup.

**Hyperparameters:**
- Batch size: 64
- Learning rate: 2e-4
- Epochs: 5 (2 curriculum + 3 full)
- Max sequence length: 256
- LoRA rank: 8, alpha: 16
- Optimizer: AdamW with weight decay 0.01
- Gradient clipping: max norm 1.0

**Training domains covered:**
- Web Development (React Dashboard, Authentication, CSS Grid)
- Backend Development (Python API, Docker Deployment)
- Personal (Meal Planning, Job Search, Fitness)
- Data Science (ML Training, Data Pipeline)
- Mobile Development (iOS/Swift, Android/Kotlin)

## Limitations

- **Trained on synthetic data only.** Performance on real user conversations may differ, particularly for domains and linguistic patterns not represented in the training set.
- **Limited domain coverage.** 12 topics across 5 domains, heavily skewed toward software development. Non-technical topics (travel, health, education, finance, creative writing) are underrepresented.
- **English only.** Not tested on multilingual conversations.
- **Cold start.** With no conversation history, the model has nothing to rank. The system falls back to treating each message as a new thread.
- **Ambiguity resolution.** On genuinely ambiguous messages with no entity, recency, or flow signal, the model may select incorrectly. The confidence threshold mechanism is designed to catch these cases and ask the user instead.

## How to Use

### PyTorch Inference

```python
import torch
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("nreimers/MiniLM-L6-H384-uncased")

# Load the full ThreadReranker (see training notebook for class definition)
model = ThreadReranker()
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()

# Score a message against a candidate thread
message = "can you fix that chart rendering issue"
thread_text = "Building a metrics dashboard with Chart.js | the bar chart overflows on mobile | React, Chart.js"

encoding = tokenizer(message, thread_text, max_length=256,
                     padding="max_length", truncation=True, return_tensors="pt")

features = torch.tensor([[1.0, 1.0, 1.0, 0.92, 2.0]])  # Step 3 features

with torch.no_grad():
    score = torch.sigmoid(model(encoding["input_ids"], encoding["attention_mask"], features))

print(f"Relevance score: {score.item():.4f}")
```

### ONNX Inference (On-Device)

```python
import onnxruntime as ort
import numpy as np

session = ort.InferenceSession("thread_reranker.onnx")

# Prepare inputs (tokenized text + structured features)
result = session.run(None, {
    "input_ids": input_ids_np,
    "attention_mask": attention_mask_np,
    "structured_features": features_np,
})

score = 1 / (1 + np.exp(-result[0]))  # sigmoid
```

## Files

| File | Description |
|---|---|
| `model.pt` | PyTorch model weights (base + LoRA merged + classification head) |
| `thread_reranker.onnx` | ONNX export for on-device inference |
| `config.json` | Model configuration and feature definitions |
| `training_history.json` | Per-epoch training and validation metrics |
| `tokenizer.json` | Tokenizer files |

## Citation

If you use this model, please reference the training dataset:

```
@misc{thread-reranker-2026,
  title={Thread Reranker: Cross-Encoder for Unified Conversation Thread Matching},
  author={Algokruti},
  year={2026},
  publisher={Hugging Face},
  url={https://huggingface.co/Algokruti/thread-reranker}
}
```