| | """ |
| | RunPod Serverless Handler - Wrapper for AI-Toolkit |
| | Does NOT modify ai-toolkit code, only wraps it |
| | |
| | Supports RunPod model caching via HuggingFace integration. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import subprocess |
| | import traceback |
| | import logging |
| | import uuid |
| | from pathlib import Path |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | RUNPOD_CACHE_BASE = "/runpod-volume/huggingface-cache" |
| | RUNPOD_HF_CACHE = "/runpod-volume/huggingface-cache/hub" |
| |
|
| | |
| | IS_RUNPOD_CACHE = os.path.exists("/runpod-volume") |
| |
|
| | if IS_RUNPOD_CACHE: |
| | |
| | os.environ["HF_HOME"] = RUNPOD_CACHE_BASE |
| | os.environ["HUGGINGFACE_HUB_CACHE"] = RUNPOD_HF_CACHE |
| | os.environ["TRANSFORMERS_CACHE"] = RUNPOD_HF_CACHE |
| | os.environ["HF_DATASETS_CACHE"] = f"{RUNPOD_CACHE_BASE}/datasets" |
| |
|
| | |
| | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
| | os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" |
| | os.environ["DISABLE_TELEMETRY"] = "YES" |
| |
|
| | |
| | HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| | if HF_TOKEN: |
| | os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN |
| |
|
| | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | AI_TOOLKIT_DIR = os.path.join(SCRIPT_DIR, "ai-toolkit") |
| |
|
| | import runpod |
| | import torch |
| | import yaml |
| | import gc |
| | import shutil |
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | CURRENT_MODEL = None |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | MODEL_PRESETS = { |
| | "wan21_1b": "train_lora_wan21_1b_24gb.yaml", |
| | "wan21_14b": "train_lora_wan21_14b_24gb.yaml", |
| | "wan22_14b": "train_lora_wan22_14b_24gb.yaml", |
| | "qwen_image": "train_lora_qwen_image_24gb.yaml", |
| | "qwen_image_edit": "train_lora_qwen_image_edit_32gb.yaml", |
| | "qwen_image_edit_2509": "train_lora_qwen_image_edit_2509_32gb.yaml", |
| | "flux_dev": "train_lora_flux_24gb.yaml", |
| | "flux_schnell": "train_lora_flux_schnell_24gb.yaml", |
| | } |
| |
|
| | |
| | MODEL_HF_REPOS = { |
| | "wan21_1b": ["Wan-AI/Wan2.1-T2V-1.3B-Diffusers"], |
| | "wan21_14b": ["Wan-AI/Wan2.1-T2V-14B-Diffusers"], |
| | "wan22_14b": ["ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16"], |
| | "qwen_image": ["Qwen/Qwen-Image"], |
| | "qwen_image_edit": ["Qwen/Qwen-Image-Edit"], |
| | "qwen_image_edit_2509": ["Qwen/Qwen-Image-Edit"], |
| | "flux_dev": ["black-forest-labs/FLUX.1-dev"], |
| | "flux_schnell": ["black-forest-labs/FLUX.1-schnell"], |
| | } |
| |
|
| | |
| | ARA_FILES = { |
| | "wan22_14b": "ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors", |
| | "qwen_image": "ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors", |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def cleanup_gpu_memory(): |
| | """Aggressively clean up GPU memory.""" |
| | logger.info("Cleaning up GPU memory...") |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| |
|
| | |
| | gc.collect() |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | logger.info(f"GPU memory after cleanup: {get_gpu_info()}") |
| |
|
| |
|
| | def cleanup_temp_files(): |
| | """Clean up temporary training files.""" |
| | logger.info("Cleaning up temporary files...") |
| |
|
| | |
| | config_dir = os.path.join(AI_TOOLKIT_DIR, "config") |
| | for f in os.listdir(config_dir): |
| | if f.endswith('.yaml') and f.startswith(('lora_', 'test_', 'my_')): |
| | try: |
| | os.remove(os.path.join(config_dir, f)) |
| | logger.info(f"Removed temp config: {f}") |
| | except Exception as e: |
| | logger.warning(f"Failed to remove {f}: {e}") |
| |
|
| | |
| | workspace_dirs = ["/workspace/dataset", "/workspace/output"] |
| | for ws_dir in workspace_dirs: |
| | if os.path.exists(ws_dir): |
| | for item in os.listdir(ws_dir): |
| | item_path = os.path.join(ws_dir, item) |
| | if item.startswith(('_latent_cache', '_t_e_cache', '.aitk')): |
| | try: |
| | if os.path.isdir(item_path): |
| | shutil.rmtree(item_path) |
| | else: |
| | os.remove(item_path) |
| | logger.info(f"Removed cache: {item_path}") |
| | except Exception as e: |
| | logger.warning(f"Failed to remove {item_path}: {e}") |
| |
|
| |
|
| | def cleanup_before_training(new_model: str): |
| | """Full cleanup before starting new model training.""" |
| | global CURRENT_MODEL |
| |
|
| | if CURRENT_MODEL and CURRENT_MODEL != new_model: |
| | logger.info(f"Switching from {CURRENT_MODEL} to {new_model} - performing full cleanup") |
| | cleanup_gpu_memory() |
| | cleanup_temp_files() |
| | elif CURRENT_MODEL == new_model: |
| | logger.info(f"Same model {new_model} - light cleanup only") |
| | cleanup_gpu_memory() |
| | else: |
| | logger.info(f"First training run with {new_model}") |
| |
|
| | CURRENT_MODEL = new_model |
| |
|
| | |
| | gpu_info = get_gpu_info() |
| | logger.info(f"Ready for training. GPU: {gpu_info['name']}, Free: {gpu_info['free_gb']}GB") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_gpu_info(): |
| | """Get GPU information.""" |
| | if not torch.cuda.is_available(): |
| | return {"available": False} |
| | props = torch.cuda.get_device_properties(0) |
| | free_mem, total_mem = torch.cuda.mem_get_info(0) |
| | return { |
| | "available": True, |
| | "name": props.name, |
| | "total_gb": round(total_mem / (1024**3), 2), |
| | "free_gb": round(free_mem / (1024**3), 2), |
| | } |
| |
|
| |
|
| | def get_environment_info(): |
| | """Get environment information for debugging.""" |
| | return { |
| | "is_runpod_cache": IS_RUNPOD_CACHE, |
| | "hf_home": os.environ.get("HF_HOME", "not set"), |
| | "hf_token_set": bool(HF_TOKEN), |
| | "gpu": get_gpu_info(), |
| | "ai_toolkit_dir": AI_TOOLKIT_DIR, |
| | "cache_exists": os.path.exists(RUNPOD_HF_CACHE) if IS_RUNPOD_CACHE else False, |
| | } |
| |
|
| |
|
| | def find_cached_model(hf_repo: str) -> str: |
| | """ |
| | Find cached model path on RunPod. |
| | |
| | Args: |
| | hf_repo: HuggingFace repo ID (e.g., 'black-forest-labs/FLUX.1-dev') |
| | |
| | Returns: |
| | Path to cached model, or original repo ID if not cached |
| | """ |
| | if not IS_RUNPOD_CACHE: |
| | return hf_repo |
| |
|
| | |
| | cache_name = hf_repo.replace("/", "--") |
| | snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
| |
|
| | if snapshots_dir.exists(): |
| | snapshots = list(snapshots_dir.iterdir()) |
| | if snapshots: |
| | cached_path = str(snapshots[0]) |
| | logger.info(f"Using cached model: {hf_repo} -> {cached_path}") |
| | return cached_path |
| |
|
| | logger.info(f"Model not cached, will download: {hf_repo}") |
| | return hf_repo |
| |
|
| |
|
| | def check_model_cache_status(model_key: str) -> dict: |
| | """Check if model files are cached.""" |
| | if model_key not in MODEL_HF_REPOS: |
| | return {"cached": False, "reason": "unknown model"} |
| |
|
| | repos = MODEL_HF_REPOS[model_key] |
| | status = {"repos": {}} |
| |
|
| | for repo in repos: |
| | cache_name = repo.replace("/", "--") |
| | snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
| |
|
| | if snapshots_dir.exists() and list(snapshots_dir.iterdir()): |
| | status["repos"][repo] = "cached" |
| | else: |
| | status["repos"][repo] = "not cached" |
| |
|
| | status["all_cached"] = all(s == "cached" for s in status["repos"].values()) |
| | return status |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_example_config(model_key): |
| | """Load example config from ai-toolkit.""" |
| | if model_key not in MODEL_PRESETS: |
| | raise ValueError(f"Unknown model: {model_key}. Available: {list(MODEL_PRESETS.keys())}") |
| |
|
| | config_file = MODEL_PRESETS[model_key] |
| | config_path = os.path.join(AI_TOOLKIT_DIR, "config", "examples", config_file) |
| |
|
| | with open(config_path, 'r') as f: |
| | return yaml.safe_load(f) |
| |
|
| |
|
| | def run_training(params): |
| | """Run training using ai-toolkit.""" |
| | model_key = params.get("model", "wan22_14b") |
| |
|
| | |
| | cleanup_before_training(model_key) |
| |
|
| | |
| | config = load_example_config(model_key) |
| |
|
| | |
| | job_name = params.get("name", f"lora_{model_key}_{uuid.uuid4().hex[:6]}") |
| | config["config"]["name"] = job_name |
| |
|
| | process = config["config"]["process"][0] |
| |
|
| | |
| | process["datasets"][0]["folder_path"] = params.get("dataset_path", "/workspace/dataset") |
| |
|
| | |
| | process["training_folder"] = params.get("output_path", "/workspace/output") |
| |
|
| | |
| | if "steps" in params: |
| | process["train"]["steps"] = params["steps"] |
| | if "batch_size" in params: |
| | process["train"]["batch_size"] = params["batch_size"] |
| | if "learning_rate" in params: |
| | process["train"]["lr"] = params["learning_rate"] |
| | if "lora_rank" in params: |
| | process["network"]["linear"] = params["lora_rank"] |
| | process["network"]["linear_alpha"] = params.get("lora_alpha", params["lora_rank"]) |
| | if "save_every" in params: |
| | process["save"]["save_every"] = params["save_every"] |
| | if "sample_every" in params: |
| | process["sample"]["sample_every"] = params["sample_every"] |
| | if "resolution" in params: |
| | process["datasets"][0]["resolution"] = params["resolution"] |
| | if "num_frames" in params: |
| | process["datasets"][0]["num_frames"] = params["num_frames"] |
| | if "sample_prompts" in params: |
| | process["sample"]["prompts"] = params["sample_prompts"] |
| | if "trigger_word" in params: |
| | process["trigger_word"] = params["trigger_word"] |
| |
|
| | |
| | if IS_RUNPOD_CACHE and "model" in process: |
| | original_path = process["model"].get("name_or_path", "") |
| | if original_path: |
| | cached_path = find_cached_model(original_path) |
| | if cached_path != original_path: |
| | process["model"]["name_or_path"] = cached_path |
| | logger.info(f"Using cached model path: {cached_path}") |
| |
|
| | |
| | config_dir = os.path.join(AI_TOOLKIT_DIR, "config") |
| | config_path = os.path.join(config_dir, f"{job_name}.yaml") |
| |
|
| | with open(config_path, 'w') as f: |
| | yaml.dump(config, f, default_flow_style=False) |
| |
|
| | logger.info(f"Config saved: {config_path}") |
| | logger.info(f"Starting: {job_name}") |
| |
|
| | |
| | cmd = [sys.executable, os.path.join(AI_TOOLKIT_DIR, "run.py"), config_path] |
| | logger.info(f"Command: {' '.join(cmd)}") |
| |
|
| | proc = subprocess.Popen( |
| | cmd, |
| | cwd=AI_TOOLKIT_DIR, |
| | stdout=subprocess.PIPE, |
| | stderr=subprocess.STDOUT, |
| | text=True, |
| | bufsize=1, |
| | ) |
| |
|
| | for line in proc.stdout: |
| | logger.info(line.rstrip()) |
| |
|
| | proc.wait() |
| |
|
| | |
| | cleanup_gpu_memory() |
| |
|
| | if proc.returncode != 0: |
| | raise RuntimeError(f"Training failed with code {proc.returncode}") |
| |
|
| | return { |
| | "status": "success", |
| | "job_name": job_name, |
| | "output_path": process["training_folder"], |
| | "model": model_key, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def handler(job): |
| | """RunPod handler.""" |
| | job_input = job.get("input", {}) |
| | action = job_input.get("action", "train") |
| |
|
| | logger.info(f"Action: {action}, GPU: {get_gpu_info()}") |
| |
|
| | try: |
| | if action == "list_models": |
| | return {"status": "success", "models": list(MODEL_PRESETS.keys())} |
| |
|
| | elif action == "status": |
| | return { |
| | "status": "success", |
| | "environment": get_environment_info(), |
| | } |
| |
|
| | elif action == "check_cache": |
| | model_key = job_input.get("model") |
| | if model_key: |
| | cache_status = check_model_cache_status(model_key) |
| | else: |
| | cache_status = {m: check_model_cache_status(m) for m in MODEL_PRESETS.keys()} |
| | return {"status": "success", "cache": cache_status} |
| |
|
| | elif action == "cleanup": |
| | |
| | cleanup_gpu_memory() |
| | cleanup_temp_files() |
| | global CURRENT_MODEL |
| | CURRENT_MODEL = None |
| | return { |
| | "status": "success", |
| | "message": "Cleanup complete", |
| | "gpu": get_gpu_info(), |
| | } |
| |
|
| | elif action == "train": |
| | params = job_input.get("params", {}) |
| | params["model"] = job_input.get("model", params.get("model", "wan22_14b")) |
| | return run_training(params) |
| |
|
| | else: |
| | return {"status": "error", "error": f"Unknown action: {action}"} |
| |
|
| | except Exception as e: |
| | logger.error(traceback.format_exc()) |
| | return {"status": "error", "error": str(e)} |
| |
|
| |
|
| | if __name__ == "__main__": |
| | logger.info("Starting AI-Toolkit RunPod Handler") |
| | logger.info(f"Environment: {get_environment_info()}") |
| | runpod.serverless.start({"handler": handler}) |
| |
|