cranky-coder08's picture
Add files using upload-large-folder tool
783a8bf verified
import base64
import time
from abc import ABC
from typing import Any, Dict, Optional, Union
from urllib.parse import urlparse
from huggingface_hub import constants
from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import get_session, hf_raise_for_status
from huggingface_hub.utils.logging import get_logger
logger = get_logger(__name__)
# Arbitrary polling interval
_POLLING_INTERVAL = 0.5
class FalAITask(TaskProviderHelper, ABC):
def __init__(self, task: str):
super().__init__(provider="fal-ai", base_url="https://fal.run", task=task)
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
headers = super()._prepare_headers(headers, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"
return headers
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return f"/{mapped_model}"
class FalAIQueueTask(TaskProviderHelper, ABC):
def __init__(self, task: str):
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
headers = super()._prepare_headers(headers, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"
return headers
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
if api_key.startswith("hf_"):
# Use the queue subdomain for HF routing
return f"/{mapped_model}?_subdomain=queue"
return f"/{mapped_model}"
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
response_dict = _as_dict(response)
request_id = response_dict.get("request_id")
if not request_id:
raise ValueError("No request ID found in the response")
if request_params is None:
raise ValueError(
f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI."
)
# extract the base url and query params
parsed_url = urlparse(request_params.url)
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
# extracting the provider model id for status and result urls
# from the response as it might be different from the mapped model in `request_params.url`
model_id = urlparse(response_dict.get("response_url")).path
status_url = f"{base_url}{str(model_id)}/status{query_param}"
result_url = f"{base_url}{str(model_id)}{query_param}"
status = response_dict.get("status")
logger.info("Generating the output.. this can take several minutes.")
while status != "COMPLETED":
time.sleep(_POLLING_INTERVAL)
status_response = get_session().get(status_url, headers=request_params.headers)
hf_raise_for_status(status_response)
status = status_response.json().get("status")
return get_session().get(result_url, headers=request_params.headers).json()
class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
# If input is a URL, pass it directly
audio_url = inputs
else:
# If input is a file path, read it first
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
audio_b64 = base64.b64encode(inputs).decode()
content_type = "audio/mpeg"
audio_url = f"data:{content_type};base64,{audio_b64}"
return {"audio_url": audio_url, **filter_none(parameters)}
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
text = _as_dict(response)["text"]
if not isinstance(text, str):
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
return text
class FalAITextToImageTask(FalAITask):
def __init__(self):
super().__init__("text-to-image")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload: Dict[str, Any] = {
"prompt": inputs,
**filter_none(parameters),
}
if "width" in payload and "height" in payload:
payload["image_size"] = {
"width": payload.pop("width"),
"height": payload.pop("height"),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]
if provider_mapping_info.provider_id == "fal-ai/lora":
# little hack: fal requires the base model for stable-diffusion-based loras but not for flux-based
# See payloads in https://fal.ai/models/fal-ai/lora/api vs https://fal.ai/models/fal-ai/flux-lora/api
payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0"
return payload
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
url = _as_dict(response)["images"][0]["url"]
return get_session().get(url).content
class FalAITextToSpeechTask(FalAITask):
def __init__(self):
super().__init__("text-to-speech")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"text": inputs, **filter_none(parameters)}
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
url = _as_dict(response)["audio"]["url"]
return get_session().get(url).content
class FalAITextToVideoTask(FalAIQueueTask):
def __init__(self):
super().__init__("text-to-video")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"prompt": inputs, **filter_none(parameters)}
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["video"]["url"]
return get_session().get(url).content
class FalAIImageToImageTask(FalAIQueueTask):
def __init__(self):
super().__init__("image-to-image")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
image_url = _as_url(inputs, default_mime_type="image/jpeg")
payload: Dict[str, Any] = {
"image_url": image_url,
**filter_none(parameters),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]
return payload
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["images"][0]["url"]
return get_session().get(url).content
class FalAIImageToVideoTask(FalAIQueueTask):
def __init__(self):
super().__init__("image-to-video")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
image_url = _as_url(inputs, default_mime_type="image/jpeg")
payload: Dict[str, Any] = {
"image_url": image_url,
**filter_none(parameters),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]
return payload
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["video"]["url"]
return get_session().get(url).content