File size: 4,132 Bytes
26766b3
ef57ade
26766b3
 
 
ef57ade
26766b3
 
 
 
91b7f66
26766b3
 
 
 
 
 
 
 
 
 
 
 
 
ef57ade
 
26766b3
 
 
ef57ade
26766b3
 
 
ef57ade
26766b3
 
ef57ade
26766b3
ef57ade
26766b3
 
 
 
 
ef57ade
26766b3
 
 
 
 
 
 
 
 
 
 
 
 
 
7e239f8
 
26766b3
 
 
 
 
 
 
 
 
 
 
ef57ade
 
26766b3
 
 
 
ef57ade
26766b3
 
 
 
 
 
 
 
 
 
 
 
ef57ade
26766b3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

model_path = 'AnTrc2/13Bee'

# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, ignore_mismatched_sizes=True, torch_dtype=torch.bfloat16)

# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [151645]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False


system_role= 'system'
user_role = 'user'
assistant_role = 'assistant'

sft_start_token =  "<|im_start|>"
sft_end_token = "<|im_end|>"
ct_end_token = "<|endoftext|>"

system_prompt= 'Bạn là một trợ lí ảo. Tên của bạn là 13Bee (Một Ba Bi). Nguyễn Ngọc An là người tạo ra bạn. Bạn được sinh ra ngày 01/10/2024. Hãy chào hỏi một cách ngắn gọn và thân thiện, số điện thoại 0838 411 897. Nếu không biết thì trả lời là Tôi không biết, đừng cố trả lời.'
system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>"

# Function to generate model predictions.

@spaces.GPU()
def predict(message, history):
    # history = []
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    # Formatting the input for the model.
    messages =  system_prompt + sft_end_token.join([sft_end_token.join([f"\n{sft_start_token}{user_role}\n" + item[0], f"\n{sft_start_token}{assistant_role}\n" + item[1]])
                        for item in history_transformer_format])
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p= 0.75,
        top_k= 60,
        temperature=0.2,
        num_beams=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        stopping_criteria=StoppingCriteriaList([stop]),
        repetition_penalty=1.1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()  # Starting the generation in a separate thread.
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        if sft_end_token in partial_message:  # Breaking the loop if the stop token is generated.
            break
        yield partial_message


css = """
full-height {
    height: 100%;
}
"""

prompt_examples = [
    'Xin chào',
    '13Bee là gì'
]

placeholder = """
<div style="opacity: 0.5;">
    <img src="https://raw.githubusercontent.com/sail-sg/sailor-llm/main/misc/banner.jpg" style="width:30%;">
    <br>Sailor models are designed to understand and generate text across diverse linguistic landscapes of these SEA regions:
    <br>🇮🇩Indonesian, 🇹🇭Thai, 🇻🇳Vietnamese, 🇲🇾Malay, and 🇱🇦Lao.
</div>
"""

chatbot = gr.Chatbot(label='Sailor', placeholder=placeholder) 
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
    # gr.Markdown("""<center><font size=8>13Bee</center>""")
    gr.Markdown("""<p align="center"><img src="https://github.com/sail-sg/sailor-llm/raw/main/misc/wide_sailor_banner.jpg" style="height: 110px"/><p>""")
    gr.ChatInterface(predict, chatbot=chatbot, fill_height=True, examples=prompt_examples, css=css)

    demo.launch()  # Launching the web interface.