Spaces:
Sleeping
Sleeping
File size: 4,173 Bytes
3bf3346 cd97e60 4ffa42b 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 ec4ed0d cd97e60 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 ec4ed0d cd97e60 ec4ed0d 7ec684a cd97e60 7ec684a cd97e60 ec4ed0d cd97e60 ec4ed0d cd97e60 063bf3b cd97e60 33e4eda 7ff30bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
import requests
import ast
# -------------------------------
# MODELS
# -------------------------------
BI_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
CROSS_ENCODER_STS = "cross-encoder/stsb-roberta-large"
CROSS_ENCODER_NLI = "cross-encoder/nli-deberta-v3-base"
JINA_MODEL = "jina-reranker-m0"
JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
# -------------------------------
# Load models
# -------------------------------
bi_encoder = SentenceTransformer(BI_ENCODER)
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
ce_sts = CrossEncoder(CROSS_ENCODER_STS)
ce_nli = CrossEncoder(CROSS_ENCODER_NLI, num_labels=3)
# -------------------------------
# Pipeline Function
# -------------------------------
def evaluate_models(query, docs_str):
try:
# Parse docs string as Python list
docs = ast.literal_eval(docs_str)
assert isinstance(docs, list), "Input must be a Python list of strings"
except Exception as e:
return f"β οΈ Error parsing documents list: {e}"
results = {}
# 1. Bi-encoder cosine similarity
query_emb = bi_encoder.encode(query, convert_to_tensor=True)
doc_embs = bi_encoder.encode(docs, convert_to_tensor=True)
cos_scores = util.cos_sim(query_emb, doc_embs)[0].cpu().tolist()
results["1. Bi-encoder similarity"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)
# 2. CrossEncoder reranker (MS MARCO)
ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
results["2. CrossEncoder Reranker (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)
# 3. Jina Reranker
headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
payload = {"model": JINA_MODEL, "query": query, "documents": docs}
try:
r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
r.raise_for_status()
jina_scores = [res["relevance_score"] for res in r.json()["results"]]
results["3. Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
except Exception as e:
results["3. Jina Reranker"] = [(f"Error: {e}", 0)]
# 4. CrossEncoder STS
ce_sts_scores = ce_sts.predict([(query, d) for d in docs])
results["4. CrossEncoder STS"] = sorted(zip(docs, ce_sts_scores), key=lambda x: x[1], reverse=True)
# 5. CrossEncoder NLI
ce_nli_probs = ce_nli.predict([(query, d) for d in docs], apply_softmax=True)
ce_nli_scores = [float(p[1] + p[2]) for p in ce_nli_probs] # neutral + entailment
results["5. CrossEncoder NLI"] = sorted(zip(docs, ce_nli_scores), key=lambda x: x[1], reverse=True)
# 6. Bi-encoder raw similarity (duplicate for clarity)
results["6. Bi-encoder baseline"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)
# -------------------------------
# Format output
# -------------------------------
out = ""
for model_name, ranked in results.items():
out += f"\n### {model_name}\n"
for doc, score in ranked:
out += f"- ({round(score,4)}) {doc}\n"
return out
# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## π Multi-Model Reranker (HF + Jina)\nPass a **query** and a **list of documents (Python list of strings)**.")
query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
docs = gr.Textbox(
label="Documents (Python list)",
lines=6,
placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
)
out = gr.Textbox(label="Ranked Results", lines=20)
btn = gr.Button("Evaluate π")
btn.click(evaluate_models, inputs=[query, docs], outputs=out)
demo.launch()
|