afriddev commited on
Commit
5f51df4
Β·
verified Β·
1 Parent(s): 25766a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -19
app.py CHANGED
@@ -32,8 +32,8 @@ def evaluate_models(query, docs_str):
32
 
33
  # 1. CrossEncoder reranker (MS MARCO)
34
  ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
35
- ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
36
- results["CrossEncoder (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)
37
 
38
  # 2. Jina Reranker
39
  headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
@@ -41,10 +41,10 @@ def evaluate_models(query, docs_str):
41
  try:
42
  r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
43
  r.raise_for_status()
44
- jina_scores = [res["relevance_score"] for res in r.json()["results"]]
45
- results["Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
46
  except Exception as e:
47
- results["Jina Reranker"] = [(f"Error: {e}", 0)]
48
 
49
  # 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
50
  try:
@@ -53,26 +53,18 @@ def evaluate_models(query, docs_str):
53
  payload = {"inputs": {"query": query, "documents": docs}}
54
  r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
55
  r.raise_for_status()
56
- nv_scores = [res["score"] for res in r.json()]
57
- results["NV-RerankQA-Mistral-4B-v3"] = sorted(zip(docs, nv_scores), key=lambda x: x[1], reverse=True)
58
  except Exception as e:
59
- results["NV-RerankQA-Mistral-4B-v3"] = [(f"Error: {e}", 0)]
60
 
61
- # -------------------------------
62
- # Format output
63
- # -------------------------------
64
- out = ""
65
- for model_name, ranked in results.items():
66
- out += f"\n### {model_name}\n"
67
- for doc, score in ranked:
68
- out += f"- ({round(score,4)}) {doc}\n"
69
- return out
70
 
71
  # -------------------------------
72
  # Gradio UI
73
  # -------------------------------
74
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
- gr.Markdown("## πŸ‘‘ Ranking Battle (3 Models)\nCompare **NV-RerankQA-Mistral-4B-v3**, **Jina**, and **CrossEncoder**.")
76
 
77
  query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
78
  docs = gr.Textbox(
@@ -80,7 +72,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
80
  lines=6,
81
  placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
82
  )
83
- out = gr.Textbox(label="Ranked Results", lines=20)
84
 
85
  btn = gr.Button("Evaluate πŸš€")
86
  btn.click(evaluate_models, inputs=[query, docs], outputs=out)
 
32
 
33
  # 1. CrossEncoder reranker (MS MARCO)
34
  ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
35
+ ce_rerank_scores = [round(torch.sigmoid(torch.tensor(s)).item(), 4) for s in ce_rerank_scores]
36
+ results["CrossEncoder (MS MARCO)"] = ce_rerank_scores
37
 
38
  # 2. Jina Reranker
39
  headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
 
41
  try:
42
  r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
43
  r.raise_for_status()
44
+ jina_scores = [round(res["relevance_score"], 4) for res in r.json()["results"]]
45
+ results["Jina Reranker"] = jina_scores
46
  except Exception as e:
47
+ results["Jina Reranker"] = [f"Error: {e}"]
48
 
49
  # 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
50
  try:
 
53
  payload = {"inputs": {"query": query, "documents": docs}}
54
  r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
55
  r.raise_for_status()
56
+ nv_scores = [round(res["score"], 4) for res in r.json()]
57
+ results["NV-RerankQA-Mistral-4B-v3"] = nv_scores
58
  except Exception as e:
59
+ results["NV-RerankQA-Mistral-4B-v3"] = [f"Error: {e}"]
60
 
61
+ return results
 
 
 
 
 
 
 
 
62
 
63
  # -------------------------------
64
  # Gradio UI
65
  # -------------------------------
66
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
67
+ gr.Markdown("## πŸ‘‘ Ranking Battle (Scores Only)\nOutputs only **scores** from the 3 models.")
68
 
69
  query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
70
  docs = gr.Textbox(
 
72
  lines=6,
73
  placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
74
  )
75
+ out = gr.JSON(label="Model Scores")
76
 
77
  btn = gr.Button("Evaluate πŸš€")
78
  btn.click(evaluate_models, inputs=[query, docs], outputs=out)