aesthetic_AI / llm_client.py
mlbench123's picture
Update llm_client.py
2fb77a4 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
import os
import re
from typing import Any, Dict, Optional
import requests
class LocalLLMClient:
"""
Local LLM client abstraction (NO OpenAI/Claude).
Providers:
- ollama : localhost Ollama (Windows/local dev)
- transformers : in-process HF Transformers (Hugging Face Spaces)
Env:
LOCAL_LLM_PROVIDER=ollama|transformers
Transformers:
HF_LLM_MODEL=Qwen/Qwen2.5-0.5B-Instruct (recommended)
HF_MAX_NEW_TOKENS=220
"""
def __init__(
self,
provider: Optional[str] = None,
model: Optional[str] = None,
host: Optional[str] = None,
timeout_sec: int = 120,
):
self.provider = (provider or os.getenv("LOCAL_LLM_PROVIDER", "ollama")).lower().strip()
self.timeout_sec = int(os.getenv("LLM_TIMEOUT_SEC", str(timeout_sec)))
# Ollama
self.host = (host or os.getenv("OLLAMA_HOST", "http://localhost:11434")).strip()
self.model = (model or os.getenv("OLLAMA_MODEL", "llama3.2:1b")).strip()
# Transformers
self.hf_model_id = (os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")).strip()
self.hf_max_new_tokens = int(os.getenv("HF_MAX_NEW_TOKENS", "220"))
self._tok = None
self._mdl = None
if self.provider not in {"ollama", "transformers"}:
raise ValueError(f"Unsupported LOCAL_LLM_PROVIDER='{self.provider}'. Use ollama or transformers.")
def generate(self, prompt: str, temperature: float = 0.2, max_tokens: int = 900) -> str:
prompt = (prompt or "").strip()
if not prompt:
return ""
if self.provider == "ollama":
return self._generate_ollama(prompt, temperature=temperature, max_tokens=max_tokens)
return self._generate_transformers(prompt, temperature=temperature, max_tokens=max_tokens)
# ---------------- Ollama ----------------
def _generate_ollama(self, prompt: str, temperature: float, max_tokens: int) -> str:
url = self.host.rstrip("/") + "/api/generate"
payload: Dict[str, Any] = {
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": float(temperature),
"num_predict": int(max_tokens),
},
}
r = requests.post(url, json=payload, timeout=self.timeout_sec)
r.raise_for_status()
data = r.json()
return (data.get("response") or "").strip()
# -------------- Transformers (HF) --------------
def _lazy_init_hf(self):
if self._tok is not None and self._mdl is not None:
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "2")))
except Exception:
pass
self._tok = AutoTokenizer.from_pretrained(self.hf_model_id, use_fast=True)
self._mdl = AutoModelForCausalLM.from_pretrained(
self.hf_model_id,
torch_dtype=torch.float32,
device_map=None,
)
self._mdl.eval()
def _chat_wrap(self, prompt: str) -> str:
if self._tok is None:
return prompt
if hasattr(self._tok, "apply_chat_template"):
msgs = [
{"role": "system", "content": "You are a helpful, precise medical aesthetics research assistant."},
{"role": "user", "content": prompt},
]
return self._tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
return "System: You are a helpful assistant.\nUser: " + prompt + "\nAssistant:"
def _generate_transformers(self, prompt: str, temperature: float, max_tokens: int) -> str:
self._lazy_init_hf()
import torch
max_new = min(int(max_tokens), int(self.hf_max_new_tokens))
wrapped = self._chat_wrap(prompt)
# Tokenize and remember prompt token length so we only decode NEW tokens
inp = self._tok(wrapped, return_tensors="pt", truncation=True, max_length=2048)
prompt_len = int(inp["input_ids"].shape[-1])
with torch.inference_mode():
out = self._mdl.generate(
**inp,
do_sample=False, # deterministic -> less garbage
max_new_tokens=max_new,
repetition_penalty=1.08,
eos_token_id=self._tok.eos_token_id,
)
gen_ids = out[0][prompt_len:] # only the new tokens
text = self._tok.decode(gen_ids, skip_special_tokens=True).strip()
# Final cleanup: strip any accidental role labels
text = re.sub(r"^\s*(assistant|system|user)\s*[:\-]\s*", "", text, flags=re.IGNORECASE)
return text.strip()
# ---------------- JSON helpers ----------------
@staticmethod
def _strip_code_fences(text: str) -> str:
t = text.strip()
t = re.sub(r"^```(?:json)?\s*", "", t, flags=re.IGNORECASE)
t = re.sub(r"\s*```$", "", t)
return t.strip()
def safe_json_loads(self, text: str) -> Dict[str, Any]:
if not text:
return {}
t = self._strip_code_fences(text)
try:
obj = json.loads(t)
return obj if isinstance(obj, dict) else {}
except Exception:
m = re.search(r"\{.*\}", t, flags=re.DOTALL)
if m:
try:
obj = json.loads(m.group(0))
return obj if isinstance(obj, dict) else {}
except Exception:
return {}
return {}