feature: added rendering
Browse files- app.py +78 -26
- generate.py +12 -5
- render_final.py +1 -2
- requirements.txt +67 -3
app.py
CHANGED
|
@@ -1,10 +1,19 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import streamlit.components.v1 as components
|
| 3 |
import subprocess
|
| 4 |
import os
|
| 5 |
import glob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
# Function to run the
|
| 8 |
def generate_html(text_input, length):
|
| 9 |
command = [
|
| 10 |
"python", "generate.py",
|
|
@@ -15,18 +24,30 @@ def generate_html(text_input, length):
|
|
| 15 |
]
|
| 16 |
try:
|
| 17 |
result = subprocess.run(command, check=True, text=True, capture_output=True)
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
except subprocess.CalledProcessError as e:
|
| 20 |
-
|
| 21 |
-
return None
|
| 22 |
|
| 23 |
-
# Function to
|
| 24 |
-
def
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
return None
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Initialize session state
|
| 32 |
if 'text_input' not in st.session_state:
|
|
@@ -34,16 +55,14 @@ if 'text_input' not in st.session_state:
|
|
| 34 |
if 'length' not in st.session_state:
|
| 35 |
st.session_state.length = 156
|
| 36 |
|
| 37 |
-
#
|
| 38 |
def select_prompt(prompt, prompt_length):
|
| 39 |
st.session_state.text_input = prompt
|
| 40 |
st.session_state.length = prompt_length
|
| 41 |
|
| 42 |
-
# app layout
|
| 43 |
-
components.html("""
|
| 44 |
-
<h1 style='text-align: center; color: white;'>MMM Model Demo</h1>
|
| 45 |
-
""", height=100)
|
| 46 |
|
|
|
|
|
|
|
| 47 |
|
| 48 |
prompts = [
|
| 49 |
("A person walks forward then turns completely around and does a cartwheel", 196),
|
|
@@ -73,13 +92,46 @@ with input_placeholder.container():
|
|
| 73 |
text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300)
|
| 74 |
length = st.number_input("Length of the generated text:", value=st.session_state.length, key="length")
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
else:
|
| 85 |
-
st.error("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import subprocess
|
| 2 |
import os
|
| 3 |
import glob
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import streamlit.components.v1 as components
|
| 6 |
+
import base64
|
| 7 |
+
|
| 8 |
+
# Function to find the latest file with a given extension in a specified directory
|
| 9 |
+
def find_latest_file(base_path, extension):
|
| 10 |
+
list_of_files = glob.glob(f'{base_path}/*.{extension}')
|
| 11 |
+
if not list_of_files:
|
| 12 |
+
return None
|
| 13 |
+
latest_file = max(list_of_files, key=os.path.getctime)
|
| 14 |
+
return latest_file
|
| 15 |
|
| 16 |
+
# Function to run the generate.py script and return paths of generated HTML and NPY files
|
| 17 |
def generate_html(text_input, length):
|
| 18 |
command = [
|
| 19 |
"python", "generate.py",
|
|
|
|
| 24 |
]
|
| 25 |
try:
|
| 26 |
result = subprocess.run(command, check=True, text=True, capture_output=True)
|
| 27 |
+
html_file = find_latest_file('output', 'html')
|
| 28 |
+
npy_file = find_latest_file('output', 'npy')
|
| 29 |
+
return html_file, npy_file
|
| 30 |
except subprocess.CalledProcessError as e:
|
| 31 |
+
st.error(f"Error: {e.stderr}")
|
| 32 |
+
return None, None
|
| 33 |
|
| 34 |
+
# Function to run render_final.py script with the generated NPY file
|
| 35 |
+
def run_render_final(npy_file_path):
|
| 36 |
+
command = ["python", "render_final.py", npy_file_path]
|
| 37 |
+
try:
|
| 38 |
+
gif_res = subprocess.run(command, check=True, text=True, capture_output=True)
|
| 39 |
+
gif_file_path = find_latest_file('output', 'gif')
|
| 40 |
+
return gif_file_path
|
| 41 |
+
except subprocess.CalledProcessError as e:
|
| 42 |
+
st.error(f"Error: {e.stderr}")
|
| 43 |
return None
|
| 44 |
+
|
| 45 |
+
# Function to convert GIF to base64
|
| 46 |
+
def gif_to_base64(gif_file_path):
|
| 47 |
+
with open(gif_file_path, "rb") as gif_file:
|
| 48 |
+
gif_bytes = gif_file.read()
|
| 49 |
+
base64_gif = base64.b64encode(gif_bytes).decode("utf-8")
|
| 50 |
+
return base64_gif
|
| 51 |
|
| 52 |
# Initialize session state
|
| 53 |
if 'text_input' not in st.session_state:
|
|
|
|
| 55 |
if 'length' not in st.session_state:
|
| 56 |
st.session_state.length = 156
|
| 57 |
|
| 58 |
+
# Handler to update session state and rerun the app
|
| 59 |
def select_prompt(prompt, prompt_length):
|
| 60 |
st.session_state.text_input = prompt
|
| 61 |
st.session_state.length = prompt_length
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
# App layout
|
| 65 |
+
components.html("<h1 style='text-align: center; color: white;'>MMM Model Demo</h1>", height=100)
|
| 66 |
|
| 67 |
prompts = [
|
| 68 |
("A person walks forward then turns completely around and does a cartwheel", 196),
|
|
|
|
| 92 |
text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300)
|
| 93 |
length = st.number_input("Length of the generated text:", value=st.session_state.length, key="length")
|
| 94 |
|
| 95 |
+
# Place the buttons side by side
|
| 96 |
+
button_col1, button_col2 = st.columns(2)
|
| 97 |
+
|
| 98 |
+
with button_col1:
|
| 99 |
+
if st.button("Generate HTML"):
|
| 100 |
+
if st.session_state.text_input and st.session_state.length:
|
| 101 |
+
html_file_path, npy_file_path = generate_html(st.session_state.text_input, st.session_state.length)
|
| 102 |
+
if html_file_path and npy_file_path:
|
| 103 |
+
st.session_state.html_file_path = html_file_path
|
| 104 |
+
st.session_state.npy_file_path = npy_file_path
|
| 105 |
+
|
| 106 |
+
# Display the HTML file content
|
| 107 |
+
with open(html_file_path, 'r') as file:
|
| 108 |
+
html_content = file.read()
|
| 109 |
+
st.session_state.html_content = html_content
|
| 110 |
+
else:
|
| 111 |
+
st.error("Error generating files. Please try again.")
|
| 112 |
+
|
| 113 |
+
with button_col2:
|
| 114 |
+
if st.button("Render Skeleton"):
|
| 115 |
+
if 'npy_file_path' in st.session_state and st.session_state.npy_file_path:
|
| 116 |
+
gif_file_path = run_render_final(st.session_state.npy_file_path)
|
| 117 |
+
if gif_file_path:
|
| 118 |
+
st.session_state.gif_file_path = gif_file_path
|
| 119 |
+
st.session_state.gif_base64 = gif_to_base64(gif_file_path)
|
| 120 |
else:
|
| 121 |
+
st.error("No npy file found. Please generate HTML first.")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# Display the results side by side using HTML components
|
| 125 |
+
if 'html_content' in st.session_state or 'gif_base64' in st.session_state:
|
| 126 |
+
html_content = st.session_state.html_content if 'html_content' in st.session_state else ""
|
| 127 |
+
gif_base64 = st.session_state.gif_base64 if 'gif_base64' in st.session_state else ""
|
| 128 |
+
|
| 129 |
+
disp_col1, disp_col2 = st.columns([1, 1])
|
| 130 |
+
|
| 131 |
+
with disp_col1:
|
| 132 |
+
components.html(html_content, height=800, scrolling=True)
|
| 133 |
+
|
| 134 |
+
with disp_col2:
|
| 135 |
+
if gif_base64:
|
| 136 |
+
gif_html = f'<img src="data:image/gif;base64,{gif_base64}" style="width:100%;">'
|
| 137 |
+
components.html(gif_html, height=800, scrolling=True)
|
generate.py
CHANGED
|
@@ -291,15 +291,22 @@ if __name__ == '__main__':
|
|
| 291 |
pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
|
| 292 |
num_joints = 22
|
| 293 |
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
-
|
|
|
|
| 297 |
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
print('File saved successfully')
|
| 300 |
|
| 301 |
-
std = np.load('./exit/t2m-std.npy')
|
| 302 |
-
mean = np.load('./exit/t2m-mean.npy')
|
| 303 |
file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
|
| 304 |
visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')
|
| 305 |
|
|
|
|
| 291 |
pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
|
| 292 |
num_joints = 22
|
| 293 |
|
| 294 |
+
std = np.load('./exit/t2m-std.npy')
|
| 295 |
+
mean = np.load('./exit/t2m-mean.npy')
|
| 296 |
+
|
| 297 |
+
norm_pose = pred_pose[0].detach().cpu().numpy() * std + mean
|
| 298 |
+
norm_pose = torch.tensor(norm_pose)
|
| 299 |
|
| 300 |
+
trimmed_pose = norm_pose[:args.length, :].unsqueeze(0).float()
|
| 301 |
+
print(trimmed_pose.shape)
|
| 302 |
|
| 303 |
+
converted_pose = recover_from_ric(trimmed_pose[0].detach().cpu(), num_joints).unsqueeze(0).numpy()
|
| 304 |
+
print(converted_pose.shape)
|
| 305 |
+
|
| 306 |
+
filename = '_'.join(args.text.split(' '))+'_'+str(args.length)
|
| 307 |
+
np.save('./output/'+filename+'.npy', converted_pose)
|
| 308 |
print('File saved successfully')
|
| 309 |
|
|
|
|
|
|
|
| 310 |
file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
|
| 311 |
visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')
|
| 312 |
|
render_final.py
CHANGED
|
@@ -164,8 +164,7 @@ def render(motions, outdir='test_vis', device_id=0, name=None, pred=True):
|
|
| 164 |
gif_path = os.path.join(outdir, f'{name}.gif')
|
| 165 |
imageio.mimsave(gif_path, out, fps=20)
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
|
| 170 |
|
| 171 |
if __name__ == "__main__":
|
|
|
|
| 164 |
gif_path = os.path.join(outdir, f'{name}.gif')
|
| 165 |
imageio.mimsave(gif_path, out, fps=20)
|
| 166 |
|
| 167 |
+
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -1,37 +1,101 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
beautifulsoup4==4.12.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
einops==0.8.0
|
| 5 |
fastjsonschema==2.20.0
|
|
|
|
|
|
|
| 6 |
freetype-py==2.4.0
|
| 7 |
fsspec==2024.6.1
|
| 8 |
ftfy==6.2.0
|
| 9 |
gdown==5.2.0
|
|
|
|
|
|
|
| 10 |
h5py==3.11.0
|
|
|
|
| 11 |
imageio==2.34.2
|
|
|
|
|
|
|
| 12 |
jsonschema==4.22.0
|
| 13 |
jsonschema-specifications==2023.12.1
|
| 14 |
jupyter_core==5.7.2
|
|
|
|
| 15 |
mapbox-earcut==1.0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
nbformat==5.10.4
|
|
|
|
| 17 |
numpy==1.23.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
pkgutil_resolve_name==1.3.10
|
| 19 |
-
|
|
|
|
| 20 |
protobuf==3.20.3
|
|
|
|
|
|
|
| 21 |
pyasn1-modules==0.2.8
|
|
|
|
| 22 |
pyglet==2.0.15
|
|
|
|
|
|
|
| 23 |
pyrender==0.1.45
|
|
|
|
|
|
|
|
|
|
| 24 |
referencing==0.35.1
|
| 25 |
regex==2024.5.15
|
|
|
|
| 26 |
requests-oauthlib==1.3.0
|
|
|
|
| 27 |
rpds-py==0.18.1
|
| 28 |
scipy==1.10.1
|
| 29 |
shapely==2.0.4
|
|
|
|
|
|
|
|
|
|
| 30 |
soupsieve==2.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
torch==2.3.1
|
| 32 |
torchaudio==2.3.1
|
| 33 |
torchvision==0.18.1
|
|
|
|
| 34 |
tqdm==4.66.4
|
| 35 |
traitlets==5.14.3
|
| 36 |
trimesh==4.4.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
wcwidth==0.2.13
|
|
|
|
|
|
| 1 |
+
altair==5.3.0
|
| 2 |
+
attrs==23.2.0
|
| 3 |
beautifulsoup4==4.12.3
|
| 4 |
+
blinker==1.8.2
|
| 5 |
+
cachetools==5.3.3
|
| 6 |
+
certifi==2024.6.2
|
| 7 |
+
charset-normalizer==3.3.2
|
| 8 |
+
chumpy==0.70
|
| 9 |
+
click==8.1.7
|
| 10 |
+
contourpy==1.1.1
|
| 11 |
+
cycler==0.12.1
|
| 12 |
einops==0.8.0
|
| 13 |
fastjsonschema==2.20.0
|
| 14 |
+
filelock==3.15.4
|
| 15 |
+
fonttools==4.53.0
|
| 16 |
freetype-py==2.4.0
|
| 17 |
fsspec==2024.6.1
|
| 18 |
ftfy==6.2.0
|
| 19 |
gdown==5.2.0
|
| 20 |
+
gitdb==4.0.11
|
| 21 |
+
GitPython==3.1.43
|
| 22 |
h5py==3.11.0
|
| 23 |
+
idna==3.7
|
| 24 |
imageio==2.34.2
|
| 25 |
+
importlib_resources==6.4.0
|
| 26 |
+
Jinja2==3.1.4
|
| 27 |
jsonschema==4.22.0
|
| 28 |
jsonschema-specifications==2023.12.1
|
| 29 |
jupyter_core==5.7.2
|
| 30 |
+
kiwisolver==1.4.5
|
| 31 |
mapbox-earcut==1.0.1
|
| 32 |
+
markdown-it-py==3.0.0
|
| 33 |
+
MarkupSafe==2.1.5
|
| 34 |
+
matplotlib==3.7.5
|
| 35 |
+
mdurl==0.1.2
|
| 36 |
+
mpmath==1.3.0
|
| 37 |
nbformat==5.10.4
|
| 38 |
+
networkx==3.1
|
| 39 |
numpy==1.23.3
|
| 40 |
+
nvidia-cublas-cu12==12.1.3.1
|
| 41 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
| 42 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
| 43 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
| 44 |
+
nvidia-cudnn-cu12==8.9.2.26
|
| 45 |
+
nvidia-cufft-cu12==11.0.2.54
|
| 46 |
+
nvidia-curand-cu12==10.3.2.106
|
| 47 |
+
nvidia-cusolver-cu12==11.4.5.107
|
| 48 |
+
nvidia-cusparse-cu12==12.1.0.106
|
| 49 |
+
nvidia-nccl-cu12==2.20.5
|
| 50 |
+
nvidia-nvjitlink-cu12==12.5.40
|
| 51 |
+
nvidia-nvtx-cu12==12.1.105
|
| 52 |
+
oauthlib==3.2.2
|
| 53 |
+
packaging==24.1
|
| 54 |
+
pandas==2.0.3
|
| 55 |
+
pillow==10.3.0
|
| 56 |
pkgutil_resolve_name==1.3.10
|
| 57 |
+
platformdirs==4.2.2
|
| 58 |
+
plotly==5.22.0
|
| 59 |
protobuf==3.20.3
|
| 60 |
+
pyarrow==16.1.0
|
| 61 |
+
pyasn1==0.4.8
|
| 62 |
pyasn1-modules==0.2.8
|
| 63 |
+
pydeck==0.9.1
|
| 64 |
pyglet==2.0.15
|
| 65 |
+
Pygments==2.18.0
|
| 66 |
+
pyparsing==3.1.2
|
| 67 |
pyrender==0.1.45
|
| 68 |
+
PySocks==1.7.1
|
| 69 |
+
python-dateutil==2.9.0.post0
|
| 70 |
+
pytz==2024.1
|
| 71 |
referencing==0.35.1
|
| 72 |
regex==2024.5.15
|
| 73 |
+
requests==2.32.3
|
| 74 |
requests-oauthlib==1.3.0
|
| 75 |
+
rich==13.7.1
|
| 76 |
rpds-py==0.18.1
|
| 77 |
scipy==1.10.1
|
| 78 |
shapely==2.0.4
|
| 79 |
+
six==1.16.0
|
| 80 |
+
smmap==5.0.1
|
| 81 |
+
smplx==0.1.28
|
| 82 |
soupsieve==2.5
|
| 83 |
+
streamlit==1.36.0
|
| 84 |
+
sympy==1.12.1
|
| 85 |
+
tenacity==8.4.2
|
| 86 |
+
toml==0.10.2
|
| 87 |
+
toolz==0.12.1
|
| 88 |
torch==2.3.1
|
| 89 |
torchaudio==2.3.1
|
| 90 |
torchvision==0.18.1
|
| 91 |
+
tornado==6.4.1
|
| 92 |
tqdm==4.66.4
|
| 93 |
traitlets==5.14.3
|
| 94 |
trimesh==4.4.1
|
| 95 |
+
triton==2.3.1
|
| 96 |
+
typing_extensions==4.12.2
|
| 97 |
+
tzdata==2024.1
|
| 98 |
+
urllib3==2.2.2
|
| 99 |
+
watchdog==4.0.1
|
| 100 |
wcwidth==0.2.13
|
| 101 |
+
zipp==3.19.2
|