Spaces:
Paused
Paused
| import base64 | |
| import hashlib | |
| import json | |
| import re | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any, Iterator | |
| from curl_cffi import requests | |
| from fastapi import HTTPException | |
| from utils.log import logger | |
| IMAGE_MODELS = {"gpt-image-2", "codex-gpt-image-2"} | |
| OUTPUT_DIR = Path(__file__).resolve().parent / "output" | |
| def new_uuid() -> str: | |
| return str(uuid.uuid4()) | |
| def is_image_chat_request(body: dict[str, object]) -> bool: | |
| model = str(body.get("model") or "").strip() | |
| modalities = body.get("modalities") | |
| if model in IMAGE_MODELS: | |
| return True | |
| return isinstance(modalities, list) and "image" in {str(item or "").strip().lower() for item in modalities} | |
| def ensure_ok(response: requests.Response, context: str) -> None: | |
| if 200 <= response.status_code < 300: | |
| return | |
| body: Any = response.text | |
| try: | |
| body = response.json() | |
| except Exception: | |
| pass | |
| raise RuntimeError(f"{context} failed: status={response.status_code}, body={body}") | |
| def sse_json_stream(items) -> Iterator[str]: | |
| yield ": stream-open\n\n" | |
| try: | |
| for item in items: | |
| yield f"data: {json.dumps(item, ensure_ascii=False)}\n\n" | |
| except Exception as exc: | |
| logger.warning({ | |
| "event": "sse_stream_error", | |
| "error_type": exc.__class__.__name__, | |
| "error": str(exc), | |
| }) | |
| error = exc.to_openai_error() if hasattr(exc, "to_openai_error") else { | |
| "error": {"message": str(exc), "type": exc.__class__.__name__} | |
| } | |
| yield f"data: {json.dumps(error, ensure_ascii=False)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| def anthropic_sse_stream(items) -> Iterator[str]: | |
| try: | |
| for item in items: | |
| event = str(item.get("type") or "message_delta") if isinstance(item, dict) else "message_delta" | |
| yield f"event: {event}\n" | |
| yield f"data: {json.dumps(item, ensure_ascii=False)}\n\n" | |
| except Exception as exc: | |
| logger.warning({ | |
| "event": "anthropic_sse_stream_error", | |
| "error_type": exc.__class__.__name__, | |
| "error": str(exc), | |
| }) | |
| error = {"type": "error", "error": {"type": exc.__class__.__name__, "message": str(exc)}} | |
| yield "event: error\n" | |
| yield f"data: {json.dumps(error, ensure_ascii=False)}\n\n" | |
| def iter_sse_payloads(response: requests.Response) -> Iterator[str]: | |
| for raw_line in response.iter_lines(): | |
| if not raw_line: | |
| continue | |
| line = raw_line.decode("utf-8", errors="ignore") if isinstance(raw_line, bytes) else str(raw_line) | |
| if not line.startswith("data:"): | |
| continue | |
| payload = line[5:].strip() | |
| if payload: | |
| yield payload | |
| def save_images_from_text(text: str, prefix: str) -> list[Path]: | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| matches = re.findall(r"data:image/[^;]+;base64,[A-Za-z0-9+/=]+", text or "") | |
| saved_paths: list[Path] = [] | |
| timestamp = int(time.time() * 1000) | |
| for index, data_url in enumerate(matches, start=1): | |
| header, encoded = data_url.split(",", 1) | |
| image_type = header.split(";")[0].removeprefix("data:image/").strip() or "png" | |
| extension = "jpg" if image_type == "jpeg" else image_type | |
| output_path = OUTPUT_DIR / f"{prefix}_{timestamp}_{index}.{extension}" | |
| output_path.write_bytes(base64.b64decode(encoded)) | |
| saved_paths.append(output_path) | |
| return saved_paths | |
| def anonymize_token(token: object) -> str: | |
| value = str(token or "").strip() | |
| if not value: | |
| return "token:empty" | |
| digest = hashlib.sha256(value.encode("utf-8")).hexdigest()[:10] | |
| return f"token:{digest}" | |
| def extract_response_prompt(input_value: object) -> str: | |
| if isinstance(input_value, str): | |
| return input_value.strip() | |
| if isinstance(input_value, dict): | |
| role = str(input_value.get("role") or "").strip().lower() | |
| if role and role != "user": | |
| return "" | |
| return extract_prompt_from_message_content(input_value.get("content")) | |
| if not isinstance(input_value, list): | |
| return "" | |
| prompt_parts: list[str] = [] | |
| for item in input_value: | |
| if isinstance(item, dict) and str(item.get("type") or "").strip() == "input_text": | |
| text = str(item.get("text") or "").strip() | |
| if text: | |
| prompt_parts.append(text) | |
| continue | |
| if not isinstance(item, dict): | |
| continue | |
| role = str(item.get("role") or "").strip().lower() | |
| if role and role != "user": | |
| continue | |
| prompt = extract_prompt_from_message_content(item.get("content")) | |
| if prompt: | |
| prompt_parts.append(prompt) | |
| return "\n".join(prompt_parts).strip() | |
| def has_response_image_generation_tool(body: dict[str, object]) -> bool: | |
| tools = body.get("tools") | |
| if isinstance(tools, list): | |
| for tool in tools: | |
| if isinstance(tool, dict) and str(tool.get("type") or "").strip() == "image_generation": | |
| return True | |
| tool_choice = body.get("tool_choice") | |
| return isinstance(tool_choice, dict) and str(tool_choice.get("type") or "").strip() == "image_generation" | |
| def extract_prompt_from_message_content(content: object) -> str: | |
| if isinstance(content, str): | |
| return content.strip() | |
| if not isinstance(content, list): | |
| return "" | |
| parts: list[str] = [] | |
| for item in content: | |
| if not isinstance(item, dict): | |
| continue | |
| item_type = str(item.get("type") or "").strip() | |
| if item_type == "text": | |
| text = str(item.get("text") or "").strip() | |
| if text: | |
| parts.append(text) | |
| elif item_type == "input_text": | |
| text = str(item.get("text") or item.get("input_text") or "").strip() | |
| if text: | |
| parts.append(text) | |
| return "\n".join(parts).strip() | |
| def extract_image_from_message_content(content: object) -> list[tuple[bytes, str]]: | |
| if not isinstance(content, list): | |
| return [] | |
| images = [] | |
| for item in content: | |
| if not isinstance(item, dict): | |
| continue | |
| item_type = str(item.get("type") or "").strip() | |
| if item_type == "image_url": | |
| url_obj = item.get("image_url") or item | |
| url = str(url_obj.get("url") or "") if isinstance(url_obj, dict) else str(url_obj) | |
| if url.startswith("data:"): | |
| header, _, data = url.partition(",") | |
| mime = header.split(";")[0].removeprefix("data:") | |
| images.append((base64.b64decode(data), mime or "image/png")) | |
| elif item_type == "input_image": | |
| image_url = str(item.get("image_url") or "") | |
| if image_url.startswith("data:"): | |
| header, _, data = image_url.partition(",") | |
| mime = header.split(";")[0].removeprefix("data:") | |
| images.append((base64.b64decode(data), mime or "image/png")) | |
| return images | |
| def extract_chat_image(body: dict[str, object]) -> list[tuple[bytes, str]]: | |
| messages = body.get("messages") | |
| if not isinstance(messages, list): | |
| return [] | |
| for message in reversed(messages): | |
| if not isinstance(message, dict): | |
| continue | |
| if str(message.get("role") or "").strip().lower() != "user": | |
| continue | |
| images = extract_image_from_message_content(message.get("content")) | |
| if images: | |
| return images | |
| return [] | |
| def extract_chat_prompt(body: dict[str, object]) -> str: | |
| direct_prompt = str(body.get("prompt") or "").strip() | |
| if direct_prompt: | |
| return direct_prompt | |
| messages = body.get("messages") | |
| if not isinstance(messages, list): | |
| return "" | |
| prompt_parts: list[str] = [] | |
| for message in messages: | |
| if not isinstance(message, dict): | |
| continue | |
| if str(message.get("role") or "").strip().lower() != "user": | |
| continue | |
| prompt = extract_prompt_from_message_content(message.get("content")) | |
| if prompt: | |
| prompt_parts.append(prompt) | |
| return "\n".join(prompt_parts).strip() | |
| def parse_image_count(raw_value: object) -> int: | |
| try: | |
| value = int(raw_value or 1) | |
| except (TypeError, ValueError) as exc: | |
| raise HTTPException(status_code=400, detail={"error": "n must be an integer"}) from exc | |
| if value < 1 or value > 4: | |
| raise HTTPException(status_code=400, detail={"error": "n must be between 1 and 4"}) | |
| return value | |
| def build_chat_image_markdown_content(image_result: dict[str, object]) -> str: | |
| image_items = image_result.get("data") if isinstance(image_result.get("data"), list) else [] | |
| markdown_images: list[str] = [] | |
| for index, item in enumerate(image_items, start=1): | |
| if not isinstance(item, dict): | |
| continue | |
| b64_json = str(item.get("b64_json") or "").strip() | |
| if b64_json: | |
| markdown_images.append(f"") | |
| return "\n\n".join(markdown_images) if markdown_images else "Image generation completed." | |