File size: 4,870 Bytes
3b11c41
a1afbba
3b11c41
 
2190f31
a94505a
5bea503
a94505a
a1afbba
3b11c41
a1afbba
3b11c41
a1afbba
 
a94505a
3b11c41
a94505a
 
 
 
a1afbba
a94505a
3b11c41
a1afbba
 
 
 
 
 
a94505a
 
 
 
 
 
a1afbba
3b11c41
 
 
 
a1afbba
 
a94505a
a1afbba
3b11c41
a1afbba
a94505a
3b11c41
a94505a
 
3b11c41
a94505a
 
3b11c41
 
2190f31
 
 
 
3b11c41
a94505a
a1afbba
a94505a
 
3b11c41
a1afbba
 
a94505a
 
2190f31
 
 
 
 
 
69c9716
 
 
1abadd0
69c9716
1abadd0
 
 
 
 
 
 
 
 
 
69c9716
1abadd0
69c9716
 
1abadd0
69c9716
1abadd0
69c9716
1abadd0
69c9716
1abadd0
 
 
 
2190f31
1abadd0
 
 
69c9716
3b11c41
a1afbba
2190f31
1abadd0
69c9716
1abadd0
69c9716
2190f31
 
69c9716
 
3b11c41
1abadd0
a94505a
3b11c41
a1afbba
a94505a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# server.py
import sys, json, asyncio
import yfinance as yf
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from pydantic import BaseModel
from langchain_core.tools import tool, BaseTool

# --- MCP Server Logic ---
class MCPToolServer:
    def __init__(self, tools: list[BaseTool]):
        self.tools = {t.name: t for t in tools}

    async def _handle_request(self, req: dict):
        method, params, req_id = req.get("method"), req.get("params", {}), req.get("id")
        
        if method == "discover":
            result = [
                {"name": t.name, "description": t.description, "args_schema": t.args}
                for t in self.tools.values()
            ]
            return {"jsonrpc": "2.0", "result": result, "id": req_id}

        if method == "execute":
            tool_name, tool_args = params.get("tool_name"), params.get("tool_args", {})
            if tool_to_exec := self.tools.get(tool_name):
                try:
                    result = await tool_to_exec.ainvoke(tool_args)
                    return {"jsonrpc": "2.0", "result": result, "id": req_id}
                except Exception as e:
                    return {
                        "jsonrpc": "2.0",
                        "error": {"code": -32603, "message": str(e)},
                        "id": req_id,
                    }

        return {"jsonrpc": "2.0", "error": {"code": -32601, "message": "Method not found"}, "id": req_id}

    async def serve(self):
        reader = asyncio.StreamReader()
        await asyncio.get_event_loop().connect_read_pipe(lambda: asyncio.StreamReaderProtocol(reader), sys.stdin)
        writer_transport, _ = await asyncio.get_event_loop().connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
        writer = asyncio.StreamWriter(writer_transport, _, None, asyncio.get_event_loop())

        while line := await reader.readline():
            try:
                response = await self._handle_request(json.loads(line))
                writer.write(json.dumps(response).encode() + b"\n")
                await writer.drain()
            except json.JSONDecodeError:
                continue


# --- Tool Definitions ---
COMMODITY_TICKERS = {"gold": "GC=F", "silver": "SI=F"}

class PriceInput(BaseModel):
    commodity_name: str

@tool(args_schema=PriceInput)
async def get_current_price(commodity_name: str) -> str:
    """Gets the most recent price for gold or silver."""
    ticker = COMMODITY_TICKERS.get(commodity_name.lower())
    if not ticker:
        return f"Error: '{commodity_name}' is not supported."
    try:
        price = yf.Ticker(ticker).history(period="1d")['Close'].iloc[-1]
        return f"The current price of {commodity_name} is approx. ${price:.2f} USD."
    except Exception as e:
        return f"Error fetching price: {e}"

class ForecastInput(BaseModel):
    commodity_name: str
    forecast_days: int

@tool(args_schema=ForecastInput)
async def get_price_forecast(commodity_name: str = None, forecast_days: int = None, **kwargs) -> str:
    """
    Generates a 3 to 5 day price forecast for gold or silver.
    Also handles positional-style input like ['gold', 5].
    """

    # 🔹 Step 1: Handle positional args (['gold', 5])
    if not commodity_name and "args" in kwargs and isinstance(kwargs["args"], list):
        args_list = kwargs["args"]
        if len(args_list) >= 2:
            commodity_name = args_list[0]
            forecast_days = int(args_list[1])

    # 🔹 Step 2: Handle JSON input like {"commodity_name": "gold", "forecast_days": 5}
    if not commodity_name:
        commodity_name = kwargs.get("commodity_name")
    if not forecast_days:
        forecast_days = kwargs.get("forecast_days")

    # 🔹 Step 3: Validate
    if not commodity_name:
        return "Error: Please provide a commodity (gold or silver)."
    if not forecast_days:
        return "Error: Please provide forecast_days (3, 4, or 5)."

    commodity_name = str(commodity_name).lower()
    forecast_days = int(forecast_days)

    ticker = COMMODITY_TICKERS.get(commodity_name)
    if not ticker:
        return f"Error: '{commodity_name}' is not supported. Use 'gold' or 'silver'."
    if not 3 <= forecast_days <= 5:
        return "Error: forecast_days must be 3, 4, or 5."

    try:
        data = yf.download(ticker, period="6mo", progress=False)['Close']
        if data.empty:
            return "Not enough historical data for forecasting."

        forecast = ExponentialSmoothing(data, trend="add").fit().forecast(steps=forecast_days)

        return "\n".join([f"Day {i+1}: ${val:.2f} USD" for i, val in enumerate(forecast)])
    except Exception as e:
        return f"Error during forecasting: {e}"



# --- EntryPoint ---
if __name__ == "__main__":
    server = MCPToolServer(tools=[get_current_price, get_price_forecast])
    asyncio.run(server.serve())