mlbench123 commited on
Commit
b2f54cb
·
verified ·
1 Parent(s): 3c0e14f

Update llm_client.py

Browse files
Files changed (1) hide show
  1. llm_client.py +184 -198
llm_client.py CHANGED
@@ -1,198 +1,184 @@
1
- #!/usr/bin/env python3
2
- """
3
- Local LLM client abstraction (NO OpenAI/Claude).
4
-
5
- Providers:
6
- - ollama : calls a local Ollama server (your Windows dev)
7
- - transformers : runs a local HF model in-process (best for Hugging Face Spaces CPU)
8
-
9
- Env:
10
- LOCAL_LLM_PROVIDER=ollama|transformers
11
-
12
- Ollama:
13
- OLLAMA_HOST=http://localhost:11434
14
- OLLAMA_MODEL=llama3.2:1b
15
-
16
- Transformers:
17
- HF_LLM_MODEL=TinyLlama/TinyLlama-1.1B-Chat-v1.0 (recommended CPU default)
18
- HF_MAX_NEW_TOKENS=450
19
- """
20
-
21
- from __future__ import annotations
22
-
23
- import json
24
- import os
25
- import re
26
- from typing import Any, Dict, Optional
27
-
28
- import requests
29
-
30
-
31
- class LocalLLMClient:
32
- def __init__(
33
- self,
34
- provider: Optional[str] = None,
35
- model: Optional[str] = None,
36
- host: Optional[str] = None,
37
- timeout_sec: int = 120,
38
- ):
39
- self.provider = (provider or os.getenv("LOCAL_LLM_PROVIDER", "ollama")).lower().strip()
40
- self.timeout_sec = int(os.getenv("LLM_TIMEOUT_SEC", str(timeout_sec)))
41
-
42
- # Ollama settings
43
- self.host = (host or os.getenv("OLLAMA_HOST", "http://localhost:11434")).strip()
44
- self.model = (model or os.getenv("OLLAMA_MODEL", "llama3.2:1b")).strip()
45
-
46
- # Transformers settings (HF Spaces)
47
- self.hf_model_id = (os.getenv("HF_LLM_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")).strip()
48
- self.hf_max_new_tokens = int(os.getenv("HF_MAX_NEW_TOKENS", "450"))
49
-
50
- self._hf_pipe = None # lazy init
51
-
52
- if self.provider not in {"ollama", "transformers"}:
53
- raise ValueError(
54
- f"Unsupported LOCAL_LLM_PROVIDER='{self.provider}'. "
55
- "Supported: ollama, transformers."
56
- )
57
-
58
- # --------------------------- Public API ---------------------------
59
-
60
- def generate(self, prompt: str, temperature: float = 0.2, max_tokens: int = 900) -> str:
61
- prompt = (prompt or "").strip()
62
- if not prompt:
63
- return ""
64
-
65
- if self.provider == "ollama":
66
- return self._generate_ollama(prompt, temperature=temperature, max_tokens=max_tokens)
67
-
68
- # transformers
69
- return self._generate_transformers(prompt, temperature=temperature, max_tokens=max_tokens)
70
-
71
- # --------------------------- Ollama ---------------------------
72
-
73
- def _generate_ollama(self, prompt: str, temperature: float, max_tokens: int) -> str:
74
- url = self.host.rstrip("/") + "/api/generate"
75
- payload: Dict[str, Any] = {
76
- "model": self.model,
77
- "prompt": prompt,
78
- "stream": False,
79
- "options": {
80
- "temperature": float(temperature),
81
- "num_predict": int(max_tokens),
82
- },
83
- }
84
-
85
- try:
86
- r = requests.post(url, json=payload, timeout=self.timeout_sec)
87
- except requests.RequestException as e:
88
- raise RuntimeError(
89
- "Failed to connect to local Ollama.\n"
90
- f"Tried: {url}\n"
91
- "Fix:\n"
92
- " - Ensure Ollama is running\n"
93
- " - Confirm endpoint: iwr http://localhost:11434/api/tags -UseBasicParsing\n"
94
- f"Error: {repr(e)}"
95
- ) from e
96
-
97
- if r.status_code != 200:
98
- body = (r.text or "").strip()
99
- msg = body
100
- try:
101
- j = r.json()
102
- if isinstance(j, dict):
103
- msg = j.get("error") or j.get("message") or body
104
- except Exception:
105
- pass
106
- raise RuntimeError(
107
- "Ollama returned an error.\n"
108
- f"URL: {url}\n"
109
- f"HTTP: {r.status_code}\n"
110
- f"Model: {self.model}\n"
111
- f"Details: {msg}"
112
- )
113
-
114
- data = r.json()
115
- return (data.get("response") or "").strip()
116
-
117
- # --------------------------- Transformers (HF Spaces) ---------------------------
118
-
119
- def _lazy_init_hf(self):
120
- if self._hf_pipe is not None:
121
- return
122
-
123
- # Lazy import to keep local installs lighter
124
- from transformers import pipeline
125
-
126
- # CPU inference; use bfloat16 only if supported (some spaces may not)
127
- # Keep it simple and robust.
128
- self._hf_pipe = pipeline(
129
- "text-generation",
130
- model=self.hf_model_id,
131
- device=-1, # CPU
132
- )
133
-
134
- def _generate_transformers(self, prompt: str, temperature: float, max_tokens: int) -> str:
135
- self._lazy_init_hf()
136
-
137
- # Cap generation for HF CPU
138
- max_new = min(int(max_tokens), int(self.hf_max_new_tokens))
139
-
140
- # Many instruct/chat models work better with a simple instruction wrapper.
141
- wrapped = (
142
- "You are a helpful assistant.\n\n"
143
- f"{prompt}\n\n"
144
- "Answer:"
145
- )
146
-
147
- out = self._hf_pipe(
148
- wrapped,
149
- max_new_tokens=max_new,
150
- do_sample=True,
151
- temperature=float(max(0.05, temperature)),
152
- top_p=0.9,
153
- repetition_penalty=1.1,
154
- )
155
-
156
- if not out:
157
- return ""
158
-
159
- # pipeline returns list[{"generated_text": "..."}]
160
- text = out[0].get("generated_text", "")
161
- text = (text or "").strip()
162
-
163
- # Remove the prompt prefix if the model echoed it
164
- if text.startswith(wrapped):
165
- text = text[len(wrapped):].strip()
166
-
167
- return text
168
-
169
- # --------------------------- JSON helpers ---------------------------
170
-
171
- @staticmethod
172
- def _strip_code_fences(text: str) -> str:
173
- t = text.strip()
174
- t = re.sub(r"^```(?:json)?\s*", "", t, flags=re.IGNORECASE)
175
- t = re.sub(r"\s*```$", "", t)
176
- return t.strip()
177
-
178
- def safe_json_loads(self, text: str) -> Dict[str, Any]:
179
- if not text:
180
- return {}
181
-
182
- t = self._strip_code_fences(text)
183
-
184
- try:
185
- out = json.loads(t)
186
- return out if isinstance(out, dict) else {}
187
- except Exception:
188
- pass
189
-
190
- m = re.search(r"\{.*\}", t, flags=re.DOTALL)
191
- if m:
192
- try:
193
- out = json.loads(m.group(0))
194
- return out if isinstance(out, dict) else {}
195
- except Exception:
196
- return {}
197
-
198
- return {}
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Local LLM client abstraction (NO OpenAI/Claude).
4
+
5
+ Providers:
6
+ - ollama : calls a local Ollama server (for Windows/local dev)
7
+ - transformers : runs a local HF model in-process (required for Hugging Face Spaces)
8
+
9
+ Env:
10
+ LOCAL_LLM_PROVIDER=ollama|transformers
11
+
12
+ Ollama:
13
+ OLLAMA_HOST=http://localhost:11434
14
+ OLLAMA_MODEL=llama3.2:1b
15
+
16
+ Transformers (HF Spaces):
17
+ HF_LLM_MODEL=TinyLlama/TinyLlama-1.1B-Chat-v1.0
18
+ HF_MAX_NEW_TOKENS=450
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import json
24
+ import os
25
+ import re
26
+ from typing import Any, Dict, Optional
27
+
28
+ import requests
29
+
30
+
31
+ class LocalLLMClient:
32
+ def __init__(
33
+ self,
34
+ provider: Optional[str] = None,
35
+ model: Optional[str] = None,
36
+ host: Optional[str] = None,
37
+ timeout_sec: int = 120,
38
+ ):
39
+ self.provider = (provider or os.getenv("LOCAL_LLM_PROVIDER", "ollama")).lower().strip()
40
+ self.timeout_sec = int(os.getenv("LLM_TIMEOUT_SEC", str(timeout_sec)))
41
+
42
+ # Ollama settings (local)
43
+ self.host = (host or os.getenv("OLLAMA_HOST", "http://localhost:11434")).strip()
44
+ self.model = (model or os.getenv("OLLAMA_MODEL", "llama3.2:1b")).strip()
45
+
46
+ # Transformers settings (HF Spaces)
47
+ self.hf_model_id = (os.getenv("HF_LLM_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")).strip()
48
+ self.hf_max_new_tokens = int(os.getenv("HF_MAX_NEW_TOKENS", "450"))
49
+
50
+ self._hf_pipe = None # lazy init
51
+
52
+ if self.provider not in {"ollama", "transformers"}:
53
+ raise ValueError(
54
+ f"Unsupported LOCAL_LLM_PROVIDER='{self.provider}'. Supported: ollama, transformers."
55
+ )
56
+
57
+ def generate(self, prompt: str, temperature: float = 0.2, max_tokens: int = 900) -> str:
58
+ prompt = (prompt or "").strip()
59
+ if not prompt:
60
+ return ""
61
+
62
+ if self.provider == "ollama":
63
+ return self._generate_ollama(prompt, temperature=temperature, max_tokens=max_tokens)
64
+
65
+ return self._generate_transformers(prompt, temperature=temperature, max_tokens=max_tokens)
66
+
67
+ # --------------------------- Ollama ---------------------------
68
+
69
+ def _generate_ollama(self, prompt: str, temperature: float, max_tokens: int) -> str:
70
+ url = self.host.rstrip("/") + "/api/generate"
71
+ payload: Dict[str, Any] = {
72
+ "model": self.model,
73
+ "prompt": prompt,
74
+ "stream": False,
75
+ "options": {
76
+ "temperature": float(temperature),
77
+ "num_predict": int(max_tokens),
78
+ },
79
+ }
80
+
81
+ try:
82
+ r = requests.post(url, json=payload, timeout=self.timeout_sec)
83
+ except requests.RequestException as e:
84
+ raise RuntimeError(
85
+ "Failed to connect to local Ollama.\n"
86
+ f"Tried: {url}\n"
87
+ "Fix:\n"
88
+ " - Ensure Ollama is running\n"
89
+ " - Confirm endpoint: iwr http://localhost:11434/api/tags -UseBasicParsing\n"
90
+ f"Error: {repr(e)}"
91
+ ) from e
92
+
93
+ if r.status_code != 200:
94
+ body = (r.text or "").strip()
95
+ msg = body
96
+ try:
97
+ j = r.json()
98
+ if isinstance(j, dict):
99
+ msg = j.get("error") or j.get("message") or body
100
+ except Exception:
101
+ pass
102
+ raise RuntimeError(
103
+ "Ollama returned an error.\n"
104
+ f"URL: {url}\n"
105
+ f"HTTP: {r.status_code}\n"
106
+ f"Model: {self.model}\n"
107
+ f"Details: {msg}"
108
+ )
109
+
110
+ data = r.json()
111
+ return (data.get("response") or "").strip()
112
+
113
+ # --------------------------- Transformers (HF Spaces) ---------------------------
114
+
115
+ def _lazy_init_hf(self):
116
+ if self._hf_pipe is not None:
117
+ return
118
+
119
+ from transformers import pipeline
120
+
121
+ self._hf_pipe = pipeline(
122
+ "text-generation",
123
+ model=self.hf_model_id,
124
+ device=-1, # CPU
125
+ )
126
+
127
+ def _generate_transformers(self, prompt: str, temperature: float, max_tokens: int) -> str:
128
+ self._lazy_init_hf()
129
+
130
+ max_new = min(int(max_tokens), int(self.hf_max_new_tokens))
131
+
132
+ wrapped = (
133
+ "You are a helpful assistant.\n\n"
134
+ f"{prompt}\n\n"
135
+ "Answer:"
136
+ )
137
+
138
+ out = self._hf_pipe(
139
+ wrapped,
140
+ max_new_tokens=max_new,
141
+ do_sample=True,
142
+ temperature=float(max(0.05, temperature)),
143
+ top_p=0.9,
144
+ repetition_penalty=1.1,
145
+ )
146
+
147
+ if not out:
148
+ return ""
149
+
150
+ text = (out[0].get("generated_text", "") or "").strip()
151
+ if text.startswith(wrapped):
152
+ text = text[len(wrapped):].strip()
153
+ return text
154
+
155
+ # --------------------------- JSON helpers ---------------------------
156
+
157
+ @staticmethod
158
+ def _strip_code_fences(text: str) -> str:
159
+ t = text.strip()
160
+ t = re.sub(r"^```(?:json)?\s*", "", t, flags=re.IGNORECASE)
161
+ t = re.sub(r"\s*```$", "", t)
162
+ return t.strip()
163
+
164
+ def safe_json_loads(self, text: str) -> Dict[str, Any]:
165
+ if not text:
166
+ return {}
167
+
168
+ t = self._strip_code_fences(text)
169
+
170
+ try:
171
+ out = json.loads(t)
172
+ return out if isinstance(out, dict) else {}
173
+ except Exception:
174
+ pass
175
+
176
+ m = re.search(r"\{.*\}", t, flags=re.DOTALL)
177
+ if m:
178
+ try:
179
+ out = json.loads(m.group(0))
180
+ return out if isinstance(out, dict) else {}
181
+ except Exception:
182
+ return {}
183
+
184
+ return {}