Shi-Jie
update files
adc0e4e
import base64
from io import BytesIO
from pathlib import Path
from urllib.parse import urlparse
import dotenv
import gradio as gr
import requests
from clients import get_client_module
from hf_datasets import dataset_rootdir
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from prompts import get_prompt_module
dotenv.load_dotenv()
prompt_versions = [d.stem for d in Path("./prompts").iterdir() if d.is_file() and not d.name.startswith("_")]
class ConfigManager:
def __init__(self):
self.configs: dict = {} # internal configs for all models
self.ignore_keys = ["type", "client_name", "model_name"]
# initialize configs
self.update()
def update(self):
"""Reload configs"""
self.configs.clear() # remove cache
# reload API-based models
configs = OmegaConf.load("./model/api.yaml")
configs = {key: configs[key] for key in configs if key not in self.ignore_keys}
self.configs.update(configs)
# reload HF-based models
configs = OmegaConf.load("./model/hf.yaml")
configs = {key: configs[key] for key in configs if key not in self.ignore_keys}
self.configs.update({"huggingface": DictConfig(configs)})
def clients(self):
"""Display all available clients"""
return list(self.configs.keys())
def models(self, client=None):
if client is None:
client = self.clients()[0]
return list(self.configs[client].available_models)
config_manager = ConfigManager()
def link_client_and_model(client, model): # noqa
all_models = config_manager.models(client)
return gr.Dropdown(choices=all_models, value=all_models[0])
def display_prompt(prompt_version):
prompt_module = get_prompt_module(prompt_version)
description = prompt_module.description()
return description
def encode_image(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def load_image(image_url_or_path, timeout=None):
result = urlparse(image_url_or_path)
if result.scheme in ("http", "https") and result.netloc and result.path:
image = Image.open(BytesIO(requests.get(image_url_or_path, timeout=timeout).content))
elif Path(image_url_or_path).is_file():
image = Image.open(image_url_or_path)
else:
if image_url_or_path.startswith("data:image/"):
image_url_or_path = image_url_or_path.split(",")[1]
# Try to load as base64
try:
base64_image = base64.decodebytes(image_url_or_path.encode())
image = Image.open(BytesIO(base64_image))
except Exception:
raise gr.Error(
"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, "
"a valid path to an image file, or a base64 encoded string."
)
return image
def llm_analyse(client, model, api_key, image, prompt):
try:
prompt_module = get_prompt_module(prompt)
client_module = get_client_module(client)
base64_image = f"data:image/png;base64,{encode_image(image)}"
if api_key == "":
api_key = None
result = client_module.sync_generate(base64_image, prompt_module.messages_encoder, model, api_key=api_key)
return result
except Exception as e:
return gr.Error(f"Error processing image: {e}")
with gr.Blocks(
theme=gr.themes.Default(primary_hue="orange"),
css="""
#app-container { max-width: 1400px; margin: auto; padding: 10px; }
#title { text-align: center; margin-bottom: 10px; font-size: 24px; }
#groq-badge { text-align: center; margin-top: 10px; }
.gr-button { border-radius: 15px; }
.gr-input, .gr-box { border-radius: 10px; }
.gr-form { gap: 5px; }
.gr-block.gr-box { padding: 10px; }
.gr-paddle { height: auto; }
""",
) as demo:
gr.Markdown("# Image Moderation WebUI", elem_id="title")
# --------------- Client and Model Selection Block --------------- #
with gr.Row(equal_height=True):
with gr.Column(scale=3):
prompt_version_input = gr.Dropdown(
prompt_versions,
value="-- Please Select --",
allow_custom_value=True,
label="Choose Prompt:",
)
client_input = gr.Dropdown(
config_manager.clients(), label="Choose Client:", info="HuggingFace Requires a GPU"
)
model_input = gr.Dropdown(config_manager.models(), label="Choose Model:")
api_input = gr.Textbox(
type="password",
label="API Key:",
info="Leave this field blank to use the default key, or if you are using HuggingFace",
)
image_input = gr.Image(type="pil", label="Upload Image:", height=300, sources=["upload"])
url_input = gr.Textbox(
label="or Paste Image URL, Local File Path, or Base64 String:",
info="Press Enter to load the image",
lines=1,
)
with gr.Row():
with gr.Column(scale=1, min_width=160):
pos_button = gr.Button("πŸ‘ Positive Demo")
with gr.Column(scale=1, min_width=160):
neg_button = gr.Button("πŸ‘Ž Negative Demo")
with gr.Column(scale=5):
prompt_text_input = gr.Textbox(label="or Paste Prompt Here:", lines=18)
model_output = gr.Textbox(label="Model Output:", lines=18)
with gr.Row():
with gr.Column(scale=1, min_width=120):
analyze_button = gr.Button("πŸš€ Analyze Image", variant="primary")
with gr.Column(scale=1, min_width=120):
clean_button = gr.Button("🧹 Clean Output", variant="primary")
client_input.change(fn=link_client_and_model, inputs=[client_input, model_input], outputs=model_input)
prompt_version_input.input(fn=display_prompt, inputs=prompt_version_input, outputs=prompt_text_input)
clean_button.click(fn=lambda: gr.Textbox(value=""), inputs=None, outputs=model_output)
url_input.submit(fn=load_image, inputs=url_input, outputs=image_input)
pos_button.click(
fn=lambda: load_image(Path(dataset_rootdir, "semeval2022/demo-pos.jpg").as_posix()),
inputs=None,
outputs=image_input,
)
neg_button.click(
fn=lambda: load_image(Path(dataset_rootdir, "semeval2022/demo-neg.jpg").as_posix()),
inputs=None,
outputs=image_input,
)
# ------------------------- Image Analysis Block ------------------------- #
analyze_button.click(
fn=llm_analyse,
inputs=[client_input, model_input, api_input, image_input, prompt_version_input],
outputs=model_output,
)
demo.launch(share=False)