xavier-fuentes commited on
Commit
0a40b90
·
verified ·
1 Parent(s): 2f47f5d

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +4 -16
  2. app.py +73 -52
  3. requirements.txt +4 -3
README.md CHANGED
@@ -1,28 +1,16 @@
1
  ---
2
- title: Text Reranker - Cross-Encoder Reranking
3
  emoji: 🔎
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- # Text Reranker - Cross-Encoder Reranking
14
 
15
- A lightweight Hugging Face Space that reranks passages for a given query using:
16
 
17
- - `cross-encoder/ms-marco-MiniLM-L-12-v2`
18
- - `sentence-transformers` `CrossEncoder`
19
- - Gradio UI with ZeroGPU support via `@spaces.GPU`
20
-
21
- ## Usage
22
-
23
- 1. Enter a query.
24
- 2. Paste passages, one per line.
25
- 3. Choose Top-K.
26
- 4. Click **Rerank**.
27
-
28
- The app returns a markdown table sorted by relevance score and displays inference time.
 
1
  ---
2
+ title: Qwen3-Reranker-8B Text Reranker
3
  emoji: 🔎
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  license: mit
10
  ---
11
 
12
+ # Qwen3-Reranker-8B Text Reranker
13
 
14
+ Fast text-only reranking Space powered by `Qwen/Qwen3-Reranker-8B`.
15
 
16
+ Enter a query and passages, one per line. The app returns a sorted relevance table and inference time.
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -3,10 +3,36 @@ from typing import List
3
 
4
  import gradio as gr
5
  import spaces
6
- from sentence_transformers import CrossEncoder
7
-
8
- MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-12-v2"
9
- model = CrossEncoder(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def _parse_passages(text: str) -> List[str]:
@@ -15,6 +41,36 @@ def _parse_passages(text: str) -> List[str]:
15
  return [line.strip() for line in text.splitlines() if line.strip()]
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @spaces.GPU
19
  def rerank(query: str, passages_text: str, top_k: int):
20
  start = time.perf_counter()
@@ -23,79 +79,44 @@ def rerank(query: str, passages_text: str, top_k: int):
23
  passages = _parse_passages(passages_text or "")
24
 
25
  if not query and not passages:
26
- return (
27
- "Please provide a query and at least one passage.",
28
- "Inference time: 0.000s",
29
- )
30
  if not query:
31
  return "Please provide a query.", "Inference time: 0.000s"
32
  if not passages:
33
  return "Please provide at least one passage.", "Inference time: 0.000s"
34
 
35
- top_k = max(1, min(int(top_k), 20, len(passages)))
 
 
 
36
 
37
- pairs = [[query, p] for p in passages]
38
- scores = model.predict(pairs)
39
-
40
- ranked = sorted(
41
- zip(passages, scores),
42
- key=lambda x: float(x[1]),
43
- reverse=True,
44
- )
45
 
46
- ranked = ranked[:top_k]
47
-
48
- lines = [
49
- "| Rank | Score | Passage |",
50
- "|---:|---:|---|",
51
- ]
52
  for i, (passage, score) in enumerate(ranked, start=1):
53
  safe_passage = passage.replace("|", "\\|").replace("\n", " ")
54
- lines.append(f"| {i} | {float(score):.4f} | {safe_passage} |")
55
 
56
  elapsed = time.perf_counter() - start
57
  return "\n".join(lines), f"Inference time: {elapsed:.3f}s"
58
 
59
 
60
- with gr.Blocks(title="Text Reranker - Cross-Encoder Reranking") as demo:
61
- gr.Markdown("# Text Reranker - Cross-Encoder Reranking")
62
-
63
- query = gr.Textbox(
64
- label="Query",
65
- placeholder="Enter your search query...",
66
- lines=1,
67
- )
68
 
 
69
  passages = gr.Textbox(
70
  label="Passages (one per line)",
71
  placeholder="Enter one passage per line...",
72
  lines=10,
73
  )
74
-
75
- top_k = gr.Slider(
76
- minimum=1,
77
- maximum=20,
78
- value=5,
79
- step=1,
80
- label="Top-K",
81
- )
82
-
83
  run_btn = gr.Button("Rerank")
84
 
85
  output_md = gr.Markdown(label="Ranked Results")
86
  inference_time = gr.Textbox(label="Inference Time", interactive=False)
87
 
88
- run_btn.click(
89
- fn=rerank,
90
- inputs=[query, passages, top_k],
91
- outputs=[output_md, inference_time],
92
- )
93
-
94
- gr.Markdown(
95
- "Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ "
96
- "[AI Enablement Academy](https://enablement.academy) | "
97
- "[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)"
98
- )
99
 
100
 
101
  if __name__ == "__main__":
 
3
 
4
  import gradio as gr
5
  import spaces
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ MODEL_NAME = "Qwen/Qwen3-Reranker-8B"
10
+ INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query"
11
+
12
+ # Load once at startup
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_NAME,
15
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
16
+ trust_remote_code=True,
17
+ )
18
+ if torch.cuda.is_available():
19
+ model = model.cuda()
20
+ model.eval()
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left", trust_remote_code=True)
23
+ token_false_id = tokenizer.convert_tokens_to_ids("no")
24
+ token_true_id = tokenizer.convert_tokens_to_ids("yes")
25
+
26
+ max_length = 8192
27
+ prefix = (
28
+ "<|im_start|>system\n"
29
+ "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
30
+ "Note that the answer can only be \"yes\" or \"no\"."
31
+ "<|im_end|>\n<|im_start|>user\n"
32
+ )
33
+ suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
34
+ prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
35
+ suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
36
 
