File size: 5,702 Bytes
06ba7ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from __future__ import annotations

from typing import Any, Optional, Protocol, runtime_checkable

from mcp.server.fastmcp import Context
from mcp.server.session import ServerSession
from mcp.types import SamplingMessage, TextContent, ModelHint, ModelPreferences

from open_storyline.utils.emoji import EmojiManager


class BaseLLMSampling(Protocol):
    # Low-level protocol: Sampling shared across multiple tools
    async def sampling(
        self,
        *,
        system_prompt: str | None,
        messages: list[SamplingMessage],
        temperature: float = 0.3,
        top_p: float = 0.9,
        max_tokens: int = 4096,
        model_preferences: dict[str, Any] | None = None,
        metadata: dict[str, Any] | None = None,
        stop_sequences: list[str] | None = None,
    ) -> str:
        ...

@runtime_checkable
class LLMClient(Protocol):
    # High-level protocol: Tools are distinguished only by multimodal capability requirement
    async def complete(
        self,
        *,
        system_prompt: str | None,
        user_prompt: str,
        media: list[dict[str, Any]] | None = None,
        temperature: float = 0.3,
        top_p: float = 0.9,
        max_tokens: int = 2048,
        model_preferences: dict[str, Any] | None = None,
        metadata: dict[str, Any] | None = None,
        stop_sequences: list[str] | None = None,
    ) -> str:
        ...

class MCPSampler(BaseLLMSampling):
    def __init__(self, mcp_ctx: Context[ServerSession, object]):
        self._mcp_ctx = mcp_ctx

    def _to_mcp_model_preferences(
        self,
        model_preferences: dict[str, Any] | None,
    ) -> Optional[ModelPreferences]:
        if not model_preferences:
            return None

        raw_hints = model_preferences.get("hints")
        hints: list[ModelHint] | None = None
        if isinstance(raw_hints, list):
            hints = []
            for h in raw_hints:
                if isinstance(h, ModelHint):
                    hints.append(h)
                elif isinstance(h, dict):
                    hints.append(ModelHint(**h))
                elif isinstance(h, str):
                    hints.append(ModelHint(name=h))

        return ModelPreferences(
            hints=hints,
            costPriority=model_preferences.get("costPriority"),
            speedPriority=model_preferences.get("speedPriority"),
            intelligencePriority=model_preferences.get("intelligencePriority"),
        )
    
    def _extract_text(self, content: Any) -> str:
        emoji_manager = EmojiManager()

        # MCP returns content as either a single block or array; here we only extract text blocks
        if isinstance(content, list):
            texts: list[str] = []
            for block in content:
                if getattr(block, "type", None) == "text":
                    texts.append(block.text)
            return emoji_manager.remove_emoji("\n".join(texts).strip())

        if getattr(content, "type", None) == "text":
            return emoji_manager.remove_emoji(content.text.strip())

        
        return emoji_manager.remove_emoji(str(content))
    
    async def sampling(self,
        *, 
        system_prompt: str | None, 
        messages: list[SamplingMessage], 
        temperature: float = 0.3, 
        top_p: float = 0.9, 
        max_tokens: int = 4096, 
        model_preferences: dict[str, Any] | None = None, 
        metadata: dict[str, Any] | None = None, 
        stop_sequences: list[str] | None = None
    ) -> str:
        merged_metadata = dict(metadata or {})
        merged_metadata["top_p"] = top_p
        
        result = await self._mcp_ctx.session.create_message(
            messages=messages,
            max_tokens=max_tokens,
            system_prompt=system_prompt,
            temperature=temperature,
            # stop_sequences=stop_sequences,
            metadata=merged_metadata,
            # model_preferences=self._to_mcp_model_preferences(model_preferences),
        )
        return self._extract_text(result.content)
    
class SamplingLLMClient(LLMClient):
    """
    Only differentiate based on presence of media input.
    Server passes media paths and timestamps to Client, Client handles base64 conversion.
    """

    def __init__(self, sampler: BaseLLMSampling):
        self._sampler = sampler

    async def complete(self,
        *,
        system_prompt: str | None,
        user_prompt: str,
        media: list[dict[str, Any]] | None = None,
        temperature: float = 0.3,
        top_p: float = 0.9,
        max_tokens: int = 2048,
        model_preferences: dict[str, Any] | None = None,
        metadata: dict[str, Any] | None = None,
        stop_sequences: list[str] | None = None
    )-> str:
        messages = [
            SamplingMessage(
                role="user",
                content=TextContent(type="text", text=user_prompt),
            )
        ]

        merged_metadata = dict(metadata or {})
        merged_metadata["modality"] = "multimodal" if media else "text"
        if media:
            merged_metadata["media"] = media # Critical: Pass media paths and timestamps through transparently

        return await self._sampler.sampling(
            system_prompt=system_prompt,
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            model_preferences=model_preferences,
            metadata=merged_metadata,
            stop_sequences=stop_sequences,
        )

def make_llm(mcp_ctx: Context[ServerSession, object]) -> LLMClient:
    # Tools can directly call llm.complete() via llm = make_llm(ctx)
    return SamplingLLMClient(MCPSampler(mcp_ctx))