cmboulanger commited on
Commit
22d2534
·
1 Parent(s): fe052dd

Add timeout parameter

Browse files
Files changed (1) hide show
  1. scripts/evaluate_llm.py +19 -7
scripts/evaluate_llm.py CHANGED
@@ -70,11 +70,11 @@ _load_env(_REPO / ".env")
70
  # ---------------------------------------------------------------------------
71
 
72
 
73
- def _post_json(url: str, payload: dict, headers: dict) -> dict:
74
  body = json.dumps(payload).encode()
75
  req = urllib.request.Request(url, data=body, headers=headers, method="POST")
76
  try:
77
- with urllib.request.urlopen(req, timeout=120) as resp:
78
  return json.loads(resp.read())
79
  except urllib.error.HTTPError as exc:
80
  detail = exc.read().decode(errors="replace")
@@ -86,7 +86,7 @@ def _post_json(url: str, payload: dict, headers: dict) -> dict:
86
  # ---------------------------------------------------------------------------
87
 
88
 
89
- def make_gemini_call_fn(api_key: str, model: str = "gemini-2.0-flash"):
90
  """Return a call_fn that sends a prompt to Gemini and returns the text reply."""
91
  url = (
92
  f"https://generativelanguage.googleapis.com/v1beta/models"
@@ -98,7 +98,7 @@ def make_gemini_call_fn(api_key: str, model: str = "gemini-2.0-flash"):
98
  "contents": [{"parts": [{"text": prompt}]}],
99
  "generationConfig": {"temperature": 0.1},
100
  }
101
- result = _post_json(url, payload, {"Content-Type": "application/json"})
102
  return result["candidates"][0]["content"]["parts"][0]["text"]
103
 
104
  call_fn.__name__ = f"gemini/{model}"
@@ -109,6 +109,7 @@ def make_kisski_call_fn(
109
  api_key: str,
110
  base_url: str = "https://chat-ai.academiccloud.de/v1",
111
  model: str = "llama-3.3-70b-instruct",
 
112
  ):
113
  """Return a call_fn for a KISSKI-hosted OpenAI-compatible model."""
114
  url = f"{base_url}/chat/completions"
@@ -123,7 +124,7 @@ def make_kisski_call_fn(
123
  "messages": [{"role": "user", "content": prompt}],
124
  "temperature": 0.1,
125
  }
126
- result = _post_json(url, payload, headers)
127
  return result["choices"][0]["message"]["content"]
128
 
129
  call_fn.__name__ = f"kisski/{model}"
@@ -735,6 +736,17 @@ def _parse_args() -> argparse.Namespace:
735
  "(\"lost in the middle\" effect for items in large batches)."
736
  ),
737
  )
 
 
 
 
 
 
 
 
 
 
 
738
  return p.parse_args()
739
 
740
 
@@ -752,7 +764,7 @@ def main() -> int:
752
  if args.provider == "gemini":
753
  return 1
754
  else:
755
- providers.append(("Gemini 2.0 Flash", make_gemini_call_fn(gemini_key)))
756
 
757
  if args.provider in ("kisski", "all"):
758
  if not kisski_key:
@@ -761,7 +773,7 @@ def main() -> int:
761
  return 1
762
  else:
763
  providers.append(
764
- ("KISSKI / llama-3.3-70b-instruct", make_kisski_call_fn(kisski_key))
765
  )
766
 
767
  if not providers:
 
70
  # ---------------------------------------------------------------------------
71
 
72
 
73
+ def _post_json(url: str, payload: dict, headers: dict, timeout: int = 120) -> dict:
74
  body = json.dumps(payload).encode()
75
  req = urllib.request.Request(url, data=body, headers=headers, method="POST")
76
  try:
77
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
78
  return json.loads(resp.read())
79
  except urllib.error.HTTPError as exc:
80
  detail = exc.read().decode(errors="replace")
 
86
  # ---------------------------------------------------------------------------
87
 
88
 
89
+ def make_gemini_call_fn(api_key: str, model: str = "gemini-2.0-flash", timeout: int = 120):
90
  """Return a call_fn that sends a prompt to Gemini and returns the text reply."""
91
  url = (
92
  f"https://generativelanguage.googleapis.com/v1beta/models"
 
98
  "contents": [{"parts": [{"text": prompt}]}],
99
  "generationConfig": {"temperature": 0.1},
100
  }
101
+ result = _post_json(url, payload, {"Content-Type": "application/json"}, timeout)
102
  return result["candidates"][0]["content"]["parts"][0]["text"]
103
 
104
  call_fn.__name__ = f"gemini/{model}"
 
109
  api_key: str,
110
  base_url: str = "https://chat-ai.academiccloud.de/v1",
111
  model: str = "llama-3.3-70b-instruct",
112
+ timeout: int = 120,
113
  ):
114
  """Return a call_fn for a KISSKI-hosted OpenAI-compatible model."""
115
  url = f"{base_url}/chat/completions"
 
124
  "messages": [{"role": "user", "content": prompt}],
125
  "temperature": 0.1,
126
  }
127
+ result = _post_json(url, payload, headers, timeout)
128
  return result["choices"][0]["message"]["content"]
129
 
130
  call_fn.__name__ = f"kisski/{model}"
 
736
  "(\"lost in the middle\" effect for items in large batches)."
737
  ),
738
  )
739
+ p.add_argument(
740
+ "--timeout",
741
+ type=int,
742
+ default=120,
743
+ metavar="SECONDS",
744
+ help=(
745
+ "HTTP read timeout in seconds for each LLM API call. "
746
+ "Default=120. Increase when using large batch sizes with slow models "
747
+ "(e.g. --timeout 600 --batch-size 10 for KISSKI Llama)."
748
+ ),
749
+ )
750
  return p.parse_args()
751
 
752
 
 
764
  if args.provider == "gemini":
765
  return 1
766
  else:
767
+ providers.append(("Gemini 2.0 Flash", make_gemini_call_fn(gemini_key, timeout=args.timeout)))
768
 
769
  if args.provider in ("kisski", "all"):
770
  if not kisski_key:
 
773
  return 1
774
  else:
775
  providers.append(
776
+ ("KISSKI / llama-3.3-70b-instruct", make_kisski_call_fn(kisski_key, timeout=args.timeout))
777
  )
778
 
779
  if not providers: