web2api / core /plugin /base.py
ohmyapi's picture
feat: align hosted Space deployment with latest upstream
77169b4
"""
插件抽象与注册表:type_name -> 插件实现。
三层设计:
AbstractPlugin — 最底层接口,理论上支持任意协议(非 Cookie、非 SSE 的站点也能接)。
BaseSitePlugin — Cookie 认证 + SSE 流式站点的通用编排,插件开发者继承它只需实现 5 个 hook。
PluginRegistry — 全局注册表。
"""
import json
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncIterator
from playwright.async_api import BrowserContext, Page
from core.api.schemas import InputAttachment
from core.config.settings import get
from core.plugin.errors import ( # noqa: F401 — re-export for backward compat
AccountFrozenError,
BrowserResourceInvalidError,
)
from core.plugin.helpers import (
apply_cookie_auth,
create_page_for_site,
stream_completion_via_sse,
)
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ResolvedModel:
public_model: str
upstream_model: str
# ---------------------------------------------------------------------------
# SiteConfig:纯声明式站点配置
# ---------------------------------------------------------------------------
@dataclass
class SiteConfig:
"""Cookie 认证站点的声明式配置,插件开发者只需填字段,无需写任何方法。"""
start_url: str
api_base: str
cookie_name: str
cookie_domain: str
auth_keys: list[str]
config_section: str = (
"" # config.yaml 中的 section,如 "claude",用于覆盖 start_url/api_base
)
# ---------------------------------------------------------------------------
# AbstractPlugin — 最底层抽象接口
# ---------------------------------------------------------------------------
class AbstractPlugin(ABC):
"""
各 type(如 claude、kimi)需实现此接口并注册。
若站点基于 Cookie + SSE,推荐直接继承 BaseSitePlugin 而非此类。
"""
def __init__(self) -> None:
self._session_state: dict[str, dict[str, Any]] = {}
type_name: str
async def create_page(
self, context: BrowserContext, reuse_page: Page | None = None
) -> Page:
raise NotImplementedError
async def apply_auth(
self,
context: BrowserContext,
page: Page,
auth: dict[str, Any],
*,
reload: bool = True,
**kwargs: Any,
) -> None:
raise NotImplementedError
async def ensure_request_ready(
self,
context: BrowserContext,
page: Page,
*,
request_id: str = "",
session_id: str | None = None,
phase: str = "",
account_id: str = "",
) -> None:
del context, page, request_id, session_id, phase, account_id
async def create_conversation(
self,
context: BrowserContext,
page: Page,
**kwargs: Any,
) -> str | None:
raise NotImplementedError
async def stream_completion(
self,
context: BrowserContext,
page: Page,
session_id: str,
message: str,
**kwargs: Any,
) -> AsyncIterator[str]:
if False:
yield # 使抽象方法为 async generator,与子类一致,便于 async for 迭代
raise NotImplementedError
def parse_session_id(self, messages: list[dict[str, Any]]) -> str | None:
return None
def is_stream_end_event(self, payload: str) -> bool:
"""判断某条流式 payload 是否表示本轮响应已正常结束。默认不识别。"""
return False
def has_session(self, session_id: str) -> bool:
return session_id in self._session_state
def drop_session(self, session_id: str) -> None:
self._session_state.pop(session_id, None)
def drop_sessions(self, session_ids: list[str] | set[str]) -> None:
for session_id in session_ids:
self._session_state.pop(session_id, None)
def model_mapping(self) -> dict[str, str] | None:
"""子类可覆盖;BaseSitePlugin 会从 config_section 的 model_mapping 读取。"""
return None
def normalized_model_mapping(self) -> dict[str, str]:
mapping = self.model_mapping()
if not isinstance(mapping, dict) or not mapping:
raise ValueError("model_mapping is not implemented")
normalized: dict[str, str] = {}
for public_model, upstream_model in mapping.items():
public_id = str(public_model or "").strip()
upstream_id = str(upstream_model or "").strip()
if public_id and upstream_id:
normalized[public_id] = upstream_id
if not normalized:
raise ValueError("model_mapping is not implemented")
return normalized
def listed_model_mapping(self) -> dict[str, str]:
return self.normalized_model_mapping()
def default_public_model(self) -> str:
listed = self.listed_model_mapping()
if listed:
return next(iter(listed))
return next(iter(self.normalized_model_mapping()))
def resolve_model(self, model: str | None) -> ResolvedModel:
mapping = self.normalized_model_mapping()
requested = str(model or "").strip()
if not requested:
default_public = self.default_public_model()
return ResolvedModel(
public_model=default_public,
upstream_model=mapping[default_public],
)
if requested in mapping:
return ResolvedModel(
public_model=requested,
upstream_model=mapping[requested],
)
for public_model, upstream_model in mapping.items():
if requested == upstream_model:
return ResolvedModel(
public_model=public_model,
upstream_model=upstream_model,
)
supported = ", ".join(mapping.keys())
raise ValueError(f"Unknown model: {requested}. Supported models: {supported}")
def on_http_error(self, message: str, headers: dict[str, str] | None) -> int | None:
return None
def stream_transport(self) -> str:
return "page_fetch"
def stream_transport_options(
self,
context: BrowserContext,
page: Page,
session_id: str,
state: dict[str, Any],
**kwargs: Any,
) -> dict[str, Any]:
del context, page, session_id, state
options: dict[str, Any] = {}
proxy_url = str(kwargs.get("proxy_url") or "").strip()
proxy_auth = kwargs.get("proxy_auth")
if proxy_url:
options["proxy_url"] = proxy_url
if isinstance(proxy_auth, tuple) and len(proxy_auth) == 2:
options["proxy_auth"] = proxy_auth
return options
# ---------------------------------------------------------------------------
# BaseSitePlugin — Cookie + SSE 站点的通用编排
# ---------------------------------------------------------------------------
class BaseSitePlugin(AbstractPlugin):
"""
Cookie 认证 + SSE 流式站点的公共基类。
插件开发者继承此类后,只需:
1. 声明 site = SiteConfig(...) — 站点配置
2. 实现 fetch_site_context() — 获取站点上下文(如 org/user 信息)
3. 实现 create_session() — 调用站点 API 创建会话
4. 实现 build_completion_url/body() — 拼补全请求的 URL 与 body
5. 实现 parse_stream_event() — 解析单条流式事件(如 SSE data)
create_page / apply_auth / create_conversation / stream_completion
均由基类自动编排,无需重写。
"""
site: SiteConfig # 子类必须赋值
# ---- 从 config.yaml 读取的 URL 属性(config_section 有值时覆盖默认) ----
@property
def start_url(self) -> str:
if self.site.config_section:
url = get(self.site.config_section, "start_url")
if url:
return str(url).strip()
return self.site.start_url
@property
def api_base(self) -> str:
if self.site.config_section:
base = get(self.site.config_section, "api_base")
if base:
return str(base).strip()
return self.site.api_base
def model_mapping(self) -> dict[str, str] | None:
"""从 config 的 config_section.model_mapping 读取;未配置时返回 None。"""
if self.site.config_section:
m = get(self.site.config_section, "model_mapping")
if isinstance(m, dict) and m:
return {str(k): str(v) for k, v in m.items()}
return None
# ---- 基类全自动实现,子类无需碰 ----
async def create_page(
self,
context: BrowserContext,
reuse_page: Page | None = None,
) -> Page:
return await create_page_for_site(
context, self.start_url, reuse_page=reuse_page
)
async def apply_auth(
self,
context: BrowserContext,
page: Page,
auth: dict[str, Any],
*,
reload: bool = True,
) -> None:
await apply_cookie_auth(
context,
page,
auth,
self.site.cookie_name,
self.site.auth_keys,
self.site.cookie_domain,
reload=reload,
)
async def create_conversation(
self,
context: BrowserContext,
page: Page,
**kwargs: Any,
) -> str | None:
extra_kwargs = dict(kwargs)
request_id = str(extra_kwargs.pop("request_id", "") or "")
# 调用子类获取站点上下文
site_context = await self.fetch_site_context(
context,
page,
request_id=request_id,
)
if site_context is None:
logger.warning(
"[%s] fetch_site_context 返回 None,请确认已登录", self.type_name
)
return None
# 通过站点上下文创建会话
conv_id = await self.create_session(
context,
page,
site_context,
request_id=request_id,
**extra_kwargs,
)
if conv_id is None:
return None
state: dict[str, Any] = {"site_context": site_context}
if kwargs.get("timezone") is not None:
state["timezone"] = kwargs["timezone"]
public_model = str(kwargs.get("public_model") or "").strip()
if public_model:
state["public_model"] = public_model
upstream_model = str(kwargs.get("upstream_model") or "").strip()
if upstream_model:
state["upstream_model"] = upstream_model
self._session_state[conv_id] = state
logger.info(
"[%s] create_conversation done conv_id=%s sessions=%s",
self.type_name,
conv_id,
list(self._session_state.keys()),
)
return conv_id
async def stream_completion(
self,
context: BrowserContext,
page: Page,
session_id: str,
message: str,
**kwargs: Any,
) -> AsyncIterator[str]:
state = self._session_state.get(session_id)
if not state:
raise RuntimeError(f"未知会话 ID: {session_id}")
request_id: str = kwargs.get("request_id", "")
url = self.build_completion_url(session_id, state)
attachments = list(kwargs.get("attachments") or [])
prepared_attachments = await self.prepare_attachments(
context,
page,
session_id,
state,
attachments,
request_id=request_id,
)
body = self.build_completion_body(
message,
session_id,
state,
prepared_attachments,
)
body_json = json.dumps(body)
logger.info(
"[%s] stream_completion session_id=%s url=%s",
self.type_name,
session_id,
url,
)
out_message_ids: list[str] = []
transport_options = self.stream_transport_options(
context,
page,
session_id,
state,
request_id=request_id,
attachments=attachments,
proxy_url=kwargs.get("proxy_url"),
proxy_auth=kwargs.get("proxy_auth"),
)
async for text in stream_completion_via_sse(
context,
page,
url,
body_json,
self.parse_stream_event,
request_id,
on_http_error=self.on_http_error,
is_terminal_event=self.is_stream_end_event,
collect_message_id=out_message_ids,
transport=self.stream_transport(),
transport_options=transport_options,
):
yield text
if out_message_ids and session_id in self._session_state:
self.on_stream_completion_finished(session_id, out_message_ids)
# ---- 子类必须实现的 hook ----
@abstractmethod
async def fetch_site_context(
self,
context: BrowserContext,
page: Page,
request_id: str = "",
) -> dict[str, Any] | None:
"""获取站点上下文信息(如 org_uuid、user_id 等),失败返回 None。"""
del request_id
...
@abstractmethod
async def create_session(
self,
context: BrowserContext,
page: Page,
site_context: dict[str, Any],
**kwargs: Any,
) -> str | None:
"""调用站点 API 创建会话,返回会话 ID,失败返回 None。"""
...
@abstractmethod
def build_completion_url(self, session_id: str, state: dict[str, Any]) -> str:
"""根据会话状态拼出补全请求的完整 URL。"""
...
@abstractmethod
def build_completion_body(
self,
message: str,
session_id: str,
state: dict[str, Any],
prepared_attachments: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""构建补全请求体,返回 dict(基类负责 json.dumps)。"""
...
@abstractmethod
def parse_stream_event(
self,
payload: str,
) -> tuple[list[str], str | None, str | None]:
"""
解析单条流式事件 payload(如 SSE data 行)。
返回 (texts, message_id, error_message)。
"""
...
# ---- 子类可选覆盖的 hook(有合理默认值) ----
def on_stream_completion_finished(
self,
session_id: str,
message_ids: list[str],
) -> None:
"""Hook:流式补全结束后调用,子类可按需用 message_ids 更新会话 state(如记续写用的父消息 id)。"""
async def prepare_attachments(
self,
context: BrowserContext,
page: Page,
session_id: str,
state: dict[str, Any],
attachments: list[InputAttachment],
request_id: str = "",
) -> dict[str, Any]:
del context, page, session_id, state, attachments, request_id
return {}
# ---------------------------------------------------------------------------
# PluginRegistry — 全局注册表
# ---------------------------------------------------------------------------
class PluginRegistry:
"""全局插件注册表:type_name -> AbstractPlugin。"""
_plugins: dict[str, AbstractPlugin] = {}
@classmethod
def register(cls, plugin: AbstractPlugin) -> None:
cls._plugins[plugin.type_name] = plugin
@classmethod
def get(cls, type_name: str) -> AbstractPlugin | None:
return cls._plugins.get(type_name)
@classmethod
def resolve_model(cls, type_name: str, model: str | None) -> ResolvedModel:
plugin = cls.get(type_name)
if plugin is None:
raise ValueError(f"Unknown provider: {type_name}")
return plugin.resolve_model(model)
@classmethod
def model_metadata(cls, type_name: str) -> dict[str, Any]:
plugin = cls.get(type_name)
if plugin is None:
raise ValueError(f"Unknown provider: {type_name}")
mapping = plugin.listed_model_mapping()
return {
"provider": type_name,
"public_models": list(mapping.keys()),
"model_mapping": mapping,
"default_model": plugin.default_public_model(),
}
@classmethod
def all_types(cls) -> list[str]:
return list(cls._plugins.keys())