File size: 4,948 Bytes
b7658fb 6b64262 3f97053 4324db0 138e306 4324db0 6b64262 3f97053 6b64262 b338394 6b64262 d082ce1 6b64262 4324db0 8a1f431 6b64262 b338394 6b64262 88cdeea 4324db0 b338394 21a831f 4324db0 2877f43 4324db0 2877f43 b338394 138e306 b338394 88cdeea 2877f43 88cdeea 2877f43 88cdeea 2877f43 88cdeea 2877f43 88cdeea 2877f43 88cdeea 2877f43 88cdeea 2877f43 88cdeea 2877f43 b338394 21a831f 2877f43 b338394 4324db0 b338394 4324db0 21a831f 4324db0 6d6254a d082ce1 6b64262 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import subprocess
import sys
from pathlib import Path
# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
print("Cloning the VibeVoice repository...")
try:
subprocess.run(
["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
check=True, capture_output=True, text=True
)
print("Repository cloned successfully.")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr}")
sys.exit(1)
else:
print("Repository already exists. Skipping clone.")
# --- 2. Install Dependencies ---
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")
print("Installing the VibeVoice package in editable mode...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", "."],
check=True, capture_output=True, text=True
)
print("Package installed successfully.")
except subprocess.CalledProcessError as e:
print(f"Error installing package: {e.stderr}")
sys.exit(1)
# --- 3. Refactor the demo script using a robust state-machine patcher ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
try:
with open(demo_script_path, 'r') as f:
lines = f.readlines()
# --- Prepare the code blocks to be inserted ---
lazy_load_code = """
# Patched: Lazy-load model and processor on the GPU worker
if self.model is None or self.processor is None:
print("Loading processor & model for the first time on GPU worker...")
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16, # Use 16-bit precision for quality
device_map="auto",
)
self.model.eval()
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config,
algorithm_type='sde-dpmsolver++',
beta_schedule='squaredcos_cap_v2'
)
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
print("Model and processor loaded successfully on GPU worker.")
"""
# --- Perform the line-by-line modifications using a state machine ---
new_lines = []
# Add 'import spaces' at the top if it doesn't exist
if not any("import spaces" in line for line in lines):
new_lines.append("import spaces\n")
# State machine variables
in_generate_method = False
patched_generate_method = False
for line in lines:
# Defer the initial model loading to prevent PicklingError
if "self.load_model()" in line and "def __init__" in "".join(lines[lines.index(line)-2:lines.index(line)]):
new_lines.append(" # self.load_model() # Patched: Defer model loading\n")
new_lines.append(" self.model = None\n")
new_lines.append(" self.processor = None\n")
print("Successfully patched __init__ to prevent startup model load.")
# Start of the target method
elif "def generate_podcast_streaming(self," in line and not patched_generate_method:
new_lines.append(" @spaces.GPU(duration=120)\n")
new_lines.append(line)
in_generate_method = True
# End of the target method signature
elif "-> Iterator[tuple]:" in line and in_generate_method:
new_lines.append(line)
# Indent and insert the lazy load code
for code_line in lazy_load_code.strip().split('\n'):
new_lines.append(' ' * 8 + code_line + '\n')
# Reset state
in_generate_method = False
patched_generate_method = True
print("Successfully patched generation method for lazy loading.")
# All other lines
else:
new_lines.append(line)
if not patched_generate_method:
print("\033[91mError: Failed to apply the lazy-loading patch. The target method signature may have changed.\033[0m")
sys.exit(1)
# --- Write the modified content back to the file ---
with open(demo_script_path, 'w') as f:
f.writelines(new_lines)
print("Script patching complete.")
except Exception as e:
print(f"An error occurred while modifying the script: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"
command = ["python", str(demo_script_path), "--model_path", model_id, "--share"]
print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command) |