""" StackNet API Client Handles all communication with the StackNettask network. SSE parsing and progress tracking are handled internally. """ import json import tempfile import os from typing import AsyncGenerator, Optional, Any, Callable from dataclasses import dataclass from enum import Enum import httpx from ..config import config class MediaAction(str, Enum): """Supported media orchestration actions.""" GENERATE_MUSIC = "generate_music" CREATE_COVER = "create_cover" EXTRACT_STEMS = "extract_stems" ANALYZE_VISUAL = "analyze_visual" DESCRIBE_VIDEO = "describe_video" CREATE_COMPOSITE = "create_composite" @dataclass class TaskProgress: """Progress update from a running task.""" progress: float # 0.0 to 1.0 status: str message: str @dataclass class TaskResult: """Final result from a completed task.""" success: bool data: dict error: Optional[str] = None class StackNetClient: """ Client for StackNet task network API. All SSE parsing and polling is handled internally. Consumers receive clean progress updates and final results. """ def __init__( self, base_url: Optional[str] = None, api_key: Optional[str] = None, timeout: float = 300.0 ): self.base_url = base_url or config.stacknet_url self.api_key = api_key # Must be provided from UI, no env fallback self.timeout = timeout self._temp_dir = tempfile.mkdtemp(prefix="stacknet_") async def submit_tool_task( self, tool_name: str, parameters: dict, server_name: str = "geoff", on_progress: Optional[Callable[[float, str], None]] = None ) -> TaskResult: """ Submit an MCP tool task and wait for completion. Args: tool_name: The tool to invoke (e.g., generate_image_5) parameters: Tool parameters server_name: MCP server name (default: geoff) on_progress: Callback for progress updates Returns: TaskResult with success status and output data """ payload = { "type": "mcp-tool", "serverName": server_name, "toolName": tool_name, "stream": True, "parameters": parameters } headers = {"Content-Type": "application/json"} if self.api_key: auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" headers["Authorization"] = auth_header async with httpx.AsyncClient(timeout=self.timeout) as client: try: async with client.stream( "POST", f"{self.base_url}/tasks", json=payload, headers=headers ) as response: if response.status_code != 200: error_text = await response.aread() return TaskResult( success=False, data={}, error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" ) return await self._process_sse_stream(response, on_progress) except httpx.TimeoutException: return TaskResult( success=False, data={}, error="Request timed out. The operation took too long." ) except httpx.RequestError as e: return TaskResult( success=False, data={}, error=f"Network error: {str(e)}" ) async def submit_media_task( self, action: MediaAction, prompt: Optional[str] = None, media_url: Optional[str] = None, audio_url: Optional[str] = None, video_url: Optional[str] = None, options: Optional[dict] = None, on_progress: Optional[Callable[[float, str], None]] = None ) -> TaskResult: """ Submit a media orchestration task and wait for completion. Args: action: The media action to perform prompt: Text prompt for generation media_url: URL for image input audio_url: URL for audio input video_url: URL for video input options: Additional options (tags, title, etc.) on_progress: Callback for progress updates (progress: 0-1, message: str) Returns: TaskResult with success status and output data """ payload = { "type": config.TASK_TYPE_MEDIA, "action": action.value, "stream": True, } if prompt: payload["prompt"] = prompt if media_url: payload["mediaUrl"] = media_url if audio_url: payload["audioUrl"] = audio_url if video_url: payload["videoUrl"] = video_url if options: payload["options"] = options headers = {"Content-Type": "application/json"} if self.api_key: auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" headers["Authorization"] = auth_header async with httpx.AsyncClient(timeout=self.timeout) as client: try: async with client.stream( "POST", f"{self.base_url}/tasks", json=payload, headers=headers ) as response: if response.status_code != 200: error_text = await response.aread() return TaskResult( success=False, data={}, error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" ) return await self._process_sse_stream(response, on_progress) except httpx.TimeoutException: return TaskResult( success=False, data={}, error="Request timed out. The operation took too long." ) except httpx.RequestError as e: return TaskResult( success=False, data={}, error=f"Network error: {str(e)}" ) async def _process_sse_stream( self, response: httpx.Response, on_progress: Optional[Callable[[float, str], None]] = None ) -> TaskResult: """Process SSE stream and extract final result.""" buffer = "" final_result: Optional[dict] = None error_message: Optional[str] = None async for chunk in response.aiter_text(): buffer += chunk lines = buffer.split("\n") buffer = lines.pop() # Keep incomplete line for line in lines: if not line.startswith("data: "): continue raw_data = line[6:].strip() # Skip markers if raw_data == "[DONE]" or not raw_data: continue try: event = json.loads(raw_data) event_type = event.get("type", "") event_data = event.get("data", event) if event_type == "progress": if on_progress: progress = self._calculate_progress(event_data) message = event_data.get("message", "Processing...") on_progress(progress, message) elif event_type == "result": final_result = event_data.get("output", event_data) elif event_type == "error": error_message = event_data.get("message", "Unknown error occurred") elif event_type == "complete": # Task completed successfully pass except json.JSONDecodeError: continue # Process any remaining buffer if buffer.strip() and buffer.startswith("data: "): raw_data = buffer[6:].strip() if raw_data and raw_data != "[DONE]": try: event = json.loads(raw_data) if event.get("type") == "result": final_result = event.get("data", {}).get("output", event.get("data", {})) except json.JSONDecodeError: pass if error_message: return TaskResult(success=False, data={}, error=error_message) if final_result: return TaskResult(success=True, data=final_result) return TaskResult( success=False, data={}, error="No result received from the API" ) def _calculate_progress(self, data: dict) -> float: """Calculate normalized progress (0.0 to 1.0).""" if not data: return 0.5 status = data.get("status", "") if status == "completed": return 1.0 if status == "polling": attempt = data.get("attempt", 1) max_attempts = data.get("maxAttempts", 30) return 0.2 + (attempt / max_attempts) * 0.6 if status == "processing": return 0.5 if status == "submitted": return 0.1 return 0.5 async def download_file(self, url: str, filename: Optional[str] = None) -> str: """Download a file to the temp directory and return local path.""" if not filename: filename = url.split("/")[-1].split("?")[0] if not filename: filename = "download" local_path = os.path.join(self._temp_dir, filename) async with httpx.AsyncClient(timeout=60.0) as client: response = await client.get(url) response.raise_for_status() with open(local_path, "wb") as f: f.write(response.content) return local_path def cleanup(self): """Clean up temporary files.""" import shutil if os.path.exists(self._temp_dir): shutil.rmtree(self._temp_dir, ignore_errors=True)