File size: 5,682 Bytes
17c6d62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | from typing import Any, Dict, Optional, Union
from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict
from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging
logger = logging.get_logger(__name__)
BASE_URL = "https://api.replicate.com"
SUPPORTED_MODELS = {
"text-to-image": {
"black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev",
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/Hyper-SD": "bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
"playgroundai/playground-v2.5-1024px-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
"stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
"stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium",
"stabilityai/stable-diffusion-xl-base-1.0": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
},
"text-to-speech": {
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26",
},
"text-to-video": {
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
},
}
def _build_url(base_url: str, model: str) -> str:
if ":" in model:
return f"{base_url}/v1/predictions"
return f"{base_url}/v1/models/{model}/predictions"
class ReplicateTask(TaskProviderHelper):
def __init__(self, task: str):
self.task = task
def prepare_request(
self,
*,
inputs: Any,
parameters: Dict[str, Any],
headers: Dict,
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
) -> RequestParameters:
if api_key is None:
api_key = get_token()
if api_key is None:
raise ValueError(
"You must provide an api_key to work with Replicate API or log in with `huggingface-cli login`."
)
# Route to the proxy if the api_key is a HF TOKEN
if api_key.startswith("hf_"):
base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="replicate")
logger.info("Calling Replicate provider through Hugging Face proxy.")
else:
base_url = BASE_URL
logger.info("Calling Replicate provider directly.")
mapped_model = self._map_model(model)
url = _build_url(base_url, mapped_model)
headers = {
**build_hf_headers(token=api_key),
**headers,
"Prefer": "wait",
}
payload = self._prepare_payload(inputs, parameters=parameters, model=mapped_model)
return RequestParameters(
url=url,
task=self.task,
model=mapped_model,
json=payload,
data=None,
headers=headers,
)
def _map_model(self, model: Optional[str]) -> str:
if model is None:
raise ValueError("Please provide a model available on Replicate.")
if self.task not in SUPPORTED_MODELS:
raise ValueError(f"Task {self.task} not supported with Replicate.")
mapped_model = SUPPORTED_MODELS[self.task].get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported with Replicate for task {self.task}.")
return mapped_model
def _prepare_payload(
self,
inputs: Any,
parameters: Dict[str, Any],
model: str,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"input": {
"prompt": inputs,
**{k: v for k, v in parameters.items() if v is not None},
}
}
if ":" in model:
version = model.split(":", 1)[1]
payload["version"] = version
return payload
def get_response(self, response: Union[bytes, Dict]) -> Any:
response_dict = _as_dict(response)
if response_dict.get("output") is None:
raise TimeoutError(
f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}"
"The model might be in cold state or starting up. Please try again later."
)
output_url = (
response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0]
)
return get_session().get(output_url).content
class ReplicateTextToSpeechTask(ReplicateTask):
def __init__(self):
super().__init__("text-to-speech")
def _prepare_payload(
self,
inputs: Any,
parameters: Dict[str, Any],
model: str,
) -> Dict[str, Any]:
# The following payload might work only for a subset of text-to-speech Replicate models.
payload: Dict[str, Any] = {
"input": {
"inputs": inputs,
**{k: v for k, v in parameters.items() if v is not None},
},
}
if ":" in model:
version = model.split(":", 1)[1]
payload["version"] = version
return payload
|