jbilcke-hf commited on
Commit
84a1d8e
·
verified ·
1 Parent(s): 3b5290e

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +81 -43
demo.py CHANGED
@@ -1,14 +1,20 @@
1
  from huggingface_hub import InferenceClient
2
  import base64
3
  import os
 
4
  from pathlib import Path
5
  import time
6
 
7
  def save_video(base64_video: str, output_path: str):
8
  """Save base64 encoded video to a file"""
 
 
 
 
9
  video_bytes = base64.b64decode(base64_video)
10
  with open(output_path, "wb") as f:
11
  f.write(video_bytes)
 
12
 
13
  def generate_video(
14
  prompt: str,
@@ -16,11 +22,13 @@ def generate_video(
16
  token: str = None,
17
  resolution: str = "1280x720",
18
  video_length: int = 129,
19
- num_inference_steps: int = 50,
20
  seed: int = -1,
21
  guidance_scale: float = 1.0,
22
  flow_shift: float = 7.0,
23
- embedded_guidance_scale: float = 6.0
 
 
24
  ) -> str:
25
  """Generate a video using the custom inference endpoint.
26
 
@@ -29,12 +37,14 @@ def generate_video(
29
  endpoint_url: Full URL to the inference endpoint
30
  token: HuggingFace API token for authentication
31
  resolution: Video resolution (default: "1280x720")
32
- video_length: Number of frames (default: 129 for 5s)
33
- num_inference_steps: Number of inference steps (default: 50)
34
  seed: Random seed, -1 for random (default: -1)
35
  guidance_scale: Guidance scale value (default: 1.0)
36
  flow_shift: Flow shift value (default: 7.0)
37
  embedded_guidance_scale: Embedded guidance scale (default: 6.0)
 
 
38
 
39
  Returns:
40
  Path to the saved video file
@@ -42,6 +52,13 @@ def generate_video(
42
  # Initialize client
43
  client = InferenceClient(model=endpoint_url, token=token)
44
 
 
 
 
 
 
 
 
45
  # Prepare payload
46
  payload = {
47
  "inputs": prompt,
@@ -51,54 +68,75 @@ def generate_video(
51
  "seed": seed,
52
  "guidance_scale": guidance_scale,
53
  "flow_shift": flow_shift,
54
- "embedded_guidance_scale": embedded_guidance_scale
 
 
55
  }
56
 
57
  # Make request
58
- response = client.post(json=payload)
59
- result = response.json()
60
 
61
- # Save video
62
- timestamp = int(time.time())
63
- output_path = f"generated_video_{timestamp}.mp4"
64
- save_video(result["video_base64"], output_path)
65
-
66
- print(f"Video generated with seed {result['seed']}")
67
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
-
71
  hf_api_token = os.environ.get('HF_API_TOKEN', '')
72
  endpoint_url = os.environ.get('ENDPOINT_URL', '')
73
-
 
 
 
 
74
  video_path = generate_video(
75
  endpoint_url=endpoint_url,
76
  token=hf_api_token,
77
-
78
  prompt="A cat walks on the grass, realistic style.",
79
-
80
- # min resolution is 64x64, max is 4096x4096 (increment steps are by 16px)
81
- # however the model is designed for 1280x720
82
- resolution="1280x720",
83
-
84
- # numbers of frames plus one (max 1024?)
85
- # increments by 4 frames
86
- video_length=49, # 129,
87
-
88
- # number of denoising/sampling steps (default: 30)
89
- num_inference_steps: int = 15, # 50,
90
-
91
- seed: int = -1, # -1 to keep it random
92
-
93
- # not sure why we have two guidance scales
94
- guidance_scale = 1.0, # 3
95
-
96
- # strength of prompt guidance (default: 6.0)
97
- embedded_guidance_scale: float = 6.0
98
-
99
-
100
- # video length (larger values result in shorter videos, default: 9.0, max: 30)
101
- flow_shift: float = 9.0,
102
-
103
- )
104
- print(f"Video saved to: {video_path}")
 
1
  from huggingface_hub import InferenceClient
2
  import base64
3
  import os
4
+ import re
5
  from pathlib import Path
6
  import time
7
 
8
  def save_video(base64_video: str, output_path: str):
9
  """Save base64 encoded video to a file"""
10
+ # Handle data URI format if present
11
+ if base64_video.startswith('data:video/mp4;base64,'):
12
+ base64_video = base64_video.split('base64,')[1]
13
+
14
  video_bytes = base64.b64decode(base64_video)
15
  with open(output_path, "wb") as f:
16
  f.write(video_bytes)
17
+ print(f"Video saved to: {output_path}")
18
 
19
  def generate_video(
20
  prompt: str,
 
22
  token: str = None,
23
  resolution: str = "1280x720",
24
  video_length: int = 129,
25
+ num_inference_steps: int = 30,
26
  seed: int = -1,
27
  guidance_scale: float = 1.0,
28
  flow_shift: float = 7.0,
29
+ embedded_guidance_scale: float = 6.0,
30
+ enable_riflex: bool = True,
31
+ tea_cache: float = 0.0
32
  ) -> str:
33
  """Generate a video using the custom inference endpoint.
34
 
 
37
  endpoint_url: Full URL to the inference endpoint
38
  token: HuggingFace API token for authentication
39
  resolution: Video resolution (default: "1280x720")
40
+ video_length: Number of frames (default: 129)
41
+ num_inference_steps: Number of inference steps (default: 30)
42
  seed: Random seed, -1 for random (default: -1)
43
  guidance_scale: Guidance scale value (default: 1.0)
44
  flow_shift: Flow shift value (default: 7.0)
45
  embedded_guidance_scale: Embedded guidance scale (default: 6.0)
46
+ enable_riflex: Enable RIFLEx positional embedding for long videos (default: True)
47
+ tea_cache: TeaCache acceleration threshold, 0.0 to disable, 0.1 for 1.6x speedup, 0.15 for 2.1x speedup (default: 0.0)
48
 
49
  Returns:
50
  Path to the saved video file
 
52
  # Initialize client
53
  client = InferenceClient(model=endpoint_url, token=token)
54
 
55
+ print(f"Generating video with prompt: \"{prompt}\"")
56
+ print(f"Resolution: {resolution}, Length: {video_length} frames")
57
+ print(f"Steps: {num_inference_steps}, Seed: {'random' if seed == -1 else seed}")
58
+
59
+ # Sanitize filename from prompt
60
+ safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip().replace(' ', '_')
61
+
62
  # Prepare payload
63
  payload = {
64
  "inputs": prompt,
 
68
  "seed": seed,
69
  "guidance_scale": guidance_scale,
70
  "flow_shift": flow_shift,
71
+ "embedded_guidance_scale": embedded_guidance_scale,
72
+ "enable_riflex": enable_riflex,
73
+ "tea_cache": tea_cache
74
  }
75
 
76
  # Make request
77
+ start_time = time.time()
78
+ print("Sending request to endpoint...")
79
 
80
+ try:
81
+ response = client.post(json=payload)
82
+
83
+ # Check if the response is a string (data URI) or JSON
84
+ if response.headers.get('content-type') == 'application/json':
85
+ result = response.json()
86
+ video_data = result.get("video_base64", result)
87
+ else:
88
+ # The response might be directly the data URI
89
+ video_data = response.text
90
+
91
+ generation_time = time.time() - start_time
92
+ print(f"Video generated in {generation_time:.2f} seconds")
93
+
94
+ # Save video
95
+ timestamp = int(time.time())
96
+ output_path = f"{safe_prompt}_{timestamp}.mp4"
97
+
98
+ # If the response is a data URI, extract the base64 part
99
+ if isinstance(video_data, str) and video_data.startswith('data:video/mp4;base64,'):
100
+ save_video(video_data, output_path)
101
+ elif isinstance(video_data, str):
102
+ save_video(video_data, output_path)
103
+ else:
104
+ # Assume it's a dictionary with a base64 key
105
+ save_video(video_data.get("video_base64", ""), output_path)
106
+
107
+ return output_path
108
+
109
+ except Exception as e:
110
+ print(f"Error generating video: {e}")
111
+ raise
112
 
113
  if __name__ == "__main__":
 
114
  hf_api_token = os.environ.get('HF_API_TOKEN', '')
115
  endpoint_url = os.environ.get('ENDPOINT_URL', '')
116
+
117
+ if not endpoint_url:
118
+ print("Please set the ENDPOINT_URL environment variable")
119
+ exit(1)
120
+
121
  video_path = generate_video(
122
  endpoint_url=endpoint_url,
123
  token=hf_api_token,
 
124
  prompt="A cat walks on the grass, realistic style.",
125
+
126
+ # Video configuration
127
+ resolution="1280x720", # Standard HD resolution
128
+ video_length=97, # About 4 seconds at 24fps
129
+
130
+ # Generation parameters
131
+ num_inference_steps=22, # Default for standard model
132
+ seed=-1, # Random seed
133
+
134
+ # Advanced parameters
135
+ guidance_scale=1.0,
136
+ embedded_guidance_scale=6.0,
137
+ flow_shift=7.0,
138
+
139
+ # Optimizations
140
+ enable_riflex=True, # Better for videos longer than 4 seconds
141
+ tea_cache=0.0 # Set to 0.1 or 0.15 for faster generation with slight quality loss
142
+ )