Spaces:
Runtime error
Runtime error
Upload src/app/pipeline.py with huggingface_hub
Browse files- src/app/pipeline.py +96 -0
src/app/pipeline.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import pickle
|
| 4 |
+
import shutil
|
| 5 |
+
from threading import Timer
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 8 |
+
app_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 9 |
+
|
| 10 |
+
from utils.inference_utils import gen_prog_ind
|
| 11 |
+
from utils.constants import TO_24
|
| 12 |
+
from inference import inference
|
| 13 |
+
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from inference.joint2smplx import process_file
|
| 16 |
+
|
| 17 |
+
def delete_folder(folder_path):
|
| 18 |
+
if os.path.exists(folder_path):
|
| 19 |
+
shutil.rmtree(folder_path)
|
| 20 |
+
|
| 21 |
+
def rand_folder_name():
|
| 22 |
+
import time
|
| 23 |
+
return str(time.time()).replace('.', '')
|
| 24 |
+
|
| 25 |
+
def pipeline(data, models, device, diffuser, **kwargs):
|
| 26 |
+
from app.process_data import get_a_sample
|
| 27 |
+
from app.setup_models import text_embedder, test_configs, normalize, denormalize
|
| 28 |
+
|
| 29 |
+
len_data = min(data['source']['transl'].shape[0]//((kwargs['SEQLEN']-2)*2), 4)
|
| 30 |
+
if len_data < 4:
|
| 31 |
+
return None # not enough data
|
| 32 |
+
|
| 33 |
+
joints_orig = get_a_sample(data['source'],
|
| 34 |
+
len_data,
|
| 35 |
+
kwargs['SEQLEN'],
|
| 36 |
+
smplx_pth=os.path.abspath(os.path.join(app_root, 'deps/smplx/models'))
|
| 37 |
+
).to(device)
|
| 38 |
+
|
| 39 |
+
joints_orig = normalize(joints_orig)
|
| 40 |
+
|
| 41 |
+
hint_text = data['text']
|
| 42 |
+
|
| 43 |
+
if data['prog_ind'] is None:
|
| 44 |
+
prog_ind = gen_prog_ind(num_cases=1, sublist_length=len_data)[0]
|
| 45 |
+
else:
|
| 46 |
+
prog_ind = data['prog_ind']
|
| 47 |
+
generated_samples, orig = inference.test_model(
|
| 48 |
+
models=models,
|
| 49 |
+
diffuser=diffuser,
|
| 50 |
+
normalizer=(normalize, denormalize),
|
| 51 |
+
configs=test_configs,
|
| 52 |
+
text_embedder=text_embedder,
|
| 53 |
+
hint_text=hint_text,
|
| 54 |
+
prog_ind=prog_ind,
|
| 55 |
+
joint_orig=joints_orig,
|
| 56 |
+
All_one_model=data['All_one_model'],
|
| 57 |
+
model_type=data['model_type']
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
generated_samples = generated_samples.reshape(1, -1, 28, 3)[..., TO_24, :].reshape(1, -1, 72)
|
| 61 |
+
orig = orig.reshape(1, -1, 28, 3)[..., TO_24, :].reshape(1, -1, 72)
|
| 62 |
+
|
| 63 |
+
combined_dict = {
|
| 64 |
+
'generated_samples': generated_samples,
|
| 65 |
+
'original_samples': orig,
|
| 66 |
+
'text' : hint_text,
|
| 67 |
+
}
|
| 68 |
+
# return combined_dict
|
| 69 |
+
|
| 70 |
+
input_folder = os.path.join(app_root, rand_folder_name())
|
| 71 |
+
output_folder = os.path.join(app_root, rand_folder_name())
|
| 72 |
+
|
| 73 |
+
if not os.path.exists(input_folder):
|
| 74 |
+
os.makedirs(input_folder)
|
| 75 |
+
with open(os.path.join(input_folder, 'temp.pkl'), 'wb') as file:
|
| 76 |
+
pickle.dump(combined_dict, file)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
j2s_config = OmegaConf.load(os.path.join(app_root, "configs/j2s.yaml"))
|
| 80 |
+
|
| 81 |
+
for file_name in os.listdir(input_folder):
|
| 82 |
+
if file_name.endswith('.pkl'):
|
| 83 |
+
process_file(file_path=input_folder,
|
| 84 |
+
file_name=file_name,
|
| 85 |
+
save_path=output_folder,
|
| 86 |
+
JointsToSMPLX_model_path=os.path.abspath(os.path.join(app_root, '..', j2s_config.JointsToSMPLX_model_path)),
|
| 87 |
+
smplx_path=os.path.abspath(os.path.join(app_root, '..', j2s_config.smplx_path)),
|
| 88 |
+
key_list = ['generated_samples'],
|
| 89 |
+
# remenber to remove original samples when using app
|
| 90 |
+
interp_s=j2s_config.interp_s,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
Timer(100, delete_folder, [input_folder]).start()
|
| 94 |
+
Timer(100, delete_folder, [output_folder]).start()
|
| 95 |
+
|
| 96 |
+
return os.path.join(output_folder, 'generated_samples/temp.pkl')
|