|
|
import streamlit as st |
|
|
import requests |
|
|
import json |
|
|
import os |
|
|
import datetime |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
SPACE_URL = "https://qf6hn3tcwcf7pc7p.us-east-1.aws.endpoints.huggingface.cloud" |
|
|
HF_API_KEY = os.getenv("HF_API_KEY") |
|
|
EOS_TOKEN = "<|end|>" |
|
|
CHAT_HISTORY_DIR = "chat_histories" |
|
|
IMAGE_PATH = "DubsChat.png" |
|
|
IMAGE_PATH_2 = "Reboot AI.png" |
|
|
DUBS_PATH = "Dubs.png" |
|
|
|
|
|
|
|
|
try: |
|
|
os.makedirs(CHAT_HISTORY_DIR, exist_ok=True) |
|
|
except OSError as e: |
|
|
st.error(f"Failed to create chat history directory: {e}") |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="DUBSChat", page_icon=IMAGE_PATH, layout="wide") |
|
|
st.logo(IMAGE_PATH_2,size="large") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_chat_template(history, user_input): |
|
|
""" |
|
|
Formats the chat template by combining the chat history and user input. |
|
|
""" |
|
|
CHAT_TEMPLATE =f""" |
|
|
<|system|> You are Dubs, a helpful assistant created by RebootAI.<|end|> \n {history} <|user|> \n {user_input}<|end|> \n <|assistant|> """ |
|
|
return CHAT_TEMPLATE.format(history=history, user_input=user_input) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_chat_history(messages): |
|
|
""" |
|
|
Converts the chat messages into a string compatible with the chat template. |
|
|
Ensures no duplicate <|assistant|> tokens in the history. |
|
|
""" |
|
|
history = "" |
|
|
for message in messages: |
|
|
if message["role"] == "user": |
|
|
history += f"<|user|>{message['content']}<|end|>\n" |
|
|
elif message["role"] == "assistant": |
|
|
history += f"<|assistant|>{message['content']}<|end|>\n" |
|
|
return history.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_chat_history(session_name, messages): |
|
|
""" |
|
|
Save the chat history to a JSON file. |
|
|
""" |
|
|
file_path = os.path.join(CHAT_HISTORY_DIR, f"{session_name}.json") |
|
|
try: |
|
|
with open(file_path, "w") as f: |
|
|
json.dump(messages, f) |
|
|
except IOError as e: |
|
|
st.error(f"Failed to save chat history: {e}") |
|
|
|
|
|
|
|
|
def load_chat_history(file_name): |
|
|
""" |
|
|
Load the chat history from a JSON file. |
|
|
""" |
|
|
file_path = os.path.join(CHAT_HISTORY_DIR, file_name) |
|
|
try: |
|
|
with open(file_path, "r") as f: |
|
|
return json.load(f) |
|
|
except (FileNotFoundError, json.JSONDecodeError): |
|
|
st.error("Failed to load chat history. Starting with a new session.") |
|
|
return [] |
|
|
|
|
|
|
|
|
def get_saved_sessions(): |
|
|
""" |
|
|
Get the list of saved chat sessions. |
|
|
""" |
|
|
return [f.replace(".json", "") for f in os.listdir(CHAT_HISTORY_DIR) if f.endswith(".json")] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
if st.button("New Chat"): |
|
|
st.session_state["messages"] = [ |
|
|
{"role": "system", "content": "Your name is Dubs, a helpful assistant created by RebootAI."}, |
|
|
] |
|
|
st.session_state["session_name"] = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
save_chat_history(st.session_state["session_name"], st.session_state["messages"]) |
|
|
st.success("Chat reset and new session started.") |
|
|
|
|
|
saved_sessions = get_saved_sessions() |
|
|
if saved_sessions: |
|
|
selected_session = st.radio("Past Sessions:", saved_sessions) |
|
|
if st.button("Load Session"): |
|
|
st.session_state["messages"] = load_chat_history(f"{selected_session}.json") |
|
|
st.session_state["session_name"] = selected_session |
|
|
st.success(f"Loaded session: {selected_session}") |
|
|
else: |
|
|
st.write("No past sessions available.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state["messages"] = [ |
|
|
{"role": "system", "content": "You are Dubs, a helpful assistant created by RebootAI"} |
|
|
] |
|
|
if "session_name" not in st.session_state: |
|
|
st.session_state["session_name"] = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.image(IMAGE_PATH, width=250) |
|
|
st.markdown("Empowering you with a Sustainable AI") |
|
|
st.markdown("DubsChat is currently best suited at assisting you with Coding Problems") |
|
|
|
|
|
|
|
|
for message in st.session_state["messages"]: |
|
|
if message["role"] == "user": |
|
|
st.chat_message("user").write(message["content"]) |
|
|
elif message["role"] == "assistant": |
|
|
st.chat_message("assistant", avatar=DUBS_PATH).write(message["content"]) |
|
|
|
|
|
client = InferenceClient(SPACE_URL, token=HF_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_response(prompt_text): |
|
|
""" |
|
|
Stream text from the HF Inference Endpoint using the InferenceClient. |
|
|
Yields each partial chunk of text as it arrives. |
|
|
""" |
|
|
gen_kwargs = { |
|
|
"max_new_tokens": 4096, |
|
|
"top_k": 30, |
|
|
"top_p": 0.9, |
|
|
"temperature": 0.2, |
|
|
"repetition_penalty": 1.02, |
|
|
"stop_sequences": ["<|end|>"] |
|
|
} |
|
|
|
|
|
stream = client.text_generation(prompt_text, stream=True, details=True, **gen_kwargs) |
|
|
|
|
|
for response in stream: |
|
|
if response.token.special: |
|
|
continue |
|
|
yield response.token.text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = st.chat_input() |
|
|
|
|
|
if prompt: |
|
|
|
|
|
st.session_state["messages"].append({"role": "user", "content": prompt}) |
|
|
st.chat_message("user").write(prompt) |
|
|
|
|
|
|
|
|
chat_history = format_chat_history(st.session_state["messages"][:-1]) |
|
|
model_input = format_chat_template(chat_history, prompt) |
|
|
|
|
|
|
|
|
with st.spinner("Dubs is thinking... Woof Woof! 🐾"): |
|
|
msg = "" |
|
|
with st.chat_message("assistant", avatar=DUBS_PATH): |
|
|
response_stream = stream_response(model_input) |
|
|
msg = st.write_stream(response_stream) |
|
|
|
|
|
|
|
|
st.session_state["messages"].append({"role": "assistant", "content": msg}) |
|
|
|
|
|
|
|
|
save_chat_history(st.session_state["session_name"], st.session_state["messages"]) |
|
|
|
|
|
|
|
|
|
|
|
|