meet4150's picture
download
raw
6.44 kB
import os
from dataclasses import dataclass
from typing import Any
import torch
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from transformers import AutoModelForCausalLM, AutoTokenizer
load_dotenv()
_client = None
@dataclass
class LocalHFConfig:
model_id: str
device: str
dtype: str
load_in_4bit: bool
max_new_tokens: int
temperature: float
top_p: float
repetition_penalty: float
class LocalHFChat:
def __init__(self, config: LocalHFConfig) -> None:
self.config = config
self.tokenizer = AutoTokenizer.from_pretrained(config.model_id, use_fast=True)
model_kwargs: dict[str, Any] = {"trust_remote_code": True}
if config.device == "cuda" and config.load_in_4bit:
model_kwargs.update(
{
"load_in_4bit": True,
"device_map": "auto",
}
)
else:
if config.device == "cuda":
model_kwargs["torch_dtype"] = _resolve_dtype(config.dtype)
model_kwargs["device_map"] = "auto"
try:
self.model = AutoModelForCausalLM.from_pretrained(config.model_id, **model_kwargs)
except Exception as exc:
if config.device == "cuda" and config.load_in_4bit:
# Graceful fallback if bitsandbytes / 4-bit setup is unavailable.
fallback_kwargs = {
"trust_remote_code": True,
"torch_dtype": _resolve_dtype(config.dtype),
"device_map": "auto",
}
self.model = AutoModelForCausalLM.from_pretrained(
config.model_id, **fallback_kwargs
)
else:
raise exc
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self._device = "cuda" if torch.cuda.is_available() and config.device == "cuda" else "cpu"
if self._device == "cpu":
self.model.to("cpu")
def _to_chat_messages(self, messages: list) -> list[dict[str, str]]:
chat = []
for msg in messages:
if isinstance(msg, dict):
role = str(msg.get("role", "user")).strip().lower() or "user"
content = str(msg.get("content", ""))
if role not in {"system", "user", "assistant"}:
role = "user"
chat.append({"role": role, "content": content})
continue
msg_type = getattr(msg, "type", "human")
role = "user"
if msg_type == "system":
role = "system"
elif msg_type in ("ai", "assistant"):
role = "assistant"
chat.append({"role": role, "content": str(msg.content)})
return chat
def _build_prompt(self, messages: list) -> str:
chat = self._to_chat_messages(messages)
if hasattr(self.tokenizer, "apply_chat_template"):
return self.tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
# Fallback prompt format if tokenizer has no chat template
lines = []
for m in chat:
lines.append(f"{m['role'].upper()}: {m['content']}")
lines.append("ASSISTANT:")
return "\n".join(lines)
def _generate_text(self, messages: list) -> str:
prompt = self._build_prompt(messages)
inputs = self.tokenizer(prompt, return_tensors="pt")
if self._device == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
do_sample = self.config.temperature > 0
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=self.config.max_new_tokens,
do_sample=do_sample,
temperature=self.config.temperature if do_sample else None,
top_p=self.config.top_p if do_sample else None,
repetition_penalty=self.config.repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
prompt_len = inputs["input_ids"].shape[1]
generated_ids = output_ids[0][prompt_len:]
return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
def invoke(self, messages: list) -> AIMessage:
return AIMessage(content=self._generate_text(messages))
def complete(self, messages: list[dict[str, str]]) -> str:
"""Return plain-text completion for role/content message dictionaries."""
return self._generate_text(messages)
async def astream(self, messages: list):
# Lightweight async streaming fallback: splits generated text into chunks.
text = self._generate_text(messages)
for token in text.split(" "):
if token:
yield AIMessage(content=f"{token} ")
def _resolve_dtype(dtype: str):
mapping = {
"auto": torch.float16 if torch.cuda.is_available() else torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
return mapping.get(dtype, mapping["auto"])
def _resolve_device(device: str) -> str:
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return device
def get_llm(streaming: bool = False) -> LocalHFChat:
# `streaming` is accepted for API compatibility with existing endpoints.
_ = streaming
global _client
if _client is None:
model_id = os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
cfg = LocalHFConfig(
model_id=model_id,
device=_resolve_device(os.getenv("HF_LLM_DEVICE", "auto")),
dtype=os.getenv("HF_LLM_DTYPE", "auto"),
load_in_4bit=os.getenv("HF_LLM_LOAD_IN_4BIT", "false").lower() == "true",
max_new_tokens=int(os.getenv("HF_LLM_MAX_NEW_TOKENS", 512)),
temperature=float(os.getenv("HF_LLM_TEMPERATURE", 0.1)),
top_p=float(os.getenv("HF_LLM_TOP_P", 0.9)),
repetition_penalty=float(os.getenv("HF_LLM_REPETITION_PENALTY", 1.05)),
)
_client = LocalHFChat(cfg)
return _client

Xet Storage Details

Size:
6.44 kB
·
Xet hash:
f3b885b05ddb54dd75ac7d3118cf171426a00e0ef4faf1a558d5e47ef559e113

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.