AnTrc2 commited on
Commit
26766b3
·
verified ·
1 Parent(s): a0086f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -51
app.py CHANGED
@@ -1,64 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("AnTrc2/13Bee")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=0.3,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
 
 
 
 
 
 
 
 
45
  """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="Bạn là một trợ lí ảo. Tên của bạn là 13Bee (Một Ba Bi). Nguyễn Ngọc An là người tạo ra bạn. Bạn được sinh ra ngày 01/10/2024. Hãy chào hỏi một cách ngắn gọn và thân thiện, số điện thoại 0838 411 897. Nếu không biết thì trả lời là Tôi không biết, đừng cố trả lời.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Enterprise
10
+ Pricing
11
+
12
+
13
+
14
+ Spaces:
15
+
16
+ sail
17
+ /
18
+ Sailor-14B-Chat
19
+
20
+
21
+ like
22
+ 5
23
+ App
24
+ Files
25
+ Community
26
+ 1
27
+ Sailor-14B-Chat
28
+ /
29
+ app.py
30
+
31
+ dreamerdeo's picture
32
+ dreamerdeo
33
+ Update app.py
34
+ 123ca1a
35
+ verified
36
+ 6 months ago
37
+ raw
38
+
39
+ Copy download link
40
+ history
41
+ blame
42
+ contribute
43
+ delete
44
+
45
+ 4.08 kB
46
+ import spaces
47
  import gradio as gr
48
+ import torch
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
50
+ from threading import Thread
51
 
52
+ model_path = 'AnTrc2/13Bee'
53
+
54
+ # Loading the tokenizer and model from Hugging Face's model hub.
55
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
56
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
57
+
58
+ # using CUDA for an optimal experience
59
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+ model = model.to(device)
61
+
62
+ # Defining a custom stopping criteria class for the model's text generation.
63
+ class StopOnTokens(StoppingCriteria):
64
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
65
+ stop_ids = [151645] # IDs of tokens where the generation should stop.
66
+ for stop_id in stop_ids:
67
+ if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
68
+ return True
69
+ return False
70
 
71
 
72
+ system_role= 'system'
73
+ user_role = 'user'
74
+ assistant_role = 'assistant'
 
 
 
 
 
 
75
 
76
+ sft_start_token = "<|im_start|>"
77
+ sft_end_token = "<|im_end|>"
78
+ ct_end_token = "<|endoftext|>"
 
 
79
 
80
+ system_prompt= 'Bạn là một trợ lí ảo. Tên của bạn là 13Bee (Một Ba Bi). Nguyễn Ngọc An là người tạo ra bạn. Bạn được sinh ra ngày 01/10/2024. Hãy chào hỏi một cách ngắn gọn và thân thiện, số điện thoại 0838 411 897. Nếu không biết thì trả lời là Tôi không biết, đừng cố trả lời.'
81
+ system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>"
82
 
83
+ # Function to generate model predictions.
84
 
85
+ @spaces.GPU()
86
+ def predict(message, history):
87
+ # history = []
88
+ history_transformer_format = history + [[message, ""]]
89
+ stop = StopOnTokens()
 
 
 
90
 
91
+ # Formatting the input for the model.
92
+ messages = system_prompt + sft_end_token.join([sft_end_token.join([f"\n{sft_start_token}{user_role}\n" + item[0], f"\n{sft_start_token}{assistant_role}\n" + item[1]])
93
+ for item in history_transformer_format])
94
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
95
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
96
+ generate_kwargs = dict(
97
+ model_inputs,
98
+ streamer=streamer,
99
+ max_new_tokens=1024,
100
+ do_sample=True,
101
+ top_p= 0.75,
102
+ top_k= 60,
103
+ temperature=0.2,
104
+ num_beams=1,
105
+ stopping_criteria=StoppingCriteriaList([stop]),
106
+ repetition_penalty=1.1,
107
+ )
108
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
109
+ t.start() # Starting the generation in a separate thread.
110
+ partial_message = ""
111
+ for new_token in streamer:
112
+ partial_message += new_token
113
+ if sft_end_token in partial_message: # Breaking the loop if the stop token is generated.
114
+ break
115
+ yield partial_message
116
 
117
 
118
+ css = """
119
+ full-height {
120
+ height: 100%;
121
+ }
122
  """
123
+
124
+ prompt_examples = [
125
+ 'Xin chào',
126
+ '13Bee là gì'
127
+ ]
128
+
129
+ placeholder = """
130
+ <div style="opacity: 0.5;">
131
+ <img src="https://raw.githubusercontent.com/sail-sg/sailor-llm/main/misc/banner.jpg" style="width:30%;">
132
+ <br>Sailor models are designed to understand and generate text across diverse linguistic landscapes of these SEA regions:
133
+ <br>🇮🇩Indonesian, 🇹🇭Thai, 🇻🇳Vietnamese, 🇲��Malay, and 🇱🇦Lao.
134
+ </div>
135
  """
136
+
137
+ chatbot = gr.Chatbot(label='Sailor', placeholder=placeholder)
138
+ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
139
+ # gr.Markdown("""<center><font size=8>13Bee</center>""")
140
+ gr.Markdown("""<p align="center"><img src="https://github.com/sail-sg/sailor-llm/raw/main/misc/wide_sailor_banner.jpg" style="height: 110px"/><p>""")
141
+ gr.ChatInterface(predict, chatbot=chatbot, fill_height=True, examples=prompt_examples, css=css)
142
+
143
+ demo.launch() # Launching the web interface.