Spaces:
Sleeping
Sleeping
File size: 3,479 Bytes
3bf3346 25766a9 4ffa42b 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 7ec684a 25766a9 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 ec4ed0d 25766a9 cd97e60 25766a9 cd97e60 25766a9 cd97e60 7ec684a cd97e60 7ec684a cd97e60 25766a9 7ec684a 25766a9 cd97e60 25766a9 ec4ed0d cd97e60 ec4ed0d 7ec684a cd97e60 7ec684a cd97e60 25766a9 cd97e60 ec4ed0d cd97e60 ec4ed0d cd97e60 063bf3b cd97e60 33e4eda 7ff30bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from sentence_transformers import CrossEncoder
import torch
import requests
import ast
# -------------------------------
# MODELS
# -------------------------------
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
JINA_MODEL = "jina-reranker-m0"
JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
NV_MODEL = "NV-RerankQA-Mistral-4B-v3" # Hugging Face hosted
# -------------------------------
# Load models
# -------------------------------
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
# -------------------------------
# Pipeline Function
# -------------------------------
def evaluate_models(query, docs_str):
try:
docs = ast.literal_eval(docs_str)
assert isinstance(docs, list), "Input must be a Python list of strings"
except Exception as e:
return f"β οΈ Error parsing documents list: {e}"
results = {}
# 1. CrossEncoder reranker (MS MARCO)
ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
results["CrossEncoder (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)
# 2. Jina Reranker
headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
payload = {"model": JINA_MODEL, "query": query, "documents": docs}
try:
r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
r.raise_for_status()
jina_scores = [res["relevance_score"] for res in r.json()["results"]]
results["Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
except Exception as e:
results["Jina Reranker"] = [(f"Error: {e}", 0)]
# 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
try:
hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
headers = {"Authorization": f"Bearer YOUR_HF_API_KEY"}
payload = {"inputs": {"query": query, "documents": docs}}
r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
r.raise_for_status()
nv_scores = [res["score"] for res in r.json()]
results["NV-RerankQA-Mistral-4B-v3"] = sorted(zip(docs, nv_scores), key=lambda x: x[1], reverse=True)
except Exception as e:
results["NV-RerankQA-Mistral-4B-v3"] = [(f"Error: {e}", 0)]
# -------------------------------
# Format output
# -------------------------------
out = ""
for model_name, ranked in results.items():
out += f"\n### {model_name}\n"
for doc, score in ranked:
out += f"- ({round(score,4)}) {doc}\n"
return out
# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## π Ranking Battle (3 Models)\nCompare **NV-RerankQA-Mistral-4B-v3**, **Jina**, and **CrossEncoder**.")
query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
docs = gr.Textbox(
label="Documents (Python list)",
lines=6,
placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
)
out = gr.Textbox(label="Ranked Results", lines=20)
btn = gr.Button("Evaluate π")
btn.click(evaluate_models, inputs=[query, docs], outputs=out)
demo.launch()
|