|
|
""" |
|
|
StackNet API Client |
|
|
|
|
|
Handles all communication with the StackNettask network. |
|
|
SSE parsing and progress tracking are handled internally. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import tempfile |
|
|
import os |
|
|
from typing import AsyncGenerator, Optional, Any, Callable |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
import httpx |
|
|
|
|
|
from ..config import config |
|
|
|
|
|
|
|
|
class MediaAction(str, Enum): |
|
|
"""Supported media orchestration actions.""" |
|
|
GENERATE_MUSIC = "generate_music" |
|
|
CREATE_COVER = "create_cover" |
|
|
EXTRACT_STEMS = "extract_stems" |
|
|
ANALYZE_VISUAL = "analyze_visual" |
|
|
DESCRIBE_VIDEO = "describe_video" |
|
|
CREATE_COMPOSITE = "create_composite" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TaskProgress: |
|
|
"""Progress update from a running task.""" |
|
|
progress: float |
|
|
status: str |
|
|
message: str |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TaskResult: |
|
|
"""Final result from a completed task.""" |
|
|
success: bool |
|
|
data: dict |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
class StackNetClient: |
|
|
""" |
|
|
Client for StackNet task network API. |
|
|
|
|
|
All SSE parsing and polling is handled internally. |
|
|
Consumers receive clean progress updates and final results. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_url: Optional[str] = None, |
|
|
api_key: Optional[str] = None, |
|
|
timeout: float = 300.0 |
|
|
): |
|
|
self.base_url = base_url or config.stacknet_url |
|
|
self.api_key = api_key |
|
|
self.timeout = timeout |
|
|
self._temp_dir = tempfile.mkdtemp(prefix="stacknet_") |
|
|
|
|
|
async def submit_tool_task( |
|
|
self, |
|
|
tool_name: str, |
|
|
parameters: dict, |
|
|
server_name: str = "geoff", |
|
|
on_progress: Optional[Callable[[float, str], None]] = None |
|
|
) -> TaskResult: |
|
|
""" |
|
|
Submit an MCP tool task and wait for completion. |
|
|
|
|
|
Args: |
|
|
tool_name: The tool to invoke (e.g., generate_image_5) |
|
|
parameters: Tool parameters |
|
|
server_name: MCP server name (default: geoff) |
|
|
on_progress: Callback for progress updates |
|
|
|
|
|
Returns: |
|
|
TaskResult with success status and output data |
|
|
""" |
|
|
payload = { |
|
|
"type": "mcp-tool", |
|
|
"serverName": server_name, |
|
|
"toolName": tool_name, |
|
|
"stream": True, |
|
|
"parameters": parameters |
|
|
} |
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
if self.api_key: |
|
|
auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" |
|
|
headers["Authorization"] = auth_header |
|
|
|
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client: |
|
|
try: |
|
|
async with client.stream( |
|
|
"POST", |
|
|
f"{self.base_url}/tasks", |
|
|
json=payload, |
|
|
headers=headers |
|
|
) as response: |
|
|
if response.status_code != 200: |
|
|
error_text = await response.aread() |
|
|
return TaskResult( |
|
|
success=False, |
|
|
data={}, |
|
|
error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" |
|
|
) |
|
|
|
|
|
return await self._process_sse_stream(response, on_progress) |
|
|
|
|
|
except httpx.TimeoutException: |
|
|
return TaskResult( |
|
|
success=False, |
|
|
data={}, |
|
|
error="Request timed out. The operation took too long." |
|
|
) |
|
|
except httpx.RequestError as e: |
|
|
return TaskResult( |
|
|
success=False, |
|
|
data={}, |
|
|
error=f"Network error: {str(e)}" |
|
|
) |
|
|
|
|
|
async def submit_media_task( |
|
|
self, |
|
|
action: MediaAction, |
|
|
prompt: Optional[str] = None, |
|
|
media_url: Optional[str] = None, |
|
|
audio_url: Optional[str] = None, |
|
|
video_url: Optional[str] = None, |
|
|
options: Optional[dict] = None, |
|
|
on_progress: Optional[Callable[[float, str], None]] = None |
|
|
) -> TaskResult: |
|
|
""" |
|
|
Submit a media orchestration task and wait for completion. |
|
|
|
|
|
Args: |
|
|
action: The media action to perform |
|
|
prompt: Text prompt for generation |
|
|
media_url: URL for image input |
|
|
audio_url: URL for audio input |
|
|
video_url: URL for video input |
|
|
options: Additional options (tags, title, etc.) |
|
|
on_progress: Callback for progress updates (progress: 0-1, message: str) |
|
|
|
|
|
Returns: |
|
|
TaskResult with success status and output data |
|
|
""" |
|
|
payload = { |
|
|
"type": config.TASK_TYPE_MEDIA, |
|
|
"action": action.value, |
|
|
"stream": True, |
|
|
} |
|
|
|
|
|
if prompt: |
|
|
payload["prompt"] = prompt |
|
|
if media_url: |
|
|
payload["mediaUrl"] = media_url |
|
|
if audio_url: |
|
|
payload["audioUrl"] = audio_url |
|
|
if video_url: |
|
|
payload["videoUrl"] = video_url |
|
|
if options: |
|
|
payload["options"] = options |
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
if self.api_key: |
|
|
auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" |
|
|
headers["Authorization"] = auth_header |
|
|
|
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client: |
|
|
try: |
|
|
async with client.stream( |
|
|
"POST", |
|
|
f"{self.base_url}/tasks", |
|
|
json=payload, |
|
|
headers=headers |
|
|
) as response: |
|
|
if response.status_code != 200: |
|
|
error_text = await response.aread() |
|
|
return TaskResult( |
|
|
success=False, |
|
|
data={}, |
|
|
error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" |
|
|
) |
|
|
|
|
|
return await self._process_sse_stream(response, on_progress) |
|
|
|
|
|
except httpx.TimeoutException: |
|
|
return TaskResult( |
|
|
success=False, |
|
|
data={}, |
|
|
error="Request timed out. The operation took too long." |
|
|
) |
|
|
except httpx.RequestError as e: |
|
|
return TaskResult( |
|
|
success=False, |
|
|
data={}, |
|
|
error=f"Network error: {str(e)}" |
|
|
) |
|
|
|
|
|
async def _process_sse_stream( |
|
|
self, |
|
|
response: httpx.Response, |
|
|
on_progress: Optional[Callable[[float, str], None]] = None |
|
|
) -> TaskResult: |
|
|
"""Process SSE stream and extract final result.""" |
|
|
buffer = "" |
|
|
final_result: Optional[dict] = None |
|
|
error_message: Optional[str] = None |
|
|
|
|
|
async for chunk in response.aiter_text(): |
|
|
buffer += chunk |
|
|
lines = buffer.split("\n") |
|
|
buffer = lines.pop() |
|
|
|
|
|
for line in lines: |
|
|
if not line.startswith("data: "): |
|
|
continue |
|
|
|
|
|
raw_data = line[6:].strip() |
|
|
|
|
|
|
|
|
if raw_data == "[DONE]" or not raw_data: |
|
|
continue |
|
|
|
|
|
try: |
|
|
event = json.loads(raw_data) |
|
|
event_type = event.get("type", "") |
|
|
event_data = event.get("data", event) |
|
|
|
|
|
if event_type == "progress": |
|
|
if on_progress: |
|
|
progress = self._calculate_progress(event_data) |
|
|
message = event_data.get("message", "Processing...") |
|
|
on_progress(progress, message) |
|
|
|
|
|
elif event_type == "result": |
|
|
final_result = event_data.get("output", event_data) |
|
|
|
|
|
elif event_type == "error": |
|
|
error_message = event_data.get("message", "Unknown error occurred") |
|
|
|
|
|
elif event_type == "complete": |
|
|
|
|
|
pass |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
|
|
|
if buffer.strip() and buffer.startswith("data: "): |
|
|
raw_data = buffer[6:].strip() |
|
|
if raw_data and raw_data != "[DONE]": |
|
|
try: |
|
|
event = json.loads(raw_data) |
|
|
if event.get("type") == "result": |
|
|
final_result = event.get("data", {}).get("output", event.get("data", {})) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
if error_message: |
|
|
return TaskResult(success=False, data={}, error=error_message) |
|
|
|
|
|
if final_result: |
|
|
return TaskResult(success=True, data=final_result) |
|
|
|
|
|
return TaskResult( |
|
|
success=False, |
|
|
data={}, |
|
|
error="No result received from the API" |
|
|
) |
|
|
|
|
|
def _calculate_progress(self, data: dict) -> float: |
|
|
"""Calculate normalized progress (0.0 to 1.0).""" |
|
|
if not data: |
|
|
return 0.5 |
|
|
|
|
|
status = data.get("status", "") |
|
|
|
|
|
if status == "completed": |
|
|
return 1.0 |
|
|
if status == "polling": |
|
|
attempt = data.get("attempt", 1) |
|
|
max_attempts = data.get("maxAttempts", 30) |
|
|
return 0.2 + (attempt / max_attempts) * 0.6 |
|
|
if status == "processing": |
|
|
return 0.5 |
|
|
if status == "submitted": |
|
|
return 0.1 |
|
|
|
|
|
return 0.5 |
|
|
|
|
|
async def download_file(self, url: str, filename: Optional[str] = None) -> str: |
|
|
"""Download a file to the temp directory and return local path.""" |
|
|
if not filename: |
|
|
filename = url.split("/")[-1].split("?")[0] |
|
|
if not filename: |
|
|
filename = "download" |
|
|
|
|
|
local_path = os.path.join(self._temp_dir, filename) |
|
|
|
|
|
async with httpx.AsyncClient(timeout=60.0) as client: |
|
|
response = await client.get(url) |
|
|
response.raise_for_status() |
|
|
|
|
|
with open(local_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
|
|
|
return local_path |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up temporary files.""" |
|
|
import shutil |
|
|
if os.path.exists(self._temp_dir): |
|
|
shutil.rmtree(self._temp_dir, ignore_errors=True) |
|
|
|