File size: 4,163 Bytes
7f9dfed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import annotations

import importlib.util
from dataclasses import dataclass
from typing import Any, cast

from models.base import BackendStatus
from models.hf_components import load_tokenizer_and_causal_lm
from models.model_catalog import ModelInfo


@dataclass(frozen=True)
class TransformersTextConfig:
    trust_remote_code: bool = False
    device_map: str = "auto"
    torch_dtype: str = "auto"
    max_new_tokens: int = 256
    temperature: float = 0.7
    do_sample: bool = True


class TransformersTextService:
    """Optional Transformers text backend with lazy model loading."""

    def __init__(
        self,
        model: ModelInfo,
        config: TransformersTextConfig | None = None,
    ) -> None:
        self.model = model
        self.config = config or TransformersTextConfig(
            trust_remote_code=model.trust_remote_code
        )
        self._model = None
        self._tokenizer = None

    @staticmethod
    def status() -> BackendStatus:
        if importlib.util.find_spec("transformers") is None:
            return BackendStatus(
                "transformers",
                False,
                "Python package transformers is not installed in the current environment.",
            )
        return BackendStatus("transformers", True, "Transformers package is installed.")

    def chat(self, system_prompt: str, user_prompt: str) -> str:
        status = self.status()
        if not status.available:
            return (
                "[Transformers unavailable]\n\n"
                f"{status.detail}\n\n"
                "Install transformers/torch and select this backend only when local hardware "
                "can load the chosen model."
            )

        model, tokenizer = self._load_components()
        prompt = self._format_chat_prompt(tokenizer, system_prompt, user_prompt)
        encoded = tokenizer(prompt, return_tensors="pt")
        encoded = self._move_encoded_to_model_device(encoded, model)
        outputs = model.generate(**encoded, **self.generation_kwargs())
        decoded = cast(str, tokenizer.decode(outputs[0], skip_special_tokens=True))
        return decoded[len(prompt) :].strip() or decoded.strip()

    def stream_chat(self, system_prompt: str, user_prompt: str) -> list[str]:
        response = self.chat(system_prompt, user_prompt)
        return [token for token in response.split(" ") if token]

    def generation_kwargs(self) -> dict[str, Any]:
        return {
            "max_new_tokens": self.config.max_new_tokens,
            "temperature": self.config.temperature,
            "do_sample": self.config.do_sample,
        }

    def _load_components(self):
        if self._model is not None and self._tokenizer is not None:
            return self._model, self._tokenizer

        self._model, self._tokenizer = load_tokenizer_and_causal_lm(
            self.model,
            self.config.trust_remote_code,
            self.config.device_map,
            self.config.torch_dtype,
        )
        return self._model, self._tokenizer

    @staticmethod
    def _format_chat_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
        messages = []
        if system_prompt.strip():
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": user_prompt})

        if hasattr(tokenizer, "apply_chat_template"):
            rendered = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
            return str(rendered)

        parts = [f"{message['role']}: {message['content']}" for message in messages]
        parts.append("assistant:")
        return "\n".join(parts)

    @staticmethod
    def _move_encoded_to_model_device(encoded, model):
        device = getattr(model, "device", None)
        if device is None:
            return encoded
        if hasattr(encoded, "to"):
            return encoded.to(device)
        return {
            key: value.to(device) if hasattr(value, "to") else value
            for key, value in encoded.items()
        }