ReallyFloppyPenguin commited on
Commit
8150f4d
Β·
verified Β·
1 Parent(s): b957fc7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +386 -0
app.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
4
+ import logging
5
+ from typing import List, Tuple
6
+ import pandas as pd
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class Qwen3Reranker:
13
+ def __init__(self, model_name="Qwen/Qwen3-Reranker-0.6B"):
14
+ self.model_name = model_name
15
+ self.tokenizer = None
16
+ self.model = None
17
+ self.token_false_id = None
18
+ self.token_true_id = None
19
+ self.max_length = 8192
20
+ self.prefix_tokens = None
21
+ self.suffix_tokens = None
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ self._load_model()
25
+
26
+ def _load_model(self):
27
+ """Load the tokenizer and model"""
28
+ try:
29
+ logger.info(f"Loading {self.model_name}...")
30
+ self.tokenizer = AutoTokenizer.from_pretrained(
31
+ self.model_name,
32
+ padding_side='left'
33
+ )
34
+
35
+ # Load model with appropriate settings
36
+ if torch.cuda.is_available():
37
+ self.model = AutoModelForCausalLM.from_pretrained(
38
+ self.model_name,
39
+ torch_dtype=torch.float16,
40
+ device_map="auto"
41
+ ).eval()
42
+ else:
43
+ self.model = AutoModelForCausalLM.from_pretrained(
44
+ self.model_name
45
+ ).eval()
46
+
47
+ # Set up tokens
48
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
49
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
50
+
51
+ # Set up prefix and suffix
52
+ prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
53
+ suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
54
+ self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
55
+ self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
56
+
57
+ logger.info("Model loaded successfully!")
58
+
59
+ except Exception as e:
60
+ logger.error(f"Error loading model: {e}")
61
+ raise e
62
+
63
+ def format_instruction(self, instruction: str, query: str, doc: str) -> str:
64
+ """Format the instruction for the reranker"""
65
+ if instruction is None or instruction.strip() == "":
66
+ instruction = 'Given a web search query, retrieve relevant passages that answer the query'
67
+ return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
68
+
69
+ def process_inputs(self, pairs: List[str]) -> dict:
70
+ """Process input pairs for the model"""
71
+ inputs = self.tokenizer(
72
+ pairs,
73
+ padding=False,
74
+ truncation='longest_first',
75
+ return_attention_mask=False,
76
+ max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
77
+ )
78
+
79
+ for i, ele in enumerate(inputs['input_ids']):
80
+ inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
81
+
82
+ inputs = self.tokenizer.pad(
83
+ inputs,
84
+ padding=True,
85
+ return_tensors="pt",
86
+ max_length=self.max_length
87
+ )
88
+
89
+ for key in inputs:
90
+ inputs[key] = inputs[key].to(self.model.device)
91
+
92
+ return inputs
93
+
94
+ @torch.no_grad()
95
+ def compute_scores(self, inputs: dict) -> List[float]:
96
+ """Compute relevance scores"""
97
+ batch_scores = self.model(**inputs).logits[:, -1, :]
98
+ true_vector = batch_scores[:, self.token_true_id]
99
+ false_vector = batch_scores[:, self.token_false_id]
100
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
101
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
102
+ scores = batch_scores[:, 1].exp().tolist()
103
+ return scores
104
+
105
+ def rank_documents(self, query: str, documents: List[str], instruction: str = None) -> List[Tuple[str, float]]:
106
+ """Rank documents by relevance to query"""
107
+ if not documents or not query.strip():
108
+ return []
109
+
110
+ # Format inputs
111
+ pairs = [
112
+ self.format_instruction(instruction, query, doc)
113
+ for doc in documents
114
+ ]
115
+
116
+ # Process and score
117
+ inputs = self.process_inputs(pairs)
118
+ scores = self.compute_scores(inputs)
119
+
120
+ # Combine documents with scores and sort
121
+ doc_scores = list(zip(documents, scores))
122
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
123
+
124
+ return doc_scores
125
+
126
+ # Initialize the reranker
127
+ try:
128
+ reranker = Qwen3Reranker()
129
+ model_loaded = True
130
+ except Exception as e:
131
+ logger.error(f"Failed to initialize reranker: {e}")
132
+ model_loaded = False
133
+ reranker = None
134
+
135
+ def rerank_documents(query: str, documents_text: str, instruction: str = None) -> tuple:
136
+ """
137
+ Rerank documents based on query relevance
138
+
139
+ Args:
140
+ query: The search query
141
+ documents_text: Documents separated by newlines or numbered
142
+ instruction: Custom instruction (optional)
143
+
144
+ Returns:
145
+ Tuple of (formatted results table, download data)
146
+ """
147
+ if not model_loaded:
148
+ return "❌ Model not loaded. Please check the logs.", None
149
+
150
+ if not query.strip():
151
+ return "❌ Please enter a query.", None
152
+
153
+ if not documents_text.strip():
154
+ return "❌ Please enter at least one document.", None
155
+
156
+ try:
157
+ # Parse documents
158
+ documents = []
159
+ lines = documents_text.strip().split('\n')
160
+
161
+ for line in lines:
162
+ line = line.strip()
163
+ if not line:
164
+ continue
165
+
166
+ # Remove numbering if present (e.g., "1. Document text" -> "Document text")
167
+ if line and line[0].isdigit() and '. ' in line:
168
+ line = line.split('. ', 1)[1]
169
+
170
+ documents.append(line)
171
+
172
+ if not documents:
173
+ return "❌ No valid documents found.", None
174
+
175
+ # Rank documents
176
+ ranked_docs = reranker.rank_documents(query, documents, instruction)
177
+
178
+ # Create results
179
+ results_data = []
180
+ for i, (doc, score) in enumerate(ranked_docs, 1):
181
+ results_data.append({
182
+ "Rank": i,
183
+ "Score": f"{score:.4f}",
184
+ "Document": doc[:200] + "..." if len(doc) > 200 else doc,
185
+ "Full Document": doc
186
+ })
187
+
188
+ # Create display table
189
+ df_display = pd.DataFrame([
190
+ {"Rank": item["Rank"], "Score": item["Score"], "Document": item["Document"]}
191
+ for item in results_data
192
+ ])
193
+
194
+ # Create download data
195
+ df_download = pd.DataFrame([
196
+ {"Rank": item["Rank"], "Score": item["Score"], "Document": item["Full Document"]}
197
+ for item in results_data
198
+ ])
199
+
200
+ return df_display, df_download
201
+
202
+ except Exception as e:
203
+ logger.error(f"Error in reranking: {e}")
204
+ return f"❌ Error during reranking: {str(e)}", None
205
+
206
+ def create_gradio_interface():
207
+ """Create the Gradio interface"""
208
+
209
+ with gr.Blocks(
210
+ title="Qwen3-Reranker-0.6B",
211
+ theme=gr.themes.Soft(),
212
+ css="""
213
+ .main-header {
214
+ text-align: center;
215
+ margin-bottom: 2rem;
216
+ }
217
+ .model-info {
218
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
219
+ color: white;
220
+ padding: 1rem;
221
+ border-radius: 10px;
222
+ margin-bottom: 1rem;
223
+ }
224
+ .example-box {
225
+ border: 1px solid #e0e0e0;
226
+ padding: 1rem;
227
+ border-radius: 8px;
228
+ margin: 0.5rem 0;
229
+ }
230
+ """
231
+ ) as demo:
232
+
233
+ gr.HTML("""
234
+ <div class="main-header">
235
+ <h1>πŸ” Qwen3-Reranker-0.6B</h1>
236
+ <p>Advanced Text Reranking with Multilingual Support</p>
237
+ </div>
238
+ """)
239
+
240
+ with gr.Row():
241
+ with gr.Column():
242
+ gr.HTML("""
243
+ <div class="model-info">
244
+ <h3>πŸš€ Model Information</h3>
245
+ <ul>
246
+ <li><strong>Model:</strong> Qwen3-Reranker-0.6B</li>
247
+ <li><strong>Parameters:</strong> 0.6B</li>
248
+ <li><strong>Context Length:</strong> 32K tokens</li>
249
+ <li><strong>Languages:</strong> 100+ languages supported</li>
250
+ <li><strong>Use Case:</strong> Document ranking and relevance scoring</li>
251
+ </ul>
252
+ </div>
253
+ """)
254
+
255
+ with gr.Row():
256
+ with gr.Column(scale=1):
257
+ gr.HTML("<h3>πŸ“ Input</h3>")
258
+
259
+ query_input = gr.Textbox(
260
+ label="Search Query",
261
+ placeholder="Enter your search query here...",
262
+ lines=2,
263
+ value="What is the capital of China?"
264
+ )
265
+
266
+ instruction_input = gr.Textbox(
267
+ label="Custom Instruction (Optional)",
268
+ placeholder="Leave empty for default instruction...",
269
+ lines=2,
270
+ value=""
271
+ )
272
+
273
+ documents_input = gr.Textbox(
274
+ label="Documents to Rank",
275
+ placeholder="Enter documents, one per line or numbered...",
276
+ lines=8,
277
+ value="""The capital of China is Beijing.
278
+ China is a country in East Asia with a large population.
279
+ Beijing is located in northern China and serves as the political center.
280
+ Shanghai is the largest city in China by population.
281
+ The Great Wall of China is a famous landmark."""
282
+ )
283
+
284
+ with gr.Row():
285
+ rank_btn = gr.Button("πŸ” Rank Documents", variant="primary", size="lg")
286
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
287
+
288
+ with gr.Column(scale=1):
289
+ gr.HTML("<h3>πŸ“Š Results</h3>")
290
+
291
+ results_display = gr.DataFrame(
292
+ label="Ranked Documents",
293
+ headers=["Rank", "Score", "Document"],
294
+ interactive=False,
295
+ height=400
296
+ )
297
+
298
+ download_data = gr.State()
299
+
300
+ download_btn = gr.DownloadButton(
301
+ "πŸ’Ύ Download Results (CSV)",
302
+ visible=False
303
+ )
304
+
305
+ # Examples section
306
+ gr.HTML("<h3>πŸ’‘ Examples</h3>")
307
+
308
+ with gr.Row():
309
+ with gr.Column():
310
+ gr.HTML("""
311
+ <div class="example-box">
312
+ <h4>Example 1: General Search</h4>
313
+ <p><strong>Query:</strong> "Python programming tutorials"</p>
314
+ <p><strong>Documents:</strong> Various programming resources</p>
315
+ </div>
316
+ """)
317
+
318
+ with gr.Column():
319
+ gr.HTML("""
320
+ <div class="example-box">
321
+ <h4>Example 2: Scientific Research</h4>
322
+ <p><strong>Query:</strong> "Machine learning applications in healthcare"</p>
323
+ <p><strong>Documents:</strong> Research papers and articles</p>
324
+ </div>
325
+ """)
326
+
327
+ def update_interface(query, documents, instruction):
328
+ if not model_loaded:
329
+ return "❌ Model not loaded", None, gr.update(visible=False)
330
+
331
+ results, download_df = rerank_documents(query, documents, instruction)
332
+
333
+ if download_df is not None:
334
+ return results, download_df, gr.update(visible=True)
335
+ else:
336
+ return results, None, gr.update(visible=False)
337
+
338
+ def clear_inputs():
339
+ return "", "", "", None, None, gr.update(visible=False)
340
+
341
+ def download_csv(download_df):
342
+ if download_df is not None:
343
+ return download_df.to_csv(index=False)
344
+ return None
345
+
346
+ # Event handlers
347
+ rank_btn.click(
348
+ fn=update_interface,
349
+ inputs=[query_input, documents_input, instruction_input],
350
+ outputs=[results_display, download_data, download_btn]
351
+ )
352
+
353
+ clear_btn.click(
354
+ fn=clear_inputs,
355
+ outputs=[query_input, documents_input, instruction_input, results_display, download_data, download_btn]
356
+ )
357
+
358
+ download_btn.click(
359
+ fn=download_csv,
360
+ inputs=[download_data],
361
+ outputs=[download_btn]
362
+ )
363
+
364
+ # Footer
365
+ gr.HTML("""
366
+ <div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid #e0e0e0;">
367
+ <p>πŸ€— <a href="https://huggingface.co/Qwen/Qwen3-Reranker-0.6B" target="_blank">Model on Hugging Face</a> |
368
+ πŸ“– <a href="https://arxiv.org/abs/2506.05176" target="_blank">Research Paper</a></p>
369
+ <p><em>Powered by Qwen3-Reranker-0.6B - Advanced multilingual text reranking</em></p>
370
+ </div>
371
+ """)
372
+
373
+ return demo
374
+
375
+ if __name__ == "__main__":
376
+ # Create and launch the interface
377
+ demo = create_gradio_interface()
378
+
379
+ # Launch with appropriate settings
380
+ demo.launch(
381
+ server_name="0.0.0.0",
382
+ server_port=7860,
383
+ share=False,
384
+ debug=True,
385
+ show_error=True
386
+ )