Aloukik21 commited on
Commit
4994fd0
·
verified ·
1 Parent(s): 42e02b3

Add RunPod handler with cleanup support

Browse files
Files changed (1) hide show
  1. rp_handler.py +417 -0
rp_handler.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RunPod Serverless Handler - Wrapper for AI-Toolkit
3
+ Does NOT modify ai-toolkit code, only wraps it
4
+
5
+ Supports RunPod model caching via HuggingFace integration.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import subprocess
11
+ import traceback
12
+ import logging
13
+ import uuid
14
+ from pathlib import Path
15
+
16
+ # =============================================================================
17
+ # Environment Setup (must be before other imports)
18
+ # =============================================================================
19
+
20
+ # RunPod cache paths
21
+ RUNPOD_CACHE_BASE = "/runpod-volume/huggingface-cache"
22
+ RUNPOD_HF_CACHE = "/runpod-volume/huggingface-cache/hub"
23
+
24
+ # Check if running on RunPod with cache available
25
+ IS_RUNPOD_CACHE = os.path.exists("/runpod-volume")
26
+
27
+ if IS_RUNPOD_CACHE:
28
+ # Use RunPod's cache directory for HuggingFace downloads
29
+ os.environ["HF_HOME"] = RUNPOD_CACHE_BASE
30
+ os.environ["HUGGINGFACE_HUB_CACHE"] = RUNPOD_HF_CACHE
31
+ os.environ["TRANSFORMERS_CACHE"] = RUNPOD_HF_CACHE
32
+ os.environ["HF_DATASETS_CACHE"] = f"{RUNPOD_CACHE_BASE}/datasets"
33
+
34
+ # Performance and telemetry settings
35
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
36
+ os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
37
+ os.environ["DISABLE_TELEMETRY"] = "YES"
38
+
39
+ # Get HF token from environment
40
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
41
+ if HF_TOKEN:
42
+ os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
43
+
44
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
45
+ AI_TOOLKIT_DIR = os.path.join(SCRIPT_DIR, "ai-toolkit")
46
+
47
+ import runpod
48
+ import torch
49
+ import yaml
50
+ import gc
51
+ import shutil
52
+
53
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
54
+ logger = logging.getLogger(__name__)
55
+
56
+ # Track current loaded model for cleanup
57
+ CURRENT_MODEL = None
58
+
59
+ # =============================================================================
60
+ # Model Configuration
61
+ # =============================================================================
62
+
63
+ # Model configs matching ai-toolkit/config/examples exactly
64
+ MODEL_PRESETS = {
65
+ "wan21_1b": "train_lora_wan21_1b_24gb.yaml",
66
+ "wan21_14b": "train_lora_wan21_14b_24gb.yaml",
67
+ "wan22_14b": "train_lora_wan22_14b_24gb.yaml",
68
+ "qwen_image": "train_lora_qwen_image_24gb.yaml",
69
+ "qwen_image_edit": "train_lora_qwen_image_edit_32gb.yaml",
70
+ "qwen_image_edit_2509": "train_lora_qwen_image_edit_2509_32gb.yaml",
71
+ "flux_dev": "train_lora_flux_24gb.yaml",
72
+ "flux_schnell": "train_lora_flux_schnell_24gb.yaml",
73
+ }
74
+
75
+ # HuggingFace repos used by each model (for pre-warming)
76
+ MODEL_HF_REPOS = {
77
+ "wan21_1b": ["Wan-AI/Wan2.1-T2V-1.3B-Diffusers"],
78
+ "wan21_14b": ["Wan-AI/Wan2.1-T2V-14B-Diffusers"],
79
+ "wan22_14b": ["ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16"],
80
+ "qwen_image": ["Qwen/Qwen-Image"],
81
+ "qwen_image_edit": ["Qwen/Qwen-Image-Edit"],
82
+ "qwen_image_edit_2509": ["Qwen/Qwen-Image-Edit"],
83
+ "flux_dev": ["black-forest-labs/FLUX.1-dev"],
84
+ "flux_schnell": ["black-forest-labs/FLUX.1-schnell"],
85
+ }
86
+
87
+ # Accuracy Recovery Adapters (smaller files, can be pre-downloaded)
88
+ ARA_FILES = {
89
+ "wan22_14b": "ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors",
90
+ "qwen_image": "ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors",
91
+ }
92
+
93
+
94
+ # =============================================================================
95
+ # Cleanup Functions
96
+ # =============================================================================
97
+
98
+ def cleanup_gpu_memory():
99
+ """Aggressively clean up GPU memory."""
100
+ logger.info("Cleaning up GPU memory...")
101
+
102
+ # Clear PyTorch cache
103
+ if torch.cuda.is_available():
104
+ torch.cuda.empty_cache()
105
+ torch.cuda.synchronize()
106
+
107
+ # Force garbage collection
108
+ gc.collect()
109
+
110
+ # Clear again after GC
111
+ if torch.cuda.is_available():
112
+ torch.cuda.empty_cache()
113
+
114
+ logger.info(f"GPU memory after cleanup: {get_gpu_info()}")
115
+
116
+
117
+ def cleanup_temp_files():
118
+ """Clean up temporary training files."""
119
+ logger.info("Cleaning up temporary files...")
120
+
121
+ # Clean up generated configs (keep example configs)
122
+ config_dir = os.path.join(AI_TOOLKIT_DIR, "config")
123
+ for f in os.listdir(config_dir):
124
+ if f.endswith('.yaml') and f.startswith(('lora_', 'test_', 'my_')):
125
+ try:
126
+ os.remove(os.path.join(config_dir, f))
127
+ logger.info(f"Removed temp config: {f}")
128
+ except Exception as e:
129
+ logger.warning(f"Failed to remove {f}: {e}")
130
+
131
+ # Clean up latent cache directories in workspace
132
+ workspace_dirs = ["/workspace/dataset", "/workspace/output"]
133
+ for ws_dir in workspace_dirs:
134
+ if os.path.exists(ws_dir):
135
+ for item in os.listdir(ws_dir):
136
+ item_path = os.path.join(ws_dir, item)
137
+ if item.startswith(('_latent_cache', '_t_e_cache', '.aitk')):
138
+ try:
139
+ if os.path.isdir(item_path):
140
+ shutil.rmtree(item_path)
141
+ else:
142
+ os.remove(item_path)
143
+ logger.info(f"Removed cache: {item_path}")
144
+ except Exception as e:
145
+ logger.warning(f"Failed to remove {item_path}: {e}")
146
+
147
+
148
+ def cleanup_before_training(new_model: str):
149
+ """Full cleanup before starting new model training."""
150
+ global CURRENT_MODEL
151
+
152
+ if CURRENT_MODEL and CURRENT_MODEL != new_model:
153
+ logger.info(f"Switching from {CURRENT_MODEL} to {new_model} - performing full cleanup")
154
+ cleanup_gpu_memory()
155
+ cleanup_temp_files()
156
+ elif CURRENT_MODEL == new_model:
157
+ logger.info(f"Same model {new_model} - light cleanup only")
158
+ cleanup_gpu_memory()
159
+ else:
160
+ logger.info(f"First training run with {new_model}")
161
+
162
+ CURRENT_MODEL = new_model
163
+
164
+ # Final memory check
165
+ gpu_info = get_gpu_info()
166
+ logger.info(f"Ready for training. GPU: {gpu_info['name']}, Free: {gpu_info['free_gb']}GB")
167
+
168
+
169
+ # =============================================================================
170
+ # Utility Functions
171
+ # =============================================================================
172
+
173
+ def get_gpu_info():
174
+ """Get GPU information."""
175
+ if not torch.cuda.is_available():
176
+ return {"available": False}
177
+ props = torch.cuda.get_device_properties(0)
178
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
179
+ return {
180
+ "available": True,
181
+ "name": props.name,
182
+ "total_gb": round(total_mem / (1024**3), 2),
183
+ "free_gb": round(free_mem / (1024**3), 2),
184
+ }
185
+
186
+
187
+ def get_environment_info():
188
+ """Get environment information for debugging."""
189
+ return {
190
+ "is_runpod_cache": IS_RUNPOD_CACHE,
191
+ "hf_home": os.environ.get("HF_HOME", "not set"),
192
+ "hf_token_set": bool(HF_TOKEN),
193
+ "gpu": get_gpu_info(),
194
+ "ai_toolkit_dir": AI_TOOLKIT_DIR,
195
+ "cache_exists": os.path.exists(RUNPOD_HF_CACHE) if IS_RUNPOD_CACHE else False,
196
+ }
197
+
198
+
199
+ def find_cached_model(hf_repo: str) -> str:
200
+ """
201
+ Find cached model path on RunPod.
202
+
203
+ Args:
204
+ hf_repo: HuggingFace repo ID (e.g., 'black-forest-labs/FLUX.1-dev')
205
+
206
+ Returns:
207
+ Path to cached model, or original repo ID if not cached
208
+ """
209
+ if not IS_RUNPOD_CACHE:
210
+ return hf_repo
211
+
212
+ # Convert "Org/Repo" -> "models--Org--Repo"
213
+ cache_name = hf_repo.replace("/", "--")
214
+ snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
215
+
216
+ if snapshots_dir.exists():
217
+ snapshots = list(snapshots_dir.iterdir())
218
+ if snapshots:
219
+ cached_path = str(snapshots[0])
220
+ logger.info(f"Using cached model: {hf_repo} -> {cached_path}")
221
+ return cached_path
222
+
223
+ logger.info(f"Model not cached, will download: {hf_repo}")
224
+ return hf_repo
225
+
226
+
227
+ def check_model_cache_status(model_key: str) -> dict:
228
+ """Check if model files are cached."""
229
+ if model_key not in MODEL_HF_REPOS:
230
+ return {"cached": False, "reason": "unknown model"}
231
+
232
+ repos = MODEL_HF_REPOS[model_key]
233
+ status = {"repos": {}}
234
+
235
+ for repo in repos:
236
+ cache_name = repo.replace("/", "--")
237
+ snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
238
+
239
+ if snapshots_dir.exists() and list(snapshots_dir.iterdir()):
240
+ status["repos"][repo] = "cached"
241
+ else:
242
+ status["repos"][repo] = "not cached"
243
+
244
+ status["all_cached"] = all(s == "cached" for s in status["repos"].values())
245
+ return status
246
+
247
+
248
+ # =============================================================================
249
+ # Config Loading and Training
250
+ # =============================================================================
251
+
252
+ def load_example_config(model_key):
253
+ """Load example config from ai-toolkit."""
254
+ if model_key not in MODEL_PRESETS:
255
+ raise ValueError(f"Unknown model: {model_key}. Available: {list(MODEL_PRESETS.keys())}")
256
+
257
+ config_file = MODEL_PRESETS[model_key]
258
+ config_path = os.path.join(AI_TOOLKIT_DIR, "config", "examples", config_file)
259
+
260
+ with open(config_path, 'r') as f:
261
+ return yaml.safe_load(f)
262
+
263
+
264
+ def run_training(params):
265
+ """Run training using ai-toolkit."""
266
+ model_key = params.get("model", "wan22_14b")
267
+
268
+ # Cleanup before starting new training
269
+ cleanup_before_training(model_key)
270
+
271
+ # Load base config from ai-toolkit examples
272
+ config = load_example_config(model_key)
273
+
274
+ # Override with user params
275
+ job_name = params.get("name", f"lora_{model_key}_{uuid.uuid4().hex[:6]}")
276
+ config["config"]["name"] = job_name
277
+
278
+ process = config["config"]["process"][0]
279
+
280
+ # Dataset
281
+ process["datasets"][0]["folder_path"] = params.get("dataset_path", "/workspace/dataset")
282
+
283
+ # Output
284
+ process["training_folder"] = params.get("output_path", "/workspace/output")
285
+
286
+ # Training params (only override if provided)
287
+ if "steps" in params:
288
+ process["train"]["steps"] = params["steps"]
289
+ if "batch_size" in params:
290
+ process["train"]["batch_size"] = params["batch_size"]
291
+ if "learning_rate" in params:
292
+ process["train"]["lr"] = params["learning_rate"]
293
+ if "lora_rank" in params:
294
+ process["network"]["linear"] = params["lora_rank"]
295
+ process["network"]["linear_alpha"] = params.get("lora_alpha", params["lora_rank"])
296
+ if "save_every" in params:
297
+ process["save"]["save_every"] = params["save_every"]
298
+ if "sample_every" in params:
299
+ process["sample"]["sample_every"] = params["sample_every"]
300
+ if "resolution" in params:
301
+ process["datasets"][0]["resolution"] = params["resolution"]
302
+ if "num_frames" in params:
303
+ process["datasets"][0]["num_frames"] = params["num_frames"]
304
+ if "sample_prompts" in params:
305
+ process["sample"]["prompts"] = params["sample_prompts"]
306
+ if "trigger_word" in params:
307
+ process["trigger_word"] = params["trigger_word"]
308
+
309
+ # Check if we should use cached model path
310
+ if IS_RUNPOD_CACHE and "model" in process:
311
+ original_path = process["model"].get("name_or_path", "")
312
+ if original_path:
313
+ cached_path = find_cached_model(original_path)
314
+ if cached_path != original_path:
315
+ process["model"]["name_or_path"] = cached_path
316
+ logger.info(f"Using cached model path: {cached_path}")
317
+
318
+ # Save config
319
+ config_dir = os.path.join(AI_TOOLKIT_DIR, "config")
320
+ config_path = os.path.join(config_dir, f"{job_name}.yaml")
321
+
322
+ with open(config_path, 'w') as f:
323
+ yaml.dump(config, f, default_flow_style=False)
324
+
325
+ logger.info(f"Config saved: {config_path}")
326
+ logger.info(f"Starting: {job_name}")
327
+
328
+ # Run ai-toolkit
329
+ cmd = [sys.executable, os.path.join(AI_TOOLKIT_DIR, "run.py"), config_path]
330
+ logger.info(f"Command: {' '.join(cmd)}")
331
+
332
+ proc = subprocess.Popen(
333
+ cmd,
334
+ cwd=AI_TOOLKIT_DIR,
335
+ stdout=subprocess.PIPE,
336
+ stderr=subprocess.STDOUT,
337
+ text=True,
338
+ bufsize=1,
339
+ )
340
+
341
+ for line in proc.stdout:
342
+ logger.info(line.rstrip())
343
+
344
+ proc.wait()
345
+
346
+ # Cleanup after training (success or fail)
347
+ cleanup_gpu_memory()
348
+
349
+ if proc.returncode != 0:
350
+ raise RuntimeError(f"Training failed with code {proc.returncode}")
351
+
352
+ return {
353
+ "status": "success",
354
+ "job_name": job_name,
355
+ "output_path": process["training_folder"],
356
+ "model": model_key,
357
+ }
358
+
359
+
360
+ # =============================================================================
361
+ # Handler
362
+ # =============================================================================
363
+
364
+ def handler(job):
365
+ """RunPod handler."""
366
+ job_input = job.get("input", {})
367
+ action = job_input.get("action", "train")
368
+
369
+ logger.info(f"Action: {action}, GPU: {get_gpu_info()}")
370
+
371
+ try:
372
+ if action == "list_models":
373
+ return {"status": "success", "models": list(MODEL_PRESETS.keys())}
374
+
375
+ elif action == "status":
376
+ return {
377
+ "status": "success",
378
+ "environment": get_environment_info(),
379
+ }
380
+
381
+ elif action == "check_cache":
382
+ model_key = job_input.get("model")
383
+ if model_key:
384
+ cache_status = check_model_cache_status(model_key)
385
+ else:
386
+ cache_status = {m: check_model_cache_status(m) for m in MODEL_PRESETS.keys()}
387
+ return {"status": "success", "cache": cache_status}
388
+
389
+ elif action == "cleanup":
390
+ # Manual cleanup action
391
+ cleanup_gpu_memory()
392
+ cleanup_temp_files()
393
+ global CURRENT_MODEL
394
+ CURRENT_MODEL = None
395
+ return {
396
+ "status": "success",
397
+ "message": "Cleanup complete",
398
+ "gpu": get_gpu_info(),
399
+ }
400
+
401
+ elif action == "train":
402
+ params = job_input.get("params", {})
403
+ params["model"] = job_input.get("model", params.get("model", "wan22_14b"))
404
+ return run_training(params)
405
+
406
+ else:
407
+ return {"status": "error", "error": f"Unknown action: {action}"}
408
+
409
+ except Exception as e:
410
+ logger.error(traceback.format_exc())
411
+ return {"status": "error", "error": str(e)}
412
+
413
+
414
+ if __name__ == "__main__":
415
+ logger.info("Starting AI-Toolkit RunPod Handler")
416
+ logger.info(f"Environment: {get_environment_info()}")
417
+ runpod.serverless.start({"handler": handler})