File size: 16,679 Bytes
77169b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
"""
插件抽象与注册表:type_name -> 插件实现。

三层设计:
  AbstractPlugin   — 最底层接口,理论上支持任意协议(非 Cookie、非 SSE 的站点也能接)。
  BaseSitePlugin   — Cookie 认证 + SSE 流式站点的通用编排,插件开发者继承它只需实现 5 个 hook。
  PluginRegistry   — 全局注册表。
"""

import json
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncIterator

from playwright.async_api import BrowserContext, Page

from core.api.schemas import InputAttachment
from core.config.settings import get
from core.plugin.errors import (  # noqa: F401  — re-export for backward compat
    AccountFrozenError,
    BrowserResourceInvalidError,
)
from core.plugin.helpers import (
    apply_cookie_auth,
    create_page_for_site,
    stream_completion_via_sse,
)

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class ResolvedModel:
    public_model: str
    upstream_model: str


# ---------------------------------------------------------------------------
# SiteConfig:纯声明式站点配置
# ---------------------------------------------------------------------------


@dataclass
class SiteConfig:
    """Cookie 认证站点的声明式配置,插件开发者只需填字段,无需写任何方法。"""

    start_url: str
    api_base: str
    cookie_name: str
    cookie_domain: str
    auth_keys: list[str]
    config_section: str = (
        ""  # config.yaml 中的 section,如 "claude",用于覆盖 start_url/api_base
    )


# ---------------------------------------------------------------------------
# AbstractPlugin — 最底层抽象接口
# ---------------------------------------------------------------------------


class AbstractPlugin(ABC):
    """
    各 type(如 claude、kimi)需实现此接口并注册。
    若站点基于 Cookie + SSE,推荐直接继承 BaseSitePlugin 而非此类。
    """

    def __init__(self) -> None:
        self._session_state: dict[str, dict[str, Any]] = {}

    type_name: str

    async def create_page(
        self, context: BrowserContext, reuse_page: Page | None = None
    ) -> Page:
        raise NotImplementedError

    async def apply_auth(
        self,
        context: BrowserContext,
        page: Page,
        auth: dict[str, Any],
        *,
        reload: bool = True,
        **kwargs: Any,
    ) -> None:
        raise NotImplementedError

    async def ensure_request_ready(
        self,
        context: BrowserContext,
        page: Page,
        *,
        request_id: str = "",
        session_id: str | None = None,
        phase: str = "",
        account_id: str = "",
    ) -> None:
        del context, page, request_id, session_id, phase, account_id

    async def create_conversation(
        self,
        context: BrowserContext,
        page: Page,
        **kwargs: Any,
    ) -> str | None:
        raise NotImplementedError

    async def stream_completion(
        self,
        context: BrowserContext,
        page: Page,
        session_id: str,
        message: str,
        **kwargs: Any,
    ) -> AsyncIterator[str]:
        if False:
            yield  # 使抽象方法为 async generator,与子类一致,便于 async for 迭代
        raise NotImplementedError

    def parse_session_id(self, messages: list[dict[str, Any]]) -> str | None:
        return None

    def is_stream_end_event(self, payload: str) -> bool:
        """判断某条流式 payload 是否表示本轮响应已正常结束。默认不识别。"""
        return False

    def has_session(self, session_id: str) -> bool:
        return session_id in self._session_state

    def drop_session(self, session_id: str) -> None:
        self._session_state.pop(session_id, None)

    def drop_sessions(self, session_ids: list[str] | set[str]) -> None:
        for session_id in session_ids:
            self._session_state.pop(session_id, None)

    def model_mapping(self) -> dict[str, str] | None:
        """子类可覆盖;BaseSitePlugin 会从 config_section 的 model_mapping 读取。"""
        return None

    def normalized_model_mapping(self) -> dict[str, str]:
        mapping = self.model_mapping()
        if not isinstance(mapping, dict) or not mapping:
            raise ValueError("model_mapping is not implemented")
        normalized: dict[str, str] = {}
        for public_model, upstream_model in mapping.items():
            public_id = str(public_model or "").strip()
            upstream_id = str(upstream_model or "").strip()
            if public_id and upstream_id:
                normalized[public_id] = upstream_id
        if not normalized:
            raise ValueError("model_mapping is not implemented")
        return normalized

    def listed_model_mapping(self) -> dict[str, str]:
        return self.normalized_model_mapping()

    def default_public_model(self) -> str:
        listed = self.listed_model_mapping()
        if listed:
            return next(iter(listed))
        return next(iter(self.normalized_model_mapping()))

    def resolve_model(self, model: str | None) -> ResolvedModel:
        mapping = self.normalized_model_mapping()
        requested = str(model or "").strip()
        if not requested:
            default_public = self.default_public_model()
            return ResolvedModel(
                public_model=default_public,
                upstream_model=mapping[default_public],
            )
        if requested in mapping:
            return ResolvedModel(
                public_model=requested,
                upstream_model=mapping[requested],
            )
        for public_model, upstream_model in mapping.items():
            if requested == upstream_model:
                return ResolvedModel(
                    public_model=public_model,
                    upstream_model=upstream_model,
                )
        supported = ", ".join(mapping.keys())
        raise ValueError(f"Unknown model: {requested}. Supported models: {supported}")

    def on_http_error(self, message: str, headers: dict[str, str] | None) -> int | None:
        return None

    def stream_transport(self) -> str:
        return "page_fetch"

    def stream_transport_options(
        self,
        context: BrowserContext,
        page: Page,
        session_id: str,
        state: dict[str, Any],
        **kwargs: Any,
    ) -> dict[str, Any]:
        del context, page, session_id, state
        options: dict[str, Any] = {}
        proxy_url = str(kwargs.get("proxy_url") or "").strip()
        proxy_auth = kwargs.get("proxy_auth")
        if proxy_url:
            options["proxy_url"] = proxy_url
        if isinstance(proxy_auth, tuple) and len(proxy_auth) == 2:
            options["proxy_auth"] = proxy_auth
        return options


