| import base64 |
| import time |
| from abc import ABC |
| from typing import Any, Dict, Optional, Union |
| from urllib.parse import urlparse |
|
|
| from huggingface_hub import constants |
| from huggingface_hub.hf_api import InferenceProviderMapping |
| from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url |
| from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none |
| from huggingface_hub.utils import get_session, hf_raise_for_status |
| from huggingface_hub.utils.logging import get_logger |
|
|
|
|
| logger = get_logger(__name__) |
|
|
| |
| _POLLING_INTERVAL = 0.5 |
|
|
|
|
| class FalAITask(TaskProviderHelper, ABC): |
| def __init__(self, task: str): |
| super().__init__(provider="fal-ai", base_url="https://fal.run", task=task) |
|
|
| def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: |
| headers = super()._prepare_headers(headers, api_key) |
| if not api_key.startswith("hf_"): |
| headers["authorization"] = f"Key {api_key}" |
| return headers |
|
|
| def _prepare_route(self, mapped_model: str, api_key: str) -> str: |
| return f"/{mapped_model}" |
|
|
|
|
| class FalAIQueueTask(TaskProviderHelper, ABC): |
| def __init__(self, task: str): |
| super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task) |
|
|
| def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: |
| headers = super()._prepare_headers(headers, api_key) |
| if not api_key.startswith("hf_"): |
| headers["authorization"] = f"Key {api_key}" |
| return headers |
|
|
| def _prepare_route(self, mapped_model: str, api_key: str) -> str: |
| if api_key.startswith("hf_"): |
| |
| return f"/{mapped_model}?_subdomain=queue" |
| return f"/{mapped_model}" |
|
|
| def get_response( |
| self, |
| response: Union[bytes, Dict], |
| request_params: Optional[RequestParameters] = None, |
| ) -> Any: |
| response_dict = _as_dict(response) |
|
|
| request_id = response_dict.get("request_id") |
| if not request_id: |
| raise ValueError("No request ID found in the response") |
| if request_params is None: |
| raise ValueError( |
| f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI." |
| ) |
|
|
| |
| parsed_url = urlparse(request_params.url) |
| |
| base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" |
| query_param = f"?{parsed_url.query}" if parsed_url.query else "" |
|
|
| |
| |
| model_id = urlparse(response_dict.get("response_url")).path |
| status_url = f"{base_url}{str(model_id)}/status{query_param}" |
| result_url = f"{base_url}{str(model_id)}{query_param}" |
|
|
| status = response_dict.get("status") |
| logger.info("Generating the output.. this can take several minutes.") |
| while status != "COMPLETED": |
| time.sleep(_POLLING_INTERVAL) |
| status_response = get_session().get(status_url, headers=request_params.headers) |
| hf_raise_for_status(status_response) |
| status = status_response.json().get("status") |
|
|
| return get_session().get(result_url, headers=request_params.headers).json() |
|
|
|
|
| class FalAIAutomaticSpeechRecognitionTask(FalAITask): |
| def __init__(self): |
| super().__init__("automatic-speech-recognition") |
|
|
| def _prepare_payload_as_dict( |
| self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
| ) -> Optional[Dict]: |
| if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): |
| |
| audio_url = inputs |
| else: |
| |
| if isinstance(inputs, str): |
| with open(inputs, "rb") as f: |
| inputs = f.read() |
|
|
| audio_b64 = base64.b64encode(inputs).decode() |
| content_type = "audio/mpeg" |
| audio_url = f"data:{content_type};base64,{audio_b64}" |
|
|
| return {"audio_url": audio_url, **filter_none(parameters)} |
|
|
| def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: |
| text = _as_dict(response)["text"] |
| if not isinstance(text, str): |
| raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.") |
| return text |
|
|
|
|
| class FalAITextToImageTask(FalAITask): |
| def __init__(self): |
| super().__init__("text-to-image") |
|
|
| def _prepare_payload_as_dict( |
| self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
| ) -> Optional[Dict]: |
| payload: Dict[str, Any] = { |
| "prompt": inputs, |
| **filter_none(parameters), |
| } |
| if "width" in payload and "height" in payload: |
| payload["image_size"] = { |
| "width": payload.pop("width"), |
| "height": payload.pop("height"), |
| } |
| if provider_mapping_info.adapter_weights_path is not None: |
| lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( |
| repo_id=provider_mapping_info.hf_model_id, |
| revision="main", |
| filename=provider_mapping_info.adapter_weights_path, |
| ) |
| payload["loras"] = [{"path": lora_path, "scale": 1}] |
| if provider_mapping_info.provider_id == "fal-ai/lora": |
| |
| |
| payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
| return payload |
|
|
| def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: |
| url = _as_dict(response)["images"][0]["url"] |
| return get_session().get(url).content |
|
|
|
|
| class FalAITextToSpeechTask(FalAITask): |
| def __init__(self): |
| super().__init__("text-to-speech") |
|
|
| def _prepare_payload_as_dict( |
| self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
| ) -> Optional[Dict]: |
| return {"text": inputs, **filter_none(parameters)} |
|
|
| def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: |
| url = _as_dict(response)["audio"]["url"] |
| return get_session().get(url).content |
|
|
|
|
| class FalAITextToVideoTask(FalAIQueueTask): |
| def __init__(self): |
| super().__init__("text-to-video") |
|
|
| def _prepare_payload_as_dict( |
| self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
| ) -> Optional[Dict]: |
| return {"prompt": inputs, **filter_none(parameters)} |
|
|
| def get_response( |
| self, |
| response: Union[bytes, Dict], |
| request_params: Optional[RequestParameters] = None, |
| ) -> Any: |
| output = super().get_response(response, request_params) |
| url = _as_dict(output)["video"]["url"] |
| return get_session().get(url).content |
|
|
|
|
| class FalAIImageToImageTask(FalAIQueueTask): |
| def __init__(self): |
| super().__init__("image-to-image") |
|
|
| def _prepare_payload_as_dict( |
| self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
| ) -> Optional[Dict]: |
| image_url = _as_url(inputs, default_mime_type="image/jpeg") |
| payload: Dict[str, Any] = { |
| "image_url": image_url, |
| **filter_none(parameters), |
| } |
| if provider_mapping_info.adapter_weights_path is not None: |
| lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( |
| repo_id=provider_mapping_info.hf_model_id, |
| revision="main", |
| filename=provider_mapping_info.adapter_weights_path, |
| ) |
| payload["loras"] = [{"path": lora_path, "scale": 1}] |
|
|
| return payload |
|
|
| def get_response( |
| self, |
| response: Union[bytes, Dict], |
| request_params: Optional[RequestParameters] = None, |
| ) -> Any: |
| output = super().get_response(response, request_params) |
| url = _as_dict(output)["images"][0]["url"] |
| return get_session().get(url).content |
|
|
|
|
| class FalAIImageToVideoTask(FalAIQueueTask): |
| def __init__(self): |
| super().__init__("image-to-video") |
|
|
| def _prepare_payload_as_dict( |
| self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
| ) -> Optional[Dict]: |
| image_url = _as_url(inputs, default_mime_type="image/jpeg") |
| payload: Dict[str, Any] = { |
| "image_url": image_url, |
| **filter_none(parameters), |
| } |
| if provider_mapping_info.adapter_weights_path is not None: |
| lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( |
| repo_id=provider_mapping_info.hf_model_id, |
| revision="main", |
| filename=provider_mapping_info.adapter_weights_path, |
| ) |
| payload["loras"] = [{"path": lora_path, "scale": 1}] |
| return payload |
|
|
| def get_response( |
| self, |
| response: Union[bytes, Dict], |
| request_params: Optional[RequestParameters] = None, |
| ) -> Any: |
| output = super().get_response(response, request_params) |
| url = _as_dict(output)["video"]["url"] |
| return get_session().get(url).content |
|
|