from __future__ import annotations import base64 import hashlib import json import re import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Iterable, Iterator import tiktoken from services.account_service import account_service from services.config import config from services.openai_backend_api import OpenAIBackendAPI from utils.helper import IMAGE_MODELS, extract_image_from_message_content from utils.log import logger class ImageGenerationError(Exception): def __init__( self, message: str, status_code: int = 502, error_type: str = "server_error", code: str | None = "upstream_error", param: str | None = None, ) -> None: super().__init__(message) self.status_code = status_code self.error_type = error_type self.code = code self.param = param def to_openai_error(self) -> dict[str, Any]: return { "error": { "message": str(self), "type": self.error_type, "param": self.param, "code": self.code, } } def is_token_invalid_error(message: str) -> bool: text = str(message or "").lower() return ( "token_invalidated" in text or "token_revoked" in text or "authentication token has been invalidated" in text or "invalidated oauth token" in text ) def image_stream_error_message(message: str) -> str: text = str(message or "") lower = text.lower() if "curl: (35)" in lower or "tls connect error" in lower or "openssl_internal" in lower: return "upstream image connection failed, please retry later" return text or "image generation failed" def encode_images(images: Iterable[tuple[bytes, str, str]]) -> list[str]: return [base64.b64encode(data).decode("ascii") for data, _, _ in images if data] def save_image_bytes(image_data: bytes, base_url: str | None = None) -> str: config.cleanup_old_images() file_hash = hashlib.md5(image_data).hexdigest() filename = f"{int(time.time())}_{file_hash}.png" relative_dir = Path(time.strftime("%Y"), time.strftime("%m"), time.strftime("%d")) file_path = config.images_dir / relative_dir / filename file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_bytes(image_data) return f"{(base_url or config.base_url)}/images/{relative_dir.as_posix()}/{filename}" def message_text(content: Any) -> str: if isinstance(content, str): return content if isinstance(content, list): parts = [] for item in content: if isinstance(item, str): parts.append(item) elif isinstance(item, dict) and str(item.get("type") or "") in {"text", "input_text", "output_text"}: parts.append(str(item.get("text") or "")) return "".join(parts) return "" def normalize_messages(messages: object, system: Any = None) -> list[dict[str, Any]]: normalized = [] if config.global_system_prompt: normalized.append({"role": "system", "content": config.global_system_prompt}) system_text = message_text(system) if system_text: normalized.append({"role": "system", "content": system_text}) if isinstance(messages, list): for message in messages: if not isinstance(message, dict): continue role = message.get("role", "user") content = message.get("content", "") text = message_text(content) images: list[tuple[bytes, str]] = [] if role == "user": images.extend(extract_image_from_message_content(content)) if isinstance(content, list): for part in content: if not isinstance(part, dict) or part.get("type") != "image": continue data = part.get("data") if isinstance(data, (bytes, bytearray)): images.append((bytes(data), str(part.get("mime") or "image/png"))) if images: parts: list[Any] = [] if text: parts.append({"type": "text", "text": text}) for data, mime in images: parts.append({"type": "image", "data": data, "mime": mime}) normalized.append({"role": role, "content": parts}) else: normalized.append({"role": role, "content": text}) return normalized def prompt_with_global_system(prompt: str) -> str: return f"{config.global_system_prompt}\n\n{prompt}" if config.global_system_prompt else prompt def assistant_history_text(messages: list[dict[str, Any]]) -> str: return "".join(str(item.get("content") or "") for item in messages if item.get("role") == "assistant") def assistant_history_messages(messages: list[dict[str, Any]]) -> list[str]: return [str(item.get("content") or "") for item in messages if item.get("role") == "assistant" and item.get("content")] def build_image_prompt(prompt: str, size: str | None) -> str: if not size: return prompt if size not in {"1:1", "16:9", "9:16", "4:3", "3:4"}: return f"{prompt.strip()}\n\n输出图片,宽高比为 {size}。" hint = { "1:1": "输出为 1:1 正方形构图,主体居中,适合正方形画幅。", "16:9": "输出为 16:9 横屏构图,适合宽画幅展示。", "9:16": "输出为 9:16 竖屏构图,适合竖版画幅展示。", "4:3": "输出为 4:3 比例,兼顾宽度与高度,适合展示画面细节。", "3:4": "输出为 3:4 比例,纵向构图,适合人物肖像或竖向场景。", }[size] return f"{prompt.strip()}\n\n{hint}" def encoding_for_model(model: str): try: return tiktoken.encoding_for_model(model) except KeyError: try: return tiktoken.get_encoding("o200k_base") except KeyError: return tiktoken.get_encoding("cl100k_base") def count_message_tokens(messages: list[dict[str, Any]], model: str) -> int: encoding = encoding_for_model(model) total = 0 for message in messages: total += 3 for key, value in message.items(): if not isinstance(value, str): continue total += len(encoding.encode(value)) if key == "name": total += 1 return total + 3 def count_text_tokens(text: str, model: str) -> int: return len(encoding_for_model(model).encode(text)) def format_image_result( items: list[dict[str, Any]], prompt: str, response_format: str, base_url: str | None = None, created: int | None = None, message: str = "", ) -> dict[str, Any]: data: list[dict[str, Any]] = [] for item in items: b64_json = str(item.get("b64_json") or "").strip() if not b64_json: continue revised_prompt = str(item.get("revised_prompt") or prompt).strip() or prompt if response_format == "b64_json": data.append({ "b64_json": b64_json, "url": save_image_bytes(base64.b64decode(b64_json), base_url), "revised_prompt": revised_prompt, }) else: data.append({ "url": save_image_bytes(base64.b64decode(b64_json), base_url), "revised_prompt": revised_prompt, }) result: dict[str, Any] = {"created": created or int(time.time()), "data": data} if message and not data: result["message"] = message return result @dataclass class ConversationRequest: model: str = "auto" prompt: str = "" messages: list[dict[str, Any]] | None = None images: list[str] | None = None n: int = 1 size: str | None = None response_format: str = "b64_json" base_url: str | None = None message_as_error: bool = False @dataclass class ConversationState: text: str = "" conversation_id: str = "" file_ids: list[str] = field(default_factory=list) sediment_ids: list[str] = field(default_factory=list) blocked: bool = False tool_invoked: bool | None = None turn_use_case: str = "" @dataclass class ImageOutput: kind: str model: str index: int total: int created: int = field(default_factory=lambda: int(time.time())) text: str = "" upstream_event_type: str = "" data: list[dict[str, Any]] = field(default_factory=list) def to_chunk(self) -> dict[str, Any]: chunk: dict[str, Any] = { "object": "image.generation.chunk", "created": self.created, "model": self.model, "index": self.index, "total": self.total, "progress_text": self.text, "upstream_event_type": self.upstream_event_type, "data": [], } if self.kind == "message": chunk.update({ "object": "image.generation.message", "message": self.text, }) chunk.pop("progress_text", None) chunk.pop("upstream_event_type", None) elif self.kind == "result": chunk.update({ "object": "image.generation.result", "data": self.data, }) chunk.pop("progress_text", None) chunk.pop("upstream_event_type", None) return chunk def assistant_message_text(message: dict[str, Any]) -> str: content = message.get("content") or {} parts = content.get("parts") or [] if not isinstance(parts, list): return "" return "".join(part for part in parts if isinstance(part, str)) def strip_history(text: str, history_text: str = "") -> str: text = str(text or "") history_text = str(history_text or "") while history_text and text.startswith(history_text): text = text[len(history_text):] return text def assistant_text(event: dict[str, Any], current_text: str = "", history_text: str = "") -> str: for candidate in (event, event.get("v")): if not isinstance(candidate, dict): continue message = candidate.get("message") if not isinstance(message, dict): continue role = str((message.get("author") or {}).get("role") or "").strip().lower() if role != "assistant": continue text = assistant_message_text(message) if text: return strip_history(text, history_text) return apply_text_patch(event, current_text, history_text) def event_assistant_text(event: dict[str, Any], history_text: str = "") -> str: for candidate in (event, event.get("v")): if not isinstance(candidate, dict): continue message = candidate.get("message") if isinstance(message, dict) and (message.get("author") or {}).get("role") == "assistant": return strip_history(assistant_message_text(message), history_text) return "" def apply_text_patch(event: dict[str, Any], current_text: str = "", history_text: str = "") -> str: if event.get("p") == "/message/content/parts/0": return apply_patch_op(event, current_text, history_text) operations = event.get("v") if isinstance(operations, str) and current_text and not event.get("p") and not event.get("o"): return current_text + operations if event.get("o") == "patch" and isinstance(operations, list): text = current_text for item in operations: if isinstance(item, dict): text = apply_text_patch(item, text, history_text) return text if not isinstance(operations, list): return current_text text = current_text for item in operations: if isinstance(item, dict): text = apply_text_patch(item, text, history_text) return text def apply_patch_op(operation: dict[str, Any], current_text: str, history_text: str = "") -> str: op = operation.get("o") value = str(operation.get("v") or "") if op == "append": return current_text + value if op == "replace": return strip_history(value, history_text) return current_text def add_unique(values: list[str], candidates: list[str]) -> None: for candidate in candidates: if candidate and candidate not in values: values.append(candidate) def extract_conversation_ids(payload: str) -> tuple[str, list[str], list[str]]: conversation_match = re.search(r'"conversation_id"\s*:\s*"([^"]+)"', payload) conversation_id = conversation_match.group(1) if conversation_match else "" file_ids = re.findall(r"(file[-_][A-Za-z0-9]+)", payload) sediment_ids = re.findall(r"sediment://([A-Za-z0-9_-]+)", payload) return conversation_id, file_ids, sediment_ids def is_image_tool_event(event: dict[str, Any]) -> bool: value = event.get("v") message = event.get("message") or (value.get("message") if isinstance(value, dict) else None) if not isinstance(message, dict): return False metadata = message.get("metadata") or {} author = message.get("author") or {} return author.get("role") == "tool" and metadata.get("async_task_type") == "image_gen" def update_conversation_state(state: ConversationState, payload: str, event: dict[str, Any] | None = None) -> None: conversation_id, file_ids, sediment_ids = extract_conversation_ids(payload) if conversation_id and not state.conversation_id: state.conversation_id = conversation_id if isinstance(event, dict) and is_image_tool_event(event): add_unique(state.file_ids, file_ids) add_unique(state.sediment_ids, sediment_ids) if not isinstance(event, dict): return state.conversation_id = str(event.get("conversation_id") or state.conversation_id) value = event.get("v") if isinstance(value, dict): state.conversation_id = str(value.get("conversation_id") or state.conversation_id) if event.get("type") == "moderation": moderation = event.get("moderation_response") if isinstance(moderation, dict) and moderation.get("blocked") is True: state.blocked = True if event.get("type") == "server_ste_metadata": metadata = event.get("metadata") if isinstance(metadata, dict): if isinstance(metadata.get("tool_invoked"), bool): state.tool_invoked = metadata["tool_invoked"] state.turn_use_case = str(metadata.get("turn_use_case") or state.turn_use_case) def conversation_base_event(event_type: str, state: ConversationState, **extra: Any) -> dict[str, Any]: return { "type": event_type, "text": state.text, "conversation_id": state.conversation_id, "file_ids": list(state.file_ids), "sediment_ids": list(state.sediment_ids), "blocked": state.blocked, "tool_invoked": state.tool_invoked, "turn_use_case": state.turn_use_case, **extra, } def iter_conversation_payloads(payloads: Iterator[str], history_text: str = "", history_messages: list[str] | None = None) -> Iterator[dict[str, Any]]: state = ConversationState() history_messages = history_messages or [] history_index = 0 for payload in payloads: # print(f"[upstream_sse] {payload}", flush=True) if not payload: continue if payload == "[DONE]": yield conversation_base_event("conversation.done", state, done=True) break try: event = json.loads(payload) except json.JSONDecodeError: update_conversation_state(state, payload) yield conversation_base_event("conversation.raw", state, payload=payload) continue if not isinstance(event, dict): yield conversation_base_event("conversation.event", state, raw=event) continue update_conversation_state(state, payload, event) if history_index < len(history_messages) and event_assistant_text(event, history_text) == history_messages[history_index]: history_index += 1 state.text = "" continue next_text = assistant_text(event, state.text, history_text) if next_text != state.text: delta = next_text[len(state.text):] if next_text.startswith(state.text) else next_text state.text = next_text yield conversation_base_event("conversation.delta", state, raw=event, delta=delta) continue yield conversation_base_event("conversation.event", state, raw=event) def conversation_events( backend: OpenAIBackendAPI, messages: list[dict[str, Any]] | None = None, model: str = "auto", prompt: str = "", images: list[str] | None = None, size: str | None = None, ) -> Iterator[dict[str, Any]]: normalized = normalize_messages(messages or ([{"role": "user", "content": prompt}] if prompt else [])) image_model = str(model or "").strip() in IMAGE_MODELS history_text = "" if image_model else assistant_history_text(normalized) history_messages = [] if image_model else assistant_history_messages(normalized) final_prompt = prompt_with_global_system(build_image_prompt(prompt, size)) if image_model else prompt payloads = backend.stream_conversation( messages=normalized, model=model, prompt=final_prompt, images=images if image_model else None, system_hints=["picture_v2"] if image_model else None, ) yield from iter_conversation_payloads(payloads, history_text, history_messages) def text_backend() -> OpenAIBackendAPI: return OpenAIBackendAPI(access_token=account_service.get_text_access_token()) def stream_text_deltas(backend: OpenAIBackendAPI, request: ConversationRequest) -> Iterator[str]: attempted_tokens: set[str] = set() token = getattr(backend, "access_token", "") emitted = False while True: if token and token in attempted_tokens: raise RuntimeError("no available text account") if token: attempted_tokens.add(token) try: active_backend = OpenAIBackendAPI(access_token=token) for event in conversation_events(active_backend, messages=request.messages, model=request.model, prompt=request.prompt): if event.get("type") != "conversation.delta": continue delta = str(event.get("delta") or "") if delta: emitted = True yield delta account_service.mark_text_used(token) return except Exception as exc: error_message = str(exc) if token and not emitted and is_token_invalid_error(error_message): account_service.remove_invalid_token(token, "text_stream") token = account_service.get_text_access_token(attempted_tokens) if token: continue raise def collect_text(backend: OpenAIBackendAPI, request: ConversationRequest) -> str: return "".join(stream_text_deltas(backend, request)) def stream_image_outputs( backend: OpenAIBackendAPI, request: ConversationRequest, index: int = 1, total: int = 1, ) -> Iterator[ImageOutput]: last: dict[str, Any] = {} for event in conversation_events( backend, prompt=request.prompt, model=request.model, images=request.images or [], size=request.size, ): last = event if event.get("type") == "conversation.delta": yield ImageOutput( kind="progress", model=request.model, index=index, total=total, text=str(event.get("delta") or ""), upstream_event_type="conversation.delta", ) continue if event.get("type") == "conversation.event": raw = event.get("raw") raw_type = str(raw.get("type") or "") if isinstance(raw, dict) else "" yield ImageOutput( kind="progress", model=request.model, index=index, total=total, upstream_event_type=raw_type, ) conversation_id = str(last.get("conversation_id") or "") file_ids = [str(item) for item in last.get("file_ids") or []] sediment_ids = [str(item) for item in last.get("sediment_ids") or []] message = str(last.get("text") or "").strip() is_text_response = last.get("tool_invoked") is False or last.get("turn_use_case") == "text" logger.info({ "event": "image_stream_resolve_start", "conversation_id": conversation_id, "file_ids": file_ids, "sediment_ids": sediment_ids, "tool_invoked": last.get("tool_invoked"), "turn_use_case": last.get("turn_use_case"), }) if message and not file_ids and not sediment_ids and (last.get("blocked") or is_text_response): yield ImageOutput(kind="message", model=request.model, index=index, total=total, text=message) return image_urls = backend.resolve_conversation_image_urls(conversation_id, file_ids, sediment_ids) if image_urls: image_items = [ {"b64_json": base64.b64encode(image_data).decode("ascii")} for image_data in backend.download_image_bytes(image_urls) ] data = format_image_result( image_items, request.prompt, request.response_format, request.base_url, int(time.time()), )["data"] if data: yield ImageOutput(kind="result", model=request.model, index=index, total=total, data=data) return if message: yield ImageOutput(kind="message", model=request.model, index=index, total=total, text=message) def stream_image_outputs_with_pool(request: ConversationRequest) -> Iterator[ImageOutput]: if str(request.model or "").strip() not in IMAGE_MODELS: raise ImageGenerationError("unsupported image model,supported models: " + ", ".join(IMAGE_MODELS)) emitted = False last_error = "" for index in range(1, request.n + 1): while True: try: token = account_service.get_available_access_token() except RuntimeError as exc: if emitted: return raise ImageGenerationError(str(exc) or "image generation failed") from exc emitted_for_token = False returned_message = False returned_result = False try: backend = OpenAIBackendAPI(access_token=token) for output in stream_image_outputs(backend, request, index, request.n): if output.kind == "message" and request.message_as_error: raise ImageGenerationError( output.text or "Image generation was rejected by upstream policy.", status_code=400, error_type="invalid_request_error", code="content_policy_violation", ) emitted = True emitted_for_token = True returned_message = output.kind == "message" returned_result = returned_result or output.kind == "result" yield output if returned_message or not returned_result: account_service.mark_image_result(token, False) return account_service.mark_image_result(token, True) break except ImageGenerationError: account_service.mark_image_result(token, False) raise except Exception as exc: account_service.mark_image_result(token, False) last_error = str(exc) logger.warning({"event": "image_stream_fail", "request_token": token, "error": last_error}) if not emitted_for_token and is_token_invalid_error(last_error): account_service.remove_invalid_token(token, "image_stream") continue raise ImageGenerationError(image_stream_error_message(last_error)) from exc if not emitted: raise ImageGenerationError(image_stream_error_message(last_error)) def stream_image_chunks(outputs: Iterable[ImageOutput]) -> Iterator[dict[str, Any]]: for output in outputs: yield output.to_chunk() def collect_image_outputs(outputs: Iterable[ImageOutput]) -> dict[str, Any]: created = None data: list[dict[str, Any]] = [] message = "" progress_parts: list[str] = [] for output in outputs: created = created or output.created if output.kind == "progress" and output.text: progress_parts.append(output.text) elif output.kind == "message": message = output.text elif output.kind == "result": data.extend(output.data) result: dict[str, Any] = {"created": created or int(time.time()), "data": data} if not data: text = message or "".join(progress_parts).strip() if text: result["message"] = text return result