File size: 4,173 Bytes
3bf3346
cd97e60
4ffa42b
7ec684a
cd97e60
7ec684a
 
cd97e60
7ec684a
cd97e60
 
 
 
7ec684a
 
 
 
 
cd97e60
7ec684a
cd97e60
 
 
 
7ec684a
cd97e60
 
 
 
 
 
 
 
 
 
7ec684a
cd97e60
ec4ed0d
cd97e60
 
 
 
 
 
 
 
 
 
 
 
 
 
7ec684a
cd97e60
7ec684a
cd97e60
 
7ec684a
cd97e60
 
 
 
 
 
 
 
 
 
 
 
 
ec4ed0d
 
 
 
cd97e60
 
 
 
 
ec4ed0d
7ec684a
 
cd97e60
7ec684a
cd97e60
 
 
 
ec4ed0d
cd97e60
 
 
ec4ed0d
cd97e60
063bf3b
cd97e60
 
33e4eda
7ff30bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
import requests
import ast

# -------------------------------
# MODELS
# -------------------------------
BI_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
CROSS_ENCODER_STS = "cross-encoder/stsb-roberta-large"
CROSS_ENCODER_NLI = "cross-encoder/nli-deberta-v3-base"
JINA_MODEL = "jina-reranker-m0"
JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"

# -------------------------------
# Load models
# -------------------------------
bi_encoder = SentenceTransformer(BI_ENCODER)
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
ce_sts = CrossEncoder(CROSS_ENCODER_STS)
ce_nli = CrossEncoder(CROSS_ENCODER_NLI, num_labels=3)

# -------------------------------
# Pipeline Function
# -------------------------------
def evaluate_models(query, docs_str):
    try:
        # Parse docs string as Python list
        docs = ast.literal_eval(docs_str)
        assert isinstance(docs, list), "Input must be a Python list of strings"
    except Exception as e:
        return f"⚠️ Error parsing documents list: {e}"

    results = {}

    # 1. Bi-encoder cosine similarity
    query_emb = bi_encoder.encode(query, convert_to_tensor=True)
    doc_embs = bi_encoder.encode(docs, convert_to_tensor=True)
    cos_scores = util.cos_sim(query_emb, doc_embs)[0].cpu().tolist()
    results["1. Bi-encoder similarity"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)

    # 2. CrossEncoder reranker (MS MARCO)
    ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
    ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
    results["2. CrossEncoder Reranker (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)

    # 3. Jina Reranker
    headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
    payload = {"model": JINA_MODEL, "query": query, "documents": docs}
    try:
        r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
        r.raise_for_status()
        jina_scores = [res["relevance_score"] for res in r.json()["results"]]
        results["3. Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
    except Exception as e:
        results["3. Jina Reranker"] = [(f"Error: {e}", 0)]

    # 4. CrossEncoder STS
    ce_sts_scores = ce_sts.predict([(query, d) for d in docs])
    results["4. CrossEncoder STS"] = sorted(zip(docs, ce_sts_scores), key=lambda x: x[1], reverse=True)

    # 5. CrossEncoder NLI
    ce_nli_probs = ce_nli.predict([(query, d) for d in docs], apply_softmax=True)
    ce_nli_scores = [float(p[1] + p[2]) for p in ce_nli_probs]  # neutral + entailment
    results["5. CrossEncoder NLI"] = sorted(zip(docs, ce_nli_scores), key=lambda x: x[1], reverse=True)

    # 6. Bi-encoder raw similarity (duplicate for clarity)
    results["6. Bi-encoder baseline"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)

    # -------------------------------
    # Format output
    # -------------------------------
    out = ""
    for model_name, ranked in results.items():
        out += f"\n### {model_name}\n"
        for doc, score in ranked:
            out += f"- ({round(score,4)}) {doc}\n"
    return out

# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## πŸ”Ž Multi-Model Reranker (HF + Jina)\nPass a **query** and a **list of documents (Python list of strings)**.")

    query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
    docs = gr.Textbox(
        label="Documents (Python list)", 
        lines=6, 
        placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
    )
    out = gr.Textbox(label="Ranked Results", lines=20)

    btn = gr.Button("Evaluate πŸš€")
    btn.click(evaluate_models, inputs=[query, docs], outputs=out)

demo.launch()