| | model_name = "berkeley-nest/Starling-LM-7B-alpha" |
| |
|
| | title = """# 👋🏻Welcome to Tonic's 💫🌠Starling 7B""" |
| | description = """You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u).""" |
| |
|
| | import transformers |
| | from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM |
| | import torch |
| | import gradio as gr |
| | import json |
| | import os |
| | import shutil |
| | import requests |
| | import accelerate |
| | import bitsandbytes |
| | import gc |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | bos_token_id = 1, |
| | eos_token_id = 32000 |
| | pad_token_id = 32001 |
| | temperature=0.4 |
| | max_new_tokens=240 |
| | top_p=0.92 |
| | repetition_penalty=1.7 |
| |
|
| | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") |
| | model.eval() |
| | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50' |
| |
|
| | class StarlingBot: |
| | def __init__(self, assistant_message="I am Starling-7B by Tonic-AI, I am ready to do anything to help my user."): |
| | self.assistant_message = assistant_message |
| |
|
| | def predict(self, user_message, assistant_message, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9): |
| | try: |
| | if mode == "Assistant": |
| | conversation = f"GPT4 Correct Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} GPT4 Correct Assistant:" |
| | else: |
| | conversation = f"Code Assistant: {assistant_message if assistant_message else ''} Code User:: {user_message} Code Assistant:" |
| | input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True) |
| | input_ids = input_ids.to(device) |
| | response = model.generate( |
| | input_ids=input_ids, |
| | use_cache=True, |
| | early_stopping=False, |
| | bos_token_id=bos_token_id, |
| | eos_token_id=eos_token_id, |
| | pad_token_id=pad_token_id, |
| | temperature=temperature, |
| | do_sample=True, |
| | max_new_tokens=max_new_tokens, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty |
| | ) |
| | response_text = tokenizer.decode(response[0], skip_special_tokens=True) |
| | |
| | return response_text |
| | finally: |
| | del input_ids |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | examples = [ |
| | [ |
| | "The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", |
| | "[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.", |
| | 0.9, |
| | 450, |
| | 0.90, |
| | 1.9, |
| | ] |
| | ] |
| |
|
| | starling_bot = StarlingBot() |
| |
|
| | def gradio_starling(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty): |
| | response = starling_bot.predict(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty) |
| | return response |
| |
|
| | with gr.Blocks(theme="ParityError/Anime") as demo: |
| | gr.Markdown(title) |
| | gr.Markdown(description) |
| | with gr.Row(): |
| | assistant_message = gr.Textbox(label="Optional💫🌠Starling Assistant Message", lines=2) |
| | user_message = gr.Textbox(label="Your Message", lines=3) |
| | with gr.Row(): |
| | mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode") |
| | do_sample = gr.Checkbox(label="Advanced", value=True) |
| | with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample): |
| | with gr.Row(): |
| | temperature = gr.Slider(label="Temperature", value=0.4, minimum=0.05, maximum=1.0, step=0.05) |
| | max_new_tokens = gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=800, step=1) |
| | top_p = gr.Slider(label="Top-p (nucleus sampling)", value=3.6, minimum=1.0, maximum=4.0, step=0.1) |
| | repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05) |
| |
|
| | submit_button = gr.Button("Submit") |
| | output_text = gr.Textbox(label="💫🌠Starling Response") |
| |
|
| | submit_button.click( |
| | gradio_starling, |
| | inputs=[user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty], |
| | outputs=output_text |
| | ) |
| |
|
| | demo.launch() |