shubhamofbce's picture
Upload folder using huggingface_hub
eeb9cbc verified
"""Utility function for gradio/external.py, designed for internal use."""
from __future__ import annotations
import base64
import inspect
import json
import math
import re
import warnings
import httpx
import yaml
from huggingface_hub import HfApi, ImageClassificationOutputElement, InferenceClient
from gradio import components
from gradio.exceptions import Error, TooManyRequestsError
def get_model_info(model_name, hf_token=None):
hf_api = HfApi(token=hf_token)
print(f"Fetching model from: https://huggingface.co/{model_name}")
model_info = hf_api.model_info(model_name)
pipeline = model_info.pipeline_tag
tags = model_info.tags
return pipeline, tags
##################
# Helper functions for processing tabular data
##################
def get_tabular_examples(model_name: str) -> dict[str, list[float]]:
readme = httpx.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
if readme.status_code != 200:
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
example_data = {}
else:
yaml_regex = re.search(
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
)
if yaml_regex is None:
example_data = {}
else:
example_yaml = next(
yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
)
example_data = example_yaml.get("widget", {}).get("structuredData", {})
if not example_data:
raise ValueError(
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
"for a reference on how to provide example data to your model."
)
# replace nan with string NaN for inference Endpoints
for data in example_data.values():
for i, val in enumerate(data):
if isinstance(val, float) and math.isnan(val):
data[i] = "NaN"
return example_data
def cols_to_rows(
example_data: dict[str, list[float | str] | None],
) -> tuple[list[str], list[list[float]]]:
headers = list(example_data.keys())
n_rows = max(len(example_data[header] or []) for header in headers)
data = []
for row_index in range(n_rows):
row_data = []
for header in headers:
col = example_data[header] or []
if row_index >= len(col):
row_data.append("NaN")
else:
row_data.append(col[row_index])
data.append(row_data)
return headers, data
def rows_to_cols(incoming_data: dict) -> dict[str, dict[str, dict[str, list[str]]]]:
data_column_wise = {}
for i, header in enumerate(incoming_data["headers"]):
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
return {"inputs": {"data": data_column_wise}}
##################
# Helper functions for processing other kinds of data
##################
def postprocess_label(scores: list[ImageClassificationOutputElement]) -> dict:
return {c.label: c.score for c in scores}
def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict:
return {c["token_str"]: c["score"] for c in scores}
def postprocess_question_answering(answer: dict) -> tuple[str, dict]:
return answer["answer"], {answer["answer"]: answer["score"]}
def postprocess_visual_question_answering(scores: list[dict[str, str | float]]) -> dict:
return {c["answer"]: c["score"] for c in scores}
def zero_shot_classification_wrapper(client: InferenceClient):
def zero_shot_classification_inner(input: str, labels: str, multi_label: bool):
return client.zero_shot_classification(
input, labels.split(","), multi_label=multi_label
)
return zero_shot_classification_inner
def sentence_similarity_wrapper(client: InferenceClient):
def sentence_similarity_inner(input: str, sentences: str):
return client.sentence_similarity(input, sentences.split("\n"))
return sentence_similarity_inner
def text_generation_wrapper(client: InferenceClient):
def text_generation_inner(input: str):
return input + client.text_generation(input)
return text_generation_inner
def conversational_wrapper(client: InferenceClient):
def chat_fn(message, history):
if not history:
history = []
history.append({"role": "user", "content": message})
try:
out = ""
for chunk in client.chat_completion(messages=history, stream=True):
out += chunk.choices[0].delta.content or ""
yield out
except Exception as e:
handle_hf_error(e)
return chat_fn
def encode_to_base64(r: httpx.Response) -> str:
# Handles the different ways HF API returns the prediction
base64_repr = base64.b64encode(r.content).decode("utf-8")
data_prefix = ";base64,"
# Case 1: base64 representation already includes data prefix
if data_prefix in base64_repr:
return base64_repr
else:
content_type = r.headers.get("content-type")
# Case 2: the data prefix is a key in the response
if content_type == "application/json":
try:
data = r.json()[0]
content_type = data["content-type"]
base64_repr = data["blob"]
except KeyError as ke:
raise ValueError(
"Cannot determine content type returned by external API."
) from ke
# Case 3: the data prefix is included in the response headers
else:
pass
new_base64 = f"data:{content_type};base64,{base64_repr}"
return new_base64
def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]):
if len(ner_groups) == 0:
return [(input_string, None)]
output = []
end = 0
prev_end = 0
for group in ner_groups:
entity, start, end = group["entity_group"], group["start"], group["end"]
output.append((input_string[prev_end:start], None))
output.append((input_string[start:end], entity))
prev_end = end
output.append((input_string[end:], None))
return output
def token_classification_wrapper(client: InferenceClient):
def token_classification_inner(input: str):
ner_list = client.token_classification(input)
return format_ner_list(input, ner_list) # type: ignore
return token_classification_inner
def object_detection_wrapper(client: InferenceClient):
def object_detection_inner(input: str):
annotations = client.object_detection(input)
formatted_annotations = [
(
(
a["box"]["xmin"],
a["box"]["ymin"],
a["box"]["xmax"],
a["box"]["ymax"],
),
a["label"],
)
for a in annotations
]
return (input, formatted_annotations)
return object_detection_inner
def chatbot_preprocess(text, state):
if not state:
return text, [], []
return (
text,
state["conversation"]["generated_responses"],
state["conversation"]["past_user_inputs"],
)
def chatbot_postprocess(response):
chatbot_history = list(
zip(
response["conversation"]["past_user_inputs"],
response["conversation"]["generated_responses"],
strict=False,
)
)
return chatbot_history, response
def tabular_wrapper(client: InferenceClient, pipeline: str):
# This wrapper is needed to handle an issue in the InfereneClient where the model name is not
# automatically loaded when using the tabular_classification and tabular_regression methods.
# See: https://github.com/huggingface/huggingface_hub/issues/2015
def tabular_inner(data):
if pipeline not in ("tabular_classification", "tabular_regression"):
raise TypeError(f"pipeline type {pipeline!r} not supported")
assert client.model # noqa: S101
if pipeline == "tabular_classification":
return client.tabular_classification(data, model=client.model)
else:
return client.tabular_regression(data, model=client.model)
return tabular_inner
##################
# Helper function for cleaning up an Interface loaded from HF Spaces
##################
def streamline_spaces_interface(config: dict) -> dict:
"""Streamlines the interface config dictionary to remove unnecessary keys."""
config["inputs"] = [
components.get_component_instance(component)
for component in config["input_components"]
]
config["outputs"] = [
components.get_component_instance(component)
for component in config["output_components"]
]
parameters = {
"article",
"description",
"flagging_options",
"inputs",
"outputs",
"title",
}
config = {k: config[k] for k in parameters}
return config
def handle_hf_error(e: Exception):
if "429" in str(e):
raise TooManyRequestsError() from e
elif "401" in str(e) or "You must provide an api_key" in str(e):
raise Error("Unauthorized, please make sure you are signed in.") from e
else:
raise Error(str(e)) from e
def create_endpoint_fn(
endpoint_path: str,
endpoint_method: str,
endpoint_operation: dict,
base_url: str,
):
# Get request body info for docstring generation
request_body = endpoint_operation.get("requestBody", {})
def endpoint_fn(*args):
url = f"{base_url.rstrip('/')}{endpoint_path}"
headers = {"Content-Type": "application/json"}
params = {}
body_data = {}
operation_params = endpoint_operation.get("parameters", [])
request_body = endpoint_operation.get("requestBody", {})
param_index = 0
for param in operation_params:
if param_index < len(args):
if param.get("in") == "query":
params[param["name"]] = args[param_index]
elif param.get("in") == "path":
url = url.replace(f"{{{param['name']}}}", str(args[param_index]))
param_index += 1
is_file_upload = False
if request_body and param_index < len(args):
content = request_body.get("content", {})
for content_type in content:
if content_type in ["application/octet-stream", "multipart/form-data"]:
is_file_upload = True
break
if request_body and param_index < len(args):
if is_file_upload:
file_data = args[param_index]
if file_data:
headers = {"Content-Type": "application/octet-stream"}
body_data = file_data
else:
body_data = b""
else:
body_data = json.loads(args[param_index])
try:
if endpoint_method.lower() == "get":
response = httpx.get(url, params=params, headers=headers)
elif endpoint_method.lower() == "post":
response = httpx.post(
url,
params=params,
content=body_data if is_file_upload else None,
json=body_data if not is_file_upload else None,
headers=headers,
)
elif endpoint_method.lower() == "put":
response = httpx.put(
url,
params=params,
content=body_data if is_file_upload else None,
json=body_data if not is_file_upload else None,
headers=headers,
)
elif endpoint_method.lower() == "patch":
response = httpx.patch(
url,
params=params,
content=body_data if is_file_upload else None,
json=body_data if not is_file_upload else None,
headers=headers,
)
elif endpoint_method.lower() == "delete":
response = httpx.delete(url, params=params, headers=headers)
else:
raise ValueError(f"Unsupported HTTP method: {endpoint_method}")
if response.status_code in [200, 201, 202, 204]:
return response.json()
else:
return {
"__status__": "error",
"status_code": response.status_code,
"message": response.text,
}
except Exception as e:
return f"Error: {str(e)}"
summary = endpoint_operation.get("summary", "")
description = endpoint_operation.get("description", "")
param_docs = []
param_names = []
for param in endpoint_operation.get("parameters", []):
param_name = param.get("name", "")
param_desc = param.get("description", "")
param_schema = param.get("schema", {})
param_enum = param_schema.get("enum", [])
if param_enum:
param_desc += f" (Choices: {', '.join(param_enum)})"
param_names.append(param_name)
param_docs.append(f" {param_name}: {param_desc}")
if request_body:
body_desc = request_body.get("description", "URL of file")
param_docs.append(f" request_body: {body_desc}")
param_names.append("request_body")
docstring_parts = []
if description or summary:
docstring_parts.append(description or summary)
if param_docs:
docstring_parts.append("Parameters:")
docstring_parts.extend(param_docs)
endpoint_fn.__doc__ = "\n".join(docstring_parts)
if param_names:
sig_params = []
for name in param_names:
sig_params.append(
inspect.Parameter(
name=name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD
)
)
sig_params.append(
inspect.Parameter(name="args", kind=inspect.Parameter.VAR_POSITIONAL)
)
new_sig = inspect.Signature(parameters=sig_params)
endpoint_fn.__signature__ = new_sig # type: ignore
return endpoint_fn
def component_from_parameter_schema(param_info: dict) -> components.Component:
import gradio as gr
param_name = param_info.get("name")
param_description = param_info.get("description")
param_schema = param_info.get("schema", {})
param_type = param_schema.get("type")
enum_values = param_schema.get("enum")
default_value = param_schema.get("default")
if enum_values is not None:
component = gr.Dropdown(
choices=enum_values,
label=param_name,
value=default_value,
allow_custom_value=False,
info=param_description,
)
elif param_type in ("number", "integer"):
component = gr.Number(
label=param_name,
value=default_value,
info=param_description,
)
elif param_type == "boolean":
component = gr.Checkbox(
label=param_name,
value=default_value,
info=param_description,
)
elif param_type == "array":
component = gr.Textbox(
label=f"{param_name} (JSON array)",
value="[]",
info=param_description,
)
else:
component = gr.Textbox(
label=param_name,
value=default_value,
info=param_description,
)
return component
def resolve_schema_ref(schema: dict, spec: dict) -> dict:
"""Resolve schema references in OpenAPI spec."""
if "$ref" in schema:
ref_path = schema["$ref"]
if ref_path.startswith("#/components/schemas/"):
schema_name = ref_path.split("/")[-1]
return spec.get("components", {}).get("schemas", {}).get(schema_name, {})
elif ref_path.startswith("#/"):
path_parts = ref_path.split("/")[1:]
current = spec
for part in path_parts:
current = current.get(part, {})
return current
return schema
def component_from_request_body_schema(
request_body: dict, spec: dict
) -> components.Component | None:
"""Create a Gradio component from an OpenAPI request body schema."""
import gradio as gr
if not request_body:
return None
content = request_body.get("content", {})
description = request_body.get("description", "Request Body")
for content_type, content_schema in content.items():
if content_type in ["application/octet-stream", "multipart/form-data"]:
schema = resolve_schema_ref(content_schema.get("schema", {}), spec)
if schema.get("type") == "string" and schema.get("format") == "binary":
return gr.File(label="File")
json_content = content.get("application/json", {})
if not json_content:
for content_type, content_schema in content.items():
if content_type.startswith("application/"):
json_content = content_schema
break
if not json_content:
return None
schema = resolve_schema_ref(json_content.get("schema", {}), spec)
default_value = schema.get("example", {})
if not default_value and schema.get("type") == "object":
properties = schema.get("properties", {})
default_value = {}
for prop_name, prop_schema in properties.items():
prop_schema = resolve_schema_ref(prop_schema, spec)
prop_type = prop_schema.get("type")
if prop_type == "string":
default_value[prop_name] = prop_schema.get("example", "")
elif prop_type in ("number", "integer"):
default_value[prop_name] = prop_schema.get("example", 0)
elif prop_type == "boolean":
default_value[prop_name] = prop_schema.get("example", False)
elif prop_type == "array":
default_value[prop_name] = prop_schema.get("example", [])
elif prop_type == "object":
default_value[prop_name] = prop_schema.get("example", {})
component = gr.Textbox(
label="Request Body",
value=json.dumps(default_value, indent=2),
info=description,
)
return component
def method_box(method: str) -> str:
color_map = {
"GET": "#61affe",
"POST": "#49cc90",
"PUT": "#fca130",
"DELETE": "#f93e3e",
"PATCH": "#50e3c2",
}
color = color_map.get(method.upper(), "#999")
return (
f"<span style='"
f"display:inline-block;min-width:48px;padding:2px 10px;border-radius:4px;"
f"background:{color};color:white;font-weight:bold;font-family:monospace;"
f"margin-right:8px;text-align:center;border:2px solid {color};"
f"box-shadow:0 1px 2px rgba(0,0,0,0.08);'"
f">{method.upper()}</span>"
)