grok2api-hf / app /api /v1 /function /imagine.py
Codex
Add root Dockerfile for HF Space build
6bff6a1
import asyncio
import time
import uuid
from typing import Optional, List, Dict, Any
import orjson
from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.core.auth import (
verify_function_key,
get_function_api_key,
is_function_enabled,
)
from app.core.config import get_config
from app.core.logger import logger
from app.api.v1.image import resolve_aspect_ratio
from app.services.grok.services.image import ImageGenerationService
from app.services.grok.services.model import ModelService
from app.services.token.manager import get_token_manager
router = APIRouter()
IMAGINE_SESSION_TTL = 600
_IMAGINE_SESSIONS: dict[str, dict] = {}
_IMAGINE_SESSIONS_LOCK = asyncio.Lock()
async def _clean_sessions(now: float) -> None:
expired = [
key
for key, info in _IMAGINE_SESSIONS.items()
if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL
]
for key in expired:
_IMAGINE_SESSIONS.pop(key, None)
def _parse_sse_chunk(chunk: str) -> Optional[Dict[str, Any]]:
if not chunk:
return None
event = None
data_lines: List[str] = []
for raw in str(chunk).splitlines():
line = raw.strip()
if not line:
continue
if line.startswith("event:"):
event = line[6:].strip()
continue
if line.startswith("data:"):
data_lines.append(line[5:].strip())
if not data_lines:
return None
data_str = "\n".join(data_lines)
if data_str == "[DONE]":
return None
try:
payload = orjson.loads(data_str)
except orjson.JSONDecodeError:
return None
if event and isinstance(payload, dict) and "type" not in payload:
payload["type"] = event
return payload
async def _new_session(prompt: str, aspect_ratio: str, nsfw: Optional[bool]) -> str:
task_id = uuid.uuid4().hex
now = time.time()
async with _IMAGINE_SESSIONS_LOCK:
await _clean_sessions(now)
_IMAGINE_SESSIONS[task_id] = {
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"nsfw": nsfw,
"created_at": now,
}
return task_id
async def _get_session(task_id: str) -> Optional[dict]:
if not task_id:
return None
now = time.time()
async with _IMAGINE_SESSIONS_LOCK:
await _clean_sessions(now)
info = _IMAGINE_SESSIONS.get(task_id)
if not info:
return None
created_at = float(info.get("created_at") or 0)
if now - created_at > IMAGINE_SESSION_TTL:
_IMAGINE_SESSIONS.pop(task_id, None)
return None
return dict(info)
async def _drop_session(task_id: str) -> None:
if not task_id:
return
async with _IMAGINE_SESSIONS_LOCK:
_IMAGINE_SESSIONS.pop(task_id, None)
async def _drop_sessions(task_ids: List[str]) -> int:
if not task_ids:
return 0
removed = 0
async with _IMAGINE_SESSIONS_LOCK:
for task_id in task_ids:
if task_id and task_id in _IMAGINE_SESSIONS:
_IMAGINE_SESSIONS.pop(task_id, None)
removed += 1
return removed
@router.websocket("/imagine/ws")
async def function_imagine_ws(websocket: WebSocket):
session_id = None
task_id = websocket.query_params.get("task_id")
if task_id:
info = await _get_session(task_id)
if info:
session_id = task_id
ok = True
if session_id is None:
function_key = get_function_api_key()
function_enabled = is_function_enabled()
if not function_key:
ok = function_enabled
else:
key = websocket.query_params.get("function_key")
ok = key == function_key
if not ok:
await websocket.close(code=1008)
return
await websocket.accept()
stop_event = asyncio.Event()
run_task: Optional[asyncio.Task] = None
async def _send(payload: dict) -> bool:
try:
await websocket.send_text(orjson.dumps(payload).decode())
return True
except Exception:
return False
async def _stop_run():
nonlocal run_task
stop_event.set()
if run_task and not run_task.done():
run_task.cancel()
try:
await run_task
except Exception:
pass
run_task = None
stop_event.clear()
async def _run(prompt: str, aspect_ratio: str, nsfw: Optional[bool]):
model_id = "grok-imagine-1.0"
model_info = ModelService.get(model_id)
if not model_info or not model_info.is_image:
await _send(
{
"type": "error",
"message": "Image model is not available.",
"code": "model_not_supported",
}
)
return
token_mgr = await get_token_manager()
run_id = uuid.uuid4().hex
await _send(
{
"type": "status",
"status": "running",
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"run_id": run_id,
}
)
while not stop_event.is_set():
try:
await token_mgr.reload_if_stale()
token = None
for pool_name in ModelService.pool_candidates_for_model(
model_info.model_id
):
token = token_mgr.get_token(pool_name)
if token:
break
if not token:
await _send(
{
"type": "error",
"message": "No available tokens. Please try again later.",
"code": "rate_limit_exceeded",
}
)
await asyncio.sleep(2)
continue
result = await ImageGenerationService().generate(
token_mgr=token_mgr,
token=token,
model_info=model_info,
prompt=prompt,
n=6,
response_format="b64_json",
size="1024x1024",
aspect_ratio=aspect_ratio,
stream=True,
enable_nsfw=nsfw,
)
if result.stream:
async for chunk in result.data:
payload = _parse_sse_chunk(chunk)
if not payload:
continue
if isinstance(payload, dict):
payload.setdefault("run_id", run_id)
await _send(payload)
else:
images = [img for img in result.data if img and img != "error"]
if images:
for img_b64 in images:
await _send(
{
"type": "image",
"b64_json": img_b64,
"created_at": int(time.time() * 1000),
"aspect_ratio": aspect_ratio,
"run_id": run_id,
}
)
else:
await _send(
{
"type": "error",
"message": "Image generation returned empty data.",
"code": "empty_image",
}
)
except asyncio.CancelledError:
break
except Exception as e:
logger.warning(f"Imagine stream error: {e}")
await _send(
{
"type": "error",
"message": str(e),
"code": "internal_error",
}
)
await asyncio.sleep(1.5)
await _send({"type": "status", "status": "stopped", "run_id": run_id})
try:
while True:
try:
raw = await websocket.receive_text()
except (RuntimeError, WebSocketDisconnect):
break
try:
payload = orjson.loads(raw)
except Exception:
await _send(
{
"type": "error",
"message": "Invalid message format.",
"code": "invalid_payload",
}
)
continue
action = payload.get("type")
if action == "start":
prompt = str(payload.get("prompt") or "").strip()
if not prompt:
await _send(
{
"type": "error",
"message": "Prompt cannot be empty.",
"code": "invalid_prompt",
}
)
continue
aspect_ratio = resolve_aspect_ratio(
str(payload.get("aspect_ratio") or "2:3").strip() or "2:3"
)
nsfw = payload.get("nsfw")
if nsfw is not None:
nsfw = bool(nsfw)
await _stop_run()
run_task = asyncio.create_task(_run(prompt, aspect_ratio, nsfw))
elif action == "stop":
await _stop_run()
else:
await _send(
{
"type": "error",
"message": "Unknown action.",
"code": "invalid_action",
}
)
except WebSocketDisconnect:
logger.debug("WebSocket disconnected by client")
except Exception as e:
logger.warning(f"WebSocket error: {e}")
finally:
await _stop_run()
try:
from starlette.websockets import WebSocketState
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close(code=1000, reason="Server closing connection")
except Exception as e:
logger.debug(f"WebSocket close ignored: {e}")
if session_id:
await _drop_session(session_id)
@router.get("/imagine/sse")
async def function_imagine_sse(
request: Request,
task_id: str = Query(""),
prompt: str = Query(""),
aspect_ratio: str = Query("2:3"),
):
"""Imagine 图片瀑布流(SSE 兜底)"""
session = None
if task_id:
session = await _get_session(task_id)
if not session:
raise HTTPException(status_code=404, detail="Task not found")
else:
function_key = get_function_api_key()
function_enabled = is_function_enabled()
if not function_key:
if not function_enabled:
raise HTTPException(status_code=401, detail="Function access is disabled")
else:
key = request.query_params.get("function_key")
if key != function_key:
raise HTTPException(status_code=401, detail="Invalid authentication token")
if session:
prompt = str(session.get("prompt") or "").strip()
ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3"
nsfw = session.get("nsfw")
else:
prompt = (prompt or "").strip()
if not prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
ratio = str(aspect_ratio or "2:3").strip() or "2:3"
ratio = resolve_aspect_ratio(ratio)
nsfw = request.query_params.get("nsfw")
if nsfw is not None:
nsfw = str(nsfw).lower() in ("1", "true", "yes", "on")
async def event_stream():
try:
model_id = "grok-imagine-1.0"
model_info = ModelService.get(model_id)
if not model_info or not model_info.is_image:
yield (
f"data: {orjson.dumps({'type': 'error', 'message': 'Image model is not available.', 'code': 'model_not_supported'}).decode()}\n\n"
)
return
token_mgr = await get_token_manager()
sequence = 0
run_id = uuid.uuid4().hex
yield (
f"data: {orjson.dumps({'type': 'status', 'status': 'running', 'prompt': prompt, 'aspect_ratio': ratio, 'run_id': run_id}).decode()}\n\n"
)
while True:
if await request.is_disconnected():
break
if task_id:
session_alive = await _get_session(task_id)
if not session_alive:
break
try:
await token_mgr.reload_if_stale()
token = None
for pool_name in ModelService.pool_candidates_for_model(
model_info.model_id
):
token = token_mgr.get_token(pool_name)
if token:
break
if not token:
yield (
f"data: {orjson.dumps({'type': 'error', 'message': 'No available tokens. Please try again later.', 'code': 'rate_limit_exceeded'}).decode()}\n\n"
)
await asyncio.sleep(2)
continue
result = await ImageGenerationService().generate(
token_mgr=token_mgr,
token=token,
model_info=model_info,
prompt=prompt,
n=6,
response_format="b64_json",
size="1024x1024",
aspect_ratio=ratio,
stream=True,
enable_nsfw=nsfw,
)
if result.stream:
async for chunk in result.data:
payload = _parse_sse_chunk(chunk)
if not payload:
continue
if isinstance(payload, dict):
payload.setdefault("run_id", run_id)
yield f"data: {orjson.dumps(payload).decode()}\n\n"
else:
images = [img for img in result.data if img and img != "error"]
if images:
for img_b64 in images:
sequence += 1
payload = {
"type": "image",
"b64_json": img_b64,
"sequence": sequence,
"created_at": int(time.time() * 1000),
"aspect_ratio": ratio,
"run_id": run_id,
}
yield f"data: {orjson.dumps(payload).decode()}\n\n"
else:
yield (
f"data: {orjson.dumps({'type': 'error', 'message': 'Image generation returned empty data.', 'code': 'empty_image'}).decode()}\n\n"
)
except asyncio.CancelledError:
break
except Exception as e:
logger.warning(f"Imagine SSE error: {e}")
yield (
f"data: {orjson.dumps({'type': 'error', 'message': str(e), 'code': 'internal_error'}).decode()}\n\n"
)
await asyncio.sleep(1.5)
yield (
f"data: {orjson.dumps({'type': 'status', 'status': 'stopped', 'run_id': run_id}).decode()}\n\n"
)
finally:
if task_id:
await _drop_session(task_id)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
@router.get("/imagine/config")
async def function_imagine_config():
return {
"final_min_bytes": int(get_config("image.final_min_bytes") or 0),
"medium_min_bytes": int(get_config("image.medium_min_bytes") or 0),
"nsfw": bool(get_config("image.nsfw")),
}
class ImagineStartRequest(BaseModel):
prompt: str
aspect_ratio: Optional[str] = "2:3"
nsfw: Optional[bool] = None
@router.post("/imagine/start", dependencies=[Depends(verify_function_key)])
async def function_imagine_start(data: ImagineStartRequest):
prompt = (data.prompt or "").strip()
if not prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3")
task_id = await _new_session(prompt, ratio, data.nsfw)
return {"task_id": task_id, "aspect_ratio": ratio}
class ImagineStopRequest(BaseModel):
task_ids: List[str]
@router.post("/imagine/stop", dependencies=[Depends(verify_function_key)])
async def function_imagine_stop(data: ImagineStopRequest):
removed = await _drop_sessions(data.task_ids or [])
return {"status": "success", "removed": removed}