| |
|
| | |
| | import datetime |
| | import os |
| | from threading import Event, Thread |
| | from uuid import uuid4 |
| |
|
| | import gradio as gr |
| | import requests |
| | import torch |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | StoppingCriteria, |
| | StoppingCriteriaList, |
| | TextIteratorStreamer, |
| | ) |
| |
|
| |
|
| | model_name = "mosaicml/mpt-7b-chat" |
| | max_new_tokens = 7000 |
| |
|
| | |
| | |
| | |
| |
|
| | auth_token = os.getenv("HF_TOKEN", None) |
| |
|
| | print(f"Starting to load the model {model_name} into memory") |
| |
|
| | m = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | use_auth_token=auth_token, |
| | max_seq_len=10000, |
| | ).cuda() |
| | tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=auth_token) |
| |
|
| | stop_token_ids = tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"]) |
| |
|
| | print(f"Successfully loaded the model {model_name} into memory") |
| |
|
| |
|
| | start_message = """<|im_start|>system |
| | - You are a helpful assistant chatbot trained by Debasish. |
| | - You answer questions. |
| | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. |
| | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|> |
| | """ |
| |
|
| |
|
| | class StopOnTokens(StoppingCriteria): |
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
| | for stop_id in stop_token_ids: |
| | if input_ids[0][-1] == stop_id: |
| | return True |
| | return False |
| |
|
| |
|
| | def convert_history_to_text(history): |
| | text = start_message + "".join( |
| | [ |
| | "".join( |
| | [ |
| | f"<|im_start|>user\n{item[0]}<|im_end|>", |
| | f"<|im_start|>assistant\n{item[1]}<|im_end|>", |
| | ] |
| | ) |
| | for item in history[:-1] |
| | ] |
| | ) |
| | text += "".join( |
| | [ |
| | "".join( |
| | [ |
| | f"<|im_start|>user\n{history[-1][0]}<|im_end|>", |
| | f"<|im_start|>assistant\n{history[-1][1]}", |
| | ] |
| | ) |
| | ] |
| | ) |
| | return text |
| |
|
| |
|
| | def log_conversation(conversation_id, history, messages, generate_kwargs): |
| | logging_url = os.getenv("LOGGING_URL", None) |
| | if logging_url is None: |
| | return |
| |
|
| | timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") |
| |
|
| | data = { |
| | "conversation_id": conversation_id, |
| | "timestamp": timestamp, |
| | "history": history, |
| | "messages": messages, |
| | "generate_kwargs": generate_kwargs, |
| | } |
| |
|
| | try: |
| | requests.post(logging_url, json=data) |
| | except requests.exceptions.RequestException as e: |
| | print(f"Error logging conversation: {e}") |
| |
|
| |
|
| | def user(message, history): |
| | |
| | return "", history + [[message, ""]] |
| |
|
| |
|
| | def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): |
| | print(f"history: {history}") |
| | |
| | stop = StopOnTokens() |
| |
|
| | |
| | messages = convert_history_to_text(history) |
| |
|
| | |
| | input_ids = tok(messages, return_tensors="pt").input_ids |
| | input_ids = input_ids.to(m.device) |
| | streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
| | generate_kwargs = dict( |
| | input_ids=input_ids, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | do_sample=temperature > 0.0, |
| | top_p=top_p, |
| | top_k=top_k, |
| | repetition_penalty=repetition_penalty, |
| | streamer=streamer, |
| | stopping_criteria=StoppingCriteriaList([stop]), |
| | ) |
| |
|
| | stream_complete = Event() |
| |
|
| | def generate_and_signal_complete(): |
| | m.generate(**generate_kwargs) |
| | stream_complete.set() |
| |
|
| | def log_after_stream_complete(): |
| | stream_complete.wait() |
| | log_conversation( |
| | conversation_id, |
| | history, |
| | messages, |
| | { |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "temperature": temperature, |
| | "repetition_penalty": repetition_penalty, |
| | }, |
| | ) |
| |
|
| | t1 = Thread(target=generate_and_signal_complete) |
| | t1.start() |
| |
|
| | t2 = Thread(target=log_after_stream_complete) |
| | t2.start() |
| |
|
| | |
| | partial_text = "" |
| | for new_text in streamer: |
| | partial_text += new_text |
| | history[-1][1] = partial_text |
| | yield history |
| |
|
| |
|
| | def get_uuid(): |
| | return str(uuid4()) |
| |
|
| |
|
| | with gr.Blocks( |
| | theme=gr.themes.Soft(), |
| | css=".disclaimer {font-variant-caps: all-small-caps;}", |
| | ) as demo: |
| | conversation_id = gr.State(get_uuid) |
| | gr.Markdown( |
| | """<h1><center>Deb-Ona-Chat</center></h1> |
| | |
| | This demo is of Ona-7B. It is based on various open source models fine-tuned with approximately [171,000 conversation samples from this dataset](https://huggingface.co/datasets/sam-mosaic/vicuna_alpaca_hc3_chatml) and another [217,000] samples from another model. |
| | """ |
| | ) |
| | chatbot = gr.Chatbot().style(height=500) |
| | with gr.Row(): |
| | with gr.Column(): |
| | msg = gr.Textbox( |
| | label="Chat Message Box", |
| | placeholder="Chat Message Box", |
| | show_label=False, |
| | ).style(container=False) |
| | with gr.Column(): |
| | with gr.Row(): |
| | submit = gr.Button("Submit") |
| | stop = gr.Button("Stop") |
| | clear = gr.Button("Clear") |
| | with gr.Row(): |
| | with gr.Accordion("Advanced Options:", open=False): |
| | with gr.Row(): |
| | with gr.Column(): |
| | with gr.Row(): |
| | temperature = gr.Slider( |
| | label="Temperature", |
| | value=0.1, |
| | minimum=0.0, |
| | maximum=1.0, |
| | step=0.1, |
| | interactive=True, |
| | info="Higher values produce more diverse outputs", |
| | ) |
| | with gr.Column(): |
| | with gr.Row(): |
| | top_p = gr.Slider( |
| | label="Top-p (nucleus sampling)", |
| | value=1.0, |
| | minimum=0.0, |
| | maximum=1, |
| | step=0.01, |
| | interactive=True, |
| | info=( |
| | "Sample from the smallest possible set of tokens whose cumulative probability " |
| | "exceeds top_p. Set to 1 to disable and sample from all tokens." |
| | ), |
| | ) |
| | with gr.Column(): |
| | with gr.Row(): |
| | top_k = gr.Slider( |
| | label="Top-k", |
| | value=0, |
| | minimum=0.0, |
| | maximum=200, |
| | step=1, |
| | interactive=True, |
| | info="Sample from a shortlist of top-k tokens β 0 to disable and sample from all tokens.", |
| | ) |
| | with gr.Column(): |
| | with gr.Row(): |
| | repetition_penalty = gr.Slider( |
| | label="Repetition Penalty", |
| | value=1.1, |
| | minimum=1.0, |
| | maximum=2.0, |
| | step=0.1, |
| | interactive=True, |
| | info="Penalize repetition β 1.0 to disable.", |
| | ) |
| | with gr.Row(): |
| | gr.Markdown( |
| | "Disclaimer: Deb-Ona can produce factually incorrect output, and should not be relied on to produce " |
| | "factually accurate information. Ona-7B was trained on various public datasets; while great efforts " |
| | "have been taken to clean the pretraining data, it is possible that this model could generate lewd, " |
| | "biased, or otherwise offensive outputs.", |
| | elem_classes=["disclaimer"], |
| | ) |
| | with gr.Row(): |
| | gr.Markdown( |
| | "Test Chat", |
| | elem_classes=["disclaimer"], |
| | ) |
| |
|
| | submit_event = msg.submit( |
| | fn=user, |
| | inputs=[msg, chatbot], |
| | outputs=[msg, chatbot], |
| | queue=False, |
| | ).then( |
| | fn=bot, |
| | inputs=[ |
| | chatbot, |
| | temperature, |
| | top_p, |
| | top_k, |
| | repetition_penalty, |
| | conversation_id, |
| | ], |
| | outputs=chatbot, |
| | queue=True, |
| | ) |
| | submit_click_event = submit.click( |
| | fn=user, |
| | inputs=[msg, chatbot], |
| | outputs=[msg, chatbot], |
| | queue=False, |
| | ).then( |
| | fn=bot, |
| | inputs=[ |
| | chatbot, |
| | temperature, |
| | top_p, |
| | top_k, |
| | repetition_penalty, |
| | conversation_id, |
| | ], |
| | outputs=chatbot, |
| | queue=True, |
| | ) |
| | stop.click( |
| | fn=None, |
| | inputs=None, |
| | outputs=None, |
| | cancels=[submit_event, submit_click_event], |
| | queue=False, |
| | ) |
| | clear.click(lambda: None, None, chatbot, queue=False) |
| |
|
| | demo.queue(max_size=128, concurrency_count=2) |
| | demo.launch() |
| |
|