chatgpt2api / utils /helper.py
tx1538's picture
Upload 179 files
9d7ddb9 verified
Raw
History Blame
9.15 kB
import base64
import hashlib
import json
import re
import time
import uuid
from pathlib import Path
from typing import Any, Iterator
from curl_cffi import requests
from fastapi import HTTPException
from utils.log import logger
IMAGE_MODELS = {"gpt-image-2", "codex-gpt-image-2"}
OUTPUT_DIR = Path(__file__).resolve().parent / "output"
def new_uuid() -> str:
return str(uuid.uuid4())
def is_image_chat_request(body: dict[str, object]) -> bool:
model = str(body.get("model") or "").strip()
modalities = body.get("modalities")
if model in IMAGE_MODELS:
return True
return isinstance(modalities, list) and "image" in {str(item or "").strip().lower() for item in modalities}
def ensure_ok(response: requests.Response, context: str) -> None:
if 200 <= response.status_code < 300:
return
body: Any = response.text
try:
body = response.json()
except Exception:
pass
raise RuntimeError(f"{context} failed: status={response.status_code}, body={body}")
def sse_json_stream(items) -> Iterator[str]:
yield ": stream-open\n\n"
try:
for item in items:
yield f"data: {json.dumps(item, ensure_ascii=False)}\n\n"
except Exception as exc:
logger.warning({
"event": "sse_stream_error",
"error_type": exc.__class__.__name__,
"error": str(exc),
})
error = exc.to_openai_error() if hasattr(exc, "to_openai_error") else {
"error": {"message": str(exc), "type": exc.__class__.__name__}
}
yield f"data: {json.dumps(error, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
def anthropic_sse_stream(items) -> Iterator[str]:
try:
for item in items:
event = str(item.get("type") or "message_delta") if isinstance(item, dict) else "message_delta"
yield f"event: {event}\n"
yield f"data: {json.dumps(item, ensure_ascii=False)}\n\n"
except Exception as exc:
logger.warning({
"event": "anthropic_sse_stream_error",
"error_type": exc.__class__.__name__,
"error": str(exc),
})
error = {"type": "error", "error": {"type": exc.__class__.__name__, "message": str(exc)}}
yield "event: error\n"
yield f"data: {json.dumps(error, ensure_ascii=False)}\n\n"
def iter_sse_payloads(response: requests.Response) -> Iterator[str]:
for raw_line in response.iter_lines():
if not raw_line:
continue
line = raw_line.decode("utf-8", errors="ignore") if isinstance(raw_line, bytes) else str(raw_line)
if not line.startswith("data:"):
continue
payload = line[5:].strip()
if payload:
yield payload
def save_images_from_text(text: str, prefix: str) -> list[Path]:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
matches = re.findall(r"data:image/[^;]+;base64,[A-Za-z0-9+/=]+", text or "")
saved_paths: list[Path] = []
timestamp = int(time.time() * 1000)
for index, data_url in enumerate(matches, start=1):
header, encoded = data_url.split(",", 1)
image_type = header.split(";")[0].removeprefix("data:image/").strip() or "png"
extension = "jpg" if image_type == "jpeg" else image_type
output_path = OUTPUT_DIR / f"{prefix}_{timestamp}_{index}.{extension}"
output_path.write_bytes(base64.b64decode(encoded))
saved_paths.append(output_path)
return saved_paths
def anonymize_token(token: object) -> str:
value = str(token or "").strip()
if not value:
return "token:empty"
digest = hashlib.sha256(value.encode("utf-8")).hexdigest()[:10]
return f"token:{digest}"
def extract_response_prompt(input_value: object) -> str:
if isinstance(input_value, str):
return input_value.strip()
if isinstance(input_value, dict):
role = str(input_value.get("role") or "").strip().lower()
if role and role != "user":
return ""
return extract_prompt_from_message_content(input_value.get("content"))
if not isinstance(input_value, list):
return ""
prompt_parts: list[str] = []
for item in input_value:
if isinstance(item, dict) and str(item.get("type") or "").strip() == "input_text":
text = str(item.get("text") or "").strip()
if text:
prompt_parts.append(text)
continue
if not isinstance(item, dict):
continue
role = str(item.get("role") or "").strip().lower()
if role and role != "user":
continue
prompt = extract_prompt_from_message_content(item.get("content"))
if prompt:
prompt_parts.append(prompt)
return "\n".join(prompt_parts).strip()
def has_response_image_generation_tool(body: dict[str, object]) -> bool:
tools = body.get("tools")
if isinstance(tools, list):
for tool in tools:
if isinstance(tool, dict) and str(tool.get("type") or "").strip() == "image_generation":
return True
tool_choice = body.get("tool_choice")
return isinstance(tool_choice, dict) and str(tool_choice.get("type") or "").strip() == "image_generation"
def extract_prompt_from_message_content(content: object) -> str:
if isinstance(content, str):
return content.strip()
if not isinstance(content, list):
return ""
parts: list[str] = []
for item in content:
if not isinstance(item, dict):
continue
item_type = str(item.get("type") or "").strip()
if item_type == "text":
text = str(item.get("text") or "").strip()
if text:
parts.append(text)
elif item_type == "input_text":
text = str(item.get("text") or item.get("input_text") or "").strip()
if text:
parts.append(text)
return "\n".join(parts).strip()
def extract_image_from_message_content(content: object) -> list[tuple[bytes, str]]:
if not isinstance(content, list):
return []
images = []
for item in content:
if not isinstance(item, dict):
continue
item_type = str(item.get("type") or "").strip()
if item_type == "image_url":
url_obj = item.get("image_url") or item
url = str(url_obj.get("url") or "") if isinstance(url_obj, dict) else str(url_obj)
if url.startswith("data:"):
header, _, data = url.partition(",")
mime = header.split(";")[0].removeprefix("data:")
images.append((base64.b64decode(data), mime or "image/png"))
elif item_type == "input_image":
image_url = str(item.get("image_url") or "")
if image_url.startswith("data:"):
header, _, data = image_url.partition(",")
mime = header.split(";")[0].removeprefix("data:")
images.append((base64.b64decode(data), mime or "image/png"))
return images
def extract_chat_image(body: dict[str, object]) -> list[tuple[bytes, str]]:
messages = body.get("messages")
if not isinstance(messages, list):
return []
for message in reversed(messages):
if not isinstance(message, dict):
continue
if str(message.get("role") or "").strip().lower() != "user":
continue
images = extract_image_from_message_content(message.get("content"))
if images:
return images
return []
def extract_chat_prompt(body: dict[str, object]) -> str:
direct_prompt = str(body.get("prompt") or "").strip()
if direct_prompt:
return direct_prompt
messages = body.get("messages")
if not isinstance(messages, list):
return ""
prompt_parts: list[str] = []
for message in messages:
if not isinstance(message, dict):
continue
if str(message.get("role") or "").strip().lower() != "user":
continue
prompt = extract_prompt_from_message_content(message.get("content"))
if prompt:
prompt_parts.append(prompt)
return "\n".join(prompt_parts).strip()
def parse_image_count(raw_value: object) -> int:
try:
value = int(raw_value or 1)
except (TypeError, ValueError) as exc:
raise HTTPException(status_code=400, detail={"error": "n must be an integer"}) from exc
if value < 1 or value > 4:
raise HTTPException(status_code=400, detail={"error": "n must be between 1 and 4"})
return value
def build_chat_image_markdown_content(image_result: dict[str, object]) -> str:
image_items = image_result.get("data") if isinstance(image_result.get("data"), list) else []
markdown_images: list[str] = []
for index, item in enumerate(image_items, start=1):
if not isinstance(item, dict):
continue
b64_json = str(item.get("b64_json") or "").strip()
if b64_json:
markdown_images.append(f"![image_{index}](data:image/png;base64,{b64_json})")
return "\n\n".join(markdown_images) if markdown_images else "Image generation completed."