Spaces:
Running
Running
| ## All Generation Gradio Interface | |
| import uuid | |
| import time | |
| from .utils import * | |
| from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger | |
| from .constants import IMAGE_DIR, OFFLINE_DIR, TEXT_PROMPT_PATH | |
| with open(TEXT_PROMPT_PATH, 'r') as f: | |
| prompt_list = json.load(f) | |
| class State: | |
| def __init__(self, | |
| model_name, i2s_mode=False, offline=False, | |
| prompt=None, image=None, offline_idx=None, | |
| normal_video=None , rgb_video=None): | |
| self.conv_id = uuid.uuid4().hex | |
| self.model_name = model_name | |
| self.i2s_mode = i2s_mode | |
| self.offline = offline | |
| self.prompt = prompt | |
| self.image = image | |
| self.offline_idx = offline_idx | |
| # self.output = None | |
| self.normal_video = normal_video | |
| self.rgb_video = rgb_video | |
| def dict(self): | |
| base = { | |
| "conv_id": self.conv_id, | |
| "model_name": self.model_name, | |
| "i2s_mode": self.i2s_mode, | |
| "offline": self.offline, | |
| "prompt": self.prompt | |
| } | |
| if not self.offline and not self.offline_idx: | |
| base['offline_idx'] = self.offline_idx | |
| return base | |
| # class StateI2S: | |
| # def __init__(self, model_name): | |
| # self.conv_id = uuid.uuid4().hex | |
| # self.model_name = model_name | |
| # self.image = None | |
| # self.output = None | |
| # def dict(self): | |
| # base = { | |
| # "conv_id": self.conv_id, | |
| # "model_name": self.model_name, | |
| # } | |
| # return base | |
| def sample_t2s_model(state_0, state_1, model_list): | |
| model_name_0, model_name_1 = random.sample(eval(model_list), 2) | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=False) | |
| state_0.model_name = model_name_0 | |
| state_0.i2s_mode = False | |
| state_1.model_name = model_name_1 | |
| state_1.i2s_mode = False | |
| return state_0, state_1, model_name_0, model_name_1 | |
| def sample_i2s_model(state_0, state_1, model_list): | |
| model_name_0, model_name_1 = random.sample(eval(model_list), 2) | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=True) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=True) | |
| state_0.model_name = model_name_0 | |
| state_0.i2s_mode = True | |
| state_1.model_name = model_name_1 | |
| state_1.i2s_mode = True | |
| return state_0, state_1, model_name_0, model_name_1 | |
| def sample_prompt(state, model_name): | |
| if state is None: | |
| state = State(model_name) | |
| idx = random.randint(0, len(prompt_list)-1) | |
| prompt = prompt_list[idx] | |
| state.model_name = model_name | |
| state.prompt = prompt | |
| return state, prompt | |
| def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1): | |
| if state_0 is None: | |
| state_0 = State(model_name_0) | |
| if state_1 is None: | |
| state_1 = State(model_name_1) | |
| idx = random.randint(0, len(prompt_list)-1) | |
| prompt = prompt_list[idx] | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| state_0.prompt, state_1.prompt = prompt, prompt | |
| return state_0, state_1, prompt | |
| def sample_image(state, model_name): | |
| if state is None: | |
| state = State(model_name) | |
| idx = random.randint(0, len(prompt_list)-1) | |
| prompt = prompt_list[idx] | |
| state.model_name = model_name | |
| state.prompt = prompt | |
| return state, prompt | |
| def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1): | |
| if state_0 is None: | |
| state_0 = State(model_name_0) | |
| if state_1 is None: | |
| state_1 = State(model_name_1) | |
| idx = random.randint(0, len(prompt_list)-1) | |
| prompt = prompt_list[idx] | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| state_0.prompt, state_1.prompt = prompt, prompt | |
| return state_0, state_1, prompt | |
| def generate_t2s(gen_func, render_func, | |
| state, | |
| text, | |
| model_name, | |
| request: gr.Request): | |
| if not text: | |
| raise gr.Warning("Prompt cannot be empty.") | |
| if not model_name: | |
| raise gr.Warning("Model name cannot be empty.") | |
| if state is None: | |
| state = State(model_name, i2s_mode=False, offline=False) | |
| ip = get_ip(request) | |
| t2s_logger.info(f"generate. ip: {ip}") | |
| state.model_name = model_name | |
| state.prompt = text | |
| try: | |
| idx = prompt_list.index(text) | |
| state.offline = True | |
| state.offline_idx = idx | |
| except: | |
| state.offline = False | |
| state.offline_idx = None | |
| if not state.offline and not state.offline_idx: | |
| start_time = time.time() | |
| normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4") | |
| rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4") | |
| state.normal_video = normal_video | |
| state.rgb_video = rgb_video | |
| yield state, normal_video, rgb_video | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "ip": ip, | |
| "model": model_name, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape = gen_func(text, model_name) | |
| generate_time = time.time() - start_time | |
| normal_video, rgb_video = render_func(shape, model_name) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state.normal_video = normal_video | |
| state.rgb_video = rgb_video | |
| yield state, normal_video, rgb_video | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "ip": ip, | |
| "model": model_name, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
| # os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(output_file) | |
| def generate_t2s_multi(gen_func, render_func, | |
| state_0, state_1, | |
| text, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if not text: | |
| raise gr.Warning("Prompt cannot be empty.") | |
| if not model_name_0: | |
| raise gr.Warning("Model name A cannot be empty.") | |
| if not model_name_1: | |
| raise gr.Warning("Model name B cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=False, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=False, offline=False) | |
| ip = get_ip(request) | |
| t2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.prompt, state_1.prompt = text, text | |
| try: | |
| idx = prompt_list.index(text) | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| except: | |
| state_0.offline, state_1.offline = False, False | |
| state_0.offline_idx, state_1.offline_idx = None, None | |
| if not state_0.offline and not state_0.offline_idx: | |
| start_time = time.time() | |
| normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1 | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1 = gen_func(text, model_name_0, model_name_1) | |
| generate_time = time.time() - start_time | |
| normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, | |
| shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1 | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
| # os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(output_file) | |
| def generate_t2s_multi_annoy(gen_func, render_func, | |
| state_0, state_1, | |
| text, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if not text: | |
| raise gr.Warning("Prompt cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=False, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=False, offline=False) | |
| ip = get_ip(request) | |
| t2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.prompt, state_1.prompt = text, text | |
| try: | |
| idx = prompt_list.index(text) | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| except: | |
| state_0.offline, state_1.offline = False, False | |
| state_0.offline_idx, state_1.offline_idx = None, None | |
| if not state_0.offline and not state_0.offline_idx: | |
| start_time = time.time() | |
| normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_1, rgb_video_1, \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1 = gen_func(text, model_name_0, model_name_1) | |
| generate_time = time.time() - start_time | |
| normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, | |
| shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
| # os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(output_file) | |
| def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request): | |
| if not image: | |
| raise gr.Warning("Image cannot be empty.") | |
| if not model_name: | |
| raise gr.Warning("Model name cannot be empty.") | |
| if state is None: | |
| state = State(model_name, i2s_mode=True, offline=False) | |
| ip = get_ip(request) | |
| t2s_logger.info(f"generate. ip: {ip}") | |
| state.model_name = model_name | |
| state.image = image | |
| if not state.offline and not state.offline_idx: | |
| start_time = time.time() | |
| normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4") | |
| rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4") | |
| state.normal_video = normal_video | |
| state.rgb_video = rgb_video | |
| yield state, normal_video, rgb_video | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "ip": ip, | |
| "model": model_name, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape = gen_func(image, model_name) | |
| generate_time = time.time() - start_time | |
| normal_video, rgb_video = render_func(shape, model_name) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state.normal_video = normal_video | |
| state.rgb_video = rgb_video | |
| yield state, normal_video, rgb_video | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "ip": ip, | |
| "model": model_name, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
| # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
| # with open(src_img_file, 'w') as f: | |
| # state.source_image.save(f, 'PNG') | |
| # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(src_img_file) | |
| # save_image_file_on_log_server(output_file) | |
| def generate_i2s_multi(gen_func, render_func, | |
| state_0, state_1, | |
| image, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if not image: | |
| raise gr.Warning("Image cannot be empty.") | |
| if not model_name_0: | |
| raise gr.Warning("Model name A cannot be empty.") | |
| if not model_name_1: | |
| raise gr.Warning("Model name B cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=True, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=True, offline=False) | |
| ip = get_ip(request) | |
| t2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.image, state_1.image = image, image | |
| if not state_0.offline and not state_0.offline_idx and \ | |
| not state_1.offline and not state_1.offline_idx: | |
| start_time = time.time() | |
| normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1 = gen_func(image, model_name_0, model_name_1) | |
| generate_time = time.time() - start_time | |
| normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, | |
| shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1 | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
| # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
| # with open(src_img_file, 'w') as f: | |
| # state.source_image.save(f, 'PNG') | |
| # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(src_img_file) | |
| # save_image_file_on_log_server(output_file) | |
| def generate_i2s_multi_annoy(gen_func, | |
| state_0, state_1, | |
| image, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if not image: | |
| raise gr.Warning("Image cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=True, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=True, offline=False) | |
| ip = get_ip(request) | |
| t2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.image, state_1.image = image, image | |
| if not state_0.offline and not state_0.offline_idx and \ | |
| not state_1.offline and not state_1.offline_idx: | |
| start_time = time.time() | |
| normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1 = gen_func(image, model_name_0, model_name_1) | |
| generate_time = time.time() - start_time | |
| normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, | |
| shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.normal_video = normal_video_0 | |
| state_0.rgb_video = rgb_video_0 | |
| state_1.normal_video = normal_video_1 | |
| state_1.rgb_video = rgb_video_1 | |
| yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "ip": get_ip(request), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| data_1 = { | |
| "ip": get_ip(request), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
| # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
| # with open(src_img_file, 'w') as f: | |
| # state.source_image.save(f, 'PNG') | |
| # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(src_img_file) | |
| # save_image_file_on_log_server(output_file) |