girish00 commited on
Commit
b330ff5
·
verified ·
1 Parent(s): f204dba

update endpoint helper files

Browse files
Files changed (1) hide show
  1. infer_cloud.py +38 -4
infer_cloud.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import subprocess
5
  import sys
6
  import time
 
7
 
8
  import requests
9
  from huggingface_hub import InferenceClient, get_token
@@ -11,7 +12,26 @@ from huggingface_hub import InferenceClient, get_token
11
  from infer_local import build_instruction_prompt, build_structured_result
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def normalize_hf_response(response):
 
 
 
15
  if isinstance(response, str):
16
  return response
17
 
@@ -26,6 +46,8 @@ def normalize_hf_response(response):
26
  return str(first)
27
 
28
  if isinstance(response, dict):
 
 
29
  return str(response.get("generated_text", response.get("text", "")))
30
 
31
  return str(response)
@@ -59,13 +81,13 @@ def call_direct_inference_api(repo_id, token, prompt_text, generation_kwargs):
59
  return body
60
 
61
 
62
- def call_endpoint_url(endpoint_url, token, prompt_text, generation_kwargs):
63
  headers = {"Content-Type": "application/json"}
64
  if token:
65
  headers["Authorization"] = f"Bearer {token}"
66
 
67
  payload = {
68
- "inputs": prompt_text,
69
  "parameters": generation_kwargs,
70
  "options": {"wait_for_model": True},
71
  }
@@ -97,9 +119,10 @@ def run_local_fallback(args, reason):
97
  ),
98
  file=sys.stderr,
99
  )
 
100
  cmd = [
101
  sys.executable,
102
- "infer_local.py",
103
  "--model-path",
104
  args.fallback_model_path,
105
  "--prompt",
@@ -117,6 +140,8 @@ def run_local_fallback(args, reason):
117
  str(args.top_p),
118
  ]
119
  )
 
 
120
  completed = subprocess.run(cmd, check=True, text=True, capture_output=True)
121
  if completed.stderr:
122
  print(completed.stderr, file=sys.stderr, end="")
@@ -149,6 +174,11 @@ def main():
149
  parser.add_argument("--temperature", type=float, default=0.25)
150
  parser.add_argument("--top-p", type=float, default=0.9)
151
  parser.add_argument("--do-sample", action="store_true")
 
 
 
 
 
152
  args = parser.parse_args()
153
  if args.no_local_fallback:
154
  args.fallback_model_path = ""
@@ -171,7 +201,7 @@ def main():
171
  start_time = time.perf_counter()
172
  if args.endpoint_url:
173
  try:
174
- response = call_endpoint_url(args.endpoint_url, token, prompt_text, generation_kwargs)
175
  except Exception as exc:
176
  run_local_fallback(args, str(exc))
177
  return
@@ -199,6 +229,10 @@ def main():
199
  return
200
  latency_ms = int((time.perf_counter() - start_time) * 1000)
201
 
 
 
 
 
202
  generated_text = normalize_hf_response(response).strip()
203
  if generated_text.startswith(prompt_text):
204
  generated_text = generated_text[len(prompt_text) :].strip()
 
4
  import subprocess
5
  import sys
6
  import time
7
+ from pathlib import Path
8
 
9
  import requests
10
  from huggingface_hub import InferenceClient, get_token
 
12
  from infer_local import build_instruction_prompt, build_structured_result
13
 
14
 
15
+ REQUIRED_OUTPUT_KEYS = {
16
+ "code",
17
+ "explanation",
18
+ "confidence",
19
+ "important_tokens",
20
+ "relevancy_score",
21
+ "hallucination",
22
+ "hallucination_check_reason",
23
+ "latency_ms",
24
+ }
25
+
26
+
27
+ def is_structured_result(payload):
28
+ return isinstance(payload, dict) and REQUIRED_OUTPUT_KEYS.issubset(payload.keys())
29
+
30
+
31
  def normalize_hf_response(response):
32
+ if is_structured_result(response):
33
+ return json.dumps(response, ensure_ascii=False)
34
+
35
  if isinstance(response, str):
36
  return response
37
 
 
46
  return str(first)
47
 
48
  if isinstance(response, dict):
49
+ if "code" in response and "explanation" in response:
50
+ return json.dumps(response, ensure_ascii=False)
51
  return str(response.get("generated_text", response.get("text", "")))
52
 
53
  return str(response)
 
81
  return body
82
 
83
 
84
+ def call_endpoint_url(endpoint_url, token, user_prompt, generation_kwargs):
85
  headers = {"Content-Type": "application/json"}
86
  if token:
87
  headers["Authorization"] = f"Bearer {token}"
88
 
89
  payload = {
90
+ "inputs": user_prompt,
91
  "parameters": generation_kwargs,
92
  "options": {"wait_for_model": True},
93
  }
 
119
  ),
120
  file=sys.stderr,
121
  )
122
+ script_path = Path(__file__).resolve().with_name("infer_local.py")
123
  cmd = [
124
  sys.executable,
125
+ str(script_path),
126
  "--model-path",
127
  args.fallback_model_path,
128
  "--prompt",
 
140
  str(args.top_p),
141
  ]
142
  )
143
+ if args.allow_downloads:
144
+ cmd.append("--allow-downloads")
145
  completed = subprocess.run(cmd, check=True, text=True, capture_output=True)
146
  if completed.stderr:
147
  print(completed.stderr, file=sys.stderr, end="")
 
174
  parser.add_argument("--temperature", type=float, default=0.25)
175
  parser.add_argument("--top-p", type=float, default=0.9)
176
  parser.add_argument("--do-sample", action="store_true")
177
+ parser.add_argument(
178
+ "--allow-downloads",
179
+ action="store_true",
180
+ help="Allow local fallback inference to download missing model files.",
181
+ )
182
  args = parser.parse_args()
183
  if args.no_local_fallback:
184
  args.fallback_model_path = ""
 
201
  start_time = time.perf_counter()
202
  if args.endpoint_url:
203
  try:
204
+ response = call_endpoint_url(args.endpoint_url, token, args.prompt, generation_kwargs)
205
  except Exception as exc:
206
  run_local_fallback(args, str(exc))
207
  return
 
229
  return
230
  latency_ms = int((time.perf_counter() - start_time) * 1000)
231
 
232
+ if is_structured_result(response):
233
+ print(json.dumps(response, indent=2, ensure_ascii=False))
234
+ return
235
+
236
  generated_text = normalize_hf_response(response).strip()
237
  if generated_text.startswith(prompt_text):
238
  generated_text = generated_text[len(prompt_text) :].strip()