Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from sentence_transformers import
|
| 3 |
import torch
|
| 4 |
import requests
|
| 5 |
import ast
|
|
@@ -7,28 +7,22 @@ import ast
|
|
| 7 |
# -------------------------------
|
| 8 |
# MODELS
|
| 9 |
# -------------------------------
|
| 10 |
-
BI_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
|
| 11 |
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
| 12 |
-
CROSS_ENCODER_STS = "cross-encoder/stsb-roberta-large"
|
| 13 |
-
CROSS_ENCODER_NLI = "cross-encoder/nli-deberta-v3-base"
|
| 14 |
JINA_MODEL = "jina-reranker-m0"
|
| 15 |
JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
|
| 16 |
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
|
|
|
|
| 17 |
|
| 18 |
# -------------------------------
|
| 19 |
# Load models
|
| 20 |
# -------------------------------
|
| 21 |
-
bi_encoder = SentenceTransformer(BI_ENCODER)
|
| 22 |
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
|
| 23 |
-
ce_sts = CrossEncoder(CROSS_ENCODER_STS)
|
| 24 |
-
ce_nli = CrossEncoder(CROSS_ENCODER_NLI, num_labels=3)
|
| 25 |
|
| 26 |
# -------------------------------
|
| 27 |
# Pipeline Function
|
| 28 |
# -------------------------------
|
| 29 |
def evaluate_models(query, docs_str):
|
| 30 |
try:
|
| 31 |
-
# Parse docs string as Python list
|
| 32 |
docs = ast.literal_eval(docs_str)
|
| 33 |
assert isinstance(docs, list), "Input must be a Python list of strings"
|
| 34 |
except Exception as e:
|
|
@@ -36,39 +30,33 @@ def evaluate_models(query, docs_str):
|
|
| 36 |
|
| 37 |
results = {}
|
| 38 |
|
| 39 |
-
# 1.
|
| 40 |
-
query_emb = bi_encoder.encode(query, convert_to_tensor=True)
|
| 41 |
-
doc_embs = bi_encoder.encode(docs, convert_to_tensor=True)
|
| 42 |
-
cos_scores = util.cos_sim(query_emb, doc_embs)[0].cpu().tolist()
|
| 43 |
-
results["1. Bi-encoder similarity"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)
|
| 44 |
-
|
| 45 |
-
# 2. CrossEncoder reranker (MS MARCO)
|
| 46 |
ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
|
| 47 |
ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
|
| 48 |
-
results["
|
| 49 |
|
| 50 |
-
#
|
| 51 |
headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
|
| 52 |
payload = {"model": JINA_MODEL, "query": query, "documents": docs}
|
| 53 |
try:
|
| 54 |
r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
|
| 55 |
r.raise_for_status()
|
| 56 |
jina_scores = [res["relevance_score"] for res in r.json()["results"]]
|
| 57 |
-
results["
|
| 58 |
except Exception as e:
|
| 59 |
-
results["
|
| 60 |
-
|
| 61 |
-
# 4. CrossEncoder STS
|
| 62 |
-
ce_sts_scores = ce_sts.predict([(query, d) for d in docs])
|
| 63 |
-
results["4. CrossEncoder STS"] = sorted(zip(docs, ce_sts_scores), key=lambda x: x[1], reverse=True)
|
| 64 |
-
|
| 65 |
-
# 5. CrossEncoder NLI
|
| 66 |
-
ce_nli_probs = ce_nli.predict([(query, d) for d in docs], apply_softmax=True)
|
| 67 |
-
ce_nli_scores = [float(p[1] + p[2]) for p in ce_nli_probs] # neutral + entailment
|
| 68 |
-
results["5. CrossEncoder NLI"] = sorted(zip(docs, ce_nli_scores), key=lambda x: x[1], reverse=True)
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# -------------------------------
|
| 74 |
# Format output
|
|
@@ -84,7 +72,7 @@ def evaluate_models(query, docs_str):
|
|
| 84 |
# Gradio UI
|
| 85 |
# -------------------------------
|
| 86 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 87 |
-
gr.Markdown("##
|
| 88 |
|
| 89 |
query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
|
| 90 |
docs = gr.Textbox(
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from sentence_transformers import CrossEncoder
|
| 3 |
import torch
|
| 4 |
import requests
|
| 5 |
import ast
|
|
|
|
| 7 |
# -------------------------------
|
| 8 |
# MODELS
|
| 9 |
# -------------------------------
|
|
|
|
| 10 |
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
|
|
|
|
|
|
| 11 |
JINA_MODEL = "jina-reranker-m0"
|
| 12 |
JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
|
| 13 |
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
|
| 14 |
+
NV_MODEL = "NV-RerankQA-Mistral-4B-v3" # Hugging Face hosted
|
| 15 |
|
| 16 |
# -------------------------------
|
| 17 |
# Load models
|
| 18 |
# -------------------------------
|
|
|
|
| 19 |
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# -------------------------------
|
| 22 |
# Pipeline Function
|
| 23 |
# -------------------------------
|
| 24 |
def evaluate_models(query, docs_str):
|
| 25 |
try:
|
|
|
|
| 26 |
docs = ast.literal_eval(docs_str)
|
| 27 |
assert isinstance(docs, list), "Input must be a Python list of strings"
|
| 28 |
except Exception as e:
|
|
|
|
| 30 |
|
| 31 |
results = {}
|
| 32 |
|
| 33 |
+
# 1. CrossEncoder reranker (MS MARCO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
|
| 35 |
ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
|
| 36 |
+
results["CrossEncoder (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)
|
| 37 |
|
| 38 |
+
# 2. Jina Reranker
|
| 39 |
headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
|
| 40 |
payload = {"model": JINA_MODEL, "query": query, "documents": docs}
|
| 41 |
try:
|
| 42 |
r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
|
| 43 |
r.raise_for_status()
|
| 44 |
jina_scores = [res["relevance_score"] for res in r.json()["results"]]
|
| 45 |
+
results["Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
|
| 46 |
except Exception as e:
|
| 47 |
+
results["Jina Reranker"] = [(f"Error: {e}", 0)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
# 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
|
| 50 |
+
try:
|
| 51 |
+
hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
|
| 52 |
+
headers = {"Authorization": f"Bearer YOUR_HF_API_KEY"}
|
| 53 |
+
payload = {"inputs": {"query": query, "documents": docs}}
|
| 54 |
+
r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
|
| 55 |
+
r.raise_for_status()
|
| 56 |
+
nv_scores = [res["score"] for res in r.json()]
|
| 57 |
+
results["NV-RerankQA-Mistral-4B-v3"] = sorted(zip(docs, nv_scores), key=lambda x: x[1], reverse=True)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
results["NV-RerankQA-Mistral-4B-v3"] = [(f"Error: {e}", 0)]
|
| 60 |
|
| 61 |
# -------------------------------
|
| 62 |
# Format output
|
|
|
|
| 72 |
# Gradio UI
|
| 73 |
# -------------------------------
|
| 74 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 75 |
+
gr.Markdown("## π Ranking Battle (3 Models)\nCompare **NV-RerankQA-Mistral-4B-v3**, **Jina**, and **CrossEncoder**.")
|
| 76 |
|
| 77 |
query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
|
| 78 |
docs = gr.Textbox(
|