BytArch commited on
Commit
9703cbd
·
verified ·
1 Parent(s): f8a37fb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ model_path = "BytArch/source-mini"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
7
+ model = AutoModelForCausalLM.from_pretrained(model_path)
8
+
9
+ if tokenizer.pad_token is None:
10
+ tokenizer.pad_token = tokenizer.eos_token
11
+
12
+ def generate_response(
13
+ prompt,
14
+ system_message,
15
+ conversation_history=None,
16
+ max_tokens=75,
17
+ temperature=0.78,
18
+ top_p=0.85,
19
+ repetition_penalty=1.031,
20
+ top_k=55,
21
+ ):
22
+ context = ""
23
+ if conversation_history:
24
+ recent = conversation_history[-30:] if len(conversation_history) > 30 else conversation_history
25
+ is_first_message = False
26
+ for i, message in enumerate(recent):
27
+ if i == 0:
28
+ is_first_message = True
29
+ context += (
30
+ f"<|start|>User:<|message|>{system_message}<|end|>\n"
31
+ f"<|start|>Assistant:<|message|>Hello, nice to meet you!<|end|>\n"
32
+ )
33
+ if message["role"] == "user":
34
+ context += f"<|start|>User:<|message|>{message['content']}<|end|>\n"
35
+ else:
36
+ context += f"<|start|>Assistant:<|message|>{message['content']}<|end|>\n"
37
+
38
+ formatted_input = (
39
+ f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
40
+ )
41
+
42
+ inputs = tokenizer(
43
+ formatted_input,
44
+ return_tensors="pt",
45
+ padding=True,
46
+ truncation=True,
47
+ max_length=512,
48
+ )
49
+
50
+ with torch.no_grad():
51
+ outputs = model.generate(
52
+ inputs.input_ids,
53
+ attention_mask=inputs.attention_mask,
54
+ max_new_tokens=max_tokens,
55
+ temperature=temperature,
56
+ top_p=top_p,
57
+ top_k=top_k,
58
+ do_sample=True,
59
+ pad_token_id=tokenizer.pad_token_id,
60
+ repetition_penalty=repetition_penalty,
61
+ eos_token_id=tokenizer.encode("<|end|>", add_special_tokens=False)[0],
62
+ )
63
+
64
+ new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
65
+ response = tokenizer.decode(new_tokens, skip_special_tokens=False)
66
+
67
+ return response.strip()
68
+
69
+ def respond(
70
+ message,
71
+ history: list[dict[str, str]],
72
+ system_message,
73
+ max_tokens,
74
+ temperature,
75
+ top_p,
76
+ repetition_penalty,
77
+ top_k,
78
+ ):
79
+ conversation_history = history
80
+ response = generate_response(
81
+ message,
82
+ system_message,
83
+ conversation_history,
84
+ max_tokens=max_tokens,
85
+ temperature=temperature,
86
+ top_p=top_p,
87
+ repetition_penalty=repetition_penalty,
88
+ top_k=top_k,
89
+ )
90
+
91
+ if "<|end|>" in response:
92
+ response = response.split("<|end|>")[0]
93
+
94
+ return response.strip()
95
+
96
+ chatbot = gr.ChatInterface(
97
+ respond,
98
+ type="messages",
99
+ title="Chat with source-mini",
100
+ description="Chat with BytArch/source-mini",
101
+ additional_inputs=[
102
+ gr.Textbox(
103
+ value="You are source-mini, a helpful medical/nursing assistant chatbot.",
104
+ label="System message",
105
+ ),
106
+ gr.Slider(minimum=10, maximum=150, value=75, step=5, label="Max new tokens"),
107
+ gr.Slider(minimum=0.01, maximum=1.2, value=0.7, step=0.01, label="Temperature"),
108
+ gr.Slider(
109
+ minimum=0.01,
110
+ maximum=1.0,
111
+ value=0.85,
112
+ step=0.01,
113
+ label="Top-p (nucleus sampling)",
114
+ ),
115
+ gr.Slider(
116
+ minimum=1.0,
117
+ maximum=1.5,
118
+ value=1.031,
119
+ step=0.001,
120
+ label="Repetition penalty",
121
+ ),
122
+ gr.Slider(
123
+ minimum=1,
124
+ maximum=100,
125
+ value=55,
126
+ step=1,
127
+ label="Top-k (prediction sampling)",
128
+ ),
129
+ ],
130
+ )
131
+
132
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
+ chatbot.render()
134
+
135
+ if __name__ == "__main__":
136
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)