|
|
import os |
|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_ZEROGPU = True |
|
|
|
|
|
|
|
|
repo_dir = "VibeVoice" |
|
|
if not os.path.exists(repo_dir): |
|
|
print("Cloning the VibeVoice repository...") |
|
|
try: |
|
|
subprocess.run( |
|
|
["git", "clone", "https://github.com/microsoft/VibeVoice.git"], |
|
|
check=True, |
|
|
capture_output=True, |
|
|
text=True |
|
|
) |
|
|
print("Repository cloned successfully.") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error cloning repository: {e.stderr}") |
|
|
sys.exit(1) |
|
|
else: |
|
|
print("Repository already exists. Skipping clone.") |
|
|
|
|
|
|
|
|
|
|
|
os.chdir(repo_dir) |
|
|
print(f"Changed directory to: {os.getcwd()}") |
|
|
|
|
|
print("Installing the VibeVoice package in editable mode...") |
|
|
try: |
|
|
subprocess.run( |
|
|
[sys.executable, "-m", "pip", "install", "-e", "."], |
|
|
check=True, |
|
|
capture_output=True, |
|
|
text=True |
|
|
) |
|
|
print("Package installed successfully.") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error installing package: {e.stderr}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
demo_script_path = Path("demo/gradio_demo.py") |
|
|
print(f"Reading {demo_script_path} to apply environment-specific modifications...") |
|
|
|
|
|
try: |
|
|
modified_content = demo_script_path.read_text() |
|
|
|
|
|
if USE_ZEROGPU: |
|
|
print("Configuring for ZeroGPU execution while keeping Flash Attention...") |
|
|
|
|
|
|
|
|
if "import spaces" not in modified_content: |
|
|
modified_content = "import spaces\n" + modified_content |
|
|
|
|
|
|
|
|
|
|
|
original_method_signature = " def generate_podcast_streaming(self," |
|
|
|
|
|
|
|
|
replacement_method_signature_gpu = " @spaces.GPU(duration=120)\n" + original_method_signature |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if original_method_signature in modified_content: |
|
|
modified_content = modified_content.replace(original_method_signature, replacement_method_signature_gpu) |
|
|
print("Successfully applied GPU decorator to the generation method.") |
|
|
print("Model loading block remains unchanged to explicitly use Flash Attention.") |
|
|
else: |
|
|
print("\033[91mError: Could not find the generation method signature to apply the GPU decorator.\033[0m") |
|
|
sys.exit(1) |
|
|
|
|
|
else: |
|
|
print("Modifying for pure CPU execution...") |
|
|
|
|
|
|
|
|
original_model_lines = [ |
|
|
' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(', |
|
|
' self.model_path,', |
|
|
' torch_dtype=torch.bfloat16,', |
|
|
" device_map='cuda',", |
|
|
' attn_implementation="flash_attention_2",', |
|
|
' )' |
|
|
] |
|
|
original_model_block = "\n".join(original_model_lines) |
|
|
|
|
|
|
|
|
replacement_model_lines_cpu = [ |
|
|
' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(', |
|
|
' self.model_path,', |
|
|
' torch_dtype=torch.float32, # Use float32 for CPU', |
|
|
' device_map="cpu",', |
|
|
' )' |
|
|
] |
|
|
replacement_model_block_cpu = "\n".join(replacement_model_lines_cpu) |
|
|
|
|
|
|
|
|
if original_model_block in modified_content: |
|
|
modified_content = modified_content.replace(original_model_block, replacement_model_block_cpu) |
|
|
print("Script modified for CPU successfully.") |
|
|
else: |
|
|
print("\033[91mError: The original model loading block was not found for CPU patching.\033[0m") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
demo_script_path.write_text(modified_content) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred while modifying the script: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
model_id = "microsoft/VibeVoice-1.5B" |
|
|
|
|
|
|
|
|
command = [ |
|
|
"python", |
|
|
str(demo_script_path), |
|
|
"--model_path", |
|
|
model_id, |
|
|
"--share" |
|
|
] |
|
|
|
|
|
print(f"Launching Gradio demo with command: {' '.join(command)}") |
|
|
subprocess.run(command) |