girish00 commited on
Commit
eb75868
·
verified ·
1 Parent(s): c29fe9d

add cloud structured inference wrapper

Browse files
Files changed (1) hide show
  1. infer_cloud.py +110 -0
infer_cloud.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import requests
7
+ from huggingface_hub import InferenceClient, get_token
8
+
9
+ from infer_local import build_instruction_prompt, build_structured_result
10
+
11
+
12
+ def normalize_hf_response(response):
13
+ if isinstance(response, str):
14
+ return response
15
+
16
+ generated_text = getattr(response, "generated_text", None)
17
+ if generated_text is not None:
18
+ return generated_text
19
+
20
+ if isinstance(response, list) and response:
21
+ first = response[0]
22
+ if isinstance(first, dict):
23
+ return str(first.get("generated_text", ""))
24
+ return str(first)
25
+
26
+ if isinstance(response, dict):
27
+ return str(response.get("generated_text", response.get("text", "")))
28
+
29
+ return str(response)
30
+
31
+
32
+ def call_direct_inference_api(repo_id, token, prompt_text, generation_kwargs):
33
+ headers = {}
34
+ if token:
35
+ headers["Authorization"] = f"Bearer {token}"
36
+
37
+ payload = {
38
+ "inputs": prompt_text,
39
+ "parameters": generation_kwargs,
40
+ "options": {"wait_for_model": True},
41
+ }
42
+ response = requests.post(
43
+ f"https://api-inference.huggingface.co/models/{repo_id}",
44
+ headers=headers,
45
+ json=payload,
46
+ timeout=120,
47
+ )
48
+ try:
49
+ body = response.json()
50
+ except ValueError:
51
+ body = response.text
52
+
53
+ if response.status_code >= 400:
54
+ raise RuntimeError(f"Hugging Face API error {response.status_code}: {body}")
55
+ if isinstance(body, dict) and body.get("error"):
56
+ raise RuntimeError(f"Hugging Face API error: {body['error']}")
57
+ return body
58
+
59
+
60
+ def main():
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument("--repo-id", type=str, required=True)
63
+ parser.add_argument("--prompt", type=str, required=True)
64
+ parser.add_argument("--token", type=str, default=os.getenv("HF_TOKEN"))
65
+ parser.add_argument("--max-new-tokens", type=int, default=320)
66
+ parser.add_argument("--temperature", type=float, default=0.25)
67
+ parser.add_argument("--top-p", type=float, default=0.9)
68
+ parser.add_argument("--do-sample", action="store_true")
69
+ args = parser.parse_args()
70
+
71
+ token = args.token or get_token()
72
+ client = InferenceClient(model=args.repo_id, token=token)
73
+ prompt_text = build_instruction_prompt(args.prompt)
74
+
75
+ generation_kwargs = {
76
+ "max_new_tokens": args.max_new_tokens,
77
+ "return_full_text": False,
78
+ }
79
+ if args.do_sample:
80
+ generation_kwargs["temperature"] = args.temperature
81
+ generation_kwargs["top_p"] = args.top_p
82
+ else:
83
+ generation_kwargs["temperature"] = 0.01
84
+
85
+ start_time = time.perf_counter()
86
+ try:
87
+ response = client.text_generation(prompt_text, **generation_kwargs)
88
+ except TypeError:
89
+ generation_kwargs.pop("return_full_text", None)
90
+ response = client.text_generation(prompt_text, **generation_kwargs)
91
+ except Exception:
92
+ response = call_direct_inference_api(args.repo_id, token, prompt_text, generation_kwargs)
93
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
94
+
95
+ generated_text = normalize_hf_response(response).strip()
96
+ if generated_text.startswith(prompt_text):
97
+ generated_text = generated_text[len(prompt_text) :].strip()
98
+ generated_text = generated_text.replace("<|im_end|>", "").strip()
99
+
100
+ result = build_structured_result(
101
+ args.prompt,
102
+ generated_text,
103
+ latency_ms,
104
+ default_confidence=0.0,
105
+ )
106
+ print(json.dumps(result, indent=2, ensure_ascii=False))
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()