Navya-Sree commited on
Commit
e53a03c
·
verified ·
1 Parent(s): 18c9c26

Create src/macg/llm_hf.py

Browse files
Files changed (1) hide show
  1. src/macg/llm_hf.py +70 -0
src/macg/llm_hf.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import time
4
+ import requests
5
+ from macg.llm_base import LLMClient
6
+
7
+ def _strip_code_fences(text: str) -> str:
8
+ t = text.strip()
9
+ if "```" not in t:
10
+ return t
11
+ for fence in ("```python", "```py", "```"):
12
+ if fence in t:
13
+ start = t.find(fence) + len(fence)
14
+ end = t.find("```", start)
15
+ if end != -1:
16
+ return t[start:end].strip()
17
+ return t.replace("```", "").strip()
18
+
19
+ class HuggingFaceInferenceLLM(LLMClient):
20
+ def __init__(
21
+ self,
22
+ model: str,
23
+ token: str | None = None,
24
+ max_new_tokens: int = 900,
25
+ temperature: float = 0.2,
26
+ retries: int = 4,
27
+ timeout_s: int = 90,
28
+ ) -> None:
29
+ self.model = model
30
+ self.token = token or os.getenv("HF_TOKEN")
31
+ if not self.token:
32
+ raise ValueError("HF_TOKEN is not set.")
33
+ self.max_new_tokens = max_new_tokens
34
+ self.temperature = temperature
35
+ self.retries = retries
36
+ self.timeout_s = timeout_s
37
+
38
+ def complete(self, system: str, prompt: str) -> str:
39
+ url = f"https://api-inference.huggingface.co/models/{self.model}"
40
+ headers = {"Authorization": f"Bearer {self.token}"}
41
+ payload = {
42
+ "inputs": f"{system}\n\n{prompt}".strip(),
43
+ "parameters": {
44
+ "max_new_tokens": self.max_new_tokens,
45
+ "temperature": self.temperature,
46
+ "return_full_text": False,
47
+ },
48
+ "options": {"wait_for_model": True},
49
+ }
50
+
51
+ last_err = None
52
+ for attempt in range(1, self.retries + 1):
53
+ try:
54
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout_s)
55
+ if r.status_code in (503, 504):
56
+ time.sleep(min(2 * attempt, 8))
57
+ continue
58
+ r.raise_for_status()
59
+ data = r.json()
60
+
61
+ if isinstance(data, list) and data and "generated_text" in data[0]:
62
+ return _strip_code_fences(data[0]["generated_text"])
63
+ if isinstance(data, dict) and "error" in data:
64
+ raise RuntimeError(f"HF error: {data['error']}")
65
+ raise RuntimeError(f"Unexpected HF response: {data}")
66
+ except Exception as e:
67
+ last_err = e
68
+ time.sleep(min(2 * attempt, 8))
69
+
70
+ raise RuntimeError(f"HF inference failed. Last error: {last_err}")