broadfield-dev commited on
Commit
e59066e
·
verified ·
1 Parent(s): a549db3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -4,8 +4,8 @@ import sys
4
  from pathlib import Path
5
 
6
  # --- 0. Hardcoded Toggle for Execution Environment ---
7
- # Ensure this is set to True to use the GPU
8
- USE_ZEROGPU = False
9
 
10
  # --- 1. Clone the VibeVoice Repository ---
11
  repo_dir = "VibeVoice"
@@ -29,6 +29,18 @@ else:
29
  os.chdir(repo_dir)
30
  print(f"Changed directory to: {os.getcwd()}")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  print("Installing the VibeVoice package in editable mode...")
33
  try:
34
  subprocess.run(
@@ -42,7 +54,7 @@ except subprocess.CalledProcessError as e:
42
  print(f"Error installing package: {e.stderr}")
43
  sys.exit(1)
44
 
45
- # --- 3. Modify the demo script to be environment-aware ---
46
  demo_script_path = Path("demo/gradio_demo.py")
47
  print(f"Reading {demo_script_path} to apply environment-specific modifications...")
48
 
@@ -64,18 +76,19 @@ try:
64
  original_method_signature = " def generate_podcast_streaming(self,"
65
 
66
  if USE_ZEROGPU:
67
- print("Optimizing for ZeroGPU execution with robust attention...")
68
 
69
- # Add 'import spaces' if it's not already there.
70
  if "import spaces" not in modified_content:
71
  modified_content = "import spaces\n" + modified_content
72
 
73
- # New block for ZeroGPU model loading: remove `attn_implementation` for auto-detection.
 
74
  replacement_model_lines_gpu = [
75
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
76
  ' self.model_path,',
77
- ' torch_dtype=torch.bfloat16,',
78
- " device_map='cuda',",
79
  ' )'
80
  ]
81
  replacement_model_block_gpu = "\n".join(replacement_model_lines_gpu)
@@ -93,15 +106,16 @@ try:
93
  print("\033[91mError: Could not find the generation method signature to patch.\033[0m")
94
  sys.exit(1)
95
 
96
- # Patch 2: Modify the model loading to allow auto-detection of attention
97
  if original_model_block in modified_content:
98
  modified_content = modified_content.replace(original_model_block, replacement_model_block_gpu)
99
- print("Successfully patched model loading to remove hardcoded Flash Attention.")
100
  else:
101
  print("\033[91mError: The original model loading block was not found.\033[0m")
102
  sys.exit(1)
103
 
104
- else: # Pure CPU execution
 
105
  print("Modifying for pure CPU execution...")
106
  replacement_model_lines_cpu = [
107
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
 
4
  from pathlib import Path
5
 
6
  # --- 0. Hardcoded Toggle for Execution Environment ---
7
+ # Ensure this is set to True to use the GPU with quantization
8
+ USE_ZEROGPU = True
9
 
10
  # --- 1. Clone the VibeVoice Repository ---
11
  repo_dir = "VibeVoice"
 
29
  os.chdir(repo_dir)
30
  print(f"Changed directory to: {os.getcwd()}")
31
 
32
+ # Install bitsandbytes for quantization to reduce memory usage
33
+ print("Installing bitsandbytes for quantization...")
34
+ try:
35
+ subprocess.run(
36
+ [sys.executable, "-m", "pip", "install", "bitsandbytes"],
37
+ check=True, capture_output=True, text=True
38
+ )
39
+ print("bitsandbytes installed successfully.")
40
+ except subprocess.CalledProcessError as e:
41
+ print(f"Error installing bitsandbytes: {e.stderr}")
42
+ sys.exit(1)
43
+
44
  print("Installing the VibeVoice package in editable mode...")
45
  try:
46
  subprocess.run(
 
54
  print(f"Error installing package: {e.stderr}")
55
  sys.exit(1)
56
 
57
+ # --- 3. Modify the demo script for a memory-constrained environment ---
58
  demo_script_path = Path("demo/gradio_demo.py")
59
  print(f"Reading {demo_script_path} to apply environment-specific modifications...")
60
 
 
76
  original_method_signature = " def generate_podcast_streaming(self,"
77
 
78
  if USE_ZEROGPU:
79
+ print("Optimizing for ZeroGPU with 8-bit quantization...")
80
 
81
+ # Add necessary imports if they are not already there.
82
  if "import spaces" not in modified_content:
83
  modified_content = "import spaces\n" + modified_content
84
 
85
+ # New block for ZeroGPU with 8-bit quantization.
86
+ # This is the key change to solve the memory issue.
87
  replacement_model_lines_gpu = [
88
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
89
  ' self.model_path,',
90
+ ' load_in_8bit=True,',
91
+ ' device_map="auto",',
92
  ' )'
93
  ]
94
  replacement_model_block_gpu = "\n".join(replacement_model_lines_gpu)
 
106
  print("\033[91mError: Could not find the generation method signature to patch.\033[0m")
107
  sys.exit(1)
108
 
109
+ # Patch 2: Modify the model loading to use 8-bit quantization
110
  if original_model_block in modified_content:
111
  modified_content = modified_content.replace(original_model_block, replacement_model_block_gpu)
112
+ print("Successfully patched model loading for 8-bit quantization.")
113
  else:
114
  print("\033[91mError: The original model loading block was not found.\033[0m")
115
  sys.exit(1)
116
 
117
+ else: # Pure CPU execution (not recommended on ZeroGPU hardware)
118
+ # This block is unlikely to be used but kept for completeness
119
  print("Modifying for pure CPU execution...")
120
  replacement_model_lines_cpu = [
121
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',