# ---------------------------------------------------------------------------
# BaseSitePlugin — Cookie + SSE 站点的通用编排
# ---------------------------------------------------------------------------


class BaseSitePlugin(AbstractPlugin):
    """
    Cookie 认证 + SSE 流式站点的公共基类。

    插件开发者继承此类后,只需:
      1. 声明 site = SiteConfig(...)        — 站点配置
      2. 实现 fetch_site_context()           — 获取站点上下文(如 org/user 信息)
      3. 实现 create_session()              — 调用站点 API 创建会话
      4. 实现 build_completion_url/body()    — 拼补全请求的 URL 与 body
      5. 实现 parse_stream_event()          — 解析单条流式事件(如 SSE data)

    create_page / apply_auth / create_conversation / stream_completion
    均由基类自动编排,无需重写。
    """

    site: SiteConfig  # 子类必须赋值

    # ---- 从 config.yaml 读取的 URL 属性(config_section 有值时覆盖默认) ----

    @property
    def start_url(self) -> str:
        if self.site.config_section:
            url = get(self.site.config_section, "start_url")
            if url:
                return str(url).strip()
        return self.site.start_url

    @property
    def api_base(self) -> str:
        if self.site.config_section:
            base = get(self.site.config_section, "api_base")
            if base:
                return str(base).strip()
        return self.site.api_base

    def model_mapping(self) -> dict[str, str] | None:
        """从 config 的 config_section.model_mapping 读取;未配置时返回 None。"""
        if self.site.config_section:
            m = get(self.site.config_section, "model_mapping")
            if isinstance(m, dict) and m:
                return {str(k): str(v) for k, v in m.items()}
        return None

    # ---- 基类全自动实现,子类无需碰 ----

    async def create_page(
        self,
        context: BrowserContext,
        reuse_page: Page | None = None,
    ) -> Page:
        return await create_page_for_site(
            context, self.start_url, reuse_page=reuse_page
        )

    async def apply_auth(
        self,
        context: BrowserContext,
        page: Page,
        auth: dict[str, Any],
        *,
        reload: bool = True,
    ) -> None:
        await apply_cookie_auth(
            context,
            page,
            auth,
            self.site.cookie_name,
            self.site.auth_keys,
            self.site.cookie_domain,
            reload=reload,
        )

    async def create_conversation(
        self,
        context: BrowserContext,
        page: Page,
        **kwargs: Any,
    ) -> str | None:
        extra_kwargs = dict(kwargs)
        request_id = str(extra_kwargs.pop("request_id", "") or "")
        # 调用子类获取站点上下文
        site_context = await self.fetch_site_context(
            context,
            page,
            request_id=request_id,
        )
        if site_context is None:
            logger.warning(
                "[%s] fetch_site_context 返回 None,请确认已登录", self.type_name
            )
            return None
        # 通过站点上下文创建会话
        conv_id = await self.create_session(
            context,
            page,
            site_context,
            request_id=request_id,
            **extra_kwargs,
        )
        if conv_id is None:
            return None
        state: dict[str, Any] = {"site_context": site_context}
        if kwargs.get("timezone") is not None:
            state["timezone"] = kwargs["timezone"]
        public_model = str(kwargs.get("public_model") or "").strip()
        if public_model:
            state["public_model"] = public_model
        upstream_model = str(kwargs.get("upstream_model") or "").strip()
        if upstream_model:
            state["upstream_model"] = upstream_model
        self._session_state[conv_id] = state
        logger.info(
            "[%s] create_conversation done conv_id=%s sessions=%s",
            self.type_name,
            conv_id,
            list(self._session_state.keys()),
        )
        return conv_id

    async def stream_completion(
        self,
        context: BrowserContext,
        page: Page,
        session_id: str,
        message: str,
        **kwargs: Any,
    ) -> AsyncIterator[str]:
        state = self._session_state.get(session_id)
        if not state:
            raise RuntimeError(f"未知会话 ID: {session_id}")

        request_id: str = kwargs.get("request_id", "")
        url = self.build_completion_url(session_id, state)
        attachments = list(kwargs.get("attachments") or [])
        prepared_attachments = await self.prepare_attachments(
            context,
            page,
            session_id,
            state,
            attachments,
            request_id=request_id,
        )
        body = self.build_completion_body(
            message,
            session_id,
            state,
            prepared_attachments,
        )
        body_json = json.dumps(body)

        logger.info(
            "[%s] stream_completion session_id=%s url=%s",
            self.type_name,
            session_id,
            url,
        )

        out_message_ids: list[str] = []
        transport_options = self.stream_transport_options(
            context,
            page,
            session_id,
            state,
            request_id=request_id,
            attachments=attachments,
            proxy_url=kwargs.get("proxy_url"),
            proxy_auth=kwargs.get("proxy_auth"),
        )

        async for text in stream_completion_via_sse(
            context,
            page,
            url,
            body_json,
            self.parse_stream_event,
            request_id,
            on_http_error=self.on_http_error,
            is_terminal_event=self.is_stream_end_event,
            collect_message_id=out_message_ids,
            transport=self.stream_transport(),
            transport_options=transport_options,
        ):
            yield text

        if out_message_ids and session_id in self._session_state:
            self.on_stream_completion_finished(session_id, out_message_ids)

    # ---- 子类必须实现的 hook ----

    @abstractmethod
    async def fetch_site_context(
        self,
        context: BrowserContext,
        page: Page,
        request_id: str = "",
    ) -> dict[str, Any] | None:
        """获取站点上下文信息(如 org_uuid、user_id 等),失败返回 None。"""
        del request_id
        ...

    @abstractmethod
    async def create_session(
        self,
        context: BrowserContext,
        page: Page,
        site_context: dict[str, Any],
        **kwargs: Any,
    ) -> str | None:
        """调用站点 API 创建会话,返回会话 ID,失败返回 None。"""
        ...

    @abstractmethod
    def build_completion_url(self, session_id: str, state: dict[str, Any]) -> str:
        """根据会话状态拼出补全请求的完整 URL。"""
        ...

    @abstractmethod
    def build_completion_body(
        self,
        message: str,
        session_id: str,
        state: dict[str, Any],
        prepared_attachments: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """构建补全请求体,返回 dict(基类负责 json.dumps)。"""
        ...

    @abstractmethod
    def parse_stream_event(
        self,
        payload: str,
    ) -> tuple[list[str], str | None, str | None]:
        """
        解析单条流式事件 payload(如 SSE data 行)。
        返回 (texts, message_id, error_message)。
        """
        ...

    # ---- 子类可选覆盖的 hook(有合理默认值) ----

    def on_stream_completion_finished(
        self,
        session_id: str,
        message_ids: list[str],
    ) -> None:
        """Hook:流式补全结束后调用,子类可按需用 message_ids 更新会话 state(如记续写用的父消息 id)。"""

    async def prepare_attachments(
        self,
        context: BrowserContext,
        page: Page,
        session_id: str,
        state: dict[str, Any],
        attachments: list[InputAttachment],
        request_id: str = "",
    ) -> dict[str, Any]:
        del context, page, session_id, state, attachments, request_id
        return {}


# ---------------------------------------------------------------------------
# PluginRegistry — 全局注册表
# ---------------------------------------------------------------------------


class PluginRegistry:
    """全局插件注册表:type_name -> AbstractPlugin。"""

    _plugins: dict[str, AbstractPlugin] = {}

    @classmethod
    def register(cls, plugin: AbstractPlugin) -> None:
        cls._plugins[plugin.type_name] = plugin

    @classmethod
    def get(cls, type_name: str) -> AbstractPlugin | None:
        return cls._plugins.get(type_name)

    @classmethod
    def resolve_model(cls, type_name: str, model: str | None) -> ResolvedModel:
        plugin = cls.get(type_name)
        if plugin is None:
            raise ValueError(f"Unknown provider: {type_name}")
        return plugin.resolve_model(model)

    @classmethod
    def model_metadata(cls, type_name: str) -> dict[str, Any]:
        plugin = cls.get(type_name)
        if plugin is None:
            raise ValueError(f"Unknown provider: {type_name}")
        mapping = plugin.listed_model_mapping()
        return {
            "provider": type_name,
            "public_models": list(mapping.keys()),
            "model_mapping": mapping,
            "default_model": plugin.default_public_model(),
        }

    @classmethod
    def all_types(cls) -> list[str]:
        return list(cls._plugins.keys())