|
|
import streamlit as st |
|
|
from huggingface_hub import InferenceClient |
|
|
import os |
|
|
|
|
|
|
|
|
SPACE_URL = "https://z7svds7k42bwhhgm.us-east-1.aws.endpoints.huggingface.cloud" |
|
|
HF_API_KEY = os.getenv("HF_API_KEY") |
|
|
DUBS_PATH = "๐พ" |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Chatbot Test", page_icon="๐ค", layout="centered") |
|
|
|
|
|
client = InferenceClient(SPACE_URL, token=HF_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_response(prompt): |
|
|
""" |
|
|
Fetch full text response from the HF Inference Endpoint using the InferenceClient. |
|
|
Returns tokens in a streaming fashion. |
|
|
""" |
|
|
partial_text = "" |
|
|
gen_kwargs = { |
|
|
"max_new_tokens": 512, |
|
|
"top_k": 30, |
|
|
"top_p": 0.9, |
|
|
"temperature": 0.2, |
|
|
"repetition_penalty": 1.02, |
|
|
"stop_sequences": ["<|endoftext|>"] |
|
|
} |
|
|
|
|
|
stream = client.text_generation(prompt, stream=True, details=True, **gen_kwargs) |
|
|
|
|
|
for response in stream: |
|
|
if response.token.special: |
|
|
continue |
|
|
|
|
|
if response.token.text in gen_kwargs["stop_sequences"]: |
|
|
break |
|
|
partial_text += response.token.text |
|
|
yield response.token.text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("Chatbot Testing Interface") |
|
|
|
|
|
|
|
|
prompt = st.chat_input("Enter your message...") |
|
|
|
|
|
if prompt: |
|
|
|
|
|
st.chat_message("user").write(prompt) |
|
|
|
|
|
|
|
|
chat_history = f"<|user|>{prompt}<|end|> \n <|assistant|> " |
|
|
|
|
|
|
|
|
with st.spinner("Dubs is thinking... Woof Woof! ๐พ"): |
|
|
with st.chat_message("assistant", avatar=DUBS_PATH): |
|
|
full_response = fetch_response(chat_history) |
|
|
st.write_stream(full_response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|