Spaces:
Sleeping
Sleeping
| """Utility function for gradio/external.py, designed for internal use.""" | |
| from __future__ import annotations | |
| import base64 | |
| import inspect | |
| import json | |
| import math | |
| import re | |
| import warnings | |
| import httpx | |
| import yaml | |
| from huggingface_hub import HfApi, ImageClassificationOutputElement, InferenceClient | |
| from gradio import components | |
| from gradio.exceptions import Error, TooManyRequestsError | |
| def get_model_info(model_name, hf_token=None): | |
| hf_api = HfApi(token=hf_token) | |
| print(f"Fetching model from: https://huggingface.co/{model_name}") | |
| model_info = hf_api.model_info(model_name) | |
| pipeline = model_info.pipeline_tag | |
| tags = model_info.tags | |
| return pipeline, tags | |
| ################## | |
| # Helper functions for processing tabular data | |
| ################## | |
| def get_tabular_examples(model_name: str) -> dict[str, list[float]]: | |
| readme = httpx.get(f"https://huggingface.co/{model_name}/resolve/main/README.md") | |
| if readme.status_code != 200: | |
| warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning) | |
| example_data = {} | |
| else: | |
| yaml_regex = re.search( | |
| "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text | |
| ) | |
| if yaml_regex is None: | |
| example_data = {} | |
| else: | |
| example_yaml = next( | |
| yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]) | |
| ) | |
| example_data = example_yaml.get("widget", {}).get("structuredData", {}) | |
| if not example_data: | |
| raise ValueError( | |
| f"No example data found in README.md of {model_name} - Cannot build gradio demo. " | |
| "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " | |
| "for a reference on how to provide example data to your model." | |
| ) | |
| # replace nan with string NaN for inference Endpoints | |
| for data in example_data.values(): | |
| for i, val in enumerate(data): | |
| if isinstance(val, float) and math.isnan(val): | |
| data[i] = "NaN" | |
| return example_data | |
| def cols_to_rows( | |
| example_data: dict[str, list[float | str] | None], | |
| ) -> tuple[list[str], list[list[float]]]: | |
| headers = list(example_data.keys()) | |
| n_rows = max(len(example_data[header] or []) for header in headers) | |
| data = [] | |
| for row_index in range(n_rows): | |
| row_data = [] | |
| for header in headers: | |
| col = example_data[header] or [] | |
| if row_index >= len(col): | |
| row_data.append("NaN") | |
| else: | |
| row_data.append(col[row_index]) | |
| data.append(row_data) | |
| return headers, data | |
| def rows_to_cols(incoming_data: dict) -> dict[str, dict[str, dict[str, list[str]]]]: | |
| data_column_wise = {} | |
| for i, header in enumerate(incoming_data["headers"]): | |
| data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]] | |
| return {"inputs": {"data": data_column_wise}} | |
| ################## | |
| # Helper functions for processing other kinds of data | |
| ################## | |
| def postprocess_label(scores: list[ImageClassificationOutputElement]) -> dict: | |
| return {c.label: c.score for c in scores} | |
| def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict: | |
| return {c["token_str"]: c["score"] for c in scores} | |
| def postprocess_question_answering(answer: dict) -> tuple[str, dict]: | |
| return answer["answer"], {answer["answer"]: answer["score"]} | |
| def postprocess_visual_question_answering(scores: list[dict[str, str | float]]) -> dict: | |
| return {c["answer"]: c["score"] for c in scores} | |
| def zero_shot_classification_wrapper(client: InferenceClient): | |
| def zero_shot_classification_inner(input: str, labels: str, multi_label: bool): | |
| return client.zero_shot_classification( | |
| input, labels.split(","), multi_label=multi_label | |
| ) | |
| return zero_shot_classification_inner | |
| def sentence_similarity_wrapper(client: InferenceClient): | |
| def sentence_similarity_inner(input: str, sentences: str): | |
| return client.sentence_similarity(input, sentences.split("\n")) | |
| return sentence_similarity_inner | |
| def text_generation_wrapper(client: InferenceClient): | |
| def text_generation_inner(input: str): | |
| return input + client.text_generation(input) | |
| return text_generation_inner | |
| def conversational_wrapper(client: InferenceClient): | |
| def chat_fn(message, history): | |
| if not history: | |
| history = [] | |
| history.append({"role": "user", "content": message}) | |
| try: | |
| out = "" | |
| for chunk in client.chat_completion(messages=history, stream=True): | |
| out += chunk.choices[0].delta.content or "" | |
| yield out | |
| except Exception as e: | |
| handle_hf_error(e) | |
| return chat_fn | |
| def encode_to_base64(r: httpx.Response) -> str: | |
| # Handles the different ways HF API returns the prediction | |
| base64_repr = base64.b64encode(r.content).decode("utf-8") | |
| data_prefix = ";base64," | |
| # Case 1: base64 representation already includes data prefix | |
| if data_prefix in base64_repr: | |
| return base64_repr | |
| else: | |
| content_type = r.headers.get("content-type") | |
| # Case 2: the data prefix is a key in the response | |
| if content_type == "application/json": | |
| try: | |
| data = r.json()[0] | |
| content_type = data["content-type"] | |
| base64_repr = data["blob"] | |
| except KeyError as ke: | |
| raise ValueError( | |
| "Cannot determine content type returned by external API." | |
| ) from ke | |
| # Case 3: the data prefix is included in the response headers | |
| else: | |
| pass | |
| new_base64 = f"data:{content_type};base64,{base64_repr}" | |
| return new_base64 | |
| def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]): | |
| if len(ner_groups) == 0: | |
| return [(input_string, None)] | |
| output = [] | |
| end = 0 | |
| prev_end = 0 | |
| for group in ner_groups: | |
| entity, start, end = group["entity_group"], group["start"], group["end"] | |
| output.append((input_string[prev_end:start], None)) | |
| output.append((input_string[start:end], entity)) | |
| prev_end = end | |
| output.append((input_string[end:], None)) | |
| return output | |
| def token_classification_wrapper(client: InferenceClient): | |
| def token_classification_inner(input: str): | |
| ner_list = client.token_classification(input) | |
| return format_ner_list(input, ner_list) # type: ignore | |
| return token_classification_inner | |
| def object_detection_wrapper(client: InferenceClient): | |
| def object_detection_inner(input: str): | |
| annotations = client.object_detection(input) | |
| formatted_annotations = [ | |
| ( | |
| ( | |
| a["box"]["xmin"], | |
| a["box"]["ymin"], | |
| a["box"]["xmax"], | |
| a["box"]["ymax"], | |
| ), | |
| a["label"], | |
| ) | |
| for a in annotations | |
| ] | |
| return (input, formatted_annotations) | |
| return object_detection_inner | |
| def chatbot_preprocess(text, state): | |
| if not state: | |
| return text, [], [] | |
| return ( | |
| text, | |
| state["conversation"]["generated_responses"], | |
| state["conversation"]["past_user_inputs"], | |
| ) | |
| def chatbot_postprocess(response): | |
| chatbot_history = list( | |
| zip( | |
| response["conversation"]["past_user_inputs"], | |
| response["conversation"]["generated_responses"], | |
| strict=False, | |
| ) | |
| ) | |
| return chatbot_history, response | |
| def tabular_wrapper(client: InferenceClient, pipeline: str): | |
| # This wrapper is needed to handle an issue in the InfereneClient where the model name is not | |
| # automatically loaded when using the tabular_classification and tabular_regression methods. | |
| # See: https://github.com/huggingface/huggingface_hub/issues/2015 | |
| def tabular_inner(data): | |
| if pipeline not in ("tabular_classification", "tabular_regression"): | |
| raise TypeError(f"pipeline type {pipeline!r} not supported") | |
| assert client.model # noqa: S101 | |
| if pipeline == "tabular_classification": | |
| return client.tabular_classification(data, model=client.model) | |
| else: | |
| return client.tabular_regression(data, model=client.model) | |
| return tabular_inner | |
| ################## | |
| # Helper function for cleaning up an Interface loaded from HF Spaces | |
| ################## | |
| def streamline_spaces_interface(config: dict) -> dict: | |
| """Streamlines the interface config dictionary to remove unnecessary keys.""" | |
| config["inputs"] = [ | |
| components.get_component_instance(component) | |
| for component in config["input_components"] | |
| ] | |
| config["outputs"] = [ | |
| components.get_component_instance(component) | |
| for component in config["output_components"] | |
| ] | |
| parameters = { | |
| "article", | |
| "description", | |
| "flagging_options", | |
| "inputs", | |
| "outputs", | |
| "title", | |
| } | |
| config = {k: config[k] for k in parameters} | |
| return config | |
| def handle_hf_error(e: Exception): | |
| if "429" in str(e): | |
| raise TooManyRequestsError() from e | |
| elif "401" in str(e) or "You must provide an api_key" in str(e): | |
| raise Error("Unauthorized, please make sure you are signed in.") from e | |
| else: | |
| raise Error(str(e)) from e | |
| def create_endpoint_fn( | |
| endpoint_path: str, | |
| endpoint_method: str, | |
| endpoint_operation: dict, | |
| base_url: str, | |
| ): | |
| # Get request body info for docstring generation | |
| request_body = endpoint_operation.get("requestBody", {}) | |
| def endpoint_fn(*args): | |
| url = f"{base_url.rstrip('/')}{endpoint_path}" | |
| headers = {"Content-Type": "application/json"} | |
| params = {} | |
| body_data = {} | |
| operation_params = endpoint_operation.get("parameters", []) | |
| request_body = endpoint_operation.get("requestBody", {}) | |
| param_index = 0 | |
| for param in operation_params: | |
| if param_index < len(args): | |
| if param.get("in") == "query": | |
| params[param["name"]] = args[param_index] | |
| elif param.get("in") == "path": | |
| url = url.replace(f"{{{param['name']}}}", str(args[param_index])) | |
| param_index += 1 | |
| is_file_upload = False | |
| if request_body and param_index < len(args): | |
| content = request_body.get("content", {}) | |
| for content_type in content: | |
| if content_type in ["application/octet-stream", "multipart/form-data"]: | |
| is_file_upload = True | |
| break | |
| if request_body and param_index < len(args): | |
| if is_file_upload: | |
| file_data = args[param_index] | |
| if file_data: | |
| headers = {"Content-Type": "application/octet-stream"} | |
| body_data = file_data | |
| else: | |
| body_data = b"" | |
| else: | |
| body_data = json.loads(args[param_index]) | |
| try: | |
| if endpoint_method.lower() == "get": | |
| response = httpx.get(url, params=params, headers=headers) | |
| elif endpoint_method.lower() == "post": | |
| response = httpx.post( | |
| url, | |
| params=params, | |
| content=body_data if is_file_upload else None, | |
| json=body_data if not is_file_upload else None, | |
| headers=headers, | |
| ) | |
| elif endpoint_method.lower() == "put": | |
| response = httpx.put( | |
| url, | |
| params=params, | |
| content=body_data if is_file_upload else None, | |
| json=body_data if not is_file_upload else None, | |
| headers=headers, | |
| ) | |
| elif endpoint_method.lower() == "patch": | |
| response = httpx.patch( | |
| url, | |
| params=params, | |
| content=body_data if is_file_upload else None, | |
| json=body_data if not is_file_upload else None, | |
| headers=headers, | |
| ) | |
| elif endpoint_method.lower() == "delete": | |
| response = httpx.delete(url, params=params, headers=headers) | |
| else: | |
| raise ValueError(f"Unsupported HTTP method: {endpoint_method}") | |
| if response.status_code in [200, 201, 202, 204]: | |
| return response.json() | |
| else: | |
| return { | |
| "__status__": "error", | |
| "status_code": response.status_code, | |
| "message": response.text, | |
| } | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| summary = endpoint_operation.get("summary", "") | |
| description = endpoint_operation.get("description", "") | |
| param_docs = [] | |
| param_names = [] | |
| for param in endpoint_operation.get("parameters", []): | |
| param_name = param.get("name", "") | |
| param_desc = param.get("description", "") | |
| param_schema = param.get("schema", {}) | |
| param_enum = param_schema.get("enum", []) | |
| if param_enum: | |
| param_desc += f" (Choices: {', '.join(param_enum)})" | |
| param_names.append(param_name) | |
| param_docs.append(f" {param_name}: {param_desc}") | |
| if request_body: | |
| body_desc = request_body.get("description", "URL of file") | |
| param_docs.append(f" request_body: {body_desc}") | |
| param_names.append("request_body") | |
| docstring_parts = [] | |
| if description or summary: | |
| docstring_parts.append(description or summary) | |
| if param_docs: | |
| docstring_parts.append("Parameters:") | |
| docstring_parts.extend(param_docs) | |
| endpoint_fn.__doc__ = "\n".join(docstring_parts) | |
| if param_names: | |
| sig_params = [] | |
| for name in param_names: | |
| sig_params.append( | |
| inspect.Parameter( | |
| name=name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD | |
| ) | |
| ) | |
| sig_params.append( | |
| inspect.Parameter(name="args", kind=inspect.Parameter.VAR_POSITIONAL) | |
| ) | |
| new_sig = inspect.Signature(parameters=sig_params) | |
| endpoint_fn.__signature__ = new_sig # type: ignore | |
| return endpoint_fn | |
| def component_from_parameter_schema(param_info: dict) -> components.Component: | |
| import gradio as gr | |
| param_name = param_info.get("name") | |
| param_description = param_info.get("description") | |
| param_schema = param_info.get("schema", {}) | |
| param_type = param_schema.get("type") | |
| enum_values = param_schema.get("enum") | |
| default_value = param_schema.get("default") | |
| if enum_values is not None: | |
| component = gr.Dropdown( | |
| choices=enum_values, | |
| label=param_name, | |
| value=default_value, | |
| allow_custom_value=False, | |
| info=param_description, | |
| ) | |
| elif param_type in ("number", "integer"): | |
| component = gr.Number( | |
| label=param_name, | |
| value=default_value, | |
| info=param_description, | |
| ) | |
| elif param_type == "boolean": | |
| component = gr.Checkbox( | |
| label=param_name, | |
| value=default_value, | |
| info=param_description, | |
| ) | |
| elif param_type == "array": | |
| component = gr.Textbox( | |
| label=f"{param_name} (JSON array)", | |
| value="[]", | |
| info=param_description, | |
| ) | |
| else: | |
| component = gr.Textbox( | |
| label=param_name, | |
| value=default_value, | |
| info=param_description, | |
| ) | |
| return component | |
| def resolve_schema_ref(schema: dict, spec: dict) -> dict: | |
| """Resolve schema references in OpenAPI spec.""" | |
| if "$ref" in schema: | |
| ref_path = schema["$ref"] | |
| if ref_path.startswith("#/components/schemas/"): | |
| schema_name = ref_path.split("/")[-1] | |
| return spec.get("components", {}).get("schemas", {}).get(schema_name, {}) | |
| elif ref_path.startswith("#/"): | |
| path_parts = ref_path.split("/")[1:] | |
| current = spec | |
| for part in path_parts: | |
| current = current.get(part, {}) | |
| return current | |
| return schema | |
| def component_from_request_body_schema( | |
| request_body: dict, spec: dict | |
| ) -> components.Component | None: | |
| """Create a Gradio component from an OpenAPI request body schema.""" | |
| import gradio as gr | |
| if not request_body: | |
| return None | |
| content = request_body.get("content", {}) | |
| description = request_body.get("description", "Request Body") | |
| for content_type, content_schema in content.items(): | |
| if content_type in ["application/octet-stream", "multipart/form-data"]: | |
| schema = resolve_schema_ref(content_schema.get("schema", {}), spec) | |
| if schema.get("type") == "string" and schema.get("format") == "binary": | |
| return gr.File(label="File") | |
| json_content = content.get("application/json", {}) | |
| if not json_content: | |
| for content_type, content_schema in content.items(): | |
| if content_type.startswith("application/"): | |
| json_content = content_schema | |
| break | |
| if not json_content: | |
| return None | |
| schema = resolve_schema_ref(json_content.get("schema", {}), spec) | |
| default_value = schema.get("example", {}) | |
| if not default_value and schema.get("type") == "object": | |
| properties = schema.get("properties", {}) | |
| default_value = {} | |
| for prop_name, prop_schema in properties.items(): | |
| prop_schema = resolve_schema_ref(prop_schema, spec) | |
| prop_type = prop_schema.get("type") | |
| if prop_type == "string": | |
| default_value[prop_name] = prop_schema.get("example", "") | |
| elif prop_type in ("number", "integer"): | |
| default_value[prop_name] = prop_schema.get("example", 0) | |
| elif prop_type == "boolean": | |
| default_value[prop_name] = prop_schema.get("example", False) | |
| elif prop_type == "array": | |
| default_value[prop_name] = prop_schema.get("example", []) | |
| elif prop_type == "object": | |
| default_value[prop_name] = prop_schema.get("example", {}) | |
| component = gr.Textbox( | |
| label="Request Body", | |
| value=json.dumps(default_value, indent=2), | |
| info=description, | |
| ) | |
| return component | |
| def method_box(method: str) -> str: | |
| color_map = { | |
| "GET": "#61affe", | |
| "POST": "#49cc90", | |
| "PUT": "#fca130", | |
| "DELETE": "#f93e3e", | |
| "PATCH": "#50e3c2", | |
| } | |
| color = color_map.get(method.upper(), "#999") | |
| return ( | |
| f"<span style='" | |
| f"display:inline-block;min-width:48px;padding:2px 10px;border-radius:4px;" | |
| f"background:{color};color:white;font-weight:bold;font-family:monospace;" | |
| f"margin-right:8px;text-align:center;border:2px solid {color};" | |
| f"box-shadow:0 1px 2px rgba(0,0,0,0.08);'" | |
| f">{method.upper()}</span>" | |
| ) | |