Update server.py
Browse files
server.py
CHANGED
|
@@ -2,23 +2,24 @@
|
|
| 2 |
import sys, json, asyncio
|
| 3 |
import yfinance as yf
|
| 4 |
from statsmodels.tsa.holtwinters import ExponentialSmoothing
|
| 5 |
-
from langchain_core.tools import StructuredTool
|
| 6 |
from pydantic import BaseModel
|
| 7 |
-
from langchain_core.tools import BaseTool
|
| 8 |
-
from langchain_core.tools import tool
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
# --- Compact MCP Server Logic ---
|
| 13 |
class MCPToolServer:
|
| 14 |
def __init__(self, tools: list[BaseTool]):
|
| 15 |
self.tools = {t.name: t for t in tools}
|
| 16 |
|
| 17 |
async def _handle_request(self, req: dict):
|
| 18 |
method, params, req_id = req.get("method"), req.get("params", {}), req.get("id")
|
|
|
|
| 19 |
if method == "discover":
|
| 20 |
-
result = [
|
|
|
|
|
|
|
|
|
|
| 21 |
return {"jsonrpc": "2.0", "result": result, "id": req_id}
|
|
|
|
| 22 |
if method == "execute":
|
| 23 |
tool_name, tool_args = params.get("tool_name"), params.get("tool_args", {})
|
| 24 |
if tool_to_exec := self.tools.get(tool_name):
|
|
@@ -26,7 +27,12 @@ class MCPToolServer:
|
|
| 26 |
result = await tool_to_exec.ainvoke(tool_args)
|
| 27 |
return {"jsonrpc": "2.0", "result": result, "id": req_id}
|
| 28 |
except Exception as e:
|
| 29 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
return {"jsonrpc": "2.0", "error": {"code": -32601, "message": "Method not found"}, "id": req_id}
|
| 31 |
|
| 32 |
async def serve(self):
|
|
@@ -34,14 +40,17 @@ class MCPToolServer:
|
|
| 34 |
await asyncio.get_event_loop().connect_read_pipe(lambda: asyncio.StreamReaderProtocol(reader), sys.stdin)
|
| 35 |
writer_transport, _ = await asyncio.get_event_loop().connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
|
| 36 |
writer = asyncio.StreamWriter(writer_transport, _, None, asyncio.get_event_loop())
|
|
|
|
| 37 |
while line := await reader.readline():
|
| 38 |
try:
|
| 39 |
response = await self._handle_request(json.loads(line))
|
| 40 |
-
writer.write(json.dumps(response).encode() + b
|
| 41 |
await writer.drain()
|
| 42 |
-
except json.JSONDecodeError:
|
|
|
|
| 43 |
|
| 44 |
-
|
|
|
|
| 45 |
COMMODITY_TICKERS = {"gold": "GC=F", "silver": "SI=F"}
|
| 46 |
|
| 47 |
class PriceInput(BaseModel):
|
|
@@ -49,15 +58,15 @@ class PriceInput(BaseModel):
|
|
| 49 |
|
| 50 |
@tool(args_schema=PriceInput)
|
| 51 |
async def get_current_price(commodity_name: str) -> str:
|
| 52 |
-
"""Gets the most recent
|
| 53 |
ticker = COMMODITY_TICKERS.get(commodity_name.lower())
|
| 54 |
-
if not ticker:
|
|
|
|
| 55 |
try:
|
| 56 |
price = yf.Ticker(ticker).history(period="1d")['Close'].iloc[-1]
|
| 57 |
return f"The current price of {commodity_name} is approx. ${price:.2f} USD."
|
| 58 |
-
except Exception as e:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
class ForecastInput(BaseModel):
|
| 63 |
commodity_name: str
|
|
@@ -65,12 +74,12 @@ class ForecastInput(BaseModel):
|
|
| 65 |
|
| 66 |
@tool(args_schema=ForecastInput)
|
| 67 |
async def get_price_forecast(commodity_name: str, forecast_days: int) -> str:
|
| 68 |
-
"""Generates a 3 to 5 day price forecast
|
| 69 |
ticker = COMMODITY_TICKERS.get(commodity_name.lower())
|
| 70 |
if not ticker:
|
| 71 |
-
return
|
| 72 |
if not 3 <= forecast_days <= 5:
|
| 73 |
-
return "Error: Forecast must be
|
| 74 |
try:
|
| 75 |
data = yf.download(ticker, period="6mo", progress=False)['Close']
|
| 76 |
if data.empty:
|
|
@@ -80,7 +89,7 @@ async def get_price_forecast(commodity_name: str, forecast_days: int) -> str:
|
|
| 80 |
except Exception as e:
|
| 81 |
return f"Error during forecast: {e}"
|
| 82 |
|
| 83 |
-
# ---
|
| 84 |
if __name__ == "__main__":
|
| 85 |
server = MCPToolServer(tools=[get_current_price, get_price_forecast])
|
| 86 |
-
asyncio.run(server.serve())
|
|
|
|
| 2 |
import sys, json, asyncio
|
| 3 |
import yfinance as yf
|
| 4 |
from statsmodels.tsa.holtwinters import ExponentialSmoothing
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
+
from langchain_core.tools import tool, BaseTool
|
|
|
|
| 7 |
|
| 8 |
+
# --- MCP Server Logic ---
|
|
|
|
|
|
|
| 9 |
class MCPToolServer:
|
| 10 |
def __init__(self, tools: list[BaseTool]):
|
| 11 |
self.tools = {t.name: t for t in tools}
|
| 12 |
|
| 13 |
async def _handle_request(self, req: dict):
|
| 14 |
method, params, req_id = req.get("method"), req.get("params", {}), req.get("id")
|
| 15 |
+
|
| 16 |
if method == "discover":
|
| 17 |
+
result = [
|
| 18 |
+
{"name": t.name, "description": t.description, "args_schema": t.args}
|
| 19 |
+
for t in self.tools.values()
|
| 20 |
+
]
|
| 21 |
return {"jsonrpc": "2.0", "result": result, "id": req_id}
|
| 22 |
+
|
| 23 |
if method == "execute":
|
| 24 |
tool_name, tool_args = params.get("tool_name"), params.get("tool_args", {})
|
| 25 |
if tool_to_exec := self.tools.get(tool_name):
|
|
|
|
| 27 |
result = await tool_to_exec.ainvoke(tool_args)
|
| 28 |
return {"jsonrpc": "2.0", "result": result, "id": req_id}
|
| 29 |
except Exception as e:
|
| 30 |
+
return {
|
| 31 |
+
"jsonrpc": "2.0",
|
| 32 |
+
"error": {"code": -32603, "message": str(e)},
|
| 33 |
+
"id": req_id,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
return {"jsonrpc": "2.0", "error": {"code": -32601, "message": "Method not found"}, "id": req_id}
|
| 37 |
|
| 38 |
async def serve(self):
|
|
|
|
| 40 |
await asyncio.get_event_loop().connect_read_pipe(lambda: asyncio.StreamReaderProtocol(reader), sys.stdin)
|
| 41 |
writer_transport, _ = await asyncio.get_event_loop().connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
|
| 42 |
writer = asyncio.StreamWriter(writer_transport, _, None, asyncio.get_event_loop())
|
| 43 |
+
|
| 44 |
while line := await reader.readline():
|
| 45 |
try:
|
| 46 |
response = await self._handle_request(json.loads(line))
|
| 47 |
+
writer.write(json.dumps(response).encode() + b"\n")
|
| 48 |
await writer.drain()
|
| 49 |
+
except json.JSONDecodeError:
|
| 50 |
+
continue
|
| 51 |
|
| 52 |
+
|
| 53 |
+
# --- Tool Definitions ---
|
| 54 |
COMMODITY_TICKERS = {"gold": "GC=F", "silver": "SI=F"}
|
| 55 |
|
| 56 |
class PriceInput(BaseModel):
|
|
|
|
| 58 |
|
| 59 |
@tool(args_schema=PriceInput)
|
| 60 |
async def get_current_price(commodity_name: str) -> str:
|
| 61 |
+
"""Gets the most recent price for gold or silver."""
|
| 62 |
ticker = COMMODITY_TICKERS.get(commodity_name.lower())
|
| 63 |
+
if not ticker:
|
| 64 |
+
return f"Error: '{commodity_name}' is not supported."
|
| 65 |
try:
|
| 66 |
price = yf.Ticker(ticker).history(period="1d")['Close'].iloc[-1]
|
| 67 |
return f"The current price of {commodity_name} is approx. ${price:.2f} USD."
|
| 68 |
+
except Exception as e:
|
| 69 |
+
return f"Error fetching price: {e}"
|
|
|
|
| 70 |
|
| 71 |
class ForecastInput(BaseModel):
|
| 72 |
commodity_name: str
|
|
|
|
| 74 |
|
| 75 |
@tool(args_schema=ForecastInput)
|
| 76 |
async def get_price_forecast(commodity_name: str, forecast_days: int) -> str:
|
| 77 |
+
"""Generates a 3 to 5 day price forecast."""
|
| 78 |
ticker = COMMODITY_TICKERS.get(commodity_name.lower())
|
| 79 |
if not ticker:
|
| 80 |
+
return "Error: Unknown commodity. Use 'gold' or 'silver'."
|
| 81 |
if not 3 <= forecast_days <= 5:
|
| 82 |
+
return "Error: Forecast must be 3, 4, or 5 days."
|
| 83 |
try:
|
| 84 |
data = yf.download(ticker, period="6mo", progress=False)['Close']
|
| 85 |
if data.empty:
|
|
|
|
| 89 |
except Exception as e:
|
| 90 |
return f"Error during forecast: {e}"
|
| 91 |
|
| 92 |
+
# --- EntryPoint ---
|
| 93 |
if __name__ == "__main__":
|
| 94 |
server = MCPToolServer(tools=[get_current_price, get_price_forecast])
|
| 95 |
+
asyncio.run(server.serve())
|