ReallyFloppyPenguin's picture
Update app.py
b9f7748 verified
import gradio as gr
import logging
from typing import List, Tuple
import pandas as pd
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import torch and transformers with fallback
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
DEPENDENCIES_AVAILABLE = True
except ImportError as e:
logger.warning(f"Dependencies not available: {e}")
DEPENDENCIES_AVAILABLE = False
torch = None
AutoTokenizer = None
AutoModelForCausalLM = None
class Qwen3Reranker:
def __init__(self, model_name="Qwen/Qwen3-Reranker-0.6B"):
if not DEPENDENCIES_AVAILABLE:
raise ImportError("Required dependencies (torch, transformers) are not available")
self.model_name = model_name
self.tokenizer = None
self.model = None
self.token_false_id = None
self.token_true_id = None
self.max_length = 8192
self.prefix_tokens = None
self.suffix_tokens = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._load_model()
def _load_model(self):
"""Load the tokenizer and model"""
try:
logger.info(f"Loading {self.model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
padding_side='left'
)
# Load model with appropriate settings
if torch.cuda.is_available():
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
else:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name
).eval()
# Set up tokens
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
# Set up prefix and suffix
prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise e
def format_instruction(self, instruction: str, query: str, doc: str) -> str:
"""Format the instruction for the reranker"""
if instruction is None or instruction.strip() == "":
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
def process_inputs(self, pairs: List[str]) -> dict:
"""Process input pairs for the model"""
inputs = self.tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors='pt',
max_length=self.max_length
)
return inputs
def rank_documents(self, instruction: str, query: str, documents: List[str]) -> List[Tuple[str, float, str]]:
"""Rank documents based on their relevance to the query"""
if not DEPENDENCIES_AVAILABLE:
return [(doc[:100] + "...", 0.5, "Dependencies not available") for doc in documents]
results = []
for i, doc in enumerate(documents):
try:
# Format the instruction
formatted_instruction = self.format_instruction(instruction, query, doc)
# Tokenize
inputs = self.tokenizer(
formatted_instruction,
return_tensors='pt',
max_length=self.max_length,
truncation=True
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# Get model output
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[0, -1, :]
# Get probabilities for "yes" and "no" tokens
yes_prob = torch.softmax(logits, dim=-1)[self.token_true_id].item()
no_prob = torch.softmax(logits, dim=-1)[self.token_false_id].item()
# Calculate relevance score (probability of "yes")
relevance_score = yes_prob / (yes_prob + no_prob)
# Truncate document for display
display_doc = doc[:200] + "..." if len(doc) > 200 else doc
results.append((display_doc, relevance_score, f"Document {i+1}"))
except Exception as e:
logger.error(f"Error processing document {i+1}: {e}")
display_doc = doc[:200] + "..." if len(doc) > 200 else doc
results.append((display_doc, 0.0, f"Error: {str(e)[:50]}..."))
# Sort by relevance score (highest first)
results.sort(key=lambda x: x[1], reverse=True)
return results
# Initialize the reranker
try:
reranker = Qwen3Reranker()
except Exception as e:
logger.error(f"Failed to initialize reranker: {e}")
reranker = None
def rerank_documents(instruction, query, documents_text):
"""Gradio interface function"""
if not reranker:
return pd.DataFrame([["Error", "Model not loaded", 0.0]],
columns=["Document", "Relevance Score", "Rank"])
if not query.strip():
return pd.DataFrame([["Error", "Please provide a query", 0.0]],
columns=["Document", "Relevance Score", "Rank"])
if not documents_text.strip():
return pd.DataFrame([["Error", "Please provide documents", 0.0]],
columns=["Document", "Relevance Score", "Rank"])
# Split documents by double newlines or numbered list format
documents = []
if '\n\n' in documents_text:
documents = [doc.strip() for doc in documents_text.split('\n\n') if doc.strip()]
else:
# Try to split by numbered format (1., 2., etc.)
lines = documents_text.strip().split('\n')
current_doc = ""
for line in lines:
if line.strip() and (line.strip()[0].isdigit() and '.' in line[:5]):
if current_doc:
documents.append(current_doc.strip())
current_doc = line
else:
current_doc += "\n" + line
if current_doc:
documents.append(current_doc.strip())
if not documents:
documents = [documents_text] # Treat as single document
# Rank documents
results = reranker.rank_documents(instruction, query, documents)
# Create DataFrame for display
df_data = []
for i, (doc, score, label) in enumerate(results):
df_data.append([f"#{i+1}", doc, f"{score:.4f}"])
return pd.DataFrame(df_data, columns=["Rank", "Document", "Relevance Score"])
def create_gradio_interface():
"""Create the Gradio interface"""
with gr.Blocks(title="Qwen3 Document Reranker", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐Ÿ” Qwen3 Document Reranker
This tool uses the **Qwen3-Reranker-0.6B** model to rank documents by their relevance to your query.
## How to use:
1. **Instruction** (optional): Provide context for the ranking task
2. **Query**: Enter your search query
3. **Documents**: Enter multiple documents separated by double newlines (\\n\\n) or as a numbered list
""")
with gr.Row():
with gr.Column(scale=1):
instruction_input = gr.Textbox(
label="Instruction (Optional)",
placeholder="Given a web search query, retrieve relevant passages that answer the query",
value="Given a web search query, retrieve relevant passages that answer the query",
lines=2
)
query_input = gr.Textbox(
label="Query",
placeholder="Enter your search query here...",
lines=2
)
documents_input = gr.Textbox(
label="Documents to Rank",
placeholder="Enter documents separated by double newlines...\n\nDocument 1 content here\n\nDocument 2 content here\n\nDocument 3 content here",
lines=10
)
rank_button = gr.Button("๐Ÿ” Rank Documents", variant="primary")
gr.Markdown("### Example:")
gr.Examples(
examples=[
[
"Given a web search query, retrieve relevant passages that answer the query",
"What is machine learning?",
"Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.\n\nPython is a programming language commonly used for web development.\n\nDeep learning uses neural networks with multiple layers to model complex patterns in data."
]
],
inputs=[instruction_input, query_input, documents_input]
)
with gr.Column(scale=1):
# FIXED: Remove the height parameter
results_display = gr.DataFrame(
label="Ranking Results",
headers=["Rank", "Document", "Relevance Score"],
interactive=False
)
rank_button.click(
fn=rerank_documents,
inputs=[instruction_input, query_input, documents_input],
outputs=[results_display]
)
gr.Markdown("""
### About the Model
- **Model**: Qwen/Qwen3-Reranker-0.6B
- **Task**: Document reranking based on query relevance
- **Output**: Relevance scores between 0 and 1 (higher = more relevant)
""")
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch()