File size: 1,503 Bytes
ed37502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Abstract base class for cloud generation providers."""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class CloudGenerationResult:
    """Result from a cloud provider generation."""

    job_id: str
    image_bytes: bytes
    generation_time_seconds: float


class CloudProvider(ABC):
    """Abstract interface for cloud GPU providers.

    Implement this for each provider (Replicate, RunPod, fal.ai, etc.).
    """

    @property
    @abstractmethod
    def name(self) -> str:
        """Provider name (e.g., 'replicate', 'runpod')."""

    @abstractmethod
    async def submit_generation(
        self,
        *,
        positive_prompt: str,
        negative_prompt: str,
        checkpoint: str,
        lora_name: str | None = None,
        lora_strength: float = 0.85,
        seed: int = -1,
        steps: int = 28,
        cfg: float = 7.0,
        width: int = 832,
        height: int = 1216,
    ) -> str:
        """Submit a generation job. Returns a job ID for tracking."""

    @abstractmethod
    async def check_status(self, job_id: str) -> str:
        """Check job status. Returns: 'pending', 'running', 'completed', 'failed'."""

    @abstractmethod
    async def get_result(self, job_id: str) -> CloudGenerationResult:
        """Download the completed generation result."""

    @abstractmethod
    async def is_available(self) -> bool:
        """Check if this provider is configured and reachable."""