import gradio as gr import logging from typing import List, Tuple import pandas as pd import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Try to import torch and transformers with fallback try: import torch from transformers import AutoTokenizer, AutoModelForCausalLM DEPENDENCIES_AVAILABLE = True except ImportError as e: logger.warning(f"Dependencies not available: {e}") DEPENDENCIES_AVAILABLE = False torch = None AutoTokenizer = None AutoModelForCausalLM = None class Qwen3Reranker: def __init__(self, model_name="Qwen/Qwen3-Reranker-0.6B"): if not DEPENDENCIES_AVAILABLE: raise ImportError("Required dependencies (torch, transformers) are not available") self.model_name = model_name self.tokenizer = None self.model = None self.token_false_id = None self.token_true_id = None self.max_length = 8192 self.prefix_tokens = None self.suffix_tokens = None self.device = "cuda" if torch.cuda.is_available() else "cpu" self._load_model() def _load_model(self): """Load the tokenizer and model""" try: logger.info(f"Loading {self.model_name}...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, padding_side='left' ) # Load model with appropriate settings if torch.cuda.is_available(): self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, device_map="auto" ).eval() else: self.model = AutoModelForCausalLM.from_pretrained( self.model_name ).eval() # Set up tokens self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") # Set up prefix and suffix prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Error loading model: {e}") raise e def format_instruction(self, instruction: str, query: str, doc: str) -> str: """Format the instruction for the reranker""" if instruction is None or instruction.strip() == "": instruction = 'Given a web search query, retrieve relevant passages that answer the query' return f": {instruction}\n: {query}\n: {doc}" def process_inputs(self, pairs: List[str]) -> dict: """Process input pairs for the model""" inputs = self.tokenizer( pairs, padding=True, truncation=True, return_tensors='pt', max_length=self.max_length ) return inputs def rank_documents(self, instruction: str, query: str, documents: List[str]) -> List[Tuple[str, float, str]]: """Rank documents based on their relevance to the query""" if not DEPENDENCIES_AVAILABLE: return [(doc[:100] + "...", 0.5, "Dependencies not available") for doc in documents] results = [] for i, doc in enumerate(documents): try: # Format the instruction formatted_instruction = self.format_instruction(instruction, query, doc) # Tokenize inputs = self.tokenizer( formatted_instruction, return_tensors='pt', max_length=self.max_length, truncation=True ) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Get model output with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits[0, -1, :] # Get probabilities for "yes" and "no" tokens yes_prob = torch.softmax(logits, dim=-1)[self.token_true_id].item() no_prob = torch.softmax(logits, dim=-1)[self.token_false_id].item() # Calculate relevance score (probability of "yes") relevance_score = yes_prob / (yes_prob + no_prob) # Truncate document for display display_doc = doc[:200] + "..." if len(doc) > 200 else doc results.append((display_doc, relevance_score, f"Document {i+1}")) except Exception as e: logger.error(f"Error processing document {i+1}: {e}") display_doc = doc[:200] + "..." if len(doc) > 200 else doc results.append((display_doc, 0.0, f"Error: {str(e)[:50]}...")) # Sort by relevance score (highest first) results.sort(key=lambda x: x[1], reverse=True) return results # Initialize the reranker try: reranker = Qwen3Reranker() except Exception as e: logger.error(f"Failed to initialize reranker: {e}") reranker = None def rerank_documents(instruction, query, documents_text): """Gradio interface function""" if not reranker: return pd.DataFrame([["Error", "Model not loaded", 0.0]], columns=["Document", "Relevance Score", "Rank"]) if not query.strip(): return pd.DataFrame([["Error", "Please provide a query", 0.0]], columns=["Document", "Relevance Score", "Rank"]) if not documents_text.strip(): return pd.DataFrame([["Error", "Please provide documents", 0.0]], columns=["Document", "Relevance Score", "Rank"]) # Split documents by double newlines or numbered list format documents = [] if '\n\n' in documents_text: documents = [doc.strip() for doc in documents_text.split('\n\n') if doc.strip()] else: # Try to split by numbered format (1., 2., etc.) lines = documents_text.strip().split('\n') current_doc = "" for line in lines: if line.strip() and (line.strip()[0].isdigit() and '.' in line[:5]): if current_doc: documents.append(current_doc.strip()) current_doc = line else: current_doc += "\n" + line if current_doc: documents.append(current_doc.strip()) if not documents: documents = [documents_text] # Treat as single document # Rank documents results = reranker.rank_documents(instruction, query, documents) # Create DataFrame for display df_data = [] for i, (doc, score, label) in enumerate(results): df_data.append([f"#{i+1}", doc, f"{score:.4f}"]) return pd.DataFrame(df_data, columns=["Rank", "Document", "Relevance Score"]) def create_gradio_interface(): """Create the Gradio interface""" with gr.Blocks(title="Qwen3 Document Reranker", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔍 Qwen3 Document Reranker This tool uses the **Qwen3-Reranker-0.6B** model to rank documents by their relevance to your query. ## How to use: 1. **Instruction** (optional): Provide context for the ranking task 2. **Query**: Enter your search query 3. **Documents**: Enter multiple documents separated by double newlines (\\n\\n) or as a numbered list """) with gr.Row(): with gr.Column(scale=1): instruction_input = gr.Textbox( label="Instruction (Optional)", placeholder="Given a web search query, retrieve relevant passages that answer the query", value="Given a web search query, retrieve relevant passages that answer the query", lines=2 ) query_input = gr.Textbox( label="Query", placeholder="Enter your search query here...", lines=2 ) documents_input = gr.Textbox( label="Documents to Rank", placeholder="Enter documents separated by double newlines...\n\nDocument 1 content here\n\nDocument 2 content here\n\nDocument 3 content here", lines=10 ) rank_button = gr.Button("🔍 Rank Documents", variant="primary") gr.Markdown("### Example:") gr.Examples( examples=[ [ "Given a web search query, retrieve relevant passages that answer the query", "What is machine learning?", "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.\n\nPython is a programming language commonly used for web development.\n\nDeep learning uses neural networks with multiple layers to model complex patterns in data." ] ], inputs=[instruction_input, query_input, documents_input] ) with gr.Column(scale=1): # FIXED: Remove the height parameter results_display = gr.DataFrame( label="Ranking Results", headers=["Rank", "Document", "Relevance Score"], interactive=False ) rank_button.click( fn=rerank_documents, inputs=[instruction_input, query_input, documents_input], outputs=[results_display] ) gr.Markdown(""" ### About the Model - **Model**: Qwen/Qwen3-Reranker-0.6B - **Task**: Document reranking based on query relevance - **Output**: Relevance scores between 0 and 1 (higher = more relevant) """) return demo if __name__ == "__main__": demo = create_gradio_interface() demo.launch()