afriddev commited on
Commit
c7ff2c3
·
verified ·
1 Parent(s): 5f51df4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -3,15 +3,17 @@ from sentence_transformers import CrossEncoder
3
  import torch
4
  import requests
5
  import ast
 
6
 
7
  # -------------------------------
8
  # MODELS
9
  # -------------------------------
10
  CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
11
  JINA_MODEL = "jina-reranker-m0"
12
- JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
13
  JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
14
- NV_MODEL = "NV-RerankQA-Mistral-4B-v3" # Hugging Face hosted
 
15
 
16
  # -------------------------------
17
  # Load models
@@ -26,7 +28,7 @@ def evaluate_models(query, docs_str):
26
  docs = ast.literal_eval(docs_str)
27
  assert isinstance(docs, list), "Input must be a Python list of strings"
28
  except Exception as e:
29
- return f"⚠️ Error parsing documents list: {e}"
30
 
31
  results = {}
32
 
@@ -36,27 +38,35 @@ def evaluate_models(query, docs_str):
36
  results["CrossEncoder (MS MARCO)"] = ce_rerank_scores
37
 
38
  # 2. Jina Reranker
39
- headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
40
- payload = {"model": JINA_MODEL, "query": query, "documents": docs}
41
- try:
42
- r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
43
- r.raise_for_status()
44
- jina_scores = [round(res["relevance_score"], 4) for res in r.json()["results"]]
45
- results["Jina Reranker"] = jina_scores
46
- except Exception as e:
47
- results["Jina Reranker"] = [f"Error: {e}"]
 
 
 
 
 
48
 
49
  # 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
50
- try:
51
- hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
52
- headers = {"Authorization": f"Bearer YOUR_HF_API_KEY"}
53
- payload = {"inputs": {"query": query, "documents": docs}}
54
- r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
55
- r.raise_for_status()
56
- nv_scores = [round(res["score"], 4) for res in r.json()]
57
- results["NV-RerankQA-Mistral-4B-v3"] = nv_scores
58
- except Exception as e:
59
- results["NV-RerankQA-Mistral-4B-v3"] = [f"Error: {e}"]
 
 
 
60
 
61
  return results
62
 
@@ -64,13 +74,13 @@ def evaluate_models(query, docs_str):
64
  # Gradio UI
65
  # -------------------------------
66
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
67
- gr.Markdown("## 👑 Ranking Battle (Scores Only)\nOutputs only **scores** from the 3 models.")
68
 
69
  query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
70
  docs = gr.Textbox(
71
  label="Documents (Python list)",
72
  lines=6,
73
- placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
74
  )
75
  out = gr.JSON(label="Model Scores")
76
 
 
3
  import torch
4
  import requests
5
  import ast
6
+ import os
7
 
8
  # -------------------------------
9
  # MODELS
10
  # -------------------------------
11
  CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
12
  JINA_MODEL = "jina-reranker-m0"
13
+ JINA_API_KEY = os.getenv("JINA_API_KEY") # set in HF Space settings
14
  JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
15
+ NV_MODEL = "NV-RerankQA-Mistral-4B-v3"
16
+ HF_API_KEY = os.getenv("HF_API_KEY") # set in HF Space settings
17
 
18
  # -------------------------------
19
  # Load models
 
28
  docs = ast.literal_eval(docs_str)
29
  assert isinstance(docs, list), "Input must be a Python list of strings"
30
  except Exception as e:
31
+ return {"Error": f"⚠️ Error parsing documents list: {e}"}
32
 
33
  results = {}
34
 
 
38
  results["CrossEncoder (MS MARCO)"] = ce_rerank_scores
39
 
40
  # 2. Jina Reranker
41
+ if JINA_API_KEY:
42
+ headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
43
+ payload = {"model": JINA_MODEL, "query": query, "documents": docs}
44
+ try:
45
+ r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
46
+ r.raise_for_status()
47
+ jina_scores = [0] * len(docs)
48
+ for res in r.json()["results"]:
49
+ jina_scores[res["index"]] = round(res["relevance_score"], 4)
50
+ results["Jina Reranker"] = jina_scores
51
+ except Exception as e:
52
+ results["Jina Reranker"] = [f"Error: {e}"]
53
+ else:
54
+ results["Jina Reranker"] = ["Error: Missing JINA_API_KEY"]
55
 
56
  # 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
57
+ if HF_API_KEY:
58
+ try:
59
+ hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
60
+ headers = {"Authorization": f"Bearer {HF_API_KEY}"}
61
+ payload = {"inputs": {"query": query, "documents": docs}}
62
+ r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
63
+ r.raise_for_status()
64
+ nv_scores = [round(res["score"], 4) for res in r.json()]
65
+ results["NV-RerankQA-Mistral-4B-v3"] = nv_scores
66
+ except Exception as e:
67
+ results["NV-RerankQA-Mistral-4B-v3"] = [f"Error: {e}"]
68
+ else:
69
+ results["NV-RerankQA-Mistral-4B-v3"] = ["Error: Missing HF_API_KEY"]
70
 
71
  return results
72
 
 
74
  # Gradio UI
75
  # -------------------------------
76
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
77
+ gr.Markdown("## 👑 Ranking Battle (Aligned Scores)\nOutputs only **scores aligned to input docs** from 3 models.")
78
 
79
  query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
80
  docs = gr.Textbox(
81
  label="Documents (Python list)",
82
  lines=6,
83
+ placeholder='Example: [\"Doc one text\", \"Doc two text\", \"Doc three text\"]'
84
  )
85
  out = gr.JSON(label="Model Scores")
86