37
 
38
  def _parse_passages(text: str) -> List[str]:
 
41
  return [line.strip() for line in text.splitlines() if line.strip()]
42
 
43
 
44
+ def _format_pair(query: str, doc: str) -> str:
45
+ return f"<Instruct>: {INSTRUCTION}\n<Query>: {query}\n<Document>: {doc}"
46
+
47
+
48
+ def _process_inputs(pairs: List[str]):
49
+ inputs = tokenizer(
50
+ pairs,
51
+ padding=False,
52
+ truncation="longest_first",
53
+ return_attention_mask=False,
54
+ max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
55
+ )
56
+ for i, ids in enumerate(inputs["input_ids"]):
57
+ inputs["input_ids"][i] = prefix_tokens + ids + suffix_tokens
58
+ inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
59
+ for key in inputs:
60
+ inputs[key] = inputs[key].to(model.device)
61
+ return inputs
62
+
63
+
64
+ @torch.no_grad()
65
+ def _compute_scores(inputs):
66
+ logits = model(**inputs).logits[:, -1, :]
67
+ true_vector = logits[:, token_true_id]
68
+ false_vector = logits[:, token_false_id]
69
+ score_2way = torch.stack([false_vector, true_vector], dim=1)
70
+ score_2way = torch.nn.functional.log_softmax(score_2way, dim=1)
71
+ return score_2way[:, 1].exp().tolist()
72
+
73
+
74
  @spaces.GPU
75
  def rerank(query: str, passages_text: str, top_k: int):
76
  start = time.perf_counter()
 
79
  passages = _parse_passages(passages_text or "")
80
 
81
  if not query and not passages:
82
+ return "Please provide a query and at least one passage.", "Inference time: 0.000s"
 
 
 
83
  if not query:
84
  return "Please provide a query.", "Inference time: 0.000s"
85
  if not passages:
86
  return "Please provide at least one passage.", "Inference time: 0.000s"
87
 
88
+ top_k = max(1, min(int(top_k), 50, len(passages)))
89
+ pairs = [_format_pair(query, p) for p in passages]
90
+ inputs = _process_inputs(pairs)
91
+ scores = _compute_scores(inputs)
92
 
93
+ ranked = sorted(zip(passages, scores), key=lambda x: float(x[1]), reverse=True)[:top_k]
 
 
 
 
 
 
 
94
 
95
+ lines = ["| Rank | Score | Passage |", "|---:|---:|---|"]
 
 
 
 
 
96
  for i, (passage, score) in enumerate(ranked, start=1):
97
  safe_passage = passage.replace("|", "\\|").replace("\n", " ")
98
+ lines.append(f"| {i} | {float(score):.6f} | {safe_passage} |")
99
 
100
  elapsed = time.perf_counter() - start
101
  return "\n".join(lines), f"Inference time: {elapsed:.3f}s"
102
 
103
 
104
+ with gr.Blocks(title="Qwen3-Reranker-8B Text Reranker") as demo:
105
+ gr.Markdown("# Qwen3-Reranker-8B Text Reranker")
 
 
 
 
 
 
106
 
107
+ query = gr.Textbox(label="Query", placeholder="Enter your search query...", lines=1)
108
  passages = gr.Textbox(
109
  label="Passages (one per line)",
110
  placeholder="Enter one passage per line...",
111
  lines=10,
112
  )
113
+ top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Top-K")
 
 
 
 
 
 
 
 
114
  run_btn = gr.Button("Rerank")
115
 
116
  output_md = gr.Markdown(label="Ranked Results")
117
  inference_time = gr.Textbox(label="Inference Time", interactive=False)
118
 
119
+ run_btn.click(fn=rerank, inputs=[query, passages, top_k], outputs=[output_md, inference_time])
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- gradio
2
- sentence-transformers
3
- torch
 
4
  accelerate
 
1
+ transformers>=4.57.0
2
+ torch>=2.0
3
+ gradio>=4.0
4
+ spaces
5
  accelerate