File size: 5,740 Bytes
b2f54cb
 
 
 
 
 
 
 
 
 
 
 
1200b35
 
 
 
2fb77a4
1200b35
 
 
 
 
 
2fb77a4
 
1200b35
 
b2f54cb
 
 
 
 
 
 
 
 
 
2fb77a4
b2f54cb
 
 
2fb77a4
1200b35
2fb77a4
b2f54cb
1200b35
 
b2f54cb
 
1200b35
b2f54cb
 
 
 
 
 
 
 
 
 
 
1200b35
b2f54cb
 
 
 
 
 
 
 
 
 
 
1200b35
 
b2f54cb
 
 
1200b35
b2f54cb
1200b35
b2f54cb
 
1200b35
 
b2f54cb
1200b35
 
 
 
b2f54cb
1200b35
 
 
 
 
b2f54cb
1200b35
b2f54cb
2fb77a4
1200b35
 
 
 
 
 
 
 
2fb77a4
b2f54cb
1200b35
b2f54cb
1200b35
 
 
 
b2f54cb
1200b35
2fb77a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1200b35
 
b2f54cb
 
 
 
 
 
 
 
 
 
 
 
1200b35
 
b2f54cb
1200b35
 
 
 
 
 
 
b2f54cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python3
from __future__ import annotations

import json
import os
import re
from typing import Any, Dict, Optional

import requests


class LocalLLMClient:
    """
    Local LLM client abstraction (NO OpenAI/Claude).

    Providers:
      - ollama        : localhost Ollama (Windows/local dev)
      - transformers  : in-process HF Transformers (Hugging Face Spaces)

    Env:
      LOCAL_LLM_PROVIDER=ollama|transformers

    Transformers:
      HF_LLM_MODEL=Qwen/Qwen2.5-0.5B-Instruct  (recommended)
      HF_MAX_NEW_TOKENS=220
    """

    def __init__(
        self,
        provider: Optional[str] = None,
        model: Optional[str] = None,
        host: Optional[str] = None,
        timeout_sec: int = 120,
    ):
        self.provider = (provider or os.getenv("LOCAL_LLM_PROVIDER", "ollama")).lower().strip()
        self.timeout_sec = int(os.getenv("LLM_TIMEOUT_SEC", str(timeout_sec)))

        # Ollama
        self.host = (host or os.getenv("OLLAMA_HOST", "http://localhost:11434")).strip()
        self.model = (model or os.getenv("OLLAMA_MODEL", "llama3.2:1b")).strip()

        # Transformers
        self.hf_model_id = (os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")).strip()
        self.hf_max_new_tokens = int(os.getenv("HF_MAX_NEW_TOKENS", "220"))

        self._tok = None
        self._mdl = None

        if self.provider not in {"ollama", "transformers"}:
            raise ValueError(f"Unsupported LOCAL_LLM_PROVIDER='{self.provider}'. Use ollama or transformers.")

    def generate(self, prompt: str, temperature: float = 0.2, max_tokens: int = 900) -> str:
        prompt = (prompt or "").strip()
        if not prompt:
            return ""

        if self.provider == "ollama":
            return self._generate_ollama(prompt, temperature=temperature, max_tokens=max_tokens)

        return self._generate_transformers(prompt, temperature=temperature, max_tokens=max_tokens)

    # ---------------- Ollama ----------------
    def _generate_ollama(self, prompt: str, temperature: float, max_tokens: int) -> str:
        url = self.host.rstrip("/") + "/api/generate"
        payload: Dict[str, Any] = {
            "model": self.model,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": float(temperature),
                "num_predict": int(max_tokens),
            },
        }
        r = requests.post(url, json=payload, timeout=self.timeout_sec)
        r.raise_for_status()
        data = r.json()
        return (data.get("response") or "").strip()

    # -------------- Transformers (HF) --------------
    def _lazy_init_hf(self):
        if self._tok is not None and self._mdl is not None:
            return

        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        try:
            torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "2")))
        except Exception:
            pass

        self._tok = AutoTokenizer.from_pretrained(self.hf_model_id, use_fast=True)
        self._mdl = AutoModelForCausalLM.from_pretrained(
            self.hf_model_id,
            torch_dtype=torch.float32,
            device_map=None,
        )
        self._mdl.eval()

    def _chat_wrap(self, prompt: str) -> str:
        if self._tok is None:
            return prompt

        if hasattr(self._tok, "apply_chat_template"):
            msgs = [
                {"role": "system", "content": "You are a helpful, precise medical aesthetics research assistant."},
                {"role": "user", "content": prompt},
            ]
            return self._tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

        return "System: You are a helpful assistant.\nUser: " + prompt + "\nAssistant:"

    def _generate_transformers(self, prompt: str, temperature: float, max_tokens: int) -> str:
        self._lazy_init_hf()

        import torch

        max_new = min(int(max_tokens), int(self.hf_max_new_tokens))
        wrapped = self._chat_wrap(prompt)

        # Tokenize and remember prompt token length so we only decode NEW tokens
        inp = self._tok(wrapped, return_tensors="pt", truncation=True, max_length=2048)
        prompt_len = int(inp["input_ids"].shape[-1])

        with torch.inference_mode():
            out = self._mdl.generate(
                **inp,
                do_sample=False,  # deterministic -> less garbage
                max_new_tokens=max_new,
                repetition_penalty=1.08,
                eos_token_id=self._tok.eos_token_id,
            )

        gen_ids = out[0][prompt_len:]  # only the new tokens
        text = self._tok.decode(gen_ids, skip_special_tokens=True).strip()

        # Final cleanup: strip any accidental role labels
        text = re.sub(r"^\s*(assistant|system|user)\s*[:\-]\s*", "", text, flags=re.IGNORECASE)
        return text.strip()

    # ---------------- JSON helpers ----------------
    @staticmethod
    def _strip_code_fences(text: str) -> str:
        t = text.strip()
        t = re.sub(r"^```(?:json)?\s*", "", t, flags=re.IGNORECASE)
        t = re.sub(r"\s*```$", "", t)
        return t.strip()

    def safe_json_loads(self, text: str) -> Dict[str, Any]:
        if not text:
            return {}
        t = self._strip_code_fences(text)
        try:
            obj = json.loads(t)
            return obj if isinstance(obj, dict) else {}
        except Exception:
            m = re.search(r"\{.*\}", t, flags=re.DOTALL)
            if m:
                try:
                    obj = json.loads(m.group(0))
                    return obj if isinstance(obj, dict) else {}
                except Exception:
                    return {}
        return {}