| import os |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| from threading import Thread |
| import torch |
|
|
| class HuggingFaceLLM: |
| def __init__(self, model_name="mistralai/Mistral-7B-Instruct-v0.1"): |
| hf_token = os.getenv("HUGGINGFACE_TOKEN") |
| if not hf_token: |
| raise ValueError("HUGGINGFACE_TOKEN not found in environment") |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| use_auth_token=hf_token |
| ) |
|
|
| async def astream(self, messages): |
| prompt = "" |
| for msg in messages: |
| prompt += msg["content"] + "\n" |
|
|
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
| generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=500, do_sample=True) |
|
|
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| for new_text in streamer: |
| yield new_text |