eshwar06 commited on
Commit
6f18a74
·
verified ·
1 Parent(s): 337cff7

Upload Linly.py

Browse files
Files changed (1) hide show
  1. LLM/Linly.py +89 -0
LLM/Linly.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import requests
4
+ import json
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from configs import ip, api_port, model_path
7
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
8
+
9
+ class Linly:
10
+ def __init__(self, mode='api', model_path="Linly-AI/Chinese-LLaMA-2-7B-hf", prefix_prompt = '''请用少于25个字回答以下问题\n\n'''):
11
+ # mode = api need
12
+ # 定义设置的api的服务器,首先记得运行Linly-api-fast.py 填入ip地址和端口号
13
+ self.url = f"http://{ip}:{api_port}" # local server: http://ip:port
14
+ self.headers = {
15
+ "Content-Type": "application/json"
16
+ }
17
+ self.data = {
18
+ "question": "北京有什么好玩的地方?"
19
+ }
20
+ # 全局设定的prompt
21
+ self.prefix_prompt = prefix_prompt
22
+ self.mode = mode
23
+ if mode != 'api':
24
+ self.model, self.tokenizer = self.init_model(model_path)
25
+ self.history = []
26
+
27
+ def init_model(self, path = "Linly-AI/Chinese-LLaMA-2-7B-hf"):
28
+ model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda:0",
29
+ torch_dtype=torch.bfloat16, trust_remote_code=True)
30
+ tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False, trust_remote_code=True)
31
+ return model, tokenizer
32
+
33
+ def generate(self, question, system_prompt=""):
34
+ if self.mode != 'api':
35
+ self.data["question"] = self.message_to_prompt(question, system_prompt)
36
+ inputs = self.tokenizer(self.data["question"], return_tensors="pt").to("cuda:0")
37
+ try:
38
+ generate_ids = self.model.generate(inputs.input_ids,
39
+ max_new_tokens=2048,
40
+ do_sample=True,
41
+ top_k=20,
42
+ top_p=0.84,
43
+ temperature=1,
44
+ repetition_penalty=1.15,
45
+ eos_token_id=2,
46
+ bos_token_id=1,
47
+ pad_token_id=0)
48
+ response = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
49
+ response = response.split("### Response:")[-1]
50
+ return response
51
+ except:
52
+ return "对不起,你的请求出错了,请再次尝试。\nSorry, your request has encountered an error. Please try again.\n"
53
+ elif self.mode == 'api':
54
+ return self.predict_api(question)
55
+
56
+ def message_to_prompt(self, message, system_prompt=""):
57
+ system_prompt = self.prefix_prompt + system_prompt
58
+ for interaction in self.history:
59
+ user_prompt, bot_prompt = str(interaction[0]).strip(' '), str(interaction[1]).strip(' ')
60
+ system_prompt = f"{system_prompt} User: {user_prompt} Bot: {bot_prompt}"
61
+ prompt = f"{system_prompt} ### Instruction:{message.strip()} ### Response:"
62
+ return prompt
63
+
64
+ def predict_api(self, question):
65
+ # FastAPI Predict 调用API来进行预测
66
+ self.data["question"] = question
67
+ headers = {'Content-Type': 'application/json'}
68
+ data = {"prompt": question}
69
+ response = requests.post(url=self.url, headers=headers, data=json.dumps(data))
70
+ return response.json()['response']
71
+
72
+ def chat(self, system_prompt, message, history):
73
+ self.history = history
74
+ prompt = self.message_to_prompt(message, system_prompt)
75
+ response = self.generate(prompt)
76
+ self.history.append([message, response])
77
+ return response, self.history
78
+
79
+ def clear_history(self):
80
+ # 清空历史记录
81
+ self.history = []
82
+
83
+ def test():
84
+ llm = Linly(mode='offline',model_path='../Linly-AI/Chinese-LLaMA-2-7B-hf')
85
+ answer = llm.generate("如何应对压力?")
86
+ print(answer)
87
+
88
+ if __name__ == '__main__':
89
+ test()