File size: 9,860 Bytes
783a8bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import base64
import time
from abc import ABC
from typing import Any, Dict, Optional, Union
from urllib.parse import urlparse
from huggingface_hub import constants
from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import get_session, hf_raise_for_status
from huggingface_hub.utils.logging import get_logger
logger = get_logger(__name__)
# Arbitrary polling interval
_POLLING_INTERVAL = 0.5
class FalAITask(TaskProviderHelper, ABC):
def __init__(self, task: str):
super().__init__(provider="fal-ai", base_url="https://fal.run", task=task)
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
headers = super()._prepare_headers(headers, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"
return headers
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return f"/{mapped_model}"
class FalAIQueueTask(TaskProviderHelper, ABC):
def __init__(self, task: str):
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
headers = super()._prepare_headers(headers, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"
return headers
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
if api_key.startswith("hf_"):
# Use the queue subdomain for HF routing
return f"/{mapped_model}?_subdomain=queue"
return f"/{mapped_model}"
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
response_dict = _as_dict(response)
request_id = response_dict.get("request_id")
if not request_id:
raise ValueError("No request ID found in the response")
if request_params is None:
raise ValueError(
f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI."
)
# extract the base url and query params
parsed_url = urlparse(request_params.url)
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
# extracting the provider model id for status and result urls
# from the response as it might be different from the mapped model in `request_params.url`
model_id = urlparse(response_dict.get("response_url")).path
status_url = f"{base_url}{str(model_id)}/status{query_param}"
result_url = f"{base_url}{str(model_id)}{query_param}"
status = response_dict.get("status")
logger.info("Generating the output.. this can take several minutes.")
while status != "COMPLETED":
time.sleep(_POLLING_INTERVAL)
status_response = get_session().get(status_url, headers=request_params.headers)
hf_raise_for_status(status_response)
status = status_response.json().get("status")
return get_session().get(result_url, headers=request_params.headers).json()
class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
# If input is a URL, pass it directly
audio_url = inputs
else:
# If input is a file path, read it first
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
audio_b64 = base64.b64encode(inputs).decode()
content_type = "audio/mpeg"
audio_url = f"data:{content_type};base64,{audio_b64}"
return {"audio_url": audio_url, **filter_none(parameters)}
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
text = _as_dict(response)["text"]
if not isinstance(text, str):
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
return text
class FalAITextToImageTask(FalAITask):
def __init__(self):
super().__init__("text-to-image")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload: Dict[str, Any] = {
"prompt": inputs,
**filter_none(parameters),
}
if "width" in payload and "height" in payload:
payload["image_size"] = {
"width": payload.pop("width"),
"height": payload.pop("height"),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]
if provider_mapping_info.provider_id == "fal-ai/lora":
# little hack: fal requires the base model for stable-diffusion-based loras but not for flux-based
# See payloads in https://fal.ai/models/fal-ai/lora/api vs https://fal.ai/models/fal-ai/flux-lora/api
payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0"
return payload
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
url = _as_dict(response)["images"][0]["url"]
return get_session().get(url).content
class FalAITextToSpeechTask(FalAITask):
def __init__(self):
super().__init__("text-to-speech")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"text": inputs, **filter_none(parameters)}
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
url = _as_dict(response)["audio"]["url"]
return get_session().get(url).content
class FalAITextToVideoTask(FalAIQueueTask):
def __init__(self):
super().__init__("text-to-video")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"prompt": inputs, **filter_none(parameters)}
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["video"]["url"]
return get_session().get(url).content
class FalAIImageToImageTask(FalAIQueueTask):
def __init__(self):
super().__init__("image-to-image")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
image_url = _as_url(inputs, default_mime_type="image/jpeg")
payload: Dict[str, Any] = {
"image_url": image_url,
**filter_none(parameters),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]
return payload
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["images"][0]["url"]
return get_session().get(url).content
class FalAIImageToVideoTask(FalAIQueueTask):
def __init__(self):
super().__init__("image-to-video")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
image_url = _as_url(inputs, default_mime_type="image/jpeg")
payload: Dict[str, Any] = {
"image_url": image_url,
**filter_none(parameters),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]
return payload
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["video"]["url"]
return get_session().get(url).content
|