broadfield-dev commited on
Commit
1cb26d6
·
verified ·
1 Parent(s): 4324db0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -17
app.py CHANGED
@@ -3,6 +3,11 @@ import subprocess
3
  import sys
4
  from pathlib import Path
5
 
 
 
 
 
 
6
  # --- 1. Clone the VibeVoice Repository ---
7
  repo_dir = "VibeVoice"
8
  if not os.path.exists(repo_dir):
@@ -21,10 +26,11 @@ if not os.path.exists(repo_dir):
21
  else:
22
  print("Repository already exists. Skipping clone.")
23
 
24
- # --- 2. Install the Package ---
25
  os.chdir(repo_dir)
26
  print(f"Changed directory to: {os.getcwd()}")
27
 
 
28
  print("Installing the VibeVoice package...")
29
  try:
30
  subprocess.run(
@@ -38,38 +44,83 @@ except subprocess.CalledProcessError as e:
38
  print(f"Error installing package: {e.stderr}")
39
  sys.exit(1)
40
 
41
- # --- 3. Modify the demo script for CPU execution (Robust Method) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  demo_script_path = Path("demo/gradio_demo.py")
43
- print(f"Modifying {demo_script_path} for CPU execution...")
44
 
45
  try:
46
- # Read the entire file content
47
  file_content = demo_script_path.read_text()
 
 
 
48
 
49
- # Define the original GPU-specific model loading block
50
- original_block = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
51
  self.model_path,
52
  torch_dtype=torch.bfloat16,
53
  device_map='cuda',
54
  attn_implementation="flash_attention_2",
55
  )"""
56
 
57
- # Define the new CPU-compatible block
58
- replacement_block = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  self.model_path,
60
  torch_dtype=torch.float32, # Use float32 for CPU
61
  device_map="cpu",
62
  )"""
63
 
64
- # Replace the entire block
65
- if original_block in file_content:
66
- modified_content = file_content.replace(original_block, replacement_block)
67
-
68
- # Write the modified content back to the file
69
- demo_script_path.write_text(modified_content)
70
- print("Script modified successfully.")
71
- else:
72
- print("Warning: GPU-specific model loading block not found. The script might have been updated. Proceeding without modification.")
73
 
74
  except Exception as e:
75
  print(f"An error occurred while modifying the script: {e}")
 
3
  import sys
4
  from pathlib import Path
5
 
6
+ # --- 0. Hardcoded Toggle for Execution Environment ---
7
+ # Set this to True to use Hugging Face ZeroGPU
8
+ # Set this to False to use a pure CPU environment
9
+ USE_ZEROGPU = True
10
+
11
  # --- 1. Clone the VibeVoice Repository ---
12
  repo_dir = "VibeVoice"
13
  if not os.path.exists(repo_dir):
 
26
  else:
27
  print("Repository already exists. Skipping clone.")
28
 
29
+ # --- 2. Install Dependencies ---
30
  os.chdir(repo_dir)
31
  print(f"Changed directory to: {os.getcwd()}")
32
 
33
+ # Install the main package
34
  print("Installing the VibeVoice package...")
35
  try:
36
  subprocess.run(
 
44
  print(f"Error installing package: {e.stderr}")
45
  sys.exit(1)
46
 
47
+ # Install 'spaces' if using ZeroGPU
48
+ if USE_ZEROGPU:
49
+ print("Installing the 'spaces' library for ZeroGPU...")
50
+ try:
51
+ subprocess.run(
52
+ [sys.executable, "-m", "pip", "install", "huggingface-hub", "gradio", "spaces"],
53
+ check=True,
54
+ capture_output=True,
55
+ text=True
56
+ )
57
+ print("'spaces' library installed successfully.")
58
+ except subprocess.CalledProcessError as e:
59
+ print(f"Error installing 'spaces' library: {e.stderr}")
60
+ sys.exit(1)
61
+
62
+
63
+ # --- 3. Modify the demo script based on the toggle ---
64
  demo_script_path = Path("demo/gradio_demo.py")
65
+ print(f"Reading {demo_script_path}...")
66
 
67
  try:
 
68
  file_content = demo_script_path.read_text()
69
+
70
+ if USE_ZEROGPU:
71
+ print("Optimizing for ZeroGPU execution...")
72
 
73
+ # Ensure the original GPU block is present
74
+ original_block = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
75
  self.model_path,
76
  torch_dtype=torch.bfloat16,
77
  device_map='cuda',
78
  attn_implementation="flash_attention_2",
79
  )"""
80
 
81
+ if original_block in file_content:
82
+ # Add 'import spaces' at the beginning of the file
83
+ modified_content = "import spaces\n" + file_content
84
+
85
+ # Decorate the model loading and generation functions with @spaces.GPU
86
+ # This is a robust way to ensure both setup and inference get GPU access
87
+ modified_content = modified_content.replace(
88
+ "class VibeVoiceGradioInterface:",
89
+ "@spaces.GPU\nclass VibeVoiceGradioInterface:"
90
+ )
91
+ print("Script modified for ZeroGPU successfully.")
92
+
93
+ # Write the modified content back to the file
94
+ demo_script_path.write_text(modified_content)
95
+ else:
96
+ print("Warning: Original GPU-specific model loading block not found. The script might have been updated. Proceeding with potential ZeroGPU compatibility.")
97
+
98
+ else:
99
+ print("Modifying for CPU execution...")
100
+ # Define the original GPU-specific model loading block
101
+ original_block = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
102
+ self.model_path,
103
+ torch_dtype=torch.bfloat16,
104
+ device_map='cuda',
105
+ attn_implementation="flash_attention_2",
106
+ )"""
107
+
108
+ # Define the new CPU-compatible block
109
+ replacement_block = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
110
  self.model_path,
111
  torch_dtype=torch.float32, # Use float32 for CPU
112
  device_map="cpu",
113
  )"""
114
 
115
+ # Replace the entire block
116
+ if original_block in file_content:
117
+ modified_content = file_content.replace(original_block, replacement_block)
118
+
119
+ # Write the modified content back to the file
120
+ demo_script_path.write_text(modified_content)
121
+ print("Script modified for CPU successfully.")
122
+ else:
123
+ print("Warning: GPU-specific model loading block not found. The script might have been updated. Proceeding without modification.")
124
 
125
  except Exception as e:
126
  print(f"An error occurred while modifying the script: {e}")