warleagle commited on
Commit
fe83af5
·
verified ·
1 Parent(s): 521a0d9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +59 -0
  2. conversation.py +43 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from peft import PeftModel, PeftConfig
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
5
+ from conversation import Conversation
6
+
7
+ MODEL_NAME = "warleagle/medical_chat_saiga"
8
+
9
+ config = PeftConfig.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ config.base_model_name_or_path,
12
+ load_in_8bit=True,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto"
15
+ )
16
+
17
+ model = PeftModel.from_pretrained(
18
+ model,
19
+ MODEL_NAME,
20
+ torch_dtype=torch.float16
21
+ )
22
+ model.eval()
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
25
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
26
+ generation_config.max_new_tokens = 70
27
+
28
+
29
+ def generate(model, tokenizer, prompt, generation_config):
30
+ data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
31
+ data = {k: v.to(model.device) for k, v in data.items()}
32
+ output_ids = model.generate(
33
+ **data,
34
+ generation_config=generation_config
35
+ )[0]
36
+ output_ids = output_ids[len(data["input_ids"][0]):]
37
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
38
+ return output.strip()
39
+
40
+ def predict(input_data, temp):
41
+ generation_config.temperature = temp
42
+
43
+ conversation = Conversation()
44
+ conversation.add_user_message(input_data)
45
+ prompt = conversation.get_prompt(tokenizer)
46
+
47
+ output = generate(model, tokenizer, prompt, generation_config)
48
+ return output
49
+
50
+ io = gr.Interface(predict,
51
+ inputs=[gr.Text(value="Как записаться к стоматологу?"),
52
+ gr.Slider(minimum=0.01,
53
+ maximum=1,
54
+ value=0.3,
55
+ step=0.1)],
56
+ outputs=[gr.Text()])
57
+
58
+ if __name__ == 'main':
59
+ io.launch()
conversation.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Conversation:
2
+ def __init__(
3
+ self,
4
+ ):
5
+
6
+ self.message_template = "<s>{role}\n{content}</s>"
7
+ self.response_template = "<s>bot\n"
8
+ self.messages = [{
9
+ "role": "system",
10
+ "content": "Предложи ответ оператора технической поддержки на вопрос пользователя из чата."
11
+ }]
12
+
13
+ def add_user_message(self, message):
14
+ self.messages.append({
15
+ "role": "user",
16
+ "content": message
17
+ })
18
+
19
+ def add_bot_message(self, message):
20
+ self.messages.append({
21
+ "role": "bot",
22
+ "content": message
23
+ })
24
+
25
+ def get_prompt(self):
26
+ final_text = ""
27
+ for message in self.messages:
28
+ message_text = self.message_template.format(**message)
29
+ final_text += message_text
30
+ final_text += "<s>bot\n"
31
+ return final_text.strip()
32
+
33
+
34
+ def generate(model, tokenizer, prompt, generation_config):
35
+ data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
36
+ data = {k: v.to(model.device) for k, v in data.items()}
37
+ output_ids = model.generate(
38
+ **data,
39
+ generation_config=generation_config
40
+ )[0]
41
+ output_ids = output_ids[len(data["input_ids"][0]):]
42
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
43
+ return output.strip()