| import gradio as gr |
| from urllib.parse import urlparse |
| import requests |
| import time |
| from PIL import Image |
| import base64 |
| import io |
| import uuid |
| import os |
|
|
|
|
| def extract_property_info(prop): |
| combined_prop = {} |
| merge_keywords = ["allOf", "anyOf", "oneOf"] |
|
|
| for keyword in merge_keywords: |
| if keyword in prop: |
| for subprop in prop[keyword]: |
| combined_prop.update(subprop) |
| del prop[keyword] |
|
|
| if not combined_prop: |
| combined_prop = prop.copy() |
|
|
| for key in ["description", "default"]: |
| if key in prop: |
| combined_prop[key] = prop[key] |
|
|
| return combined_prop |
|
|
|
|
| def detect_file_type(filename): |
| audio_extensions = [".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"] |
| image_extensions = [ |
| ".jpg", |
| ".jpeg", |
| ".png", |
| ".gif", |
| ".bmp", |
| ".tiff", |
| ".svg", |
| ".webp", |
| ] |
| video_extensions = [ |
| ".mp4", |
| ".mov", |
| ".wmv", |
| ".flv", |
| ".avi", |
| ".avchd", |
| ".mkv", |
| ".webm", |
| ] |
|
|
| |
| if isinstance(filename, str): |
| extension = filename[filename.rfind(".") :].lower() |
|
|
| |
| if extension in audio_extensions: |
| return "audio" |
| elif extension in image_extensions: |
| return "image" |
| elif extension in video_extensions: |
| return "video" |
| else: |
| return "string" |
| elif isinstance(filename, list): |
| return "list" |
|
|
|
|
| def build_gradio_inputs(ordered_input_schema, example_inputs=None): |
| inputs = [] |
| input_field_strings = """inputs = []\n""" |
| names = [] |
| for index, (name, prop) in enumerate(ordered_input_schema): |
| names.append(name) |
| prop = extract_property_info(prop) |
| if "enum" in prop: |
| input_field = gr.Dropdown( |
| choices=prop["enum"], |
| label=prop.get("title"), |
| info=prop.get("description"), |
| value=prop.get("default"), |
| ) |
| input_field_string = f"""inputs.append(gr.Dropdown( |
| choices={prop["enum"]}, label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value="{prop.get("default")}" |
| ))\n""" |
| elif prop["type"] == "integer": |
| if prop.get("minimum") and prop.get("maximum"): |
| input_field = gr.Slider( |
| label=prop.get("title"), |
| info=prop.get("description"), |
| value=prop.get("default"), |
| minimum=prop.get("minimum"), |
| maximum=prop.get("maximum"), |
| step=1, |
| ) |
| input_field_string = f"""inputs.append(gr.Slider( |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, |
| minimum={prop.get("minimum")}, maximum={prop.get("maximum")}, step=1, |
| ))\n""" |
| else: |
| input_field = gr.Number( |
| label=prop.get("title"), |
| info=prop.get("description"), |
| value=prop.get("default"), |
| ) |
| input_field_string = f"""inputs.append(gr.Number( |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} |
| ))\n""" |
| elif prop["type"] == "number": |
| if prop.get("minimum") and prop.get("maximum"): |
| input_field = gr.Slider( |
| label=prop.get("title"), |
| info=prop.get("description"), |
| value=prop.get("default"), |
| minimum=prop.get("minimum"), |
| maximum=prop.get("maximum"), |
| ) |
| input_field_string = f"""inputs.append(gr.Slider( |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, |
| minimum={prop.get("minimum")}, maximum={prop.get("maximum")} |
| ))\n""" |
| else: |
| input_field = gr.Number( |
| label=prop.get("title"), |
| info=prop.get("description"), |
| value=prop.get("default"), |
| ) |
| input_field_string = f"""inputs.append(gr.Number( |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} |
| ))\n""" |
| elif prop["type"] == "boolean": |
| input_field = gr.Checkbox( |
| label=prop.get("title"), |
| info=prop.get("description"), |
| value=prop.get("default"), |
| ) |
| input_field_string = f"""inputs.append(gr.Checkbox( |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} |
| ))\n""" |
| elif ( |
| prop["type"] == "string" and prop.get("format") == "uri" and example_inputs |
| ): |
| input_type_example = example_inputs.get(name, None) |
| if input_type_example: |
| input_type = detect_file_type(input_type_example) |
| else: |
| input_type = None |
| if input_type == "image": |
| input_field = gr.Image(label=prop.get("title"), type="filepath") |
| input_field_string = f"""inputs.append(gr.Image( |
| label="{prop.get("title")}", type="filepath" |
| ))\n""" |
| elif input_type == "audio": |
| input_field = gr.Audio(label=prop.get("title"), type="filepath") |
| input_field_string = f"""inputs.append(gr.Audio( |
| label="{prop.get("title")}", type="filepath" |
| ))\n""" |
| elif input_type == "video": |
| input_field = gr.Video(label=prop.get("title")) |
| input_field_string = f"""inputs.append(gr.Video( |
| label="{prop.get("title")}" |
| ))\n""" |
| else: |
| input_field = gr.File(label=prop.get("title")) |
| input_field_string = f"""inputs.append(gr.File( |
| label="{prop.get("title")}" |
| ))\n""" |
| else: |
| input_field = gr.Textbox( |
| label=prop.get("title"), |
| info=prop.get("description"), |
| ) |
| input_field_string = f"""inputs.append(gr.Textbox( |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'} |
| ))\n""" |
| inputs.append(input_field) |
| input_field_strings += f"{input_field_string}\n" |
|
|
| input_field_strings += f"names = {names}\n" |
|
|
| return inputs, input_field_strings, names |
|
|
|
|
| def build_gradio_outputs_replicate(output_types): |
| outputs = [] |
| output_field_strings = """outputs = []\n""" |
| if output_types: |
| for output in output_types: |
| if output == "image": |
| output_field = gr.Image() |
| output_field_string = "outputs.append(gr.Image())" |
| elif output == "audio": |
| output_field = gr.Audio(type="filepath") |
| output_field_string = "outputs.append(gr.Audio(type='filepath'))" |
| elif output == "video": |
| output_field = gr.Video() |
| output_field_string = "outputs.append(gr.Video())" |
| elif output == "string": |
| output_field = gr.Textbox() |
| output_field_string = "outputs.append(gr.Textbox())" |
| elif output == "json": |
| output_field = gr.JSON() |
| output_field_string = "outputs.append(gr.JSON())" |
| elif output == "list": |
| output_field = gr.JSON() |
| output_field_string = "outputs.append(gr.JSON())" |
| outputs.append(output_field) |
| output_field_strings += f"{output_field_string}\n" |
| else: |
| output_field = gr.JSON() |
| output_field_string = "outputs.append(gr.JSON())" |
| outputs.append(output_field) |
|
|
| return outputs, output_field_strings |
|
|
|
|
| def build_gradio_outputs_cog(): |
| pass |
|
|
|
|
| def process_outputs(outputs): |
| output_values = [] |
| for output in outputs: |
| if not output: |
| continue |
| if isinstance(output, str): |
| if output.startswith("data:image"): |
| base64_data = output.split(",", 1)[1] |
| image_data = base64.b64decode(base64_data) |
| image_stream = io.BytesIO(image_data) |
| image = Image.open(image_stream) |
| output_values.append(image) |
| elif output.startswith("data:audio"): |
| base64_data = output.split(",", 1)[1] |
| audio_data = base64.b64decode(base64_data) |
| audio_stream = io.BytesIO(audio_data) |
| filename = f"{uuid.uuid4()}.wav" |
| with open(filename, "wb") as audio_file: |
| audio_file.write(audio_stream.getbuffer()) |
| output_values.append(filename) |
| elif output.startswith("data:video"): |
| base64_data = output.split(",", 1)[1] |
| video_data = base64.b64decode(base64_data) |
| video_stream = io.BytesIO(video_data) |
| |
| filename = f"{uuid.uuid4()}.mp4" |
| with open(filename, "wb") as video_file: |
| video_file.write(video_stream.getbuffer()) |
| output_values.append(filename) |
| else: |
| output_values.append(output) |
| else: |
| output_values.append(output) |
| return output_values |
|
|
|
|
| def parse_outputs(data): |
| if isinstance(data, dict): |
| |
| dict_values = [] |
| for value in data.values(): |
| extracted_values = parse_outputs(value) |
| |
| if isinstance(value, list): |
| dict_values += [extracted_values] |
| else: |
| dict_values += extracted_values |
| return dict_values |
| elif isinstance(data, list): |
| |
| list_values = [] |
| for item in data: |
| |
| list_values += parse_outputs(item) |
| return list_values |
| else: |
| |
| return [data] |
|
|
|
|
| def create_dynamic_gradio_app( |
| inputs, |
| outputs, |
| api_url, |
| api_id=None, |
| replicate_token=None, |
| title="", |
| model_description="", |
| names=[], |
| local_base=False, |
| hostname="0.0.0.0", |
| ): |
| expected_outputs = len(outputs) |
|
|
| def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)): |
| payload = {"input": {}} |
| if api_id: |
| payload["version"] = api_id |
| parsed_url = urlparse(str(request.url)) |
| if local_base: |
| base_url = f"http://{hostname}:7860" |
| else: |
| base_url = parsed_url.scheme + "://" + parsed_url.netloc |
| for i, key in enumerate(names): |
| value = args[i] |
| if value and (os.path.exists(str(value))): |
| value = f"{base_url}/file=" + value |
| if value is not None and value != "": |
| payload["input"][key] = value |
| print(payload) |
| headers = {"Content-Type": "application/json"} |
| if replicate_token: |
| headers["Authorization"] = f"Token {replicate_token}" |
| print(headers) |
| response = requests.post(api_url, headers=headers, json=payload) |
| if response.status_code == 201: |
| follow_up_url = response.json()["urls"]["get"] |
| response = requests.get(follow_up_url, headers=headers) |
| while response.json()["status"] != "succeeded": |
| if response.json()["status"] == "failed": |
| raise gr.Error("The submission failed!") |
| response = requests.get(follow_up_url, headers=headers) |
| time.sleep(1) |
| |
| if response.status_code == 200: |
| json_response = response.json() |
| |
| if outputs[0].get_config()["name"] == "json": |
| return json_response["output"] |
| predict_outputs = parse_outputs(json_response["output"]) |
| processed_outputs = process_outputs(predict_outputs) |
| difference_outputs = expected_outputs - len(processed_outputs) |
| |
| if difference_outputs > 0: |
| extra_outputs = [gr.update(visible=False)] * difference_outputs |
| processed_outputs.extend(extra_outputs) |
| |
| elif difference_outputs < 0: |
| processed_outputs = processed_outputs[:difference_outputs] |
|
|
| return ( |
| tuple(processed_outputs) |
| if len(processed_outputs) > 1 |
| else processed_outputs[0] |
| ) |
|
|
| else: |
| if response.status_code == 409: |
| raise gr.Error( |
| f"Sorry, the Cog image is still processing. Try again in a bit." |
| ) |
| raise gr.Error(f"The submission failed! Error: {response.status_code}") |
|
|
| app = gr.Interface( |
| fn=predict, |
| inputs=inputs, |
| outputs=outputs, |
| title=title, |
| description=model_description, |
| allow_flagging="never", |
| ) |
| return app |
|
|
|
|
| def create_gradio_app_script( |
| inputs_string, |
| outputs_string, |
| api_url, |
| api_id=None, |
| replicate_token=None, |
| title="", |
| model_description="", |
| local_base=False, |
| hostname="0.0.0.0" |
| ): |
| headers = {"Content-Type": "application/json"} |
| if replicate_token: |
| headers["Authorization"] = f"Token {replicate_token}" |
|
|
| if local_base: |
| base_url = f'base_url = "http://{hostname}:7860"' |
| else: |
| base_url = """parsed_url = urlparse(str(request.url)) |
| base_url = parsed_url.scheme + "://" + parsed_url.netloc""" |
| headers_string = f"""headers = {headers}\n""" |
| api_id_value = f'payload["version"] = "{api_id}"' if api_id is not None else "" |
| definition_string = """expected_outputs = len(outputs) |
| def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):""" |
| payload_string = f"""payload = {{"input": {{}}}} |
| {api_id_value} |
| |
| {base_url} |
| for i, key in enumerate(names): |
| value = args[i] |
| if value and (os.path.exists(str(value))): |
| value = f"{{base_url}}/file=" + value |
| if value is not None and value != "": |
| payload["input"][key] = value\n""" |
|
|
| request_string = ( |
| f"""response = requests.post("{api_url}", headers=headers, json=payload)\n""" |
| ) |
|
|
| result_string = f""" |
| if response.status_code == 201: |
| follow_up_url = response.json()["urls"]["get"] |
| response = requests.get(follow_up_url, headers=headers) |
| while response.json()["status"] != "succeeded": |
| if response.json()["status"] == "failed": |
| raise gr.Error("The submission failed!") |
| response = requests.get(follow_up_url, headers=headers) |
| time.sleep(1) |
| if response.status_code == 200: |
| json_response = response.json() |
| #If the output component is JSON return the entire output response |
| if(outputs[0].get_config()["name"] == "json"): |
| return json_response["output"] |
| predict_outputs = parse_outputs(json_response["output"]) |
| processed_outputs = process_outputs(predict_outputs) |
| difference_outputs = expected_outputs - len(processed_outputs) |
| # If less outputs than expected, hide the extra ones |
| if difference_outputs > 0: |
| extra_outputs = [gr.update(visible=False)] * difference_outputs |
| processed_outputs.extend(extra_outputs) |
| # If more outputs than expected, cap the outputs to the expected number |
| elif difference_outputs < 0: |
| processed_outputs = processed_outputs[:difference_outputs] |
| |
| return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0] |
| else: |
| if(response.status_code == 409): |
| raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.") |
| raise gr.Error(f"The submission failed! Error: {{response.status_code}}")\n""" |
|
|
| interface_string = f"""title = "{title}" |
| model_description = "{model_description}" |
| |
| app = gr.Interface( |
| fn=predict, |
| inputs=inputs, |
| outputs=outputs, |
| title=title, |
| description=model_description, |
| allow_flagging="never", |
| ) |
| app.launch(share=True) |
| """ |
|
|
| app_string = f"""import gradio as gr |
| from urllib.parse import urlparse |
| import requests |
| import time |
| import os |
| |
| from utils.gradio_helpers import parse_outputs, process_outputs |
| |
| {inputs_string} |
| {outputs_string} |
| {definition_string} |
| {headers_string} |
| {payload_string} |
| {request_string} |
| {result_string} |
| {interface_string} |
| """ |
| return app_string |
|
|