File size: 4,948 Bytes
b7658fb
6b64262
3f97053
4324db0
138e306
4324db0
6b64262
 
 
3f97053
6b64262
 
b338394
6b64262
 
 
 
 
 
 
 
d082ce1
6b64262
 
4324db0
8a1f431
6b64262
 
 
b338394
6b64262
 
 
 
 
 
88cdeea
4324db0
b338394
21a831f
4324db0
2877f43
 
4324db0
2877f43
b338394
 
 
 
 
 
 
138e306
b338394
 
 
 
 
 
 
 
 
 
 
 
88cdeea
2877f43
 
 
 
 
88cdeea
 
 
 
2877f43
 
 
 
 
 
 
88cdeea
 
 
2877f43
 
88cdeea
 
 
 
2877f43
88cdeea
2877f43
 
88cdeea
 
 
 
2877f43
88cdeea
 
2877f43
 
 
88cdeea
2877f43
b338394
21a831f
2877f43
 
 
 
b338394
4324db0
 
 
b338394
 
4324db0
21a831f
4324db0
6d6254a
d082ce1
6b64262
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import subprocess
import sys
from pathlib import Path

# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
    print("Cloning the VibeVoice repository...")
    try:
        subprocess.run(
            ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
            check=True, capture_output=True, text=True
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error cloning repository: {e.stderr}")
        sys.exit(1)
else:
    print("Repository already exists. Skipping clone.")

# --- 2. Install Dependencies ---
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")

print("Installing the VibeVoice package in editable mode...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-e", "."],
        check=True, capture_output=True, text=True
    )
    print("Package installed successfully.")
except subprocess.CalledProcessError as e:
    print(f"Error installing package: {e.stderr}")
    sys.exit(1)

# --- 3. Refactor the demo script using a robust state-machine patcher ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")

try:
    with open(demo_script_path, 'r') as f:
        lines = f.readlines()

    # --- Prepare the code blocks to be inserted ---
    lazy_load_code = """
        # Patched: Lazy-load model and processor on the GPU worker
        if self.model is None or self.processor is None:
            print("Loading processor & model for the first time on GPU worker...")
            self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                self.model_path,
                torch_dtype=torch.bfloat16, # Use 16-bit precision for quality
                device_map="auto",
            )
            self.model.eval()
            self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
                self.model.model.noise_scheduler.config, 
                algorithm_type='sde-dpmsolver++',
                beta_schedule='squaredcos_cap_v2'
            )
            self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
            print("Model and processor loaded successfully on GPU worker.")
"""
    
    # --- Perform the line-by-line modifications using a state machine ---
    new_lines = []
    # Add 'import spaces' at the top if it doesn't exist
    if not any("import spaces" in line for line in lines):
        new_lines.append("import spaces\n")

    # State machine variables
    in_generate_method = False
    patched_generate_method = False
    
    for line in lines:
        # Defer the initial model loading to prevent PicklingError
        if "self.load_model()" in line and "def __init__" in "".join(lines[lines.index(line)-2:lines.index(line)]):
            new_lines.append("        # self.load_model() # Patched: Defer model loading\n")
            new_lines.append("        self.model = None\n")
            new_lines.append("        self.processor = None\n")
            print("Successfully patched __init__ to prevent startup model load.")
        
        # Start of the target method
        elif "def generate_podcast_streaming(self," in line and not patched_generate_method:
            new_lines.append("    @spaces.GPU(duration=120)\n")
            new_lines.append(line)
            in_generate_method = True
        
        # End of the target method signature
        elif "-> Iterator[tuple]:" in line and in_generate_method:
            new_lines.append(line)
            # Indent and insert the lazy load code
            for code_line in lazy_load_code.strip().split('\n'):
                new_lines.append(' ' * 8 + code_line + '\n')
            
            # Reset state
            in_generate_method = False
            patched_generate_method = True
            print("Successfully patched generation method for lazy loading.")
        
        # All other lines
        else:
            new_lines.append(line)
            
    if not patched_generate_method:
        print("\033[91mError: Failed to apply the lazy-loading patch. The target method signature may have changed.\033[0m")
        sys.exit(1)

    # --- Write the modified content back to the file ---
    with open(demo_script_path, 'w') as f:
        f.writelines(new_lines)
    
    print("Script patching complete.")

except Exception as e:
    print(f"An error occurred while modifying the script: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"
command = ["python", str(demo_script_path), "--model_path", model_id, "--share"]
print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)