"""Abstract base class for cloud generation providers.""" from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass @dataclass class CloudGenerationResult: """Result from a cloud provider generation.""" job_id: str image_bytes: bytes generation_time_seconds: float class CloudProvider(ABC): """Abstract interface for cloud GPU providers. Implement this for each provider (Replicate, RunPod, fal.ai, etc.). """ @property @abstractmethod def name(self) -> str: """Provider name (e.g., 'replicate', 'runpod').""" @abstractmethod async def submit_generation( self, *, positive_prompt: str, negative_prompt: str, checkpoint: str, lora_name: str | None = None, lora_strength: float = 0.85, seed: int = -1, steps: int = 28, cfg: float = 7.0, width: int = 832, height: int = 1216, ) -> str: """Submit a generation job. Returns a job ID for tracking.""" @abstractmethod async def check_status(self, job_id: str) -> str: """Check job status. Returns: 'pending', 'running', 'completed', 'failed'.""" @abstractmethod async def get_result(self, job_id: str) -> CloudGenerationResult: """Download the completed generation result.""" @abstractmethod async def is_available(self) -> bool: """Check if this provider is configured and reachable."""