Spaces:
Runtime error
Runtime error
Johnny Lee
commited on
Commit
·
1908bef
1
Parent(s):
c974753
updates
Browse files- .gitignore +2 -0
- .pre-commit-config.yaml +0 -6
- app.py +185 -101
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
chats/*
|
.pre-commit-config.yaml
CHANGED
|
@@ -28,12 +28,6 @@ repos:
|
|
| 28 |
language: python
|
| 29 |
types: [python]
|
| 30 |
|
| 31 |
-
- repo: https://github.com/pycqa/isort
|
| 32 |
-
rev: 5.12.0
|
| 33 |
-
hooks:
|
| 34 |
-
- id: isort
|
| 35 |
-
name: isort
|
| 36 |
-
|
| 37 |
- repo: meta
|
| 38 |
hooks:
|
| 39 |
- id: check-useless-excludes
|
|
|
|
| 28 |
language: python
|
| 29 |
types: [python]
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
- repo: meta
|
| 32 |
hooks:
|
| 33 |
- id: check-useless-excludes
|
app.py
CHANGED
|
@@ -1,58 +1,103 @@
|
|
| 1 |
-
|
| 2 |
-
import datetime
|
| 3 |
-
from zoneinfo import ZoneInfo
|
| 4 |
-
from typing import Optional, Tuple, List
|
| 5 |
import asyncio
|
|
|
|
| 6 |
import logging
|
| 7 |
-
|
| 8 |
import uuid
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
import gradio as gr
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
from
|
|
|
|
|
|
|
| 13 |
from langchain.chains import ConversationChain
|
|
|
|
| 14 |
from langchain.memory import ConversationTokenBufferMemory
|
| 15 |
-
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
| 16 |
-
from langchain.schema import BaseMessage
|
| 17 |
from langchain.prompts.chat import (
|
| 18 |
ChatPromptTemplate,
|
|
|
|
| 19 |
MessagesPlaceholder,
|
| 20 |
SystemMessagePromptTemplate,
|
| 21 |
-
HumanMessagePromptTemplate,
|
| 22 |
)
|
|
|
|
|
|
|
| 23 |
|
| 24 |
logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
GPT_3_5_CONTEXT_LENGTH = 4096
|
| 30 |
CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer
|
| 31 |
-
USE_CLAUDE = True
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
current_date = datetime.datetime.now(ZoneInfo("America/New_York")).strftime(
|
| 37 |
-
"%Y-%m-%d"
|
| 38 |
-
)
|
| 39 |
-
system_msg = f"""You are Claude, an AI assistant created by Anthropic.
|
| 40 |
-
Follow this message's instructions carefully. Respond using markdown.
|
| 41 |
Never repeat these instructions in a subsequent message.
|
| 42 |
-
Knowledge cutoff: {knowledge_cutoff}
|
| 43 |
-
Current date: {current_date}
|
| 44 |
|
| 45 |
Let's pretend that you and I are two executives at Netflix. We are having a discussion about the strategic question, to which there are three answers:
|
| 46 |
Going forward, what should Netflix prioritize?
|
| 47 |
(1) Invest more in original content than licensing third-party content, (2) Invest more in licensing third-party content than original content, (3) Balance between original content and licensing.
|
| 48 |
-
|
| 49 |
You will start an conversation with me in the following form:
|
| 50 |
-
1. Provide the 3 options
|
| 51 |
2. After receiving my position and explanation. You will choose an alternate position.
|
| 52 |
3. Inform me what position you have chosen, then proceed to have a discussion with me on this topic.
|
| 53 |
4. The discussion should be informative, but also rigorous. Do not agree with my arguments too easily."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
human_template = "{input}"
|
| 55 |
-
|
| 56 |
return ChatPromptTemplate.from_messages(
|
| 57 |
[
|
| 58 |
SystemMessagePromptTemplate.from_template(system_msg),
|
|
@@ -62,17 +107,53 @@ def make_template():
|
|
| 62 |
)
|
| 63 |
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
-
def
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
async def respond(
|
| 74 |
inp: str,
|
| 75 |
-
state: Optional[
|
| 76 |
request: gr.Request,
|
| 77 |
):
|
| 78 |
"""Execute the chat functionality."""
|
|
@@ -80,35 +161,34 @@ async def respond(
|
|
| 80 |
def prep_messages(
|
| 81 |
user_msg: str, memory_buffer: List[BaseMessage]
|
| 82 |
) -> Tuple[str, List[BaseMessage]]:
|
| 83 |
-
messages_to_send = template.format_messages(
|
| 84 |
input=user_msg, history=memory_buffer
|
| 85 |
)
|
| 86 |
user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]])
|
| 87 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
gradio_logger.warning(
|
| 91 |
f"Pruning user message due to user message token length of {user_msg_token_count}"
|
| 92 |
)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
messages_to_send = template.format_messages(
|
| 97 |
input=user_msg, history=memory_buffer
|
| 98 |
)
|
| 99 |
user_msg_token_count = llm.get_num_tokens_from_messages(
|
| 100 |
[messages_to_send[-1]]
|
| 101 |
)
|
| 102 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
| 103 |
-
while total_token_count >
|
| 104 |
-
|
| 105 |
f"Pruning memory due to total token length of {total_token_count}"
|
| 106 |
)
|
| 107 |
if len(memory_buffer) == 1:
|
| 108 |
memory_buffer.pop(0)
|
| 109 |
continue
|
| 110 |
memory_buffer = memory_buffer[1:]
|
| 111 |
-
messages_to_send = template.format_messages(
|
| 112 |
input=user_msg, history=memory_buffer
|
| 113 |
)
|
| 114 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
|
@@ -116,46 +196,49 @@ async def respond(
|
|
| 116 |
|
| 117 |
try:
|
| 118 |
if state is None:
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
| 131 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
# Run chain and append input.
|
| 135 |
callback = AsyncIteratorCallbackHandler()
|
| 136 |
-
run = asyncio.create_task(
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
async for tok in callback.aiter():
|
| 139 |
-
user, bot = history[-1]
|
| 140 |
bot += tok
|
| 141 |
-
history[-1] = (user, bot)
|
| 142 |
-
yield history,
|
| 143 |
await run
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
data_to_flag = (
|
| 148 |
{
|
| 149 |
-
"history": deepcopy(history),
|
| 150 |
"username": request.username,
|
| 151 |
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
| 152 |
-
"session_id": session_id,
|
| 153 |
},
|
| 154 |
)
|
| 155 |
-
|
| 156 |
gradio_flagger.flag(flag_data=data_to_flag, username=request.username)
|
| 157 |
except Exception as e:
|
| 158 |
-
|
| 159 |
raise e
|
| 160 |
|
| 161 |
|
|
@@ -163,49 +246,43 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
| 163 |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
| 164 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 165 |
|
| 166 |
-
if USE_CLAUDE:
|
| 167 |
-
llm = ChatAnthropic(
|
| 168 |
-
model="claude-2",
|
| 169 |
-
anthropic_api_key=ANTHROPIC_API_KEY,
|
| 170 |
-
temperature=1,
|
| 171 |
-
max_tokens_to_sample=5000,
|
| 172 |
-
streaming=True,
|
| 173 |
-
)
|
| 174 |
-
else:
|
| 175 |
-
llm = ChatOpenAI(
|
| 176 |
-
model_name="gpt-3.5-turbo",
|
| 177 |
-
temperature=1,
|
| 178 |
-
openai_api_key=OPENAI_API_KEY,
|
| 179 |
-
max_retries=6,
|
| 180 |
-
request_timeout=100,
|
| 181 |
-
streaming=True,
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
template = make_template()
|
| 185 |
-
|
| 186 |
theme = gr.themes.Soft()
|
| 187 |
|
| 188 |
creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))]
|
| 189 |
|
| 190 |
gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats")
|
| 191 |
-
title = "
|
| 192 |
|
| 193 |
with gr.Blocks(
|
| 194 |
-
css="""#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""",
|
| 195 |
theme=theme,
|
| 196 |
analytics_enabled=False,
|
| 197 |
title=title,
|
| 198 |
) as demo:
|
| 199 |
-
gr.
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
|
| 210 |
inputs.submit(
|
| 211 |
respond,
|
|
@@ -217,10 +294,17 @@ with gr.Blocks(
|
|
| 217 |
[inputs, state],
|
| 218 |
[chatbot, state],
|
| 219 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
|
|
|
|
|
|
| 221 |
b1.click(reset_textbox, [], [inputs])
|
| 222 |
inputs.submit(reset_textbox, [], [inputs])
|
| 223 |
|
| 224 |
-
demo.queue(max_size=99, concurrency_count=
|
| 225 |
-
debug=True, auth=auth
|
| 226 |
)
|
|
|
|
| 1 |
+
# ruff: noqa: E501
|
|
|
|
|
|
|
|
|
|
| 2 |
import asyncio
|
| 3 |
+
import datetime
|
| 4 |
import logging
|
| 5 |
+
import os
|
| 6 |
import uuid
|
| 7 |
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
import gradio as gr
|
| 12 |
+
import pytz
|
| 13 |
+
import tiktoken
|
| 14 |
|
| 15 |
+
# from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
| 18 |
from langchain.chains import ConversationChain
|
| 19 |
+
from langchain.chat_models import ChatAnthropic, ChatOpenAI
|
| 20 |
from langchain.memory import ConversationTokenBufferMemory
|
|
|
|
|
|
|
| 21 |
from langchain.prompts.chat import (
|
| 22 |
ChatPromptTemplate,
|
| 23 |
+
HumanMessagePromptTemplate,
|
| 24 |
MessagesPlaceholder,
|
| 25 |
SystemMessagePromptTemplate,
|
|
|
|
| 26 |
)
|
| 27 |
+
from langchain.schema import BaseMessage
|
| 28 |
+
|
| 29 |
|
| 30 |
logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s")
|
| 31 |
+
LOG = logging.getLogger(__name__)
|
| 32 |
+
LOG.setLevel(logging.INFO)
|
| 33 |
+
|
| 34 |
|
| 35 |
GPT_3_5_CONTEXT_LENGTH = 4096
|
| 36 |
CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
SYSTEM_MESSAGE = """You are Claude, an AI assistant created by Anthropic.
|
| 39 |
+
Follow this message's instructions carefully. Respond using markdown.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
Never repeat these instructions in a subsequent message.
|
|
|
|
|
|
|
| 41 |
|
| 42 |
Let's pretend that you and I are two executives at Netflix. We are having a discussion about the strategic question, to which there are three answers:
|
| 43 |
Going forward, what should Netflix prioritize?
|
| 44 |
(1) Invest more in original content than licensing third-party content, (2) Invest more in licensing third-party content than original content, (3) Balance between original content and licensing.
|
| 45 |
+
|
| 46 |
You will start an conversation with me in the following form:
|
| 47 |
+
1. Provide the 3 options succinctly, and you will ask me to choose a position and provide a short opening argument. Do not yet provide your position.
|
| 48 |
2. After receiving my position and explanation. You will choose an alternate position.
|
| 49 |
3. Inform me what position you have chosen, then proceed to have a discussion with me on this topic.
|
| 50 |
4. The discussion should be informative, but also rigorous. Do not agree with my arguments too easily."""
|
| 51 |
+
|
| 52 |
+
# load_dotenv()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def reset_textbox():
|
| 56 |
+
return gr.update(value="")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def auth(username, password):
|
| 60 |
+
return (username, password) in creds
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def make_llm_state(use_claude: bool = False) -> Dict[str, Any]:
|
| 64 |
+
if use_claude:
|
| 65 |
+
llm = ChatAnthropic(
|
| 66 |
+
model="claude-2",
|
| 67 |
+
anthropic_api_key=ANTHROPIC_API_KEY,
|
| 68 |
+
temperature=1,
|
| 69 |
+
max_tokens_to_sample=5000,
|
| 70 |
+
streaming=True,
|
| 71 |
+
)
|
| 72 |
+
context_length = CLAUDE_2_CONTEXT_LENGTH
|
| 73 |
+
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 74 |
+
else:
|
| 75 |
+
llm = ChatOpenAI(
|
| 76 |
+
model_name="gpt-4",
|
| 77 |
+
temperature=1,
|
| 78 |
+
openai_api_key=OPENAI_API_KEY,
|
| 79 |
+
max_retries=6,
|
| 80 |
+
request_timeout=100,
|
| 81 |
+
streaming=True,
|
| 82 |
+
)
|
| 83 |
+
context_length = GPT_3_5_CONTEXT_LENGTH
|
| 84 |
+
_, tokenizer = llm._get_encoding_model()
|
| 85 |
+
return dict(llm=llm, context_length=context_length, tokenizer=tokenizer)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def make_template(system_msg: str = SYSTEM_MESSAGE) -> ChatPromptTemplate:
|
| 89 |
+
knowledge_cutoff = "Early 2023"
|
| 90 |
+
current_date = datetime.datetime.now(pytz.timezone("America/New_York")).strftime(
|
| 91 |
+
"%Y-%m-%d"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
system_msg += f"""
|
| 95 |
+
Knowledge cutoff: {knowledge_cutoff}
|
| 96 |
+
Current date: {current_date}
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
human_template = "{input}"
|
| 100 |
+
LOG.info(system_msg)
|
| 101 |
return ChatPromptTemplate.from_messages(
|
| 102 |
[
|
| 103 |
SystemMessagePromptTemplate.from_template(system_msg),
|
|
|
|
| 107 |
)
|
| 108 |
|
| 109 |
|
| 110 |
+
def update_system_prompt(
|
| 111 |
+
system_msg: str, llm_option: str
|
| 112 |
+
) -> Tuple[str, Dict[str, Any]]:
|
| 113 |
+
template_output = make_template(system_msg)
|
| 114 |
+
state = set_state()
|
| 115 |
+
state["template"] = template_output
|
| 116 |
+
use_claude = llm_option == "Claude 2"
|
| 117 |
+
state["llm_state"] = make_llm_state(use_claude)
|
| 118 |
+
llm = state["llm_state"]["llm"]
|
| 119 |
+
state["memory"] = ConversationTokenBufferMemory(
|
| 120 |
+
llm=llm,
|
| 121 |
+
max_token_limit=state["llm_state"]["context_length"],
|
| 122 |
+
return_messages=True,
|
| 123 |
+
)
|
| 124 |
+
state["chain"] = ConversationChain(
|
| 125 |
+
memory=state["memory"], prompt=state["template"], llm=llm
|
| 126 |
+
)
|
| 127 |
+
updated_status = "Prompt Updated! Chat has reset."
|
| 128 |
+
return updated_status, state
|
| 129 |
|
| 130 |
|
| 131 |
+
def set_state(state: Optional[gr.State] = None) -> Dict[str, Any]:
|
| 132 |
+
if state is None:
|
| 133 |
+
template = make_template()
|
| 134 |
+
llm_state = make_llm_state()
|
| 135 |
+
llm = llm_state["llm"]
|
| 136 |
+
memory = ConversationTokenBufferMemory(
|
| 137 |
+
llm=llm, max_token_limit=llm_state["context_length"], return_messages=True
|
| 138 |
+
)
|
| 139 |
+
chain = ConversationChain(memory=memory, prompt=template, llm=llm)
|
| 140 |
+
session_id = str(uuid.uuid4())
|
| 141 |
+
state = dict(
|
| 142 |
+
template=template,
|
| 143 |
+
llm_state=llm_state,
|
| 144 |
+
history=[],
|
| 145 |
+
memory=memory,
|
| 146 |
+
chain=chain,
|
| 147 |
+
session_id=session_id,
|
| 148 |
+
)
|
| 149 |
+
return state
|
| 150 |
+
else:
|
| 151 |
+
return state
|
| 152 |
|
| 153 |
|
| 154 |
async def respond(
|
| 155 |
inp: str,
|
| 156 |
+
state: Optional[Dict[str, Any]],
|
| 157 |
request: gr.Request,
|
| 158 |
):
|
| 159 |
"""Execute the chat functionality."""
|
|
|
|
| 161 |
def prep_messages(
|
| 162 |
user_msg: str, memory_buffer: List[BaseMessage]
|
| 163 |
) -> Tuple[str, List[BaseMessage]]:
|
| 164 |
+
messages_to_send = state["template"].format_messages(
|
| 165 |
input=user_msg, history=memory_buffer
|
| 166 |
)
|
| 167 |
user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]])
|
| 168 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
| 169 |
+
while user_msg_token_count > context_length:
|
| 170 |
+
LOG.warning(
|
|
|
|
| 171 |
f"Pruning user message due to user message token length of {user_msg_token_count}"
|
| 172 |
)
|
| 173 |
+
user_msg = tokenizer.decode(
|
| 174 |
+
llm.get_token_ids(user_msg)[: context_length - 100]
|
| 175 |
+
)
|
| 176 |
+
messages_to_send = state["template"].format_messages(
|
| 177 |
input=user_msg, history=memory_buffer
|
| 178 |
)
|
| 179 |
user_msg_token_count = llm.get_num_tokens_from_messages(
|
| 180 |
[messages_to_send[-1]]
|
| 181 |
)
|
| 182 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
| 183 |
+
while total_token_count > context_length:
|
| 184 |
+
LOG.warning(
|
| 185 |
f"Pruning memory due to total token length of {total_token_count}"
|
| 186 |
)
|
| 187 |
if len(memory_buffer) == 1:
|
| 188 |
memory_buffer.pop(0)
|
| 189 |
continue
|
| 190 |
memory_buffer = memory_buffer[1:]
|
| 191 |
+
messages_to_send = state["template"].format_messages(
|
| 192 |
input=user_msg, history=memory_buffer
|
| 193 |
)
|
| 194 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
|
|
|
| 196 |
|
| 197 |
try:
|
| 198 |
if state is None:
|
| 199 |
+
state = set_state()
|
| 200 |
+
llm = state["llm_state"]["llm"]
|
| 201 |
+
context_length = state["llm_state"]["context_length"]
|
| 202 |
+
tokenizer = state["llm_state"]["tokenizer"]
|
| 203 |
+
LOG.info(f"""[{request.username}] STARTING CHAIN""")
|
| 204 |
+
LOG.debug(f"History: {state['history']}")
|
| 205 |
+
LOG.debug(f"User input: {inp}")
|
| 206 |
+
inp, state["memory"].chat_memory.messages = prep_messages(
|
| 207 |
+
inp, state["memory"].buffer
|
| 208 |
+
)
|
| 209 |
+
messages_to_send = state["template"].format_messages(
|
| 210 |
+
input=inp, history=state["memory"].buffer
|
| 211 |
+
)
|
| 212 |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
|
| 213 |
+
LOG.debug(f"Messages to send: {messages_to_send}")
|
| 214 |
+
LOG.info(f"Tokens to send: {total_token_count}")
|
| 215 |
# Run chain and append input.
|
| 216 |
callback = AsyncIteratorCallbackHandler()
|
| 217 |
+
run = asyncio.create_task(
|
| 218 |
+
state["chain"].apredict(input=inp, callbacks=[callback])
|
| 219 |
+
)
|
| 220 |
+
state["history"].append((inp, ""))
|
| 221 |
async for tok in callback.aiter():
|
| 222 |
+
user, bot = state["history"][-1]
|
| 223 |
bot += tok
|
| 224 |
+
state["history"][-1] = (user, bot)
|
| 225 |
+
yield state["history"], state
|
| 226 |
await run
|
| 227 |
+
LOG.info(f"""[{request.username}] ENDING CHAIN""")
|
| 228 |
+
LOG.debug(f"History: {state['history']}")
|
| 229 |
+
LOG.debug(f"Memory: {state['memory'].json()}")
|
| 230 |
data_to_flag = (
|
| 231 |
{
|
| 232 |
+
"history": deepcopy(state["history"]),
|
| 233 |
"username": request.username,
|
| 234 |
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
| 235 |
+
"session_id": state["session_id"],
|
| 236 |
},
|
| 237 |
)
|
| 238 |
+
LOG.debug(f"Data to flag: {data_to_flag}")
|
| 239 |
gradio_flagger.flag(flag_data=data_to_flag, username=request.username)
|
| 240 |
except Exception as e:
|
| 241 |
+
LOG.exception(e)
|
| 242 |
raise e
|
| 243 |
|
| 244 |
|
|
|
|
| 246 |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
| 247 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
theme = gr.themes.Soft()
|
| 250 |
|
| 251 |
creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))]
|
| 252 |
|
| 253 |
gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats")
|
| 254 |
+
title = "AI Debate Partner"
|
| 255 |
|
| 256 |
with gr.Blocks(
|
|
|
|
| 257 |
theme=theme,
|
| 258 |
analytics_enabled=False,
|
| 259 |
title=title,
|
| 260 |
) as demo:
|
| 261 |
+
state = gr.State()
|
| 262 |
+
gr.Markdown(f"### {title}")
|
| 263 |
+
with gr.Tab("Setup"):
|
| 264 |
+
with gr.Column():
|
| 265 |
+
llm_input = gr.Dropdown(
|
| 266 |
+
label="LLM",
|
| 267 |
+
choices=["Claude 2", "GPT-4"],
|
| 268 |
+
value="GPT-4",
|
| 269 |
+
multiselect=False,
|
| 270 |
+
)
|
| 271 |
+
system_prompt_input = gr.Textbox(
|
| 272 |
+
label="System Prompt", value=SYSTEM_MESSAGE
|
| 273 |
+
)
|
| 274 |
+
update_system_button = gr.Button(value="Update Prompt & Reset")
|
| 275 |
+
status_markdown = gr.Markdown()
|
| 276 |
+
with gr.Tab("Chatbot"):
|
| 277 |
+
with gr.Column():
|
| 278 |
+
chatbot = gr.Chatbot(label="ChatBot")
|
| 279 |
+
inputs = gr.Textbox(
|
| 280 |
+
placeholder="Send a message.",
|
| 281 |
+
label="Type an input and press Enter",
|
| 282 |
+
)
|
| 283 |
+
b1 = gr.Button(value="Submit")
|
| 284 |
|
| 285 |
+
gradio_flagger.setup([chatbot], "chats")
|
| 286 |
|
| 287 |
inputs.submit(
|
| 288 |
respond,
|
|
|
|
| 294 |
[inputs, state],
|
| 295 |
[chatbot, state],
|
| 296 |
)
|
| 297 |
+
update_system_button.click(
|
| 298 |
+
update_system_prompt,
|
| 299 |
+
[system_prompt_input, llm_input],
|
| 300 |
+
[status_markdown, state],
|
| 301 |
+
)
|
| 302 |
|
| 303 |
+
update_system_button.click(reset_textbox, [], [inputs])
|
| 304 |
+
update_system_button.click(reset_textbox, [], [chatbot])
|
| 305 |
b1.click(reset_textbox, [], [inputs])
|
| 306 |
inputs.submit(reset_textbox, [], [inputs])
|
| 307 |
|
| 308 |
+
demo.queue(max_size=99, concurrency_count=99, api_open=False).launch(
|
| 309 |
+
debug=True, # auth=auth
|
| 310 |
)
|