|
|
import gradio as gr |
|
|
import logging |
|
|
from typing import List, Tuple |
|
|
import pandas as pd |
|
|
import os |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
DEPENDENCIES_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
logger.warning(f"Dependencies not available: {e}") |
|
|
DEPENDENCIES_AVAILABLE = False |
|
|
torch = None |
|
|
AutoTokenizer = None |
|
|
AutoModelForCausalLM = None |
|
|
|
|
|
class Qwen3Reranker: |
|
|
def __init__(self, model_name="Qwen/Qwen3-Reranker-0.6B"): |
|
|
if not DEPENDENCIES_AVAILABLE: |
|
|
raise ImportError("Required dependencies (torch, transformers) are not available") |
|
|
|
|
|
self.model_name = model_name |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
self.token_false_id = None |
|
|
self.token_true_id = None |
|
|
self.max_length = 8192 |
|
|
self.prefix_tokens = None |
|
|
self.suffix_tokens = None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the tokenizer and model""" |
|
|
try: |
|
|
logger.info(f"Loading {self.model_name}...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_name, |
|
|
padding_side='left' |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
).eval() |
|
|
else: |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name |
|
|
).eval() |
|
|
|
|
|
|
|
|
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") |
|
|
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") |
|
|
|
|
|
|
|
|
prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" |
|
|
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
|
self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
|
|
self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) |
|
|
|
|
|
logger.info("Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {e}") |
|
|
raise e |
|
|
|
|
|
def format_instruction(self, instruction: str, query: str, doc: str) -> str: |
|
|
"""Format the instruction for the reranker""" |
|
|
if instruction is None or instruction.strip() == "": |
|
|
instruction = 'Given a web search query, retrieve relevant passages that answer the query' |
|
|
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}" |
|
|
|
|
|
def process_inputs(self, pairs: List[str]) -> dict: |
|
|
"""Process input pairs for the model""" |
|
|
inputs = self.tokenizer( |
|
|
pairs, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors='pt', |
|
|
max_length=self.max_length |
|
|
) |
|
|
return inputs |
|
|
|
|
|
def rank_documents(self, instruction: str, query: str, documents: List[str]) -> List[Tuple[str, float, str]]: |
|
|
"""Rank documents based on their relevance to the query""" |
|
|
if not DEPENDENCIES_AVAILABLE: |
|
|
return [(doc[:100] + "...", 0.5, "Dependencies not available") for doc in documents] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, doc in enumerate(documents): |
|
|
try: |
|
|
|
|
|
formatted_instruction = self.format_instruction(instruction, query, doc) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
formatted_instruction, |
|
|
return_tensors='pt', |
|
|
max_length=self.max_length, |
|
|
truncation=True |
|
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
logits = outputs.logits[0, -1, :] |
|
|
|
|
|
|
|
|
yes_prob = torch.softmax(logits, dim=-1)[self.token_true_id].item() |
|
|
no_prob = torch.softmax(logits, dim=-1)[self.token_false_id].item() |
|
|
|
|
|
|
|
|
relevance_score = yes_prob / (yes_prob + no_prob) |
|
|
|
|
|
|
|
|
display_doc = doc[:200] + "..." if len(doc) > 200 else doc |
|
|
results.append((display_doc, relevance_score, f"Document {i+1}")) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing document {i+1}: {e}") |
|
|
display_doc = doc[:200] + "..." if len(doc) > 200 else doc |
|
|
results.append((display_doc, 0.0, f"Error: {str(e)[:50]}...")) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[1], reverse=True) |
|
|
return results |
|
|
|
|
|
|
|
|
try: |
|
|
reranker = Qwen3Reranker() |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize reranker: {e}") |
|
|
reranker = None |
|
|
|
|
|
def rerank_documents(instruction, query, documents_text): |
|
|
"""Gradio interface function""" |
|
|
if not reranker: |
|
|
return pd.DataFrame([["Error", "Model not loaded", 0.0]], |
|
|
columns=["Document", "Relevance Score", "Rank"]) |
|
|
|
|
|
if not query.strip(): |
|
|
return pd.DataFrame([["Error", "Please provide a query", 0.0]], |
|
|
columns=["Document", "Relevance Score", "Rank"]) |
|
|
|
|
|
if not documents_text.strip(): |
|
|
return pd.DataFrame([["Error", "Please provide documents", 0.0]], |
|
|
columns=["Document", "Relevance Score", "Rank"]) |
|
|
|
|
|
|
|
|
documents = [] |
|
|
if '\n\n' in documents_text: |
|
|
documents = [doc.strip() for doc in documents_text.split('\n\n') if doc.strip()] |
|
|
else: |
|
|
|
|
|
lines = documents_text.strip().split('\n') |
|
|
current_doc = "" |
|
|
for line in lines: |
|
|
if line.strip() and (line.strip()[0].isdigit() and '.' in line[:5]): |
|
|
if current_doc: |
|
|
documents.append(current_doc.strip()) |
|
|
current_doc = line |
|
|
else: |
|
|
current_doc += "\n" + line |
|
|
if current_doc: |
|
|
documents.append(current_doc.strip()) |
|
|
|
|
|
if not documents: |
|
|
documents = [documents_text] |
|
|
|
|
|
|
|
|
results = reranker.rank_documents(instruction, query, documents) |
|
|
|
|
|
|
|
|
df_data = [] |
|
|
for i, (doc, score, label) in enumerate(results): |
|
|
df_data.append([f"#{i+1}", doc, f"{score:.4f}"]) |
|
|
|
|
|
return pd.DataFrame(df_data, columns=["Rank", "Document", "Relevance Score"]) |
|
|
|
|
|
def create_gradio_interface(): |
|
|
"""Create the Gradio interface""" |
|
|
|
|
|
with gr.Blocks(title="Qwen3 Document Reranker", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# ๐ Qwen3 Document Reranker |
|
|
|
|
|
This tool uses the **Qwen3-Reranker-0.6B** model to rank documents by their relevance to your query. |
|
|
|
|
|
## How to use: |
|
|
1. **Instruction** (optional): Provide context for the ranking task |
|
|
2. **Query**: Enter your search query |
|
|
3. **Documents**: Enter multiple documents separated by double newlines (\\n\\n) or as a numbered list |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
instruction_input = gr.Textbox( |
|
|
label="Instruction (Optional)", |
|
|
placeholder="Given a web search query, retrieve relevant passages that answer the query", |
|
|
value="Given a web search query, retrieve relevant passages that answer the query", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
query_input = gr.Textbox( |
|
|
label="Query", |
|
|
placeholder="Enter your search query here...", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
documents_input = gr.Textbox( |
|
|
label="Documents to Rank", |
|
|
placeholder="Enter documents separated by double newlines...\n\nDocument 1 content here\n\nDocument 2 content here\n\nDocument 3 content here", |
|
|
lines=10 |
|
|
) |
|
|
|
|
|
rank_button = gr.Button("๐ Rank Documents", variant="primary") |
|
|
|
|
|
gr.Markdown("### Example:") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"Given a web search query, retrieve relevant passages that answer the query", |
|
|
"What is machine learning?", |
|
|
"Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.\n\nPython is a programming language commonly used for web development.\n\nDeep learning uses neural networks with multiple layers to model complex patterns in data." |
|
|
] |
|
|
], |
|
|
inputs=[instruction_input, query_input, documents_input] |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
results_display = gr.DataFrame( |
|
|
label="Ranking Results", |
|
|
headers=["Rank", "Document", "Relevance Score"], |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
rank_button.click( |
|
|
fn=rerank_documents, |
|
|
inputs=[instruction_input, query_input, documents_input], |
|
|
outputs=[results_display] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### About the Model |
|
|
- **Model**: Qwen/Qwen3-Reranker-0.6B |
|
|
- **Task**: Document reranking based on query relevance |
|
|
- **Output**: Relevance scores between 0 and 1 (higher = more relevant) |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_gradio_interface() |
|
|
demo.launch() |