DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified
from typing import Any, Dict, Optional, Union
from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict
from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging
logger = logging.get_logger(__name__)
BASE_URL = "https://api.replicate.com"
SUPPORTED_MODELS = {
"text-to-image": {
"black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev",
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/Hyper-SD": "bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
"playgroundai/playground-v2.5-1024px-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
"stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
"stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium",
"stabilityai/stable-diffusion-xl-base-1.0": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
},
"text-to-speech": {
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26",
},
"text-to-video": {
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
},
}
def _build_url(base_url: str, model: str) -> str:
if ":" in model:
return f"{base_url}/v1/predictions"
return f"{base_url}/v1/models/{model}/predictions"
class ReplicateTask(TaskProviderHelper):
def __init__(self, task: str):
self.task = task
def prepare_request(
self,
*,
inputs: Any,
parameters: Dict[str, Any],
headers: Dict,
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
) -> RequestParameters:
if api_key is None:
api_key = get_token()
if api_key is None:
raise ValueError(
"You must provide an api_key to work with Replicate API or log in with `huggingface-cli login`."
)
# Route to the proxy if the api_key is a HF TOKEN
if api_key.startswith("hf_"):
base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="replicate")
logger.info("Calling Replicate provider through Hugging Face proxy.")
else:
base_url = BASE_URL
logger.info("Calling Replicate provider directly.")
mapped_model = self._map_model(model)
url = _build_url(base_url, mapped_model)
headers = {
**build_hf_headers(token=api_key),
**headers,
"Prefer": "wait",
}
payload = self._prepare_payload(inputs, parameters=parameters, model=mapped_model)
return RequestParameters(
url=url,
task=self.task,
model=mapped_model,
json=payload,
data=None,
headers=headers,
)
def _map_model(self, model: Optional[str]) -> str:
if model is None:
raise ValueError("Please provide a model available on Replicate.")
if self.task not in SUPPORTED_MODELS:
raise ValueError(f"Task {self.task} not supported with Replicate.")
mapped_model = SUPPORTED_MODELS[self.task].get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported with Replicate for task {self.task}.")
return mapped_model
def _prepare_payload(
self,
inputs: Any,
parameters: Dict[str, Any],
model: str,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"input": {
"prompt": inputs,
**{k: v for k, v in parameters.items() if v is not None},
}
}
if ":" in model:
version = model.split(":", 1)[1]
payload["version"] = version
return payload
def get_response(self, response: Union[bytes, Dict]) -> Any:
response_dict = _as_dict(response)
if response_dict.get("output") is None:
raise TimeoutError(
f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}"
"The model might be in cold state or starting up. Please try again later."
)
output_url = (
response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0]
)
return get_session().get(output_url).content
class ReplicateTextToSpeechTask(ReplicateTask):
def __init__(self):
super().__init__("text-to-speech")
def _prepare_payload(
self,
inputs: Any,
parameters: Dict[str, Any],
model: str,
) -> Dict[str, Any]:
# The following payload might work only for a subset of text-to-speech Replicate models.
payload: Dict[str, Any] = {
"input": {
"inputs": inputs,
**{k: v for k, v in parameters.items() if v is not None},
},
}
if ":" in model:
version = model.split(":", 1)[1]
payload["version"] = version
return payload