girish00 commited on
Commit
dc14a91
·
verified ·
1 Parent(s): 7e079b1

add dedicated endpoint cloud mode

Browse files
Files changed (1) hide show
  1. infer_cloud.py +53 -15
infer_cloud.py CHANGED
@@ -59,6 +59,29 @@ def call_direct_inference_api(repo_id, token, prompt_text, generation_kwargs):
59
  return body
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def run_local_fallback(args, reason):
63
  if not args.fallback_model_path:
64
  raise RuntimeError(reason)
@@ -102,7 +125,13 @@ def run_local_fallback(args, reason):
102
 
103
  def main():
104
  parser = argparse.ArgumentParser()
105
- parser.add_argument("--repo-id", type=str, required=True)
 
 
 
 
 
 
106
  parser.add_argument("--prompt", type=str, required=True)
107
  parser.add_argument("--token", type=str, default=os.getenv("HF_TOKEN"))
108
  parser.add_argument(
@@ -123,9 +152,10 @@ def main():
123
  args = parser.parse_args()
124
  if args.no_local_fallback:
125
  args.fallback_model_path = ""
 
 
126
 
127
  token = args.token or get_token()
128
- client = InferenceClient(model=args.repo_id, token=token)
129
  prompt_text = build_instruction_prompt(args.prompt)
130
 
131
  generation_kwargs = {
@@ -139,26 +169,34 @@ def main():
139
  generation_kwargs["temperature"] = 0.01
140
 
141
  start_time = time.perf_counter()
142
- try:
143
- response = client.text_generation(prompt_text, **generation_kwargs)
144
- except TypeError:
145
- generation_kwargs.pop("return_full_text", None)
 
 
 
 
146
  try:
147
  response = client.text_generation(prompt_text, **generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
148
  except Exception as exc:
149
  try:
150
- response = call_direct_inference_api(
151
- args.repo_id, token, prompt_text, generation_kwargs
152
- )
153
  except Exception as direct_exc:
154
  run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}")
155
  return
156
- except Exception as exc:
157
- try:
158
- response = call_direct_inference_api(args.repo_id, token, prompt_text, generation_kwargs)
159
- except Exception as direct_exc:
160
- run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}")
161
- return
162
  latency_ms = int((time.perf_counter() - start_time) * 1000)
163
 
164
  generated_text = normalize_hf_response(response).strip()
 
59
  return body
60
 
61
 
62
+ def call_endpoint_url(endpoint_url, token, prompt_text, generation_kwargs):
63
+ headers = {"Content-Type": "application/json"}
64
+ if token:
65
+ headers["Authorization"] = f"Bearer {token}"
66
+
67
+ payload = {
68
+ "inputs": prompt_text,
69
+ "parameters": generation_kwargs,
70
+ "options": {"wait_for_model": True},
71
+ }
72
+ response = requests.post(endpoint_url, headers=headers, json=payload, timeout=180)
73
+ try:
74
+ body = response.json()
75
+ except ValueError:
76
+ body = response.text
77
+
78
+ if response.status_code >= 400:
79
+ raise RuntimeError(f"Endpoint API error {response.status_code}: {body}")
80
+ if isinstance(body, dict) and body.get("error"):
81
+ raise RuntimeError(f"Endpoint API error: {body['error']}")
82
+ return body
83
+
84
+
85
  def run_local_fallback(args, reason):
86
  if not args.fallback_model_path:
87
  raise RuntimeError(reason)
 
125
 
126
  def main():
127
  parser = argparse.ArgumentParser()
128
+ parser.add_argument("--repo-id", type=str, default="")
129
+ parser.add_argument(
130
+ "--endpoint-url",
131
+ type=str,
132
+ default=os.getenv("HF_ENDPOINT_URL", ""),
133
+ help="Dedicated inference endpoint URL. Use this for true cloud inference.",
134
+ )
135
  parser.add_argument("--prompt", type=str, required=True)
136
  parser.add_argument("--token", type=str, default=os.getenv("HF_TOKEN"))
137
  parser.add_argument(
 
152
  args = parser.parse_args()
153
  if args.no_local_fallback:
154
  args.fallback_model_path = ""
155
+ if not args.repo_id and not args.endpoint_url:
156
+ raise ValueError("Pass --repo-id or --endpoint-url.")
157
 
158
  token = args.token or get_token()
 
159
  prompt_text = build_instruction_prompt(args.prompt)
160
 
161
  generation_kwargs = {
 
169
  generation_kwargs["temperature"] = 0.01
170
 
171
  start_time = time.perf_counter()
172
+ if args.endpoint_url:
173
+ try:
174
+ response = call_endpoint_url(args.endpoint_url, token, prompt_text, generation_kwargs)
175
+ except Exception as exc:
176
+ run_local_fallback(args, str(exc))
177
+ return
178
+ else:
179
+ client = InferenceClient(model=args.repo_id, token=token)
180
  try:
181
  response = client.text_generation(prompt_text, **generation_kwargs)
182
+ except TypeError:
183
+ generation_kwargs.pop("return_full_text", None)
184
+ try:
185
+ response = client.text_generation(prompt_text, **generation_kwargs)
186
+ except Exception as exc:
187
+ try:
188
+ response = call_direct_inference_api(
189
+ args.repo_id, token, prompt_text, generation_kwargs
190
+ )
191
+ except Exception as direct_exc:
192
+ run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}")
193
+ return
194
  except Exception as exc:
195
  try:
196
+ response = call_direct_inference_api(args.repo_id, token, prompt_text, generation_kwargs)
 
 
197
  except Exception as direct_exc:
198
  run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}")
199
  return
 
 
 
 
 
 
200
  latency_ms = int((time.perf_counter() - start_time) * 1000)
201
 
202
  generated_text = normalize_hf_response(response).strip()