| | import os |
| | from threading import Thread |
| | from typing import Iterator |
| |
|
| | import gradio as gr |
| | from langfuse import Langfuse |
| | from langfuse.decorators import observe |
| | import spaces |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
| | import time |
| |
|
| | MAX_MAX_NEW_TOKENS = 2048 |
| | DEFAULT_MAX_NEW_TOKENS = 1024 |
| | MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
| |
|
| |
|
| | DESCRIPTION = """\ |
| | # Dorna-Llama3-8B-Instruct Chat |
| | """ |
| |
|
| | PLACEHOLDER = """ |
| | <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> |
| | <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Test Dorna-Llama3-8B-Instruct</h1> |
| | </div> |
| | """ |
| |
|
| | custom_css = """ |
| | @import url('https://fonts.googleapis.com/css2?family=Vazirmatn&display=swap'); |
| | |
| | body, .gradio-container, .gr-button, .gr-input, .gr-slider, .gr-dropdown, .gr-markdown { |
| | font-family: 'Vazirmatn', sans-serif !important; |
| | } |
| | |
| | ._button { |
| | font-size: 20px; |
| | } |
| | |
| | pre, code { |
| | direction: ltr !important; |
| | unicode-bidi: plaintext !important; |
| | } |
| | """ |
| |
|
| |
|
| | system_prompt = str(os.getenv("SYSTEM_PROMPT")) |
| |
|
| | secret_key = str(os.getenv("LANGFUSE_SECRET_KEY")) |
| | public_key = str(os.getenv("LANGFUSE_PUBLIC_KEY")) |
| | host = str(os.getenv("LANGFUSE_HOST")) |
| |
|
| | langfuse = Langfuse( |
| | secret_key=secret_key, |
| | public_key=public_key, |
| | host=host |
| | ) |
| |
|
| |
|
| | def execution_time_calculator(start_time, log=True): |
| | delta = time.time() - start_time |
| | if log: |
| | print("--- %s seconds ---" % (delta)) |
| | return delta |
| |
|
| | def token_per_second_calculator(tokens_count, time_delta): |
| | return tokens_count/time_delta |
| |
|
| | if not torch.cuda.is_available(): |
| | DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
| |
|
| |
|
| | if torch.cuda.is_available(): |
| | model_id = "PartAI/Dorna-Llama3-8B-Instruct" |
| | model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16) |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | |
| | generation_speed = 0 |
| |
|
| | def get_generation_speed(): |
| | global generation_speed |
| |
|
| | return generation_speed |
| |
|
| | @observe() |
| | def log_to_langfuse(message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, do_sample, generation_speed, model_outputs): |
| | print(f"generation_speed: {generation_speed}") |
| | return "".join(model_outputs) |
| |
|
| |
|
| | @spaces.GPU |
| | def generate( |
| | message: str, |
| | chat_history: list[tuple[str, str]], |
| | max_new_tokens: int = 1024, |
| | temperature: float = 0.6, |
| | top_p: float = 0.9, |
| | top_k: int = 50, |
| | repetition_penalty: float = 1.2, |
| | do_sample: bool =True, |
| | ) -> Iterator[str]: |
| | global generation_speed |
| | global system_prompt |
| |
|
| | conversation = [] |
| | if system_prompt: |
| | conversation.append({"role": "system", "content": system_prompt}) |
| | for user, assistant in chat_history: |
| | conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
| | conversation.append({"role": "user", "content": message}) |
| |
|
| | input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") |
| | if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
| | input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
| | gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
| | input_ids = input_ids.to(model.device) |
| |
|
| | streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
| | generate_kwargs = dict( |
| | {"input_ids": input_ids}, |
| | streamer=streamer, |
| | max_new_tokens=max_new_tokens, |
| | do_sample=do_sample, |
| | top_p=top_p, |
| | top_k=top_k, |
| | temperature=temperature, |
| | num_beams=1, |
| | repetition_penalty=repetition_penalty, |
| | ) |
| |
|
| | start_time = time.time() |
| | t = Thread(target=model.generate, kwargs=generate_kwargs) |
| | t.start() |
| |
|
| | outputs = [] |
| | sum_tokens = 0 |
| | for text in streamer: |
| | num_tokens = len(tokenizer.tokenize(text)) |
| | sum_tokens += num_tokens |
| | |
| | outputs.append(text) |
| | yield "".join(outputs) |
| |
|
| | time_delta = execution_time_calculator(start_time, log=False) |
| |
|
| | generation_speed = token_per_second_calculator(sum_tokens, time_delta) |
| |
|
| | log_function = log_to_langfuse(message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, do_sample, generation_speed, outputs) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, show_copy_button=True, height="5%", rtl=True) |
| | chat_input = gr.Textbox(show_label=False, lines=2, rtl=True, placeholder="ورودی", show_copy_button=True, scale=4) |
| | submit_btn = gr.Button(variant="primary", value="ارسال", size="sm", scale=1, elem_classes=["_button"]) |
| |
|
| |
|
| | chat_interface = gr.ChatInterface( |
| | fn=generate, |
| | additional_inputs_accordion=gr.Accordion(label="ورودیهای اضافی", open=False), |
| | additional_inputs=[ |
| | gr.Slider( |
| | label="حداکثر تعداد توکن ها", |
| | minimum=1, |
| | maximum=MAX_MAX_NEW_TOKENS, |
| | step=1, |
| | value=DEFAULT_MAX_NEW_TOKENS, |
| | ), |
| | gr.Slider( |
| | label="Temperature", |
| | minimum=0.01, |
| | maximum=4.0, |
| | step=0.01, |
| | value=0.5, |
| | ), |
| | gr.Slider( |
| | label="Top-p", |
| | minimum=0.05, |
| | maximum=1.0, |
| | step=0.01, |
| | value=0.9, |
| | ), |
| | gr.Slider( |
| | label="Top-k", |
| | minimum=1, |
| | maximum=1000, |
| | step=1, |
| | value=20, |
| | ), |
| | gr.Slider( |
| | label="جریمه تکرار", |
| | minimum=1.0, |
| | maximum=2.0, |
| | step=0.05, |
| | value=1.2, |
| | ), |
| | gr.Dropdown( |
| | label="نمونهگیری", |
| | choices=[False, True], |
| | value=True) |
| | ], |
| | stop_btn="توقف", |
| | chatbot=chatbot, |
| | textbox=chat_input, |
| | submit_btn=submit_btn, |
| | retry_btn="🔄 تلاش مجدد", |
| | undo_btn="↩️ بازگشت", |
| | clear_btn="🗑️ پاک کردن", |
| | title="تست llama3" |
| | ) |
| |
|
| |
|
| | with gr.Blocks(css=custom_css, fill_height=False) as demo: |
| | gr.Markdown(DESCRIPTION) |
| | chat_interface.render() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.queue(max_size=20).launch() |
| |
|