Yzy00518 commited on
Commit
b275b5c
·
1 Parent(s): f632f68

Upload src/app/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/app/pipeline.py +96 -0
src/app/pipeline.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pickle
4
+ import shutil
5
+ from threading import Timer
6
+
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+ app_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
9
+
10
+ from utils.inference_utils import gen_prog_ind
11
+ from utils.constants import TO_24
12
+ from inference import inference
13
+
14
+ from omegaconf import OmegaConf
15
+ from inference.joint2smplx import process_file
16
+
17
+ def delete_folder(folder_path):
18
+ if os.path.exists(folder_path):
19
+ shutil.rmtree(folder_path)
20
+
21
+ def rand_folder_name():
22
+ import time
23
+ return str(time.time()).replace('.', '')
24
+
25
+ def pipeline(data, models, device, diffuser, **kwargs):
26
+ from app.process_data import get_a_sample
27
+ from app.setup_models import text_embedder, test_configs, normalize, denormalize
28
+
29
+ len_data = min(data['source']['transl'].shape[0]//((kwargs['SEQLEN']-2)*2), 4)
30
+ if len_data < 4:
31
+ return None # not enough data
32
+
33
+ joints_orig = get_a_sample(data['source'],
34
+ len_data,
35
+ kwargs['SEQLEN'],
36
+ smplx_pth=os.path.abspath(os.path.join(app_root, 'deps/smplx/models'))
37
+ ).to(device)
38
+
39
+ joints_orig = normalize(joints_orig)
40
+
41
+ hint_text = data['text']
42
+
43
+ if data['prog_ind'] is None:
44
+ prog_ind = gen_prog_ind(num_cases=1, sublist_length=len_data)[0]
45
+ else:
46
+ prog_ind = data['prog_ind']
47
+ generated_samples, orig = inference.test_model(
48
+ models=models,
49
+ diffuser=diffuser,
50
+ normalizer=(normalize, denormalize),
51
+ configs=test_configs,
52
+ text_embedder=text_embedder,
53
+ hint_text=hint_text,
54
+ prog_ind=prog_ind,
55
+ joint_orig=joints_orig,
56
+ All_one_model=data['All_one_model'],
57
+ model_type=data['model_type']
58
+ )
59
+
60
+ generated_samples = generated_samples.reshape(1, -1, 28, 3)[..., TO_24, :].reshape(1, -1, 72)
61
+ orig = orig.reshape(1, -1, 28, 3)[..., TO_24, :].reshape(1, -1, 72)
62
+
63
+ combined_dict = {
64
+ 'generated_samples': generated_samples,
65
+ 'original_samples': orig,
66
+ 'text' : hint_text,
67
+ }
68
+ # return combined_dict
69
+
70
+ input_folder = os.path.join(app_root, rand_folder_name())
71
+ output_folder = os.path.join(app_root, rand_folder_name())
72
+
73
+ if not os.path.exists(input_folder):
74
+ os.makedirs(input_folder)
75
+ with open(os.path.join(input_folder, 'temp.pkl'), 'wb') as file:
76
+ pickle.dump(combined_dict, file)
77
+
78
+
79
+ j2s_config = OmegaConf.load(os.path.join(app_root, "configs/j2s.yaml"))
80
+
81
+ for file_name in os.listdir(input_folder):
82
+ if file_name.endswith('.pkl'):
83
+ process_file(file_path=input_folder,
84
+ file_name=file_name,
85
+ save_path=output_folder,
86
+ JointsToSMPLX_model_path=os.path.abspath(os.path.join(app_root, '..', j2s_config.JointsToSMPLX_model_path)),
87
+ smplx_path=os.path.abspath(os.path.join(app_root, '..', j2s_config.smplx_path)),
88
+ key_list = ['generated_samples'],
89
+ # remenber to remove original samples when using app
90
+ interp_s=j2s_config.interp_s,
91
+ )
92
+
93
+ Timer(100, delete_folder, [input_folder]).start()
94
+ Timer(100, delete_folder, [output_folder]).start()
95
+
96
+ return os.path.join(output_folder, 'generated_samples/temp.pkl')