Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings("ignore", message=".*TqdmWarning.*") | |
| from dotenv import load_dotenv | |
| _ = load_dotenv() | |
| from langgraph.graph import StateGraph, END | |
| from typing import TypedDict, Annotated, List | |
| import operator | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage | |
| from langchain_openai import ChatOpenAI | |
| # from langchain_core.pydantic_v1 import BaseModel | |
| from pydantic import BaseModel | |
| from tavily import TavilyClient | |
| import os | |
| import sqlite3 | |
| class ChatOpenRouter(ChatOpenAI): | |
| openai_api_base: str | |
| openai_api_key: str | |
| model_name: str | |
| def __init__(self, | |
| model_name: str, | |
| openai_api_key: str = None, | |
| openai_api_base: str = "https://openrouter.ai/api/v1", | |
| **kwargs): | |
| openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY') | |
| super().__init__(openai_api_base=openai_api_base, | |
| openai_api_key=openai_api_key, | |
| model_name=model_name, **kwargs) | |
| class AgentState(TypedDict): | |
| task: str | |
| lnode: str | |
| plan: str | |
| draft: str | |
| critique: str | |
| content: List[str] | |
| queries: List[str] | |
| revision_number: int | |
| max_revisions: int | |
| count: Annotated[int, operator.add] | |
| class Queries(BaseModel): | |
| queries: List[str] | |
| class ewriter: | |
| # def __init__(self, model_names = ["GPT-3.5","GPT-3.5","GPT-3.5"]): | |
| def __init__(self, model_names = ["GPT-3.5","Claude 3 Sonnet","Claude 3.5 Sonnet"]): | |
| models = [self.create_model(model_name) for model_name in model_names] | |
| self.model1 = models[0] | |
| self.model2 = models[1] | |
| self.model3 = models[2] | |
| self.PLAN_PROMPT = ("You are an expert writer tasked with writing a high level outline of a short 300 words essay. " | |
| "Write such an outline for the user provided topic. Give the five main headers of an outline of " | |
| "the essay along with any relevant notes or instructions for the sections. ") | |
| self.WRITER_PROMPT = ("You are an essay assistant tasked with writing excellent 300 words essays. " | |
| "Generate the best essay possible for the user's request and the initial outline. " | |
| "If the user provides critique, respond with a revised version of your previous attempts. " | |
| "Utilize all the information below as needed: \n" | |
| "------\n" | |
| "{content}") | |
| self.RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can " | |
| "be used when writing the following essay. Generate a list of search " | |
| "queries that will gather " | |
| "any relevant information. Only generate 3 queries max.") | |
| self.REFLECTION_PROMPT = ("You are a teacher grading an 300 words essay submission. " | |
| "Generate critique and recommendations for the user's submission. " | |
| "Provide detailed recommendations, including requests for length, depth, style, etc.") | |
| self.RESEARCH_CRITIQUE_PROMPT = ("You are a researcher charged with providing information that can " | |
| "be used when making any requested revisions (as outlined below). " | |
| "Generate a list of search queries that will gather any relevant information. " | |
| "Only generate 3 queries max.") | |
| self.tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) | |
| builder = StateGraph(AgentState) | |
| builder.add_node("planner", self.plan_node) | |
| builder.add_node("research_plan", self.research_plan_node) | |
| builder.add_node("generate", self.generation_node) | |
| builder.add_node("reflect", self.reflection_node) | |
| builder.add_node("research_critique", self.research_critique_node) | |
| builder.set_entry_point("planner") | |
| builder.add_conditional_edges( | |
| "generate", | |
| self.should_continue, | |
| {END: END, "reflect": "reflect"} | |
| ) | |
| builder.add_edge("planner", "research_plan") | |
| builder.add_edge("research_plan", "generate") | |
| builder.add_edge("reflect", "research_critique") | |
| builder.add_edge("research_critique", "generate") | |
| # memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False)) | |
| memory = MemorySaver() | |
| self.graph = builder.compile( | |
| checkpointer=memory, | |
| interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_critique'] | |
| ) | |
| def create_model(self, model_name): | |
| if model_name == "GPT-4": | |
| return ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0) | |
| elif model_name == "llama-3 405B": | |
| return ChatOpenRouter(model_name = 'meta-llama/llama-3.1-405b-instruct', temperature=0) | |
| elif model_name == "Claude 3 Sonnet": | |
| return ChatOpenRouter(model_name='anthropic/claude-3-sonnet', temperature=0) | |
| elif model_name == "Claude 3.5 Sonnet": | |
| return ChatOpenRouter(model_name='anthropic/claude-3.5-sonnet', temperature=0) | |
| elif model_name == "llama-3 70B": | |
| return ChatOpenRouter(model_name='meta-llama/llama-3.1-70b-instruct', temperature=0) | |
| else: | |
| # Default model | |
| return ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0) | |
| def plan_node(self, state: AgentState): | |
| messages = [ | |
| SystemMessage(content=self.PLAN_PROMPT), | |
| HumanMessage(content=state['task']) | |
| ] | |
| response = self.model1.invoke(messages) | |
| return {"plan": response.content, | |
| "lnode": "planner", | |
| "count": 1, | |
| } | |
| def research_plan_node(self, state: AgentState): | |
| queries = self.model2.with_structured_output(Queries).invoke([ | |
| SystemMessage(content=self.RESEARCH_PLAN_PROMPT), | |
| HumanMessage(content=state['task']) | |
| ]) | |
| content = state['content'] or [] # add to content | |
| for q in queries.queries: | |
| response = self.tavily.search(query=q, max_results=2) | |
| for r in response['results']: | |
| content.append(r['content']) | |
| return {"content": content, | |
| "queries": queries.queries, | |
| "lnode": "research_plan", | |
| "count": 1, | |
| } | |
| def generation_node(self, state: AgentState): | |
| content = "\n\n".join(state['content'] or []) | |
| user_message = HumanMessage( | |
| content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}") | |
| messages = [ | |
| SystemMessage( | |
| content=self.WRITER_PROMPT.format(content=content) | |
| ), | |
| user_message | |
| ] | |
| response = self.model1.invoke(messages) | |
| return { | |
| "draft": response.content, | |
| "revision_number": state.get("revision_number", 1) + 1, | |
| "lnode": "generate", | |
| "count": 1, | |
| } | |
| def reflection_node(self, state: AgentState): | |
| messages = [ | |
| SystemMessage(content=self.REFLECTION_PROMPT), | |
| HumanMessage(content=state['draft']) | |
| ] | |
| response = self.model3.invoke(messages) | |
| return {"critique": response.content, | |
| "lnode": "reflect", | |
| "count": 1, | |
| } | |
| def research_critique_node(self, state: AgentState): | |
| queries = self.model2.with_structured_output(Queries).invoke([ | |
| SystemMessage(content=self.RESEARCH_CRITIQUE_PROMPT), | |
| HumanMessage(content=state['critique']) | |
| ]) | |
| content = state['content'] or [] | |
| for q in queries.queries: | |
| response = self.tavily.search(query=q, max_results=2) | |
| for r in response['results']: | |
| content.append(r['content']) | |
| return {"content": content, | |
| "lnode": "research_critique", | |
| "count": 1, | |
| } | |
| def should_continue(self, state): | |
| if state["revision_number"] > state["max_revisions"]: | |
| return END | |
| return "reflect" | |
| from langchain_openai import ChatOpenAI | |
| import gradio as gr | |
| import time | |
| class writer_gui: | |
| # def __init__(self, model_names=["GPT-3.5","GPT-3.5","GPT-3.5"], share=False): | |
| def __init__(self, model_names=["GPT-3.5","Claude 3 Sonnet","Claude 3.5 Sonnet"], share=False): | |
| # def __init__(self, share=False): | |
| # self.is_new_session = True # Indicates if it's a new session | |
| self.model_names = model_names | |
| MultiAgent = ewriter(self.model_names) | |
| self.graph = MultiAgent.graph | |
| # self.scribe = scribe | |
| # self.graph = ggr.graph | |
| self.share = share | |
| self.partial_message = "" | |
| self.response = {} | |
| self.max_iterations = 10 | |
| self.iterations = [] | |
| self.threads = [] | |
| self.thread_id = -1 | |
| self.thread = {"configurable": {"thread_id": str(self.thread_id)}} | |
| # #self.sdisps = {} #global | |
| self.demo = self.create_interface() | |
| def update_model(self, model_names): | |
| self.model_names = model_names | |
| # self.model = self.create_model(model_name) | |
| self.graph = langgraph.Graph() | |
| MultiAgent = ewriter(self.model_names) | |
| self.graph = MultiAgent.graph | |
| self.share = share | |
| self.partial_message = "" | |
| self.max_iterations = 10 | |
| self.response = {} | |
| self.iterations = [] | |
| self.threads = [] | |
| self.thread_id = -1 | |
| self.thread = {"configurable": {"thread_id": str(self.thread_id)}} | |
| self.demo = self.create_interface() | |
| def run_agent(self, start,topic,stop_after): | |
| #global partial_message, thread_id,thread | |
| #global response, max_iterations, iterations, threads | |
| if start: | |
| self.iterations.append(0) | |
| config = {'task': topic,"max_revisions": 2,"revision_number": 0, | |
| 'lnode': "", 'planner': "no plan", 'draft': "no draft", 'critique': "no critique", | |
| 'content': ["no content",], 'queries': "no queries", 'count':0} | |
| self.thread_id += 1 # new agent, new thread | |
| self.threads.append(self.thread_id) | |
| else: | |
| config = None | |
| self.thread = {"configurable": {"thread_id": str(self.thread_id)}} | |
| while self.iterations[self.thread_id] < self.max_iterations: | |
| self.response = self.graph.invoke(config, self.thread) | |
| self.iterations[self.thread_id] += 1 | |
| self.partial_message += str(self.response) | |
| self.partial_message += f"\n------------------\n\n" | |
| ## fix | |
| lnode,nnode,_,rev,acount = self.get_disp_state() | |
| yield self.partial_message,lnode,nnode,self.thread_id,rev,acount | |
| config = None #need | |
| #print(f"run_agent:{lnode}") | |
| if not nnode: | |
| #print("Hit the end") | |
| return | |
| if lnode in stop_after: | |
| #print(f"stopping due to stop_after {lnode}") | |
| return | |
| else: | |
| #print(f"Not stopping on lnode {lnode}") | |
| pass | |
| return | |
| def get_disp_state(self,): | |
| current_state = self.graph.get_state(self.thread) | |
| lnode = current_state.values["lnode"] | |
| acount = current_state.values["count"] | |
| rev = current_state.values["revision_number"] | |
| nnode = current_state.next | |
| #print (lnode,nnode,self.thread_id,rev,acount) | |
| return lnode,nnode,self.thread_id,rev,acount | |
| def get_state(self,key): | |
| current_values = self.graph.get_state(self.thread) | |
| if key in current_values.values: | |
| lnode,nnode,self.thread_id,rev,astep = self.get_disp_state() | |
| new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}" | |
| return gr.update(label=new_label, value=current_values.values[key]) | |
| else: | |
| return "" | |
| def get_content(self,): | |
| current_values = self.graph.get_state(self.thread) | |
| if "content" in current_values.values: | |
| content = current_values.values["content"] | |
| lnode,nnode,thread_id,rev,astep = self.get_disp_state() | |
| new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}" | |
| return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n") | |
| else: | |
| return "" | |
| def update_hist_pd(self,): | |
| #print("update_hist_pd") | |
| hist = [] | |
| # curiously, this generator returns the latest first | |
| for state in self.graph.get_state_history(self.thread): | |
| if state.metadata['step'] < 1: | |
| continue | |
| thread_ts = state.config['configurable']['checkpoint_id'] | |
| tid = state.config['configurable']['thread_id'] | |
| count = state.values['count'] | |
| lnode = state.values['lnode'] | |
| rev = state.values['revision_number'] | |
| nnode = state.next | |
| st = f"{tid}:{count}:{lnode}:{nnode}:{rev}:{thread_ts}" | |
| hist.append(st) | |
| return gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts", | |
| choices=hist, value=hist[0],interactive=True) | |
| def find_config(self,thread_ts): | |
| for state in self.graph.get_state_history(self.thread): | |
| config = state.config | |
| if config['configurable']['checkpoint_id'] == thread_ts: | |
| return config | |
| return(None) | |
| def copy_state(self,hist_str): | |
| ''' result of selecting an old state from the step pulldown. Note does not change thread. | |
| This copies an old state to a new current state. | |
| ''' | |
| thread_ts = hist_str.split(":")[-1] | |
| #print(f"copy_state from {thread_ts}") | |
| config = self.find_config(thread_ts) | |
| #print(config) | |
| state = self.graph.get_state(config) | |
| self.graph.update_state(self.thread, state.values, as_node=state.values['lnode']) | |
| new_state = self.graph.get_state(self.thread) #should now match | |
| new_thread_ts = new_state.config['configurable']['checkpoint_id'] | |
| tid = new_state.config['configurable']['thread_id'] | |
| count = new_state.values['count'] | |
| lnode = new_state.values['lnode'] | |
| rev = new_state.values['revision_number'] | |
| nnode = new_state.next | |
| return lnode,nnode,new_thread_ts,rev,count | |
| def update_thread_pd(self,): | |
| #print("update_thread_pd") | |
| return gr.Dropdown(label="choose thread", choices=threads, value=self.thread_id,interactive=True) | |
| def switch_thread(self,new_thread_id): | |
| #print(f"switch_thread{new_thread_id}") | |
| self.thread = {"configurable": {"thread_id": str(new_thread_id)}} | |
| self.thread_id = new_thread_id | |
| self.is_new_thread = True | |
| return | |
| def modify_state(self,key,asnode,new_state): | |
| ''' gets the current state, modifes a single value in the state identified by key, and updates state with it. | |
| note that this will create a new 'current state' node. If you do this multiple times with different keys, it will create | |
| one for each update. Note also that it doesn't resume after the update | |
| ''' | |
| current_values = self.graph.get_state(self.thread) | |
| current_values.values[key] = new_state | |
| self.graph.update_state(self.thread, current_values.values,as_node=asnode) | |
| return | |
| def create_interface(self): | |
| # self.graph = ewriter().graph | |
| # # print(self.graph) | |
| with gr.Blocks(theme=gr.themes.Default(spacing_size='sm',text_size="sm")) as demo: | |
| def updt_disp(): | |
| ''' general update display on state change ''' | |
| current_state = self.graph.get_state(self.thread) | |
| hist = [] | |
| # curiously, this generator returns the latest first | |
| for state in self.graph.get_state_history(self.thread): | |
| if state.metadata['step'] < 1: #ignore early states | |
| continue | |
| s_thread_ts = state.config['configurable']['checkpoint_id'] | |
| s_tid = state.config['configurable']['thread_id'] | |
| s_count = state.values['count'] | |
| s_lnode = state.values['lnode'] | |
| s_rev = state.values['revision_number'] | |
| s_nnode = state.next | |
| st = f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}" | |
| hist.append(st) | |
| if not current_state.metadata: #handle init call | |
| return{} | |
| else: | |
| return { | |
| topic_bx : current_state.values["task"], | |
| lnode_bx : current_state.values["lnode"], | |
| count_bx : current_state.values["count"], | |
| revision_bx : current_state.values["revision_number"], | |
| nnode_bx : current_state.next, | |
| threadid_bx : self.thread_id, | |
| thread_pd : gr.Dropdown(label="choose thread", choices=self.threads, value=self.thread_id,interactive=True), | |
| step_pd : gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts", | |
| choices=hist, value=hist[0],interactive=True), | |
| } | |
| # def make_button_inactive(button): | |
| # return gr.update(variant='secondary', interactive=False) | |
| def make_button_inactive(): | |
| return (gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) | |
| def make_buttons_active(): | |
| return (gr.update(interactive=True), gr.update(interactive=True)) | |
| # def update_model_choice(model_name1, model_name2, model_name3): | |
| # self.update_model([model_name1, model_name2, model_name3]) | |
| # return ( | |
| # gr.update(interactive=False), # Make the first dropdown non-interactive | |
| # gr.update(interactive=False), # Make the second dropdown non-interactive | |
| # gr.update(interactive=False), # Make the third dropdown non-interactive | |
| # gr.update(value = "Thanks for selecting the models")) | |
| # # f"You selected Author: {model_name1}, RA: {model_name2}, and Referee: {model_name3}") | |
| def update_model_choice(model_name1, model_name2, model_name3): | |
| self.update_model([model_name1, model_name2, model_name3]) | |
| # # return gr.update(value = "Thanks for selecting the models") | |
| # return f"Selected models: Writer: {model_name1}, Research Assistant: {model_name2}, Referee: {model_name3}" | |
| def update_model_choice(model_name1, model_name2, model_name3, history): | |
| if history is None: | |
| history = "" | |
| # Append the selected models to the live output history | |
| history += f"Selected models: Writer: {model_name1}, Research Assistant: {model_name2}, Referee: {model_name3}\n" | |
| return history, history # Return the updated history for the live output and state | |
| def reset_all(): | |
| return ( | |
| gr.update(value="", interactive=True), # Reset live output | |
| gr.update(value="GPT-3.5", interactive=True), # Reset model dropdown1 | |
| gr.update(value="Claude 3 Sonnet", interactive=True), # Reset model dropdown2 | |
| gr.update(value="Claude 3.5 Sonnet", interactive=True), # Reset model dropdown3 | |
| gr.update(value="All work and no play makes Jack a Jackass!", interactive=True), # Reset topic box | |
| gr.update(interactive=False), # Set generate button to inactive | |
| gr.update(interactive=False), # Set continue button to inactive | |
| gr.update(interactive=True) # Activate model selection button | |
| ) | |
| # def update_model_choice(model_name1, model_name2, model_name3): | |
| # if self.is_new_session: | |
| # self.update_model([model_name1, model_name2, model_name3]) | |
| # # Create a new button to replace the old one | |
| # new_button = gr.Button("Models Selected", scale=0, min_width=120, variant='secondary', interactive=False) | |
| # return ( | |
| # gr.update(interactive=False), # Make the first dropdown non-interactive | |
| # gr.update(interactive=False), # Make the second dropdown non-interactive | |
| # gr.update(interactive=False), # Make the third dropdown non-interactive | |
| # gr.update(value=f"You selected Author: {model_name1}, RA: {model_name2}, and Referee: {model_name3}"), # Update the value of the live output | |
| # new_button | |
| # ) | |
| # else: | |
| # return gr.update(value="Model choices can only be changed at the start of a new session") | |
| def get_snapshots(): | |
| new_label = f"thread_id: {self.thread_id}, Summary of snapshots" | |
| sstate = "" | |
| for state in self.graph.get_state_history(self.thread): | |
| for key in ['plan', 'draft', 'critique']: | |
| if key in state.values: | |
| state.values[key] = state.values[key][:80] + "..." | |
| if 'content' in state.values: | |
| for i in range(len(state.values['content'])): | |
| state.values['content'][i] = state.values['content'][i][:20] + '...' | |
| if 'writes' in state.metadata: | |
| state.metadata['writes'] = "not shown" | |
| sstate += str(state) + "\n\n" | |
| return gr.update(label=new_label, value=sstate) | |
| def vary_btn(stat): | |
| #print(f"vary_btn{stat}") | |
| return(gr.update(variant=stat)) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=120): | |
| gr.Image("resized_output_graph.png", elem_id="side_image", show_label=False) | |
| with gr.Column(scale=4): | |
| # Full-width text below the image | |
| with gr.Row(): | |
| gr.Markdown("<h1 style='text-align: center; font-size: 2.0em;'>Writing with Research Assistance and Peer Review</h1>") | |
| # gr.Markdown("# Writing with Research Assistance and Critical Feedback") | |
| # The rest of the text follows, covering the full width | |
| gr.Markdown( | |
| """ | |
| In this app, you’ll be an agent: a writer, | |
| supported by another agent: a research assistant, and yet another, a peer reviewer: referee. | |
| A writer plans the essay, the research assistant gathers relevant material, and then writer drafts the essay, | |
| The referee then reviews and provides feedback based on which the writer directs the RA | |
| to get further material and the process repeats, allowing for multiple iterations based on the referee’s input. | |
| While the graphical representation is static, you can explore the progress and actions | |
| by navigating through different tabs in the app and refreshing and/or modifying the output. | |
| If you prefer to get the final version in one go, click on the Agent dropdown and uncheck all the | |
| boxes that are checked as of now for human intervention at any or all stages. For example, | |
| if you would like to offer some input on any of the actions, go ahead and input the text and then | |
| hit Modify at the top. Don't worry about various other | |
| technical options that I have left there for now for my own exploration. Feel free to | |
| play with it, and rest assured it will not use a curse word in return unless you yourself do so. | |
| Dive into your topics, explore the features, and enjoy the writing journey! | |
| """) | |
| with gr.Tab("Agent"): | |
| with gr.Row(): | |
| model_dropdown1 = gr.Dropdown(label="Select Author", choices=["GPT-3.5", "GPT-4", "Claude 3 Sonnet", "Claude 3.5 Sonnet","llama-3 70B","llama-3 405B"], value=self.model_names[0], interactive=True) | |
| model_dropdown2 = gr.Dropdown(label="Select RA", choices=["GPT-3.5", "GPT-4", "Claude 3 Sonnet", "Claude 3.5 Sonnet","llama-3 70B","llama-3 405B"], value=self.model_names[1], interactive=True) | |
| model_dropdown3 = gr.Dropdown(label="Select Referee", choices=["GPT-3.5", "GPT-4", "Claude 3 Sonnet", "Claude 3.5 Sonnet","llama-3 70B","llama-3 405B"], value=self.model_names[2], interactive=True) | |
| models_selected_btn = gr.Button("Submit Selections", scale=0, min_width=120, variant='primary') | |
| with gr.Row(): | |
| topic_bx = gr.Textbox(label = "Essay Topic",value="All work and no play makes Jack a Jackass!") | |
| gen_btn = gr.Button("Generate Essay", scale=0,min_width=80, variant='primary',interactive = False) | |
| cont_btn = gr.Button("Continue Essay", scale=0,min_width=80, interactive = False) | |
| with gr.Row(): | |
| lnode_bx = gr.Textbox(label="Last Action", min_width=100) | |
| nnode_bx = gr.Textbox(label="Next Action", min_width=100) | |
| threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80) | |
| revision_bx = gr.Textbox(label="Draft Rev", scale=0, min_width=80) | |
| count_bx = gr.Textbox(label="count", scale=0, min_width=80) | |
| with gr.Accordion("Manage Agent", open=False): | |
| checks = list(self.graph.nodes.keys()) | |
| checks.remove('__start__') | |
| stop_after = gr.CheckboxGroup(checks,label="Interrupt After State", value=checks, scale=0, min_width=400) | |
| with gr.Row(): | |
| thread_pd = gr.Dropdown(choices=self.threads,interactive=True, label="select thread", min_width=120, scale=0) | |
| step_pd = gr.Dropdown(choices=['N/A'],interactive=True, label="select step", min_width=160, scale=1) | |
| live = gr.Textbox(label="Live Agent Output", placeholder="First assign one model each for the writer, for the RA, and for the referee.", | |
| lines=22, max_lines=22) | |
| # models_selected_btn.click(fn=update_model_choice, inputs=[model_dropdown1, model_dropdown2, model_dropdown3]).then( | |
| # fn=make_button_inactive, inputs=[], outputs=[model_dropdown1, model_dropdown2, model_dropdown3, models_selected_btn]).then( | |
| # fn=make_buttons_active, inputs=[], outputs=[gen_btn, cont_btn]).then( | |
| # fn=lambda: gr.update(value=f"Selected models: Writer: {self.model_names[0]}, Research Assistant: {self.model_names[1]},Referee: {self.model_names[2]}"), | |
| # inputs=[], outputs=live) | |
| # The state parameter is used to store session-specific data (history) | |
| models_selected_btn.click(fn=update_model_choice, inputs=[model_dropdown1, model_dropdown2, model_dropdown3, gr.State()], outputs=[live, gr.State()]).then( | |
| fn=make_button_inactive, inputs=[], outputs=[model_dropdown1, model_dropdown2, model_dropdown3, models_selected_btn] | |
| ).then(fn=make_buttons_active, inputs=[], outputs=[gen_btn, cont_btn]) | |
| # Reset button with session-specific reset functionality | |
| reset_btn = gr.Button("Reset Session") | |
| reset_btn.click( | |
| fn=reset_all, | |
| inputs=[], | |
| outputs=[live, model_dropdown1, model_dropdown2, model_dropdown3, topic_bx, gen_btn, cont_btn, models_selected_btn] | |
| ) | |
| # models_selected_btn.click(fn=update_model_choice, inputs=[model_dropdown1, model_dropdown2, model_dropdown3], | |
| # outputs=[model_dropdown1, model_dropdown2, model_dropdown3, live]).then( | |
| # make_button_inactive, gr.Number('secondary', visible=False), models_selected_btn).then( | |
| # fn=make_buttons_active, inputs=[], outputs=[gen_btn, cont_btn]) | |
| sdisps =[topic_bx,lnode_bx,nnode_bx,threadid_bx,revision_bx,count_bx,step_pd,thread_pd] | |
| thread_pd.input(self.switch_thread, [thread_pd], None).then( | |
| fn=updt_disp, inputs=None, outputs=sdisps) | |
| step_pd.input(self.copy_state,[step_pd],None).then( | |
| fn=updt_disp, inputs=None, outputs=sdisps) | |
| gen_btn.click(vary_btn,gr.Number("secondary", visible=False), gen_btn).then( | |
| fn=self.run_agent, inputs=[gr.Number(True, visible=False),topic_bx,stop_after], outputs=[live],show_progress=True).then( | |
| fn=updt_disp, inputs=None, outputs=sdisps).then( | |
| vary_btn,gr.Number("primary", visible=False), gen_btn).then( | |
| vary_btn,gr.Number("primary", visible=False), cont_btn) | |
| cont_btn.click(vary_btn,gr.Number("secondary", visible=False), cont_btn).then( | |
| fn=self.run_agent, inputs=[gr.Number(False, visible=False),topic_bx,stop_after], | |
| outputs=[live]).then( | |
| fn=updt_disp, inputs=None, outputs=sdisps).then( | |
| vary_btn,gr.Number("primary", visible=False), cont_btn) | |
| with gr.Tab("Plan"): | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh") | |
| modify_btn = gr.Button("Modify") | |
| plan = gr.Textbox(label="Plan", lines=10, interactive=True) | |
| refresh_btn.click(fn=self.get_state, inputs=gr.Number("plan", visible=False), outputs=plan) | |
| modify_btn.click(fn=self.modify_state, inputs=[gr.Number("plan", visible=False), | |
| gr.Number("planner", visible=False), plan],outputs=None).then( | |
| fn=updt_disp, inputs=None, outputs=sdisps) | |
| with gr.Tab("Research Content"): | |
| refresh_btn = gr.Button("Refresh") | |
| content_bx = gr.Textbox(label="content", lines=10) | |
| refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx) | |
| with gr.Tab("Draft"): | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh") | |
| modify_btn = gr.Button("Modify") | |
| draft_bx = gr.Textbox(label="draft", lines=10, interactive=True) | |
| refresh_btn.click(fn=self.get_state, inputs=gr.Number("draft", visible=False), outputs=draft_bx) | |
| modify_btn.click(fn=self.modify_state, inputs=[gr.Number("draft", visible=False), | |
| gr.Number("generate", visible=False), draft_bx], outputs=None).then( | |
| fn=updt_disp, inputs=None, outputs=sdisps) | |
| with gr.Tab("Critique"): | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh") | |
| modify_btn = gr.Button("Modify") | |
| critique_bx = gr.Textbox(label="Critique", lines=10, interactive=True) | |
| refresh_btn.click(fn=self.get_state, inputs=gr.Number("critique", visible=False), outputs=critique_bx) | |
| modify_btn.click(fn=self.modify_state, inputs=[gr.Number("critique", visible=False), | |
| gr.Number("reflect", visible=False), | |
| critique_bx], outputs=None).then( | |
| fn=updt_disp, inputs=None, outputs=sdisps) | |
| with gr.Tab("StateSnapShots"): | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh") | |
| snapshots = gr.Textbox(label="State Snapshots Summaries") | |
| refresh_btn.click(fn=get_snapshots, inputs=None, outputs=snapshots) | |
| return demo | |
| def launch(self, share=None): | |
| if port := os.getenv("PORT1"): | |
| self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0") | |
| else: | |
| self.demo.launch(share=self.share) | |