briankchan commited on
Commit
4aaf91a
·
1 Parent(s): 9ecdc3d

Change outputs to use streaming

Browse files
Files changed (2) hide show
  1. app.py +62 -15
  2. util.py +25 -0
app.py CHANGED
@@ -2,15 +2,20 @@
2
  import collections
3
  import os
4
  from itertools import islice
 
5
 
 
6
  import gradio as gr
7
  from diff_match_patch import diff_match_patch
8
  from langchain.chains import LLMChain
9
- from langchain.chat_models import PromptLayerChatOpenAI
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.prompts import PromptTemplate
12
  from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
13
  from langchain.schema import HumanMessage
 
 
 
14
 
15
  GRAMMAR_PROMPT = "Proofread for grammar and spelling without adding new paragraphs:\n{content}"
16
 
@@ -71,6 +76,7 @@ def load_chain(api_key, api_type):
71
  "model_name": "gpt-3.5-turbo",
72
  "api_key": api_key, # deliberately not use "openai_api_key" and other openai args since those apply globally
73
  "pl_tags": ["grammar"],
 
74
  }
75
  if api_type == "OpenAI":
76
  llm = PromptLayerChatOpenAI(**shared_args)
@@ -106,16 +112,35 @@ def load_chain(api_key, api_type):
106
  return chain, llm, chain_intro, chain_body1
107
 
108
 
109
- def run_diff(content, chain):
110
  chain.memory.clear()
111
- edited = "\n".join([(chain.run(x) if should_check else x) for x, should_check in split_paragraphs(content)])
112
  return diff_words(content, edited) + (edited,)
113
 
114
- def run(content, chain):
 
115
  chain.memory.clear()
116
- return chain.run(content)
117
 
118
- def run_followup(followup_question, input_vars, chain, chat):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m
121
  for m in chain.memory.chat_memory.messages]
@@ -123,16 +148,37 @@ def run_followup(followup_question, input_vars, chain, chat):
123
  *history,
124
  HumanMessagePromptTemplate.from_template(followup_question)])
125
  messages = prompt.format_prompt(**input_vars).to_messages()
126
- return chat(messages).content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def run_body(content, title, chain, llm):
129
  if not title:
130
  return "Please enter the book title."
131
- output1 = run(content, chain)
132
- output2 = run_followup(BODY_PROMPT2, {}, chain, llm)
133
- output3 = run_followup(BODY_PROMPT3, {"title": title}, chain, llm)
134
- output3 = output3.split("----")[-1]
135
- return output1 + "\n\n" + output2 + "\n\n7. Whether supporting evidence is from the book:" + output3
 
 
 
136
 
137
  def run_custom(content, llm, prompt):
138
  chain = LLMChain(llm=llm,
@@ -143,9 +189,9 @@ def run_custom(content, llm, prompt):
143
  ))
144
  return chain.run(content), chain
145
 
 
146
  def split_paragraphs(text):
147
- # return [(x, x != "" and not x.startswith("#") and not x.isspace()) for x in text.split("\n")]
148
- return [(text, True)]
149
 
150
  def sliding_window(iterable, n):
151
  # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
@@ -167,7 +213,7 @@ def diff_words(content, edited):
167
  diff = dmp.diff_main(content, edited)
168
  dmp.diff_cleanupSemantic(diff)
169
  diff += [(None, None)]
170
- # print(diff)
171
  for [(change, text), (next_change, next_text)] in sliding_window(diff, 2):
172
  if change == 0:
173
  before.append((text, None))
@@ -364,4 +410,5 @@ with demo:
364
  port = os.environ.get("SERVER_PORT", None)
365
  if port:
366
  port = int(port)
 
367
  demo.launch(debug=True, server_port=port)
 
2
  import collections
3
  import os
4
  from itertools import islice
5
+ from queue import Queue
6
 
7
+ from anyio.from_thread import start_blocking_portal
8
  import gradio as gr
9
  from diff_match_patch import diff_match_patch
10
  from langchain.chains import LLMChain
11
+ from langchain.chat_models import PromptLayerChatOpenAI, ChatOpenAI
12
  from langchain.memory import ConversationBufferMemory
13
  from langchain.prompts import PromptTemplate
14
  from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
15
  from langchain.schema import HumanMessage
16
+ from langchain.callbacks.manager import AsyncCallbackManager
17
+
18
+ from util import StreamingLLMCallbackHandler, concatenate_generators
19
 
