Update app.py
Browse files
app.py
CHANGED
|
@@ -34,35 +34,15 @@ except subprocess.CalledProcessError as e:
|
|
| 34 |
print(f"Error installing package: {e.stderr}")
|
| 35 |
sys.exit(1)
|
| 36 |
|
| 37 |
-
# --- 3. Refactor the demo script
|
| 38 |
demo_script_path = Path("demo/gradio_demo.py")
|
| 39 |
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
|
| 40 |
|
| 41 |
try:
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
-
# ---
|
| 45 |
-
if "import spaces" not in modified_content:
|
| 46 |
-
modified_content = "import spaces\n" + modified_content
|
| 47 |
-
|
| 48 |
-
# --- Patch 1: Prevent model loading at startup ---
|
| 49 |
-
# Comment out self.load_model() in __init__ to avoid loading on the main CPU process.
|
| 50 |
-
original_init_line = " self.load_model()"
|
| 51 |
-
replacement_init_line = " # self.load_model() # Patched: Defer model loading\n self.model = None\n self.processor = None"
|
| 52 |
-
|
| 53 |
-
if original_init_line in modified_content:
|
| 54 |
-
modified_content = modified_content.replace(original_init_line, replacement_init_line)
|
| 55 |
-
print("Successfully patched __init__ to prevent model loading on startup.")
|
| 56 |
-
else:
|
| 57 |
-
print(f"\033[91mError: Could not find '{original_init_line}' to patch.\033[0m")
|
| 58 |
-
sys.exit(1)
|
| 59 |
-
|
| 60 |
-
# --- Patch 2: Move model loading inside the generation function and add decorator ---
|
| 61 |
-
# This ensures the model is loaded "just-in-time" on the GPU worker with proper precision.
|
| 62 |
-
original_method_signature = " def generate_podcast_streaming(self,"
|
| 63 |
-
|
| 64 |
-
# Define the model loading code to be inserted.
|
| 65 |
-
# We use torch.bfloat16 for a balance of performance and quality.
|
| 66 |
lazy_load_code = """
|
| 67 |
# Patched: Lazy-load model and processor on the GPU worker
|
| 68 |
if self.model is None or self.processor is None:
|
|
@@ -83,28 +63,42 @@ try:
|
|
| 83 |
print("Model and processor loaded successfully on GPU worker.")
|
| 84 |
"""
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
#
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
"
|
| 98 |
-
"
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
sys.exit(1)
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
| 108 |
print("Script patching complete.")
|
| 109 |
|
| 110 |
except Exception as e:
|
|
|
|
| 34 |
print(f"Error installing package: {e.stderr}")
|
| 35 |
sys.exit(1)
|
| 36 |
|
| 37 |
+
# --- 3. Refactor the demo script using a robust line-by-line patch ---
|
| 38 |
demo_script_path = Path("demo/gradio_demo.py")
|
| 39 |
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
|
| 40 |
|
| 41 |
try:
|
| 42 |
+
with open(demo_script_path, 'r') as f:
|
| 43 |
+
lines = f.readlines()
|
| 44 |
|
| 45 |
+
# --- Prepare the code blocks to be inserted ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
lazy_load_code = """
|
| 47 |
# Patched: Lazy-load model and processor on the GPU worker
|
| 48 |
if self.model is None or self.processor is None:
|
|
|
|
| 63 |
print("Model and processor loaded successfully on GPU worker.")
|
| 64 |
"""
|
| 65 |
|
| 66 |
+
# --- Perform the line-by-line modifications ---
|
| 67 |
+
new_lines = []
|
| 68 |
+
# Add 'import spaces' at the top if it doesn't exist
|
| 69 |
+
if not any("import spaces" in line for line in lines):
|
| 70 |
+
new_lines.append("import spaces\n")
|
| 71 |
+
|
| 72 |
+
patched = False
|
| 73 |
+
for line in lines:
|
| 74 |
+
# Defer the initial model loading to prevent PicklingError
|
| 75 |
+
if "self.load_model()" in line and "def __init__" in "".join(lines[lines.index(line)-2:lines.index(line)]):
|
| 76 |
+
new_lines.append(" # self.load_model() # Patched: Defer model loading\n")
|
| 77 |
+
new_lines.append(" self.model = None\n")
|
| 78 |
+
new_lines.append(" self.processor = None\n")
|
| 79 |
+
print("Successfully patched __init__ to prevent startup model load.")
|
| 80 |
+
# Find the generation method to add the decorator and lazy-loading logic
|
| 81 |
+
elif "def generate_podcast_streaming(self," in line:
|
| 82 |
+
new_lines.append(" @spaces.GPU(duration=120)\n")
|
| 83 |
+
new_lines.append(line)
|
| 84 |
+
elif "-> Iterator[tuple]:" in line and "generate_podcast_streaming" in new_lines[-1]:
|
| 85 |
+
new_lines.append(line)
|
| 86 |
+
# Indent the lazy load code correctly
|
| 87 |
+
for code_line in lazy_load_code.strip().split('\n'):
|
| 88 |
+
new_lines.append(' ' * 8 + code_line + '\n')
|
| 89 |
+
patched = True
|
| 90 |
+
print("Successfully patched generation method for lazy loading.")
|
| 91 |
+
else:
|
| 92 |
+
new_lines.append(line)
|
| 93 |
+
|
| 94 |
+
if not patched:
|
| 95 |
+
print("\033[91mError: Failed to apply the lazy-loading patch. The target method signature may have changed.\033[0m")
|
| 96 |
sys.exit(1)
|
| 97 |
|
| 98 |
+
# --- Write the modified content back to the file ---
|
| 99 |
+
with open(demo_script_path, 'w') as f:
|
| 100 |
+
f.writelines(new_lines)
|
| 101 |
+
|
| 102 |
print("Script patching complete.")
|
| 103 |
|
| 104 |
except Exception as e:
|