Spaces:
Runtime error
Runtime error
| """ | |
| Reverse interface: Imagine WebSocket image stream. | |
| """ | |
| import asyncio | |
| import base64 | |
| import binascii | |
| import orjson | |
| import re | |
| import time | |
| import uuid | |
| from typing import AsyncGenerator, Dict, Optional | |
| import aiohttp | |
| from app.core.config import get_config | |
| from app.core.logger import logger | |
| from app.services.reverse.utils.headers import build_ws_headers | |
| from app.services.reverse.utils.websocket import WebSocketClient | |
| WS_IMAGINE_URL = "wss://grok.com/ws/imagine/listen" | |
| class _BlockedError(Exception): | |
| pass | |
| class ImagineWebSocketReverse: | |
| """Imagine WebSocket reverse interface.""" | |
| def __init__(self) -> None: | |
| self._url_pattern = re.compile(r"/images/([a-f0-9-]+)\.(png|jpg|jpeg)") | |
| self._client = WebSocketClient() | |
| def _parse_image_url(self, url: str) -> tuple[Optional[str], Optional[str]]: | |
| match = self._url_pattern.search(url or "") | |
| if not match: | |
| return None, None | |
| return match.group(1), match.group(2).lower() | |
| def _decoded_blob_size(self, blob: str) -> int: | |
| if not blob: | |
| return 0 | |
| data = blob | |
| if "," in blob and "base64" in blob.split(",", 1)[0]: | |
| data = blob.split(",", 1)[1] | |
| try: | |
| return len(base64.b64decode(data, validate=False)) | |
| except (ValueError, binascii.Error): | |
| return 0 | |
| def _is_final_image(self, decoded_size: int, final_min_bytes: int) -> bool: | |
| # Final image must satisfy byte-size threshold to avoid tiny preview | |
| # images being treated as final outputs. | |
| return decoded_size >= final_min_bytes | |
| def _copy_metadata(message: Dict[str, object]) -> Dict[str, object]: | |
| metadata: Dict[str, object] = {} | |
| for key in ( | |
| "width", | |
| "height", | |
| "model_name", | |
| "percentage_complete", | |
| "job_id", | |
| "request_id", | |
| "order", | |
| "full_prompt", | |
| ): | |
| value = message.get(key) | |
| if value is not None: | |
| metadata[key] = value | |
| return metadata | |
| def _classify_image(self, message: Dict[str, object], final_min_bytes: int, medium_min_bytes: int) -> Optional[Dict[str, object]]: | |
| url = str(message.get("url") or "") | |
| blob = str(message.get("blob") or "") | |
| if not url or not blob: | |
| return None | |
| image_id, ext = self._parse_image_url(url) | |
| image_id = image_id or uuid.uuid4().hex | |
| encoded_size = len(blob) | |
| blob_size = self._decoded_blob_size(blob) | |
| is_final = self._is_final_image(blob_size, final_min_bytes) | |
| stage = ( | |
| "final" | |
| if is_final | |
| else ("medium" if blob_size > medium_min_bytes else "preview") | |
| ) | |
| result = { | |
| "type": "image", | |
| "image_id": image_id, | |
| "ext": ext, | |
| "stage": stage, | |
| "blob": blob, | |
| "blob_size": blob_size, | |
| "encoded_blob_size": encoded_size, | |
| "url": url, | |
| "is_final": is_final, | |
| } | |
| result.update(self._copy_metadata(message)) | |
| return result | |
| def _build_request_message(self, request_id: str, prompt: str, aspect_ratio: str, enable_nsfw: bool) -> Dict[str, object]: | |
| return { | |
| "type": "conversation.item.create", | |
| "timestamp": int(time.time() * 1000), | |
| "item": { | |
| "type": "message", | |
| "content": [ | |
| { | |
| "requestId": request_id, | |
| "text": prompt, | |
| "type": "input_scroll", | |
| "properties": { | |
| "section_count": 0, | |
| "is_kids_mode": False, | |
| "enable_nsfw": enable_nsfw, | |
| "skip_upsampler": False, | |
| "enable_side_by_side": True, | |
| "is_initial": False, | |
| "last_prompt": prompt, | |
| "aspect_ratio": aspect_ratio, | |
| }, | |
| } | |
| ], | |
| }, | |
| } | |
| async def stream( | |
| self, | |
| token: str, | |
| prompt: str, | |
| aspect_ratio: str = "2:3", | |
| n: int = 1, | |
| enable_nsfw: bool = True, | |
| max_retries: Optional[int] = None, | |
| ) -> AsyncGenerator[Dict[str, object], None]: | |
| retries = max(1, max_retries if max_retries is not None else 1) | |
| parallel_enabled = bool(get_config("image.blocked_parallel_enabled", True)) | |
| logger.info( | |
| f"Image generation: prompt='{prompt[:50]}...', n={n}, ratio={aspect_ratio}, nsfw={enable_nsfw}" | |
| ) | |
| async def _collect_once() -> list[Dict[str, object]]: | |
| items: list[Dict[str, object]] = [] | |
| async for item in self._stream_once( | |
| token, prompt, aspect_ratio, n, enable_nsfw | |
| ): | |
| items.append(item) | |
| return items | |
| for attempt in range(retries): | |
| try: | |
| items = await _collect_once() | |
| for item in items: | |
| yield item | |
| return | |
| except _BlockedError: | |
| retries_left = retries - (attempt + 1) | |
| if retries_left > 0 and parallel_enabled: | |
| logger.warning( | |
| f"WebSocket blocked/reviewed, launching {retries_left} parallel retries" | |
| ) | |
| tasks = [asyncio.create_task(_collect_once()) for _ in range(retries_left)] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| for result in results: | |
| if isinstance(result, Exception): | |
| continue | |
| has_final = any( | |
| isinstance(item, dict) | |
| and item.get("type") == "image" | |
| and item.get("is_final") | |
| for item in result | |
| ) | |
| if has_final: | |
| for item in result: | |
| yield item | |
| return | |
| yield { | |
| "type": "error", | |
| "error_code": "blocked", | |
| "error": "blocked_no_final_image", | |
| "parallel_attempts": retries_left, | |
| } | |
| return | |
| if attempt + 1 < retries: | |
| logger.warning( | |
| f"WebSocket blocked/reviewed, retry {attempt + 1}/{retries}" | |
| ) | |
| continue | |
| yield { | |
| "type": "error", | |
| "error_code": "blocked", | |
| "error": "blocked_no_final_image", | |
| } | |
| return | |
| except Exception as e: | |
| logger.error(f"WebSocket stream failed: {e}") | |
| yield { | |
| "type": "error", | |
| "error_code": "ws_stream_failed", | |
| "error": str(e), | |
| } | |
| return | |
| async def _stream_once( | |
| self, | |
| token: str, | |
| prompt: str, | |
| aspect_ratio: str, | |
| n: int, | |
| enable_nsfw: bool, | |
| ) -> AsyncGenerator[Dict[str, object], None]: | |
| request_id = str(uuid.uuid4()) | |
| headers = build_ws_headers(token=token) | |
| timeout = float(get_config("image.timeout")) | |
| stream_timeout = float(get_config("image.stream_timeout")) | |
| final_timeout = float(get_config("image.final_timeout")) | |
| blocked_grace_cfg = get_config("image.blocked_grace_seconds") | |
| blocked_grace = float(blocked_grace_cfg) if blocked_grace_cfg is not None else 10.0 | |
| blocked_grace = max(1.0, min(blocked_grace, final_timeout)) | |
| final_min_bytes = int(get_config("image.final_min_bytes")) | |
| medium_min_bytes = int(get_config("image.medium_min_bytes")) | |
| try: | |
| conn = await self._client.connect( | |
| WS_IMAGINE_URL, | |
| headers=headers, | |
| timeout=timeout, | |
| ws_kwargs={ | |
| "heartbeat": 20, | |
| "receive_timeout": stream_timeout, | |
| }, | |
| ) | |
| except Exception as e: | |
| status = getattr(e, "status", None) | |
| error_code = ( | |
| "rate_limit_exceeded" if status == 429 else "connection_failed" | |
| ) | |
| logger.error(f"WebSocket connect failed: {e}") | |
| yield { | |
| "type": "error", | |
| "error_code": error_code, | |
| "status": status, | |
| "error": str(e), | |
| } | |
| return | |
| try: | |
| async with conn as ws: | |
| message = self._build_request_message( | |
| request_id, prompt, aspect_ratio, enable_nsfw | |
| ) | |
| await ws.send_json(message) | |
| logger.info(f"WebSocket request sent: {prompt[:80]}...") | |
| final_ids: set[str] = set() | |
| completed = 0 | |
| start_time = last_activity = time.monotonic() | |
| medium_received_time: Optional[float] = None | |
| while time.monotonic() - start_time < timeout: | |
| try: | |
| ws_msg = await asyncio.wait_for(ws.receive(), timeout=5.0) | |
| except asyncio.TimeoutError: | |
| now = time.monotonic() | |
| if ( | |
| medium_received_time | |
| and completed == 0 | |
| and now - medium_received_time > blocked_grace | |
| ): | |
| logger.warning( | |
| "Imagine stream blocked suspected: received medium preview but no valid final image " | |
| f"within {blocked_grace:.1f}s (request_id={request_id})" | |
| ) | |
| raise _BlockedError() | |
| if completed > 0 and now - last_activity > 10: | |
| logger.info( | |
| f"WebSocket idle timeout, collected {completed} images" | |
| ) | |
| break | |
| continue | |
| if ws_msg.type == aiohttp.WSMsgType.TEXT: | |
| last_activity = time.monotonic() | |
| try: | |
| msg = orjson.loads(ws_msg.data) | |
| except orjson.JSONDecodeError as e: | |
| logger.warning(f"WebSocket message decode failed: {e}") | |
| continue | |
| msg_type = msg.get("type") | |
| if msg_type == "image": | |
| info = self._classify_image(msg, final_min_bytes, medium_min_bytes) | |
| if not info: | |
| continue | |
| image_id = info["image_id"] | |
| if info["stage"] == "medium" and medium_received_time is None: | |
| medium_received_time = time.monotonic() | |
| if info["is_final"] and image_id not in final_ids: | |
| final_ids.add(image_id) | |
| completed += 1 | |
| logger.debug( | |
| "Final image received: " | |
| f"id={image_id}, decoded_size={info['blob_size']}, " | |
| f"encoded_size={info.get('encoded_blob_size', 0)}" | |
| ) | |
| yield info | |
| elif msg_type == "error": | |
| logger.warning( | |
| f"WebSocket error: {msg.get('err_code', '')} - {msg.get('err_msg', '')}" | |
| ) | |
| yield { | |
| "type": "error", | |
| "error_code": msg.get("err_code", ""), | |
| "error": msg.get("err_msg", ""), | |
| } | |
| return | |
| if completed >= n: | |
| logger.info(f"WebSocket collected {completed} final images") | |
| break | |
| if ( | |
| medium_received_time | |
| and completed == 0 | |
| and time.monotonic() - medium_received_time > final_timeout | |
| ): | |
| logger.warning( | |
| "Imagine stream final-timeout suspected review/block: " | |
| f"no final image reached threshold in {final_timeout:.1f}s " | |
| f"(request_id={request_id})" | |
| ) | |
| raise _BlockedError() | |
| elif ws_msg.type in ( | |
| aiohttp.WSMsgType.CLOSED, | |
| aiohttp.WSMsgType.ERROR, | |
| ): | |
| logger.warning(f"WebSocket closed/error: {ws_msg.type}") | |
| yield { | |
| "type": "error", | |
| "error_code": "ws_closed", | |
| "error": f"websocket closed: {ws_msg.type}", | |
| } | |
| break | |
| except aiohttp.ClientError as e: | |
| logger.error(f"WebSocket connection error: {e}") | |
| yield {"type": "error", "error_code": "connection_failed", "error": str(e)} | |
| __all__ = ["ImagineWebSocketReverse", "WS_IMAGINE_URL"] | |