afriddev commited on
Commit
25766a9
Β·
verified Β·
1 Parent(s): cd97e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -31
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from sentence_transformers import SentenceTransformer, CrossEncoder, util
3
  import torch
4
  import requests
5
  import ast
@@ -7,28 +7,22 @@ import ast
7
  # -------------------------------
8
  # MODELS
9
  # -------------------------------
10
- BI_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
11
  CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
12
- CROSS_ENCODER_STS = "cross-encoder/stsb-roberta-large"
13
- CROSS_ENCODER_NLI = "cross-encoder/nli-deberta-v3-base"
14
  JINA_MODEL = "jina-reranker-m0"
15
  JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
16
  JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
 
17
 
18
  # -------------------------------
19
  # Load models
20
  # -------------------------------
21
- bi_encoder = SentenceTransformer(BI_ENCODER)
22
  ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
23
- ce_sts = CrossEncoder(CROSS_ENCODER_STS)
24
- ce_nli = CrossEncoder(CROSS_ENCODER_NLI, num_labels=3)
25
 
26
  # -------------------------------
27
  # Pipeline Function
28
  # -------------------------------
29
  def evaluate_models(query, docs_str):
30
  try:
31
- # Parse docs string as Python list
32
  docs = ast.literal_eval(docs_str)
33
  assert isinstance(docs, list), "Input must be a Python list of strings"
34
  except Exception as e:
@@ -36,39 +30,33 @@ def evaluate_models(query, docs_str):
36
 
37
  results = {}
38
 
39
- # 1. Bi-encoder cosine similarity
40
- query_emb = bi_encoder.encode(query, convert_to_tensor=True)
41
- doc_embs = bi_encoder.encode(docs, convert_to_tensor=True)
42
- cos_scores = util.cos_sim(query_emb, doc_embs)[0].cpu().tolist()
43
- results["1. Bi-encoder similarity"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)
44
-
45
- # 2. CrossEncoder reranker (MS MARCO)
46
  ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
47
  ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
48
- results["2. CrossEncoder Reranker (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)
49
 
50
- # 3. Jina Reranker
51
  headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
52
  payload = {"model": JINA_MODEL, "query": query, "documents": docs}
53
  try:
54
  r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
55
  r.raise_for_status()
56
  jina_scores = [res["relevance_score"] for res in r.json()["results"]]
57
- results["3. Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
58
  except Exception as e:
59
- results["3. Jina Reranker"] = [(f"Error: {e}", 0)]
60
-
61
- # 4. CrossEncoder STS
62
- ce_sts_scores = ce_sts.predict([(query, d) for d in docs])
63
- results["4. CrossEncoder STS"] = sorted(zip(docs, ce_sts_scores), key=lambda x: x[1], reverse=True)
64
-
65
- # 5. CrossEncoder NLI
66
- ce_nli_probs = ce_nli.predict([(query, d) for d in docs], apply_softmax=True)
67
- ce_nli_scores = [float(p[1] + p[2]) for p in ce_nli_probs] # neutral + entailment
68
- results["5. CrossEncoder NLI"] = sorted(zip(docs, ce_nli_scores), key=lambda x: x[1], reverse=True)
69
 
70
- # 6. Bi-encoder raw similarity (duplicate for clarity)
71
- results["6. Bi-encoder baseline"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
72
 
73
  # -------------------------------
74
  # Format output
@@ -84,7 +72,7 @@ def evaluate_models(query, docs_str):
84
  # Gradio UI
85
  # -------------------------------
86
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
- gr.Markdown("## πŸ”Ž Multi-Model Reranker (HF + Jina)\nPass a **query** and a **list of documents (Python list of strings)**.")
88
 
89
  query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
90
  docs = gr.Textbox(
 
1
  import gradio as gr
2
+ from sentence_transformers import CrossEncoder
3
  import torch
4
  import requests
5
  import ast
 
7
  # -------------------------------
8
  # MODELS
9
  # -------------------------------
 
10
  CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
 
 
11
  JINA_MODEL = "jina-reranker-m0"
12
  JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
13
  JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
14
+ NV_MODEL = "NV-RerankQA-Mistral-4B-v3" # Hugging Face hosted
15
 
16
  # -------------------------------
17
  # Load models
18
  # -------------------------------
 
19
  ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
 
 
20
 
21
  # -------------------------------
22
  # Pipeline Function
23
  # -------------------------------
24
  def evaluate_models(query, docs_str):
25
  try:
 
26
  docs = ast.literal_eval(docs_str)
27
  assert isinstance(docs, list), "Input must be a Python list of strings"
28
  except Exception as e:
 
30
 
31
  results = {}
32
 
33
+ # 1. CrossEncoder reranker (MS MARCO)
 
 
 
 
 
 
34
  ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
35
  ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
36
+ results["CrossEncoder (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)
37
 
38
+ # 2. Jina Reranker
39
  headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
40
  payload = {"model": JINA_MODEL, "query": query, "documents": docs}
41
  try:
42
  r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
43
  r.raise_for_status()
44
  jina_scores = [res["relevance_score"] for res in r.json()["results"]]
45
+ results["Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
46
  except Exception as e:
47
+ results["Jina Reranker"] = [(f"Error: {e}", 0)]
 
 
 
 
 
 
 
 
 
48
 
49
+ # 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
50
+ try:
51
+ hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
52
+ headers = {"Authorization": f"Bearer YOUR_HF_API_KEY"}
53
+ payload = {"inputs": {"query": query, "documents": docs}}
54
+ r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
55
+ r.raise_for_status()
56
+ nv_scores = [res["score"] for res in r.json()]
57
+ results["NV-RerankQA-Mistral-4B-v3"] = sorted(zip(docs, nv_scores), key=lambda x: x[1], reverse=True)
58
+ except Exception as e:
59
+ results["NV-RerankQA-Mistral-4B-v3"] = [(f"Error: {e}", 0)]
60
 
61
  # -------------------------------
62
  # Format output
 
72
  # Gradio UI
73
  # -------------------------------
74
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
+ gr.Markdown("## πŸ‘‘ Ranking Battle (3 Models)\nCompare **NV-RerankQA-Mistral-4B-v3**, **Jina**, and **CrossEncoder**.")
76
 
77
  query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
78
  docs = gr.Textbox(