import os from typing import Iterator, List, Tuple def get_default_local_model() -> str: return os.getenv("LOCAL_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0") class LocalHFBackend: def __init__(self, model_name: str): self.model_name = model_name from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch self.torch = torch self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) self.streamer_cls = TextIteratorStreamer def _build_prompt(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> str: parts = [f"<|system|>\n{system_prompt}\n"] for u, a in history: if u: parts.append(f"<|user|>\n{u}\n") if a: parts.append(f"<|assistant|>\n{a}\n") parts.append(f"<|user|>\n{user_msg}\n\n<|assistant|>\n") return "".join(parts) def generate_stream( self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int, ) -> Iterator[str]: from threading import Thread from transformers import StoppingCriteria, StoppingCriteriaList prompt = self._build_prompt(system_prompt, history, user_msg) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) class StopOnAssistantTag(StoppingCriteria): def __call__(self, input_ids, scores, **kwargs): text = self.tokenizer.decode(input_ids[0].tolist()[-20:]) return "<|user|>" in text streamer = self.streamer_cls(self.tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True if temperature > 0 else False, eos_token_id=self.tokenizer.eos_token_id, stopping_criteria=StoppingCriteriaList([StopOnAssistantTag()]), ) thread = Thread(target=self.model.generate, kwargs=gen_kwargs) thread.start() buf = [] for token in streamer: buf.append(token) yield "".join(buf)