Spaces:
Runtime error
Runtime error
| # based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py | |
| import collections | |
| import os | |
| from itertools import islice | |
| from queue import Queue | |
| from anyio.from_thread import start_blocking_portal | |
| import gradio as gr | |
| from diff_match_patch import diff_match_patch | |
| from langchain.chains import LLMChain | |
| from langchain.chat_models import PromptLayerChatOpenAI, ChatOpenAI | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate | |
| from langchain.schema import HumanMessage | |
| from util import SyncStreamingLLMCallbackHandler, concatenate_generators | |
| GRAMMAR_PROMPT = "Proofread for grammar and spelling without adding new paragraphs:\n{content}" | |
| INTRO_PROMPT = """These are the parts of a good introductory paragraph: | |
| 1. Introductory information | |
| 2. The stage of human development of the main character | |
| 3. Summary of story | |
| 4. Thesis statement (this should also provide an overview the essay structure or topics that may be covered in each paragraph) | |
| For each part, put a quote of the sentences from the following paragraph that fulfil that part and say how confident you are (percentage). If you're not confident, explain why. | |
| --- | |
| Example output format: | |
| Thesis statement and outline: | |
| "Sentence A. Sentence B" | |
| Score: X%. Feedback goes here. | |
| --- | |
| Intro paragraph: | |
| {content}""" | |
| BODY_PROMPT1 = """You are a university English teacher. Complete the following tasks for the following essay paragraph about a book: | |
| 1. Topic sentence: Identify the topic sentence and determine whether it introduces an argument | |
| 2. Key points: Outline a bullet list of key points | |
| 3. Supporting evidence: Give a bullet list of parts of the paragraph that use quotes or other textual evidence from the book | |
| {content}""" | |
| BODY_PROMPT2 = """4. Give advice on how the topic sentence could be made stronger or clearer | |
| 5. In a bullet list, state how each key point supports the topic (or if any doesn't support it) | |
| 6. In a bullet list for each supporting evidence, state which key point the evidence supports. | |
| """ | |
| BODY_PROMPT3 = """Briefly summarize "{title}". Then, in a bullet list for each supporting evidence you liisted above, state if it describes an event/detail from the "{title}" or if it's from outside sources. | |
| Use this output format: | |
| [summary] | |
| ---- | |
| - [supporting evidence 1] - book | |
| - [supporting evidence 2] - outside source""" | |
| BODY_DESCRIPTION = """1. identifies the topic sentence | |
| 2. outlines key points | |
| 3. checks for supporting evidence (e.g., quotes, summaries, and concrete details) | |
| 4. suggests topic sentence improvements | |
| 5. checks that the key points match the paragraph topic | |
| 6. determines which key point each piece of evidence supports | |
| 7. checks whether each evidence is from the book or from an outside source""" | |
| def is_empty(s: str): | |
| return len(s) == 0 or s.isspace() | |
| def check_content(s: str): | |
| if is_empty(s): | |
| raise gr.exceptions.Error('Please input some text before running.') | |
| def load_chain(api_key, api_type): | |
| if api_key == "" or api_key.isspace(): | |
| if api_type == "OpenAI": | |
| api_key = os.environ.get("OPENAI_API_KEY", None) | |
| elif api_type == "Azure OpenAI": | |
| api_key = os.environ.get("AZURE_OPENAI_API_KEY", None) | |
| else: | |
| raise RuntimeError("Unknown API type? " + api_type) | |
| if api_key: | |
| shared_args = { | |
| "temperature": 0, | |
| "model_name": "gpt-3.5-turbo", | |
| "api_key": api_key, # deliberately not use "openai_api_key" and other openai args since those apply globally | |
| "pl_tags": ["grammar"], | |
| "streaming": True, | |
| } | |
| if api_type == "OpenAI": | |
| llm = PromptLayerChatOpenAI(**shared_args) | |
| elif api_type == "Azure OpenAI": | |
| llm = PromptLayerChatOpenAI( | |
| api_type = "azure", | |
| api_base = os.environ.get("AZURE_OPENAI_API_BASE", None), | |
| api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"), | |
| engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None), | |
| **shared_args | |
| ) | |
| prompt1 = PromptTemplate( | |
| input_variables=["content"], | |
| template=GRAMMAR_PROMPT | |
| ) | |
| chain = LLMChain(llm=llm, | |
| prompt=prompt1, | |
| memory=ConversationBufferMemory()) | |
| chain_intro = LLMChain(llm=llm, | |
| prompt=PromptTemplate( | |
| input_variables=["content"], | |
| template=INTRO_PROMPT | |
| ), | |
| memory=ConversationBufferMemory()) | |
| chain_body1 = LLMChain(llm=llm, | |
| prompt=PromptTemplate( | |
| input_variables=["content"], | |
| template=BODY_PROMPT1 | |
| ), | |
| memory=ConversationBufferMemory()) | |
| return chain, llm, chain_intro, chain_body1 | |
| def run_diff(content, chain: LLMChain): | |
| check_content(content) | |
| chain.memory.clear() | |
| edited = chain.run(content) | |
| return diff_words(content, edited) + (edited,) | |
| # https://github.com/hwchase17/langchain/issues/2428#issuecomment-1512280045 | |
| def run(content, chain: LLMChain): | |
| check_content(content) | |
| chain.memory.clear() | |
| q = Queue() | |
| job_done = object() | |
| def task(): | |
| result = chain.run(content, callbacks=[SyncStreamingLLMCallbackHandler(q)]) | |
| q.put(job_done) | |
| return result | |
| with start_blocking_portal() as portal: | |
| portal.start_task_soon(task) | |
| output = "" | |
| while True: | |
| next_token = q.get(True, timeout=10) | |
| if next_token is job_done: | |
| break | |
| output += next_token | |
| yield output | |
| # TODO share code with above | |
| def run_followup(followup_question, input_vars, chain, chat: ChatOpenAI): | |
| check_content(followup_question) | |
| history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m | |
| for m in chain.memory.chat_memory.messages] | |
| prompt = ChatPromptTemplate.from_messages([ | |
| *history, | |
| HumanMessagePromptTemplate.from_template(followup_question)]) | |
| messages = prompt.format_prompt(**input_vars).to_messages() | |
| q = Queue() | |
| job_done = object() | |
| def task(): | |
| result = chat.generate([messages], callbacks=[SyncStreamingLLMCallbackHandler(q)]) | |
| q.put(job_done) | |
| return result.generations[0][0].message.content | |
| with start_blocking_portal() as portal: | |
| portal.start_task_soon(task) | |
| output = "" | |
| while True: | |
| next_token = q.get(True, timeout=10) | |
| if next_token is job_done: | |
| break | |
| output += next_token | |
| yield output | |
| def run_body(content, title, chain, llm): | |
| check_content(content) # note: run() also checks, but the error doesn't get shown in the UI? | |
| if not title: | |
| return "Please enter the book title." | |
| yield from concatenate_generators( | |
| run(content, chain), | |
| "\n\n", | |
| run_followup(BODY_PROMPT2, {}, chain, llm), | |
| "\n\n7. Whether supporting evidence is from the book:", | |
| (output.split("----")[-1] for output in run_followup(BODY_PROMPT3, {"title": title}, chain, llm)) | |
| ) | |
| def run_custom(content, llm, prompt): | |
| chain = LLMChain(llm=llm, | |
| memory=ConversationBufferMemory(), | |
| prompt=PromptTemplate( | |
| input_variables=["content"], | |
| template=prompt | |
| )) | |
| return chain.run(content), chain | |
| # not currently used | |
| def split_paragraphs(text): | |
| return [(x, x != "" and not x.startswith("#") and not x.isspace()) for x in text.split("\n")] | |
| def sliding_window(iterable, n): | |
| # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG | |
| it = iter(iterable) | |
| window = collections.deque(islice(it, n), maxlen=n) | |
| if len(window) == n: | |
| yield tuple(window) | |
| for x in it: | |
| window.append(x) | |
| yield tuple(window) | |
| dmp = diff_match_patch() | |
| def diff_words(content, edited): | |
| before = [] | |
| after = [] | |
| changes = [] | |
| change_count = 0 | |
| changed = False | |
| diff = dmp.diff_main(content, edited) | |
| dmp.diff_cleanupSemantic(diff) | |
| diff += [(None, None)] | |
| for [(change, text), (next_change, next_text)] in sliding_window(diff, 2): | |
| if change == 0: | |
| before.append((text, None)) | |
| after.append((text, None)) | |
| else: | |
| if change == -1 and next_change == 1: | |
| change_count += 1 | |
| before.append((text, str(change_count))) | |
| after.append((next_text, str(change_count))) | |
| changes.append((text, next_text)) | |
| changed = True | |
| elif change == -1: | |
| before.append((text, "-")) | |
| elif change == 1: | |
| if changed: | |
| changed = False | |
| else: | |
| after.append((text, "+")) | |
| else: | |
| raise Exception("Unknown change type: " + change) | |
| return before, after, changes | |
| def get_parts(arr, start, end): | |
| return "".join(arr[start:end]) | |
| CHANGES = { | |
| "-": "remove", | |
| "+": "add", | |
| # "→": "change" | |
| } | |
| def select_diff(evt: gr.SelectData, changes): | |
| text, change = evt.value | |
| if not change: | |
| return | |
| change_text = CHANGES.get(change, None) | |
| if change_text: | |
| return f"Why is it better to {change_text} [{text}]?" | |
| # if change == "→": | |
| else: | |
| # clicked = evt.target | |
| # if clicked.label == "Before": | |
| # original = text | |
| # else: | |
| # edited = text | |
| original, edited = changes[int(change) - 1] | |
| # original, edited = text.split("→") | |
| return f"Why is it better to change [{original}] to [{edited}]?" | |
| demo = gr.Blocks(css=""" | |
| .diff-component { | |
| white-space: pre-wrap; | |
| } | |
| .diff-component .textspan.hl { | |
| white-space: normal; | |
| } | |
| """) | |
| with demo: | |
| # api_key = gr.Textbox( | |
| # placeholder="Paste your OpenAPI API key here (sk-...)", | |
| # show_label=False, | |
| # lines=1, | |
| # type="password" | |
| # ) | |
| api_key = gr.State("") | |
| gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>""") | |
| gr.Markdown("""Paste a paragraph below, and then choose one of the modes to generate feedback.""") | |
| content = gr.Textbox( | |
| label="Paragraph" | |
| ) | |
| with gr.Tab("Grammar/Spelling"): | |
| gr.Markdown("Suggests grammar and spelling revisions.") | |
| submit = gr.Button( | |
| value="Revise", | |
| ).style(full_width=False) | |
| with gr.Row(): | |
| output_before = gr.HighlightedText( | |
| label="Before", | |
| combine_adjacent=True, | |
| elem_classes="diff-component" | |
| ).style(color_map={ | |
| "-": "red", | |
| # "→": "yellow", | |
| }) | |
| output_after = gr.HighlightedText( | |
| label="After", | |
| combine_adjacent=True, | |
| elem_classes="diff-component" | |
| ).style(color_map={ | |
| "+": "green", | |
| # "→": "yellow", | |
| }) | |
| followup_question = gr.Textbox( | |
| label="Follow-up Question", | |
| ) | |
| followup_submit = gr.Button( | |
| value="Ask" | |
| ).style(full_width=False) | |
| followup_answer = gr.Textbox( | |
| label="Answer" | |
| ) | |
| with gr.Tab("Intro"): | |
| gr.Markdown("Checks for the key components of an introductory paragraph.") | |
| submit_intro = gr.Button( | |
| value="Run" | |
| ).style(full_width=False) | |
| output_intro = gr.Textbox( | |
| label="Output", | |
| lines=1000, | |
| max_lines=1000 | |
| ) | |
| with gr.Tab("Body Paragraph"): | |
| gr.Markdown(BODY_DESCRIPTION) | |
| title = gr.Textbox( | |
| label="Book Title" | |
| ) | |
| submit_body = gr.Button( | |
| value="Run" | |
| ).style(full_width=False) | |
| output_body = gr.Textbox( | |
| label="Output", | |
| lines=1000, | |
| max_lines=1000 | |
| ) | |
| # with gr.Tab("Custom prompt"): | |
| # gr.Markdown("This mode is for testing and debugging.") | |
| # prompt = gr.Textbox( | |
| # label="Prompt", | |
| # value=GRAMMAR_PROMPT, | |
| # lines=2 | |
| # ) | |
| # submit_custom = gr.Button( | |
| # value="Run" | |
| # ).style(full_width=False) | |
| # output_custom = gr.Textbox( | |
| # label="Output" | |
| # ) | |
| # followup_custom = gr.Textbox( | |
| # label="Follow-up Question" | |
| # ) | |
| # followup_answer_custom = gr.Textbox( | |
| # label="Answer" | |
| # ) | |
| with gr.Tab("Settings"): | |
| api_type = gr.Radio( | |
| ["OpenAI", "Azure OpenAI"], | |
| value="OpenAI", | |
| label="Server", | |
| info="You can try changing this if responses are slow." | |
| ) | |
| changes = gr.State() | |
| edited = gr.State() | |
| chain = gr.State() | |
| llm = gr.State() | |
| chain_intro = gr.State() | |
| chain_body1 = gr.State() | |
| chain_custom = gr.State() | |
| # api_key.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1]) | |
| api_type.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1]) | |
| inputs = [content, chain] | |
| outputs = [output_before, output_after, changes, edited] | |
| # content.submit(run_diff, inputs=inputs, outputs=outputs) | |
| submit.click(run_diff, inputs=inputs, outputs=outputs) | |
| output_before.select(select_diff, changes, followup_question) | |
| output_after.select(select_diff, changes, followup_question) | |
| empty_input = gr.State({}) | |
| inputs2 = [followup_question, empty_input, chain, llm] | |
| outputs2 = followup_answer | |
| followup_question.submit(run_followup, inputs2, outputs2) | |
| followup_submit.click(run_followup, inputs2, outputs2) | |
| submit_intro.click(run, [content, chain_intro], output_intro) | |
| submit_body.click(run_body, [content, title, chain_body1, llm], output_body) # body part A only | |
| # submit_custom.click(run_custom, [content, llm, prompt], [output_custom, chain_custom]) # TODO standardize api--return memory instead of using chain? | |
| # followup_custom.submit(run_followup, [followup_custom, empty_input, chain_custom, llm], followup_answer_custom) | |
| demo.load(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1]) | |
| port = os.environ.get("SERVER_PORT", None) | |
| if port: | |
| port = int(port) | |
| demo.queue() | |
| demo.launch(debug=True, server_port=port, prevent_thread_lock=True) | |