Buckets:
| import os | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import torch | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import AIMessage | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| load_dotenv() | |
| _client = None | |
| class LocalHFConfig: | |
| model_id: str | |
| device: str | |
| dtype: str | |
| load_in_4bit: bool | |
| max_new_tokens: int | |
| temperature: float | |
| top_p: float | |
| repetition_penalty: float | |
| class LocalHFChat: | |
| def __init__(self, config: LocalHFConfig) -> None: | |
| self.config = config | |
| self.tokenizer = AutoTokenizer.from_pretrained(config.model_id, use_fast=True) | |
| model_kwargs: dict[str, Any] = {"trust_remote_code": True} | |
| if config.device == "cuda" and config.load_in_4bit: | |
| model_kwargs.update( | |
| { | |
| "load_in_4bit": True, | |
| "device_map": "auto", | |
| } | |
| ) | |
| else: | |
| if config.device == "cuda": | |
| model_kwargs["torch_dtype"] = _resolve_dtype(config.dtype) | |
| model_kwargs["device_map"] = "auto" | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained(config.model_id, **model_kwargs) | |
| except Exception as exc: | |
| if config.device == "cuda" and config.load_in_4bit: | |
| # Graceful fallback if bitsandbytes / 4-bit setup is unavailable. | |
| fallback_kwargs = { | |
| "trust_remote_code": True, | |
| "torch_dtype": _resolve_dtype(config.dtype), | |
| "device_map": "auto", | |
| } | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| config.model_id, **fallback_kwargs | |
| ) | |
| else: | |
| raise exc | |
| if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self._device = "cuda" if torch.cuda.is_available() and config.device == "cuda" else "cpu" | |
| if self._device == "cpu": | |
| self.model.to("cpu") | |
| def _to_chat_messages(self, messages: list) -> list[dict[str, str]]: | |
| chat = [] | |
| for msg in messages: | |
| if isinstance(msg, dict): | |
| role = str(msg.get("role", "user")).strip().lower() or "user" | |
| content = str(msg.get("content", "")) | |
| if role not in {"system", "user", "assistant"}: | |
| role = "user" | |
| chat.append({"role": role, "content": content}) | |
| continue | |
| msg_type = getattr(msg, "type", "human") | |
| role = "user" | |
| if msg_type == "system": | |
| role = "system" | |
| elif msg_type in ("ai", "assistant"): | |
| role = "assistant" | |
| chat.append({"role": role, "content": str(msg.content)}) | |
| return chat | |
| def _build_prompt(self, messages: list) -> str: | |
| chat = self._to_chat_messages(messages) | |
| if hasattr(self.tokenizer, "apply_chat_template"): | |
| return self.tokenizer.apply_chat_template( | |
| chat, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # Fallback prompt format if tokenizer has no chat template | |
| lines = [] | |
| for m in chat: | |
| lines.append(f"{m['role'].upper()}: {m['content']}") | |
| lines.append("ASSISTANT:") | |
| return "\n".join(lines) | |
| def _generate_text(self, messages: list) -> str: | |
| prompt = self._build_prompt(messages) | |
| inputs = self.tokenizer(prompt, return_tensors="pt") | |
| if self._device == "cuda": | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| do_sample = self.config.temperature > 0 | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=self.config.max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=self.config.temperature if do_sample else None, | |
| top_p=self.config.top_p if do_sample else None, | |
| repetition_penalty=self.config.repetition_penalty, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| generated_ids = output_ids[0][prompt_len:] | |
| return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| def invoke(self, messages: list) -> AIMessage: | |
| return AIMessage(content=self._generate_text(messages)) | |
| def complete(self, messages: list[dict[str, str]]) -> str: | |
| """Return plain-text completion for role/content message dictionaries.""" | |
| return self._generate_text(messages) | |
| async def astream(self, messages: list): | |
| # Lightweight async streaming fallback: splits generated text into chunks. | |
| text = self._generate_text(messages) | |
| for token in text.split(" "): | |
| if token: | |
| yield AIMessage(content=f"{token} ") | |
| def _resolve_dtype(dtype: str): | |
| mapping = { | |
| "auto": torch.float16 if torch.cuda.is_available() else torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| return mapping.get(dtype, mapping["auto"]) | |
| def _resolve_device(device: str) -> str: | |
| if device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return device | |
| def get_llm(streaming: bool = False) -> LocalHFChat: | |
| # `streaming` is accepted for API compatibility with existing endpoints. | |
| _ = streaming | |
| global _client | |
| if _client is None: | |
| model_id = os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct") | |
| cfg = LocalHFConfig( | |
| model_id=model_id, | |
| device=_resolve_device(os.getenv("HF_LLM_DEVICE", "auto")), | |
| dtype=os.getenv("HF_LLM_DTYPE", "auto"), | |
| load_in_4bit=os.getenv("HF_LLM_LOAD_IN_4BIT", "false").lower() == "true", | |
| max_new_tokens=int(os.getenv("HF_LLM_MAX_NEW_TOKENS", 512)), | |
| temperature=float(os.getenv("HF_LLM_TEMPERATURE", 0.1)), | |
| top_p=float(os.getenv("HF_LLM_TOP_P", 0.9)), | |
| repetition_penalty=float(os.getenv("HF_LLM_REPETITION_PENALTY", 1.05)), | |
| ) | |
| _client = LocalHFChat(cfg) | |
| return _client | |
Xet Storage Details
- Size:
- 6.44 kB
- Xet hash:
- f3b885b05ddb54dd75ac7d3118cf171426a00e0ef4faf1a558d5e47ef559e113
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.