Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from typing import Any, Dict, Optional | |
| import requests | |
| class LocalLLMClient: | |
| """ | |
| Local LLM client abstraction (NO OpenAI/Claude). | |
| Providers: | |
| - ollama : localhost Ollama (Windows/local dev) | |
| - transformers : in-process HF Transformers (Hugging Face Spaces) | |
| Env: | |
| LOCAL_LLM_PROVIDER=ollama|transformers | |
| Transformers: | |
| HF_LLM_MODEL=Qwen/Qwen2.5-0.5B-Instruct (recommended) | |
| HF_MAX_NEW_TOKENS=220 | |
| """ | |
| def __init__( | |
| self, | |
| provider: Optional[str] = None, | |
| model: Optional[str] = None, | |
| host: Optional[str] = None, | |
| timeout_sec: int = 120, | |
| ): | |
| self.provider = (provider or os.getenv("LOCAL_LLM_PROVIDER", "ollama")).lower().strip() | |
| self.timeout_sec = int(os.getenv("LLM_TIMEOUT_SEC", str(timeout_sec))) | |
| # Ollama | |
| self.host = (host or os.getenv("OLLAMA_HOST", "http://localhost:11434")).strip() | |
| self.model = (model or os.getenv("OLLAMA_MODEL", "llama3.2:1b")).strip() | |
| # Transformers | |
| self.hf_model_id = (os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")).strip() | |
| self.hf_max_new_tokens = int(os.getenv("HF_MAX_NEW_TOKENS", "220")) | |
| self._tok = None | |
| self._mdl = None | |
| if self.provider not in {"ollama", "transformers"}: | |
| raise ValueError(f"Unsupported LOCAL_LLM_PROVIDER='{self.provider}'. Use ollama or transformers.") | |
| def generate(self, prompt: str, temperature: float = 0.2, max_tokens: int = 900) -> str: | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| return "" | |
| if self.provider == "ollama": | |
| return self._generate_ollama(prompt, temperature=temperature, max_tokens=max_tokens) | |
| return self._generate_transformers(prompt, temperature=temperature, max_tokens=max_tokens) | |
| # ---------------- Ollama ---------------- | |
| def _generate_ollama(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
| url = self.host.rstrip("/") + "/api/generate" | |
| payload: Dict[str, Any] = { | |
| "model": self.model, | |
| "prompt": prompt, | |
| "stream": False, | |
| "options": { | |
| "temperature": float(temperature), | |
| "num_predict": int(max_tokens), | |
| }, | |
| } | |
| r = requests.post(url, json=payload, timeout=self.timeout_sec) | |
| r.raise_for_status() | |
| data = r.json() | |
| return (data.get("response") or "").strip() | |
| # -------------- Transformers (HF) -------------- | |
| def _lazy_init_hf(self): | |
| if self._tok is not None and self._mdl is not None: | |
| return | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| try: | |
| torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "2"))) | |
| except Exception: | |
| pass | |
| self._tok = AutoTokenizer.from_pretrained(self.hf_model_id, use_fast=True) | |
| self._mdl = AutoModelForCausalLM.from_pretrained( | |
| self.hf_model_id, | |
| torch_dtype=torch.float32, | |
| device_map=None, | |
| ) | |
| self._mdl.eval() | |
| def _chat_wrap(self, prompt: str) -> str: | |
| if self._tok is None: | |
| return prompt | |
| if hasattr(self._tok, "apply_chat_template"): | |
| msgs = [ | |
| {"role": "system", "content": "You are a helpful, precise medical aesthetics research assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| return self._tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| return "System: You are a helpful assistant.\nUser: " + prompt + "\nAssistant:" | |
| def _generate_transformers(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
| self._lazy_init_hf() | |
| import torch | |
| max_new = min(int(max_tokens), int(self.hf_max_new_tokens)) | |
| wrapped = self._chat_wrap(prompt) | |
| # Tokenize and remember prompt token length so we only decode NEW tokens | |
| inp = self._tok(wrapped, return_tensors="pt", truncation=True, max_length=2048) | |
| prompt_len = int(inp["input_ids"].shape[-1]) | |
| with torch.inference_mode(): | |
| out = self._mdl.generate( | |
| **inp, | |
| do_sample=False, # deterministic -> less garbage | |
| max_new_tokens=max_new, | |
| repetition_penalty=1.08, | |
| eos_token_id=self._tok.eos_token_id, | |
| ) | |
| gen_ids = out[0][prompt_len:] # only the new tokens | |
| text = self._tok.decode(gen_ids, skip_special_tokens=True).strip() | |
| # Final cleanup: strip any accidental role labels | |
| text = re.sub(r"^\s*(assistant|system|user)\s*[:\-]\s*", "", text, flags=re.IGNORECASE) | |
| return text.strip() | |
| # ---------------- JSON helpers ---------------- | |
| def _strip_code_fences(text: str) -> str: | |
| t = text.strip() | |
| t = re.sub(r"^```(?:json)?\s*", "", t, flags=re.IGNORECASE) | |
| t = re.sub(r"\s*```$", "", t) | |
| return t.strip() | |
| def safe_json_loads(self, text: str) -> Dict[str, Any]: | |
| if not text: | |
| return {} | |
| t = self._strip_code_fences(text) | |
| try: | |
| obj = json.loads(t) | |
| return obj if isinstance(obj, dict) else {} | |
| except Exception: | |
| m = re.search(r"\{.*\}", t, flags=re.DOTALL) | |
| if m: | |
| try: | |
| obj = json.loads(m.group(0)) | |
| return obj if isinstance(obj, dict) else {} | |
| except Exception: | |
| return {} | |
| return {} | |