Spaces:
Running
Running
| import datetime | |
| import time | |
| import json | |
| import uuid | |
| import gradio as gr | |
| import regex as re | |
| from .utils import * | |
| from .log_utils import build_logger | |
| from .constants import IMAGE_DIR | |
| from diffusers.utils import load_image | |
| igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle | |
| def save_any_image(image_file, file_path): | |
| if isinstance(image_file, str): | |
| image = load_image(image_file) | |
| image.save(file_path, 'JPEG') | |
| else: | |
| image_file.save(file_path, 'JPEG') | |
| def vote_last_response_igm(states, vote_type, anony, request: gr.Request): | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(time.time(), 4), | |
| "type": vote_type, | |
| "models": [x.name for x in states], | |
| "states": [{} for x in states], | |
| "anony": anony, | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| # append_json_item_on_log_server(data, get_conv_log_filename()) | |
| ## Image Generation Multi (IGM) Side-by-Side and Battle | |
| def leftvote_last_response_igm( | |
| state0, state1, request: gr.Request | |
| ): | |
| igm_logger.info(f"leftvote (named). ip: {get_ip(request)}") | |
| vote_last_response_igm( | |
| [state0, state1], "leftvote", False, request | |
| ) | |
| return (disable_btn,) * 3 + ( | |
| gr.Markdown(f"### ⬆ Model A: {state0.name}", visible=True), | |
| gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
| ) | |
| def rightvote_last_response_igm( | |
| state0, state1, request: gr.Request | |
| ): | |
| igm_logger.info(f"rightvote (named). ip: {get_ip(request)}") | |
| vote_last_response_igm( | |
| [state0, state1], "rightvote", False, request | |
| ) | |
| return (disable_btn,) * 3 + ( | |
| gr.Markdown(f"### ⬇ Model B: {state0.name}", visible=True), | |
| gr.Markdown(f"### ⬆ Model B: {state1.name}", visible=True) | |
| ) | |
| def bothbad_vote_last_response_igm( | |
| state0, state1, request: gr.Request | |
| ): | |
| igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") | |
| vote_last_response_igm( | |
| [state0, state1], "bothbad_vote", False, request | |
| ) | |
| return (disable_btn,) * 3 + ( | |
| gr.Markdown(f"### ⬇ Model A: {state0.name}", visible=True), | |
| gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
| ) | |
| def leftvote_last_response_igm_anony( | |
| state0, state1, request: gr.Request | |
| ): | |
| igm_logger.info(f"leftvote (named). ip: {get_ip(request)}") | |
| vote_last_response_igm( | |
| [state0, state1], "leftvote", True, request | |
| ) | |
| return (disable_btn,) * 3 + ( | |
| gr.Markdown(f"### ⬆ Model A: {state0.name}", visible=True), | |
| gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
| ) | |
| def rightvote_last_response_igm_anony( | |
| state0, state1, request: gr.Request | |
| ): | |
| igm_logger.info(f"rightvote (named). ip: {get_ip(request)}") | |
| vote_last_response_igm( | |
| [state0, state1], "rightvote", True, request | |
| ) | |
| return (disable_btn,) * 3 + ( | |
| gr.Markdown(f"### ⬇ Model B: {state0.name}", visible=True), | |
| gr.Markdown(f"### ⬆ Model B: {state1.name}", visible=True) | |
| ) | |
| def bothbad_vote_last_response_igm_anony( | |
| state0, state1, request: gr.Request | |
| ): | |
| igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") | |
| vote_last_response_igm( | |
| [state0, state1], "bothbad_vote", True, request | |
| ) | |
| return (disable_btn,) * 3 + ( | |
| gr.Markdown(f"### ⬇ Model A: {state0.name}", visible=True), | |
| gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
| ) | |
| share_js = """ | |
| function (a, b, c, d) { | |
| const captureElement = document.querySelector('#share-region-named'); | |
| html2canvas(captureElement) | |
| .then(canvas => { | |
| canvas.style.display = 'none' | |
| document.body.appendChild(canvas) | |
| return canvas | |
| }) | |
| .then(canvas => { | |
| const image = canvas.toDataURL('image/png') | |
| const a = document.createElement('a') | |
| a.setAttribute('download', 'chatbot-arena.png') | |
| a.setAttribute('href', image) | |
| a.click() | |
| canvas.remove() | |
| }); | |
| return [a, b, c, d]; | |
| } | |
| """ | |
| def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request): | |
| igm_logger.info(f"share (anony). ip: {get_ip(request)}") | |
| if state0 is not None and state1 is not None: | |
| vote_last_response_igm( | |
| [state0, state1], "share", [model_selector0, model_selector1], request | |
| ) | |
| ## All Generation Gradio Interface | |
| class ImageStateIG: | |
| def __init__(self, model_name): | |
| self.conv_id = uuid.uuid4().hex | |
| self.model_name = model_name | |
| self.prompt = None | |
| self.output = None | |
| def dict(self): | |
| base = { | |
| "conv_id": self.conv_id, | |
| "model_name": self.model_name, | |
| "prompt": self.prompt | |
| } | |
| return base | |
| class ImageStateIE: | |
| def __init__(self, model_name): | |
| self.conv_id = uuid.uuid4().hex | |
| self.model_name = model_name | |
| self.source_prompt = None | |
| self.target_prompt = None | |
| self.instruct_prompt = None | |
| self.source_image = None | |
| self.output = None | |
| def dict(self): | |
| base = { | |
| "conv_id": self.conv_id, | |
| "model_name": self.model_name, | |
| "source_prompt": self.source_prompt, | |
| "target_prompt": self.target_prompt, | |
| "instruct_prompt": self.instruct_prompt | |
| } | |
| return base | |
| class VideoStateVG: | |
| def __init__(self, model_name): | |
| self.conv_id = uuid.uuid4().hex | |
| self.model_name = model_name | |
| self.prompt = None | |
| self.output = None | |
| def dict(self): | |
| base = { | |
| "conv_id": self.conv_id, | |
| "model_name": self.model_name, | |
| "prompt": self.prompt | |
| } | |
| return base | |
| def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): | |
| if not text: | |
| raise gr.Warning("Prompt cannot be empty.") | |
| if not model_name0: | |
| raise gr.Warning("Model name A cannot be empty.") | |
| if not model_name1: | |
| raise gr.Warning("Model name B cannot be empty.") | |
| state0 = ImageStateIG(model_name0) | |
| state1 = ImageStateIG(model_name1) | |
| ip = get_ip(request) | |
| igm_logger.info(f"generate. ip: {ip}") | |
| start_tstamp = time.time() | |
| # Remove ### Model (A|B): from model name | |
| model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
| model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
| generated_image0, generated_image1 = gen_func(text, model_name0, model_name1) | |
| state0.prompt = text | |
| state1.prompt = text | |
| state0.output = generated_image0 | |
| state1.output = generated_image1 | |
| state0.model_name = model_name0 | |
| state1.model_name = model_name1 | |
| yield state0, state1, generated_image0, generated_image1 | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name0, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state0.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name1, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state1.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| for i, state in enumerate([state0, state1]): | |
| output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| with open(output_file, 'w') as f: | |
| save_any_image(state.output, f) | |
| save_image_file_on_log_server(output_file) | |
| def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
| if not model_name0: | |
| raise gr.Warning("Model name A cannot be empty.") | |
| if not model_name1: | |
| raise gr.Warning("Model name B cannot be empty.") | |
| state0 = ImageStateIG(model_name0) | |
| state1 = ImageStateIG(model_name1) | |
| ip = get_ip(request) | |
| igm_logger.info(f"generate. ip: {ip}") | |
| start_tstamp = time.time() | |
| # Remove ### Model (A|B): from model name | |
| model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
| model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
| generated_image0, generated_image1, text = gen_func(model_name0, model_name1) | |
| state0.prompt = text | |
| state1.prompt = text | |
| state0.output = generated_image0 | |
| state1.output = generated_image1 | |
| state0.model_name = model_name0 | |
| state1.model_name = model_name1 | |
| yield state0, state1, generated_image0, generated_image1, text | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name0, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state0.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name1, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state1.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| for i, state in enumerate([state0, state1]): | |
| output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| with open(output_file, 'w') as f: | |
| save_any_image(state.output, f) | |
| save_image_file_on_log_server(output_file) | |
| def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): | |
| if not text: | |
| raise gr.Warning("Prompt cannot be empty.") | |
| state0 = ImageStateIG(model_name0) | |
| state1 = ImageStateIG(model_name1) | |
| ip = get_ip(request) | |
| igm_logger.info(f"generate. ip: {ip}") | |
| start_tstamp = time.time() | |
| model_name0 = "" | |
| model_name1 = "" | |
| generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1) | |
| state0.prompt = text | |
| state1.prompt = text | |
| state0.output = generated_image0 | |
| state1.output = generated_image1 | |
| state0.model_name = model_name0 | |
| state1.model_name = model_name1 | |
| yield state0, state1, generated_image0, generated_image1, \ | |
| gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name0, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state0.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name1, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state1.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| for i, state in enumerate([state0, state1]): | |
| output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| with open(output_file, 'w') as f: | |
| save_any_image(state.output, f) | |
| save_image_file_on_log_server(output_file) | |
| def generate_igm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): | |
| state0 = ImageStateIG(model_name0) | |
| state1 = ImageStateIG(model_name1) | |
| ip = get_ip(request) | |
| igm_logger.info(f"generate. ip: {ip}") | |
| start_tstamp = time.time() | |
| model_name0 = re.sub(r"### Model A: ", "", model_name0) | |
| model_name1 = re.sub(r"### Model B: ", "", model_name1) | |
| generated_image0, generated_image1, model_name0, model_name1, text = gen_func(model_name0, model_name1) | |
| state0.prompt = text | |
| state1.prompt = text | |
| state0.output = generated_image0 | |
| state1.output = generated_image1 | |
| state0.model_name = model_name0 | |
| state1.model_name = model_name1 | |
| yield state0, state1, generated_image0, generated_image1, text,\ | |
| gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}") | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name0, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state0.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name1, | |
| "gen_params": {}, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state1.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| for i, state in enumerate([state0, state1]): | |
| output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' | |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| with open(output_file, 'w') as f: | |
| save_any_image(state.output, f) | |
| save_image_file_on_log_server(output_file) | |