| import spaces |
| import argparse |
| from ast import parse |
| import datetime |
| import json |
| import os |
| import time |
| import hashlib |
| import re |
|
|
| import gradio as gr |
| import requests |
| import random |
| from filelock import FileLock |
| from io import BytesIO |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| from constants import LOGDIR |
| from utils import ( |
| build_logger, |
| server_error_msg, |
| violates_moderation, |
| moderation_msg, |
| load_image_from_base64, |
| get_log_filename, |
| ) |
| from conversation import Conversation |
|
|
| logger = build_logger("gradio_web_server", "gradio_web_server.log") |
|
|
| headers = {"User-Agent": "InternVL-Chat Client"} |
|
|
| no_change_btn = gr.Button() |
| enable_btn = gr.Button(interactive=True) |
| disable_btn = gr.Button(interactive=False) |
|
|
|
|
| @spaces.GPU(duration=10) |
| def make_zerogpu_happy(): |
| pass |
|
|
|
|
| def write2file(path, content): |
| lock = FileLock(f"{path}.lock") |
| with lock: |
| with open(path, "a") as fout: |
| fout.write(content) |
|
|
|
|
| get_window_url_params = """ |
| function() { |
| const params = new URLSearchParams(window.location.search); |
| url_params = Object.fromEntries(params); |
| console.log(url_params); |
| return url_params; |
| } |
| """ |
|
|
|
|
| def init_state(state=None): |
| if state is not None: |
| del state |
| return Conversation() |
|
|
|
|
| def find_bounding_boxes(state, response): |
| pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>") |
| matches = pattern.findall(response) |
| results = [] |
| for match in matches: |
| results.append((match[0], eval(match[1]))) |
| returned_image = None |
| latest_image = state.get_images(source=state.USER)[-1] |
| returned_image = latest_image.copy() |
| width, height = returned_image.size |
| draw = ImageDraw.Draw(returned_image) |
| for result in results: |
| line_width = max(1, int(min(width, height) / 200)) |
| random_color = ( |
| random.randint(0, 128), |
| random.randint(0, 128), |
| random.randint(0, 128), |
| ) |
| category_name, coordinates = result |
| coordinates = [ |
| ( |
| float(x[0]) / 1000, |
| float(x[1]) / 1000, |
| float(x[2]) / 1000, |
| float(x[3]) / 1000, |
| ) |
| for x in coordinates |
| ] |
| coordinates = [ |
| ( |
| int(x[0] * width), |
| int(x[1] * height), |
| int(x[2] * width), |
| int(x[3] * height), |
| ) |
| for x in coordinates |
| ] |
| for box in coordinates: |
| draw.rectangle(box, outline=random_color, width=line_width) |
| font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2)) |
| text_size = font.getbbox(category_name) |
| text_width, text_height = ( |
| text_size[2] - text_size[0], |
| text_size[3] - text_size[1], |
| ) |
| text_position = (box[0], max(0, box[1] - text_height)) |
| draw.rectangle( |
| [ |
| text_position, |
| (text_position[0] + text_width, text_position[1] + text_height), |
| ], |
| fill=random_color, |
| ) |
| draw.text(text_position, category_name, fill="white", font=font) |
| return returned_image if len(matches) > 0 else None |
|
|
|
|
| def vote_last_response(state, liked, request: gr.Request): |
| conv_data = { |
| "tstamp": round(time.time(), 4), |
| "like": liked, |
| "model": 'InternVL2.5-78B', |
| "state": state.dict(), |
| "ip": request.client.host, |
| } |
| write2file(get_log_filename(), json.dumps(conv_data) + "\n") |
|
|
|
|
| def upvote_last_response(state, request: gr.Request): |
| logger.info(f"upvote. ip: {request.client.host}") |
| vote_last_response(state, True, request) |
| textbox = gr.MultimodalTextbox(value=None, interactive=True) |
| return (textbox,) + (disable_btn,) * 3 |
|
|
|
|
| def downvote_last_response(state, request: gr.Request): |
| logger.info(f"downvote. ip: {request.client.host}") |
| vote_last_response(state, False, request) |
| textbox = gr.MultimodalTextbox(value=None, interactive=True) |
| return (textbox,) + (disable_btn,) * 3 |
|
|
|
|
| def vote_selected_response( |
| state, request: gr.Request, data: gr.LikeData |
| ): |
| logger.info( |
| f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}" |
| ) |
| conv_data = { |
| "tstamp": round(time.time(), 4), |
| "like": data.liked, |
| "index": data.index, |
| "model": 'InternVL2.5-78B', |
| "state": state.dict(), |
| "ip": request.client.host, |
| } |
| write2file(get_log_filename(), json.dumps(conv_data) + "\n") |
| return |
|
|
|
|
| def flag_last_response(state, request: gr.Request): |
| logger.info(f"flag. ip: {request.client.host}") |
| vote_last_response(state, "flag", request) |
| textbox = gr.MultimodalTextbox(value=None, interactive=True) |
| return (textbox,) + (disable_btn,) * 3 |
|
|
|
|
| def regenerate(state, image_process_mode, request: gr.Request): |
| logger.info(f"regenerate. ip: {request.client.host}") |
| |
| state.update_message(Conversation.ASSISTANT, content='', image=None, idx=-1) |
| prev_human_msg = state.messages[-2] |
| if type(prev_human_msg[1]) in (tuple, list): |
| prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) |
| state.skip_next = False |
| textbox = gr.MultimodalTextbox(value=None, interactive=True) |
| return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5 |
|
|
|
|
| def clear_history(request: gr.Request): |
| logger.info(f"clear_history. ip: {request.client.host}") |
| state = init_state() |
| textbox = gr.MultimodalTextbox(value=None, interactive=True) |
| return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5 |
|
|
|
|
| def add_text(state, message, system_prompt, request: gr.Request): |
| print(f"state: {state}") |
| if not state: |
| state = init_state() |
| images = message.get("files", []) |
| text = message.get("text", "").strip() |
| logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") |
| |
| textbox = gr.MultimodalTextbox(value=None, interactive=False) |
| if len(text) <= 0 and len(images) == 0: |
| state.skip_next = True |
| return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5 |
| if args.moderate: |
| flagged = violates_moderation(text) |
| if flagged: |
| state.skip_next = True |
| textbox = gr.MultimodalTextbox( |
| value={"text": moderation_msg}, interactive=True |
| ) |
| return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5 |
| images = [Image.open(path).convert("RGB") for path in images] |
|
|
| if len(images) > 0 and len(state.get_images(source=state.USER)) > 0: |
| state = init_state(state) |
| state.set_system_message(system_prompt) |
| state.append_message(Conversation.USER, text, images) |
| state.skip_next = False |
| return (state, state.to_gradio_chatbot(), textbox) + ( |
| disable_btn, |
| ) * 5 |
|
|
|
|
| def http_bot( |
| state, |
| temperature, |
| top_p, |
| repetition_penalty, |
| max_new_tokens, |
| max_input_tiles, |
| request: gr.Request, |
| ): |
| model_name = 'InternVL2.5-78B' |
| logger.info(f"http_bot. ip: {request.client.host}") |
| start_tstamp = time.time() |
| if hasattr(state, "skip_next") and state.skip_next: |
| |
| yield ( |
| state, |
| state.to_gradio_chatbot(), |
| gr.MultimodalTextbox(interactive=False), |
| ) + (no_change_btn,) * 5 |
| return |
|
|
| worker_addr = os.environ.get("WORKER_ADDR", "") |
| api_token = os.environ.get("API_TOKEN", "") |
| headers = {"Authorization": f"{api_token}", "Content-Type": "application/json"} |
|
|
| |
| if worker_addr == "": |
| |
| state.update_message(Conversation.ASSISTANT, server_error_msg) |
| yield ( |
| state, |
| state.to_gradio_chatbot(), |
| gr.MultimodalTextbox(interactive=False), |
| disable_btn, |
| disable_btn, |
| disable_btn, |
| enable_btn, |
| enable_btn, |
| ) |
| return |
|
|
| all_images = state.get_images(source=state.USER) |
| all_image_paths = [state.save_image(image) for image in all_images] |
|
|
| |
| pload = { |
| "model": model_name, |
| "messages": state.get_prompt_v2(inlude_image=True, max_dynamic_patch=max_input_tiles), |
| "temperature": float(temperature), |
| "top_p": float(top_p), |
| "max_tokens": max_new_tokens, |
| "repetition_penalty": repetition_penalty, |
| "stream": True |
| } |
| logger.info(f"==== request ====\n{pload}") |
| state.append_message(Conversation.ASSISTANT, state.streaming_placeholder) |
| yield ( |
| state, |
| state.to_gradio_chatbot(), |
| gr.MultimodalTextbox(interactive=False), |
| ) + (disable_btn,) * 5 |
|
|
| try: |
| |
| response = requests.post(worker_addr, json=pload, headers=headers, stream=True, timeout=40) |
| finnal_output = '' |
| for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"): |
| if chunk: |
| chunk = chunk.decode() |
| if chunk == 'data: [DONE]': |
| break |
| if chunk.startswith("data:"): |
| chunk = chunk[5:] |
| chunk = json.loads(chunk) |
| output = chunk['choices'][0]['delta']['content'] |
| finnal_output += output |
| |
| state.update_message(Conversation.ASSISTANT, finnal_output + state.streaming_placeholder, None) |
| yield ( |
| state, |
| state.to_gradio_chatbot(), |
| gr.MultimodalTextbox(interactive=False), |
| ) + (disable_btn,) * 5 |
| except requests.exceptions.RequestException as e: |
| state.update_message(Conversation.ASSISTANT, server_error_msg, None) |
| yield ( |
| state, |
| state.to_gradio_chatbot(), |
| gr.MultimodalTextbox(interactive=True), |
| ) + ( |
| disable_btn, |
| disable_btn, |
| disable_btn, |
| enable_btn, |
| enable_btn, |
| ) |
| return |
|
|
| ai_response = state.return_last_message() |
| if "<ref>" in ai_response: |
| returned_image = find_bounding_boxes(state, ai_response) |
| returned_image = [returned_image] if returned_image else [] |
| state.update_message(Conversation.ASSISTANT, ai_response, returned_image) |
|
|
| state.end_of_current_turn() |
|
|
| yield ( |
| state, |
| state.to_gradio_chatbot(), |
| gr.MultimodalTextbox(interactive=True), |
| ) + (enable_btn,) * 5 |
|
|
| finish_tstamp = time.time() |
| logger.info(f"{finnal_output}") |
| data = { |
| "tstamp": round(finish_tstamp, 4), |
| "like": None, |
| "model": model_name, |
| "start": round(start_tstamp, 4), |
| "finish": round(start_tstamp, 4), |
| "state": state.dict(), |
| "images": all_image_paths, |
| "ip": request.client.host, |
| } |
| write2file(get_log_filename(), json.dumps(data) + "\n") |
|
|
| |
| title_html = """ |
| <img src="https://internvl.opengvlab.com/assets/logo-47b364d3.jpg" style="width: 280px; height: 70px;"> |
| <p>InternVL2.5 Expanding Performance Boundaries of Open-Source Multimodal Models with Model, Data, and Test-Time Scaling</p> |
| <a href="https://internvl.github.io/blog/2024-12-05-InternVL-2.5/">[📜 InternVL Blog]</a> |
| <a href="https://internvl.opengvlab.com/">[🌟 Official Demo]</a> |
| <a href="https://github.com/OpenGVLab/InternVL?tab=readme-ov-file#quick-start-with-huggingface">[🚀 Quick Start]</a> |
| """ |
|
|
|
|
| |
| block_css = """ |
| .gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;}; |
| #buttons button { |
| min-width: min(120px,100%); |
| } |
| |
| .gradient-text { |
| font-size: 28px; |
| width: auto; |
| font-weight: bold; |
| background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet); |
| background-clip: text; |
| -webkit-background-clip: text; |
| color: transparent; |
| } |
| |
| .plain-text { |
| font-size: 22px; |
| width: auto; |
| font-weight: bold; |
| } |
| """ |
|
|
| js = """ |
| function createWaveAnimation() { |
| const text = document.getElementById('text'); |
| var i = 0; |
| setInterval(function() { |
| const colors = [ |
| 'red, orange, yellow, green, blue, indigo, violet, purple', |
| 'orange, yellow, green, blue, indigo, violet, purple, red', |
| 'yellow, green, blue, indigo, violet, purple, red, orange', |
| 'green, blue, indigo, violet, purple, red, orange, yellow', |
| 'blue, indigo, violet, purple, red, orange, yellow, green', |
| 'indigo, violet, purple, red, orange, yellow, green, blue', |
| 'violet, purple, red, orange, yellow, green, blue, indigo', |
| 'purple, red, orange, yellow, green, blue, indigo, violet', |
| ]; |
| const angle = 45; |
| const colorIndex = i % colors.length; |
| text.style.background = `linear-gradient(${angle}deg, ${colors[colorIndex]})`; |
| text.style.webkitBackgroundClip = 'text'; |
| text.style.backgroundClip = 'text'; |
| text.style.color = 'transparent'; |
| text.style.fontSize = '28px'; |
| text.style.width = 'auto'; |
| text.textContent = 'InternVL2'; |
| text.style.fontWeight = 'bold'; |
| i += 1; |
| }, 200); |
| const params = new URLSearchParams(window.location.search); |
| url_params = Object.fromEntries(params); |
| // console.log(url_params); |
| // console.log('hello world...'); |
| // console.log(window.location.search); |
| // console.log('hello world...'); |
| // alert(window.location.search) |
| // alert(url_params); |
| return url_params; |
| } |
| |
| """ |
|
|
|
|
| def build_demo(): |
| textbox = gr.MultimodalTextbox( |
| interactive=True, |
| file_types=["image", "video"], |
| placeholder="Enter message or upload file...", |
| show_label=False, |
| ) |
|
|
| with gr.Blocks( |
| title="InternVL-Chat", |
| theme=gr.themes.Default(), |
| css=block_css, |
| ) as demo: |
| state = gr.State() |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| |
| gr.HTML(title_html) |
|
|
| with gr.Accordion("Settings", open=False) as setting_row: |
| system_prompt = gr.Textbox( |
| value="请尽可能详细地回答用户的问题。", |
| label="System Prompt", |
| interactive=True, |
| ) |
| temperature = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.2, |
| step=0.1, |
| interactive=True, |
| label="Temperature", |
| ) |
| top_p = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.7, |
| step=0.1, |
| interactive=True, |
| label="Top P", |
| ) |
| repetition_penalty = gr.Slider( |
| minimum=1.0, |
| maximum=1.5, |
| value=1.1, |
| step=0.02, |
| interactive=True, |
| label="Repetition penalty", |
| ) |
| max_output_tokens = gr.Slider( |
| minimum=0, |
| maximum=4096, |
| value=1024, |
| step=64, |
| interactive=True, |
| label="Max output tokens", |
| ) |
| max_input_tiles = gr.Slider( |
| minimum=1, |
| maximum=32, |
| value=12, |
| step=1, |
| interactive=True, |
| label="Max input tiles (control the image size)", |
| ) |
| examples = gr.Examples( |
| examples=[ |
| [ |
| { |
| "files": [ |
| "gallery/14.jfif", |
| ], |
| "text": "Please help me analyze this picture.", |
| } |
| ], |
| [ |
| { |
| "files": [ |
| "gallery/1-2.PNG", |
| ], |
| "text": "Implement this flow chart using python", |
| } |
| ], |
| [ |
| { |
| "files": [ |
| "gallery/15.PNG", |
| ], |
| "text": "Please help me analyze this picture.", |
| } |
| ], |
| ], |
| inputs=[textbox], |
| ) |
|
|
| with gr.Column(scale=8): |
| chatbot = gr.Chatbot( |
| elem_id="chatbot", |
| label="InternVL", |
| height=580, |
| show_copy_button=True, |
| show_share_button=True, |
| avatar_images=[ |
| "assets/human.png", |
| "assets/assistant.png", |
| ], |
| bubble_full_width=False, |
| ) |
| with gr.Row(): |
| with gr.Column(scale=8): |
| textbox.render() |
| with gr.Column(scale=1, min_width=50): |
| submit_btn = gr.Button(value="Send", variant="primary") |
| with gr.Row(elem_id="buttons") as button_row: |
| upvote_btn = gr.Button(value="👍 Upvote", interactive=False) |
| downvote_btn = gr.Button(value="👎 Downvote", interactive=False) |
| flag_btn = gr.Button(value="⚠️ Flag", interactive=False) |
| |
| regenerate_btn = gr.Button( |
| value="🔄 Regenerate", interactive=False |
| ) |
| clear_btn = gr.Button(value="🗑️ Clear", interactive=False) |
|
|
| url_params = gr.JSON(visible=False) |
|
|
| |
| btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] |
| upvote_btn.click( |
| upvote_last_response, |
| [state], |
| [textbox, upvote_btn, downvote_btn, flag_btn], |
| ) |
| downvote_btn.click( |
| downvote_last_response, |
| [state], |
| [textbox, upvote_btn, downvote_btn, flag_btn], |
| ) |
| chatbot.like( |
| vote_selected_response, |
| [state], |
| [], |
| ) |
| flag_btn.click( |
| flag_last_response, |
| [state], |
| [textbox, upvote_btn, downvote_btn, flag_btn], |
| ) |
| regenerate_btn.click( |
| regenerate, |
| [state, system_prompt], |
| [state, chatbot, textbox] + btn_list, |
| ).then( |
| http_bot, |
| [ |
| state, |
| temperature, |
| top_p, |
| repetition_penalty, |
| max_output_tokens, |
| max_input_tiles, |
| ], |
| [state, chatbot, textbox] + btn_list, |
| ) |
| clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) |
|
|
| textbox.submit( |
| add_text, |
| [state, textbox, system_prompt], |
| [state, chatbot, textbox] + btn_list, |
| ).then( |
| http_bot, |
| [ |
| state, |
| temperature, |
| top_p, |
| repetition_penalty, |
| max_output_tokens, |
| max_input_tiles, |
| ], |
| [state, chatbot, textbox] + btn_list, |
| ) |
| submit_btn.click( |
| add_text, |
| [state, textbox, system_prompt], |
| [state, chatbot, textbox] + btn_list, |
| ).then( |
| http_bot, |
| [ |
| state, |
| temperature, |
| top_p, |
| repetition_penalty, |
| max_output_tokens, |
| max_input_tiles, |
| ], |
| [state, chatbot, textbox] + btn_list, |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=7860) |
| parser.add_argument("--concurrency-count", type=int, default=10) |
| parser.add_argument("--share", action="store_true") |
| parser.add_argument("--moderate", action="store_true") |
| args = parser.parse_args() |
| logger.info(f"args: {args}") |
|
|
| logger.info(args) |
| demo = build_demo() |
| demo.queue(api_open=False).launch( |
| server_name=args.host, |
| server_port=args.port, |
| share=args.share, |
| max_threads=args.concurrency_count, |
| ) |
|
|
|
|