Nurisslam commited on
Commit
b3765b1
·
verified ·
1 Parent(s): c024ceb

Create huggingface_utils / chatmodel.py

Browse files
Files changed (1) hide show
  1. huggingface_utils / chatmodel.py +33 -0
huggingface_utils / chatmodel.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ from threading import Thread
4
+ import torch
5
+
6
+ class HuggingFaceLLM:
7
+ def __init__(self, model_name="mistralai/Mistral-7B-Instruct-v0.1"):
8
+ hf_token = os.getenv("HUGGINGFACE_TOKEN")
9
+ if not hf_token:
10
+ raise ValueError("HUGGINGFACE_TOKEN not found in environment")
11
+
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
13
+ self.model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ torch_dtype=torch.float16,
16
+ device_map="auto",
17
+ use_auth_token=hf_token
18
+ )
19
+
20
+ async def astream(self, messages):
21
+ prompt = ""
22
+ for msg in messages:
23
+ prompt += msg["content"] + "\n"
24
+
25
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
26
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
27
+ generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=500, do_sample=True)
28
+
29
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
30
+ thread.start()
31
+
32
+ for new_text in streamer:
33
+ yield new_text