thebajajra commited on
Commit
8d034e7
·
verified ·
1 Parent(s): 461601f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +289 -3
README.md CHANGED
@@ -1,3 +1,289 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - ecommerce
7
+ - e-commerce
8
+ - retail
9
+ - marketplace
10
+ - shopping
11
+ - amazon
12
+ - ebay
13
+ - alibaba
14
+ - google
15
+ - rakuten
16
+ - bestbuy
17
+ - walmart
18
+ - flipkart
19
+ - wayfair
20
+ - shein
21
+ - target
22
+ - etsy
23
+ - shopify
24
+ - taobao
25
+ - asos
26
+ - carrefour
27
+ - costco
28
+ - overstock
29
+ - pretraining
30
+ - encoder
31
+ - language-modeling
32
+ - foundation-model
33
+ base_model:
34
+ - thebajajra/RexBERT-base
35
+ pipeline_tag: text-ranking
36
+ library_name: sentence-transformers
37
+ ---
38
+
39
+ <p align="center">
40
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/6893dd21467f7d2f5f358a95/apOIbl5PdJuRk-tQMdDc8.png" alt="RexReranker">
41
+ </p>
42
+ <p align="center">
43
+ </p>
44
+
45
+ # RexReranker Base
46
+
47
+ A distributional **e-commerce** neural reranker based on RexBERT-base that predicts relevance scores as a probability distribution, providing both accurate relevance predictions and uncertainty estimates.
48
+
49
+ ## Features
50
+
51
+ - **Distributional Output**: Predicts a probability distribution over relevance bins (0.0 to 1.0)
52
+ - **Uncertainty Estimates**: Provides variance and entropy for confidence assessment
53
+ - **CrossEncoder Compatible**: Works directly with Sentence Transformers CrossEncoder
54
+ - **Mean Pooling**: Uses mean pooling over all tokens for robust representations
55
+
56
+ ## Installation
57
+
58
+ ```bash
59
+ pip install transformers sentence-transformers torch
60
+ ```
61
+
62
+ ## Quick Start
63
+
64
+ ### 1. Using HuggingFace Transformers
65
+
66
+ ```python
67
+ from transformers import AutoModel, AutoTokenizer
68
+ import torch
69
+
70
+ # Load model and tokenizer
71
+ model = AutoModel.from_pretrained(
72
+ "thebajajra/RexReranker-base",
73
+ trust_remote_code=True
74
+ )
75
+ tokenizer = AutoTokenizer.from_pretrained("thebajajra/RexReranker-base")
76
+
77
+ # Move to GPU if available
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+ model = model.to(device)
80
+ model.eval()
81
+
82
+ # Prepare input (query-document pair)
83
+ query = "best laptop for programming"
84
+ title = "MacBook Pro M3"
85
+ description = "Powerful laptop with M3 chip, 16GB RAM, perfect for developers and creative professionals"
86
+
87
+ inputs = tokenizer(
88
+ f"Query: {query}",
89
+ f"Title: {title}\nDescription: {description}",
90
+ return_tensors="pt",
91
+ truncation=True,
92
+ max_length=2048,
93
+ ).to(device)
94
+
95
+ # Get relevance score
96
+ with torch.no_grad():
97
+ score = model.predict_relevance(**inputs)
98
+ print(f"Relevance Score: {score.item():.4f}")
99
+ ```
100
+
101
+ ### 2. Using Sentence Transformers CrossEncoder
102
+
103
+ ```python
104
+ from sentence_transformers import CrossEncoder
105
+
106
+ # Load as CrossEncoder
107
+ model = CrossEncoder(
108
+ "thebajajra/RexReranker-base",
109
+ trust_remote_code=True
110
+ )
111
+
112
+ # Single prediction
113
+ query = "best laptop for programming"
114
+ document = "MacBook Pro M3 - Powerful laptop with M3 chip for developers"
115
+
116
+ score = model.predict([(query, document)])[0]
117
+ print(f"Score: {score:.4f}")
118
+ ```
119
+
120
+ ### 3. Batch Reranking with CrossEncoder
121
+
122
+ ```python
123
+ from sentence_transformers import CrossEncoder
124
+
125
+ model = CrossEncoder("thebajajra/RexReranker-base", trust_remote_code=True)
126
+
127
+ query = "best laptop for programming"
128
+ documents = [
129
+ "MacBook Pro M3 - Powerful laptop with M3 chip for developers",
130
+ "Gaming Mouse RGB - High precision gaming mouse with 16000 DPI",
131
+ "ThinkPad X1 Carbon - Business ultrabook with long battery life",
132
+ "Mechanical Keyboard - Cherry MX switches for typing comfort",
133
+ "Dell XPS 15 - Premium laptop with 4K OLED display",
134
+ ]
135
+
136
+ # Get scores for all documents
137
+ pairs = [(query, doc) for doc in documents]
138
+ scores = model.predict(pairs)
139
+
140
+ # Print ranked results
141
+ print(f"Query: {query}\n")
142
+ for doc, score in sorted(zip(documents, scores), key=lambda x: x[1], reverse=True):
143
+ print(f" {score:.4f} | {doc[:60]}")
144
+ ```
145
+
146
+ ### 4. Using CrossEncoder's rank() Method
147
+
148
+ ```python
149
+ from sentence_transformers import CrossEncoder
150
+
151
+ model = CrossEncoder("thebajajra/RexReranker-base", trust_remote_code=True)
152
+
153
+ query = "wireless headphones with noise cancellation"
154
+ documents = [
155
+ "Sony WH-1000XM5 - Industry-leading noise cancellation headphones",
156
+ "Apple AirPods Max - Premium over-ear headphones with spatial audio",
157
+ "Bose QuietComfort 45 - Comfortable wireless noise cancelling headphones",
158
+ "JBL Tune 750BTNC - Affordable wireless headphones with ANC",
159
+ "Logitech Gaming Headset - Wired gaming headphones with microphone",
160
+ ]
161
+
162
+ # Rank documents
163
+ results = model.rank(query, documents, top_k=3)
164
+
165
+ print(f"Query: {query}\n")
166
+ print("Top 3 Results:")
167
+ for result in results:
168
+ idx = result['corpus_id']
169
+ score = result['score']
170
+ print(f" {score:.4f} | {documents[idx][:60]}")
171
+ ```
172
+
173
+ ### 5. With Uncertainty Estimates
174
+
175
+ ```python
176
+ from transformers import AutoModel, AutoTokenizer
177
+ import torch
178
+
179
+ model = AutoModel.from_pretrained("thebajajra/RexReranker-base", trust_remote_code=True)
180
+ tokenizer = AutoTokenizer.from_pretrained("thebajajra/RexReranker-base")
181
+
182
+ device = "cuda" if torch.cuda.is_available() else "cpu"
183
+ model = model.to(device).eval()
184
+
185
+ # Prepare inputs
186
+ inputs = tokenizer(
187
+ "Query: best laptop for programming",
188
+ "Title: MacBook Pro\nDescription: Great laptop for developers",
189
+ return_tensors="pt",
190
+ truncation=True,
191
+ ).to(device)
192
+
193
+ # Get prediction with uncertainty
194
+ with torch.no_grad():
195
+ result = model.predict_with_uncertainty(**inputs)
196
+
197
+ print(f"Relevance: {result['relevance'].item():.4f}")
198
+ print(f"Variance: {result['variance'].item():.6f}") # Higher = more uncertain
199
+ print(f"Entropy: {result['entropy'].item():.4f}") # Higher = more uncertain
200
+
201
+ # Access full probability distribution
202
+ print(f"\nDistribution over bins:")
203
+ probs = result['probs'][0].cpu().numpy()
204
+ for i, p in enumerate(probs):
205
+ bin_center = i / (len(probs) - 1)
206
+ bar = "█" * int(p * 50)
207
+ print(f" {bin_center:.1f}: {bar} ({p:.3f})")
208
+ ```
209
+
210
+ ### 6. Batch Processing for Production
211
+
212
+ ```python
213
+ from transformers import AutoModel, AutoTokenizer
214
+ import torch
215
+ from torch.utils.data import DataLoader
216
+
217
+ model = AutoModel.from_pretrained("thebajajra/RexReranker-base", trust_remote_code=True)
218
+ tokenizer = AutoTokenizer.from_pretrained("thebajajra/RexReranker-base")
219
+
220
+ device = "cuda" if torch.cuda.is_available() else "cpu"
221
+ model = model.to(device).eval()
222
+
223
+ def rerank_batch(query: str, documents: list, batch_size: int = 32) -> list:
224
+ """Rerank documents for a query with batched inference."""
225
+
226
+ # Prepare all inputs
227
+ all_inputs = []
228
+ for doc in documents:
229
+ title = doc.get("title", "")
230
+ description = doc.get("description", "")
231
+ inputs = tokenizer(
232
+ f"Query: {query}",
233
+ f"Title: {title}\nDescription: {description}",
234
+ truncation=True,
235
+ max_length=2048,
236
+ padding=False,
237
+ )
238
+ all_inputs.append(inputs)
239
+
240
+ # Batch inference
241
+ all_scores = []
242
+ for i in range(0, len(all_inputs), batch_size):
243
+ batch = all_inputs[i:i + batch_size]
244
+ padded = tokenizer.pad(batch, return_tensors="pt").to(device)
245
+
246
+ with torch.no_grad():
247
+ scores = model.predict_relevance(**padded)
248
+ all_scores.extend(scores.cpu().tolist())
249
+
250
+ # Add scores to documents and sort
251
+ for doc, score in zip(documents, all_scores):
252
+ doc["score"] = score
253
+
254
+ return sorted(documents, key=lambda x: x["score"], reverse=True)
255
+
256
+ # Example usage
257
+ query = "best laptop for programming"
258
+ documents = [
259
+ {"title": "MacBook Pro M3", "description": "Powerful laptop for developers"},
260
+ {"title": "Gaming Mouse", "description": "High DPI gaming mouse"},
261
+ {"title": "ThinkPad X1", "description": "Business laptop with long battery"},
262
+ ]
263
+
264
+ ranked = rerank_batch(query, documents)
265
+ for doc in ranked:
266
+ print(f"{doc['score']:.4f} | {doc['title']}")
267
+ ```
268
+
269
+ ## Input Format
270
+
271
+ The model expects query-document pairs formatted as:
272
+
273
+ | Field | Format |
274
+ |-------|--------|
275
+ | Text A (Query) | `Query: {your search query}` |
276
+ | Text B (Document) | `Title: {document title}\nDescription: {document description}` |
277
+
278
+ ## Output Details
279
+
280
+ ### Standard Output (CrossEncoder compatible)
281
+ - `outputs.logits`: Shape `[B, 1]` - Single relevance score per example
282
+ - `outputs.relevance`: Shape `[B]` - Same as logits squeezed
283
+
284
+ ### With Uncertainty (`output_distribution=True` or `predict_with_uncertainty()`)
285
+ - `relevance`: Expected relevance score [0, 1]
286
+ - `variance`: Prediction variance (higher = less confident)
287
+ - `entropy`: Distribution entropy (higher = less confident)
288
+ - `probs`: Full probability distribution over bins
289
+ - `distribution_logits`: Raw logits before softmax