Muhammadidrees commited on
Commit
7a855fe
·
verified ·
1 Parent(s): 2856d25

Upload chat.py

Browse files
Files changed (1) hide show
  1. chat.py +95 -0
chat.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, itertools, bisect, gc
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
+ import transformers
5
+ import torch
6
+ from accelerate import Accelerator
7
+ import accelerate
8
+ import time
9
+
10
+ model = None
11
+ tokenizer = None
12
+ generator = None
13
+
14
+ def load_model(model_name, eight_bit=0, device_map="auto"):
15
+ global model, tokenizer, generator
16
+
17
+ print("Loading "+model_name+"...")
18
+
19
+ if device_map == "zero":
20
+ device_map = "balanced_low_0"
21
+
22
+ # config
23
+ gpu_count = torch.cuda.device_count()
24
+ print('gpu_count', gpu_count)
25
+
26
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name)
27
+ model = transformers.LLaMAForCausalLM.from_pretrained(
28
+ model_name,
29
+ #device_map=device_map,
30
+ #device_map="auto",
31
+ torch_dtype=torch.float16,
32
+ #max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"},
33
+ #load_in_8bit=eight_bit,
34
+ low_cpu_mem_usage=True,
35
+ load_in_8bit=False,
36
+ cache_dir="cache"
37
+ ).cuda()
38
+
39
+ generator = model.generate
40
+
41
+ load_model("./pretrained")
42
+
43
+ history = []
44
+
45
+ def go():
46
+ invitation = "Assistant: "
47
+ human_invitation = "Human: "
48
+
49
+ # input
50
+ msg = input(human_invitation)
51
+ print("")
52
+
53
+ history.append(human_invitation + msg)
54
+
55
+ fulltext = "\n\n".join(history) + "\n\n" + invitation
56
+
57
+ # print('SENDING==========')
58
+ # print(fulltext)
59
+ # print('==========')
60
+
61
+ generated_text = ""
62
+ gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda()
63
+ in_tokens = len(gen_in)
64
+ with torch.no_grad():
65
+ generated_ids = generator(
66
+ gen_in,
67
+ max_new_tokens=200,
68
+ use_cache=True,
69
+ pad_token_id=tokenizer.eos_token_id,
70
+ num_return_sequences=1,
71
+ do_sample=True,
72
+ repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx:
73
+ temperature=0.5, # default: 1.0
74
+ top_k = 50, # default: 50
75
+ top_p = 1.0, # default: 1.0
76
+ early_stopping=True,
77
+ )
78
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element?
79
+
80
+ text_without_prompt = generated_text[len(fulltext):]
81
+
82
+ response = text_without_prompt
83
+
84
+ response = response.split(human_invitation)[0]
85
+
86
+ response.strip()
87
+
88
+ print(invitation + response)
89
+
90
+ print("")
91
+
92
+ history.append(invitation + response)
93
+
94
+ while True:
95
+ go()