ConicAI_LLM_model / infer_cloud.py
girish00's picture
update endpoint helper files
b330ff5 verified
import argparse
import json
import os
import subprocess
import sys
import time
from pathlib import Path
import requests
from huggingface_hub import InferenceClient, get_token
from infer_local import build_instruction_prompt, build_structured_result
REQUIRED_OUTPUT_KEYS = {
"code",
"explanation",
"confidence",
"important_tokens",
"relevancy_score",
"hallucination",
"hallucination_check_reason",
"latency_ms",
}
def is_structured_result(payload):
return isinstance(payload, dict) and REQUIRED_OUTPUT_KEYS.issubset(payload.keys())
def normalize_hf_response(response):
if is_structured_result(response):
return json.dumps(response, ensure_ascii=False)
if isinstance(response, str):
return response
generated_text = getattr(response, "generated_text", None)
if generated_text is not None:
return generated_text
if isinstance(response, list) and response:
first = response[0]
if isinstance(first, dict):
return str(first.get("generated_text", ""))
return str(first)
if isinstance(response, dict):
if "code" in response and "explanation" in response:
return json.dumps(response, ensure_ascii=False)
return str(response.get("generated_text", response.get("text", "")))
return str(response)
def call_direct_inference_api(repo_id, token, prompt_text, generation_kwargs):
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
payload = {
"inputs": prompt_text,
"parameters": generation_kwargs,
"options": {"wait_for_model": True},
}
response = requests.post(
f"https://api-inference.huggingface.co/models/{repo_id}",
headers=headers,
json=payload,
timeout=120,
)
try:
body = response.json()
except ValueError:
body = response.text
if response.status_code >= 400:
raise RuntimeError(f"Hugging Face API error {response.status_code}: {body}")
if isinstance(body, dict) and body.get("error"):
raise RuntimeError(f"Hugging Face API error: {body['error']}")
return body
def call_endpoint_url(endpoint_url, token, user_prompt, generation_kwargs):
headers = {"Content-Type": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
payload = {
"inputs": user_prompt,
"parameters": generation_kwargs,
"options": {"wait_for_model": True},
}
response = requests.post(endpoint_url, headers=headers, json=payload, timeout=180)
try:
body = response.json()
except ValueError:
body = response.text
if response.status_code >= 400:
raise RuntimeError(f"Endpoint API error {response.status_code}: {body}")
if isinstance(body, dict) and body.get("error"):
raise RuntimeError(f"Endpoint API error: {body['error']}")
return body
def run_local_fallback(args, reason):
if not args.fallback_model_path:
raise RuntimeError(reason)
if not os.path.exists(args.fallback_model_path):
raise RuntimeError(
f"{reason}\nLocal fallback model path not found: {args.fallback_model_path}"
)
print(
(
"Warning: Hugging Face cloud inference could not serve this repo. "
f"Falling back to local model path '{args.fallback_model_path}'. Reason: {reason}"
),
file=sys.stderr,
)
script_path = Path(__file__).resolve().with_name("infer_local.py")
cmd = [
sys.executable,
str(script_path),
"--model-path",
args.fallback_model_path,
"--prompt",
args.prompt,
"--max-new-tokens",
str(args.max_new_tokens),
]
if args.do_sample:
cmd.extend(
[
"--do-sample",
"--temperature",
str(args.temperature),
"--top-p",
str(args.top_p),
]
)
if args.allow_downloads:
cmd.append("--allow-downloads")
completed = subprocess.run(cmd, check=True, text=True, capture_output=True)
if completed.stderr:
print(completed.stderr, file=sys.stderr, end="")
print(completed.stdout, end="")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, default="")
parser.add_argument(
"--endpoint-url",
type=str,
default=os.getenv("HF_ENDPOINT_URL", ""),
help="Dedicated inference endpoint URL. Use this for true cloud inference.",
)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--token", type=str, default=os.getenv("HF_TOKEN"))
parser.add_argument(
"--fallback-model-path",
type=str,
default="model",
help="Local model path used when Hugging Face cannot serve the repo.",
)
parser.add_argument(
"--no-local-fallback",
action="store_true",
help="Fail instead of running local fallback when cloud inference is unavailable.",
)
parser.add_argument("--max-new-tokens", type=int, default=320)
parser.add_argument("--temperature", type=float, default=0.25)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--do-sample", action="store_true")
parser.add_argument(
"--allow-downloads",
action="store_true",
help="Allow local fallback inference to download missing model files.",
)
args = parser.parse_args()
if args.no_local_fallback:
args.fallback_model_path = ""
if not args.repo_id and not args.endpoint_url:
raise ValueError("Pass --repo-id or --endpoint-url.")
token = args.token or get_token()
prompt_text = build_instruction_prompt(args.prompt)
generation_kwargs = {
"max_new_tokens": args.max_new_tokens,
"return_full_text": False,
}
if args.do_sample:
generation_kwargs["temperature"] = args.temperature
generation_kwargs["top_p"] = args.top_p
else:
generation_kwargs["temperature"] = 0.01
start_time = time.perf_counter()
if args.endpoint_url:
try:
response = call_endpoint_url(args.endpoint_url, token, args.prompt, generation_kwargs)
except Exception as exc:
run_local_fallback(args, str(exc))
return
else:
client = InferenceClient(model=args.repo_id, token=token)
try:
response = client.text_generation(prompt_text, **generation_kwargs)
except TypeError:
generation_kwargs.pop("return_full_text", None)
try:
response = client.text_generation(prompt_text, **generation_kwargs)
except Exception as exc:
try:
response = call_direct_inference_api(
args.repo_id, token, prompt_text, generation_kwargs
)
except Exception as direct_exc:
run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}")
return
except Exception as exc:
try:
response = call_direct_inference_api(args.repo_id, token, prompt_text, generation_kwargs)
except Exception as direct_exc:
run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}")
return
latency_ms = int((time.perf_counter() - start_time) * 1000)
if is_structured_result(response):
print(json.dumps(response, indent=2, ensure_ascii=False))
return
generated_text = normalize_hf_response(response).strip()
if generated_text.startswith(prompt_text):
generated_text = generated_text[len(prompt_text) :].strip()
generated_text = generated_text.replace("<|im_end|>", "").strip()
result = build_structured_result(
args.prompt,
generated_text,
latency_ms,
default_confidence=0.0,
)
print(json.dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
try:
main()
except (RuntimeError, ValueError) as exc:
print(
json.dumps(
{
"error": "Cloud inference request failed.",
"reason": str(exc),
"cloud_available": False,
"hint": (
"Pass --repo-id for development fallback mode, or pass "
"--endpoint-url for a deployed Hugging Face Dedicated "
"Inference Endpoint."
),
},
indent=2,
ensure_ascii=False,
),
file=sys.stderr,
)
sys.exit(1)