Spaces:
Paused
Paused
| from __future__ import annotations | |
| import base64 | |
| import hashlib | |
| import json | |
| import re | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Iterable, Iterator | |
| import tiktoken | |
| from services.account_service import account_service | |
| from services.config import config | |
| from services.openai_backend_api import OpenAIBackendAPI | |
| from utils.helper import IMAGE_MODELS, extract_image_from_message_content | |
| from utils.log import logger | |
| class ImageGenerationError(Exception): | |
| def __init__( | |
| self, | |
| message: str, | |
| status_code: int = 502, | |
| error_type: str = "server_error", | |
| code: str | None = "upstream_error", | |
| param: str | None = None, | |
| ) -> None: | |
| super().__init__(message) | |
| self.status_code = status_code | |
| self.error_type = error_type | |
| self.code = code | |
| self.param = param | |
| def to_openai_error(self) -> dict[str, Any]: | |
| return { | |
| "error": { | |
| "message": str(self), | |
| "type": self.error_type, | |
| "param": self.param, | |
| "code": self.code, | |
| } | |
| } | |
| def is_token_invalid_error(message: str) -> bool: | |
| text = str(message or "").lower() | |
| return ( | |
| "token_invalidated" in text | |
| or "token_revoked" in text | |
| or "authentication token has been invalidated" in text | |
| or "invalidated oauth token" in text | |
| ) | |
| def image_stream_error_message(message: str) -> str: | |
| text = str(message or "") | |
| lower = text.lower() | |
| if "curl: (35)" in lower or "tls connect error" in lower or "openssl_internal" in lower: | |
| return "upstream image connection failed, please retry later" | |
| return text or "image generation failed" | |
| def encode_images(images: Iterable[tuple[bytes, str, str]]) -> list[str]: | |
| return [base64.b64encode(data).decode("ascii") for data, _, _ in images if data] | |
| def save_image_bytes(image_data: bytes, base_url: str | None = None) -> str: | |
| config.cleanup_old_images() | |
| file_hash = hashlib.md5(image_data).hexdigest() | |
| filename = f"{int(time.time())}_{file_hash}.png" | |
| relative_dir = Path(time.strftime("%Y"), time.strftime("%m"), time.strftime("%d")) | |
| file_path = config.images_dir / relative_dir / filename | |
| file_path.parent.mkdir(parents=True, exist_ok=True) | |
| file_path.write_bytes(image_data) | |
| return f"{(base_url or config.base_url)}/images/{relative_dir.as_posix()}/{filename}" | |
| def message_text(content: Any) -> str: | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| parts = [] | |
| for item in content: | |
| if isinstance(item, str): | |
| parts.append(item) | |
| elif isinstance(item, dict) and str(item.get("type") or "") in {"text", "input_text", "output_text"}: | |
| parts.append(str(item.get("text") or "")) | |
| return "".join(parts) | |
| return "" | |
| def normalize_messages(messages: object, system: Any = None) -> list[dict[str, Any]]: | |
| normalized = [] | |
| if config.global_system_prompt: | |
| normalized.append({"role": "system", "content": config.global_system_prompt}) | |
| system_text = message_text(system) | |
| if system_text: | |
| normalized.append({"role": "system", "content": system_text}) | |
| if isinstance(messages, list): | |
| for message in messages: | |
| if not isinstance(message, dict): | |
| continue | |
| role = message.get("role", "user") | |
| content = message.get("content", "") | |
| text = message_text(content) | |
| images: list[tuple[bytes, str]] = [] | |
| if role == "user": | |
| images.extend(extract_image_from_message_content(content)) | |
| if isinstance(content, list): | |
| for part in content: | |
| if not isinstance(part, dict) or part.get("type") != "image": | |
| continue | |
| data = part.get("data") | |
| if isinstance(data, (bytes, bytearray)): | |
| images.append((bytes(data), str(part.get("mime") or "image/png"))) | |
| if images: | |
| parts: list[Any] = [] | |
| if text: | |
| parts.append({"type": "text", "text": text}) | |
| for data, mime in images: | |
| parts.append({"type": "image", "data": data, "mime": mime}) | |
| normalized.append({"role": role, "content": parts}) | |
| else: | |
| normalized.append({"role": role, "content": text}) | |
| return normalized | |
| def prompt_with_global_system(prompt: str) -> str: | |
| return f"{config.global_system_prompt}\n\n{prompt}" if config.global_system_prompt else prompt | |
| def assistant_history_text(messages: list[dict[str, Any]]) -> str: | |
| return "".join(str(item.get("content") or "") for item in messages if item.get("role") == "assistant") | |
| def assistant_history_messages(messages: list[dict[str, Any]]) -> list[str]: | |
| return [str(item.get("content") or "") for item in messages if item.get("role") == "assistant" and item.get("content")] | |
| def build_image_prompt(prompt: str, size: str | None) -> str: | |
| if not size: | |
| return prompt | |
| if size not in {"1:1", "16:9", "9:16", "4:3", "3:4"}: | |
| return f"{prompt.strip()}\n\n输出图片,宽高比为 {size}。" | |
| hint = { | |
| "1:1": "输出为 1:1 正方形构图,主体居中,适合正方形画幅。", | |
| "16:9": "输出为 16:9 横屏构图,适合宽画幅展示。", | |
| "9:16": "输出为 9:16 竖屏构图,适合竖版画幅展示。", | |
| "4:3": "输出为 4:3 比例,兼顾宽度与高度,适合展示画面细节。", | |
| "3:4": "输出为 3:4 比例,纵向构图,适合人物肖像或竖向场景。", | |
| }[size] | |
| return f"{prompt.strip()}\n\n{hint}" | |
| def encoding_for_model(model: str): | |
| try: | |
| return tiktoken.encoding_for_model(model) | |
| except KeyError: | |
| try: | |
| return tiktoken.get_encoding("o200k_base") | |
| except KeyError: | |
| return tiktoken.get_encoding("cl100k_base") | |
| def count_message_tokens(messages: list[dict[str, Any]], model: str) -> int: | |
| encoding = encoding_for_model(model) | |
| total = 0 | |
| for message in messages: | |
| total += 3 | |
| for key, value in message.items(): | |
| if not isinstance(value, str): | |
| continue | |
| total += len(encoding.encode(value)) | |
| if key == "name": | |
| total += 1 | |
| return total + 3 | |
| def count_text_tokens(text: str, model: str) -> int: | |
| return len(encoding_for_model(model).encode(text)) | |
| def format_image_result( | |
| items: list[dict[str, Any]], | |
| prompt: str, | |
| response_format: str, | |
| base_url: str | None = None, | |
| created: int | None = None, | |
| message: str = "", | |
| ) -> dict[str, Any]: | |
| data: list[dict[str, Any]] = [] | |
| for item in items: | |
| b64_json = str(item.get("b64_json") or "").strip() | |
| if not b64_json: | |
| continue | |
| revised_prompt = str(item.get("revised_prompt") or prompt).strip() or prompt | |
| if response_format == "b64_json": | |
| data.append({ | |
| "b64_json": b64_json, | |
| "url": save_image_bytes(base64.b64decode(b64_json), base_url), | |
| "revised_prompt": revised_prompt, | |
| }) | |
| else: | |
| data.append({ | |
| "url": save_image_bytes(base64.b64decode(b64_json), base_url), | |
| "revised_prompt": revised_prompt, | |
| }) | |
| result: dict[str, Any] = {"created": created or int(time.time()), "data": data} | |
| if message and not data: | |
| result["message"] = message | |
| return result | |
| class ConversationRequest: | |
| model: str = "auto" | |
| prompt: str = "" | |
| messages: list[dict[str, Any]] | None = None | |
| images: list[str] | None = None | |
| n: int = 1 | |
| size: str | None = None | |
| response_format: str = "b64_json" | |
| base_url: str | None = None | |
| message_as_error: bool = False | |
| class ConversationState: | |
| text: str = "" | |
| conversation_id: str = "" | |
| file_ids: list[str] = field(default_factory=list) | |
| sediment_ids: list[str] = field(default_factory=list) | |
| blocked: bool = False | |
| tool_invoked: bool | None = None | |
| turn_use_case: str = "" | |
| class ImageOutput: | |
| kind: str | |
| model: str | |
| index: int | |
| total: int | |
| created: int = field(default_factory=lambda: int(time.time())) | |
| text: str = "" | |
| upstream_event_type: str = "" | |
| data: list[dict[str, Any]] = field(default_factory=list) | |
| def to_chunk(self) -> dict[str, Any]: | |
| chunk: dict[str, Any] = { | |
| "object": "image.generation.chunk", | |
| "created": self.created, | |
| "model": self.model, | |
| "index": self.index, | |
| "total": self.total, | |
| "progress_text": self.text, | |
| "upstream_event_type": self.upstream_event_type, | |
| "data": [], | |
| } | |
| if self.kind == "message": | |
| chunk.update({ | |
| "object": "image.generation.message", | |
| "message": self.text, | |
| }) | |
| chunk.pop("progress_text", None) | |
| chunk.pop("upstream_event_type", None) | |
| elif self.kind == "result": | |
| chunk.update({ | |
| "object": "image.generation.result", | |
| "data": self.data, | |
| }) | |
| chunk.pop("progress_text", None) | |
| chunk.pop("upstream_event_type", None) | |
| return chunk | |
| def assistant_message_text(message: dict[str, Any]) -> str: | |
| content = message.get("content") or {} | |
| parts = content.get("parts") or [] | |
| if not isinstance(parts, list): | |
| return "" | |
| return "".join(part for part in parts if isinstance(part, str)) | |
| def strip_history(text: str, history_text: str = "") -> str: | |
| text = str(text or "") | |
| history_text = str(history_text or "") | |
| while history_text and text.startswith(history_text): | |
| text = text[len(history_text):] | |
| return text | |
| def assistant_text(event: dict[str, Any], current_text: str = "", history_text: str = "") -> str: | |
| for candidate in (event, event.get("v")): | |
| if not isinstance(candidate, dict): | |
| continue | |
| message = candidate.get("message") | |
| if not isinstance(message, dict): | |
| continue | |
| role = str((message.get("author") or {}).get("role") or "").strip().lower() | |
| if role != "assistant": | |
| continue | |
| text = assistant_message_text(message) | |
| if text: | |
| return strip_history(text, history_text) | |
| return apply_text_patch(event, current_text, history_text) | |
| def event_assistant_text(event: dict[str, Any], history_text: str = "") -> str: | |
| for candidate in (event, event.get("v")): | |
| if not isinstance(candidate, dict): | |
| continue | |
| message = candidate.get("message") | |
| if isinstance(message, dict) and (message.get("author") or {}).get("role") == "assistant": | |
| return strip_history(assistant_message_text(message), history_text) | |
| return "" | |
| def apply_text_patch(event: dict[str, Any], current_text: str = "", history_text: str = "") -> str: | |
| if event.get("p") == "/message/content/parts/0": | |
| return apply_patch_op(event, current_text, history_text) | |
| operations = event.get("v") | |
| if isinstance(operations, str) and current_text and not event.get("p") and not event.get("o"): | |
| return current_text + operations | |
| if event.get("o") == "patch" and isinstance(operations, list): | |
| text = current_text | |
| for item in operations: | |
| if isinstance(item, dict): | |
| text = apply_text_patch(item, text, history_text) | |
| return text | |
| if not isinstance(operations, list): | |
| return current_text | |
| text = current_text | |
| for item in operations: | |
| if isinstance(item, dict): | |
| text = apply_text_patch(item, text, history_text) | |
| return text | |
| def apply_patch_op(operation: dict[str, Any], current_text: str, history_text: str = "") -> str: | |
| op = operation.get("o") | |
| value = str(operation.get("v") or "") | |
| if op == "append": | |
| return current_text + value | |
| if op == "replace": | |
| return strip_history(value, history_text) | |
| return current_text | |
| def add_unique(values: list[str], candidates: list[str]) -> None: | |
| for candidate in candidates: | |
| if candidate and candidate not in values: | |
| values.append(candidate) | |
| def extract_conversation_ids(payload: str) -> tuple[str, list[str], list[str]]: | |
| conversation_match = re.search(r'"conversation_id"\s*:\s*"([^"]+)"', payload) | |
| conversation_id = conversation_match.group(1) if conversation_match else "" | |
| file_ids = re.findall(r"(file[-_][A-Za-z0-9]+)", payload) | |
| sediment_ids = re.findall(r"sediment://([A-Za-z0-9_-]+)", payload) | |
| return conversation_id, file_ids, sediment_ids | |
| def is_image_tool_event(event: dict[str, Any]) -> bool: | |
| value = event.get("v") | |
| message = event.get("message") or (value.get("message") if isinstance(value, dict) else None) | |
| if not isinstance(message, dict): | |
| return False | |
| metadata = message.get("metadata") or {} | |
| author = message.get("author") or {} | |
| return author.get("role") == "tool" and metadata.get("async_task_type") == "image_gen" | |
| def update_conversation_state(state: ConversationState, payload: str, event: dict[str, Any] | None = None) -> None: | |
| conversation_id, file_ids, sediment_ids = extract_conversation_ids(payload) | |
| if conversation_id and not state.conversation_id: | |
| state.conversation_id = conversation_id | |
| if isinstance(event, dict) and is_image_tool_event(event): | |
| add_unique(state.file_ids, file_ids) | |
| add_unique(state.sediment_ids, sediment_ids) | |
| if not isinstance(event, dict): | |
| return | |
| state.conversation_id = str(event.get("conversation_id") or state.conversation_id) | |
| value = event.get("v") | |
| if isinstance(value, dict): | |
| state.conversation_id = str(value.get("conversation_id") or state.conversation_id) | |
| if event.get("type") == "moderation": | |
| moderation = event.get("moderation_response") | |
| if isinstance(moderation, dict) and moderation.get("blocked") is True: | |
| state.blocked = True | |
| if event.get("type") == "server_ste_metadata": | |
| metadata = event.get("metadata") | |
| if isinstance(metadata, dict): | |
| if isinstance(metadata.get("tool_invoked"), bool): | |
| state.tool_invoked = metadata["tool_invoked"] | |
| state.turn_use_case = str(metadata.get("turn_use_case") or state.turn_use_case) | |
| def conversation_base_event(event_type: str, state: ConversationState, **extra: Any) -> dict[str, Any]: | |
| return { | |
| "type": event_type, | |
| "text": state.text, | |
| "conversation_id": state.conversation_id, | |
| "file_ids": list(state.file_ids), | |
| "sediment_ids": list(state.sediment_ids), | |
| "blocked": state.blocked, | |
| "tool_invoked": state.tool_invoked, | |
| "turn_use_case": state.turn_use_case, | |
| **extra, | |
| } | |
| def iter_conversation_payloads(payloads: Iterator[str], history_text: str = "", | |
| history_messages: list[str] | None = None) -> Iterator[dict[str, Any]]: | |
| state = ConversationState() | |
| history_messages = history_messages or [] | |
| history_index = 0 | |
| for payload in payloads: | |
| # print(f"[upstream_sse] {payload}", flush=True) | |
| if not payload: | |
| continue | |
| if payload == "[DONE]": | |
| yield conversation_base_event("conversation.done", state, done=True) | |
| break | |
| try: | |
| event = json.loads(payload) | |
| except json.JSONDecodeError: | |
| update_conversation_state(state, payload) | |
| yield conversation_base_event("conversation.raw", state, payload=payload) | |
| continue | |
| if not isinstance(event, dict): | |
| yield conversation_base_event("conversation.event", state, raw=event) | |
| continue | |
| update_conversation_state(state, payload, event) | |
| if history_index < len(history_messages) and event_assistant_text(event, history_text) == history_messages[history_index]: | |
| history_index += 1 | |
| state.text = "" | |
| continue | |
| next_text = assistant_text(event, state.text, history_text) | |
| if next_text != state.text: | |
| delta = next_text[len(state.text):] if next_text.startswith(state.text) else next_text | |
| state.text = next_text | |
| yield conversation_base_event("conversation.delta", state, raw=event, delta=delta) | |
| continue | |
| yield conversation_base_event("conversation.event", state, raw=event) | |
| def conversation_events( | |
| backend: OpenAIBackendAPI, | |
| messages: list[dict[str, Any]] | None = None, | |
| model: str = "auto", | |
| prompt: str = "", | |
| images: list[str] | None = None, | |
| size: str | None = None, | |
| ) -> Iterator[dict[str, Any]]: | |
| normalized = normalize_messages(messages or ([{"role": "user", "content": prompt}] if prompt else [])) | |
| image_model = str(model or "").strip() in IMAGE_MODELS | |
| history_text = "" if image_model else assistant_history_text(normalized) | |
| history_messages = [] if image_model else assistant_history_messages(normalized) | |
| final_prompt = prompt_with_global_system(build_image_prompt(prompt, size)) if image_model else prompt | |
| payloads = backend.stream_conversation( | |
| messages=normalized, | |
| model=model, | |
| prompt=final_prompt, | |
| images=images if image_model else None, | |
| system_hints=["picture_v2"] if image_model else None, | |
| ) | |
| yield from iter_conversation_payloads(payloads, history_text, history_messages) | |
| def text_backend() -> OpenAIBackendAPI: | |
| return OpenAIBackendAPI(access_token=account_service.get_text_access_token()) | |
| def stream_text_deltas(backend: OpenAIBackendAPI, request: ConversationRequest) -> Iterator[str]: | |
| attempted_tokens: set[str] = set() | |
| token = getattr(backend, "access_token", "") | |
| emitted = False | |
| while True: | |
| if token and token in attempted_tokens: | |
| raise RuntimeError("no available text account") | |
| if token: | |
| attempted_tokens.add(token) | |
| try: | |
| active_backend = OpenAIBackendAPI(access_token=token) | |
| for event in conversation_events(active_backend, messages=request.messages, model=request.model, prompt=request.prompt): | |
| if event.get("type") != "conversation.delta": | |
| continue | |
| delta = str(event.get("delta") or "") | |
| if delta: | |
| emitted = True | |
| yield delta | |
| account_service.mark_text_used(token) | |
| return | |
| except Exception as exc: | |
| error_message = str(exc) | |
| if token and not emitted and is_token_invalid_error(error_message): | |
| account_service.remove_invalid_token(token, "text_stream") | |
| token = account_service.get_text_access_token(attempted_tokens) | |
| if token: | |
| continue | |
| raise | |
| def collect_text(backend: OpenAIBackendAPI, request: ConversationRequest) -> str: | |
| return "".join(stream_text_deltas(backend, request)) | |
| def stream_image_outputs( | |
| backend: OpenAIBackendAPI, | |
| request: ConversationRequest, | |
| index: int = 1, | |
| total: int = 1, | |
| ) -> Iterator[ImageOutput]: | |
| last: dict[str, Any] = {} | |
| for event in conversation_events( | |
| backend, | |
| prompt=request.prompt, | |
| model=request.model, | |
| images=request.images or [], | |
| size=request.size, | |
| ): | |
| last = event | |
| if event.get("type") == "conversation.delta": | |
| yield ImageOutput( | |
| kind="progress", | |
| model=request.model, | |
| index=index, | |
| total=total, | |
| text=str(event.get("delta") or ""), | |
| upstream_event_type="conversation.delta", | |
| ) | |
| continue | |
| if event.get("type") == "conversation.event": | |
| raw = event.get("raw") | |
| raw_type = str(raw.get("type") or "") if isinstance(raw, dict) else "" | |
| yield ImageOutput( | |
| kind="progress", | |
| model=request.model, | |
| index=index, | |
| total=total, | |
| upstream_event_type=raw_type, | |
| ) | |
| conversation_id = str(last.get("conversation_id") or "") | |
| file_ids = [str(item) for item in last.get("file_ids") or []] | |
| sediment_ids = [str(item) for item in last.get("sediment_ids") or []] | |
| message = str(last.get("text") or "").strip() | |
| is_text_response = last.get("tool_invoked") is False or last.get("turn_use_case") == "text" | |
| logger.info({ | |
| "event": "image_stream_resolve_start", | |
| "conversation_id": conversation_id, | |
| "file_ids": file_ids, | |
| "sediment_ids": sediment_ids, | |
| "tool_invoked": last.get("tool_invoked"), | |
| "turn_use_case": last.get("turn_use_case"), | |
| }) | |
| if message and not file_ids and not sediment_ids and (last.get("blocked") or is_text_response): | |
| yield ImageOutput(kind="message", model=request.model, index=index, total=total, text=message) | |
| return | |
| image_urls = backend.resolve_conversation_image_urls(conversation_id, file_ids, sediment_ids) | |
| if image_urls: | |
| image_items = [ | |
| {"b64_json": base64.b64encode(image_data).decode("ascii")} | |
| for image_data in backend.download_image_bytes(image_urls) | |
| ] | |
| data = format_image_result( | |
| image_items, | |
| request.prompt, | |
| request.response_format, | |
| request.base_url, | |
| int(time.time()), | |
| )["data"] | |
| if data: | |
| yield ImageOutput(kind="result", model=request.model, index=index, total=total, data=data) | |
| return | |
| if message: | |
| yield ImageOutput(kind="message", model=request.model, index=index, total=total, text=message) | |
| def stream_image_outputs_with_pool(request: ConversationRequest) -> Iterator[ImageOutput]: | |
| if str(request.model or "").strip() not in IMAGE_MODELS: | |
| raise ImageGenerationError("unsupported image model,supported models: " + ", ".join(IMAGE_MODELS)) | |
| emitted = False | |
| last_error = "" | |
| for index in range(1, request.n + 1): | |
| while True: | |
| try: | |
| token = account_service.get_available_access_token() | |
| except RuntimeError as exc: | |
| if emitted: | |
| return | |
| raise ImageGenerationError(str(exc) or "image generation failed") from exc | |
| emitted_for_token = False | |
| returned_message = False | |
| returned_result = False | |
| try: | |
| backend = OpenAIBackendAPI(access_token=token) | |
| for output in stream_image_outputs(backend, request, index, request.n): | |
| if output.kind == "message" and request.message_as_error: | |
| raise ImageGenerationError( | |
| output.text or "Image generation was rejected by upstream policy.", | |
| status_code=400, | |
| error_type="invalid_request_error", | |
| code="content_policy_violation", | |
| ) | |
| emitted = True | |
| emitted_for_token = True | |
| returned_message = output.kind == "message" | |
| returned_result = returned_result or output.kind == "result" | |
| yield output | |
| if returned_message or not returned_result: | |
| account_service.mark_image_result(token, False) | |
| return | |
| account_service.mark_image_result(token, True) | |
| break | |
| except ImageGenerationError: | |
| account_service.mark_image_result(token, False) | |
| raise | |
| except Exception as exc: | |
| account_service.mark_image_result(token, False) | |
| last_error = str(exc) | |
| logger.warning({"event": "image_stream_fail", "request_token": token, "error": last_error}) | |
| if not emitted_for_token and is_token_invalid_error(last_error): | |
| account_service.remove_invalid_token(token, "image_stream") | |
| continue | |
| raise ImageGenerationError(image_stream_error_message(last_error)) from exc | |
| if not emitted: | |
| raise ImageGenerationError(image_stream_error_message(last_error)) | |
| def stream_image_chunks(outputs: Iterable[ImageOutput]) -> Iterator[dict[str, Any]]: | |
| for output in outputs: | |
| yield output.to_chunk() | |
| def collect_image_outputs(outputs: Iterable[ImageOutput]) -> dict[str, Any]: | |
| created = None | |
| data: list[dict[str, Any]] = [] | |
| message = "" | |
| progress_parts: list[str] = [] | |
| for output in outputs: | |
| created = created or output.created | |
| if output.kind == "progress" and output.text: | |
| progress_parts.append(output.text) | |
| elif output.kind == "message": | |
| message = output.text | |
| elif output.kind == "result": | |
| data.extend(output.data) | |
| result: dict[str, Any] = {"created": created or int(time.time()), "data": data} | |
| if not data: | |
| text = message or "".join(progress_parts).strip() | |
| if text: | |
| result["message"] = text | |
| return result | |