longdiyao commited on
Commit
d85b784
·
verified ·
1 Parent(s): 8c35de8

Update src/models/chatglm2.py

Browse files
Files changed (1) hide show
  1. src/models/chatglm2.py +0 -24
src/models/chatglm2.py CHANGED
@@ -1,24 +0,0 @@
1
- # src/models/chatglm2.py
2
- from transformers import AutoTokenizer, AutoModel
3
- import torch
4
- import os
5
-
6
- class ChatGLM2:
7
- def __init__(self):
8
- model_id = "THUDM/chatglm2-6b-int4"
9
- offload_path = "./offload_chatglm2"
10
- os.makedirs(offload_path, exist_ok=True)
11
-
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
13
- self.model = AutoModel.from_pretrained(
14
- model_id,
15
- trust_remote_code=True,
16
- torch_dtype=torch.float16,
17
- device_map="auto",
18
- offload_folder=offload_path
19
- ).eval()
20
-
21
- def generate(self, prompt):
22
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
23
- outputs = self.model.generate(**inputs, max_new_tokens=256)
24
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)