rdz-falcon commited on
Commit
1eea1ba
·
verified ·
1 Parent(s): 0d41680

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -166
app.py CHANGED
@@ -1,166 +1,166 @@
1
- import gradio as gr
2
- import torch
3
- import os
4
- import sys
5
- import warnings
6
- from pathlib import Path
7
-
8
- # Add root to path to allow imports from project root when running from demo-code/
9
- # or when running from root
10
- current_dir = os.path.dirname(os.path.abspath(__file__))
11
- parent_dir = os.path.dirname(current_dir)
12
- sys.path.append(current_dir)
13
- sys.path.append(parent_dir)
14
-
15
- # Import project modules
16
- try:
17
- from inference import load_trained_model, inference as run_inference_cmd
18
- from visualize import visualize
19
- from model import setup_model_and_tokenizer, get_motion_token_info
20
- from generate import generate_t2m
21
- from data import compute_length_stats, build_prompt_vocab, check_has_participant_id, load_dataset
22
- import config
23
- except ImportError as e:
24
- print(f"Error importing project modules: {e}")
25
- print("Make sure you are running this from the project root or have the project structure intact.")
26
-
27
- # Constants
28
- HF_REPO_ID = "rdz-falcon/SignMotionGPT"
29
- EPOCH_SUBFOLDER = "stage2/epoch-030"
30
-
31
- def load_model_from_hf(repo_id, subfolder, token=None):
32
- from transformers import AutoModelForCausalLM, AutoTokenizer
33
- print(f"Loading model from HF: {repo_id}/{subfolder}")
34
- try:
35
- tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
36
- model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
37
- return model, tokenizer
38
- except Exception as e:
39
- print(f"Error loading model: {e}")
40
- return None, None
41
-
42
- # Global model cache
43
- MODEL = None
44
- TOKENIZER = None
45
- MOTION_TOKEN_IDS = None
46
- MOT_BEGIN_ID = None
47
- MOT_END_ID = None
48
- CODEBOOK_SIZE = 512
49
-
50
- def init_model():
51
- global MODEL, TOKENIZER, MOTION_TOKEN_IDS, MOT_BEGIN_ID, MOT_END_ID
52
- if MODEL is not None:
53
- return
54
-
55
- token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
56
-
57
- # Load model/tokenizer
58
- MODEL, TOKENIZER = load_model_from_hf(HF_REPO_ID, EPOCH_SUBFOLDER, token)
59
-
60
- if MODEL is None:
61
- raise RuntimeError(f"Failed to load model from {HF_REPO_ID}/{EPOCH_SUBFOLDER}")
62
-
63
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
- MODEL.to(device)
65
- MODEL.eval()
66
-
67
- # Setup token info
68
- motion_token_ids = []
69
- for i in range(CODEBOOK_SIZE):
70
- t = f"<motion_{i}>"
71
- if t in TOKENIZER.get_vocab():
72
- motion_token_ids.append(TOKENIZER.convert_tokens_to_ids(t))
73
-
74
- MOTION_TOKEN_IDS = motion_token_ids
75
- MOT_BEGIN_ID = TOKENIZER.convert_tokens_to_ids("<MOT_BEGIN>") if "<MOT_BEGIN>" in TOKENIZER.get_vocab() else None
76
- MOT_END_ID = TOKENIZER.convert_tokens_to_ids("<MOT_END>") if "<MOT_END>" in TOKENIZER.get_vocab() else None
77
-
78
- print("Model initialized.")
79
-
80
- def generate_motion_app(text_prompt):
81
- if not text_prompt:
82
- return None, "Please enter a prompt."
83
-
84
- if MODEL is None:
85
- try:
86
- init_model()
87
- except Exception as e:
88
- return None, f"Model Initialization Failed: {e}"
89
-
90
- device = MODEL.device
91
- print(f"Generating for: {text_prompt}")
92
-
93
- try:
94
- generated_tokens = generate_t2m(
95
- model=MODEL,
96
- tokenizer=TOKENIZER,
97
- prompt_text=text_prompt,
98
- mot_begin_id=MOT_BEGIN_ID,
99
- mot_end_id=MOT_END_ID,
100
- motion_token_ids=MOTION_TOKEN_IDS,
101
- length_stats_by_text={}, # Fallback to global_median_len
102
- global_median_len=100, # Reasonable default
103
- prompt_vocab=None,
104
- has_pid=False,
105
- per_prompt_vocab=False # Allow all tokens
106
- )
107
- except Exception as e:
108
- return None, f"Generation Error: {e}"
109
-
110
- # Visualization
111
- try:
112
- # Ensure paths for VQ-VAE and SMPL-X
113
- # In HF Spaces, we assume these are in the repo (e.g., ./data)
114
- data_dir = os.environ.get("DATA_DIR", "data")
115
- vqvae_ckpt = os.path.join(data_dir, "vqvae_model.pt")
116
- stats_path = os.path.join(data_dir, "vqvae_stats.pt")
117
- smplx_dir = os.path.join(data_dir, "smplx_models")
118
-
119
- # Check existence
120
- missing = []
121
- if not os.path.exists(vqvae_ckpt): missing.append(vqvae_ckpt)
122
- if not os.path.exists(stats_path): missing.append(stats_path)
123
- if not os.path.exists(smplx_dir): missing.append(smplx_dir)
124
-
125
- if missing:
126
- return None, f"Missing visualization files in {data_dir}: {missing}. Please ensure they are uploaded to the Space."
127
-
128
- # Output to a temporary file
129
- # Gradio needs a file path or HTML string. visualize returns a Figure.
130
- output_html = "temp_viz.html"
131
-
132
- fig = visualize(
133
- tokens=generated_tokens,
134
- vqvae_ckpt=vqvae_ckpt,
135
- stats_path=stats_path,
136
- smplx_dir=smplx_dir,
137
- output_html=output_html,
138
- title=f"Motion: {text_prompt}",
139
- fps=20
140
- )
141
-
142
- if fig is None:
143
- return None, "Visualization failed (no frames produced)."
144
-
145
- return fig, f"Success! Generated tokens length: {len(generated_tokens.split())}"
146
-
147
- except Exception as e:
148
- return None, f"Visualization Error: {e}"
149
-
150
-
151
- # Gradio UI
152
- with gr.Interface(
153
- fn=generate_motion_app,
154
- inputs=gr.Textbox(label="Enter Motion Prompt", placeholder="e.g. walking forward"),
155
- outputs=[
156
- gr.Plot(label="Motion Visualization"),
157
- gr.Textbox(label="Status/Output")
158
- ],
159
- title="SignMotionGPT Demo",
160
- description="Generate Sign Language/Motion Avatars from Text. Using model checkpoint: epoch 30."
161
- ) as demo:
162
- pass
163
-
164
- if __name__ == "__main__":
165
- demo.launch()
166
-
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import sys
5
+ import warnings
6
+ from pathlib import Path
7
+
8
+ # Add root to path to allow imports from project root when running from demo-code/
9
+ # or when running from root
10
+ current_dir = os.path.dirname(os.path.abspath(__file__))
11
+ parent_dir = os.path.dirname(current_dir)
12
+ sys.path.append(current_dir)
13
+ sys.path.append(parent_dir)
14
+
15
+ # Import project modules
16
+ try:
17
+ from inference import load_trained_model, inference as run_inference_cmd
18
+ from visualize import visualize
19
+ from model import setup_model_and_tokenizer, get_motion_token_info
20
+ from generate import generate_t2m
21
+ from data import compute_length_stats, build_prompt_vocab, check_has_participant_id, load_dataset
22
+ import config
23
+ except ImportError as e:
24
+ print(f"Error importing project modules: {e}")
25
+ print("Make sure you are running this from the project root or have the project structure intact.")
26
+
27
+ # Constants
28
+ HF_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
29
+ EPOCH_SUBFOLDER = "stage2/epoch-030"
30
+
31
+ def load_model_from_hf(repo_id, subfolder, token=None):
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+ print(f"Loading model from HF: {repo_id}/{subfolder}")
34
+ try:
35
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
36
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
37
+ return model, tokenizer
38
+ except Exception as e:
39
+ print(f"Error loading model: {e}")
40
+ return None, None
41
+
42
+ # Global model cache
43
+ MODEL = None
44
+ TOKENIZER = None
45
+ MOTION_TOKEN_IDS = None
46
+ MOT_BEGIN_ID = None
47
+ MOT_END_ID = None
48
+ CODEBOOK_SIZE = 512
49
+
50
+ def init_model():
51
+ global MODEL, TOKENIZER, MOTION_TOKEN_IDS, MOT_BEGIN_ID, MOT_END_ID
52
+ if MODEL is not None:
53
+ return
54
+
55
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
56
+
57
+ # Load model/tokenizer
58
+ MODEL, TOKENIZER = load_model_from_hf(HF_REPO_ID, EPOCH_SUBFOLDER, token)
59
+
60
+ if MODEL is None:
61
+ raise RuntimeError(f"Failed to load model from {HF_REPO_ID}/{EPOCH_SUBFOLDER}")
62
+
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ MODEL.to(device)
65
+ MODEL.eval()
66
+
67
+ # Setup token info
68
+ motion_token_ids = []
69
+ for i in range(CODEBOOK_SIZE):
70
+ t = f"<motion_{i}>"
71
+ if t in TOKENIZER.get_vocab():
72
+ motion_token_ids.append(TOKENIZER.convert_tokens_to_ids(t))
73
+
74
+ MOTION_TOKEN_IDS = motion_token_ids
75
+ MOT_BEGIN_ID = TOKENIZER.convert_tokens_to_ids("<MOT_BEGIN>") if "<MOT_BEGIN>" in TOKENIZER.get_vocab() else None
76
+ MOT_END_ID = TOKENIZER.convert_tokens_to_ids("<MOT_END>") if "<MOT_END>" in TOKENIZER.get_vocab() else None
77
+
78
+ print("Model initialized.")
79
+
80
+ def generate_motion_app(text_prompt):
81
+ if not text_prompt:
82
+ return None, "Please enter a prompt."
83
+
84
+ if MODEL is None:
85
+ try:
86
+ init_model()
87
+ except Exception as e:
88
+ return None, f"Model Initialization Failed: {e}"
89
+
90
+ device = MODEL.device
91
+ print(f"Generating for: {text_prompt}")
92
+
93
+ try:
94
+ generated_tokens = generate_t2m(
95
+ model=MODEL,
96
+ tokenizer=TOKENIZER,
97
+ prompt_text=text_prompt,
98
+ mot_begin_id=MOT_BEGIN_ID,
99
+ mot_end_id=MOT_END_ID,
100
+ motion_token_ids=MOTION_TOKEN_IDS,
101
+ length_stats_by_text={}, # Fallback to global_median_len
102
+ global_median_len=100, # Reasonable default
103
+ prompt_vocab=None,
104
+ has_pid=False,
105
+ per_prompt_vocab=False # Allow all tokens
106
+ )
107
+ except Exception as e:
108
+ return None, f"Generation Error: {e}"
109
+
110
+ # Visualization
111
+ try:
112
+ # Ensure paths for VQ-VAE and SMPL-X
113
+ # In HF Spaces, we assume these are in the repo (e.g., ./data)
114
+ data_dir = os.environ.get("DATA_DIR", "data")
115
+ vqvae_ckpt = os.path.join(data_dir, "vqvae_model.pt")
116
+ stats_path = os.path.join(data_dir, "vqvae_stats.pt")
117
+ smplx_dir = os.path.join(data_dir, "smplx_models")
118
+
119
+ # Check existence
120
+ missing = []
121
+ if not os.path.exists(vqvae_ckpt): missing.append(vqvae_ckpt)
122
+ if not os.path.exists(stats_path): missing.append(stats_path)
123
+ if not os.path.exists(smplx_dir): missing.append(smplx_dir)
124
+
125
+ if missing:
126
+ return None, f"Missing visualization files in {data_dir}: {missing}. Please ensure they are uploaded to the Space."
127
+
128
+ # Output to a temporary file
129
+ # Gradio needs a file path or HTML string. visualize returns a Figure.
130
+ output_html = "temp_viz.html"
131
+
132
+ fig = visualize(
133
+ tokens=generated_tokens,
134
+ vqvae_ckpt=vqvae_ckpt,
135
+ stats_path=stats_path,
136
+ smplx_dir=smplx_dir,
137
+ output_html=output_html,
138
+ title=f"Motion: {text_prompt}",
139
+ fps=20
140
+ )
141
+
142
+ if fig is None:
143
+ return None, "Visualization failed (no frames produced)."
144
+
145
+ return fig, f"Success! Generated tokens length: {len(generated_tokens.split())}"
146
+
147
+ except Exception as e:
148
+ return None, f"Visualization Error: {e}"
149
+
150
+
151
+ # Gradio UI
152
+ with gr.Interface(
153
+ fn=generate_motion_app,
154
+ inputs=gr.Textbox(label="Enter Motion Prompt", placeholder="e.g. walking forward"),
155
+ outputs=[
156
+ gr.Plot(label="Motion Visualization"),
157
+ gr.Textbox(label="Status/Output")
158
+ ],
159
+ title="SignMotionGPT Demo",
160
+ description="Generate Sign Language/Motion Avatars from Text. Using model checkpoint: epoch 30."
161
+ ) as demo:
162
+ pass
163
+
164
+ if __name__ == "__main__":
165
+ demo.launch()
166
+