grok2api-hf / app /services /reverse /ws_imagine.py
Codex
Add root Dockerfile for HF Space build
6bff6a1
"""
Reverse interface: Imagine WebSocket image stream.
"""
import asyncio
import base64
import binascii
import orjson
import re
import time
import uuid
from typing import AsyncGenerator, Dict, Optional
import aiohttp
from app.core.config import get_config
from app.core.logger import logger
from app.services.reverse.utils.headers import build_ws_headers
from app.services.reverse.utils.websocket import WebSocketClient
WS_IMAGINE_URL = "wss://grok.com/ws/imagine/listen"
class _BlockedError(Exception):
pass
class ImagineWebSocketReverse:
"""Imagine WebSocket reverse interface."""
def __init__(self) -> None:
self._url_pattern = re.compile(r"/images/([a-f0-9-]+)\.(png|jpg|jpeg)")
self._client = WebSocketClient()
def _parse_image_url(self, url: str) -> tuple[Optional[str], Optional[str]]:
match = self._url_pattern.search(url or "")
if not match:
return None, None
return match.group(1), match.group(2).lower()
def _decoded_blob_size(self, blob: str) -> int:
if not blob:
return 0
data = blob
if "," in blob and "base64" in blob.split(",", 1)[0]:
data = blob.split(",", 1)[1]
try:
return len(base64.b64decode(data, validate=False))
except (ValueError, binascii.Error):
return 0
def _is_final_image(self, decoded_size: int, final_min_bytes: int) -> bool:
# Final image must satisfy byte-size threshold to avoid tiny preview
# images being treated as final outputs.
return decoded_size >= final_min_bytes
@staticmethod
def _copy_metadata(message: Dict[str, object]) -> Dict[str, object]:
metadata: Dict[str, object] = {}
for key in (
"width",
"height",
"model_name",
"percentage_complete",
"job_id",
"request_id",
"order",
"full_prompt",
):
value = message.get(key)
if value is not None:
metadata[key] = value
return metadata
def _classify_image(self, message: Dict[str, object], final_min_bytes: int, medium_min_bytes: int) -> Optional[Dict[str, object]]:
url = str(message.get("url") or "")
blob = str(message.get("blob") or "")
if not url or not blob:
return None
image_id, ext = self._parse_image_url(url)
image_id = image_id or uuid.uuid4().hex
encoded_size = len(blob)
blob_size = self._decoded_blob_size(blob)
is_final = self._is_final_image(blob_size, final_min_bytes)
stage = (
"final"
if is_final
else ("medium" if blob_size > medium_min_bytes else "preview")
)
result = {
"type": "image",
"image_id": image_id,
"ext": ext,
"stage": stage,
"blob": blob,
"blob_size": blob_size,
"encoded_blob_size": encoded_size,
"url": url,
"is_final": is_final,
}
result.update(self._copy_metadata(message))
return result
def _build_request_message(self, request_id: str, prompt: str, aspect_ratio: str, enable_nsfw: bool) -> Dict[str, object]:
return {
"type": "conversation.item.create",
"timestamp": int(time.time() * 1000),
"item": {
"type": "message",
"content": [
{
"requestId": request_id,
"text": prompt,
"type": "input_scroll",
"properties": {
"section_count": 0,
"is_kids_mode": False,
"enable_nsfw": enable_nsfw,
"skip_upsampler": False,
"enable_side_by_side": True,
"is_initial": False,
"last_prompt": prompt,
"aspect_ratio": aspect_ratio,
},
}
],
},
}
async def stream(
self,
token: str,
prompt: str,
aspect_ratio: str = "2:3",
n: int = 1,
enable_nsfw: bool = True,
max_retries: Optional[int] = None,
) -> AsyncGenerator[Dict[str, object], None]:
retries = max(1, max_retries if max_retries is not None else 1)
parallel_enabled = bool(get_config("image.blocked_parallel_enabled", True))
logger.info(
f"Image generation: prompt='{prompt[:50]}...', n={n}, ratio={aspect_ratio}, nsfw={enable_nsfw}"
)
async def _collect_once() -> list[Dict[str, object]]:
items: list[Dict[str, object]] = []
async for item in self._stream_once(
token, prompt, aspect_ratio, n, enable_nsfw
):
items.append(item)
return items
for attempt in range(retries):
try:
items = await _collect_once()
for item in items:
yield item
return
except _BlockedError:
retries_left = retries - (attempt + 1)
if retries_left > 0 and parallel_enabled:
logger.warning(
f"WebSocket blocked/reviewed, launching {retries_left} parallel retries"
)
tasks = [asyncio.create_task(_collect_once()) for _ in range(retries_left)]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
continue
has_final = any(
isinstance(item, dict)
and item.get("type") == "image"
and item.get("is_final")
for item in result
)
if has_final:
for item in result:
yield item
return
yield {
"type": "error",
"error_code": "blocked",
"error": "blocked_no_final_image",
"parallel_attempts": retries_left,
}
return
if attempt + 1 < retries:
logger.warning(
f"WebSocket blocked/reviewed, retry {attempt + 1}/{retries}"
)
continue
yield {
"type": "error",
"error_code": "blocked",
"error": "blocked_no_final_image",
}
return
except Exception as e:
logger.error(f"WebSocket stream failed: {e}")
yield {
"type": "error",
"error_code": "ws_stream_failed",
"error": str(e),
}
return
async def _stream_once(
self,
token: str,
prompt: str,
aspect_ratio: str,
n: int,
enable_nsfw: bool,
) -> AsyncGenerator[Dict[str, object], None]:
request_id = str(uuid.uuid4())
headers = build_ws_headers(token=token)
timeout = float(get_config("image.timeout"))
stream_timeout = float(get_config("image.stream_timeout"))
final_timeout = float(get_config("image.final_timeout"))
blocked_grace_cfg = get_config("image.blocked_grace_seconds")
blocked_grace = float(blocked_grace_cfg) if blocked_grace_cfg is not None else 10.0
blocked_grace = max(1.0, min(blocked_grace, final_timeout))
final_min_bytes = int(get_config("image.final_min_bytes"))
medium_min_bytes = int(get_config("image.medium_min_bytes"))
try:
conn = await self._client.connect(
WS_IMAGINE_URL,
headers=headers,
timeout=timeout,
ws_kwargs={
"heartbeat": 20,
"receive_timeout": stream_timeout,
},
)
except Exception as e:
status = getattr(e, "status", None)
error_code = (
"rate_limit_exceeded" if status == 429 else "connection_failed"
)
logger.error(f"WebSocket connect failed: {e}")
yield {
"type": "error",
"error_code": error_code,
"status": status,
"error": str(e),
}
return
try:
async with conn as ws:
message = self._build_request_message(
request_id, prompt, aspect_ratio, enable_nsfw
)
await ws.send_json(message)
logger.info(f"WebSocket request sent: {prompt[:80]}...")
final_ids: set[str] = set()
completed = 0
start_time = last_activity = time.monotonic()
medium_received_time: Optional[float] = None
while time.monotonic() - start_time < timeout:
try:
ws_msg = await asyncio.wait_for(ws.receive(), timeout=5.0)
except asyncio.TimeoutError:
now = time.monotonic()
if (
medium_received_time
and completed == 0
and now - medium_received_time > blocked_grace
):
logger.warning(
"Imagine stream blocked suspected: received medium preview but no valid final image "
f"within {blocked_grace:.1f}s (request_id={request_id})"
)
raise _BlockedError()
if completed > 0 and now - last_activity > 10:
logger.info(
f"WebSocket idle timeout, collected {completed} images"
)
break
continue
if ws_msg.type == aiohttp.WSMsgType.TEXT:
last_activity = time.monotonic()
try:
msg = orjson.loads(ws_msg.data)
except orjson.JSONDecodeError as e:
logger.warning(f"WebSocket message decode failed: {e}")
continue
msg_type = msg.get("type")
if msg_type == "image":
info = self._classify_image(msg, final_min_bytes, medium_min_bytes)
if not info:
continue
image_id = info["image_id"]
if info["stage"] == "medium" and medium_received_time is None:
medium_received_time = time.monotonic()
if info["is_final"] and image_id not in final_ids:
final_ids.add(image_id)
completed += 1
logger.debug(
"Final image received: "
f"id={image_id}, decoded_size={info['blob_size']}, "
f"encoded_size={info.get('encoded_blob_size', 0)}"
)
yield info
elif msg_type == "error":
logger.warning(
f"WebSocket error: {msg.get('err_code', '')} - {msg.get('err_msg', '')}"
)
yield {
"type": "error",
"error_code": msg.get("err_code", ""),
"error": msg.get("err_msg", ""),
}
return
if completed >= n:
logger.info(f"WebSocket collected {completed} final images")
break
if (
medium_received_time
and completed == 0
and time.monotonic() - medium_received_time > final_timeout
):
logger.warning(
"Imagine stream final-timeout suspected review/block: "
f"no final image reached threshold in {final_timeout:.1f}s "
f"(request_id={request_id})"
)
raise _BlockedError()
elif ws_msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.ERROR,
):
logger.warning(f"WebSocket closed/error: {ws_msg.type}")
yield {
"type": "error",
"error_code": "ws_closed",
"error": f"websocket closed: {ws_msg.type}",
}
break
except aiohttp.ClientError as e:
logger.error(f"WebSocket connection error: {e}")
yield {"type": "error", "error_code": "connection_failed", "error": str(e)}
__all__ = ["ImagineWebSocketReverse", "WS_IMAGINE_URL"]