Text Classification
Transformers
ONNX
PEFT
English
cross-encoder
reranker
thread-matching
conversational-ai
lora
Eval Results (legacy)
Instructions to use Algokruti/thread-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Algokruti/thread-reranker with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="Algokruti/thread-reranker")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Algokruti/thread-reranker", dtype="auto") - PEFT
How to use Algokruti/thread-reranker with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| language: | |
| - en | |
| license: apache-2.0 | |
| library_name: transformers | |
| tags: | |
| - cross-encoder | |
| - reranker | |
| - thread-matching | |
| - conversational-ai | |
| - lora | |
| - peft | |
| - onnx | |
| pipeline_tag: text-classification | |
| datasets: | |
| - Algokruti/thread-reranker-data | |
| base_model: nreimers/MiniLM-L6-H384-uncased | |
| model-index: | |
| - name: thread-reranker | |
| results: | |
| - task: | |
| type: text-classification | |
| name: Thread Relevance Ranking | |
| dataset: | |
| name: thread-reranker-data | |
| type: Algokruti/thread-reranker-data | |
| split: test | |
| metrics: | |
| - name: Hit Rate @ 1 (Overall) | |
| type: accuracy | |
| value: 0.9049 | |
| - name: Hit Rate @ 1 (Easy) | |
| type: accuracy | |
| value: 1.0000 | |
| - name: Hit Rate @ 1 (Medium) | |
| type: accuracy | |
| value: 0.8211 | |
| - name: Hit Rate @ 1 (Hard) | |
| type: accuracy | |
| value: 0.8413 | |
| # Thread Reranker | |
| A cross-encoder reranker that scores how relevant a conversation thread is to a new user message. Designed for unified conversation architectures where a single chat stream replaces explicit thread management β the model determines which internal thread a message belongs to so the right context can be retrieved automatically. | |
| ## How It Works | |
| In a unified conversation system, users interact through a single continuous chat. Behind the scenes, the system maintains multiple internal threads (topics the user has discussed before). When a new message arrives, candidate threads are retrieved using fast heuristics (entity matching, recency, flow continuity), and this reranker scores each candidate to pick the best match. | |
| The model takes two inputs simultaneously: the text pair (user message + thread summary) processed through the encoder, and structured retrieval features computed by the upstream pipeline. It fuses both signals to produce a relevance score. | |
| ### Architecture | |
| ``` | |
| User Message + Thread Summary βββΊ MiniLM-L6 (frozen + LoRA r=8) βββΊ CLS token βββ | |
| ββββΊ MLP Head βββΊ Score | |
| Step 3 Structured Features βββββββΊ Feature Projection (LinearβReLUβLinear) βββββββ | |
| ``` | |
| **Base model:** nreimers/MiniLM-L6-H384-uncased (22M parameters, encoder-only) | |
| **LoRA configuration:** Rank 8, alpha 16, applied to query and value projections, dropout 0.1 | |
| **Structured features (5 inputs):** | |
| - `entity_overlap` β count of thread entities found in the user message | |
| - `keyword_matches` β keyword overlap between message and thread content | |
| - `flow_continuity` β 1.0 if this thread was the most recently active, 0.0 otherwise | |
| - `recency_score` β exponential decay score based on hours since thread was last active | |
| - `hours_since_active` β raw hours since thread was last active | |
| ## Intended Use | |
| This model is one component in a 7-step unified conversation pipeline: | |
| 1. **User sends message** β single chat stream, no thread selector | |
| 2. **Entity & signal extraction** β lightweight NER and pattern matching (no ML) | |
| 3. **Layered context retrieval** β database queries using entity match, recency, flow continuity | |
| 4. **Reranker (this model)** β scores candidate threads from Step 3 | |
| 5. **Confidence threshold** β auto-select if confident, ask user if ambiguous | |
| 6. **LLM responds** β with the correct thread context injected | |
| 7. **Update thread store** β extract new entities and facts, write back to database | |
| The model only fires when the deterministic heuristics in Step 3 produce multiple plausible candidates. Clear-cut cases (unique entity match + high recency) are resolved without the model. | |
| ## Performance | |
| Evaluated on synthetic test data with three difficulty tiers: | |
| | Difficulty | Hit Rate @ 1 | Description | | |
| |---|---|---| | |
| | **Easy** | 100.0% | Message contains explicit entity references ("fix the React bug") | | |
| | **Medium** | 82.1% | Indirect references ("that bug we were debugging") | | |
| | **Hard** | 84.1% | No entity signal, relies on recency and flow ("let's keep going") | | |
| | **Overall** | 90.5% | Weighted across all tiers | | |
| **Note:** In the hybrid pipeline, easy cases are handled by deterministic heuristics without calling the model. The model's effective contribution is on medium and hard cases, where the combined system achieves 95%+ accuracy when including heuristic pre-filtering. | |
| ## Training | |
| **Dataset:** Algokruti/thread-reranker-data β 50,543 synthetic examples (12,500 positive, 38,043 negative) generated from 500 simulated user profiles across 12 topic types in 5 domains. | |
| **Training strategy:** Curriculum learning β epochs 1-2 trained on easy examples only, epochs 3-5 on all difficulty tiers. Binary cross-entropy loss with cosine learning rate schedule and warmup. | |
| **Hyperparameters:** | |
| - Batch size: 64 | |
| - Learning rate: 2e-4 | |
| - Epochs: 5 (2 curriculum + 3 full) | |
| - Max sequence length: 256 | |
| - LoRA rank: 8, alpha: 16 | |
| - Optimizer: AdamW with weight decay 0.01 | |
| - Gradient clipping: max norm 1.0 | |
| **Training domains covered:** | |
| - Web Development (React Dashboard, Authentication, CSS Grid) | |
| - Backend Development (Python API, Docker Deployment) | |
| - Personal (Meal Planning, Job Search, Fitness) | |
| - Data Science (ML Training, Data Pipeline) | |
| - Mobile Development (iOS/Swift, Android/Kotlin) | |
| ## Limitations | |
| - **Trained on synthetic data only.** Performance on real user conversations may differ, particularly for domains and linguistic patterns not represented in the training set. | |
| - **Limited domain coverage.** 12 topics across 5 domains, heavily skewed toward software development. Non-technical topics (travel, health, education, finance, creative writing) are underrepresented. | |
| - **English only.** Not tested on multilingual conversations. | |
| - **Cold start.** With no conversation history, the model has nothing to rank. The system falls back to treating each message as a new thread. | |
| - **Ambiguity resolution.** On genuinely ambiguous messages with no entity, recency, or flow signal, the model may select incorrectly. The confidence threshold mechanism is designed to catch these cases and ask the user instead. | |
| ## How to Use | |
| ### PyTorch Inference | |
| ```python | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| from peft import PeftModel | |
| # Load model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("nreimers/MiniLM-L6-H384-uncased") | |
| # Load the full ThreadReranker (see training notebook for class definition) | |
| model = ThreadReranker() | |
| model.load_state_dict(torch.load("model.pt", map_location="cpu")) | |
| model.eval() | |
| # Score a message against a candidate thread | |
| message = "can you fix that chart rendering issue" | |
| thread_text = "Building a metrics dashboard with Chart.js | the bar chart overflows on mobile | React, Chart.js" | |
| encoding = tokenizer(message, thread_text, max_length=256, | |
| padding="max_length", truncation=True, return_tensors="pt") | |
| features = torch.tensor([[1.0, 1.0, 1.0, 0.92, 2.0]]) # Step 3 features | |
| with torch.no_grad(): | |
| score = torch.sigmoid(model(encoding["input_ids"], encoding["attention_mask"], features)) | |
| print(f"Relevance score: {score.item():.4f}") | |
| ``` | |
| ### ONNX Inference (On-Device) | |
| ```python | |
| import onnxruntime as ort | |
| import numpy as np | |
| session = ort.InferenceSession("thread_reranker.onnx") | |
| # Prepare inputs (tokenized text + structured features) | |
| result = session.run(None, { | |
| "input_ids": input_ids_np, | |
| "attention_mask": attention_mask_np, | |
| "structured_features": features_np, | |
| }) | |
| score = 1 / (1 + np.exp(-result[0])) # sigmoid | |
| ``` | |
| ## Files | |
| | File | Description | | |
| |---|---| | |
| | `model.pt` | PyTorch model weights (base + LoRA merged + classification head) | | |
| | `thread_reranker.onnx` | ONNX export for on-device inference | | |
| | `config.json` | Model configuration and feature definitions | | |
| | `training_history.json` | Per-epoch training and validation metrics | | |
| | `tokenizer.json` | Tokenizer files | | |
| ## Citation | |
| If you use this model, please reference the training dataset: | |
| ``` | |
| @misc{thread-reranker-2026, | |
| title={Thread Reranker: Cross-Encoder for Unified Conversation Thread Matching}, | |
| author={Algokruti}, | |
| year={2026}, | |
| publisher={Hugging Face}, | |
| url={https://huggingface.co/Algokruti/thread-reranker} | |
| } | |
| ``` |