Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import uuid | |
| import json | |
| import time | |
| import asyncio | |
| import io | |
| import base64 | |
| from threading import Thread | |
| from typing import Union, Iterable, cast | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| MllamaForConditionalGeneration, | |
| GenerationMixin, | |
| AutoModel, | |
| AutoTokenizer, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| from transformers.image_utils import load_image | |
| from transformers.generation import GenerationConfig | |
| import huggingface_hub | |
| import openai | |
| from llm_json import parse_llm_json | |
| from data_experiments import all_products, all_experiments, filter_experiments, get_experiment | |
| from data_comments import all_comments, filter_comments | |
| import llm_messages_v1 | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| huggingface_api_key = os.getenv("MEL_HUGGING_FACE_API_KEY") | |
| huggingface_hub.login(token=huggingface_api_key) | |
| llm_messages = llm_messages_v1 | |
| # Constants for text generation | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 1024 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| MODEL_ID_QWEN_7B = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| processor_qwen_7b = None | |
| model_qwen_7b = None | |
| def load_qwen_7b_model(): | |
| global processor_qwen_7b, model_qwen_7b | |
| if processor_qwen_7b is None or model_qwen_7b is None: | |
| processor_qwen_7b = AutoProcessor.from_pretrained(MODEL_ID_QWEN_7B, trust_remote_code=True) | |
| model_qwen_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_QWEN_7B, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| MODEL_ID_QWEN_3B = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| processor_qwen_3b = None | |
| model_qwen_3b = None | |
| def load_qwen_3b_model(): | |
| global processor_qwen_3b, model_qwen_3b | |
| if processor_qwen_3b is None or model_qwen_3b is None: | |
| processor_qwen_3b = AutoProcessor.from_pretrained(MODEL_ID_QWEN_3B, trust_remote_code=True) | |
| model_qwen_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_QWEN_3B, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| MODEL_ID_LLAMA = "meta-llama/Llama-3.2-11B-Vision-Instruct" | |
| processor_llama = None | |
| model_llama = None | |
| def load_llama_model(): | |
| global processor_llama, model_llama | |
| if processor_llama is None or model_llama is None: | |
| processor_llama = AutoProcessor.from_pretrained(MODEL_ID_LLAMA, trust_remote_code=True) | |
| model_llama = MllamaForConditionalGeneration.from_pretrained( | |
| MODEL_ID_LLAMA, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ).to(device).eval() | |
| # load_qwen_3b_model() | |
| MODEL_ID_ENDPOINT = "huggingface/endpoint" | |
| MODEL_ID_OPENAI_GPT_41_NANO = "openai/gpt-4.1-nano" | |
| MODEL_ID_OPENAI_GPT_41_MINI = "openai/gpt-4.1-mini" | |
| MODEL_ID_OPENAI_GPT_41 = "openai/gpt-4.1" | |
| openai_client = openai.Client( | |
| api_key=openai_api_key, | |
| ) | |
| class LlmGenerator: | |
| def generate(self, messages: list[dict], generation_config: GenerationConfig) -> Iterable[str]: | |
| pass | |
| class TransformersLlmGenerator(LlmGenerator): | |
| def __init__(self, model_id: str, processor: AutoProcessor, model: GenerationMixin): | |
| self.model_id = model_id | |
| self.processor = processor | |
| self.model = model | |
| def generate(self, messages: list[dict], generation_config: GenerationConfig) -> Iterable[str]: | |
| if "Llama-3.2" in self.model_id: | |
| # https://github.com/huggingface/transformers/issues/34304 | |
| generation_config.repetition_penalty = 1.0 | |
| prompt_full = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| images = [] | |
| for message in messages: | |
| if message["role"] == "user": | |
| for content in message["content"]: | |
| if content["type"] == "image": | |
| images.append(content["image"]) | |
| if not images: | |
| images = None | |
| inputs = self.processor( | |
| text=[prompt_full], | |
| images=images, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=False, | |
| max_length=MAX_INPUT_TOKEN_LENGTH | |
| ).to(device) | |
| streamer = TextIteratorStreamer(self.processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "generation_config": generation_config, | |
| } | |
| thread = Thread( | |
| target=self.model.generate, | |
| kwargs=generation_kwargs | |
| ) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| time.sleep(0.01) | |
| yield buffer | |
| class OpenAiLlmGenerator(LlmGenerator): | |
| def __init__(self, openai_model_id: str): | |
| self.model_id = openai_model_id | |
| def generate(self, messages: list[dict], generation_config: GenerationConfig) -> Iterable[str]: | |
| for message in messages: | |
| for content in message["content"]: | |
| if content["type"] == "image": | |
| image = content["image"] | |
| del content["image"] | |
| image_data_url = self._convert_image_to_data_url(image) | |
| content["type"] = "image_url" | |
| content["image_url"] = {"url": image_data_url} | |
| response = openai_client.chat.completions.create( | |
| model=self.model_id, | |
| messages=messages, | |
| max_tokens=generation_config.max_new_tokens, | |
| temperature=generation_config.temperature, | |
| top_p=generation_config.top_p, | |
| stream=True | |
| ) | |
| buffer = "" | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| token = chunk.choices[0].delta.content | |
| buffer += token | |
| time.sleep(0.01) | |
| yield buffer | |
| def _convert_image_to_data_url(self, image: Image.Image) -> str: | |
| """ | |
| Convert a PIL Image to a base64 data URL. | |
| """ | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/jpeg;base64,{img_str}" | |
| class HuggingfaceEndpointLlmGenerator(LlmGenerator): | |
| def __init__(self, url: str, model_id: str): | |
| self.url = url | |
| self.model_id = model_id | |
| self.client = openai.Client( | |
| base_url = url, | |
| api_key=huggingface_api_key, | |
| ) | |
| def generate(self, messages: list[dict], generation_config: GenerationConfig) -> Iterable[str]: | |
| for message in messages: | |
| for content in message["content"]: | |
| if content["type"] == "image": | |
| image = content["image"] | |
| del content["image"] | |
| image_data_url = self._convert_image_to_data_url(image) | |
| content["type"] = "image_url" | |
| content["image_url"] = {"url": image_data_url} | |
| response = self.client.chat.completions.create( | |
| model=self.model_id, | |
| messages=messages, | |
| max_tokens=generation_config.max_new_tokens, | |
| temperature=generation_config.temperature, | |
| top_p=generation_config.top_p, | |
| stream=True | |
| ) | |
| buffer = "" | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| token = chunk.choices[0].delta.content | |
| buffer += token | |
| time.sleep(0.01) | |
| yield buffer | |
| def _convert_image_to_data_url(self, image: Image.Image) -> str: | |
| """ | |
| Convert a PIL Image to a base64 data URL. | |
| """ | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/jpeg;base64,{img_str}" | |
| def process_image( | |
| model_name: str, | |
| experiment_id: int, | |
| image: Image.Image, | |
| language: str = "en", | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2 | |
| ) -> Iterable[tuple[str, str]]: | |
| generator = _get_llm_generator(model_name) | |
| if not generator: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| experiment = get_experiment(experiment_id=experiment_id) | |
| if not experiment: | |
| yield "Experiment not found.", "Experiment not found." | |
| return | |
| generation_config = GenerationConfig( | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| lines = _process_image( | |
| generator=generator, | |
| experiment=experiment, | |
| image=image, | |
| comments=None, | |
| generation_config=generation_config, | |
| language=language, | |
| ) | |
| for line in lines: | |
| yield line | |
| def process_image_json( | |
| model_name: str, | |
| experiment_json_text: str, | |
| image: Image.Image, | |
| language: str = "en", | |
| rewrite_comment: bool = False, | |
| translate_comment: bool = False, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2 | |
| ) -> Iterable[tuple[str, str]]: | |
| generator = _get_llm_generator(model_name) | |
| if not generator: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| experiment = parse_llm_json(experiment_json_text) | |
| if not experiment or not isinstance(experiment, dict) or "id" not in experiment: | |
| yield "Could not parse experiment JSON.", "Could not parse experiment JSON." | |
| return | |
| generation_config = GenerationConfig( | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| lines = _process_image( | |
| generator=generator, | |
| experiment=experiment, | |
| image=image, | |
| comments=None, | |
| generation_config=generation_config, | |
| language=language, | |
| ) | |
| for line in lines: | |
| yield line | |
| def process_image_json_and_comments_json( | |
| model_name: str, | |
| experiment_json_text: str, | |
| image: Image.Image, | |
| language: str = "en", | |
| comments_json_text: str = "", | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2 | |
| ) -> Iterable[tuple[str, str]]: | |
| generator = _get_llm_generator(model_name) | |
| if not generator: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| experiment = parse_llm_json(experiment_json_text) | |
| if not experiment or not isinstance(experiment, dict) or "id" not in experiment: | |
| yield "Could not parse experiment JSON.", "Could not parse experiment JSON." | |
| return | |
| comments = parse_llm_json(comments_json_text) | |
| if comments is None or not isinstance(comments, list): | |
| comments = None | |
| if comments is None: | |
| comments = filter_comments( | |
| experiment_id=experiment["id"], | |
| ) | |
| generation_config = GenerationConfig( | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| lines = _process_image( | |
| generator=generator, | |
| experiment=experiment, | |
| image=image, | |
| comments=comments, | |
| generation_config=generation_config, | |
| language=language, | |
| ) | |
| for line in lines: | |
| yield line | |
| def _get_llm_generator(model_name: str) -> Union[LlmGenerator, None]: | |
| if model_name == MODEL_ID_QWEN_7B: | |
| load_qwen_7b_model() | |
| return TransformersLlmGenerator(model_name, processor_qwen_7b, model_qwen_7b) | |
| elif model_name == MODEL_ID_QWEN_3B: | |
| load_qwen_3b_model() | |
| return TransformersLlmGenerator(model_name, processor_qwen_3b, model_qwen_3b) | |
| elif model_name == MODEL_ID_LLAMA: | |
| load_llama_model() | |
| return TransformersLlmGenerator(model_name, processor_llama, model_llama) | |
| elif model_name.startswith("openai/"): | |
| open_ai_model_id = model_name.split("/")[1] | |
| return OpenAiLlmGenerator(open_ai_model_id) | |
| elif model_name == MODEL_ID_ENDPOINT: | |
| return HuggingfaceEndpointLlmGenerator( | |
| url="https://syrapffqxq1kmues.us-east-1.aws.endpoints.huggingface.cloud/v1/", | |
| model_id="ggml-org/Qwen2.5-VL-3B-Instruct-GGUF", | |
| ) | |
| else: | |
| return None | |
| def _process_image( | |
| generator: LlmGenerator, | |
| experiment: dict, | |
| image: Image.Image, | |
| comments: list[dict] | None, | |
| generation_config: GenerationConfig, | |
| language: str = "en", | |
| ) -> Iterable[tuple[str, str]]: | |
| full_text_lines = [] | |
| full_text = "" | |
| safety_result = process_image_safety_state( | |
| generator=generator, | |
| image=image, | |
| generateion_config=generation_config, | |
| ) | |
| safety_raw_text = None | |
| for raw_text in safety_result: | |
| yield full_text + raw_text, None | |
| safety_raw_text = raw_text | |
| if safety_raw_text: | |
| full_text_lines.append("Safety:") | |
| full_text_lines.append(safety_raw_text) | |
| full_text_lines.append("\n") | |
| full_text = "\n".join(full_text_lines) | |
| if not safety_raw_text: | |
| safety_json = None | |
| else: | |
| safety_json = parse_llm_json(safety_raw_text) | |
| sharing_result = process_image_sharing_state( | |
| generator=generator, | |
| image=image, | |
| generateion_config=generation_config, | |
| ) | |
| sharing_raw_text = None | |
| for raw_text in sharing_result: | |
| yield full_text + raw_text, None | |
| sharing_raw_text = raw_text | |
| if safety_raw_text: | |
| full_text_lines.append("Sharing:") | |
| full_text_lines.append(sharing_raw_text) | |
| full_text_lines.append("\n") | |
| full_text = "\n".join(full_text_lines) | |
| if not sharing_raw_text: | |
| sharing_json = None | |
| else: | |
| sharing_json = parse_llm_json(sharing_raw_text) | |
| approval_result = process_image_approval_state( | |
| generator=generator, | |
| experiment=experiment, | |
| image=image, | |
| comments=comments, | |
| additional_text=None, | |
| generation_config=generation_config, | |
| ) | |
| aprroval_raw_text = None | |
| for raw_text in approval_result: | |
| yield full_text + raw_text, None | |
| aprroval_raw_text = raw_text | |
| if safety_raw_text: | |
| full_text_lines.append("Approval:") | |
| full_text_lines.append(aprroval_raw_text) | |
| full_text_lines.append("\n") | |
| full_text = "\n".join(full_text_lines) | |
| if not aprroval_raw_text: | |
| approval_json = None | |
| else: | |
| approval_json = parse_llm_json(aprroval_raw_text) | |
| comment_json = None | |
| if approval_json["result"] in ["approved", "partial"] and "state" in approval_json: | |
| comment = None | |
| selected_comment = select_comment_from_comments( | |
| all_comments=comments or all_comments, | |
| experiment_id=experiment["id"], | |
| language=language, | |
| reason=approval_json["state"], | |
| ) | |
| if selected_comment: | |
| comment = selected_comment["comment"] | |
| if comment: | |
| full_text_lines.append("Comment:") | |
| full_text_lines.append(comment) | |
| full_text_lines.append("\n") | |
| full_text = "\n".join(full_text_lines) | |
| comment_json = { | |
| "text": comment, | |
| "reason": approval_json["state"], | |
| "original_text": selected_comment["comment"], | |
| } | |
| else: | |
| #TODO: comment for failed | |
| pass | |
| full_json = { | |
| "safety": safety_json, | |
| "sharing": sharing_json, | |
| "approval": approval_json, | |
| "comment": comment_json, | |
| } | |
| full_json_text = json.dumps(full_json, indent=2, ensure_ascii=False) | |
| yield full_text, full_json_text | |
| def process_image_safety_state( | |
| generator: LlmGenerator, | |
| image: Image.Image, | |
| generateion_config: GenerationConfig, | |
| ) -> Iterable[str]: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": format_llm_safety_human_message()}, | |
| ] | |
| }, | |
| ] | |
| generation = generate( | |
| generator=generator, | |
| messages=messages, | |
| generateion_config=generateion_config, | |
| ) | |
| return generation | |
| def process_image_sharing_state( | |
| generator: LlmGenerator, | |
| image: Image.Image, | |
| generateion_config: GenerationConfig, | |
| ) -> Iterable[str]: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": format_llm_sharing_human_message()}, | |
| ] | |
| }, | |
| ] | |
| generation = generate( | |
| generator=generator, | |
| messages=messages, | |
| generateion_config=generateion_config, | |
| ) | |
| return generation | |
| def process_image_approval_state( | |
| generator: LlmGenerator, | |
| experiment: dict, | |
| image: Image.Image, | |
| comments: list[dict] | None, | |
| additional_text: Union[str, None], | |
| generation_config: GenerationConfig, | |
| ) -> Iterable[str]: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| {"type": "text", "text": format_llm_approval_system_message(experiment)}, | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": format_llm_approval_human_message(experiment, comments, additional_text)}, | |
| ] | |
| }, | |
| ] | |
| generation = generate( | |
| generator=generator, | |
| messages=messages, | |
| generateion_config=generation_config, | |
| ) | |
| return generation | |
| def generate( | |
| generator: LlmGenerator, | |
| messages: list[dict], | |
| generateion_config: GenerationConfig, | |
| ) -> Iterable[str]: | |
| return generator.generate( | |
| messages=messages, | |
| generation_config=generateion_config, | |
| ) | |
| def select_comment_from_comments( | |
| all_comments: list[dict] | None, | |
| experiment_id: int, | |
| language: str, | |
| reason: str, | |
| ) -> dict | None: | |
| comments = filter_comments( | |
| comments=all_comments, | |
| experiment_id=experiment_id, | |
| reason=reason, | |
| language=language, | |
| ) | |
| if not comments: | |
| comments = filter_comments( | |
| experiment_id=experiment_id, | |
| reason=reason, | |
| language="en", | |
| ) | |
| if not comments: | |
| return None | |
| comment = random.choice(comments) | |
| return comment | |
| def format_llm_safety_human_message() -> str: | |
| return llm_messages.format_safety_human_message() | |
| def format_llm_sharing_human_message() -> str: | |
| return llm_messages.format_sharing_human_message() | |
| def format_llm_approval_system_message(experiment: dict) -> str: | |
| return llm_messages.format_approval_system_message(experiment) | |
| def format_llm_approval_human_message(experiment: dict, comments: list[dict] | None, additional_query: Union[str, None]) -> str: | |
| parts = [] | |
| if additional_query: | |
| parts.append(f"{additional_query}") | |
| reasons = [comment["reason"] for comment in (comments or [])] | |
| reasons = list(set(reasons)) | |
| parts.append(llm_messages.format_approval_human_message(experiment, reasons)) | |
| combined = "\n\n".join(parts) | |
| return combined | |
| def build_experiments_for_dropdown() -> list[dict]: | |
| products = all_products | |
| sorted_products = sorted(products) | |
| experiments = filter_experiments( | |
| experiments=all_experiments, | |
| ) | |
| sorted_experiments = sorted(experiments, key=lambda x: (sorted_products.index(x["product_id"]), x["id"])) | |
| return sorted_experiments | |
| def format_experiment_dropdown_title(experiment: dict) -> str: | |
| id = str(experiment["id"]) | |
| title = experiment.get("title", "") | |
| product = experiment.get("product_id", "") | |
| return f"{id}. {title} ({product})" | |
| def build_experiemnts_dropdown_choices() -> list[tuple[str, int]]: | |
| experiments = build_experiments_for_dropdown() | |
| return [(format_experiment_dropdown_title(exp), exp["id"]) for exp in experiments] | |
| def format_llm_full_text_preview(experiment_id: int, additional_query: str) -> str: | |
| experiment = get_experiment(experiment_id=experiment_id) | |
| if not experiment: | |
| return "Experiment not found." | |
| parts = [] | |
| parts.append("SAFETY:") | |
| parts.append(format_llm_safety_human_message()) | |
| parts.append("-------\n") | |
| parts.append("SHARING:") | |
| parts.append(format_llm_sharing_human_message()) | |
| parts.append("-------\n") | |
| parts.append("APPROVAL:") | |
| parts.append(format_llm_approval_system_message(experiment)) | |
| parts.append(format_llm_approval_human_message(experiment, comments=[], additional_query=additional_query)) | |
| parts.append("-------\n") | |
| combined = "\n\n".join(parts) | |
| return combined | |
| # Create the Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# **Experiments photo moderation**") | |
| with gr.Row(): | |
| with gr.Column(): | |
| experiment_dropdown = gr.Dropdown( | |
| choices=build_experiemnts_dropdown_choices(), | |
| value=None, | |
| label="Select Experiment", | |
| interactive=True, | |
| ) | |
| language = gr.Radio( | |
| choices=["en", "fr", "ru", "de", "es", "zh", "id", "th", "ja"], | |
| value="en", | |
| label="Language", | |
| ) | |
| rewrite_comment = gr.Checkbox( | |
| label="Rewrite comment", | |
| value=False | |
| ) | |
| translate_comment = gr.Checkbox( | |
| label="Translate comment", | |
| value=False, | |
| ) | |
| image_upload = gr.Image(type="pil", label="Image") | |
| submit_button = gr.Button("Submit", elem_classes="submit-btn") | |
| model_choice = gr.Radio( | |
| choices=[ | |
| MODEL_ID_ENDPOINT, | |
| MODEL_ID_QWEN_7B, | |
| MODEL_ID_QWEN_3B, | |
| MODEL_ID_LLAMA, | |
| MODEL_ID_OPENAI_GPT_41_NANO, | |
| MODEL_ID_OPENAI_GPT_41_MINI, | |
| MODEL_ID_OPENAI_GPT_41 | |
| ], | |
| label="Select Model", | |
| value=MODEL_ID_ENDPOINT | |
| ) | |
| with gr.Accordion("Advanced options", open=False): | |
| max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) | |
| temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) | |
| top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) | |
| top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) | |
| repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) | |
| with gr.Accordion("Full prompt", open=False): | |
| full_prompt_textbox = gr.Markdown() | |
| with gr.Accordion("Experiment json", open=False): | |
| experiment_json_text = gr.Textbox(label="Experiment JSON", interactive=True, lines=10) | |
| comments_json_text = gr.Textbox(label="Comments JSON", interactive=True, lines=10) | |
| submit_json_button = gr.Button("Submit JSON", elem_classes="submit-btn") | |
| submit_json_2_button = gr.Button("Submit 2 JSON", elem_classes="submit-btn") | |
| with gr.Column(): | |
| with gr.Column(elem_classes="canvas-output"): | |
| gr.Markdown("## Output") | |
| output = gr.Textbox(label="Raw Output", interactive=False, lines=2, scale=2) | |
| with gr.Accordion("(Result.md)", open=False): | |
| markdown_output = gr.Markdown() | |
| experiment_dropdown.change( | |
| fn=format_llm_full_text_preview, | |
| inputs=[experiment_dropdown], | |
| outputs=full_prompt_textbox, | |
| ) | |
| submit_button.click( | |
| fn=process_image, | |
| inputs=[ | |
| model_choice, | |
| experiment_dropdown, | |
| image_upload, | |
| language, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty | |
| ], | |
| outputs=[ | |
| output, | |
| markdown_output | |
| ] | |
| ) | |
| submit_json_button.click( | |
| fn=process_image_json, | |
| inputs=[ | |
| model_choice, | |
| experiment_json_text, | |
| image_upload, | |
| language, | |
| rewrite_comment, | |
| translate_comment, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty | |
| ], | |
| outputs=[ | |
| output, | |
| markdown_output | |
| ] | |
| ) | |
| submit_json_2_button.click( | |
| fn=process_image_json_and_comments_json, | |
| inputs=[ | |
| model_choice, | |
| experiment_json_text, | |
| image_upload, | |
| language, | |
| comments_json_text, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty | |
| ], | |
| outputs=[ | |
| output, | |
| markdown_output | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=30).launch(share=True, mcp_server=False, ssr_mode=False, show_error=True) |