mkoot007 commited on
Commit
b9ea6e1
·
1 Parent(s): 1fbe707

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -43
app.py CHANGED
@@ -1,45 +1,85 @@
1
- import gradio as gr
2
- from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
3
-
4
- tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
5
- model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
6
-
7
- def chat_with_model(input_text):
8
- input_ids = tokenizer.encode("You: " + input_text, return_tensors="pt", max_length=512, truncation=True)
9
- response_ids = model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2)
10
- reply = tokenizer.decode(response_ids[0], skip_special_tokens=True)
11
- return reply
12
-
13
- custom_css = """
14
- <style>
15
- .gr-form-box {
16
- border: 2px solid #0074D9;
17
- border-radius: 10px;
18
- background-color: #F5F5F5;
19
- padding: 10px;
20
- }
21
- .gr-textbox {
22
- background-color: #EAEAEA;
23
- border: 1px solid #0074D9;
24
- border-radius: 5px;
25
- padding: 5px;
26
- }
27
- .gr-button {
28
- background-color: #0074D9;
29
- color: #FFFFFF;
30
- border: 2px solid #0074D9;
31
- border-radius: 5px;
32
- padding: 5px 10px;
33
- }
34
- </style>
35
- """
36
-
37
- iface = gr.Interface(
38
- fn=chat_with_model,
39
- inputs=gr.Textbox(prompt="You:"),
40
- outputs=gr.Textbox(prompt="Bot:"),
41
- live=True,
42
- custom_css=custom_css # Add your custom CSS styles here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- iface.launch()
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel, PeftConfig
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+
5
+ MODEL_NAME = "IlyaGusev/saiga_mistral_7b"
6
+ DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>"
7
+ DEFAULT_RESPONSE_TEMPLATE = "<s>bot\n"
8
+ DEFAULT_SYSTEM_PROMPT = "Ты Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
9
+
10
+ class Conversation:
11
+ def __init__(
12
+ self,
13
+ message_template=DEFAULT_MESSAGE_TEMPLATE,
14
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
15
+ response_template=DEFAULT_RESPONSE_TEMPLATE
16
+ ):
17
+ self.message_template = message_template
18
+ self.response_template = response_template
19
+ self.messages = [{
20
+ "role": "system",
21
+ "content": system_prompt
22
+ }]
23
+
24
+ def add_user_message(self, message):
25
+ self.messages.append({
26
+ "role": "user",
27
+ "content": message
28
+ })
29
+
30
+ def add_bot_message(self, message):
31
+ self.messages.append({
32
+ "role": "bot",
33
+ "content": message
34
+ })
35
+
36
+ def get_prompt(self, tokenizer):
37
+ final_text = ""
38
+ for message in self.messages:
39
+ message_text = self.message_template.format(**message)
40
+ final_text += message_text
41
+ final_text += DEFAULT_RESPONSE_TEMPLATE
42
+ return final_text.strip()
43
+
44
+
45
+ def generate(model, tokenizer, prompt, generation_config):
46
+ data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
47
+ data = {k: v.to(model.device) for k, v in data.items()}
48
+ output_ids = model.generate(
49
+ **data,
50
+ generation_config=generation_config
51
+ )[0]
52
+ output_ids = output_ids[len(data["input_ids"][0]):]
53
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
54
+ return output.strip()
55
+
56
+ config = PeftConfig.from_pretrained(MODEL_NAME)
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ config.base_model_name_or_path,
59
+ load_in_8bit=True,
60
+ torch_dtype=torch.float16,
61
+ device_map="auto"
62
+ )
63
+ model = PeftModel.from_pretrained(
64
+ model,
65
+ MODEL_NAME,
66
+ torch_dtype=torch.float16
67
  )
68
+ model.eval()
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
71
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
72
+ print(generation_config)
73
+
74
+ inputs = ["Почему трава зеленая?", "Сочини длинный рассказ, обязательно упоминая следующие объекты. Дано: Таня, мяч"]
75
+ for inp in inputs:
76
+ conversation = Conversation()
77
+ conversation.add_user_message(inp)
78
+ prompt = conversation.get_prompt(tokenizer)
79
 
80
+ output = generate(model, tokenizer, prompt, generation_config)
81
+ print(inp)
82
+ print(output)
83
+ print()
84
+ print("==============================")
85
+ print()