Spaces:
Runtime error
Runtime error
| import base64 | |
| from io import BytesIO | |
| from pathlib import Path | |
| from urllib.parse import urlparse | |
| import dotenv | |
| import gradio as gr | |
| import requests | |
| from clients import get_client_module | |
| from hf_datasets import dataset_rootdir | |
| from omegaconf import DictConfig, OmegaConf | |
| from PIL import Image | |
| from prompts import get_prompt_module | |
| dotenv.load_dotenv() | |
| prompt_versions = [d.stem for d in Path("./prompts").iterdir() if d.is_file() and not d.name.startswith("_")] | |
| class ConfigManager: | |
| def __init__(self): | |
| self.configs: dict = {} # internal configs for all models | |
| self.ignore_keys = ["type", "client_name", "model_name"] | |
| # initialize configs | |
| self.update() | |
| def update(self): | |
| """Reload configs""" | |
| self.configs.clear() # remove cache | |
| # reload API-based models | |
| configs = OmegaConf.load("./model/api.yaml") | |
| configs = {key: configs[key] for key in configs if key not in self.ignore_keys} | |
| self.configs.update(configs) | |
| # reload HF-based models | |
| configs = OmegaConf.load("./model/hf.yaml") | |
| configs = {key: configs[key] for key in configs if key not in self.ignore_keys} | |
| self.configs.update({"huggingface": DictConfig(configs)}) | |
| def clients(self): | |
| """Display all available clients""" | |
| return list(self.configs.keys()) | |
| def models(self, client=None): | |
| if client is None: | |
| client = self.clients()[0] | |
| return list(self.configs[client].available_models) | |
| config_manager = ConfigManager() | |
| def link_client_and_model(client, model): # noqa | |
| all_models = config_manager.models(client) | |
| return gr.Dropdown(choices=all_models, value=all_models[0]) | |
| def display_prompt(prompt_version): | |
| prompt_module = get_prompt_module(prompt_version) | |
| description = prompt_module.description() | |
| return description | |
| def encode_image(image): | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def load_image(image_url_or_path, timeout=None): | |
| result = urlparse(image_url_or_path) | |
| if result.scheme in ("http", "https") and result.netloc and result.path: | |
| image = Image.open(BytesIO(requests.get(image_url_or_path, timeout=timeout).content)) | |
| elif Path(image_url_or_path).is_file(): | |
| image = Image.open(image_url_or_path) | |
| else: | |
| if image_url_or_path.startswith("data:image/"): | |
| image_url_or_path = image_url_or_path.split(",")[1] | |
| # Try to load as base64 | |
| try: | |
| base64_image = base64.decodebytes(image_url_or_path.encode()) | |
| image = Image.open(BytesIO(base64_image)) | |
| except Exception: | |
| raise gr.Error( | |
| "Incorrect image source. Must be a valid URL starting with `http://` or `https://`, " | |
| "a valid path to an image file, or a base64 encoded string." | |
| ) | |
| return image | |
| def llm_analyse(client, model, api_key, image, prompt): | |
| try: | |
| prompt_module = get_prompt_module(prompt) | |
| client_module = get_client_module(client) | |
| base64_image = f"data:image/png;base64,{encode_image(image)}" | |
| if api_key == "": | |
| api_key = None | |
| result = client_module.sync_generate(base64_image, prompt_module.messages_encoder, model, api_key=api_key) | |
| return result | |
| except Exception as e: | |
| return gr.Error(f"Error processing image: {e}") | |
| with gr.Blocks( | |
| theme=gr.themes.Default(primary_hue="orange"), | |
| css=""" | |
| #app-container { max-width: 1400px; margin: auto; padding: 10px; } | |
| #title { text-align: center; margin-bottom: 10px; font-size: 24px; } | |
| #groq-badge { text-align: center; margin-top: 10px; } | |
| .gr-button { border-radius: 15px; } | |
| .gr-input, .gr-box { border-radius: 10px; } | |
| .gr-form { gap: 5px; } | |
| .gr-block.gr-box { padding: 10px; } | |
| .gr-paddle { height: auto; } | |
| """, | |
| ) as demo: | |
| gr.Markdown("# Image Moderation WebUI", elem_id="title") | |
| # --------------- Client and Model Selection Block --------------- # | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| prompt_version_input = gr.Dropdown( | |
| prompt_versions, | |
| value="-- Please Select --", | |
| allow_custom_value=True, | |
| label="Choose Prompt:", | |
| ) | |
| client_input = gr.Dropdown( | |
| config_manager.clients(), label="Choose Client:", info="HuggingFace Requires a GPU" | |
| ) | |
| model_input = gr.Dropdown(config_manager.models(), label="Choose Model:") | |
| api_input = gr.Textbox( | |
| type="password", | |
| label="API Key:", | |
| info="Leave this field blank to use the default key, or if you are using HuggingFace", | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Image:", height=300, sources=["upload"]) | |
| url_input = gr.Textbox( | |
| label="or Paste Image URL, Local File Path, or Base64 String:", | |
| info="Press Enter to load the image", | |
| lines=1, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=160): | |
| pos_button = gr.Button("π Positive Demo") | |
| with gr.Column(scale=1, min_width=160): | |
| neg_button = gr.Button("π Negative Demo") | |
| with gr.Column(scale=5): | |
| prompt_text_input = gr.Textbox(label="or Paste Prompt Here:", lines=18) | |
| model_output = gr.Textbox(label="Model Output:", lines=18) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=120): | |
| analyze_button = gr.Button("π Analyze Image", variant="primary") | |
| with gr.Column(scale=1, min_width=120): | |
| clean_button = gr.Button("π§Ή Clean Output", variant="primary") | |
| client_input.change(fn=link_client_and_model, inputs=[client_input, model_input], outputs=model_input) | |
| prompt_version_input.input(fn=display_prompt, inputs=prompt_version_input, outputs=prompt_text_input) | |
| clean_button.click(fn=lambda: gr.Textbox(value=""), inputs=None, outputs=model_output) | |
| url_input.submit(fn=load_image, inputs=url_input, outputs=image_input) | |
| pos_button.click( | |
| fn=lambda: load_image(Path(dataset_rootdir, "semeval2022/demo-pos.jpg").as_posix()), | |
| inputs=None, | |
| outputs=image_input, | |
| ) | |
| neg_button.click( | |
| fn=lambda: load_image(Path(dataset_rootdir, "semeval2022/demo-neg.jpg").as_posix()), | |
| inputs=None, | |
| outputs=image_input, | |
| ) | |
| # ------------------------- Image Analysis Block ------------------------- # | |
| analyze_button.click( | |
| fn=llm_analyse, | |
| inputs=[client_input, model_input, api_input, image_input, prompt_version_input], | |
| outputs=model_output, | |
| ) | |
| demo.launch(share=False) | |