Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Iterator, List, Tuple | |
| def get_default_local_model() -> str: | |
| return os.getenv("LOCAL_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| class LocalHFBackend: | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| import torch | |
| self.torch = torch | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| self.streamer_cls = TextIteratorStreamer | |
| def _build_prompt(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> str: | |
| parts = [f"<|system|>\n{system_prompt}\n</s>"] | |
| for u, a in history: | |
| if u: | |
| parts.append(f"<|user|>\n{u}\n</s>") | |
| if a: | |
| parts.append(f"<|assistant|>\n{a}\n</s>") | |
| parts.append(f"<|user|>\n{user_msg}\n</s>\n<|assistant|>\n") | |
| return "".join(parts) | |
| def generate_stream( | |
| self, | |
| system_prompt: str, | |
| history: List[Tuple[str, str]], | |
| user_msg: str, | |
| temperature: float, | |
| max_new_tokens: int, | |
| ) -> Iterator[str]: | |
| from threading import Thread | |
| from transformers import StoppingCriteria, StoppingCriteriaList | |
| prompt = self._build_prompt(system_prompt, history, user_msg) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| class StopOnAssistantTag(StoppingCriteria): | |
| def __call__(self, input_ids, scores, **kwargs): | |
| text = self.tokenizer.decode(input_ids[0].tolist()[-20:]) | |
| return "</s><|user|>" in text | |
| streamer = self.streamer_cls(self.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True if temperature > 0 else False, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| stopping_criteria=StoppingCriteriaList([StopOnAssistantTag()]), | |
| ) | |
| thread = Thread(target=self.model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| buf = [] | |
| for token in streamer: | |
| buf.append(token) | |
| yield "".join(buf) | |