20
  GRAMMAR_PROMPT = "Proofread for grammar and spelling without adding new paragraphs:\n{content}"
21
 
 
76
  "model_name": "gpt-3.5-turbo",
77
  "api_key": api_key, # deliberately not use "openai_api_key" and other openai args since those apply globally
78
  "pl_tags": ["grammar"],
79
+ "streaming": True,
80
  }
81
  if api_type == "OpenAI":
82
  llm = PromptLayerChatOpenAI(**shared_args)
 
112
  return chain, llm, chain_intro, chain_body1
113
 
114
 
115
+ def run_diff(content, chain: LLMChain):
116
  chain.memory.clear()
117
+ edited = chain.run(content)
118
  return diff_words(content, edited) + (edited,)
119
 
120
+ # https://github.com/hwchase17/langchain/issues/2428#issuecomment-1512280045
121
+ def run(content, chain: LLMChain):
122
  chain.memory.clear()
 
123
 
124
+ q = Queue()
125
+ job_done = object()
126
+ async def task():
127
+ result = await chain.arun(content, callbacks=[StreamingLLMCallbackHandler(q)])
128
+ q.put(job_done)
129
+ return result
130
+
131
+ with start_blocking_portal() as portal:
132
+ portal.start_task_soon(task)
133
+
134
+ content = ""
135
+ while True:
136
+ next_token = q.get(True, timeout=10)
137
+ if next_token is job_done:
138
+ break
139
+ content += next_token
140
+ yield content
141
+
142
+ # TODO share code with above
143
+ def run_followup(followup_question, input_vars, chain, chat: ChatOpenAI):
144
 
145
  history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m
146
  for m in chain.memory.chat_memory.messages]
 
148
  *history,
149
  HumanMessagePromptTemplate.from_template(followup_question)])
150
  messages = prompt.format_prompt(**input_vars).to_messages()
151
+
152
+ q = Queue()
153
+ job_done = object()
154
+ async def task():
155
+ result = await chat.agenerate([messages], callbacks=[StreamingLLMCallbackHandler(q)])
156
+ q.put(job_done)
157
+ return result.generations[0][0].message.content
158
+
159
+ with start_blocking_portal() as portal:
160
+ portal.start_task_soon(task)
161
+
162
+ content = ""
163
+ while True:
164
+ next_token = q.get(True, timeout=10)
165
+ if next_token is job_done:
166
+ break
167
+ content += next_token
168
+ yield content
169
+
170
 
171
  def run_body(content, title, chain, llm):
172
  if not title:
173
  return "Please enter the book title."
174
+
175
+ yield from concatenate_generators(
176
+ run(content, chain),
177
+ "\n\n",
178
+ run_followup(BODY_PROMPT2, {}, chain, llm),
179
+ "\n\n7. Whether supporting evidence is from the book:",
180
+ (output.split("----")[-1] for output in run_followup(BODY_PROMPT3, {"title": title}, chain, llm))
181
+ )
182
 
183
  def run_custom(content, llm, prompt):
184
  chain = LLMChain(llm=llm,
 
189
  ))
190
  return chain.run(content), chain
191
 
192
+ # not currently used
193
  def split_paragraphs(text):
194
+ return [(x, x != "" and not x.startswith("#") and not x.isspace()) for x in text.split("\n")]
 
195
 
196
  def sliding_window(iterable, n):
197
  # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
 
213
  diff = dmp.diff_main(content, edited)
214
  dmp.diff_cleanupSemantic(diff)
215
  diff += [(None, None)]
216
+
217
  for [(change, text), (next_change, next_text)] in sliding_window(diff, 2):
218
  if change == 0:
219
  before.append((text, None))
 
410
  port = os.environ.get("SERVER_PORT", None)
411
  if port:
412
  port = int(port)
413
+ demo.queue()
414
  demo.launch(debug=True, server_port=port)
util.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from types import GeneratorType
3
+ from langchain.callbacks.base import AsyncCallbackHandler
4
+
5
+ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
6
+ """Callback handler for streaming LLM responses to a queue."""
7
+
8
+ def __init__(self, q):
9
+ self.q = q
10
+
11
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
12
+ self.q.put(token)
13
+
14
+
15
+ def concatenate_generators(*args):
16
+ final_outputs = ""
17
+ for g in args:
18
+ if isinstance(g, GeneratorType):
19
+ for v in g:
20
+ yield final_outputs + v
21
+ result = v
22
+ else:
23
+ yield final_outputs + g
24
+ result = g
25
+ final_outputs += result