broadfield-dev commited on
Commit
2877f43
·
verified ·
1 Parent(s): 138e306

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -44
app.py CHANGED
@@ -34,35 +34,15 @@ except subprocess.CalledProcessError as e:
34
  print(f"Error installing package: {e.stderr}")
35
  sys.exit(1)
36
 
37
- # --- 3. Refactor the demo script for ZeroGPU compatibility ---
38
  demo_script_path = Path("demo/gradio_demo.py")
39
  print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
40
 
41
  try:
42
- modified_content = demo_script_path.read_text()
 
43
 
44
- # --- Add necessary imports ---
45
- if "import spaces" not in modified_content:
46
- modified_content = "import spaces\n" + modified_content
47
-
48
- # --- Patch 1: Prevent model loading at startup ---
49
- # Comment out self.load_model() in __init__ to avoid loading on the main CPU process.
50
- original_init_line = " self.load_model()"
51
- replacement_init_line = " # self.load_model() # Patched: Defer model loading\n self.model = None\n self.processor = None"
52
-
53
- if original_init_line in modified_content:
54
- modified_content = modified_content.replace(original_init_line, replacement_init_line)
55
- print("Successfully patched __init__ to prevent model loading on startup.")
56
- else:
57
- print(f"\033[91mError: Could not find '{original_init_line}' to patch.\033[0m")
58
- sys.exit(1)
59
-
60
- # --- Patch 2: Move model loading inside the generation function and add decorator ---
61
- # This ensures the model is loaded "just-in-time" on the GPU worker with proper precision.
62
- original_method_signature = " def generate_podcast_streaming(self,"
63
-
64
- # Define the model loading code to be inserted.
65
- # We use torch.bfloat16 for a balance of performance and quality.
66
  lazy_load_code = """
67
  # Patched: Lazy-load model and processor on the GPU worker
68
  if self.model is None or self.processor is None:
@@ -83,28 +63,42 @@ try:
83
  print("Model and processor loaded successfully on GPU worker.")
84
  """
85
 
86
- # We need to find the full method signature to insert code into it.
87
- full_method_signature_line = None
88
- for line in modified_content.splitlines():
89
- if "def generate_podcast_streaming" in line:
90
- full_method_signature_line = line.strip()
91
- break
92
-
93
- if full_method_signature_line:
94
- # We find the end of the method signature to insert our code block.
95
- target_to_replace = full_method_signature_line + "\n"
96
- replacement_block = (
97
- " @spaces.GPU(duration=120)\n" +
98
- " " + full_method_signature_line + "\n" +
99
- lazy_load_code
100
- )
101
- modified_content = modified_content.replace(target_to_replace, replacement_block, 1)
102
- print("Successfully refactored generation method for lazy loading on GPU.")
103
- else:
104
- print(f"\033[91mError: Could not find full method signature for 'generate_podcast_streaming' to patch.\033[0m")
 
 
 
 
 
 
 
 
 
 
 
105
  sys.exit(1)
106
 
107
- demo_script_path.write_text(modified_content)
 
 
 
108
  print("Script patching complete.")
109
 
110
  except Exception as e:
 
34
  print(f"Error installing package: {e.stderr}")
35
  sys.exit(1)
36
 
37
+ # --- 3. Refactor the demo script using a robust line-by-line patch ---
38
  demo_script_path = Path("demo/gradio_demo.py")
39
  print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
40
 
41
  try:
42
+ with open(demo_script_path, 'r') as f:
43
+ lines = f.readlines()
44
 
45
+ # --- Prepare the code blocks to be inserted ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  lazy_load_code = """
47
  # Patched: Lazy-load model and processor on the GPU worker
48
  if self.model is None or self.processor is None:
 
63
  print("Model and processor loaded successfully on GPU worker.")
64
  """
65
 
66
+ # --- Perform the line-by-line modifications ---
67
+ new_lines = []
68
+ # Add 'import spaces' at the top if it doesn't exist
69
+ if not any("import spaces" in line for line in lines):
70
+ new_lines.append("import spaces\n")
71
+
72
+ patched = False
73
+ for line in lines:
74
+ # Defer the initial model loading to prevent PicklingError
75
+ if "self.load_model()" in line and "def __init__" in "".join(lines[lines.index(line)-2:lines.index(line)]):
76
+ new_lines.append(" # self.load_model() # Patched: Defer model loading\n")
77
+ new_lines.append(" self.model = None\n")
78
+ new_lines.append(" self.processor = None\n")
79
+ print("Successfully patched __init__ to prevent startup model load.")
80
+ # Find the generation method to add the decorator and lazy-loading logic
81
+ elif "def generate_podcast_streaming(self," in line:
82
+ new_lines.append(" @spaces.GPU(duration=120)\n")
83
+ new_lines.append(line)
84
+ elif "-> Iterator[tuple]:" in line and "generate_podcast_streaming" in new_lines[-1]:
85
+ new_lines.append(line)
86
+ # Indent the lazy load code correctly
87
+ for code_line in lazy_load_code.strip().split('\n'):
88
+ new_lines.append(' ' * 8 + code_line + '\n')
89
+ patched = True
90
+ print("Successfully patched generation method for lazy loading.")
91
+ else:
92
+ new_lines.append(line)
93
+
94
+ if not patched:
95
+ print("\033[91mError: Failed to apply the lazy-loading patch. The target method signature may have changed.\033[0m")
96
  sys.exit(1)
97
 
98
+ # --- Write the modified content back to the file ---
99
+ with open(demo_script_path, 'w') as f:
100
+ f.writelines(new_lines)
101
+
102
  print("Script patching complete.")
103
 
104
  except Exception as e: