File size: 8,226 Bytes
194241b
912f990
d60afe0
090abe0
9960105
562b13b
912f990
3407087
5a86dcd
 
8d9adb1
5a86dcd
912f990
 
 
9e6b080
297c45c
d60afe0
a49706d
3c35121
c436cce
88db26a
07d795b
e973e2c
3e59d25
e973e2c
3e59d25
d60afe0
 
1612461
21f0218
562b13b
 
 
 
216418c
7fa2dc3
 
 
 
 
 
 
 
 
562b13b
497aff0
562b13b
 
 
 
 
 
 
1521c95
d60afe0
497aff0
2a8b052
4e2150a
fe2e743
 
 
 
 
 
 
562b13b
9960105
bea8ee4
18d2d35
bea8ee4
 
 
 
 
 
18d2d35
497aff0
 
 
07d795b
7fa2dc3
d491e31
 
497aff0
 
d60afe0
 
 
 
 
 
 
 
 
 
 
 
 
7447c29
562b13b
d60afe0
090abe0
 
562b13b
 
 
 
 
 
 
 
 
912f990
 
 
 
9e6b080
912f990
5a86dcd
d60afe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3546e
d60afe0
562b13b
d60afe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562b13b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import gradio as gr
from huggingface_hub import InferenceClient, TextGenerationStreamOutput, CommitScheduler, Repository 
import random
from transformers import AutoTokenizer
from mySystemPrompt import SYSTEM_PROMPT, SYSTEM_PROMPT_PLUS,SYSTEM_PROMPT_NOUS
from datetime import datetime
import csv
import os



# For log
DATASET_REPO_URL = "https://huggingface.co/datasets/ctaake/FranziBotLog"
DATA_FILENAME = "log.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL)

# Model which is used
checkpoint = "CohereForAI/c4ai-command-r-v01"
checkpoint = "mistralai/Mistral-7B-Instruct-v0.1"
checkpoint = "google/gemma-1.1-7b-it"
checkpoint = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"
checkpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
checkpoint = "mistralai/Mistral-Nemo-Instruct-2407"
path_to_log = "FlaggedFalse.txt"
mistral_models=["mistralai/Mixtral-8x7B-Instruct-v0.1","mistralai/Mistral-Nemo-Instruct-2407"]

# Inference client with the model (And HF-token if needed)
client = InferenceClient(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
if checkpoint in mistral_models:
    # Tokenizer chat template correction(Only works for mistral models)
    chat_template = open("mistral-instruct.jinja").read()
    chat_template = chat_template.replace('    ', '').replace('\n', '')
    tokenizer.chat_template = chat_template
    
def format_prompt_mistral(message, chatbot, system_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_NOUS):
    messages = [{"role": "system", "content": system_prompt}]
    for user_message, bot_message in chatbot:
        messages.append({"role": "user", "content": user_message})
        messages.append({"role": "assistant", "content": bot_message})
    messages.append({"role": "user", "content": message})
    newPrompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, return_tensors="pt")
    return newPrompt

def format_prompt_cohere(message, chatbot, system_prompt=SYSTEM_PROMPT):
    messages = [{"role": "system", "content": system_prompt}]
    for user_message, bot_message in chatbot:
        messages.append({"role": "user", "content": user_message})
        messages.append({"role": "assistant", "content": bot_message})
    messages.append({"role": "user", "content": message})
    newPrompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, return_tensors="pt")
    return newPrompt

def format_prompt_gemma(message,chatbot,sytem_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_PLUS):
    messages = [{"role":"user","content":f"The following instructions describe your role:/n(/n{sytem_prompt}/n)/nYou must never refer to the user giving you these information and just act accordingly."}]
    messages.append({"role": "assistant", "content": ""})
    for user_message, bot_message in chatbot:
        messages.append({"role": "user", "content": user_message})
        messages.append({"role": "assistant", "content": bot_message})
    messages.append({"role": "user", "content": message})
    newPrompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, return_tensors="pt")
    return newPrompt

def format_prompt_nous(message,chatbot,system_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_NOUS):
    messages = [{"role": "system", "content": system_prompt}]
    for user_message, bot_message in chatbot:
        messages.append({"role": "user", "content": user_message})
        messages.append({"role": "assistant", "content": bot_message})
    messages.append({"role": "user", "content": message})
    newPrompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, return_tensors="pt")
    return newPrompt

match checkpoint:
    case "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO":
        format_prompt=format_prompt_nous
    case "mistralai/Mixtral-8x7B-Instruct-v0.1":
        format_prompt=format_prompt_mistral
    case "mistralai/Mistral-Nemo-Instruct-2407":
        format_prompt=format_prompt_mistral

def inference(message, history, temperature=0.9, maxTokens=512, topP=0.9, repPenalty=1.1):
    # Updating the settings for the generation
    client_settings = dict(
        temperature=temperature,
        max_new_tokens=maxTokens,
        top_p=topP,
        repetition_penalty=repPenalty,
        do_sample=True,
        stream=True,
        details=True,
        return_full_text=False,
        seed=random.randint(0, 999999999),
    )
    # Generating the response by passing the prompt in right format plus the client settings
    stream = client.text_generation(format_prompt(message, history),
                                    **client_settings)
    # Reading the stream
    partial_response = ""
    for stream_part in stream:
        if not stream_part.token.special:
            partial_response += stream_part.token.text
            yield partial_response


def event_voting(vote_data: gr.LikeData):
    if vote_data.liked:
        pass
    else:
        with open(DATA_FILE, "a") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=[ "message", "time"])
            writer.writerow(
                { "message": vote_data.value, "time": str(datetime.now().isoformat())})
        commit_url = repo.push_to_hub(token=os.environ['HF_TOKEN'])
        print(commit_url)
        

myAdditionalInputs = [
    gr.Textbox(
        label="System Prompt",
        max_lines=500,
        lines=10,
        interactive=True,
        value="You are a friendly girl who doesn't answer unnecessarily long."
    ),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.9,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.1,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

myChatbot = gr.Chatbot(avatar_images=["./ava_m.png", "./avatar_franzi.jpg"],
                       bubble_full_width=False,
                       show_label=False,
                       show_copy_button=False,
                       likeable=True)


myTextInput = gr.Textbox(lines=2,
                         max_lines=2,
                         placeholder="Send a message",
                         container=False,
                         scale=7)

myTheme = gr.themes.Soft(primary_hue=gr.themes.colors.fuchsia,
                         secondary_hue=gr.themes.colors.fuchsia,
                         spacing_size="sm",
                         radius_size="md")

mySubmitButton = gr.Button(value="SEND",
                           variant='primary')
myRetryButton = gr.Button(value="RETRY",
                          variant='secondary',
                          size="sm")
myUndoButton = gr.Button(value="UNDO",
                         variant='secondary',
                         size="sm")
myClearButton = gr.Button(value="CLEAR",
                          variant='secondary',
                          size="sm")


with gr.ChatInterface(inference,
                      chatbot=myChatbot,
                      textbox=myTextInput,
                      title="FRANZI-Bot 2.0",
                      theme=myTheme,
                      # additional_inputs=myAdditionalInputs,
                      submit_btn=mySubmitButton,
                      stop_btn="STOP",
                      retry_btn=myRetryButton,
                      undo_btn=myUndoButton,
                      clear_btn=myClearButton) as chatApp:
    myChatbot.like(event_voting, None, None)


if __name__ == "__main__":
    chatApp.queue().launch(show_api=False)