workbench / models /transformers_text.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
4.16 kB
from __future__ import annotations
import importlib.util
from dataclasses import dataclass
from typing import Any, cast
from models.base import BackendStatus
from models.hf_components import load_tokenizer_and_causal_lm
from models.model_catalog import ModelInfo
@dataclass(frozen=True)
class TransformersTextConfig:
trust_remote_code: bool = False
device_map: str = "auto"
torch_dtype: str = "auto"
max_new_tokens: int = 256
temperature: float = 0.7
do_sample: bool = True
class TransformersTextService:
"""Optional Transformers text backend with lazy model loading."""
def __init__(
self,
model: ModelInfo,
config: TransformersTextConfig | None = None,
) -> None:
self.model = model
self.config = config or TransformersTextConfig(
trust_remote_code=model.trust_remote_code
)
self._model = None
self._tokenizer = None
@staticmethod
def status() -> BackendStatus:
if importlib.util.find_spec("transformers") is None:
return BackendStatus(
"transformers",
False,
"Python package transformers is not installed in the current environment.",
)
return BackendStatus("transformers", True, "Transformers package is installed.")
def chat(self, system_prompt: str, user_prompt: str) -> str:
status = self.status()
if not status.available:
return (
"[Transformers unavailable]\n\n"
f"{status.detail}\n\n"
"Install transformers/torch and select this backend only when local hardware "
"can load the chosen model."
)
model, tokenizer = self._load_components()
prompt = self._format_chat_prompt(tokenizer, system_prompt, user_prompt)
encoded = tokenizer(prompt, return_tensors="pt")
encoded = self._move_encoded_to_model_device(encoded, model)
outputs = model.generate(**encoded, **self.generation_kwargs())
decoded = cast(str, tokenizer.decode(outputs[0], skip_special_tokens=True))
return decoded[len(prompt) :].strip() or decoded.strip()
def stream_chat(self, system_prompt: str, user_prompt: str) -> list[str]:
response = self.chat(system_prompt, user_prompt)
return [token for token in response.split(" ") if token]
def generation_kwargs(self) -> dict[str, Any]:
return {
"max_new_tokens": self.config.max_new_tokens,
"temperature": self.config.temperature,
"do_sample": self.config.do_sample,
}
def _load_components(self):
if self._model is not None and self._tokenizer is not None:
return self._model, self._tokenizer
self._model, self._tokenizer = load_tokenizer_and_causal_lm(
self.model,
self.config.trust_remote_code,
self.config.device_map,
self.config.torch_dtype,
)
return self._model, self._tokenizer
@staticmethod
def _format_chat_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})
if hasattr(tokenizer, "apply_chat_template"):
rendered = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return str(rendered)
parts = [f"{message['role']}: {message['content']}" for message in messages]
parts.append("assistant:")
return "\n".join(parts)
@staticmethod
def _move_encoded_to_model_device(encoded, model):
device = getattr(model, "device", None)
if device is None:
return encoded
if hasattr(encoded, "to"):
return encoded.to(device)
return {
key: value.to(device) if hasattr(value, "to") else value
for key, value in encoded.items()
}