| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| class TinyLlama: | |
| def __init__(self) -> None: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| load_in_4bit=True, | |
| device_map="auto", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| print(f"LLM loaded to {self.model.device}") | |
| self._messages = [] | |
| def __call__(self, messages, *args, **kwds): | |
| tokenized_chat = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self.tokenizer(tokenized_chat, return_tensors="pt").to( | |
| self.model.device | |
| ) | |
| outputs = self.model.generate( | |
| **inputs, | |
| use_cache=True, | |
| max_length=1000, | |
| min_length=10, | |
| temperature=0.7, | |
| num_return_sequences=1, | |
| do_sample=True, | |
| ) | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_text | |