broadfield-dev commited on
Commit
b338394
·
verified ·
1 Parent(s): 31adbe7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -78
app.py CHANGED
@@ -3,10 +3,6 @@ import subprocess
3
  import sys
4
  from pathlib import Path
5
 
6
- # --- 0. Hardcoded Toggle for Execution Environment ---
7
- # Ensure this is set to True to use the GPU with quantization
8
- USE_ZEROGPU = False
9
-
10
  # --- 1. Clone the VibeVoice Repository ---
11
  repo_dir = "VibeVoice"
12
  if not os.path.exists(repo_dir):
@@ -14,9 +10,7 @@ if not os.path.exists(repo_dir):
14
  try:
15
  subprocess.run(
16
  ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
17
- check=True,
18
- capture_output=True,
19
- text=True
20
  )
21
  print("Repository cloned successfully.")
22
  except subprocess.CalledProcessError as e:
@@ -29,8 +23,7 @@ else:
29
  os.chdir(repo_dir)
30
  print(f"Changed directory to: {os.getcwd()}")
31
 
32
- # Install bitsandbytes for quantization to reduce memory usage
33
- print("Installing bitsandbytes for quantization...")
34
  try:
35
  subprocess.run(
36
  [sys.executable, "-m", "pip", "install", "bitsandbytes"],
@@ -45,96 +38,103 @@ print("Installing the VibeVoice package in editable mode...")
45
  try:
46
  subprocess.run(
47
  [sys.executable, "-m", "pip", "install", "-e", "."],
48
- check=True,
49
- capture_output=True,
50
- text=True
51
  )
52
  print("Package installed successfully.")
53
  except subprocess.CalledProcessError as e:
54
  print(f"Error installing package: {e.stderr}")
55
  sys.exit(1)
56
 
57
- # --- 3. Modify the demo script for a memory-constrained environment ---
58
  demo_script_path = Path("demo/gradio_demo.py")
59
- print(f"Reading {demo_script_path} to apply environment-specific modifications...")
60
 
61
  try:
62
  modified_content = demo_script_path.read_text()
63
 
64
- # Define the original model loading block to be replaced.
65
- original_model_lines = [
66
- ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
67
- ' self.model_path,',
68
- ' torch_dtype=torch.bfloat16,',
69
- " device_map='cuda',",
70
- ' attn_implementation="flash_attention_2",',
71
- ' )'
72
- ]
73
- original_model_block = "\n".join(original_model_lines)
74
 
75
- # Define the generation method signature to add the decorator to.
 
 
 
 
 
 
 
 
76
  original_method_signature = " def generate_podcast_streaming(self,"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- if USE_ZEROGPU:
79
- print("Optimizing for ZeroGPU with 8-bit quantization...")
 
 
 
 
80
 
81
- # Add necessary imports if they are not already there.
82
- if "import spaces" not in modified_content:
83
- modified_content = "import spaces\n" + modified_content
84
-
85
- # New block for ZeroGPU with 8-bit quantization.
86
- # This is the key change to solve the memory issue.
87
- replacement_model_lines_gpu = [
88
- ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
89
- ' self.model_path,',
90
- ' load_in_8bit=True,',
91
- ' device_map="auto",',
92
- ' )'
93
- ]
94
- replacement_model_block_gpu = "\n".join(replacement_model_lines_gpu)
95
 
96
- # Add the @spaces.GPU decorator with correct indentation.
97
- replacement_method_signature_gpu = " @spaces.GPU(duration=120)\n" + original_method_signature
98
-
99
- # --- Apply Patches for GPU ---
100
-
101
- # Patch 1: Decorate the generation method
102
- if original_method_signature in modified_content:
103
- modified_content = modified_content.replace(original_method_signature, replacement_method_signature_gpu)
104
- print("Successfully applied GPU decorator to the generation method.")
105
- else:
106
- print("\033[91mError: Could not find the generation method signature to patch.\033[0m")
107
- sys.exit(1)
108
-
109
- # Patch 2: Modify the model loading to use 8-bit quantization
110
- if original_model_block in modified_content:
111
- modified_content = modified_content.replace(original_model_block, replacement_model_block_gpu)
112
- print("Successfully patched model loading for 8-bit quantization.")
113
- else:
114
- print("\033[91mError: The original model loading block was not found.\033[0m")
115
- sys.exit(1)
116
-
117
- else: # Pure CPU execution (not recommended on ZeroGPU hardware)
118
- # This block is unlikely to be used but kept for completeness
119
- print("Modifying for pure CPU execution...")
120
- replacement_model_lines_cpu = [
121
- ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
122
- ' self.model_path,',
123
- ' torch_dtype=torch.float32,',
124
- ' device_map="cpu",',
125
- ' )'
126
- ]
127
- replacement_model_block_cpu = "\n".join(replacement_model_lines_cpu)
128
- if original_model_block in modified_content:
129
- modified_content = modified_content.replace(original_model_block, replacement_model_block_cpu)
130
- else:
131
- print("\033[91mError: The original model loading block was not found for CPU patching.\033[0m")
132
- sys.exit(1)
133
 
134
  demo_script_path.write_text(modified_content)
 
135
 
136
  except Exception as e:
137
  print(f"An error occurred while modifying the script: {e}")
 
 
138
  sys.exit(1)
139
 
140
  # --- 4. Launch the Gradio Demo ---
 
3
  import sys
4
  from pathlib import Path
5
 
 
 
 
 
6
  # --- 1. Clone the VibeVoice Repository ---
7
  repo_dir = "VibeVoice"
8
  if not os.path.exists(repo_dir):
 
10
  try:
11
  subprocess.run(
12
  ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
13
+ check=True, capture_output=True, text=True
 
 
14
  )
15
  print("Repository cloned successfully.")
16
  except subprocess.CalledProcessError as e:
 
23
  os.chdir(repo_dir)
24
  print(f"Changed directory to: {os.getcwd()}")
25
 
26
+ print("Installing bitsandbytes for potential quantization...")
 
27
  try:
28
  subprocess.run(
29
  [sys.executable, "-m", "pip", "install", "bitsandbytes"],
 
38
  try:
39
  subprocess.run(
40
  [sys.executable, "-m", "pip", "install", "-e", "."],
41
+ check=True, capture_output=True, text=True
 
 
42
  )
43
  print("Package installed successfully.")
44
  except subprocess.CalledProcessError as e:
45
  print(f"Error installing package: {e.stderr}")
46
  sys.exit(1)
47
 
48
+ # --- 3. Refactor the demo script for ZeroGPU compatibility ---
49
  demo_script_path = Path("demo/gradio_demo.py")
50
+ print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
51
 
52
  try:
53
  modified_content = demo_script_path.read_text()
54
 
55
+ # --- Add necessary imports ---
56
+ if "import spaces" not in modified_content:
57
+ modified_content = "import spaces\n" + modified_content
58
+
59
+ # --- Patch 1: Prevent model loading at startup ---
60
+ # We comment out the self.load_model() call in the __init__ method.
61
+ # This stops the main CPU process from loading the heavyweight model.
62
+ original_init_line = " self.load_model()"
63
+ replacement_init_line = " # self.load_model() # Patched: Defer model loading to the GPU worker\n self.model = None\n self.processor = None"
 
64
 
65
+ if original_init_line in modified_content:
66
+ modified_content = modified_content.replace(original_init_line, replacement_init_line)
67
+ print("Successfully patched __init__ to prevent model loading on startup.")
68
+ else:
69
+ print(f"\033[91mError: Could not find '{original_init_line}' to patch.\033[0m")
70
+ sys.exit(1)
71
+
72
+ # --- Patch 2: Move model loading inside the generation function and add decorator ---
73
+ # This ensures the model is loaded "just-in-time" on the GPU worker.
74
  original_method_signature = " def generate_podcast_streaming(self,"
75
+
76
+ # Define the model loading code to be inserted.
77
+ # We will use 8-bit quantization to be safe with memory.
78
+ lazy_load_code = """
79
+ # Patched: Lazy-load model and processor on the GPU worker
80
+ if self.model is None or self.processor is None:
81
+ print("Loading processor & model for the first time on GPU worker...")
82
+ self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
83
+ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
84
+ self.model_path,
85
+ load_in_8bit=True,
86
+ device_map="auto",
87
+ )
88
+ self.model.eval()
89
+ self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
90
+ self.model.model.noise_scheduler.config,
91
+ algorithm_type='sde-dpmsolver++',
92
+ beta_schedule='squaredcos_cap_v2'
93
+ )
94
+ self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
95
+ print("Model and processor loaded successfully on GPU worker.")
96
+
97
+ """
98
+
99
+ # We add the decorator and the lazy loading code.
100
+ replacement_block = (
101
+ " @spaces.GPU(duration=120)\n" +
102
+ original_method_signature +
103
+ "\n" +
104
+ " " * 8 + lazy_load_code.strip().replace("\n", "\n" + " " * 8)
105
+ )
106
 
107
+ if original_method_signature in modified_content:
108
+ # Find the start of the method and insert our block right after the signature.
109
+ # We need to find the full method signature to insert code into it.
110
+ method_start_index = modified_content.find(original_method_signature)
111
+ # Find the end of the signature line
112
+ signature_end_index = modified_content.find("-> Iterator[tuple]:", method_start_index) + len("-> Iterator[tuple]:")
113
 
114
+ # Reconstruct the content
115
+ pre_method = modified_content[:method_start_index]
116
+ method_signature_and_body = modified_content[method_start_index:]
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # Decorate the original signature
119
+ decorated_signature = " @spaces.GPU(duration=120)\n" + original_method_signature
120
+ method_signature_and_body = method_signature_and_body.replace(original_method_signature, decorated_signature)
121
+
122
+ # Insert the lazy loading code after the signature line
123
+ final_method = method_signature_and_body.replace("-> Iterator[tuple]:", "-> Iterator[tuple]:\n" + lazy_load_code, 1)
124
+
125
+ modified_content = pre_method + final_method
126
+ print("Successfully refactored generation method for lazy loading on GPU.")
127
+ else:
128
+ print(f"\033[91mError: Could not find '{original_method_signature}' to patch.\033[0m")
129
+ sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  demo_script_path.write_text(modified_content)
132
+ print("Script patching complete.")
133
 
134
  except Exception as e:
135
  print(f"An error occurred while modifying the script: {e}")
136
+ import traceback
137
+ traceback.print_exc()
138
  sys.exit(1)
139
 
140
  # --- 4. Launch the Gradio Demo ---