from __future__ import annotations import base64 import json import re import time from dataclasses import dataclass, field from typing import Any, Iterable, Iterator import tiktoken from services.account_service import account_service from services.config import config from services.image_storage_service import image_storage_service from services.openai_backend_api import ImagePollTimeoutError, OpenAIBackendAPI from utils.helper import IMAGE_MODELS, extract_image_from_message_content from utils.log import logger class ImageGenerationError(Exception): def __init__( self, message: str, status_code: int = 502, error_type: str = "server_error", code: str | None = "upstream_error", param: str | None = None, ) -> None: super().__init__(message) self.status_code = status_code self.error_type = error_type self.code = code self.param = param def to_openai_error(self) -> dict[str, Any]: return { "error": { "message": str(self), "type": self.error_type, "param": self.param, "code": self.code, } } def is_token_invalid_error(message: str) -> bool: text = str(message or "").lower() return ( "token_invalidated" in text or "token_revoked" in text or "authentication token has been invalidated" in text or "invalidated oauth token" in text ) def image_stream_error_message(message: str) -> str: text = str(message or "") lower = text.lower() if is_token_invalid_error(text): return "image generation failed" if "curl: (35)" in lower or "tls connect error" in lower or "openssl_internal" in lower: return "upstream image connection failed, please retry later" return text or "image generation failed" def encode_images(images: Iterable[tuple[bytes, str, str]]) -> list[str]: return [base64.b64encode(data).decode("ascii") for data, _, _ in images if data] def save_image_bytes(image_data: bytes, base_url: str | None = None) -> str: return image_storage_service.save(image_data, base_url).url def message_text(content: Any) -> str: if isinstance(content, str): return content if isinstance(content, list): parts = [] for item in content: if isinstance(item, str): parts.append(item) elif isinstance(item, dict) and str(item.get("type") or "") in {"text", "input_text", "output_text"}: parts.append(str(item.get("text") or "")) return "".join(parts) return "" def normalize_messages(messages: object, system: Any = None) -> list[dict[str, Any]]: normalized = [] if config.global_system_prompt: normalized.append({"role": "system", "content": config.global_system_prompt}) system_text = message_text(system) if system_text: normalized.append({"role": "system", "content": system_text}) if isinstance(messages, list): for message in messages: if not isinstance(message, dict): continue role = message.get("role", "user") content = message.get("content", "") text = message_text(content) images: list[tuple[bytes, str]] = [] if role == "user": images.extend(extract_image_from_message_content(content)) if isinstance(content, list): for part in content: if not isinstance(part, dict) or part.get("type") != "image": continue data = part.get("data") if isinstance(data, (bytes, bytearray)): images.append((bytes(data), str(part.get("mime") or "image/png"))) if images: parts: list[Any] = [] if text: parts.append({"type": "text", "text": text}) for data, mime in images: parts.append({"type": "image", "data": data, "mime": mime}) normalized.append({"role": role, "content": parts}) else: normalized.append({"role": role, "content": text}) return normalized def prompt_with_global_system(prompt: str) -> str: return f"{config.global_system_prompt}\n\n{prompt}" if config.global_system_prompt else prompt def assistant_history_text(messages: list[dict[str, Any]]) -> str: return "".join(str(item.get("content") or "") for item in messages if item.get("role") == "assistant") def assistant_history_messages(messages: list[dict[str, Any]]) -> list[str]: return [str(item.get("content") or "") for item in messages if item.get("role") == "assistant" and item.get("content")] def build_image_prompt(prompt: str, size: str | None) -> str: if not size: return prompt if size not in {"1:1", "16:9", "9:16", "4:3", "3:4"}: return f"{prompt.strip()}\n\n输出图片,宽高比为 {size}。" hint = { "1:1": "输出为 1:1 正方形构图,主体居中,适合正方形画幅。", "16:9": "输出为 16:9 横屏构图,适合宽画幅展示。", "9:16": "输出为 9:16 竖屏构图,适合竖版画幅展示。", "4:3": "输出为 4:3 比例,兼顾宽度与高度,适合展示画面细节。", "3:4": "输出为 3:4 比例,纵向构图,适合人物肖像或竖向场景。", }[size] return f"{prompt.strip()}\n\n{hint}" def encoding_for_model(model: str): try: return tiktoken.encoding_for_model(model) except KeyError: try: return tiktoken.get_encoding("o200k_base") except KeyError: return tiktoken.get_encoding("cl100k_base") def count_message_tokens(messages: list[dict[str, Any]], model: str) -> int: encoding = encoding_for_model(model) total = 0 for message in messages: total += 3 for key, value in message.items(): if not isinstance(value, str): continue total += len(encoding.encode(value)) if key == "name": total += 1 return total + 3 def count_text_tokens(text: str, model: str) -> int: return len(encoding_for_model(model).encode(text)) def format_image_result( items: list[dict[str, Any]], prompt: str, response_format: str, base_url: str | None = None, created: int | None = None, message: str = "", ) -> dict[str, Any]: data: list[dict[str, Any]] = [] for item in items: b64_json = str(item.get("b64_json") or "").strip() if not b64_json: continue revised_prompt = str(item.get("revised_prompt") or prompt).strip() or prompt if response_format == "b64_json": data.append({ "b64_json": b64_json, "url": save_image_bytes(base64.b64decode(b64_json), base_url), "revised_prompt": revised_prompt, }) else: data.append({ "url": save_image_bytes(base64.b64decode(b64_json), base_url), "revised_prompt": revised_prompt, }) result: dict[str, Any] = {"created": created or int(time.time()), "data": data} if message and not data: result["message"] = message return result @dataclass class ConversationRequest: model: str = "auto" prompt: str = "" messages: list[dict[str, Any]] | None = None images: list[str] | None = None n: int = 1 size: str | None = None response_format: str = "b64_json" base_url: str | None = None message_as_error: bool = False @dataclass class ConversationState: text: str = "" conversation_id: str = "" file_ids: list[str] = field(default_factory=list) sediment_ids: list[str] = field(default_factory=list) blocked: bool = False tool_invoked: bool | None = None turn_use_case: str = "" @dataclass class ImageOutput: kind: str model: str index: int total: int created: int = field(default_factory=lambda: int(time.time())) text: str = "" upstream_event_type: str = "" data: list[dict[str, Any]] = field(default_factory=list) def to_chunk(self) -> dict[str, Any]: chunk: dict[str, Any] = { "object": "image.generation.chunk", "created": self.created, "model": self.model, "index": self.index, "total": self.total, "progress_text": self.text, "upstream_event_type": self.upstream_event_type, "data": [], } if self.kind == "message": chunk.update({ "object": "image.generation.message", "message": self.text, }) chunk.pop("progress_text", None) chunk.pop("upstream_event_type", None) elif self.kind == "result": chunk.update({ "object": "image.generation.result", "data": self.data, }) chunk.pop("progress_text", None) chunk.pop("upstream_event_type", None) return chunk def assistant_message_text(message: dict[str, Any]) -> str: content = message.get("content") or {} parts = content.get("parts") or [] if not isinstance(parts, list): return "" return "".join(part for part in parts if isinstance(part, str)) def strip_history(text: str, history_text: str = "") -> str: text = str(text or "") history_text = str(history_text or "") while history_text and text.startswith(history_text): text = text[len(history_text):] return text def assistant_text(event: dict[str, Any], current_text: str = "", history_text: str = "") -> str: for candidate in (event, event.get("v")): if not isinstance(candidate, dict): continue message = candidate.get("message") if not isinstance(message, dict): continue role = str((message.get("author") or {}).get("role") or "").strip().lower() if role != "assistant": continue text = assistant_message_text(message) if text: return strip_history(text, history_text) return apply_text_patch(event, current_text, history_text) def event_assistant_text(event: dict[str, Any], history_text: str = "") -> str: for candidate in (event, event.get("v")): if not isinstance(candidate, dict): continue message = candidate.get("message") if isinstance(message, dict) and (message.get("author") or {}).get("role") == "assistant": return strip_history(assistant_message_text(message), history_text) return "" def apply_text_patch(event: dict[str, Any], current_text: str = "", history_text: str = "") -> str: if event.get("p") == "/message/content/parts/0": return apply_patch_op(event, current_text, history_text) operations = event.get("v") if isinstance(operations, str) and current_text and not event.get("p") and not event.get("o"): return current_text + operations if event.get("o") == "patch" and isinstance(operations, list): text = current_text for item in operations: if isinstance(item, dict): text = apply_text_patch(item, text, history_text) return text if not isinstance(operations, list): return current_text text = current_text for item in operations: if isinstance(item, dict): text = apply_text_patch(item, text, history_text) return text def apply_patch_op(operation: dict[str, Any], current_text: str, history_text: str = "") -> str: op = operation.get("o") value = str(operation.get("v") or "") if op == "append": return current_text + value if op == "replace": return strip_history(value, history_text) return current_text def add_unique(values: list[str], candidates: list[str]) -> None: for candidate in candidates: if candidate and candidate not in values: values.append(candidate) def extract_conversation_ids(payload: str) -> tuple[str, list[str], list[str]]: conversation_match = re.search(r'"conversation_id"\s*:\s*"([^"]+)"', payload) conversation_id = conversation_match.group(1) if conversation_match else "" # Negative lookahead excludes "file-service" (URI prefix, not a real id). file_ids = re.findall(r"(file[-_](?!service\b)[A-Za-z0-9]+)", payload) sediment_ids = re.findall(r"sediment://([A-Za-z0-9_-]+)", payload) return conversation_id, file_ids, sediment_ids def is_image_tool_event(event: dict[str, Any]) -> bool: value = event.get("v") message = event.get("message") or (value.get("message") if isinstance(value, dict) else None) if not isinstance(message, dict): return False metadata = message.get("metadata") or {} author = message.get("author") or {} content = message.get("content") or {} if author.get("role") != "tool": return False if metadata.get("async_task_type") == "image_gen": return True if content.get("content_type") != "multimodal_text": return False return any( isinstance(part, dict) and ( part.get("content_type") == "image_asset_pointer" or str(part.get("asset_pointer") or "").startswith(("file-service://", "sediment://")) ) for part in content.get("parts") or [] ) def update_conversation_state(state: ConversationState, payload: str, event: dict[str, Any] | None = None) -> None: conversation_id, file_ids, sediment_ids = extract_conversation_ids(payload) if conversation_id and not state.conversation_id: state.conversation_id = conversation_id # Accept file_id / sediment_id when any of: # 1) event is a complete image_gen tool message # 2) prior server_ste_metadata already flipped tool_invoked True (in an image_gen turn) # 3) patch event whose payload references asset_pointer / file-service:// # User messages (type=conversation.message) never satisfy these, so attacker-controlled # substrings in user input cannot inject file ids into state. is_patch_event = isinstance(event, dict) and event.get("o") == "patch" image_context = ( (isinstance(event, dict) and is_image_tool_event(event)) or state.tool_invoked is True or (is_patch_event and ("asset_pointer" in payload or "file-service://" in payload)) ) if image_context: add_unique(state.file_ids, file_ids) add_unique(state.sediment_ids, sediment_ids) if not isinstance(event, dict): return state.conversation_id = str(event.get("conversation_id") or state.conversation_id) value = event.get("v") if isinstance(value, dict): state.conversation_id = str(value.get("conversation_id") or state.conversation_id) if event.get("type") == "moderation": moderation = event.get("moderation_response") if isinstance(moderation, dict) and moderation.get("blocked") is True: state.blocked = True if event.get("type") == "server_ste_metadata": metadata = event.get("metadata") if isinstance(metadata, dict): if isinstance(metadata.get("tool_invoked"), bool): state.tool_invoked = metadata["tool_invoked"] state.turn_use_case = str(metadata.get("turn_use_case") or state.turn_use_case) def conversation_base_event(event_type: str, state: ConversationState, **extra: Any) -> dict[str, Any]: return { "type": event_type, "text": state.text, "conversation_id": state.conversation_id, "file_ids": list(state.file_ids), "sediment_ids": list(state.sediment_ids), "blocked": state.blocked, "tool_invoked": state.tool_invoked, "turn_use_case": state.turn_use_case, **extra, } def iter_conversation_payloads(payloads: Iterator[str], history_text: str = "", history_messages: list[str] | None = None) -> Iterator[dict[str, Any]]: state = ConversationState() history_messages = history_messages or [] history_index = 0 for payload in payloads: # print(f"[upstream_sse] {payload}", flush=True) if not payload: continue if payload == "[DONE]": yield conversation_base_event("conversation.done", state, done=True) break try: event = json.loads(payload) except json.JSONDecodeError: update_conversation_state(state, payload) yield conversation_base_event("conversation.raw", state, payload=payload) continue if not isinstance(event, dict): yield conversation_base_event("conversation.event", state, raw=event) continue update_conversation_state(state, payload, event) if history_index < len(history_messages) and event_assistant_text(event, history_text) == history_messages[history_index]: history_index += 1 state.text = "" continue next_text = assistant_text(event, state.text, history_text) if next_text != state.text: delta = next_text[len(state.text):] if next_text.startswith(state.text) else next_text state.text = next_text yield conversation_base_event("conversation.delta", state, raw=event, delta=delta) continue yield conversation_base_event("conversation.event", state, raw=event) def conversation_events( backend: OpenAIBackendAPI, messages: list[dict[str, Any]] | None = None, model: str = "auto", prompt: str = "", images: list[str] | None = None, size: str | None = None, ) -> Iterator[dict[str, Any]]: normalized = normalize_messages(messages or ([{"role": "user", "content": prompt}] if prompt else [])) image_model = str(model or "").strip() in IMAGE_MODELS history_text = "" if image_model else assistant_history_text(normalized) history_messages = [] if image_model else assistant_history_messages(normalized) final_prompt = prompt_with_global_system(build_image_prompt(prompt, size)) if image_model else prompt payloads = backend.stream_conversation( messages=normalized, model=model, prompt=final_prompt, images=images if image_model else None, system_hints=["picture_v2"] if image_model else None, ) yield from iter_conversation_payloads(payloads, history_text, history_messages) def text_backend() -> OpenAIBackendAPI: return OpenAIBackendAPI(access_token=account_service.get_text_access_token()) def stream_text_deltas(backend: OpenAIBackendAPI, request: ConversationRequest) -> Iterator[str]: attempted_tokens: set[str] = set() token = getattr(backend, "access_token", "") emitted = False while True: if token and token in attempted_tokens: raise RuntimeError("no available text account") if token: attempted_tokens.add(token) try: active_backend = OpenAIBackendAPI(access_token=token) for event in conversation_events(active_backend, messages=request.messages, model=request.model, prompt=request.prompt): if event.get("type") != "conversation.delta": continue delta = str(event.get("delta") or "") if delta: emitted = True yield delta account_service.mark_text_used(token) return except Exception as exc: error_message = str(exc) if token and not emitted and is_token_invalid_error(error_message): account_service.remove_invalid_token(token, "text_stream") token = account_service.get_text_access_token(attempted_tokens) if token: continue raise def collect_text(backend: OpenAIBackendAPI, request: ConversationRequest) -> str: return "".join(stream_text_deltas(backend, request)) def stream_image_outputs( backend: OpenAIBackendAPI, request: ConversationRequest, index: int = 1, total: int = 1, ) -> Iterator[ImageOutput]: last: dict[str, Any] = {} for event in conversation_events( backend, prompt=request.prompt, model=request.model, images=request.images or [], size=request.size, ): last = event if event.get("type") == "conversation.delta": yield ImageOutput( kind="progress", model=request.model, index=index, total=total, text=str(event.get("delta") or ""), upstream_event_type="conversation.delta", ) continue if event.get("type") == "conversation.event": raw = event.get("raw") raw_type = str(raw.get("type") or "") if isinstance(raw, dict) else "" yield ImageOutput( kind="progress", model=request.model, index=index, total=total, upstream_event_type=raw_type, ) conversation_id = str(last.get("conversation_id") or "") file_ids = [str(item) for item in last.get("file_ids") or []] sediment_ids = [str(item) for item in last.get("sediment_ids") or []] message = str(last.get("text") or "").strip() logger.info({ "event": "image_stream_resolve_start", "conversation_id": conversation_id, "file_ids": file_ids, "sediment_ids": sediment_ids, "tool_invoked": last.get("tool_invoked"), "turn_use_case": last.get("turn_use_case"), }) if message and not file_ids and not sediment_ids and last.get("blocked"): yield ImageOutput(kind="message", model=request.model, index=index, total=total, text=message) return should_poll_for_image = bool(request.images) or last.get("turn_use_case") == "image gen" if message and not file_ids and not sediment_ids and not should_poll_for_image: yield ImageOutput(kind="message", model=request.model, index=index, total=total, text=message) return image_urls = backend.resolve_conversation_image_urls(conversation_id, file_ids, sediment_ids) if image_urls: image_items = [ {"b64_json": base64.b64encode(image_data).decode("ascii")} for image_data in backend.download_image_bytes(image_urls) ] data = format_image_result( image_items, request.prompt, request.response_format, request.base_url, int(time.time()), )["data"] if data: yield ImageOutput(kind="result", model=request.model, index=index, total=total, data=data) return if message: yield ImageOutput(kind="message", model=request.model, index=index, total=total, text=message) def stream_image_outputs_with_pool(request: ConversationRequest) -> Iterator[ImageOutput]: if str(request.model or "").strip() not in IMAGE_MODELS: raise ImageGenerationError("unsupported image model,supported models: " + ", ".join(IMAGE_MODELS)) emitted = False last_error = "" for index in range(1, request.n + 1): while True: try: token = account_service.get_available_access_token() except RuntimeError as exc: if emitted: return raise ImageGenerationError(str(exc) or "image generation failed") from exc emitted_for_token = False returned_message = False returned_result = False try: backend = OpenAIBackendAPI(access_token=token) for output in stream_image_outputs(backend, request, index, request.n): if output.kind == "message" and request.message_as_error: raise ImageGenerationError( output.text or "Image generation was rejected by upstream policy.", status_code=400, error_type="invalid_request_error", code="content_policy_violation", ) emitted = True emitted_for_token = True returned_message = output.kind == "message" returned_result = returned_result or output.kind == "result" yield output if returned_message or not returned_result: account_service.mark_image_result(token, False) return account_service.mark_image_result(token, True) break except ImagePollTimeoutError: raise except ImageGenerationError: account_service.mark_image_result(token, False) raise except Exception as exc: account_service.mark_image_result(token, False) last_error = str(exc) logger.warning({"event": "image_stream_fail", "request_token": token, "error": last_error}) if not emitted_for_token and is_token_invalid_error(last_error): account_service.remove_invalid_token(token, "image_stream") continue raise ImageGenerationError(image_stream_error_message(last_error)) from exc if not emitted: if not last_error: last_error = "no account in the pool could generate images — check account quota and rate-limit status" raise ImageGenerationError(image_stream_error_message(last_error)) def stream_image_chunks(outputs: Iterable[ImageOutput]) -> Iterator[dict[str, Any]]: for output in outputs: yield output.to_chunk() def collect_image_outputs(outputs: Iterable[ImageOutput]) -> dict[str, Any]: created = None data: list[dict[str, Any]] = [] message = "" progress_parts: list[str] = [] for output in outputs: created = created or output.created if output.kind == "progress" and output.text: progress_parts.append(output.text) elif output.kind == "message": message = output.text elif output.kind == "result": data.extend(output.data) result: dict[str, Any] = {"created": created or int(time.time()), "data": data} if not data: text = message or "".join(progress_parts).strip() if text: result["message"] = text return result