| import subprocess |
| import os |
| import glob |
| import streamlit as st |
| import streamlit.components.v1 as components |
| import base64 |
|
|
| |
| def find_latest_file(base_path, extension): |
| list_of_files = glob.glob(f'{base_path}/*.{extension}') |
| if not list_of_files: |
| return None |
| latest_file = max(list_of_files, key=os.path.getctime) |
| return latest_file |
|
|
| |
| def generate_html(text_input, length): |
| command = [ |
| "python", "generate.py", |
| "--resume-pth", "output/vq/2023-07-19-04-17-17_12_VQVAE_20batchResetNRandom_8192_32/net_last.pth", |
| "--resume-trans", "output/t2m/2023-10-10-03-17-01_HML3D_44_crsAtt2lyr_mask0.5-1/net_last.pth", |
| "--text", text_input, |
| "--length", str(length) |
| ] |
| try: |
| result = subprocess.run(command, check=True, text=True, capture_output=True) |
| html_file = find_latest_file('output', 'html') |
| npy_file = find_latest_file('output', 'npy') |
| return html_file, npy_file |
| except subprocess.CalledProcessError as e: |
| st.error(f"Error: {e.stderr}") |
| return None, None |
|
|
| |
| def run_render_final(npy_file_path): |
| command = ["python", "render_final.py", npy_file_path] |
| try: |
| gif_res = subprocess.run(command, check=True, text=True, capture_output=True) |
| vid_file_path = find_latest_file('output', 'mp4') |
| return vid_file_path |
| except subprocess.CalledProcessError as e: |
| st.error(f"Error: {e.stderr}") |
| return None |
|
|
| |
| def gif_to_base64(gif_file_path): |
| with open(gif_file_path, "rb") as gif_file: |
| gif_bytes = gif_file.read() |
| base64_gif = base64.b64encode(gif_bytes).decode("utf-8") |
| return base64_gif |
|
|
| |
| if 'text_input' not in st.session_state: |
| st.session_state.text_input = "" |
| if 'length' not in st.session_state: |
| st.session_state.length = 156 |
|
|
| |
| def select_prompt(prompt, prompt_length): |
| st.session_state.text_input = prompt |
| st.session_state.length = prompt_length |
|
|
|
|
| |
| components.html("<h1 style='text-align: center; color: white;'>MMM Model Demo</h1>", height=100) |
|
|
| prompts = [ |
| ("A person walks forward then turns completely around and does a cartwheel", 196), |
| ("A person bouncing around while throwing jabs and upper cuts.", 196), |
| ("A person start to dance with legs", 176), |
| ("A person steps forward and leans over; they grab a cup with their left hand and empty it before putting it down and stepping back to their original position.", 156), |
| ("Walking forward and kicking foot.", 68), |
| ("A man walks forward, stumbles to the right, and then regains his balance and keeps walking forwards.", 92) |
| ] |
|
|
| col1, col2 = st.columns([6, 5]) |
|
|
| |
| input_placeholder = st.empty() |
|
|
| with col1: |
| input_placeholder = st.empty() |
|
|
| with col2: |
| st.write("Or choose a prompt:") |
| for prompt, prompt_length in prompts: |
| if st.button(prompt): |
| select_prompt(prompt, prompt_length) |
|
|
| |
| with input_placeholder.container(): |
| text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300) |
| length = st.number_input("Length of the generated motion:", value=st.session_state.length, key="length") |
|
|
| |
| button_col1, button_col2 = st.columns(2) |
|
|
| with button_col1: |
| if st.button("Generate HTML"): |
| if st.session_state.text_input and st.session_state.length: |
| html_file_path, npy_file_path = generate_html(st.session_state.text_input, st.session_state.length) |
| if html_file_path and npy_file_path: |
| st.session_state.html_file_path = html_file_path |
| st.session_state.npy_file_path = npy_file_path |
| |
| |
| with open(html_file_path, 'r') as file: |
| html_content = file.read() |
| st.session_state.html_content = html_content |
| else: |
| st.error("Error generating files. Please try again.") |
|
|
| with button_col2: |
| if st.button("Render Skeleton"): |
| if 'npy_file_path' in st.session_state and st.session_state.npy_file_path: |
| vid_file_path = run_render_final(st.session_state.npy_file_path) |
| if vid_file_path: |
| st.session_state.vid_file_path = vid_file_path |
| |
| else: |
| st.error("No npy file found. Please generate HTML first.") |
|
|
|
|
| |
| if 'html_content' in st.session_state or 'gif_base64' in st.session_state: |
| html_content = st.session_state.html_content if 'html_content' in st.session_state else "" |
| video_path = st.session_state.vid_file_path if 'vid_file_path' in st.session_state else "" |
| |
| disp_col1, disp_col2 = st.columns([1, 1]) |
| |
| with disp_col1: |
| components.html(html_content, height=800, scrolling=True) |
| |
| with disp_col2: |
| if video_path: |
| video_file = open(video_path, 'rb') |
| video_bytes = video_file.read() |
| |
| st.video(video_bytes, format='video/mp4', loop=True) |