| | import sys |
| | import spaces |
| | import os |
| | import gradio as gr |
| | import torch |
| | import torch |
| | import threading |
| | from together import Together |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | TextIteratorStreamer, |
| | ) |
| |
|
| | TOGETHER_MODEL_ID = "EssentialAI/RNJ-1-instruct" |
| | MAX_INPUT_TOKEN_LENGTH = 8192 |
| |
|
| | client = Together() |
| |
|
| | @spaces.GPU |
| | def chat_fn(message, history, system_message, max_tokens, temperature, top_p): |
| |
|
| | |
| | messages = [{"role": "system", "content": system_message}] |
| | |
| | for user_msg, bot_msg in history: |
| | messages.append({"role": "user", "content": user_msg}) |
| | messages.append({"role": "assistant", "content": bot_msg}) |
| |
|
| | |
| | messages.append({"role": "user", "content": message}) |
| |
|
| | response = client.chat.completions.create( |
| | model=TOGETHER_MODEL_ID, |
| | messages=messages, |
| | temperature=temperature, |
| | top_p=top_p, |
| | max_tokens=max_tokens, |
| | ) |
| |
|
| | response_str = response.choices[0].message.content |
| | print(response.choices[0].message.content) |
| |
|
| | yield response_str |
| |
|
| | demo = gr.ChatInterface( |
| | fn=chat_fn, |
| | title="Chat Demo with RNJ1", |
| | description=" Chat with RNJ1", |
| | additional_inputs=[ |
| | gr.Textbox(value="You are a helpful AI assistant.", label="System message"), |
| | gr.Slider(minimum=1, maximum=10240, value=512, step=1, label="Max new tokens"), |
| | gr.Slider(minimum=0.1, maximum=4.0, value=0.2, step=0.1, label="Temperature"), |
| | gr.Slider( |
| | minimum=0.1, |
| | maximum=1.0, |
| | value=0.95, |
| | step=0.05, |
| | label="Top-p (nucleus sampling)", |
| | ), |
| | ], |
| | examples=[ |
| | ["Hello there! How are you doing?"], |
| | ["Can you write a Python program that adds two numbers?"], |
| | ["Which one of these mountains is not located in Europe? Hoverla, Mont-Blanc, Gran Paradiso, Everest"], |
| | ] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|