File size: 3,479 Bytes
3bf3346
25766a9
4ffa42b
7ec684a
cd97e60
7ec684a
 
cd97e60
7ec684a
cd97e60
7ec684a
 
 
25766a9
7ec684a
 
cd97e60
7ec684a
cd97e60
7ec684a
cd97e60
 
 
 
 
 
 
 
 
7ec684a
cd97e60
ec4ed0d
25766a9
cd97e60
 
25766a9
cd97e60
25766a9
cd97e60
 
7ec684a
cd97e60
7ec684a
cd97e60
25766a9
7ec684a
25766a9
cd97e60
25766a9
 
 
 
 
 
 
 
 
 
 
ec4ed0d
 
 
 
cd97e60
 
 
 
 
ec4ed0d
7ec684a
 
cd97e60
7ec684a
cd97e60
25766a9
cd97e60
 
ec4ed0d
cd97e60
 
 
ec4ed0d
cd97e60
063bf3b
cd97e60
 
33e4eda
7ff30bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from sentence_transformers import CrossEncoder
import torch
import requests
import ast

# -------------------------------
# MODELS
# -------------------------------
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
JINA_MODEL = "jina-reranker-m0"
JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
NV_MODEL = "NV-RerankQA-Mistral-4B-v3"   # Hugging Face hosted

# -------------------------------
# Load models
# -------------------------------
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)

# -------------------------------
# Pipeline Function
# -------------------------------
def evaluate_models(query, docs_str):
    try:
        docs = ast.literal_eval(docs_str)
        assert isinstance(docs, list), "Input must be a Python list of strings"
    except Exception as e:
        return f"⚠️ Error parsing documents list: {e}"

    results = {}

    # 1. CrossEncoder reranker (MS MARCO)
    ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
    ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
    results["CrossEncoder (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)

    # 2. Jina Reranker
    headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
    payload = {"model": JINA_MODEL, "query": query, "documents": docs}
    try:
        r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
        r.raise_for_status()
        jina_scores = [res["relevance_score"] for res in r.json()["results"]]
        results["Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
    except Exception as e:
        results["Jina Reranker"] = [(f"Error: {e}", 0)]

    # 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
    try:
        hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
        headers = {"Authorization": f"Bearer YOUR_HF_API_KEY"}
        payload = {"inputs": {"query": query, "documents": docs}}
        r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
        r.raise_for_status()
        nv_scores = [res["score"] for res in r.json()]
        results["NV-RerankQA-Mistral-4B-v3"] = sorted(zip(docs, nv_scores), key=lambda x: x[1], reverse=True)
    except Exception as e:
        results["NV-RerankQA-Mistral-4B-v3"] = [(f"Error: {e}", 0)]

    # -------------------------------
    # Format output
    # -------------------------------
    out = ""
    for model_name, ranked in results.items():
        out += f"\n### {model_name}\n"
        for doc, score in ranked:
            out += f"- ({round(score,4)}) {doc}\n"
    return out

# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## πŸ‘‘ Ranking Battle (3 Models)\nCompare **NV-RerankQA-Mistral-4B-v3**, **Jina**, and **CrossEncoder**.")

    query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
    docs = gr.Textbox(
        label="Documents (Python list)", 
        lines=6, 
        placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
    )
    out = gr.Textbox(label="Ranked Results", lines=20)

    btn = gr.Button("Evaluate πŸš€")
    btn.click(evaluate_models, inputs=[query, docs], outputs=out)

demo.launch()