Spaces:
Sleeping
Sleeping
| from dotenv import load_dotenv | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| HumanMessage, | |
| ToolMessage, | |
| ) | |
| import base64 | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langgraph.graph import END, StateGraph, START | |
| from typing import Annotated, List | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_core.tools import tool | |
| from langchain_experimental.utilities import PythonREPL | |
| import operator | |
| from typing import Annotated, Sequence, TypedDict | |
| from langchain_groq import ChatGroq | |
| import functools | |
| from langchain_core.messages import AIMessage | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langgraph.prebuilt import ToolNode | |
| from typing import Literal | |
| import gradio as gr | |
| import io | |
| import PIL | |
| load_dotenv() | |
| llm_coder = ChatGroq(temperature=0, model_name="llama-3.1-8b-instant") | |
| llm_image = ChatGoogleGenerativeAI( | |
| model="gemini-1.5-flash", | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| ) | |
| search_tool = DuckDuckGoSearchRun() | |
| repl_tool = PythonREPL() | |
| def python_repl( | |
| code: Annotated[str, "The python code to execute to answer the question."], | |
| ): | |
| """Use this to execute python code. If you want to see the output of a value, | |
| you should print it out with `print(...)`. This is visible to the user.""" | |
| try: | |
| result = repl_tool.run(code) | |
| except BaseException as e: | |
| return f"Failed to execute. Error: {repr(e)}" | |
| result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}" | |
| return ( | |
| result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER." | |
| ) | |
| def create_agent(llm, tools, system_message: str): | |
| """Create an agent.""" | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| "You are a helpful AI assistant, collaborating with other assistants." | |
| " Use the provided tools to progress towards answering the question." | |
| " If you are unable to fully answer, that's OK, another assistant with different tools " | |
| " will help where you left off. Execute what you can to make progress." | |
| " If you or any of the other assistants have the final answer or deliverable," | |
| " prefix your response with FINAL ANSWER so the team knows to stop." | |
| " You have access to the following tools: {tool_names}.\n{system_message}", | |
| ), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ] | |
| ) | |
| prompt = prompt.partial(system_message=system_message) | |
| prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) | |
| return prompt | llm.bind_tools(tools) | |
| class AgentState(TypedDict): | |
| messages: Annotated[Sequence[BaseMessage], operator.add] | |
| sender: str | |
| def agent_node(state, agent, name): | |
| result = agent.invoke(state) | |
| if isinstance(result, ToolMessage): | |
| pass | |
| else: | |
| result = AIMessage(**result.dict(exclude={"type", "name"}), name=name) | |
| return { | |
| "messages": [result], | |
| "sender": name, | |
| } | |
| problem_agent = create_agent( | |
| llm_image, | |
| [], | |
| system_message="You should understand the problem properly and provide a clear description with the edge cases, don't provide the solution, after completing all tasks." | |
| ) | |
| problem_node = functools.partial(agent_node, agent=problem_agent, name="problem_agent") | |
| solution_agent = create_agent( | |
| llm_image, | |
| [], | |
| system_message="after understanding the problem, you should provide a solution to the problem in python that is clear and concise and solves all edge cases, also provide intuition behind the solution." | |
| ) | |
| solution_node = functools.partial(agent_node, agent=solution_agent, name="solution_agent") | |
| checker_agent = create_agent( | |
| llm_coder, | |
| [], | |
| system_message="critically analyze the solution provided by the solution agent, check for correctness, efficiency, and edge cases, if the solution is correct, provide a message saying so, if not, provide a message with the error and suggest a fix." | |
| ) | |
| def checker_node(state): | |
| text_only_messages = [] | |
| for msg in state["messages"]: | |
| if isinstance(msg.content, list): | |
| text_content = [item["text"] for item in msg.content if item["type"] == "text"] | |
| new_msg = msg.copy() | |
| new_msg.content = " ".join(text_content) | |
| text_only_messages.append(new_msg) | |
| else: | |
| text_only_messages.append(msg) | |
| text_only_state = { | |
| "messages": text_only_messages, | |
| "sender": state["sender"] | |
| } | |
| result = checker_agent.invoke(text_only_state) | |
| if isinstance(result, ToolMessage): | |
| pass | |
| else: | |
| result = AIMessage(**result.dict(exclude={"type", "name"}), name="checker_agent") | |
| return { | |
| "messages": [result], | |
| "sender": "checker_agent", | |
| } | |
| tools = [search_tool, python_repl] | |
| tool_node = ToolNode(tools) | |
| def router(state) -> Literal["call_tool", "__end__", "continue"]: | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| if last_message.tool_calls: | |
| return "call_tool" | |
| if "FINAL ANSWER" in last_message.content: | |
| return "__end__" | |
| return "continue" | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("problem_creator", problem_node) | |
| workflow.add_node("solution_generator", solution_node) | |
| workflow.add_node("checker_agent", checker_node) | |
| workflow.add_node("call_tool", tool_node) | |
| workflow.add_conditional_edges( | |
| "problem_creator", | |
| router, | |
| {"continue": "solution_generator", "call_tool": "call_tool", "__end__": END}, | |
| ) | |
| workflow.add_conditional_edges( | |
| "solution_generator", | |
| router, | |
| {"continue": "checker_agent", "call_tool": "call_tool", "__end__": END}, | |
| ) | |
| workflow.add_conditional_edges( | |
| "checker_agent", | |
| router, | |
| {"continue": "problem_creator", "call_tool": "call_tool", "__end__": END}, | |
| ) | |
| workflow.add_conditional_edges( | |
| "call_tool", | |
| lambda x: x["sender"], | |
| { | |
| "problem_creator": "problem_creator", | |
| "solution_generator": "solution_generator", | |
| "checker_agent": "checker_agent", | |
| }, | |
| ) | |
| workflow.add_edge(START, "problem_creator") | |
| graph = workflow.compile() | |
| def process_images(images: List[tuple[PIL.Image.Image, str | None]]): | |
| if not images: | |
| return "No images uploaded" | |
| # Convert all images to base64 | |
| image_contents = [] | |
| for (image, _) in images: | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| image_contents.append({ | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{img_str}"} | |
| }) | |
| # Create the input for the workflow | |
| input_data = {"messages": [HumanMessage( | |
| content = [ | |
| {"type": "text", "text": "answer the question about the following images"}, | |
| *image_contents | |
| ] | |
| )]} | |
| # Run the workflow | |
| output = [] | |
| try: | |
| for chunk in graph.stream(input_data, {"recursion_limit": 10}, stream_mode="values"): | |
| message = chunk["messages"][-1] | |
| output.append(f"{message.name}: {message.content}") | |
| except Exception as e: | |
| output.append(f"Error: {repr(e)}") | |
| return "\n\n".join(output) | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=process_images, | |
| inputs=[gr.Gallery(label="Upload an image", type="pil")], | |
| outputs=[gr.Markdown(label="Output", show_copy_button=True)], | |
| title="Image Question Answering", | |
| description="Upload an image to get it processed and answered." | |
| ) | |
| # Launch the interface | |
| iface.launch() |