|
|
from typing import Any, Dict, Optional, Union |
|
|
|
|
|
from huggingface_hub import constants |
|
|
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict |
|
|
from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
BASE_URL = "https://api.replicate.com" |
|
|
|
|
|
SUPPORTED_MODELS = { |
|
|
"text-to-image": { |
|
|
"black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev", |
|
|
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", |
|
|
"ByteDance/Hyper-SD": "bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7", |
|
|
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637", |
|
|
"playgroundai/playground-v2.5-1024px-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24", |
|
|
"stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo", |
|
|
"stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large", |
|
|
"stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium", |
|
|
"stabilityai/stable-diffusion-xl-base-1.0": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", |
|
|
}, |
|
|
"text-to-speech": { |
|
|
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26", |
|
|
}, |
|
|
"text-to-video": { |
|
|
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def _build_url(base_url: str, model: str) -> str: |
|
|
if ":" in model: |
|
|
return f"{base_url}/v1/predictions" |
|
|
return f"{base_url}/v1/models/{model}/predictions" |
|
|
|
|
|
|
|
|
class ReplicateTask(TaskProviderHelper): |
|
|
def __init__(self, task: str): |
|
|
self.task = task |
|
|
|
|
|
def prepare_request( |
|
|
self, |
|
|
*, |
|
|
inputs: Any, |
|
|
parameters: Dict[str, Any], |
|
|
headers: Dict, |
|
|
model: Optional[str], |
|
|
api_key: Optional[str], |
|
|
extra_payload: Optional[Dict[str, Any]] = None, |
|
|
) -> RequestParameters: |
|
|
if api_key is None: |
|
|
api_key = get_token() |
|
|
if api_key is None: |
|
|
raise ValueError( |
|
|
"You must provide an api_key to work with Replicate API or log in with `huggingface-cli login`." |
|
|
) |
|
|
|
|
|
|
|
|
if api_key.startswith("hf_"): |
|
|
base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="replicate") |
|
|
logger.info("Calling Replicate provider through Hugging Face proxy.") |
|
|
else: |
|
|
base_url = BASE_URL |
|
|
logger.info("Calling Replicate provider directly.") |
|
|
mapped_model = self._map_model(model) |
|
|
url = _build_url(base_url, mapped_model) |
|
|
|
|
|
headers = { |
|
|
**build_hf_headers(token=api_key), |
|
|
**headers, |
|
|
"Prefer": "wait", |
|
|
} |
|
|
|
|
|
payload = self._prepare_payload(inputs, parameters=parameters, model=mapped_model) |
|
|
|
|
|
return RequestParameters( |
|
|
url=url, |
|
|
task=self.task, |
|
|
model=mapped_model, |
|
|
json=payload, |
|
|
data=None, |
|
|
headers=headers, |
|
|
) |
|
|
|
|
|
def _map_model(self, model: Optional[str]) -> str: |
|
|
if model is None: |
|
|
raise ValueError("Please provide a model available on Replicate.") |
|
|
if self.task not in SUPPORTED_MODELS: |
|
|
raise ValueError(f"Task {self.task} not supported with Replicate.") |
|
|
mapped_model = SUPPORTED_MODELS[self.task].get(model) |
|
|
if mapped_model is None: |
|
|
raise ValueError(f"Model {model} is not supported with Replicate for task {self.task}.") |
|
|
return mapped_model |
|
|
|
|
|
def _prepare_payload( |
|
|
self, |
|
|
inputs: Any, |
|
|
parameters: Dict[str, Any], |
|
|
model: str, |
|
|
) -> Dict[str, Any]: |
|
|
payload: Dict[str, Any] = { |
|
|
"input": { |
|
|
"prompt": inputs, |
|
|
**{k: v for k, v in parameters.items() if v is not None}, |
|
|
} |
|
|
} |
|
|
if ":" in model: |
|
|
version = model.split(":", 1)[1] |
|
|
payload["version"] = version |
|
|
return payload |
|
|
|
|
|
def get_response(self, response: Union[bytes, Dict]) -> Any: |
|
|
response_dict = _as_dict(response) |
|
|
if response_dict.get("output") is None: |
|
|
raise TimeoutError( |
|
|
f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}" |
|
|
"The model might be in cold state or starting up. Please try again later." |
|
|
) |
|
|
output_url = ( |
|
|
response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0] |
|
|
) |
|
|
return get_session().get(output_url).content |
|
|
|
|
|
|
|
|
class ReplicateTextToSpeechTask(ReplicateTask): |
|
|
def __init__(self): |
|
|
super().__init__("text-to-speech") |
|
|
|
|
|
def _prepare_payload( |
|
|
self, |
|
|
inputs: Any, |
|
|
parameters: Dict[str, Any], |
|
|
model: str, |
|
|
) -> Dict[str, Any]: |
|
|
|
|
|
payload: Dict[str, Any] = { |
|
|
"input": { |
|
|
"inputs": inputs, |
|
|
**{k: v for k, v in parameters.items() if v is not None}, |
|
|
}, |
|
|
} |
|
|
if ":" in model: |
|
|
version = model.split(":", 1)[1] |
|
|
payload["version"] = version |
|
|
return payload |
|
|
|