File size: 5,702 Bytes
06ba7ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | from __future__ import annotations
from typing import Any, Optional, Protocol, runtime_checkable
from mcp.server.fastmcp import Context
from mcp.server.session import ServerSession
from mcp.types import SamplingMessage, TextContent, ModelHint, ModelPreferences
from open_storyline.utils.emoji import EmojiManager
class BaseLLMSampling(Protocol):
# Low-level protocol: Sampling shared across multiple tools
async def sampling(
self,
*,
system_prompt: str | None,
messages: list[SamplingMessage],
temperature: float = 0.3,
top_p: float = 0.9,
max_tokens: int = 4096,
model_preferences: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
stop_sequences: list[str] | None = None,
) -> str:
...
@runtime_checkable
class LLMClient(Protocol):
# High-level protocol: Tools are distinguished only by multimodal capability requirement
async def complete(
self,
*,
system_prompt: str | None,
user_prompt: str,
media: list[dict[str, Any]] | None = None,
temperature: float = 0.3,
top_p: float = 0.9,
max_tokens: int = 2048,
model_preferences: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
stop_sequences: list[str] | None = None,
) -> str:
...
class MCPSampler(BaseLLMSampling):
def __init__(self, mcp_ctx: Context[ServerSession, object]):
self._mcp_ctx = mcp_ctx
def _to_mcp_model_preferences(
self,
model_preferences: dict[str, Any] | None,
) -> Optional[ModelPreferences]:
if not model_preferences:
return None
raw_hints = model_preferences.get("hints")
hints: list[ModelHint] | None = None
if isinstance(raw_hints, list):
hints = []
for h in raw_hints:
if isinstance(h, ModelHint):
hints.append(h)
elif isinstance(h, dict):
hints.append(ModelHint(**h))
elif isinstance(h, str):
hints.append(ModelHint(name=h))
return ModelPreferences(
hints=hints,
costPriority=model_preferences.get("costPriority"),
speedPriority=model_preferences.get("speedPriority"),
intelligencePriority=model_preferences.get("intelligencePriority"),
)
def _extract_text(self, content: Any) -> str:
emoji_manager = EmojiManager()
# MCP returns content as either a single block or array; here we only extract text blocks
if isinstance(content, list):
texts: list[str] = []
for block in content:
if getattr(block, "type", None) == "text":
texts.append(block.text)
return emoji_manager.remove_emoji("\n".join(texts).strip())
if getattr(content, "type", None) == "text":
return emoji_manager.remove_emoji(content.text.strip())
return emoji_manager.remove_emoji(str(content))
async def sampling(self,
*,
system_prompt: str | None,
messages: list[SamplingMessage],
temperature: float = 0.3,
top_p: float = 0.9,
max_tokens: int = 4096,
model_preferences: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
stop_sequences: list[str] | None = None
) -> str:
merged_metadata = dict(metadata or {})
merged_metadata["top_p"] = top_p
result = await self._mcp_ctx.session.create_message(
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
temperature=temperature,
# stop_sequences=stop_sequences,
metadata=merged_metadata,
# model_preferences=self._to_mcp_model_preferences(model_preferences),
)
return self._extract_text(result.content)
class SamplingLLMClient(LLMClient):
"""
Only differentiate based on presence of media input.
Server passes media paths and timestamps to Client, Client handles base64 conversion.
"""
def __init__(self, sampler: BaseLLMSampling):
self._sampler = sampler
async def complete(self,
*,
system_prompt: str | None,
user_prompt: str,
media: list[dict[str, Any]] | None = None,
temperature: float = 0.3,
top_p: float = 0.9,
max_tokens: int = 2048,
model_preferences: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
stop_sequences: list[str] | None = None
)-> str:
messages = [
SamplingMessage(
role="user",
content=TextContent(type="text", text=user_prompt),
)
]
merged_metadata = dict(metadata or {})
merged_metadata["modality"] = "multimodal" if media else "text"
if media:
merged_metadata["media"] = media # Critical: Pass media paths and timestamps through transparently
return await self._sampler.sampling(
system_prompt=system_prompt,
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
model_preferences=model_preferences,
metadata=merged_metadata,
stop_sequences=stop_sequences,
)
def make_llm(mcp_ctx: Context[ServerSession, object]) -> LLMClient:
# Tools can directly call llm.complete() via llm = make_llm(ctx)
return SamplingLLMClient(MCPSampler(mcp_ctx))
|