|
|
import base64 |
|
|
import time |
|
|
from abc import ABC |
|
|
from typing import Any, Dict, Optional, Union |
|
|
from urllib.parse import urlparse |
|
|
|
|
|
from huggingface_hub import constants |
|
|
from huggingface_hub.hf_api import InferenceProviderMapping |
|
|
from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url |
|
|
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none |
|
|
from huggingface_hub.utils import get_session, hf_raise_for_status |
|
|
from huggingface_hub.utils.logging import get_logger |
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
_POLLING_INTERVAL = 0.5 |
|
|
|
|
|
|
|
|
class FalAITask(TaskProviderHelper, ABC): |
|
|
def __init__(self, task: str): |
|
|
super().__init__(provider="fal-ai", base_url="https://fal.run", task=task) |
|
|
|
|
|
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: |
|
|
headers = super()._prepare_headers(headers, api_key) |
|
|
if not api_key.startswith("hf_"): |
|
|
headers["authorization"] = f"Key {api_key}" |
|
|
return headers |
|
|
|
|
|
def _prepare_route(self, mapped_model: str, api_key: str) -> str: |
|
|
return f"/{mapped_model}" |
|
|
|
|
|
|
|
|
class FalAIQueueTask(TaskProviderHelper, ABC): |
|
|
def __init__(self, task: str): |
|
|
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task) |
|
|
|
|
|
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: |
|
|
headers = super()._prepare_headers(headers, api_key) |
|
|
if not api_key.startswith("hf_"): |
|
|
headers["authorization"] = f"Key {api_key}" |
|
|
return headers |
|
|
|
|
|
def _prepare_route(self, mapped_model: str, api_key: str) -> str: |
|
|
if api_key.startswith("hf_"): |
|
|
|
|
|
return f"/{mapped_model}?_subdomain=queue" |
|
|
return f"/{mapped_model}" |
|
|
|
|
|
def get_response( |
|
|
self, |
|
|
response: Union[bytes, Dict], |
|
|
request_params: Optional[RequestParameters] = None, |
|
|
) -> Any: |
|
|
response_dict = _as_dict(response) |
|
|
|
|
|
request_id = response_dict.get("request_id") |
|
|
if not request_id: |
|
|
raise ValueError("No request ID found in the response") |
|
|
if request_params is None: |
|
|
raise ValueError( |
|
|
f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI." |
|
|
) |
|
|
|
|
|
|
|
|
parsed_url = urlparse(request_params.url) |
|
|
|
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" |
|
|
query_param = f"?{parsed_url.query}" if parsed_url.query else "" |
|
|
|
|
|
|
|
|
|
|
|
model_id = urlparse(response_dict.get("response_url")).path |
|
|
status_url = f"{base_url}{str(model_id)}/status{query_param}" |
|
|
result_url = f"{base_url}{str(model_id)}{query_param}" |
|
|
|
|
|
status = response_dict.get("status") |
|
|
logger.info("Generating the output.. this can take several minutes.") |
|
|
while status != "COMPLETED": |
|
|
time.sleep(_POLLING_INTERVAL) |
|
|
status_response = get_session().get(status_url, headers=request_params.headers) |
|
|
hf_raise_for_status(status_response) |
|
|
status = status_response.json().get("status") |
|
|
|
|
|
return get_session().get(result_url, headers=request_params.headers).json() |
|
|
|
|
|
|
|
|
class FalAIAutomaticSpeechRecognitionTask(FalAITask): |
|
|
def __init__(self): |
|
|
super().__init__("automatic-speech-recognition") |
|
|
|
|
|
def _prepare_payload_as_dict( |
|
|
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
|
|
) -> Optional[Dict]: |
|
|
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): |
|
|
|
|
|
audio_url = inputs |
|
|
else: |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
with open(inputs, "rb") as f: |
|
|
inputs = f.read() |
|
|
|
|
|
audio_b64 = base64.b64encode(inputs).decode() |
|
|
content_type = "audio/mpeg" |
|
|
audio_url = f"data:{content_type};base64,{audio_b64}" |
|
|
|
|
|
return {"audio_url": audio_url, **filter_none(parameters)} |
|
|
|
|
|
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: |
|
|
text = _as_dict(response)["text"] |
|
|
if not isinstance(text, str): |
|
|
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.") |
|
|
return text |
|
|
|
|
|
|
|
|
class FalAITextToImageTask(FalAITask): |
|
|
def __init__(self): |
|
|
super().__init__("text-to-image") |
|
|
|
|
|
def _prepare_payload_as_dict( |
|
|
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
|
|
) -> Optional[Dict]: |
|
|
payload: Dict[str, Any] = { |
|
|
"prompt": inputs, |
|
|
**filter_none(parameters), |
|
|
} |
|
|
if "width" in payload and "height" in payload: |
|
|
payload["image_size"] = { |
|
|
"width": payload.pop("width"), |
|
|
"height": payload.pop("height"), |
|
|
} |
|
|
if provider_mapping_info.adapter_weights_path is not None: |
|
|
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( |
|
|
repo_id=provider_mapping_info.hf_model_id, |
|
|
revision="main", |
|
|
filename=provider_mapping_info.adapter_weights_path, |
|
|
) |
|
|
payload["loras"] = [{"path": lora_path, "scale": 1}] |
|
|
if provider_mapping_info.provider_id == "fal-ai/lora": |
|
|
|
|
|
|
|
|
payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
|
|
|
return payload |
|
|
|
|
|
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: |
|
|
url = _as_dict(response)["images"][0]["url"] |
|
|
return get_session().get(url).content |
|
|
|
|
|
|
|
|
class FalAITextToSpeechTask(FalAITask): |
|
|
def __init__(self): |
|
|
super().__init__("text-to-speech") |
|
|
|
|
|
def _prepare_payload_as_dict( |
|
|
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
|
|
) -> Optional[Dict]: |
|
|
return {"text": inputs, **filter_none(parameters)} |
|
|
|
|
|
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: |
|
|
url = _as_dict(response)["audio"]["url"] |
|
|
return get_session().get(url).content |
|
|
|
|
|
|
|
|
class FalAITextToVideoTask(FalAIQueueTask): |
|
|
def __init__(self): |
|
|
super().__init__("text-to-video") |
|
|
|
|
|
def _prepare_payload_as_dict( |
|
|
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
|
|
) -> Optional[Dict]: |
|
|
return {"prompt": inputs, **filter_none(parameters)} |
|
|
|
|
|
def get_response( |
|
|
self, |
|
|
response: Union[bytes, Dict], |
|
|
request_params: Optional[RequestParameters] = None, |
|
|
) -> Any: |
|
|
output = super().get_response(response, request_params) |
|
|
url = _as_dict(output)["video"]["url"] |
|
|
return get_session().get(url).content |
|
|
|
|
|
|
|
|
class FalAIImageToImageTask(FalAIQueueTask): |
|
|
def __init__(self): |
|
|
super().__init__("image-to-image") |
|
|
|
|
|
def _prepare_payload_as_dict( |
|
|
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
|
|
) -> Optional[Dict]: |
|
|
image_url = _as_url(inputs, default_mime_type="image/jpeg") |
|
|
payload: Dict[str, Any] = { |
|
|
"image_url": image_url, |
|
|
**filter_none(parameters), |
|
|
} |
|
|
if provider_mapping_info.adapter_weights_path is not None: |
|
|
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( |
|
|
repo_id=provider_mapping_info.hf_model_id, |
|
|
revision="main", |
|
|
filename=provider_mapping_info.adapter_weights_path, |
|
|
) |
|
|
payload["loras"] = [{"path": lora_path, "scale": 1}] |
|
|
|
|
|
return payload |
|
|
|
|
|
def get_response( |
|
|
self, |
|
|
response: Union[bytes, Dict], |
|
|
request_params: Optional[RequestParameters] = None, |
|
|
) -> Any: |
|
|
output = super().get_response(response, request_params) |
|
|
url = _as_dict(output)["images"][0]["url"] |
|
|
return get_session().get(url).content |
|
|
|
|
|
|
|
|
class FalAIImageToVideoTask(FalAIQueueTask): |
|
|
def __init__(self): |
|
|
super().__init__("image-to-video") |
|
|
|
|
|
def _prepare_payload_as_dict( |
|
|
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping |
|
|
) -> Optional[Dict]: |
|
|
image_url = _as_url(inputs, default_mime_type="image/jpeg") |
|
|
payload: Dict[str, Any] = { |
|
|
"image_url": image_url, |
|
|
**filter_none(parameters), |
|
|
} |
|
|
if provider_mapping_info.adapter_weights_path is not None: |
|
|
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( |
|
|
repo_id=provider_mapping_info.hf_model_id, |
|
|
revision="main", |
|
|
filename=provider_mapping_info.adapter_weights_path, |
|
|
) |
|
|
payload["loras"] = [{"path": lora_path, "scale": 1}] |
|
|
return payload |
|
|
|
|
|
def get_response( |
|
|
self, |
|
|
response: Union[bytes, Dict], |
|
|
request_params: Optional[RequestParameters] = None, |
|
|
) -> Any: |
|
|
output = super().get_response(response, request_params) |
|
|
url = _as_dict(output)["video"]["url"] |
|
|
return get_session().get(url).content |
|
|
|