from __future__ import annotations from typing import Any, Optional, Protocol, runtime_checkable from mcp.server.fastmcp import Context from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent, ModelHint, ModelPreferences from open_storyline.utils.emoji import EmojiManager class BaseLLMSampling(Protocol): # Low-level protocol: Sampling shared across multiple tools async def sampling( self, *, system_prompt: str | None, messages: list[SamplingMessage], temperature: float = 0.3, top_p: float = 0.9, max_tokens: int = 4096, model_preferences: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, stop_sequences: list[str] | None = None, ) -> str: ... @runtime_checkable class LLMClient(Protocol): # High-level protocol: Tools are distinguished only by multimodal capability requirement async def complete( self, *, system_prompt: str | None, user_prompt: str, media: list[dict[str, Any]] | None = None, temperature: float = 0.3, top_p: float = 0.9, max_tokens: int = 2048, model_preferences: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, stop_sequences: list[str] | None = None, ) -> str: ... class MCPSampler(BaseLLMSampling): def __init__(self, mcp_ctx: Context[ServerSession, object]): self._mcp_ctx = mcp_ctx def _to_mcp_model_preferences( self, model_preferences: dict[str, Any] | None, ) -> Optional[ModelPreferences]: if not model_preferences: return None raw_hints = model_preferences.get("hints") hints: list[ModelHint] | None = None if isinstance(raw_hints, list): hints = [] for h in raw_hints: if isinstance(h, ModelHint): hints.append(h) elif isinstance(h, dict): hints.append(ModelHint(**h)) elif isinstance(h, str): hints.append(ModelHint(name=h)) return ModelPreferences( hints=hints, costPriority=model_preferences.get("costPriority"), speedPriority=model_preferences.get("speedPriority"), intelligencePriority=model_preferences.get("intelligencePriority"), ) def _extract_text(self, content: Any) -> str: emoji_manager = EmojiManager() # MCP returns content as either a single block or array; here we only extract text blocks if isinstance(content, list): texts: list[str] = [] for block in content: if getattr(block, "type", None) == "text": texts.append(block.text) return emoji_manager.remove_emoji("\n".join(texts).strip()) if getattr(content, "type", None) == "text": return emoji_manager.remove_emoji(content.text.strip()) return emoji_manager.remove_emoji(str(content)) async def sampling(self, *, system_prompt: str | None, messages: list[SamplingMessage], temperature: float = 0.3, top_p: float = 0.9, max_tokens: int = 4096, model_preferences: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, stop_sequences: list[str] | None = None ) -> str: merged_metadata = dict(metadata or {}) merged_metadata["top_p"] = top_p result = await self._mcp_ctx.session.create_message( messages=messages, max_tokens=max_tokens, system_prompt=system_prompt, temperature=temperature, # stop_sequences=stop_sequences, metadata=merged_metadata, # model_preferences=self._to_mcp_model_preferences(model_preferences), ) return self._extract_text(result.content) class SamplingLLMClient(LLMClient): """ Only differentiate based on presence of media input. Server passes media paths and timestamps to Client, Client handles base64 conversion. """ def __init__(self, sampler: BaseLLMSampling): self._sampler = sampler async def complete(self, *, system_prompt: str | None, user_prompt: str, media: list[dict[str, Any]] | None = None, temperature: float = 0.3, top_p: float = 0.9, max_tokens: int = 2048, model_preferences: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, stop_sequences: list[str] | None = None )-> str: messages = [ SamplingMessage( role="user", content=TextContent(type="text", text=user_prompt), ) ] merged_metadata = dict(metadata or {}) merged_metadata["modality"] = "multimodal" if media else "text" if media: merged_metadata["media"] = media # Critical: Pass media paths and timestamps through transparently return await self._sampler.sampling( system_prompt=system_prompt, messages=messages, temperature=temperature, top_p=top_p, max_tokens=max_tokens, model_preferences=model_preferences, metadata=merged_metadata, stop_sequences=stop_sequences, ) def make_llm(mcp_ctx: Context[ServerSession, object]) -> LLMClient: # Tools can directly call llm.complete() via llm = make_llm(ctx) return SamplingLLMClient(MCPSampler(mcp_ctx))