File size: 3,610 Bytes
6f54a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad17e30
6f54a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bec8a1d
e5220ed
6f54a59
 
 
 
ad17e30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
import os
import asyncio
import gradio as gr
from dotenv import load_dotenv
from typing import Annotated, TypedDict

from langgraph.graph import StateGraph, START
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage

from langchain_mcp_adapters.resources import load_mcp_resources
from langchain_mcp_adapters.tools import load_mcp_tools
from langchain_mcp_adapters.prompts import load_mcp_prompt
from langchain_google_genai import ChatGoogleGenerativeAI

from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client

load_dotenv()
MCP_URL = os.getenv("MCP_SERVER_URL", "http://localhost:7860/mcp")
PROMPT_NAME = os.getenv("MCP_PROMPT_NAME", "ship_meme_for_commit")


class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


async def run_for_commit(commit_id: str) -> str:
    async with streamablehttp_client(MCP_URL) as (read, write, _):
        async with ClientSession(read, write) as session:
            await session.initialize()

            # 1) Resource: gitdiff://{commit_id}
            resources = await load_mcp_resources(
                session, uris=[f"gitdiff://{commit_id}"]
            )
            git_diff_content = "".join(
                (r.data for r in resources if isinstance(getattr(r, "data", None), str))
            )

            # 2) Tools (bound to this live session)
            tools = await load_mcp_tools(session)

            # 3) Prompt
            prompts = await load_mcp_prompt(
                session,
                PROMPT_NAME,
                arguments={
                    "commit_id": commit_id,
                    "git_diff_content": git_diff_content,
                },
            )
            sys_text = (
                prompts[0].content
                if prompts
                else "You are an agent. Use tools as needed."
            )

            # 4) LLM + tools
            llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0)
            llm_with_tools = llm.bind_tools(tools)

            # 5) Assistant node (async)
            async def assistant(state: AgentState):
                sys_msg = SystemMessage(content=sys_text)
                msg = await llm_with_tools.ainvoke([sys_msg] + state["messages"])
                return {"messages": [msg]}

            # 6) Graph compile + run inside the same context
            g = StateGraph(AgentState)
            g.add_node("assistant", assistant)
            g.add_node("tools", ToolNode(tools))

            g.add_edge(START, "assistant")
            g.add_conditional_edges("assistant", tools_condition)

            agent = g.compile()

            result = await agent.ainvoke(
                {
                    "messages": [
                        HumanMessage(content=f"Ship a meme for commit {commit_id}")
                    ]
                }
            )
            return result["messages"][-1].content


# -------- Gradio --------
def ui_fn(message, _history):
    commit = (message or "").strip() or "demo-42"
    try:
        return asyncio.run(run_for_commit(commit))
    except Exception as e:
        return f"Client error: {type(e).__name__}: {e}"


demo = gr.ChatInterface(
    fn=ui_fn,
    type="messages",
    title="MemeOps - MCP + LangGraph",
    description="Give your commit ID.",
    examples=["demo-42"],
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))