import argparse import json import os import subprocess import sys import time from pathlib import Path import requests from huggingface_hub import InferenceClient, get_token from infer_local import build_instruction_prompt, build_structured_result REQUIRED_OUTPUT_KEYS = { "code", "explanation", "confidence", "important_tokens", "relevancy_score", "hallucination", "hallucination_check_reason", "latency_ms", } def is_structured_result(payload): return isinstance(payload, dict) and REQUIRED_OUTPUT_KEYS.issubset(payload.keys()) def normalize_hf_response(response): if is_structured_result(response): return json.dumps(response, ensure_ascii=False) if isinstance(response, str): return response generated_text = getattr(response, "generated_text", None) if generated_text is not None: return generated_text if isinstance(response, list) and response: first = response[0] if isinstance(first, dict): return str(first.get("generated_text", "")) return str(first) if isinstance(response, dict): if "code" in response and "explanation" in response: return json.dumps(response, ensure_ascii=False) return str(response.get("generated_text", response.get("text", ""))) return str(response) def call_direct_inference_api(repo_id, token, prompt_text, generation_kwargs): headers = {} if token: headers["Authorization"] = f"Bearer {token}" payload = { "inputs": prompt_text, "parameters": generation_kwargs, "options": {"wait_for_model": True}, } response = requests.post( f"https://api-inference.huggingface.co/models/{repo_id}", headers=headers, json=payload, timeout=120, ) try: body = response.json() except ValueError: body = response.text if response.status_code >= 400: raise RuntimeError(f"Hugging Face API error {response.status_code}: {body}") if isinstance(body, dict) and body.get("error"): raise RuntimeError(f"Hugging Face API error: {body['error']}") return body def call_endpoint_url(endpoint_url, token, user_prompt, generation_kwargs): headers = {"Content-Type": "application/json"} if token: headers["Authorization"] = f"Bearer {token}" payload = { "inputs": user_prompt, "parameters": generation_kwargs, "options": {"wait_for_model": True}, } response = requests.post(endpoint_url, headers=headers, json=payload, timeout=180) try: body = response.json() except ValueError: body = response.text if response.status_code >= 400: raise RuntimeError(f"Endpoint API error {response.status_code}: {body}") if isinstance(body, dict) and body.get("error"): raise RuntimeError(f"Endpoint API error: {body['error']}") return body def run_local_fallback(args, reason): if not args.fallback_model_path: raise RuntimeError(reason) if not os.path.exists(args.fallback_model_path): raise RuntimeError( f"{reason}\nLocal fallback model path not found: {args.fallback_model_path}" ) print( ( "Warning: Hugging Face cloud inference could not serve this repo. " f"Falling back to local model path '{args.fallback_model_path}'. Reason: {reason}" ), file=sys.stderr, ) script_path = Path(__file__).resolve().with_name("infer_local.py") cmd = [ sys.executable, str(script_path), "--model-path", args.fallback_model_path, "--prompt", args.prompt, "--max-new-tokens", str(args.max_new_tokens), ] if args.do_sample: cmd.extend( [ "--do-sample", "--temperature", str(args.temperature), "--top-p", str(args.top_p), ] ) if args.allow_downloads: cmd.append("--allow-downloads") completed = subprocess.run(cmd, check=True, text=True, capture_output=True) if completed.stderr: print(completed.stderr, file=sys.stderr, end="") print(completed.stdout, end="") def main(): parser = argparse.ArgumentParser() parser.add_argument("--repo-id", type=str, default="") parser.add_argument( "--endpoint-url", type=str, default=os.getenv("HF_ENDPOINT_URL", ""), help="Dedicated inference endpoint URL. Use this for true cloud inference.", ) parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--token", type=str, default=os.getenv("HF_TOKEN")) parser.add_argument( "--fallback-model-path", type=str, default="model", help="Local model path used when Hugging Face cannot serve the repo.", ) parser.add_argument( "--no-local-fallback", action="store_true", help="Fail instead of running local fallback when cloud inference is unavailable.", ) parser.add_argument("--max-new-tokens", type=int, default=320) parser.add_argument("--temperature", type=float, default=0.25) parser.add_argument("--top-p", type=float, default=0.9) parser.add_argument("--do-sample", action="store_true") parser.add_argument( "--allow-downloads", action="store_true", help="Allow local fallback inference to download missing model files.", ) args = parser.parse_args() if args.no_local_fallback: args.fallback_model_path = "" if not args.repo_id and not args.endpoint_url: raise ValueError("Pass --repo-id or --endpoint-url.") token = args.token or get_token() prompt_text = build_instruction_prompt(args.prompt) generation_kwargs = { "max_new_tokens": args.max_new_tokens, "return_full_text": False, } if args.do_sample: generation_kwargs["temperature"] = args.temperature generation_kwargs["top_p"] = args.top_p else: generation_kwargs["temperature"] = 0.01 start_time = time.perf_counter() if args.endpoint_url: try: response = call_endpoint_url(args.endpoint_url, token, args.prompt, generation_kwargs) except Exception as exc: run_local_fallback(args, str(exc)) return else: client = InferenceClient(model=args.repo_id, token=token) try: response = client.text_generation(prompt_text, **generation_kwargs) except TypeError: generation_kwargs.pop("return_full_text", None) try: response = client.text_generation(prompt_text, **generation_kwargs) except Exception as exc: try: response = call_direct_inference_api( args.repo_id, token, prompt_text, generation_kwargs ) except Exception as direct_exc: run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}") return except Exception as exc: try: response = call_direct_inference_api(args.repo_id, token, prompt_text, generation_kwargs) except Exception as direct_exc: run_local_fallback(args, f"{exc}; direct API fallback failed: {direct_exc}") return latency_ms = int((time.perf_counter() - start_time) * 1000) if is_structured_result(response): print(json.dumps(response, indent=2, ensure_ascii=False)) return generated_text = normalize_hf_response(response).strip() if generated_text.startswith(prompt_text): generated_text = generated_text[len(prompt_text) :].strip() generated_text = generated_text.replace("<|im_end|>", "").strip() result = build_structured_result( args.prompt, generated_text, latency_ms, default_confidence=0.0, ) print(json.dumps(result, indent=2, ensure_ascii=False)) if __name__ == "__main__": try: main() except (RuntimeError, ValueError) as exc: print( json.dumps( { "error": "Cloud inference request failed.", "reason": str(exc), "cloud_available": False, "hint": ( "Pass --repo-id for development fallback mode, or pass " "--endpoint-url for a deployed Hugging Face Dedicated " "Inference Endpoint." ), }, indent=2, ensure_ascii=False, ), file=sys.stderr, ) sys.exit(1)