lingbionlp commited on
Commit
0acea94
·
verified ·
1 Parent(s): 3c51fbe

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +6 -0
  2. taiyi2_chat.py +59 -0
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ ms_swift==2.6.1
3
+ transformers==4.44.0
4
+ transformers-stream-generator==0.0.5
5
+ vllm==0.6.0
6
+ vllm-flash-attn==2.6.1
taiyi2_chat.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import torch
4
+ torch.cuda.empty_cache()
5
+ import time
6
+ import json
7
+ from swift.llm import (
8
+ ModelType, get_vllm_engine, get_default_template_type,
9
+ get_template, inference_vllm,VllmGenerationConfig,inference_vllm
10
+ )
11
+
12
+ model_path = "../Models/Taiyi2-chat" #model path
13
+ device = "cuda:0"
14
+
15
+ model_type = ModelType.glm4_9b_chat
16
+ llm_engine = get_vllm_engine(model_type, torch.bfloat16, model_id_or_path=model_path, gpu_memory_utilization=0.9,max_model_len=8192)
17
+ template_type = get_default_template_type(model_type)
18
+ template = get_template(template_type, llm_engine.hf_tokenizer)
19
+
20
+
21
+
22
+ #Chat
23
+ def generate(message, history, repetition_penalty=1.05, max_tokens=500, temperature=0.3,
24
+ top_p=0.7, top_k=20):
25
+
26
+ request_list = [{'query': message, 'history':history}]
27
+ #print(request_list)
28
+ response = inference_vllm(
29
+ llm_engine,
30
+ template,
31
+ request_list,
32
+ generation_config=VllmGenerationConfig(
33
+ repetition_penalty=repetition_penalty,
34
+ presence_penalty=True,
35
+ max_tokens=max_tokens,
36
+ temperature=temperature,
37
+ top_p=top_p,
38
+ top_k=top_k,
39
+ )
40
+ )
41
+
42
+ #print(response)
43
+ return response[0]
44
+
45
+
46
+ if __name__ == '__main__':
47
+
48
+ history = []
49
+ while True:
50
+ message = input("Input: ")
51
+ if message == "END":
52
+ print("END!")
53
+ break
54
+ response = generate(message, history)
55
+ print("Output:"+response['response'])
56
+ history = response['history']
57
+
58
+
59
+