Spaces:
Running
Running
Add pod generation with FLUX.2, persistent state, training improvements
Browse files- Pod management: network volume support, HTTPS proxy URLs, fire-and-forget SSH
- FLUX.2 workflow: separate UNETLoader + CLIPLoader + VAELoader with Comfy-Org text encoder
- Persist pod state to disk (survives server restarts)
- Training: persistent job tracking in DB, live log streaming
- Remove tracked __pycache__ files, add .gitignore
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .gitignore +3 -0
- CLAUDE.md +219 -0
- config/models.yaml +20 -14
- src/content_engine/__pycache__/__init__.cpython-311.pyc +0 -0
- src/content_engine/__pycache__/config.cpython-311.pyc +0 -0
- src/content_engine/__pycache__/main.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/__init__.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/routes_catalog.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/routes_generation.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/routes_pod.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/routes_system.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/routes_training.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/routes_ui.cpython-311.pyc +0 -0
- src/content_engine/api/__pycache__/routes_video.cpython-311.pyc +0 -0
- src/content_engine/api/routes_pod.py +441 -112
- src/content_engine/api/routes_training.py +51 -21
- src/content_engine/api/ui.html +156 -20
- src/content_engine/models/__pycache__/__init__.cpython-311.pyc +0 -0
- src/content_engine/models/__pycache__/database.cpython-311.pyc +0 -0
- src/content_engine/models/__pycache__/schemas.cpython-311.pyc +0 -0
- src/content_engine/models/database.py +29 -7
- src/content_engine/services/__pycache__/__init__.cpython-311.pyc +0 -0
- src/content_engine/services/__pycache__/catalog.cpython-311.pyc +0 -0
- src/content_engine/services/__pycache__/comfyui_client.cpython-311.pyc +0 -0
- src/content_engine/services/__pycache__/lora_trainer.cpython-311.pyc +0 -0
- src/content_engine/services/__pycache__/runpod_trainer.cpython-311.pyc +0 -0
- src/content_engine/services/__pycache__/template_engine.cpython-311.pyc +0 -0
- src/content_engine/services/__pycache__/variation_engine.cpython-311.pyc +0 -0
- src/content_engine/services/__pycache__/workflow_builder.cpython-311.pyc +0 -0
- src/content_engine/services/runpod_trainer.py +494 -68
- src/content_engine/workers/__pycache__/__init__.cpython-311.pyc +0 -0
- src/content_engine/workers/__pycache__/local_worker.cpython-311.pyc +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
pod_state.json
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Content Engine
|
| 2 |
+
|
| 3 |
+
Automated AI content generation system with cloud APIs, LoRA training, and multi-backend support.
|
| 4 |
+
|
| 5 |
+
## Repositories
|
| 6 |
+
|
| 7 |
+
- **Local Development**: `D:\AI automation\content_engine\`
|
| 8 |
+
- **HuggingFace Deployment**: `D:\AI automation\content-engine\` (deployed to https://huggingface.co/spaces/dippoo/content-engine)
|
| 9 |
+
|
| 10 |
+
Always sync changes between both directories when modifying code.
|
| 11 |
+
|
| 12 |
+
## Architecture
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 16 |
+
│ Frontend (ui.html) │
|
| 17 |
+
│ Generate | Batch | Gallery | Train LoRA | Status | Settings │
|
| 18 |
+
└─────────────────────────────────────────────────────────────┘
|
| 19 |
+
│
|
| 20 |
+
▼
|
| 21 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 22 |
+
│ FastAPI Backend (main.py) │
|
| 23 |
+
│ routes_generation | routes_video | routes_training | etc │
|
| 24 |
+
└─────────────────────────────────────────────────────────────┘
|
| 25 |
+
│
|
| 26 |
+
┌─────────────────────┼─────────────────────┐
|
| 27 |
+
▼ ▼ ▼
|
| 28 |
+
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
|
| 29 |
+
│ Local GPU │ │ RunPod │ │ Cloud APIs │
|
| 30 |
+
│ (ComfyUI) │ │ (Serverless) │ │ (WaveSpeed) │
|
| 31 |
+
└───────────────┘ └───────────────┘ └───────────────┘
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Cloud Providers
|
| 35 |
+
|
| 36 |
+
### WaveSpeed (wavespeed_provider.py)
|
| 37 |
+
Primary cloud API for image/video generation. Uses direct HTTP API (SDK optional).
|
| 38 |
+
|
| 39 |
+
**Text-to-Image Models:**
|
| 40 |
+
- `seedream-4.5` - Best quality, NSFW OK (ByteDance)
|
| 41 |
+
- `seedream-4`, `seedream-3.1` - NSFW friendly
|
| 42 |
+
- `gpt-image-1.5`, `gpt-image-1-mini` - OpenAI models
|
| 43 |
+
- `nano-banana-pro`, `nano-banana` - Google models
|
| 44 |
+
- `wan-2.6`, `wan-2.5` - Alibaba models
|
| 45 |
+
- `kling-image-o3` - Kuaishou
|
| 46 |
+
|
| 47 |
+
**Image-to-Image (Edit) Models:**
|
| 48 |
+
- `seedream-4.5-edit` - Best for face preservation
|
| 49 |
+
- `seedream-4.5-multi`, `seedream-4-multi` - Multi-reference (up to 3 images)
|
| 50 |
+
- `kling-o1-multi` - Multi-reference (up to 10 images)
|
| 51 |
+
- `wan-2.6-edit`, `wan-2.5-edit` - NSFW friendly
|
| 52 |
+
|
| 53 |
+
**Image-to-Video Models:**
|
| 54 |
+
- `wan-2.6-i2v-pro` - Best quality ($0.05/s)
|
| 55 |
+
- `wan-2.6-i2v-flash` - Fast
|
| 56 |
+
- `kling-o3-pro`, `kling-o3` - Kuaishou
|
| 57 |
+
- `higgsfield-dop` - Cinematic 5s clips
|
| 58 |
+
- `veo-3.1`, `sora-2` - Premium models
|
| 59 |
+
|
| 60 |
+
**API Pattern:**
|
| 61 |
+
```python
|
| 62 |
+
# WaveSpeed returns async jobs - must poll for result
|
| 63 |
+
response = {"data": {"outputs": [], "urls": {"get": "poll_url"}}}
|
| 64 |
+
# Poll urls.get until outputs[] is populated
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### RunPod
|
| 68 |
+
- **Training**: Cloud GPU for LoRA training (runpod_trainer.py)
|
| 69 |
+
- **Generation**: Serverless endpoint for inference (runpod_provider.py)
|
| 70 |
+
|
| 71 |
+
## Character System
|
| 72 |
+
|
| 73 |
+
Characters link a trained LoRA to generation:
|
| 74 |
+
|
| 75 |
+
**Config file** (`config/characters/alice.yaml`):
|
| 76 |
+
```yaml
|
| 77 |
+
id: alice
|
| 78 |
+
name: "Alice"
|
| 79 |
+
trigger_word: "alicechar" # Activates the LoRA
|
| 80 |
+
lora_filename: "alice_v1.safetensors" # In D:\ComfyUI\Models\Lora\
|
| 81 |
+
lora_strength: 0.85
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
**Generation flow:**
|
| 85 |
+
1. User selects character from dropdown
|
| 86 |
+
2. System prepends trigger word: `"alicechar, a woman in red dress"`
|
| 87 |
+
3. LoRA is loaded into workflow (local/RunPod only)
|
| 88 |
+
4. Character identity is preserved in output
|
| 89 |
+
|
| 90 |
+
**For cloud-only (no local GPU):**
|
| 91 |
+
- Use img2img with reference photo
|
| 92 |
+
- Or deploy LoRA to RunPod serverless endpoint
|
| 93 |
+
|
| 94 |
+
## Templates
|
| 95 |
+
|
| 96 |
+
Prompt recipes with variables (`config/templates/*.yaml`):
|
| 97 |
+
```yaml
|
| 98 |
+
id: portrait_glamour
|
| 99 |
+
name: "Glamour Portrait"
|
| 100 |
+
positive: "{{character}}, {{pose}}, {{lighting}}, professional photo"
|
| 101 |
+
variables:
|
| 102 |
+
- name: pose
|
| 103 |
+
options: ["standing", "sitting", "leaning"]
|
| 104 |
+
- name: lighting
|
| 105 |
+
options: ["studio", "natural", "dramatic"]
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Key Files
|
| 109 |
+
|
| 110 |
+
### API Routes
|
| 111 |
+
- `routes_generation.py` - txt2img, img2img endpoints
|
| 112 |
+
- `routes_video.py` - img2video, WaveSpeed/Higgsfield video
|
| 113 |
+
- `routes_training.py` - LoRA training jobs
|
| 114 |
+
- `routes_catalog.py` - Gallery/image management
|
| 115 |
+
- `routes_system.py` - Health checks, character list
|
| 116 |
+
|
| 117 |
+
### Services
|
| 118 |
+
- `wavespeed_provider.py` - WaveSpeed API client (SDK optional, uses httpx)
|
| 119 |
+
- `runpod_trainer.py` - Cloud LoRA training
|
| 120 |
+
- `runpod_provider.py` - Cloud generation endpoint
|
| 121 |
+
- `comfyui_client.py` - Local ComfyUI integration
|
| 122 |
+
- `workflow_builder.py` - ComfyUI workflow JSON builder
|
| 123 |
+
- `template_engine.py` - Prompt template rendering
|
| 124 |
+
- `variation_engine.py` - Batch variation generation
|
| 125 |
+
|
| 126 |
+
### Frontend
|
| 127 |
+
- `ui.html` - Single-page app with all UI
|
| 128 |
+
|
| 129 |
+
## Environment Variables
|
| 130 |
+
|
| 131 |
+
```env
|
| 132 |
+
# Cloud APIs
|
| 133 |
+
WAVESPEED_API_KEY=ws_xxx # WaveSpeed.ai API key
|
| 134 |
+
RUNPOD_API_KEY=xxx # RunPod API key
|
| 135 |
+
RUNPOD_ENDPOINT_ID=xxx # RunPod serverless endpoint (for generation)
|
| 136 |
+
|
| 137 |
+
# Optional
|
| 138 |
+
HIGGSFIELD_API_KEY=xxx # Higgsfield (Kling 3.0, etc.)
|
| 139 |
+
COMFYUI_URL=http://127.0.0.1:8188 # Local ComfyUI
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## Database
|
| 143 |
+
|
| 144 |
+
SQLite with async (aiosqlite):
|
| 145 |
+
- `images` - Generated image catalog
|
| 146 |
+
- `characters` - Character profiles
|
| 147 |
+
- `generation_jobs` - Job tracking
|
| 148 |
+
- `scheduled_posts` - Publishing queue
|
| 149 |
+
|
| 150 |
+
## UI Structure
|
| 151 |
+
|
| 152 |
+
**Generate Page:**
|
| 153 |
+
- Mode chips: Text to Image | Image to Image | Image to Video
|
| 154 |
+
- Backend chips: Local GPU | RunPod GPU | Cloud API
|
| 155 |
+
- Model dropdowns (conditional on mode/backend)
|
| 156 |
+
- Character/Template selectors (2-column grid)
|
| 157 |
+
- Prompt textareas
|
| 158 |
+
- Output settings (aspect ratio, seed)
|
| 159 |
+
|
| 160 |
+
**Controls Panel:** 340px width, compact styling
|
| 161 |
+
**Drop Zones:** For reference images (character + pose)
|
| 162 |
+
|
| 163 |
+
## Common Issues
|
| 164 |
+
|
| 165 |
+
### "Product not found" from WaveSpeed
|
| 166 |
+
Model ID doesn't exist. Check `MODEL_MAP`, `EDIT_MODEL_MAP`, `VIDEO_MODEL_MAP` in wavespeed_provider.py against https://wavespeed.ai/models
|
| 167 |
+
|
| 168 |
+
### "No image URL in output"
|
| 169 |
+
WaveSpeed returned async job. Check `outputs` is empty and `urls.get` exists, then poll that URL.
|
| 170 |
+
|
| 171 |
+
### HuggingFace Space startup hang
|
| 172 |
+
Check requirements.txt for missing packages. Common: `python-dotenv`, `runpod`, `wavespeed` (optional).
|
| 173 |
+
|
| 174 |
+
### Import errors on HF Spaces
|
| 175 |
+
Make optional imports with try/except:
|
| 176 |
+
```python
|
| 177 |
+
try:
|
| 178 |
+
from wavespeed import Client
|
| 179 |
+
SDK_AVAILABLE = True
|
| 180 |
+
except ImportError:
|
| 181 |
+
SDK_AVAILABLE = False
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
## Development Commands
|
| 185 |
+
|
| 186 |
+
```bash
|
| 187 |
+
# Run locally
|
| 188 |
+
cd content_engine
|
| 189 |
+
python -m uvicorn content_engine.main:app --port 8000 --reload
|
| 190 |
+
|
| 191 |
+
# Push to HuggingFace
|
| 192 |
+
cd content-engine
|
| 193 |
+
git add . && git commit -m "message" && git push origin main
|
| 194 |
+
|
| 195 |
+
# Sync local ↔ HF
|
| 196 |
+
cp content_engine/src/content_engine/file.py content-engine/src/content_engine/file.py
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
## Multi-Reference Image Support
|
| 200 |
+
|
| 201 |
+
For img2img with 2 reference images (character + pose):
|
| 202 |
+
|
| 203 |
+
1. **UI**: Two drop zones side-by-side
|
| 204 |
+
2. **API**: `image` (required) + `image2` (optional) in FormData
|
| 205 |
+
3. **Backend**: Both uploaded to temp URLs, sent to WaveSpeed
|
| 206 |
+
4. **Models**: SeeDream Sequential, Kling O1 support multi-ref
|
| 207 |
+
|
| 208 |
+
## Pricing Notes
|
| 209 |
+
|
| 210 |
+
- **WaveSpeed**: ~$0.003-0.01 per image, $0.01-0.05/s for video
|
| 211 |
+
- **RunPod**: ~$0.0002/s for GPU time (training/generation)
|
| 212 |
+
- Cloud API cheaper for light use; RunPod better for volume
|
| 213 |
+
|
| 214 |
+
## Future Improvements
|
| 215 |
+
|
| 216 |
+
- [ ] RunPod serverless endpoint for LoRA-based generation
|
| 217 |
+
- [ ] Auto-captioning for training images
|
| 218 |
+
- [ ] Batch video generation
|
| 219 |
+
- [ ] Publishing integrations (social media APIs)
|
config/models.yaml
CHANGED
|
@@ -5,25 +5,31 @@ training_models:
|
|
| 5 |
# FLUX - Best for photorealistic images (recommended for realistic person)
|
| 6 |
flux2_dev:
|
| 7 |
name: "FLUX.2 Dev (Recommended)"
|
| 8 |
-
description: "Latest FLUX model, 32B params, best quality for realistic person.
|
| 9 |
hf_repo: "black-forest-labs/FLUX.2-dev"
|
| 10 |
-
hf_filename: "
|
| 11 |
-
model_type: "
|
|
|
|
| 12 |
resolution: 1024
|
| 13 |
-
learning_rate:
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
max_train_steps: 1200
|
| 22 |
fp8_base: true
|
|
|
|
| 23 |
use_case: "images"
|
| 24 |
-
vram_required_gb:
|
|
|
|
| 25 |
recommended_images: "15-30 high quality photos with detailed captions"
|
| 26 |
-
training_script: "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
flux1_dev:
|
| 29 |
name: "FLUX.1 Dev"
|
|
|
|
| 5 |
# FLUX - Best for photorealistic images (recommended for realistic person)
|
| 6 |
flux2_dev:
|
| 7 |
name: "FLUX.2 Dev (Recommended)"
|
| 8 |
+
description: "Latest FLUX model, 32B params, best quality for realistic person. Uses Mistral text encoder."
|
| 9 |
hf_repo: "black-forest-labs/FLUX.2-dev"
|
| 10 |
+
hf_filename: "flux2-dev.safetensors"
|
| 11 |
+
model_type: "flux2"
|
| 12 |
+
training_framework: "musubi-tuner"
|
| 13 |
resolution: 1024
|
| 14 |
+
learning_rate: 1.0
|
| 15 |
+
network_rank: 64
|
| 16 |
+
network_alpha: 32
|
| 17 |
+
optimizer: "prodigy"
|
| 18 |
+
lr_scheduler: "constant"
|
| 19 |
+
timestep_sampling: "flux2_shift"
|
| 20 |
+
network_module: "networks.lora_flux_2"
|
| 21 |
+
max_train_steps: 50
|
|
|
|
| 22 |
fp8_base: true
|
| 23 |
+
gradient_checkpointing: true
|
| 24 |
use_case: "images"
|
| 25 |
+
vram_required_gb: 48
|
| 26 |
+
recommended_gpu: "NVIDIA RTX A6000"
|
| 27 |
recommended_images: "15-30 high quality photos with detailed captions"
|
| 28 |
+
training_script: "flux_2_train_network.py"
|
| 29 |
+
# Model paths on network volume:
|
| 30 |
+
# DiT: /workspace/models/FLUX.2-dev/flux2-dev.safetensors
|
| 31 |
+
# VAE: /workspace/models/FLUX.2-dev/vae/diffusion_pytorch_model.safetensors
|
| 32 |
+
# Text encoder: /workspace/models/FLUX.2-dev/text_encoder/model-00001-of-00010.safetensors
|
| 33 |
|
| 34 |
flux1_dev:
|
| 35 |
name: "FLUX.1 Dev"
|
src/content_engine/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (279 Bytes)
|
|
|
src/content_engine/__pycache__/config.cpython-311.pyc
DELETED
|
Binary file (6.33 kB)
|
|
|
src/content_engine/__pycache__/main.cpython-311.pyc
DELETED
|
Binary file (10.6 kB)
|
|
|
src/content_engine/api/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (211 Bytes)
|
|
|
src/content_engine/api/__pycache__/routes_catalog.cpython-311.pyc
DELETED
|
Binary file (7.2 kB)
|
|
|
src/content_engine/api/__pycache__/routes_generation.cpython-311.pyc
DELETED
|
Binary file (28 kB)
|
|
|
src/content_engine/api/__pycache__/routes_pod.cpython-311.pyc
DELETED
|
Binary file (25.1 kB)
|
|
|
src/content_engine/api/__pycache__/routes_system.cpython-311.pyc
DELETED
|
Binary file (10.7 kB)
|
|
|
src/content_engine/api/__pycache__/routes_training.cpython-311.pyc
DELETED
|
Binary file (12.6 kB)
|
|
|
src/content_engine/api/__pycache__/routes_ui.cpython-311.pyc
DELETED
|
Binary file (1.23 kB)
|
|
|
src/content_engine/api/__pycache__/routes_video.cpython-311.pyc
DELETED
|
Binary file (13.3 kB)
|
|
|
src/content_engine/api/routes_pod.py
CHANGED
|
@@ -1,12 +1,18 @@
|
|
| 1 |
-
"""RunPod Pod management routes — start/stop GPU pods for generation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import asyncio
|
|
|
|
| 6 |
import logging
|
| 7 |
import os
|
| 8 |
import time
|
| 9 |
import uuid
|
|
|
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
import runpod
|
|
@@ -17,28 +23,85 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
|
| 18 |
router = APIRouter(prefix="/api/pod", tags=["pod"])
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Pod state
|
| 21 |
_pod_state = {
|
| 22 |
"pod_id": None,
|
| 23 |
-
"status": "stopped", # stopped, starting, running, stopping
|
| 24 |
"ip": None,
|
| 25 |
-
"
|
| 26 |
-
"
|
|
|
|
|
|
|
| 27 |
"started_at": None,
|
| 28 |
-
"cost_per_hour": 0.
|
|
|
|
| 29 |
}
|
| 30 |
|
| 31 |
-
|
| 32 |
-
COMFYUI_IMAGE = "timpietruskyblibla/runpod-worker-comfy:3.4.0-flux1-dev"
|
| 33 |
|
| 34 |
-
# GPU options
|
| 35 |
GPU_OPTIONS = {
|
| 36 |
-
"NVIDIA
|
| 37 |
-
"NVIDIA RTX A6000": {"name": "RTX A6000", "vram": 48, "cost": 0.76},
|
| 38 |
-
"NVIDIA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
}
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def _get_api_key() -> str:
|
| 43 |
key = os.environ.get("RUNPOD_API_KEY")
|
| 44 |
if not key:
|
|
@@ -48,7 +111,8 @@ def _get_api_key() -> str:
|
|
| 48 |
|
| 49 |
|
| 50 |
class StartPodRequest(BaseModel):
|
| 51 |
-
gpu_type: str = "NVIDIA
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
class PodStatus(BaseModel):
|
|
@@ -57,7 +121,9 @@ class PodStatus(BaseModel):
|
|
| 57 |
ip: str | None = None
|
| 58 |
port: int | None = None
|
| 59 |
gpu_type: str | None = None
|
|
|
|
| 60 |
cost_per_hour: float | None = None
|
|
|
|
| 61 |
uptime_minutes: float | None = None
|
| 62 |
comfyui_url: str | None = None
|
| 63 |
|
|
@@ -67,44 +133,52 @@ async def get_pod_status():
|
|
| 67 |
"""Get current pod status."""
|
| 68 |
_get_api_key()
|
| 69 |
|
| 70 |
-
# If we have a pod_id, check its actual status
|
| 71 |
if _pod_state["pod_id"]:
|
| 72 |
try:
|
| 73 |
-
pod =
|
|
|
|
|
|
|
|
|
|
| 74 |
if pod:
|
| 75 |
desired = pod.get("desiredStatus", "")
|
| 76 |
if desired == "RUNNING":
|
| 77 |
-
runtime = pod.get("runtime"
|
| 78 |
-
ports = runtime.get("ports"
|
| 79 |
for p in ports:
|
|
|
|
|
|
|
|
|
|
| 80 |
if p.get("privatePort") == 8188:
|
| 81 |
-
_pod_state["
|
| 82 |
-
_pod_state["
|
| 83 |
-
|
|
|
|
| 84 |
elif desired == "EXITED":
|
| 85 |
_pod_state["status"] = "stopped"
|
| 86 |
_pod_state["pod_id"] = None
|
| 87 |
else:
|
| 88 |
_pod_state["status"] = "stopped"
|
| 89 |
_pod_state["pod_id"] = None
|
|
|
|
|
|
|
| 90 |
except Exception as e:
|
| 91 |
logger.warning("Failed to check pod: %s", e)
|
| 92 |
|
| 93 |
uptime = None
|
| 94 |
-
if _pod_state["started_at"] and _pod_state["status"]
|
| 95 |
uptime = (time.time() - _pod_state["started_at"]) / 60
|
| 96 |
|
| 97 |
-
comfyui_url =
|
| 98 |
-
if _pod_state["ip"] and _pod_state["port"]:
|
| 99 |
-
comfyui_url = f"http://{_pod_state['ip']}:{_pod_state['port']}"
|
| 100 |
|
| 101 |
return PodStatus(
|
| 102 |
status=_pod_state["status"],
|
| 103 |
pod_id=_pod_state["pod_id"],
|
| 104 |
ip=_pod_state["ip"],
|
| 105 |
-
port=_pod_state
|
| 106 |
gpu_type=_pod_state["gpu_type"],
|
|
|
|
| 107 |
cost_per_hour=_pod_state["cost_per_hour"],
|
|
|
|
| 108 |
uptime_minutes=uptime,
|
| 109 |
comfyui_url=comfyui_url,
|
| 110 |
)
|
|
@@ -116,12 +190,24 @@ async def list_gpu_options():
|
|
| 116 |
return {"gpus": GPU_OPTIONS}
|
| 117 |
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
@router.post("/start")
|
| 120 |
async def start_pod(request: StartPodRequest):
|
| 121 |
-
"""Start a GPU pod for generation
|
| 122 |
_get_api_key()
|
| 123 |
|
| 124 |
-
if _pod_state["status"]
|
| 125 |
return {"status": "already_running", "pod_id": _pod_state["pod_id"]}
|
| 126 |
|
| 127 |
if _pod_state["status"] == "starting":
|
|
@@ -134,74 +220,284 @@ async def start_pod(request: StartPodRequest):
|
|
| 134 |
_pod_state["status"] = "starting"
|
| 135 |
_pod_state["gpu_type"] = request.gpu_type
|
| 136 |
_pod_state["cost_per_hour"] = gpu_info["cost"]
|
|
|
|
|
|
|
| 137 |
|
| 138 |
try:
|
| 139 |
-
logger.info("Starting RunPod with %s...", request.gpu_type)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
_pod_state["pod_id"] = pod["id"]
|
| 155 |
_pod_state["started_at"] = time.time()
|
|
|
|
| 156 |
|
| 157 |
logger.info("Pod created: %s", pod["id"])
|
| 158 |
|
| 159 |
-
|
| 160 |
-
asyncio.create_task(_wait_for_pod_ready(pod["id"]))
|
| 161 |
|
| 162 |
return {
|
| 163 |
"status": "starting",
|
| 164 |
"pod_id": pod["id"],
|
| 165 |
-
"message": f"Starting {gpu_info['name']} pod (~
|
| 166 |
}
|
| 167 |
|
| 168 |
except Exception as e:
|
| 169 |
_pod_state["status"] = "stopped"
|
|
|
|
| 170 |
logger.error("Failed to start pod: %s", e)
|
| 171 |
raise HTTPException(500, f"Failed to start pod: {e}")
|
| 172 |
|
| 173 |
|
| 174 |
-
async def
|
| 175 |
-
"""
|
| 176 |
start = time.time()
|
|
|
|
|
|
|
| 177 |
|
|
|
|
|
|
|
| 178 |
while time.time() - start < timeout:
|
| 179 |
try:
|
| 180 |
-
pod = runpod.get_pod
|
| 181 |
-
|
| 182 |
if pod and pod.get("desiredStatus") == "RUNNING":
|
| 183 |
-
runtime = pod.get("runtime"
|
| 184 |
-
ports = runtime.get("ports"
|
| 185 |
-
|
| 186 |
for p in ports:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
if p.get("privatePort") == 8188:
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
_pod_state["ip"] = ip
|
| 193 |
-
_pod_state["port"] = int(port)
|
| 194 |
-
_pod_state["status"] = "running"
|
| 195 |
-
logger.info("Pod ready at %s:%s", ip, port)
|
| 196 |
-
return
|
| 197 |
-
|
| 198 |
except Exception as e:
|
| 199 |
logger.debug("Waiting for pod: %s", e)
|
| 200 |
-
|
| 201 |
await asyncio.sleep(5)
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
@router.post("/stop")
|
|
@@ -221,20 +517,23 @@ async def stop_pod():
|
|
| 221 |
pod_id = _pod_state["pod_id"]
|
| 222 |
logger.info("Stopping pod: %s", pod_id)
|
| 223 |
|
| 224 |
-
runpod.terminate_pod
|
| 225 |
|
| 226 |
_pod_state["pod_id"] = None
|
| 227 |
_pod_state["ip"] = None
|
| 228 |
-
_pod_state["
|
|
|
|
| 229 |
_pod_state["status"] = "stopped"
|
| 230 |
_pod_state["started_at"] = None
|
|
|
|
|
|
|
| 231 |
|
| 232 |
logger.info("Pod stopped")
|
| 233 |
return {"status": "stopped", "message": "Pod terminated"}
|
| 234 |
|
| 235 |
except Exception as e:
|
| 236 |
logger.error("Failed to stop pod: %s", e)
|
| 237 |
-
_pod_state["status"] = "running"
|
| 238 |
raise HTTPException(500, f"Failed to stop pod: {e}")
|
| 239 |
|
| 240 |
|
|
@@ -244,10 +543,11 @@ async def list_pod_loras():
|
|
| 244 |
if _pod_state["status"] != "running" or not _pod_state["ip"]:
|
| 245 |
return {"loras": [], "message": "Pod not running"}
|
| 246 |
|
|
|
|
| 247 |
try:
|
| 248 |
import httpx
|
| 249 |
async with httpx.AsyncClient(timeout=30) as client:
|
| 250 |
-
url = f"
|
| 251 |
resp = await client.get(url)
|
| 252 |
if resp.status_code == 200:
|
| 253 |
data = resp.json()
|
|
@@ -256,15 +556,12 @@ async def list_pod_loras():
|
|
| 256 |
except Exception as e:
|
| 257 |
logger.warning("Failed to list pod LoRAs: %s", e)
|
| 258 |
|
| 259 |
-
return {"loras": [], "comfyui_url":
|
| 260 |
|
| 261 |
|
| 262 |
@router.post("/upload-lora")
|
| 263 |
-
async def upload_lora_to_pod(
|
| 264 |
-
file: UploadFile = File(...),
|
| 265 |
-
):
|
| 266 |
"""Upload a LoRA file to the running pod."""
|
| 267 |
-
from fastapi import UploadFile, File
|
| 268 |
import httpx
|
| 269 |
|
| 270 |
if _pod_state["status"] != "running":
|
|
@@ -275,13 +572,12 @@ async def upload_lora_to_pod(
|
|
| 275 |
|
| 276 |
try:
|
| 277 |
content = await file.read()
|
|
|
|
| 278 |
|
| 279 |
async with httpx.AsyncClient(timeout=120) as client:
|
| 280 |
-
|
| 281 |
-
url = f"http://{_pod_state['ip']}:{_pod_state['port']}/upload/image"
|
| 282 |
files = {"image": (file.filename, content, "application/octet-stream")}
|
| 283 |
data = {"subfolder": "loras", "type": "input"}
|
| 284 |
-
|
| 285 |
resp = await client.post(url, files=files, data=data)
|
| 286 |
|
| 287 |
if resp.status_code == 200:
|
|
@@ -326,7 +622,7 @@ async def generate_on_pod(request: PodGenerateRequest):
|
|
| 326 |
job_id = str(uuid.uuid4())[:8]
|
| 327 |
seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1)
|
| 328 |
|
| 329 |
-
|
| 330 |
workflow = _build_flux_workflow(
|
| 331 |
prompt=request.prompt,
|
| 332 |
negative_prompt=request.negative_prompt,
|
|
@@ -337,12 +633,14 @@ async def generate_on_pod(request: PodGenerateRequest):
|
|
| 337 |
seed=seed,
|
| 338 |
lora_name=request.lora_name,
|
| 339 |
lora_strength=request.lora_strength,
|
|
|
|
| 340 |
)
|
| 341 |
|
|
|
|
|
|
|
| 342 |
try:
|
| 343 |
async with httpx.AsyncClient(timeout=30) as client:
|
| 344 |
-
|
| 345 |
-
resp = await client.post(url, json={"prompt": workflow})
|
| 346 |
resp.raise_for_status()
|
| 347 |
|
| 348 |
data = resp.json()
|
|
@@ -356,15 +654,9 @@ async def generate_on_pod(request: PodGenerateRequest):
|
|
| 356 |
}
|
| 357 |
|
| 358 |
logger.info("Pod generation started: %s -> %s", job_id, prompt_id)
|
| 359 |
-
|
| 360 |
-
# Start background task to poll for completion
|
| 361 |
asyncio.create_task(_poll_pod_job(job_id, prompt_id, request.content_rating))
|
| 362 |
|
| 363 |
-
return {
|
| 364 |
-
"job_id": job_id,
|
| 365 |
-
"status": "running",
|
| 366 |
-
"seed": seed,
|
| 367 |
-
}
|
| 368 |
|
| 369 |
except Exception as e:
|
| 370 |
logger.error("Pod generation failed: %s", e)
|
|
@@ -374,38 +666,33 @@ async def generate_on_pod(request: PodGenerateRequest):
|
|
| 374 |
async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str):
|
| 375 |
"""Poll ComfyUI for job completion and save the result."""
|
| 376 |
import httpx
|
| 377 |
-
from pathlib import Path
|
| 378 |
|
| 379 |
start = time.time()
|
| 380 |
-
timeout =
|
|
|
|
| 381 |
|
| 382 |
async with httpx.AsyncClient(timeout=60) as client:
|
| 383 |
while time.time() - start < timeout:
|
| 384 |
try:
|
| 385 |
-
|
| 386 |
-
resp = await client.get(url)
|
| 387 |
|
| 388 |
if resp.status_code == 200:
|
| 389 |
data = resp.json()
|
| 390 |
if prompt_id in data:
|
| 391 |
outputs = data[prompt_id].get("outputs", {})
|
| 392 |
|
| 393 |
-
# Find SaveImage output
|
| 394 |
for node_id, node_output in outputs.items():
|
| 395 |
if "images" in node_output:
|
| 396 |
image_info = node_output["images"][0]
|
| 397 |
filename = image_info["filename"]
|
| 398 |
subfolder = image_info.get("subfolder", "")
|
| 399 |
|
| 400 |
-
# Download the image
|
| 401 |
-
img_url = f"http://{_pod_state['ip']}:{_pod_state['port']}/view"
|
| 402 |
params = {"filename": filename}
|
| 403 |
if subfolder:
|
| 404 |
params["subfolder"] = subfolder
|
| 405 |
|
| 406 |
-
img_resp = await client.get(
|
| 407 |
if img_resp.status_code == 200:
|
| 408 |
-
# Save to local output directory
|
| 409 |
from content_engine.config import settings
|
| 410 |
output_dir = settings.paths.output_dir / "pod" / content_rating / "raw"
|
| 411 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -419,7 +706,6 @@ async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str):
|
|
| 419 |
|
| 420 |
logger.info("Pod generation completed: %s -> %s", job_id, local_path)
|
| 421 |
|
| 422 |
-
# Catalog the image
|
| 423 |
try:
|
| 424 |
from content_engine.services.catalog import CatalogService
|
| 425 |
catalog = CatalogService()
|
|
@@ -463,29 +749,69 @@ def _build_flux_workflow(
|
|
| 463 |
seed: int,
|
| 464 |
lora_name: str | None,
|
| 465 |
lora_strength: float,
|
|
|
|
| 466 |
) -> dict:
|
| 467 |
-
"""Build a ComfyUI workflow for FLUX generation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
# Basic FLUX workflow - compatible with ComfyUI FLUX setup
|
| 470 |
workflow = {
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
},
|
|
|
|
| 475 |
"6": {
|
| 476 |
"class_type": "CLIPTextEncode",
|
| 477 |
"inputs": {
|
| 478 |
"text": prompt,
|
| 479 |
-
"clip":
|
| 480 |
},
|
| 481 |
},
|
|
|
|
| 482 |
"7": {
|
| 483 |
"class_type": "CLIPTextEncode",
|
| 484 |
"inputs": {
|
| 485 |
"text": negative_prompt or "",
|
| 486 |
-
"clip":
|
| 487 |
},
|
| 488 |
},
|
|
|
|
| 489 |
"5": {
|
| 490 |
"class_type": "EmptyLatentImage",
|
| 491 |
"inputs": {
|
|
@@ -494,7 +820,8 @@ def _build_flux_workflow(
|
|
| 494 |
"batch_size": 1,
|
| 495 |
},
|
| 496 |
},
|
| 497 |
-
|
|
|
|
| 498 |
"class_type": "KSampler",
|
| 499 |
"inputs": {
|
| 500 |
"seed": seed,
|
|
@@ -503,19 +830,21 @@ def _build_flux_workflow(
|
|
| 503 |
"sampler_name": "euler",
|
| 504 |
"scheduler": "simple",
|
| 505 |
"denoise": 1.0,
|
| 506 |
-
"model":
|
| 507 |
"positive": ["6", 0],
|
| 508 |
"negative": ["7", 0],
|
| 509 |
"latent_image": ["5", 0],
|
| 510 |
},
|
| 511 |
},
|
|
|
|
| 512 |
"8": {
|
| 513 |
"class_type": "VAEDecode",
|
| 514 |
"inputs": {
|
| 515 |
-
"samples": ["
|
| 516 |
-
"vae":
|
| 517 |
},
|
| 518 |
},
|
|
|
|
| 519 |
"9": {
|
| 520 |
"class_type": "SaveImage",
|
| 521 |
"inputs": {
|
|
@@ -527,19 +856,19 @@ def _build_flux_workflow(
|
|
| 527 |
|
| 528 |
# Add LoRA if specified
|
| 529 |
if lora_name:
|
| 530 |
-
workflow["
|
| 531 |
"class_type": "LoraLoader",
|
| 532 |
"inputs": {
|
| 533 |
"lora_name": lora_name,
|
| 534 |
"strength_model": lora_strength,
|
| 535 |
"strength_clip": lora_strength,
|
| 536 |
-
"model":
|
| 537 |
-
"clip":
|
| 538 |
},
|
| 539 |
}
|
| 540 |
-
# Rewire sampler to use LoRA output
|
| 541 |
-
workflow["
|
| 542 |
-
workflow["6"]["inputs"]["clip"] = ["
|
| 543 |
-
workflow["7"]["inputs"]["clip"] = ["
|
| 544 |
|
| 545 |
return workflow
|
|
|
|
| 1 |
+
"""RunPod Pod management routes — start/stop GPU pods for generation.
|
| 2 |
+
|
| 3 |
+
Starts a persistent ComfyUI pod with network volume access.
|
| 4 |
+
Models and LoRAs are loaded from the shared network volume.
|
| 5 |
+
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
import asyncio
|
| 10 |
+
import json
|
| 11 |
import logging
|
| 12 |
import os
|
| 13 |
import time
|
| 14 |
import uuid
|
| 15 |
+
from pathlib import Path
|
| 16 |
from typing import Any
|
| 17 |
|
| 18 |
import runpod
|
|
|
|
| 23 |
|
| 24 |
router = APIRouter(prefix="/api/pod", tags=["pod"])
|
| 25 |
|
| 26 |
+
# Persist pod state to disk so it survives server restarts
|
| 27 |
+
_POD_STATE_FILE = Path(__file__).parent.parent.parent.parent / "pod_state.json"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _save_pod_state():
|
| 31 |
+
"""Save pod state to disk."""
|
| 32 |
+
try:
|
| 33 |
+
data = {k: v for k, v in _pod_state.items() if k != "setup_status"}
|
| 34 |
+
_POD_STATE_FILE.write_text(json.dumps(data))
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.warning("Failed to save pod state: %s", e)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _load_pod_state():
|
| 40 |
+
"""Load pod state from disk on startup."""
|
| 41 |
+
try:
|
| 42 |
+
if _POD_STATE_FILE.exists():
|
| 43 |
+
data = json.loads(_POD_STATE_FILE.read_text())
|
| 44 |
+
for k, v in data.items():
|
| 45 |
+
if k in _pod_state:
|
| 46 |
+
_pod_state[k] = v
|
| 47 |
+
logger.info("Restored pod state: pod_id=%s status=%s", _pod_state.get("pod_id"), _pod_state.get("status"))
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.warning("Failed to load pod state: %s", e)
|
| 50 |
+
|
| 51 |
+
def _get_volume_config() -> tuple[str, str]:
|
| 52 |
+
"""Get network volume config at runtime (after dotenv loads)."""
|
| 53 |
+
return (
|
| 54 |
+
os.environ.get("RUNPOD_VOLUME_ID", ""),
|
| 55 |
+
os.environ.get("RUNPOD_VOLUME_DC", ""),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Docker image — PyTorch base with CUDA, we install ComfyUI ourselves
|
| 59 |
+
DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
|
| 60 |
+
|
| 61 |
# Pod state
|
| 62 |
_pod_state = {
|
| 63 |
"pod_id": None,
|
| 64 |
+
"status": "stopped", # stopped, starting, setting_up, running, stopping
|
| 65 |
"ip": None,
|
| 66 |
+
"ssh_port": None,
|
| 67 |
+
"comfyui_port": None,
|
| 68 |
+
"gpu_type": "NVIDIA RTX A6000",
|
| 69 |
+
"model_type": "flux2",
|
| 70 |
"started_at": None,
|
| 71 |
+
"cost_per_hour": 0.76,
|
| 72 |
+
"setup_status": None,
|
| 73 |
}
|
| 74 |
|
| 75 |
+
_load_pod_state()
|
|
|
|
| 76 |
|
| 77 |
+
# GPU options (same as training)
|
| 78 |
GPU_OPTIONS = {
|
| 79 |
+
"NVIDIA A40": {"name": "A40 48GB", "vram": 48, "cost": 0.64},
|
| 80 |
+
"NVIDIA RTX A6000": {"name": "RTX A6000 48GB", "vram": 48, "cost": 0.76},
|
| 81 |
+
"NVIDIA L40": {"name": "L40 48GB", "vram": 48, "cost": 0.89},
|
| 82 |
+
"NVIDIA L40S": {"name": "L40S 48GB", "vram": 48, "cost": 1.09},
|
| 83 |
+
"NVIDIA A100-SXM4-80GB": {"name": "A100 SXM 80GB", "vram": 80, "cost": 1.64},
|
| 84 |
+
"NVIDIA A100 80GB PCIe": {"name": "A100 PCIe 80GB", "vram": 80, "cost": 1.89},
|
| 85 |
+
"NVIDIA H100 80GB HBM3": {"name": "H100 80GB", "vram": 80, "cost": 3.89},
|
| 86 |
+
"NVIDIA GeForce RTX 5090": {"name": "RTX 5090 32GB", "vram": 32, "cost": 0.69},
|
| 87 |
+
"NVIDIA GeForce RTX 4090": {"name": "RTX 4090 24GB", "vram": 24, "cost": 0.44},
|
| 88 |
+
"NVIDIA GeForce RTX 3090": {"name": "RTX 3090 24GB", "vram": 24, "cost": 0.22},
|
| 89 |
}
|
| 90 |
|
| 91 |
|
| 92 |
+
def _get_comfyui_url() -> str | None:
|
| 93 |
+
"""Get the ComfyUI URL via RunPod's HTTPS proxy.
|
| 94 |
+
|
| 95 |
+
RunPod HTTP ports are only accessible through their proxy at
|
| 96 |
+
https://{pod_id}-{private_port}.proxy.runpod.net
|
| 97 |
+
The raw IP:port from the API is an internal address, not publicly routable.
|
| 98 |
+
"""
|
| 99 |
+
pod_id = _pod_state.get("pod_id")
|
| 100 |
+
if pod_id:
|
| 101 |
+
return f"https://{pod_id}-8188.proxy.runpod.net"
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
def _get_api_key() -> str:
|
| 106 |
key = os.environ.get("RUNPOD_API_KEY")
|
| 107 |
if not key:
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
class StartPodRequest(BaseModel):
|
| 114 |
+
gpu_type: str = "NVIDIA RTX A6000"
|
| 115 |
+
model_type: str = "flux2"
|
| 116 |
|
| 117 |
|
| 118 |
class PodStatus(BaseModel):
|
|
|
|
| 121 |
ip: str | None = None
|
| 122 |
port: int | None = None
|
| 123 |
gpu_type: str | None = None
|
| 124 |
+
model_type: str | None = None
|
| 125 |
cost_per_hour: float | None = None
|
| 126 |
+
setup_status: str | None = None
|
| 127 |
uptime_minutes: float | None = None
|
| 128 |
comfyui_url: str | None = None
|
| 129 |
|
|
|
|
| 133 |
"""Get current pod status."""
|
| 134 |
_get_api_key()
|
| 135 |
|
|
|
|
| 136 |
if _pod_state["pod_id"]:
|
| 137 |
try:
|
| 138 |
+
pod = await asyncio.wait_for(
|
| 139 |
+
asyncio.to_thread(runpod.get_pod, _pod_state["pod_id"]),
|
| 140 |
+
timeout=10,
|
| 141 |
+
)
|
| 142 |
if pod:
|
| 143 |
desired = pod.get("desiredStatus", "")
|
| 144 |
if desired == "RUNNING":
|
| 145 |
+
runtime = pod.get("runtime") or {}
|
| 146 |
+
ports = runtime.get("ports") or []
|
| 147 |
for p in ports:
|
| 148 |
+
if p.get("privatePort") == 22:
|
| 149 |
+
_pod_state["ssh_ip"] = p.get("ip")
|
| 150 |
+
_pod_state["ssh_port"] = p.get("publicPort")
|
| 151 |
if p.get("privatePort") == 8188:
|
| 152 |
+
_pod_state["comfyui_ip"] = p.get("ip")
|
| 153 |
+
_pod_state["comfyui_port"] = p.get("publicPort")
|
| 154 |
+
# Use SSH IP as the main IP for display
|
| 155 |
+
_pod_state["ip"] = _pod_state.get("ssh_ip") or _pod_state.get("comfyui_ip")
|
| 156 |
elif desired == "EXITED":
|
| 157 |
_pod_state["status"] = "stopped"
|
| 158 |
_pod_state["pod_id"] = None
|
| 159 |
else:
|
| 160 |
_pod_state["status"] = "stopped"
|
| 161 |
_pod_state["pod_id"] = None
|
| 162 |
+
except asyncio.TimeoutError:
|
| 163 |
+
logger.warning("RunPod API timeout checking pod status")
|
| 164 |
except Exception as e:
|
| 165 |
logger.warning("Failed to check pod: %s", e)
|
| 166 |
|
| 167 |
uptime = None
|
| 168 |
+
if _pod_state["started_at"] and _pod_state["status"] in ("running", "setting_up"):
|
| 169 |
uptime = (time.time() - _pod_state["started_at"]) / 60
|
| 170 |
|
| 171 |
+
comfyui_url = _get_comfyui_url()
|
|
|
|
|
|
|
| 172 |
|
| 173 |
return PodStatus(
|
| 174 |
status=_pod_state["status"],
|
| 175 |
pod_id=_pod_state["pod_id"],
|
| 176 |
ip=_pod_state["ip"],
|
| 177 |
+
port=_pod_state.get("comfyui_port"),
|
| 178 |
gpu_type=_pod_state["gpu_type"],
|
| 179 |
+
model_type=_pod_state.get("model_type", "flux2"),
|
| 180 |
cost_per_hour=_pod_state["cost_per_hour"],
|
| 181 |
+
setup_status=_pod_state.get("setup_status"),
|
| 182 |
uptime_minutes=uptime,
|
| 183 |
comfyui_url=comfyui_url,
|
| 184 |
)
|
|
|
|
| 190 |
return {"gpus": GPU_OPTIONS}
|
| 191 |
|
| 192 |
|
| 193 |
+
@router.get("/model-options")
|
| 194 |
+
async def list_model_options():
|
| 195 |
+
"""List available model types for the pod."""
|
| 196 |
+
return {
|
| 197 |
+
"models": {
|
| 198 |
+
"flux2": {"name": "FLUX.2 Dev", "description": "Best for realistic txt2img (requires 48GB+ VRAM)", "use_case": "txt2img"},
|
| 199 |
+
"flux1": {"name": "FLUX.1 Dev", "description": "Previous gen FLUX txt2img", "use_case": "txt2img"},
|
| 200 |
+
"wan22": {"name": "WAN 2.2", "description": "Image-to-video and general generation", "use_case": "img2video"},
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
@router.post("/start")
|
| 206 |
async def start_pod(request: StartPodRequest):
|
| 207 |
+
"""Start a GPU pod with ComfyUI for generation."""
|
| 208 |
_get_api_key()
|
| 209 |
|
| 210 |
+
if _pod_state["status"] in ("running", "setting_up"):
|
| 211 |
return {"status": "already_running", "pod_id": _pod_state["pod_id"]}
|
| 212 |
|
| 213 |
if _pod_state["status"] == "starting":
|
|
|
|
| 220 |
_pod_state["status"] = "starting"
|
| 221 |
_pod_state["gpu_type"] = request.gpu_type
|
| 222 |
_pod_state["cost_per_hour"] = gpu_info["cost"]
|
| 223 |
+
_pod_state["model_type"] = request.model_type
|
| 224 |
+
_pod_state["setup_status"] = "Creating pod..."
|
| 225 |
|
| 226 |
try:
|
| 227 |
+
logger.info("Starting RunPod with %s for %s...", request.gpu_type, request.model_type)
|
| 228 |
+
|
| 229 |
+
pod_kwargs = {
|
| 230 |
+
"container_disk_in_gb": 30,
|
| 231 |
+
"ports": "22/tcp,8188/http",
|
| 232 |
+
"docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
volume_id, volume_dc = _get_volume_config()
|
| 236 |
+
if volume_id:
|
| 237 |
+
pod_kwargs["network_volume_id"] = volume_id
|
| 238 |
+
if volume_dc:
|
| 239 |
+
pod_kwargs["data_center_id"] = volume_dc
|
| 240 |
+
logger.info("Using network volume: %s (DC: %s)", volume_id, volume_dc)
|
| 241 |
+
else:
|
| 242 |
+
pod_kwargs["volume_in_gb"] = 75
|
| 243 |
+
logger.warning("No network volume configured — using temporary volume")
|
| 244 |
+
|
| 245 |
+
pod = await asyncio.to_thread(
|
| 246 |
+
runpod.create_pod,
|
| 247 |
+
f"comfyui-gen-{request.model_type}",
|
| 248 |
+
DOCKER_IMAGE,
|
| 249 |
+
request.gpu_type,
|
| 250 |
+
**pod_kwargs,
|
| 251 |
)
|
| 252 |
|
| 253 |
_pod_state["pod_id"] = pod["id"]
|
| 254 |
_pod_state["started_at"] = time.time()
|
| 255 |
+
_save_pod_state()
|
| 256 |
|
| 257 |
logger.info("Pod created: %s", pod["id"])
|
| 258 |
|
| 259 |
+
asyncio.create_task(_wait_and_setup_pod(pod["id"], request.model_type))
|
|
|
|
| 260 |
|
| 261 |
return {
|
| 262 |
"status": "starting",
|
| 263 |
"pod_id": pod["id"],
|
| 264 |
+
"message": f"Starting {gpu_info['name']} pod (~5-8 min for setup)",
|
| 265 |
}
|
| 266 |
|
| 267 |
except Exception as e:
|
| 268 |
_pod_state["status"] = "stopped"
|
| 269 |
+
_pod_state["setup_status"] = None
|
| 270 |
logger.error("Failed to start pod: %s", e)
|
| 271 |
raise HTTPException(500, f"Failed to start pod: {e}")
|
| 272 |
|
| 273 |
|
| 274 |
+
async def _wait_and_setup_pod(pod_id: str, model_type: str, timeout: int = 600):
|
| 275 |
+
"""Wait for pod to be ready, then install ComfyUI and link models via SSH."""
|
| 276 |
start = time.time()
|
| 277 |
+
ssh_host = None
|
| 278 |
+
ssh_port = None
|
| 279 |
|
| 280 |
+
# Phase 1: Wait for SSH to be available
|
| 281 |
+
_pod_state["setup_status"] = "Waiting for pod to start..."
|
| 282 |
while time.time() - start < timeout:
|
| 283 |
try:
|
| 284 |
+
pod = await asyncio.to_thread(runpod.get_pod, pod_id)
|
|
|
|
| 285 |
if pod and pod.get("desiredStatus") == "RUNNING":
|
| 286 |
+
runtime = pod.get("runtime") or {}
|
| 287 |
+
ports = runtime.get("ports") or []
|
|
|
|
| 288 |
for p in ports:
|
| 289 |
+
if p.get("privatePort") == 22:
|
| 290 |
+
ssh_host = p.get("ip")
|
| 291 |
+
ssh_port = p.get("publicPort")
|
| 292 |
+
_pod_state["ssh_ip"] = ssh_host
|
| 293 |
+
_pod_state["ssh_port"] = ssh_port
|
| 294 |
+
_pod_state["ip"] = ssh_host
|
| 295 |
if p.get("privatePort") == 8188:
|
| 296 |
+
_pod_state["comfyui_ip"] = p.get("ip")
|
| 297 |
+
_pod_state["comfyui_port"] = p.get("publicPort")
|
| 298 |
+
if ssh_host and ssh_port:
|
| 299 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
except Exception as e:
|
| 301 |
logger.debug("Waiting for pod: %s", e)
|
|
|
|
| 302 |
await asyncio.sleep(5)
|
| 303 |
|
| 304 |
+
if not ssh_host or not ssh_port:
|
| 305 |
+
logger.error("Pod did not become ready within %ds", timeout)
|
| 306 |
+
_pod_state["status"] = "stopped"
|
| 307 |
+
_pod_state["setup_status"] = "Failed: pod did not start"
|
| 308 |
+
return
|
| 309 |
+
|
| 310 |
+
# Phase 2: SSH in and set up ComfyUI
|
| 311 |
+
_pod_state["status"] = "setting_up"
|
| 312 |
+
_pod_state["setup_status"] = "Connecting via SSH..."
|
| 313 |
+
|
| 314 |
+
import paramiko
|
| 315 |
+
ssh = paramiko.SSHClient()
|
| 316 |
+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
| 317 |
+
|
| 318 |
+
for attempt in range(30):
|
| 319 |
+
try:
|
| 320 |
+
await asyncio.to_thread(
|
| 321 |
+
ssh.connect, ssh_host, port=int(ssh_port),
|
| 322 |
+
username="root", password="runpod", timeout=10,
|
| 323 |
+
)
|
| 324 |
+
break
|
| 325 |
+
except Exception:
|
| 326 |
+
if attempt == 29:
|
| 327 |
+
_pod_state["setup_status"] = "Failed: SSH connection error"
|
| 328 |
+
_pod_state["status"] = "stopped"
|
| 329 |
+
return
|
| 330 |
+
await asyncio.sleep(5)
|
| 331 |
+
|
| 332 |
+
transport = ssh.get_transport()
|
| 333 |
+
transport.set_keepalive(30)
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
# Symlink network volume
|
| 337 |
+
volume_id, _ = _get_volume_config()
|
| 338 |
+
if volume_id:
|
| 339 |
+
await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models /runpod-volume/loras")
|
| 340 |
+
await _ssh_exec_async(ssh, "rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
|
| 341 |
+
|
| 342 |
+
# Install ComfyUI (cache on volume for reuse)
|
| 343 |
+
comfy_dir = "/workspace/ComfyUI"
|
| 344 |
+
_pod_state["setup_status"] = "Installing ComfyUI..."
|
| 345 |
+
|
| 346 |
+
comfy_exists = (await _ssh_exec_async(ssh, f"test -f {comfy_dir}/main.py && echo EXISTS || echo MISSING")).strip()
|
| 347 |
+
if comfy_exists == "EXISTS":
|
| 348 |
+
logger.info("ComfyUI already installed")
|
| 349 |
+
_pod_state["setup_status"] = "ComfyUI found, updating..."
|
| 350 |
+
await _ssh_exec_async(ssh, f"cd {comfy_dir} && git pull 2>&1 | tail -3", timeout=120)
|
| 351 |
+
else:
|
| 352 |
+
# Check volume cache
|
| 353 |
+
vol_comfy = (await _ssh_exec_async(ssh, "test -f /runpod-volume/ComfyUI/main.py && echo EXISTS || echo MISSING")).strip()
|
| 354 |
+
if vol_comfy == "EXISTS":
|
| 355 |
+
_pod_state["setup_status"] = "Restoring ComfyUI from volume..."
|
| 356 |
+
await _ssh_exec_async(ssh, f"cp -r /runpod-volume/ComfyUI {comfy_dir}", timeout=300)
|
| 357 |
+
else:
|
| 358 |
+
_pod_state["setup_status"] = "Cloning ComfyUI (first time, ~2 min)..."
|
| 359 |
+
await _ssh_exec_async(ssh, f"cd /workspace && git clone --depth 1 https://github.com/comfyanonymous/ComfyUI.git", timeout=300)
|
| 360 |
+
await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600)
|
| 361 |
+
# Cache to volume
|
| 362 |
+
volume_id, _ = _get_volume_config()
|
| 363 |
+
if volume_id:
|
| 364 |
+
await _ssh_exec_async(ssh, f"cp -r {comfy_dir} /runpod-volume/ComfyUI", timeout=300)
|
| 365 |
+
|
| 366 |
+
# Install pip deps that aren't in ComfyUI requirements
|
| 367 |
+
_pod_state["setup_status"] = "Installing dependencies..."
|
| 368 |
+
await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600)
|
| 369 |
+
await _ssh_exec_async(ssh, "pip install aiohttp einops sqlalchemy 2>&1 | tail -3", timeout=120)
|
| 370 |
+
|
| 371 |
+
# Symlink models into ComfyUI directories
|
| 372 |
+
_pod_state["setup_status"] = "Linking models..."
|
| 373 |
+
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/checkpoints {comfy_dir}/models/vae {comfy_dir}/models/loras {comfy_dir}/models/text_encoders")
|
| 374 |
+
|
| 375 |
+
if model_type == "flux2":
|
| 376 |
+
# FLUX.2 Dev — separate UNet, text encoder, and VAE
|
| 377 |
+
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
|
| 378 |
+
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/flux2-dev.safetensors {comfy_dir}/models/diffusion_models/flux2-dev.safetensors")
|
| 379 |
+
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
|
| 380 |
+
|
| 381 |
+
# Text encoder — use Comfy-Org's pre-converted single-file version
|
| 382 |
+
# (HF sharded format is incompatible with ComfyUI's CLIPLoader)
|
| 383 |
+
te_file = "/runpod-volume/models/mistral_3_small_flux2_fp8.safetensors"
|
| 384 |
+
te_exists = (await _ssh_exec_async(ssh, f"test -f {te_file} && echo EXISTS || echo MISSING")).strip()
|
| 385 |
+
if te_exists != "EXISTS":
|
| 386 |
+
_pod_state["setup_status"] = "Downloading FLUX.2 text encoder (~12GB, first time only)..."
|
| 387 |
+
await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
|
| 388 |
+
await _ssh_exec_async(ssh, f"""python -c "
|
| 389 |
+
from huggingface_hub import hf_hub_download
|
| 390 |
+
hf_hub_download(
|
| 391 |
+
repo_id='Comfy-Org/flux2-dev',
|
| 392 |
+
filename='split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors',
|
| 393 |
+
local_dir='/tmp/flux2_te',
|
| 394 |
+
)
|
| 395 |
+
import shutil
|
| 396 |
+
shutil.move('/tmp/flux2_te/split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors', '{te_file}')
|
| 397 |
+
print('Text encoder downloaded')
|
| 398 |
+
" 2>&1 | tail -5""", timeout=1800)
|
| 399 |
+
await _ssh_exec_async(ssh, f"ln -sf {te_file} {comfy_dir}/models/text_encoders/mistral_3_small_flux2_fp8.safetensors")
|
| 400 |
+
# Remove old sharded loader patch if present
|
| 401 |
+
await _ssh_exec_async(ssh, f"rm -f {comfy_dir}/custom_nodes/sharded_loader.py")
|
| 402 |
+
elif model_type == "flux1":
|
| 403 |
+
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/flux1-dev.safetensors {comfy_dir}/models/checkpoints/flux1-dev.safetensors")
|
| 404 |
+
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
|
| 405 |
+
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/clip_l.safetensors {comfy_dir}/models/text_encoders/clip_l.safetensors")
|
| 406 |
+
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/t5xxl_fp16.safetensors {comfy_dir}/models/text_encoders/t5xxl_fp16.safetensors")
|
| 407 |
+
elif model_type == "wan22":
|
| 408 |
+
# WAN 2.2 Image-to-Video (14B params)
|
| 409 |
+
wan_dir = "/workspace/models/Wan2.2-I2V-A14B"
|
| 410 |
+
wan_exists = (await _ssh_exec_async(ssh, f"test -d {wan_dir} && echo EXISTS || echo MISSING")).strip()
|
| 411 |
+
if wan_exists != "EXISTS":
|
| 412 |
+
_pod_state["setup_status"] = "Downloading WAN 2.2 model (~28GB, first time only)..."
|
| 413 |
+
await _ssh_exec_async(ssh, f"pip install huggingface_hub 2>&1 | tail -1", timeout=60)
|
| 414 |
+
await _ssh_exec_async(ssh, f"""python -c "
|
| 415 |
+
from huggingface_hub import snapshot_download
|
| 416 |
+
snapshot_download('Wan-AI/Wan2.2-I2V-A14B', local_dir='{wan_dir}', ignore_patterns=['*.md', '*.txt'])
|
| 417 |
+
print('WAN 2.2 downloaded')
|
| 418 |
+
" 2>&1 | tail -10""", timeout=3600)
|
| 419 |
+
# Symlink WAN model to ComfyUI diffusion_models dir
|
| 420 |
+
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
|
| 421 |
+
await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/diffusion_models/Wan2.2-I2V-A14B")
|
| 422 |
+
# Also need a VAE and text encoder for WAN — they use their own
|
| 423 |
+
await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/checkpoints/Wan2.2-I2V-A14B")
|
| 424 |
+
|
| 425 |
+
# Install ComfyUI-WanVideoWrapper custom nodes
|
| 426 |
+
_pod_state["setup_status"] = "Installing WAN 2.2 ComfyUI nodes..."
|
| 427 |
+
wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper"
|
| 428 |
+
wan_nodes_exist = (await _ssh_exec_async(ssh, f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip()
|
| 429 |
+
if wan_nodes_exist != "EXISTS":
|
| 430 |
+
await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120)
|
| 431 |
+
await _ssh_exec_async(ssh, f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
|
| 432 |
+
|
| 433 |
+
# Symlink all LoRAs from volume
|
| 434 |
+
await _ssh_exec_async(ssh, f"ls /runpod-volume/loras/*.safetensors 2>/dev/null | while read f; do ln -sf \"$f\" {comfy_dir}/models/loras/; done")
|
| 435 |
+
|
| 436 |
+
# Start ComfyUI in background (fire-and-forget — don't wait for output)
|
| 437 |
+
_pod_state["setup_status"] = "Starting ComfyUI..."
|
| 438 |
+
await asyncio.to_thread(
|
| 439 |
+
_ssh_exec_fire_and_forget,
|
| 440 |
+
ssh,
|
| 441 |
+
f"cd {comfy_dir} && python main.py --listen 0.0.0.0 --port 8188 --fp8_e4m3fn-unet > /tmp/comfyui.log 2>&1",
|
| 442 |
+
)
|
| 443 |
+
await asyncio.sleep(2) # Give it a moment to start
|
| 444 |
+
|
| 445 |
+
# Wait for ComfyUI HTTP to respond
|
| 446 |
+
_pod_state["setup_status"] = "Waiting for ComfyUI to load model..."
|
| 447 |
+
import httpx
|
| 448 |
+
comfyui_url = _get_comfyui_url()
|
| 449 |
+
for attempt in range(120): # Up to 10 minutes
|
| 450 |
+
try:
|
| 451 |
+
async with httpx.AsyncClient(timeout=5) as client:
|
| 452 |
+
resp = await client.get(f"{comfyui_url}/system_stats")
|
| 453 |
+
if resp.status_code == 200:
|
| 454 |
+
_pod_state["status"] = "running"
|
| 455 |
+
_pod_state["setup_status"] = "Ready"
|
| 456 |
+
_save_pod_state()
|
| 457 |
+
logger.info("ComfyUI ready at %s", comfyui_url)
|
| 458 |
+
return
|
| 459 |
+
except Exception:
|
| 460 |
+
pass
|
| 461 |
+
await asyncio.sleep(5)
|
| 462 |
+
|
| 463 |
+
# If we get here, ComfyUI didn't start
|
| 464 |
+
# Check the log for errors
|
| 465 |
+
log_tail = await _ssh_exec_async(ssh, "tail -20 /tmp/comfyui.log")
|
| 466 |
+
logger.error("ComfyUI didn't start. Log: %s", log_tail)
|
| 467 |
+
_pod_state["setup_status"] = f"ComfyUI failed to start. Check logs."
|
| 468 |
+
_pod_state["status"] = "setting_up" # Keep pod running so user can debug
|
| 469 |
+
|
| 470 |
+
except Exception as e:
|
| 471 |
+
import traceback
|
| 472 |
+
err_msg = f"{type(e).__name__}: {e}"
|
| 473 |
+
logger.error("Pod setup failed: %s\n%s", err_msg, traceback.format_exc())
|
| 474 |
+
_pod_state["setup_status"] = f"Setup failed: {err_msg}"
|
| 475 |
+
_pod_state["status"] = "setting_up" # Keep pod running so user can debug
|
| 476 |
+
finally:
|
| 477 |
+
try:
|
| 478 |
+
ssh.close()
|
| 479 |
+
except Exception:
|
| 480 |
+
pass
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def _ssh_exec(ssh, cmd: str, timeout: int = 120) -> str:
|
| 484 |
+
"""Execute a command over SSH and return stdout (blocking — call from async via to_thread or background task)."""
|
| 485 |
+
_, stdout, stderr = ssh.exec_command(cmd, timeout=timeout)
|
| 486 |
+
out = stdout.read().decode("utf-8", errors="replace")
|
| 487 |
+
return out.strip()
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
async def _ssh_exec_async(ssh, cmd: str, timeout: int = 120) -> str:
|
| 491 |
+
"""Async wrapper for SSH exec that doesn't block the event loop."""
|
| 492 |
+
return await asyncio.to_thread(_ssh_exec, ssh, cmd, timeout)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _ssh_exec_fire_and_forget(ssh, cmd: str):
|
| 496 |
+
"""Start a command over SSH without waiting for output (for background processes)."""
|
| 497 |
+
transport = ssh.get_transport()
|
| 498 |
+
channel = transport.open_session()
|
| 499 |
+
channel.exec_command(cmd)
|
| 500 |
+
# Don't read stdout/stderr — just let it run
|
| 501 |
|
| 502 |
|
| 503 |
@router.post("/stop")
|
|
|
|
| 517 |
pod_id = _pod_state["pod_id"]
|
| 518 |
logger.info("Stopping pod: %s", pod_id)
|
| 519 |
|
| 520 |
+
await asyncio.to_thread(runpod.terminate_pod, pod_id)
|
| 521 |
|
| 522 |
_pod_state["pod_id"] = None
|
| 523 |
_pod_state["ip"] = None
|
| 524 |
+
_pod_state["ssh_port"] = None
|
| 525 |
+
_pod_state["comfyui_port"] = None
|
| 526 |
_pod_state["status"] = "stopped"
|
| 527 |
_pod_state["started_at"] = None
|
| 528 |
+
_pod_state["setup_status"] = None
|
| 529 |
+
_save_pod_state()
|
| 530 |
|
| 531 |
logger.info("Pod stopped")
|
| 532 |
return {"status": "stopped", "message": "Pod terminated"}
|
| 533 |
|
| 534 |
except Exception as e:
|
| 535 |
logger.error("Failed to stop pod: %s", e)
|
| 536 |
+
_pod_state["status"] = "running"
|
| 537 |
raise HTTPException(500, f"Failed to stop pod: {e}")
|
| 538 |
|
| 539 |
|
|
|
|
| 543 |
if _pod_state["status"] != "running" or not _pod_state["ip"]:
|
| 544 |
return {"loras": [], "message": "Pod not running"}
|
| 545 |
|
| 546 |
+
comfyui_url = _get_comfyui_url()
|
| 547 |
try:
|
| 548 |
import httpx
|
| 549 |
async with httpx.AsyncClient(timeout=30) as client:
|
| 550 |
+
url = f"{comfyui_url}/object_info/LoraLoader"
|
| 551 |
resp = await client.get(url)
|
| 552 |
if resp.status_code == 200:
|
| 553 |
data = resp.json()
|
|
|
|
| 556 |
except Exception as e:
|
| 557 |
logger.warning("Failed to list pod LoRAs: %s", e)
|
| 558 |
|
| 559 |
+
return {"loras": [], "comfyui_url": comfyui_url}
|
| 560 |
|
| 561 |
|
| 562 |
@router.post("/upload-lora")
|
| 563 |
+
async def upload_lora_to_pod(file: UploadFile = File(...)):
|
|
|
|
|
|
|
| 564 |
"""Upload a LoRA file to the running pod."""
|
|
|
|
| 565 |
import httpx
|
| 566 |
|
| 567 |
if _pod_state["status"] != "running":
|
|
|
|
| 572 |
|
| 573 |
try:
|
| 574 |
content = await file.read()
|
| 575 |
+
comfyui_url = _get_comfyui_url()
|
| 576 |
|
| 577 |
async with httpx.AsyncClient(timeout=120) as client:
|
| 578 |
+
url = f"{comfyui_url}/upload/image"
|
|
|
|
| 579 |
files = {"image": (file.filename, content, "application/octet-stream")}
|
| 580 |
data = {"subfolder": "loras", "type": "input"}
|
|
|
|
| 581 |
resp = await client.post(url, files=files, data=data)
|
| 582 |
|
| 583 |
if resp.status_code == 200:
|
|
|
|
| 622 |
job_id = str(uuid.uuid4())[:8]
|
| 623 |
seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1)
|
| 624 |
|
| 625 |
+
model_type = _pod_state.get("model_type", "flux2")
|
| 626 |
workflow = _build_flux_workflow(
|
| 627 |
prompt=request.prompt,
|
| 628 |
negative_prompt=request.negative_prompt,
|
|
|
|
| 633 |
seed=seed,
|
| 634 |
lora_name=request.lora_name,
|
| 635 |
lora_strength=request.lora_strength,
|
| 636 |
+
model_type=model_type,
|
| 637 |
)
|
| 638 |
|
| 639 |
+
comfyui_url = _get_comfyui_url()
|
| 640 |
+
|
| 641 |
try:
|
| 642 |
async with httpx.AsyncClient(timeout=30) as client:
|
| 643 |
+
resp = await client.post(f"{comfyui_url}/prompt", json={"prompt": workflow})
|
|
|
|
| 644 |
resp.raise_for_status()
|
| 645 |
|
| 646 |
data = resp.json()
|
|
|
|
| 654 |
}
|
| 655 |
|
| 656 |
logger.info("Pod generation started: %s -> %s", job_id, prompt_id)
|
|
|
|
|
|
|
| 657 |
asyncio.create_task(_poll_pod_job(job_id, prompt_id, request.content_rating))
|
| 658 |
|
| 659 |
+
return {"job_id": job_id, "status": "running", "seed": seed}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
except Exception as e:
|
| 662 |
logger.error("Pod generation failed: %s", e)
|
|
|
|
| 666 |
async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str):
|
| 667 |
"""Poll ComfyUI for job completion and save the result."""
|
| 668 |
import httpx
|
|
|
|
| 669 |
|
| 670 |
start = time.time()
|
| 671 |
+
timeout = 600 # 10 min — first gen can take 5+ min for model loading
|
| 672 |
+
comfyui_url = _get_comfyui_url()
|
| 673 |
|
| 674 |
async with httpx.AsyncClient(timeout=60) as client:
|
| 675 |
while time.time() - start < timeout:
|
| 676 |
try:
|
| 677 |
+
resp = await client.get(f"{comfyui_url}/history/{prompt_id}")
|
|
|
|
| 678 |
|
| 679 |
if resp.status_code == 200:
|
| 680 |
data = resp.json()
|
| 681 |
if prompt_id in data:
|
| 682 |
outputs = data[prompt_id].get("outputs", {})
|
| 683 |
|
|
|
|
| 684 |
for node_id, node_output in outputs.items():
|
| 685 |
if "images" in node_output:
|
| 686 |
image_info = node_output["images"][0]
|
| 687 |
filename = image_info["filename"]
|
| 688 |
subfolder = image_info.get("subfolder", "")
|
| 689 |
|
|
|
|
|
|
|
| 690 |
params = {"filename": filename}
|
| 691 |
if subfolder:
|
| 692 |
params["subfolder"] = subfolder
|
| 693 |
|
| 694 |
+
img_resp = await client.get(f"{comfyui_url}/view", params=params)
|
| 695 |
if img_resp.status_code == 200:
|
|
|
|
| 696 |
from content_engine.config import settings
|
| 697 |
output_dir = settings.paths.output_dir / "pod" / content_rating / "raw"
|
| 698 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 706 |
|
| 707 |
logger.info("Pod generation completed: %s -> %s", job_id, local_path)
|
| 708 |
|
|
|
|
| 709 |
try:
|
| 710 |
from content_engine.services.catalog import CatalogService
|
| 711 |
catalog = CatalogService()
|
|
|
|
| 749 |
seed: int,
|
| 750 |
lora_name: str | None,
|
| 751 |
lora_strength: float,
|
| 752 |
+
model_type: str = "flux2",
|
| 753 |
) -> dict:
|
| 754 |
+
"""Build a ComfyUI workflow for FLUX generation.
|
| 755 |
+
|
| 756 |
+
FLUX.2 Dev uses separate model components (not a single checkpoint):
|
| 757 |
+
- UNETLoader for the diffusion model
|
| 758 |
+
- CLIPLoader (type=flux2) for the Mistral text encoder
|
| 759 |
+
- VAELoader for the autoencoder
|
| 760 |
+
"""
|
| 761 |
+
|
| 762 |
+
if model_type == "flux2":
|
| 763 |
+
unet_name = "flux2-dev.safetensors"
|
| 764 |
+
clip_type = "flux2"
|
| 765 |
+
clip_name = "mistral_3_small_flux2_fp8.safetensors"
|
| 766 |
+
else:
|
| 767 |
+
unet_name = "flux1-dev.safetensors"
|
| 768 |
+
clip_type = "flux"
|
| 769 |
+
clip_name = "t5xxl_fp16.safetensors"
|
| 770 |
+
|
| 771 |
+
# Model node ID references
|
| 772 |
+
model_out = ["1", 0] # UNETLoader -> MODEL
|
| 773 |
+
clip_out = ["2", 0] # CLIPLoader -> CLIP
|
| 774 |
+
vae_out = ["3", 0] # VAELoader -> VAE
|
| 775 |
|
|
|
|
| 776 |
workflow = {
|
| 777 |
+
# Load diffusion model (UNet)
|
| 778 |
+
"1": {
|
| 779 |
+
"class_type": "UNETLoader",
|
| 780 |
+
"inputs": {
|
| 781 |
+
"unet_name": unet_name,
|
| 782 |
+
"weight_dtype": "fp8_e4m3fn",
|
| 783 |
+
},
|
| 784 |
+
},
|
| 785 |
+
# Load text encoder
|
| 786 |
+
"2": {
|
| 787 |
+
"class_type": "CLIPLoader",
|
| 788 |
+
"inputs": {
|
| 789 |
+
"clip_name": clip_name,
|
| 790 |
+
"type": clip_type,
|
| 791 |
+
},
|
| 792 |
+
},
|
| 793 |
+
# Load VAE
|
| 794 |
+
"3": {
|
| 795 |
+
"class_type": "VAELoader",
|
| 796 |
+
"inputs": {"vae_name": "ae.safetensors"},
|
| 797 |
},
|
| 798 |
+
# Positive prompt
|
| 799 |
"6": {
|
| 800 |
"class_type": "CLIPTextEncode",
|
| 801 |
"inputs": {
|
| 802 |
"text": prompt,
|
| 803 |
+
"clip": clip_out,
|
| 804 |
},
|
| 805 |
},
|
| 806 |
+
# Negative prompt
|
| 807 |
"7": {
|
| 808 |
"class_type": "CLIPTextEncode",
|
| 809 |
"inputs": {
|
| 810 |
"text": negative_prompt or "",
|
| 811 |
+
"clip": clip_out,
|
| 812 |
},
|
| 813 |
},
|
| 814 |
+
# Empty latent
|
| 815 |
"5": {
|
| 816 |
"class_type": "EmptyLatentImage",
|
| 817 |
"inputs": {
|
|
|
|
| 820 |
"batch_size": 1,
|
| 821 |
},
|
| 822 |
},
|
| 823 |
+
# Sampler
|
| 824 |
+
"10": {
|
| 825 |
"class_type": "KSampler",
|
| 826 |
"inputs": {
|
| 827 |
"seed": seed,
|
|
|
|
| 830 |
"sampler_name": "euler",
|
| 831 |
"scheduler": "simple",
|
| 832 |
"denoise": 1.0,
|
| 833 |
+
"model": model_out,
|
| 834 |
"positive": ["6", 0],
|
| 835 |
"negative": ["7", 0],
|
| 836 |
"latent_image": ["5", 0],
|
| 837 |
},
|
| 838 |
},
|
| 839 |
+
# Decode
|
| 840 |
"8": {
|
| 841 |
"class_type": "VAEDecode",
|
| 842 |
"inputs": {
|
| 843 |
+
"samples": ["10", 0],
|
| 844 |
+
"vae": vae_out,
|
| 845 |
},
|
| 846 |
},
|
| 847 |
+
# Save
|
| 848 |
"9": {
|
| 849 |
"class_type": "SaveImage",
|
| 850 |
"inputs": {
|
|
|
|
| 856 |
|
| 857 |
# Add LoRA if specified
|
| 858 |
if lora_name:
|
| 859 |
+
workflow["20"] = {
|
| 860 |
"class_type": "LoraLoader",
|
| 861 |
"inputs": {
|
| 862 |
"lora_name": lora_name,
|
| 863 |
"strength_model": lora_strength,
|
| 864 |
"strength_clip": lora_strength,
|
| 865 |
+
"model": model_out,
|
| 866 |
+
"clip": clip_out,
|
| 867 |
},
|
| 868 |
}
|
| 869 |
+
# Rewire sampler and text encoders to use LoRA output
|
| 870 |
+
workflow["10"]["inputs"]["model"] = ["20", 0]
|
| 871 |
+
workflow["6"]["inputs"]["clip"] = ["20", 1]
|
| 872 |
+
workflow["7"]["inputs"]["clip"] = ["20", 1]
|
| 873 |
|
| 874 |
return workflow
|
src/content_engine/api/routes_training.py
CHANGED
|
@@ -213,8 +213,10 @@ async def list_training_jobs():
|
|
| 213 |
"loss": j.loss, "started_at": j.started_at,
|
| 214 |
"completed_at": j.completed_at, "output_path": j.output_path,
|
| 215 |
"error": j.error, "backend": "local",
|
|
|
|
| 216 |
})
|
| 217 |
if _runpod_trainer:
|
|
|
|
| 218 |
for j in _runpod_trainer.list_jobs():
|
| 219 |
jobs.append({
|
| 220 |
"id": j.id, "name": j.name, "status": j.status,
|
|
@@ -225,6 +227,7 @@ async def list_training_jobs():
|
|
| 225 |
"completed_at": j.completed_at, "output_path": j.output_path,
|
| 226 |
"error": j.error, "backend": "runpod",
|
| 227 |
"base_model": j.base_model, "model_type": j.model_type,
|
|
|
|
| 228 |
})
|
| 229 |
return jobs
|
| 230 |
|
|
@@ -232,27 +235,35 @@ async def list_training_jobs():
|
|
| 232 |
@router.get("/jobs/{job_id}")
|
| 233 |
async def get_training_job(job_id: str):
|
| 234 |
"""Get details of a specific training job including logs."""
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
|
| 258 |
@router.post("/jobs/{job_id}/cancel")
|
|
@@ -267,3 +278,22 @@ async def cancel_training_job(job_id: str):
|
|
| 267 |
if cancelled:
|
| 268 |
return {"status": "cancelled", "job_id": job_id}
|
| 269 |
raise HTTPException(404, "Job not found or not running")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
"loss": j.loss, "started_at": j.started_at,
|
| 214 |
"completed_at": j.completed_at, "output_path": j.output_path,
|
| 215 |
"error": j.error, "backend": "local",
|
| 216 |
+
"log_lines": j.log_lines[-50:] if hasattr(j, 'log_lines') else [],
|
| 217 |
})
|
| 218 |
if _runpod_trainer:
|
| 219 |
+
await _runpod_trainer.ensure_loaded()
|
| 220 |
for j in _runpod_trainer.list_jobs():
|
| 221 |
jobs.append({
|
| 222 |
"id": j.id, "name": j.name, "status": j.status,
|
|
|
|
| 227 |
"completed_at": j.completed_at, "output_path": j.output_path,
|
| 228 |
"error": j.error, "backend": "runpod",
|
| 229 |
"base_model": j.base_model, "model_type": j.model_type,
|
| 230 |
+
"log_lines": j.log_lines[-50:],
|
| 231 |
})
|
| 232 |
return jobs
|
| 233 |
|
|
|
|
| 235 |
@router.get("/jobs/{job_id}")
|
| 236 |
async def get_training_job(job_id: str):
|
| 237 |
"""Get details of a specific training job including logs."""
|
| 238 |
+
# Check RunPod jobs first
|
| 239 |
+
if _runpod_trainer:
|
| 240 |
+
await _runpod_trainer.ensure_loaded()
|
| 241 |
+
job = _runpod_trainer.get_job(job_id)
|
| 242 |
+
if job:
|
| 243 |
+
return {
|
| 244 |
+
"id": job.id, "name": job.name, "status": job.status,
|
| 245 |
+
"progress": round(job.progress, 3),
|
| 246 |
+
"current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
|
| 247 |
+
"current_step": job.current_step, "total_steps": job.total_steps,
|
| 248 |
+
"loss": job.loss, "started_at": job.started_at,
|
| 249 |
+
"completed_at": job.completed_at, "output_path": job.output_path,
|
| 250 |
+
"error": job.error, "log_lines": job.log_lines[-50:],
|
| 251 |
+
"backend": "runpod", "base_model": job.base_model,
|
| 252 |
+
}
|
| 253 |
+
# Then check local trainer
|
| 254 |
+
if _trainer:
|
| 255 |
+
job = _trainer.get_job(job_id)
|
| 256 |
+
if job:
|
| 257 |
+
return {
|
| 258 |
+
"id": job.id, "name": job.name, "status": job.status,
|
| 259 |
+
"progress": round(job.progress, 3),
|
| 260 |
+
"current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
|
| 261 |
+
"current_step": job.current_step, "total_steps": job.total_steps,
|
| 262 |
+
"loss": job.loss, "started_at": job.started_at,
|
| 263 |
+
"completed_at": job.completed_at, "output_path": job.output_path,
|
| 264 |
+
"error": job.error, "log_lines": job.log_lines[-50:],
|
| 265 |
+
}
|
| 266 |
+
raise HTTPException(404, f"Training job not found: {job_id}")
|
| 267 |
|
| 268 |
|
| 269 |
@router.post("/jobs/{job_id}/cancel")
|
|
|
|
| 278 |
if cancelled:
|
| 279 |
return {"status": "cancelled", "job_id": job_id}
|
| 280 |
raise HTTPException(404, "Job not found or not running")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@router.delete("/jobs/{job_id}")
|
| 284 |
+
async def delete_training_job(job_id: str):
|
| 285 |
+
"""Delete a training job from history."""
|
| 286 |
+
if _runpod_trainer:
|
| 287 |
+
deleted = await _runpod_trainer.delete_job(job_id)
|
| 288 |
+
if deleted:
|
| 289 |
+
return {"status": "deleted", "job_id": job_id}
|
| 290 |
+
raise HTTPException(404, f"Training job not found: {job_id}")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@router.delete("/jobs")
|
| 294 |
+
async def delete_failed_jobs():
|
| 295 |
+
"""Delete all failed training jobs."""
|
| 296 |
+
if _runpod_trainer:
|
| 297 |
+
count = await _runpod_trainer.delete_failed_jobs()
|
| 298 |
+
return {"status": "ok", "deleted": count}
|
| 299 |
+
return {"status": "ok", "deleted": 0}
|
src/content_engine/api/ui.html
CHANGED
|
@@ -636,6 +636,25 @@ select { cursor: pointer; }
|
|
| 636 |
.job-status-completed { background: rgba(34,197,94,0.15); color: var(--green); }
|
| 637 |
.job-status-failed { background: rgba(239,68,68,0.15); color: var(--red); }
|
| 638 |
.job-status-pending { background: rgba(136,136,136,0.15); color: var(--text-secondary); }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
/* --- Lightbox --- */
|
| 641 |
.lightbox {
|
|
@@ -1245,10 +1264,22 @@ select { cursor: pointer; }
|
|
| 1245 |
<div style="margin-top:8px">
|
| 1246 |
<label style="margin:0">GPU Type</label>
|
| 1247 |
<select id="train-gpu-type" style="margin-top:4px">
|
| 1248 |
-
<
|
| 1249 |
-
|
| 1250 |
-
|
| 1251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1252 |
</select>
|
| 1253 |
</div>
|
| 1254 |
</div>
|
|
@@ -1379,13 +1410,27 @@ select { cursor: pointer; }
|
|
| 1379 |
</div>
|
| 1380 |
<div id="pod-controls" style="display:flex; gap:8px; align-items:center; flex-wrap:wrap">
|
| 1381 |
<select id="pod-model-type" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
|
| 1382 |
-
<option value="
|
| 1383 |
-
<option value="
|
|
|
|
| 1384 |
</select>
|
| 1385 |
<select id="pod-gpu-select" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
|
| 1386 |
-
<
|
| 1387 |
-
|
| 1388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1389 |
</select>
|
| 1390 |
<button id="pod-start-btn" class="btn" onclick="startPod()">Start Pod</button>
|
| 1391 |
<button id="pod-stop-btn" class="btn btn-danger" onclick="stopPod()" style="display:none">Stop Pod</button>
|
|
@@ -2761,9 +2806,10 @@ function updateModelDefaults() {
|
|
| 2761 |
<span style="color:var(--accent)">Resolution: ${model.resolution}px | LR: ${model.learning_rate} | Rank: ${model.network_rank} | VRAM: ${model.vram_required_gb}GB</span>
|
| 2762 |
`;
|
| 2763 |
|
| 2764 |
-
// Update placeholder hints
|
| 2765 |
document.getElementById('train-lr').placeholder = `Default: ${model.learning_rate}`;
|
| 2766 |
document.getElementById('lr-default').textContent = `(default: ${model.learning_rate})`;
|
|
|
|
| 2767 |
|
| 2768 |
// Update resolution default
|
| 2769 |
const resSelect = document.getElementById('train-resolution');
|
|
@@ -2782,6 +2828,16 @@ function updateModelDefaults() {
|
|
| 2782 |
break;
|
| 2783 |
}
|
| 2784 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2785 |
}
|
| 2786 |
|
| 2787 |
function selectTrainBackend(chip, backend) {
|
|
@@ -2899,7 +2955,7 @@ async function pollTrainingJobs() {
|
|
| 2899 |
renderTrainingJobs(jobs);
|
| 2900 |
|
| 2901 |
// Stop polling if no active jobs
|
| 2902 |
-
const active = jobs.filter(j =>
|
| 2903 |
if (active.length === 0 && trainingPollInterval) {
|
| 2904 |
clearInterval(trainingPollInterval);
|
| 2905 |
trainingPollInterval = null;
|
|
@@ -2913,35 +2969,93 @@ function renderTrainingJobs(jobs) {
|
|
| 2913 |
container.innerHTML = `<div class="empty-state" style="padding:30px"><p>No training jobs yet</p><p style="font-size:12px;margin-top:4px">Upload images and configure settings to start training</p></div>`;
|
| 2914 |
return;
|
| 2915 |
}
|
| 2916 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2917 |
const pct = (j.progress * 100).toFixed(1);
|
| 2918 |
const elapsed = j.started_at ? ((Date.now()/1000 - j.started_at) / 60).toFixed(0) : '?';
|
|
|
|
|
|
|
| 2919 |
return `
|
| 2920 |
-
<div class="job-card">
|
| 2921 |
<div class="job-header">
|
| 2922 |
<span class="job-name">${j.name} ${j.backend === 'runpod' ? '<span style="font-size:10px;color:var(--blue);font-weight:400">☁ RunPod</span>' : ''}</span>
|
| 2923 |
<span class="job-status job-status-${j.status}">${j.status}</span>
|
| 2924 |
</div>
|
| 2925 |
-
${
|
| 2926 |
<div class="progress-bar-container" style="margin-top:0">
|
| 2927 |
<div class="progress-bar-fill" style="width:${pct}%"></div>
|
| 2928 |
</div>
|
| 2929 |
<div style="display:flex;gap:16px;margin-top:8px;font-size:12px;color:var(--text-secondary)">
|
| 2930 |
<span>Progress: <strong style="color:var(--text-primary)">${pct}%</strong></span>
|
| 2931 |
${j.current_step ? `<span>Step: <strong style="color:var(--text-primary)">${j.current_step}/${j.total_steps}</strong></span>` : ''}
|
| 2932 |
-
${j.loss !== null ? `<span>Loss: <strong style="color:var(--text-primary)">${j.loss.toFixed(4)}</strong></span>` : ''}
|
| 2933 |
<span>Time: <strong style="color:var(--text-primary)">${elapsed}m</strong></span>
|
| 2934 |
</div>
|
| 2935 |
-
<
|
|
|
|
|
|
|
|
|
|
| 2936 |
` : ''}
|
| 2937 |
${j.status === 'completed' ? `
|
| 2938 |
<div style="font-size:12px;color:var(--green);margin-top:4px">LoRA saved to ComfyUI models folder</div>
|
| 2939 |
${j.output_path ? `<div style="font-size:11px;color:var(--text-secondary);margin-top:2px;word-break:break-all">${j.output_path}</div>` : ''}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2940 |
` : ''}
|
| 2941 |
-
|
|
|
|
|
|
|
| 2942 |
</div>
|
| 2943 |
`;
|
| 2944 |
}).join('');
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2945 |
}
|
| 2946 |
|
| 2947 |
async function cancelTraining(jobId) {
|
|
@@ -2954,6 +3068,27 @@ async function cancelTraining(jobId) {
|
|
| 2954 |
}
|
| 2955 |
}
|
| 2956 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2957 |
// --- Status ---
|
| 2958 |
async function checkStatus() {
|
| 2959 |
try {
|
|
@@ -3085,10 +3220,11 @@ function updatePodUI(pod) {
|
|
| 3085 |
if (!podPollInterval) {
|
| 3086 |
podPollInterval = setInterval(loadPodStatus, 30000);
|
| 3087 |
}
|
| 3088 |
-
} else if (pod.status === 'starting') {
|
| 3089 |
-
|
|
|
|
| 3090 |
startBtn.style.display = 'none';
|
| 3091 |
-
stopBtn.style.display = '
|
| 3092 |
gpuSelect.style.display = 'none';
|
| 3093 |
podInfo.style.display = 'none';
|
| 3094 |
// Poll more frequently while starting
|
|
|
|
| 636 |
.job-status-completed { background: rgba(34,197,94,0.15); color: var(--green); }
|
| 637 |
.job-status-failed { background: rgba(239,68,68,0.15); color: var(--red); }
|
| 638 |
.job-status-pending { background: rgba(136,136,136,0.15); color: var(--text-secondary); }
|
| 639 |
+
.job-logs-panel {
|
| 640 |
+
margin-top: 8px;
|
| 641 |
+
border-top: 1px solid var(--border);
|
| 642 |
+
padding-top: 8px;
|
| 643 |
+
}
|
| 644 |
+
.job-logs-content {
|
| 645 |
+
background: var(--bg-primary);
|
| 646 |
+
border: 1px solid var(--border);
|
| 647 |
+
border-radius: 6px;
|
| 648 |
+
padding: 8px 10px;
|
| 649 |
+
font-family: monospace;
|
| 650 |
+
font-size: 11px;
|
| 651 |
+
line-height: 1.5;
|
| 652 |
+
max-height: 300px;
|
| 653 |
+
overflow-y: auto;
|
| 654 |
+
white-space: pre-wrap;
|
| 655 |
+
word-break: break-all;
|
| 656 |
+
color: var(--text-secondary);
|
| 657 |
+
}
|
| 658 |
|
| 659 |
/* --- Lightbox --- */
|
| 660 |
.lightbox {
|
|
|
|
| 1264 |
<div style="margin-top:8px">
|
| 1265 |
<label style="margin:0">GPU Type</label>
|
| 1266 |
<select id="train-gpu-type" style="margin-top:4px">
|
| 1267 |
+
<optgroup label="48GB+ (Required for FLUX.2 Dev)">
|
| 1268 |
+
<option value="NVIDIA A40">A40 48GB (~$0.64/hr) - Cheapest for FLUX.2</option>
|
| 1269 |
+
<option value="NVIDIA RTX A6000" selected>RTX A6000 48GB (~$0.76/hr) - Recommended</option>
|
| 1270 |
+
<option value="NVIDIA L40">L40 48GB (~$0.89/hr)</option>
|
| 1271 |
+
<option value="NVIDIA L40S">L40S 48GB (~$1.09/hr)</option>
|
| 1272 |
+
<option value="NVIDIA A100-SXM4-80GB">A100 SXM 80GB (~$1.64/hr)</option>
|
| 1273 |
+
<option value="NVIDIA A100 80GB PCIe">A100 PCIe 80GB (~$1.89/hr)</option>
|
| 1274 |
+
<option value="NVIDIA H100 80GB HBM3">H100 80GB (~$3.89/hr) - Fastest</option>
|
| 1275 |
+
</optgroup>
|
| 1276 |
+
<optgroup label="24-32GB (SD 1.5, SDXL, FLUX.1 only)">
|
| 1277 |
+
<option value="NVIDIA GeForce RTX 5090">RTX 5090 32GB (~$0.69/hr)</option>
|
| 1278 |
+
<option value="NVIDIA GeForce RTX 4090">RTX 4090 24GB (~$0.44/hr)</option>
|
| 1279 |
+
<option value="NVIDIA GeForce RTX 3090">RTX 3090 24GB (~$0.22/hr)</option>
|
| 1280 |
+
<option value="NVIDIA RTX A5000">RTX A5000 24GB (~$0.28/hr)</option>
|
| 1281 |
+
<option value="NVIDIA RTX A4000">RTX A4000 16GB (~$0.20/hr)</option>
|
| 1282 |
+
</optgroup>
|
| 1283 |
</select>
|
| 1284 |
</div>
|
| 1285 |
</div>
|
|
|
|
| 1410 |
</div>
|
| 1411 |
<div id="pod-controls" style="display:flex; gap:8px; align-items:center; flex-wrap:wrap">
|
| 1412 |
<select id="pod-model-type" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
|
| 1413 |
+
<option value="flux2">FLUX.2 Dev (Realistic txt2img)</option>
|
| 1414 |
+
<option value="flux1">FLUX.1 Dev (txt2img)</option>
|
| 1415 |
+
<option value="wan22">WAN 2.2 (img2video)</option>
|
| 1416 |
</select>
|
| 1417 |
<select id="pod-gpu-select" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
|
| 1418 |
+
<optgroup label="48GB+ (FLUX.2 / Large models)">
|
| 1419 |
+
<option value="NVIDIA A40">A40 48GB - $0.64/hr</option>
|
| 1420 |
+
<option value="NVIDIA RTX A6000" selected>A6000 48GB - $0.76/hr</option>
|
| 1421 |
+
<option value="NVIDIA L40">L40 48GB - $0.89/hr</option>
|
| 1422 |
+
<option value="NVIDIA L40S">L40S 48GB - $1.09/hr</option>
|
| 1423 |
+
<option value="NVIDIA A100-SXM4-80GB">A100 SXM 80GB - $1.64/hr</option>
|
| 1424 |
+
<option value="NVIDIA A100 80GB PCIe">A100 PCIe 80GB - $1.89/hr</option>
|
| 1425 |
+
<option value="NVIDIA H100 80GB HBM3">H100 80GB - $3.89/hr</option>
|
| 1426 |
+
</optgroup>
|
| 1427 |
+
<optgroup label="24-32GB (SD 1.5 / SDXL / FLUX.1)">
|
| 1428 |
+
<option value="NVIDIA GeForce RTX 5090">RTX 5090 32GB - $0.69/hr</option>
|
| 1429 |
+
<option value="NVIDIA GeForce RTX 4090">RTX 4090 24GB - $0.44/hr</option>
|
| 1430 |
+
<option value="NVIDIA GeForce RTX 3090">RTX 3090 24GB - $0.22/hr</option>
|
| 1431 |
+
<option value="NVIDIA RTX A5000">A5000 24GB - $0.28/hr</option>
|
| 1432 |
+
<option value="NVIDIA RTX A4000">A4000 16GB - $0.20/hr</option>
|
| 1433 |
+
</optgroup>
|
| 1434 |
</select>
|
| 1435 |
<button id="pod-start-btn" class="btn" onclick="startPod()">Start Pod</button>
|
| 1436 |
<button id="pod-stop-btn" class="btn btn-danger" onclick="stopPod()" style="display:none">Stop Pod</button>
|
|
|
|
| 2806 |
<span style="color:var(--accent)">Resolution: ${model.resolution}px | LR: ${model.learning_rate} | Rank: ${model.network_rank} | VRAM: ${model.vram_required_gb}GB</span>
|
| 2807 |
`;
|
| 2808 |
|
| 2809 |
+
// Update placeholder hints and auto-fill LR
|
| 2810 |
document.getElementById('train-lr').placeholder = `Default: ${model.learning_rate}`;
|
| 2811 |
document.getElementById('lr-default').textContent = `(default: ${model.learning_rate})`;
|
| 2812 |
+
document.getElementById('train-lr').value = model.learning_rate;
|
| 2813 |
|
| 2814 |
// Update resolution default
|
| 2815 |
const resSelect = document.getElementById('train-resolution');
|
|
|
|
| 2828 |
break;
|
| 2829 |
}
|
| 2830 |
}
|
| 2831 |
+
|
| 2832 |
+
// Update optimizer default
|
| 2833 |
+
const optSelect = document.getElementById('train-optimizer');
|
| 2834 |
+
const optName = (model.optimizer || 'AdamW8bit').toLowerCase();
|
| 2835 |
+
for (let opt of optSelect.options) {
|
| 2836 |
+
if (opt.value.toLowerCase() === optName) {
|
| 2837 |
+
opt.selected = true;
|
| 2838 |
+
break;
|
| 2839 |
+
}
|
| 2840 |
+
}
|
| 2841 |
}
|
| 2842 |
|
| 2843 |
function selectTrainBackend(chip, backend) {
|
|
|
|
| 2955 |
renderTrainingJobs(jobs);
|
| 2956 |
|
| 2957 |
// Stop polling if no active jobs
|
| 2958 |
+
const active = jobs.filter(j => ['training','preparing','creating_pod','uploading','installing','downloading','pending'].includes(j.status));
|
| 2959 |
if (active.length === 0 && trainingPollInterval) {
|
| 2960 |
clearInterval(trainingPollInterval);
|
| 2961 |
trainingPollInterval = null;
|
|
|
|
| 2969 |
container.innerHTML = `<div class="empty-state" style="padding:30px"><p>No training jobs yet</p><p style="font-size:12px;margin-top:4px">Upload images and configure settings to start training</p></div>`;
|
| 2970 |
return;
|
| 2971 |
}
|
| 2972 |
+
|
| 2973 |
+
// Store latest jobs for log viewer
|
| 2974 |
+
window._trainingJobs = jobs;
|
| 2975 |
+
|
| 2976 |
+
// Show "Clear Failed" button if there are any failed jobs
|
| 2977 |
+
const failedCount = jobs.filter(j => j.status === 'failed' || j.status === 'error').length;
|
| 2978 |
+
let html = failedCount > 0 ? `<div style="text-align:right;margin-bottom:8px"><button class="btn btn-secondary btn-small" style="color:var(--red)" onclick="clearFailedJobs()">Clear ${failedCount} Failed</button></div>` : '';
|
| 2979 |
+
|
| 2980 |
+
html += jobs.map(j => {
|
| 2981 |
const pct = (j.progress * 100).toFixed(1);
|
| 2982 |
const elapsed = j.started_at ? ((Date.now()/1000 - j.started_at) / 60).toFixed(0) : '?';
|
| 2983 |
+
const isActive = ['training','preparing','creating_pod','uploading','installing','downloading'].includes(j.status);
|
| 2984 |
+
const hasLogs = j.log_lines && j.log_lines.length > 0;
|
| 2985 |
return `
|
| 2986 |
+
<div class="job-card" id="job-card-${j.id}">
|
| 2987 |
<div class="job-header">
|
| 2988 |
<span class="job-name">${j.name} ${j.backend === 'runpod' ? '<span style="font-size:10px;color:var(--blue);font-weight:400">☁ RunPod</span>' : ''}</span>
|
| 2989 |
<span class="job-status job-status-${j.status}">${j.status}</span>
|
| 2990 |
</div>
|
| 2991 |
+
${isActive ? `
|
| 2992 |
<div class="progress-bar-container" style="margin-top:0">
|
| 2993 |
<div class="progress-bar-fill" style="width:${pct}%"></div>
|
| 2994 |
</div>
|
| 2995 |
<div style="display:flex;gap:16px;margin-top:8px;font-size:12px;color:var(--text-secondary)">
|
| 2996 |
<span>Progress: <strong style="color:var(--text-primary)">${pct}%</strong></span>
|
| 2997 |
${j.current_step ? `<span>Step: <strong style="color:var(--text-primary)">${j.current_step}/${j.total_steps}</strong></span>` : ''}
|
| 2998 |
+
${j.loss !== null && j.loss !== undefined ? `<span>Loss: <strong style="color:var(--text-primary)">${j.loss.toFixed(4)}</strong></span>` : ''}
|
| 2999 |
<span>Time: <strong style="color:var(--text-primary)">${elapsed}m</strong></span>
|
| 3000 |
</div>
|
| 3001 |
+
<div style="display:flex;gap:6px;margin-top:8px">
|
| 3002 |
+
<button class="btn btn-secondary btn-small" onclick="toggleJobLogs('${j.id}')">View Logs</button>
|
| 3003 |
+
<button class="btn btn-secondary btn-small" style="color:var(--red)" onclick="cancelTraining('${j.id}')">Cancel</button>
|
| 3004 |
+
</div>
|
| 3005 |
` : ''}
|
| 3006 |
${j.status === 'completed' ? `
|
| 3007 |
<div style="font-size:12px;color:var(--green);margin-top:4px">LoRA saved to ComfyUI models folder</div>
|
| 3008 |
${j.output_path ? `<div style="font-size:11px;color:var(--text-secondary);margin-top:2px;word-break:break-all">${j.output_path}</div>` : ''}
|
| 3009 |
+
<button class="btn btn-secondary btn-small" style="margin-top:6px" onclick="toggleJobLogs('${j.id}')">View Logs</button>
|
| 3010 |
+
` : ''}
|
| 3011 |
+
${j.status === 'failed' ? `
|
| 3012 |
+
${j.error ? `<div style="font-size:12px;color:var(--red);margin-top:4px">${j.error}</div>` : ''}
|
| 3013 |
+
<div style="display:flex;gap:6px;margin-top:6px">
|
| 3014 |
+
<button class="btn btn-secondary btn-small" onclick="toggleJobLogs('${j.id}')">View Logs</button>
|
| 3015 |
+
<button class="btn btn-secondary btn-small" style="color:var(--red)" onclick="deleteJob('${j.id}')">Delete</button>
|
| 3016 |
+
</div>
|
| 3017 |
` : ''}
|
| 3018 |
+
<div id="job-logs-${j.id}" class="job-logs-panel" style="display:none">
|
| 3019 |
+
<div class="job-logs-content"></div>
|
| 3020 |
+
</div>
|
| 3021 |
</div>
|
| 3022 |
`;
|
| 3023 |
}).join('');
|
| 3024 |
+
|
| 3025 |
+
container.innerHTML = html;
|
| 3026 |
+
|
| 3027 |
+
// Auto-show logs for active jobs
|
| 3028 |
+
const activeJob = jobs.find(j => ['training','preparing','creating_pod','uploading','installing','downloading'].includes(j.status));
|
| 3029 |
+
if (activeJob && activeJob.log_lines && activeJob.log_lines.length > 0) {
|
| 3030 |
+
showJobLogs(activeJob.id);
|
| 3031 |
+
}
|
| 3032 |
+
}
|
| 3033 |
+
|
| 3034 |
+
function toggleJobLogs(jobId) {
|
| 3035 |
+
const panel = document.getElementById('job-logs-' + jobId);
|
| 3036 |
+
if (!panel) return;
|
| 3037 |
+
if (panel.style.display === 'none') {
|
| 3038 |
+
showJobLogs(jobId);
|
| 3039 |
+
} else {
|
| 3040 |
+
panel.style.display = 'none';
|
| 3041 |
+
}
|
| 3042 |
+
}
|
| 3043 |
+
|
| 3044 |
+
function showJobLogs(jobId) {
|
| 3045 |
+
const panel = document.getElementById('job-logs-' + jobId);
|
| 3046 |
+
if (!panel) return;
|
| 3047 |
+
panel.style.display = 'block';
|
| 3048 |
+
|
| 3049 |
+
// Find job data
|
| 3050 |
+
const job = (window._trainingJobs || []).find(j => j.id === jobId);
|
| 3051 |
+
if (!job || !job.log_lines) {
|
| 3052 |
+
panel.querySelector('.job-logs-content').textContent = 'No logs available';
|
| 3053 |
+
return;
|
| 3054 |
+
}
|
| 3055 |
+
|
| 3056 |
+
const content = panel.querySelector('.job-logs-content');
|
| 3057 |
+
content.textContent = job.log_lines.join('\n');
|
| 3058 |
+
content.scrollTop = content.scrollHeight;
|
| 3059 |
}
|
| 3060 |
|
| 3061 |
async function cancelTraining(jobId) {
|
|
|
|
| 3068 |
}
|
| 3069 |
}
|
| 3070 |
|
| 3071 |
+
async function deleteJob(jobId) {
|
| 3072 |
+
try {
|
| 3073 |
+
await fetch(API + `/api/training/jobs/${jobId}`, { method: 'DELETE' });
|
| 3074 |
+
toast('Job deleted', 'info');
|
| 3075 |
+
pollTrainingJobs();
|
| 3076 |
+
} catch(e) {
|
| 3077 |
+
toast('Failed to delete job', 'error');
|
| 3078 |
+
}
|
| 3079 |
+
}
|
| 3080 |
+
|
| 3081 |
+
async function clearFailedJobs() {
|
| 3082 |
+
try {
|
| 3083 |
+
const res = await fetch(API + '/api/training/jobs', { method: 'DELETE' });
|
| 3084 |
+
const data = await res.json();
|
| 3085 |
+
toast(`Cleared ${data.deleted} failed jobs`, 'info');
|
| 3086 |
+
pollTrainingJobs();
|
| 3087 |
+
} catch(e) {
|
| 3088 |
+
toast('Failed to clear jobs', 'error');
|
| 3089 |
+
}
|
| 3090 |
+
}
|
| 3091 |
+
|
| 3092 |
// --- Status ---
|
| 3093 |
async function checkStatus() {
|
| 3094 |
try {
|
|
|
|
| 3220 |
if (!podPollInterval) {
|
| 3221 |
podPollInterval = setInterval(loadPodStatus, 30000);
|
| 3222 |
}
|
| 3223 |
+
} else if (pod.status === 'starting' || pod.status === 'setting_up') {
|
| 3224 |
+
const setupMsg = pod.setup_status || 'Starting pod...';
|
| 3225 |
+
statusText.innerHTML = `<span style="color:var(--yellow)">● ${setupMsg}</span>`;
|
| 3226 |
startBtn.style.display = 'none';
|
| 3227 |
+
stopBtn.style.display = ''; // Allow stopping during setup
|
| 3228 |
gpuSelect.style.display = 'none';
|
| 3229 |
podInfo.style.display = 'none';
|
| 3230 |
// Poll more frequently while starting
|
src/content_engine/models/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (802 Bytes)
|
|
|
src/content_engine/models/__pycache__/database.cpython-311.pyc
DELETED
|
Binary file (10.7 kB)
|
|
|
src/content_engine/models/__pycache__/schemas.cpython-311.pyc
DELETED
|
Binary file (5.59 kB)
|
|
|
src/content_engine/models/database.py
CHANGED
|
@@ -2,9 +2,7 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
import os
|
| 6 |
from datetime import datetime
|
| 7 |
-
from pathlib import Path
|
| 8 |
|
| 9 |
from sqlalchemy import (
|
| 10 |
Boolean,
|
|
@@ -149,12 +147,36 @@ class ScheduledPost(Base):
|
|
| 149 |
)
|
| 150 |
|
| 151 |
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
_catalog_engine = create_async_engine(
|
| 160 |
settings.database.url,
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 5 |
from datetime import datetime
|
|
|
|
| 6 |
|
| 7 |
from sqlalchemy import (
|
| 8 |
Boolean,
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
|
| 150 |
+
class TrainingJob(Base):
|
| 151 |
+
__tablename__ = "training_jobs"
|
| 152 |
|
| 153 |
+
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
| 154 |
+
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
| 155 |
+
status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
|
| 156 |
+
progress: Mapped[float] = mapped_column(Float, default=0.0)
|
| 157 |
+
current_epoch: Mapped[int] = mapped_column(Integer, default=0)
|
| 158 |
+
total_epochs: Mapped[int] = mapped_column(Integer, default=0)
|
| 159 |
+
current_step: Mapped[int] = mapped_column(Integer, default=0)
|
| 160 |
+
total_steps: Mapped[int] = mapped_column(Integer, default=0)
|
| 161 |
+
loss: Mapped[float | None] = mapped_column(Float)
|
| 162 |
+
started_at: Mapped[float | None] = mapped_column(Float)
|
| 163 |
+
completed_at: Mapped[float | None] = mapped_column(Float)
|
| 164 |
+
output_path: Mapped[str | None] = mapped_column(String(512))
|
| 165 |
+
error: Mapped[str | None] = mapped_column(Text)
|
| 166 |
+
log_text: Mapped[str | None] = mapped_column(Text) # newline-separated log lines
|
| 167 |
+
pod_id: Mapped[str | None] = mapped_column(String(64))
|
| 168 |
+
gpu_type: Mapped[str | None] = mapped_column(String(64))
|
| 169 |
+
backend: Mapped[str] = mapped_column(String(16), default="runpod")
|
| 170 |
+
base_model: Mapped[str | None] = mapped_column(String(64))
|
| 171 |
+
model_type: Mapped[str | None] = mapped_column(String(16))
|
| 172 |
+
trigger_word: Mapped[str | None] = mapped_column(String(128))
|
| 173 |
+
image_upload_dir: Mapped[str | None] = mapped_column(String(512))
|
| 174 |
+
created_at: Mapped[datetime] = mapped_column(
|
| 175 |
+
DateTime, server_default=func.now()
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# --- Engine / Session factories ---
|
| 180 |
|
| 181 |
_catalog_engine = create_async_engine(
|
| 182 |
settings.database.url,
|
src/content_engine/services/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (232 Bytes)
|
|
|
src/content_engine/services/__pycache__/catalog.cpython-311.pyc
DELETED
|
Binary file (13.1 kB)
|
|
|
src/content_engine/services/__pycache__/comfyui_client.cpython-311.pyc
DELETED
|
Binary file (17.5 kB)
|
|
|
src/content_engine/services/__pycache__/lora_trainer.cpython-311.pyc
DELETED
|
Binary file (19 kB)
|
|
|
src/content_engine/services/__pycache__/runpod_trainer.cpython-311.pyc
DELETED
|
Binary file (32 kB)
|
|
|
src/content_engine/services/__pycache__/template_engine.cpython-311.pyc
DELETED
|
Binary file (11 kB)
|
|
|
src/content_engine/services/__pycache__/variation_engine.cpython-311.pyc
DELETED
|
Binary file (8.52 kB)
|
|
|
src/content_engine/services/__pycache__/workflow_builder.cpython-311.pyc
DELETED
|
Binary file (7.16 kB)
|
|
|
src/content_engine/services/runpod_trainer.py
CHANGED
|
@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
|
| 28 |
import os
|
| 29 |
from content_engine.config import settings, IS_HF_SPACES
|
|
|
|
| 30 |
|
| 31 |
LORA_OUTPUT_DIR = settings.paths.lora_dir
|
| 32 |
if IS_HF_SPACES:
|
|
@@ -36,16 +37,30 @@ else:
|
|
| 36 |
|
| 37 |
# RunPod GPU options (id -> display name, approx cost/hr)
|
| 38 |
GPU_OPTIONS = {
|
| 39 |
-
|
| 40 |
-
"NVIDIA GeForce RTX
|
| 41 |
-
"NVIDIA RTX
|
| 42 |
-
"NVIDIA RTX
|
| 43 |
-
"NVIDIA RTX
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
"NVIDIA A100 80GB PCIe": "A100 80GB (~$1.89/hr)",
|
|
|
|
|
|
|
| 45 |
}
|
| 46 |
|
| 47 |
DEFAULT_GPU = "NVIDIA GeForce RTX 4090"
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# Docker image with PyTorch + CUDA pre-installed
|
| 50 |
DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
|
| 51 |
|
|
@@ -84,12 +99,15 @@ class CloudTrainingJob:
|
|
| 84 |
cost_estimate: str | None = None
|
| 85 |
base_model: str = "sd15_realistic"
|
| 86 |
model_type: str = "sd15"
|
|
|
|
| 87 |
|
| 88 |
def _log(self, msg: str):
|
| 89 |
self.log_lines.append(msg)
|
| 90 |
if len(self.log_lines) > 200:
|
| 91 |
self.log_lines = self.log_lines[-200:]
|
| 92 |
logger.info("[%s] %s", self.id, msg)
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
class RunPodTrainer:
|
|
@@ -100,6 +118,7 @@ class RunPodTrainer:
|
|
| 100 |
runpod.api_key = api_key
|
| 101 |
self._jobs: dict[str, CloudTrainingJob] = {}
|
| 102 |
self._model_registry = load_model_registry()
|
|
|
|
| 103 |
|
| 104 |
@property
|
| 105 |
def available(self) -> bool:
|
|
@@ -123,6 +142,8 @@ class RunPodTrainer:
|
|
| 123 |
"learning_rate": cfg.get("learning_rate", 1e-4),
|
| 124 |
"network_rank": cfg.get("network_rank", 32),
|
| 125 |
"network_alpha": cfg.get("network_alpha", 16),
|
|
|
|
|
|
|
| 126 |
"vram_required_gb": cfg.get("vram_required_gb", 8),
|
| 127 |
"recommended_images": cfg.get("recommended_images", "15-30 photos"),
|
| 128 |
}
|
|
@@ -179,6 +200,8 @@ class RunPodTrainer:
|
|
| 179 |
model_type=model_type,
|
| 180 |
)
|
| 181 |
self._jobs[job_id] = job
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Launch the full pipeline as a background task
|
| 184 |
asyncio.create_task(self._run_cloud_training(
|
|
@@ -224,15 +247,26 @@ class RunPodTrainer:
|
|
| 224 |
job.status = "creating_pod"
|
| 225 |
job._log(f"Creating RunPod with {job.gpu_type}...")
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
pod = await asyncio.to_thread(
|
| 228 |
runpod.create_pod,
|
| 229 |
f"lora-train-{job.id}",
|
| 230 |
DOCKER_IMAGE,
|
| 231 |
job.gpu_type,
|
| 232 |
-
|
| 233 |
-
container_disk_in_gb=30,
|
| 234 |
-
ports="22/tcp",
|
| 235 |
-
docker_args="bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo \"root:runpod\" | chpasswd && echo \"PermitRootLogin yes\" >> /etc/ssh/sshd_config && /usr/sbin/sshd && sleep infinity'",
|
| 236 |
)
|
| 237 |
|
| 238 |
job.pod_id = pod["id"]
|
|
@@ -263,89 +297,273 @@ class RunPodTrainer:
|
|
| 263 |
await asyncio.sleep(5)
|
| 264 |
|
| 265 |
job._log("SSH connected")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
sftp = ssh.open_sftp()
|
|
|
|
| 267 |
|
| 268 |
-
# Step 3: Upload training images
|
| 269 |
job.status = "uploading"
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
folder_name = f"10_{trigger_word or 'character'}"
|
| 273 |
self._ssh_exec(ssh, f"mkdir -p /workspace/dataset/{folder_name}")
|
| 274 |
-
for img_path in image_paths:
|
| 275 |
p = Path(img_path)
|
| 276 |
if p.exists():
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
# Upload matching caption .txt file if it exists locally
|
| 280 |
local_caption = p.with_suffix(".txt")
|
| 281 |
if local_caption.exists():
|
| 282 |
-
remote_caption = f"/workspace/dataset/{folder_name}/{
|
| 283 |
sftp.put(str(local_caption), remote_caption)
|
| 284 |
else:
|
| 285 |
# Fallback: create caption from trigger word
|
| 286 |
-
remote_caption =
|
| 287 |
with sftp.open(remote_caption, "w") as f:
|
| 288 |
f.write(trigger_word or "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
job._log("Images uploaded")
|
| 291 |
|
| 292 |
-
# Step 4: Install
|
| 293 |
job.status = "installing"
|
| 294 |
-
job._log("Installing Kohya sd-scripts (this takes a few minutes)...")
|
| 295 |
job.progress = 0.05
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
for cmd in install_cmds:
|
| 303 |
out = self._ssh_exec(ssh, cmd, timeout=600)
|
| 304 |
job._log(out[:200] if out else "done")
|
| 305 |
|
| 306 |
-
# Download base model from HuggingFace
|
| 307 |
hf_repo = model_cfg.get("hf_repo", "SG161222/Realistic_Vision_V5.1_noVAE")
|
| 308 |
hf_filename = model_cfg.get("hf_filename", "Realistic_Vision_V5.1_fp16-no-ema.safetensors")
|
| 309 |
model_name = model_cfg.get("name", job.base_model)
|
| 310 |
|
| 311 |
-
job._log(f"Downloading base model: {model_name}...")
|
| 312 |
job.progress = 0.1
|
| 313 |
-
|
| 314 |
self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
from huggingface_hub import hf_hub_download
|
| 320 |
hf_hub_download('{hf_repo}', '{hf_filename}', local_dir='/workspace/models')
|
| 321 |
" 2>&1 | tail -5
|
| 322 |
-
|
| 323 |
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
| 331 |
from huggingface_hub import hf_hub_download
|
| 332 |
-
# CLIP text encoder
|
| 333 |
hf_hub_download('comfyanonymous/flux_text_encoders', 'clip_l.safetensors', local_dir='/workspace/models')
|
| 334 |
-
# T5 text encoder (fp16)
|
| 335 |
hf_hub_download('comfyanonymous/flux_text_encoders', 't5xxl_fp16.safetensors', local_dir='/workspace/models')
|
| 336 |
-
# VAE/AutoEncoder
|
| 337 |
hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/workspace/models')
|
| 338 |
" 2>&1 | tail -5
|
| 339 |
-
|
| 340 |
|
| 341 |
-
job._log("Base model
|
| 342 |
job.progress = 0.15
|
| 343 |
|
| 344 |
# Step 5: Run training
|
| 345 |
job.status = "training"
|
| 346 |
job._log(f"Starting {model_type.upper()} LoRA training...")
|
| 347 |
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
# Build training command based on model type
|
| 351 |
train_cmd = self._build_training_command(
|
|
@@ -361,6 +579,7 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 361 |
optimizer=optimizer,
|
| 362 |
save_every_n_epochs=save_every_n_epochs,
|
| 363 |
model_cfg=model_cfg,
|
|
|
|
| 364 |
)
|
| 365 |
|
| 366 |
# Execute training and stream output
|
|
@@ -371,19 +590,37 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 371 |
|
| 372 |
# Read output progressively
|
| 373 |
buffer = ""
|
|
|
|
| 374 |
while not channel.exit_status_ready() or channel.recv_ready():
|
| 375 |
if channel.recv_ready():
|
| 376 |
chunk = channel.recv(4096).decode("utf-8", errors="replace")
|
| 377 |
buffer += chunk
|
| 378 |
-
# Process complete lines
|
| 379 |
-
while "\n" in buffer:
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
if not line:
|
| 383 |
continue
|
| 384 |
job._log(line)
|
| 385 |
self._parse_progress(job, line)
|
|
|
|
| 386 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
await asyncio.sleep(2)
|
| 388 |
|
| 389 |
exit_code = channel.recv_exit_status()
|
|
@@ -393,27 +630,33 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 393 |
job._log("Training completed on RunPod!")
|
| 394 |
job.progress = 0.9
|
| 395 |
|
| 396 |
-
# Step 6:
|
| 397 |
job.status = "downloading"
|
| 398 |
-
job._log("Downloading trained LoRA...")
|
| 399 |
-
|
| 400 |
-
LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 401 |
-
local_path = LORA_OUTPUT_DIR / f"{name}.safetensors"
|
| 402 |
|
| 403 |
-
#
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
| 412 |
else:
|
| 413 |
raise RuntimeError("No .safetensors output found")
|
| 414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
job.output_path = str(local_path)
|
| 416 |
-
job._log(f"LoRA saved to {local_path}")
|
| 417 |
|
| 418 |
# Done!
|
| 419 |
job.status = "completed"
|
|
@@ -449,6 +692,92 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 449 |
except Exception as e:
|
| 450 |
job._log(f"Warning: Failed to terminate pod {job.pod_id}: {e}")
|
| 451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
def _build_training_command(
|
| 453 |
self,
|
| 454 |
*,
|
|
@@ -464,6 +793,7 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 464 |
optimizer: str,
|
| 465 |
save_every_n_epochs: int,
|
| 466 |
model_cfg: dict,
|
|
|
|
| 467 |
) -> str:
|
| 468 |
"""Build the training command based on model type."""
|
| 469 |
|
|
@@ -497,8 +827,70 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 497 |
lr_scheduler = model_cfg.get("lr_scheduler", "cosine_with_restarts")
|
| 498 |
base_args += f" \\\n --lr_scheduler={lr_scheduler}"
|
| 499 |
|
| 500 |
-
if model_type == "
|
| 501 |
-
# FLUX
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
script = "flux_train_network.py"
|
| 503 |
flux_args = f"""
|
| 504 |
--pretrained_model_name_or_path="{model_path}" \
|
|
@@ -539,24 +931,33 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 539 |
--clip_skip={clip_skip} \
|
| 540 |
--xformers {base_args} 2>&1"""
|
| 541 |
|
| 542 |
-
async def _wait_for_pod_ready(self, job: CloudTrainingJob, timeout: int =
|
| 543 |
"""Wait for pod to be running and return (ssh_host, ssh_port)."""
|
| 544 |
start = time.time()
|
| 545 |
while time.time() - start < timeout:
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
|
| 548 |
status = pod.get("desiredStatus", "")
|
| 549 |
runtime = pod.get("runtime")
|
| 550 |
|
| 551 |
if status == "RUNNING" and runtime:
|
| 552 |
-
ports = runtime.get("ports"
|
| 553 |
-
for port_info in ports:
|
| 554 |
if port_info.get("privatePort") == 22:
|
| 555 |
ip = port_info.get("ip")
|
| 556 |
public_port = port_info.get("publicPort")
|
| 557 |
if ip and public_port:
|
| 558 |
return ip, int(public_port)
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
await asyncio.sleep(5)
|
| 561 |
|
| 562 |
raise RuntimeError(f"Pod did not become ready within {timeout}s")
|
|
@@ -623,3 +1024,28 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
|
|
| 623 |
job.status = "failed"
|
| 624 |
job.error = "Cancelled by user"
|
| 625 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
import os
|
| 29 |
from content_engine.config import settings, IS_HF_SPACES
|
| 30 |
+
from content_engine.models.database import catalog_session_factory, TrainingJob as TrainingJobDB
|
| 31 |
|
| 32 |
LORA_OUTPUT_DIR = settings.paths.lora_dir
|
| 33 |
if IS_HF_SPACES:
|
|
|
|
| 37 |
|
| 38 |
# RunPod GPU options (id -> display name, approx cost/hr)
|
| 39 |
GPU_OPTIONS = {
|
| 40 |
+
# 24GB - SD 1.5, SDXL, FLUX.1 only (NOT enough for FLUX.2)
|
| 41 |
+
"NVIDIA GeForce RTX 3090": "RTX 3090 24GB (~$0.22/hr)",
|
| 42 |
+
"NVIDIA GeForce RTX 4090": "RTX 4090 24GB (~$0.44/hr)",
|
| 43 |
+
"NVIDIA GeForce RTX 5090": "RTX 5090 32GB (~$0.69/hr)",
|
| 44 |
+
"NVIDIA RTX A4000": "RTX A4000 16GB (~$0.20/hr)",
|
| 45 |
+
"NVIDIA RTX A5000": "RTX A5000 24GB (~$0.28/hr)",
|
| 46 |
+
# 48GB+ - Required for FLUX.2 Dev (Mistral text encoder needs ~48GB)
|
| 47 |
+
"NVIDIA RTX A6000": "RTX A6000 48GB (~$0.76/hr)",
|
| 48 |
+
"NVIDIA A40": "A40 48GB (~$0.64/hr)",
|
| 49 |
+
"NVIDIA L40": "L40 48GB (~$0.89/hr)",
|
| 50 |
+
"NVIDIA L40S": "L40S 48GB (~$1.09/hr)",
|
| 51 |
"NVIDIA A100 80GB PCIe": "A100 80GB (~$1.89/hr)",
|
| 52 |
+
"NVIDIA A100-SXM4-80GB": "A100 SXM 80GB (~$1.64/hr)",
|
| 53 |
+
"NVIDIA H100 80GB HBM3": "H100 80GB (~$3.89/hr)",
|
| 54 |
}
|
| 55 |
|
| 56 |
DEFAULT_GPU = "NVIDIA GeForce RTX 4090"
|
| 57 |
|
| 58 |
+
# Network volume for persistent model storage (avoids re-downloading models each run)
|
| 59 |
+
# Set RUNPOD_VOLUME_ID in .env to use a persistent volume
|
| 60 |
+
# Set RUNPOD_VOLUME_DC to the datacenter ID where the volume lives (e.g. "EU-RO-1")
|
| 61 |
+
NETWORK_VOLUME_ID = os.environ.get("RUNPOD_VOLUME_ID", "")
|
| 62 |
+
NETWORK_VOLUME_DC = os.environ.get("RUNPOD_VOLUME_DC", "")
|
| 63 |
+
|
| 64 |
# Docker image with PyTorch + CUDA pre-installed
|
| 65 |
DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
|
| 66 |
|
|
|
|
| 99 |
cost_estimate: str | None = None
|
| 100 |
base_model: str = "sd15_realistic"
|
| 101 |
model_type: str = "sd15"
|
| 102 |
+
_db_callback: Any = None # called on state changes to persist to DB
|
| 103 |
|
| 104 |
def _log(self, msg: str):
|
| 105 |
self.log_lines.append(msg)
|
| 106 |
if len(self.log_lines) > 200:
|
| 107 |
self.log_lines = self.log_lines[-200:]
|
| 108 |
logger.info("[%s] %s", self.id, msg)
|
| 109 |
+
if self._db_callback:
|
| 110 |
+
self._db_callback(self)
|
| 111 |
|
| 112 |
|
| 113 |
class RunPodTrainer:
|
|
|
|
| 118 |
runpod.api_key = api_key
|
| 119 |
self._jobs: dict[str, CloudTrainingJob] = {}
|
| 120 |
self._model_registry = load_model_registry()
|
| 121 |
+
self._loaded_from_db = False
|
| 122 |
|
| 123 |
@property
|
| 124 |
def available(self) -> bool:
|
|
|
|
| 142 |
"learning_rate": cfg.get("learning_rate", 1e-4),
|
| 143 |
"network_rank": cfg.get("network_rank", 32),
|
| 144 |
"network_alpha": cfg.get("network_alpha", 16),
|
| 145 |
+
"optimizer": cfg.get("optimizer", "AdamW8bit"),
|
| 146 |
+
"lr_scheduler": cfg.get("lr_scheduler", "cosine"),
|
| 147 |
"vram_required_gb": cfg.get("vram_required_gb", 8),
|
| 148 |
"recommended_images": cfg.get("recommended_images", "15-30 photos"),
|
| 149 |
}
|
|
|
|
| 200 |
model_type=model_type,
|
| 201 |
)
|
| 202 |
self._jobs[job_id] = job
|
| 203 |
+
job._db_callback = self._schedule_db_save
|
| 204 |
+
asyncio.ensure_future(self._save_to_db(job))
|
| 205 |
|
| 206 |
# Launch the full pipeline as a background task
|
| 207 |
asyncio.create_task(self._run_cloud_training(
|
|
|
|
| 247 |
job.status = "creating_pod"
|
| 248 |
job._log(f"Creating RunPod with {job.gpu_type}...")
|
| 249 |
|
| 250 |
+
# Use network volume if configured (persists models across runs)
|
| 251 |
+
pod_kwargs = {
|
| 252 |
+
"container_disk_in_gb": 30,
|
| 253 |
+
"ports": "22/tcp",
|
| 254 |
+
"docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
|
| 255 |
+
}
|
| 256 |
+
if NETWORK_VOLUME_ID:
|
| 257 |
+
pod_kwargs["network_volume_id"] = NETWORK_VOLUME_ID
|
| 258 |
+
if NETWORK_VOLUME_DC:
|
| 259 |
+
pod_kwargs["data_center_id"] = NETWORK_VOLUME_DC
|
| 260 |
+
job._log(f"Using persistent network volume: {NETWORK_VOLUME_ID} (DC: {NETWORK_VOLUME_DC or 'auto'})")
|
| 261 |
+
else:
|
| 262 |
+
pod_kwargs["volume_in_gb"] = 75
|
| 263 |
+
|
| 264 |
pod = await asyncio.to_thread(
|
| 265 |
runpod.create_pod,
|
| 266 |
f"lora-train-{job.id}",
|
| 267 |
DOCKER_IMAGE,
|
| 268 |
job.gpu_type,
|
| 269 |
+
**pod_kwargs,
|
|
|
|
|
|
|
|
|
|
| 270 |
)
|
| 271 |
|
| 272 |
job.pod_id = pod["id"]
|
|
|
|
| 297 |
await asyncio.sleep(5)
|
| 298 |
|
| 299 |
job._log("SSH connected")
|
| 300 |
+
|
| 301 |
+
# If using network volume, symlink to /workspace so all paths work
|
| 302 |
+
if NETWORK_VOLUME_ID:
|
| 303 |
+
self._ssh_exec(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
|
| 304 |
+
job._log("Network volume symlinked to /workspace")
|
| 305 |
+
|
| 306 |
+
# Enable keepalive to prevent SSH timeout during uploads
|
| 307 |
+
transport = ssh.get_transport()
|
| 308 |
+
transport.set_keepalive(30)
|
| 309 |
+
|
| 310 |
sftp = ssh.open_sftp()
|
| 311 |
+
sftp.get_channel().settimeout(300) # 5 min timeout per file
|
| 312 |
|
| 313 |
+
# Step 3: Upload training images (compress first to speed up transfer)
|
| 314 |
job.status = "uploading"
|
| 315 |
+
resolution = model_cfg.get("resolution", 1024)
|
| 316 |
+
job._log(f"Compressing and uploading {len(image_paths)} training images...")
|
| 317 |
+
|
| 318 |
+
import tempfile
|
| 319 |
+
from PIL import Image
|
| 320 |
+
tmp_dir = Path(tempfile.mkdtemp(prefix="lora_upload_"))
|
| 321 |
|
| 322 |
folder_name = f"10_{trigger_word or 'character'}"
|
| 323 |
self._ssh_exec(ssh, f"mkdir -p /workspace/dataset/{folder_name}")
|
| 324 |
+
for i, img_path in enumerate(image_paths):
|
| 325 |
p = Path(img_path)
|
| 326 |
if p.exists():
|
| 327 |
+
# Resize and convert to JPEG to reduce upload size
|
| 328 |
+
try:
|
| 329 |
+
img = Image.open(p)
|
| 330 |
+
img.thumbnail((resolution * 2, resolution * 2), Image.LANCZOS)
|
| 331 |
+
compressed = tmp_dir / f"{p.stem}.jpg"
|
| 332 |
+
img.save(compressed, "JPEG", quality=95)
|
| 333 |
+
upload_path = compressed
|
| 334 |
+
except Exception:
|
| 335 |
+
upload_path = p # fallback to original
|
| 336 |
+
|
| 337 |
+
remote_name = f"{p.stem}.jpg" if upload_path.suffix == ".jpg" else p.name
|
| 338 |
+
remote_path = f"/workspace/dataset/{folder_name}/{remote_name}"
|
| 339 |
+
for attempt in range(3):
|
| 340 |
+
try:
|
| 341 |
+
sftp.put(str(upload_path), remote_path)
|
| 342 |
+
break
|
| 343 |
+
except (EOFError, OSError):
|
| 344 |
+
if attempt == 2:
|
| 345 |
+
raise
|
| 346 |
+
job._log(f"Upload retry {attempt+1} for {p.name}")
|
| 347 |
+
sftp.close()
|
| 348 |
+
sftp = ssh.open_sftp()
|
| 349 |
+
sftp.get_channel().settimeout(300)
|
| 350 |
# Upload matching caption .txt file if it exists locally
|
| 351 |
local_caption = p.with_suffix(".txt")
|
| 352 |
if local_caption.exists():
|
| 353 |
+
remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt"
|
| 354 |
sftp.put(str(local_caption), remote_caption)
|
| 355 |
else:
|
| 356 |
# Fallback: create caption from trigger word
|
| 357 |
+
remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt"
|
| 358 |
with sftp.open(remote_caption, "w") as f:
|
| 359 |
f.write(trigger_word or "")
|
| 360 |
+
job._log(f"Uploaded {i+1}/{len(image_paths)}: {p.name}")
|
| 361 |
+
|
| 362 |
+
# Cleanup temp compressed images
|
| 363 |
+
import shutil
|
| 364 |
+
shutil.rmtree(tmp_dir, ignore_errors=True)
|
| 365 |
|
| 366 |
job._log("Images uploaded")
|
| 367 |
|
| 368 |
+
# Step 4: Install training framework on the pod (skip if cached on volume)
|
| 369 |
job.status = "installing"
|
|
|
|
| 370 |
job.progress = 0.05
|
| 371 |
|
| 372 |
+
training_framework = model_cfg.get("training_framework", "sd-scripts")
|
| 373 |
+
|
| 374 |
+
if training_framework == "musubi-tuner":
|
| 375 |
+
# FLUX.2 uses musubi-tuner (Kohya's newer framework)
|
| 376 |
+
tuner_dir = "/workspace/musubi-tuner"
|
| 377 |
+
install_cmds = []
|
| 378 |
+
|
| 379 |
+
# Check if already present in workspace
|
| 380 |
+
tuner_exist = self._ssh_exec(ssh, f"test -f {tuner_dir}/pyproject.toml && echo EXISTS || echo MISSING").strip()
|
| 381 |
+
if tuner_exist == "EXISTS":
|
| 382 |
+
job._log("musubi-tuner found in workspace")
|
| 383 |
+
else:
|
| 384 |
+
# Check volume cache
|
| 385 |
+
vol_exist = self._ssh_exec(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING").strip()
|
| 386 |
+
if vol_exist == "EXISTS":
|
| 387 |
+
job._log("Restoring musubi-tuner from volume cache...")
|
| 388 |
+
self._ssh_exec(ssh, f"rm -rf {tuner_dir} 2>/dev/null; cp -r /runpod-volume/musubi-tuner {tuner_dir}")
|
| 389 |
+
else:
|
| 390 |
+
job._log("Cloning musubi-tuner from GitHub...")
|
| 391 |
+
self._ssh_exec(ssh, f"rm -rf {tuner_dir} /runpod-volume/musubi-tuner 2>/dev/null; true")
|
| 392 |
+
install_cmds.append(f"cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git")
|
| 393 |
+
# Save to volume for future pods
|
| 394 |
+
if NETWORK_VOLUME_ID:
|
| 395 |
+
install_cmds.append(f"cp -r {tuner_dir} /runpod-volume/musubi-tuner")
|
| 396 |
+
|
| 397 |
+
# Always install pip deps (they are pod-local, lost on every new pod)
|
| 398 |
+
job._log("Installing pip dependencies (accelerate, torch, etc.)...")
|
| 399 |
+
install_cmds.extend([
|
| 400 |
+
f"cd {tuner_dir} && pip install -e . 2>&1 | tail -5",
|
| 401 |
+
"pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes 2>&1 | tail -5",
|
| 402 |
+
])
|
| 403 |
+
else:
|
| 404 |
+
# SD 1.5 / SDXL / FLUX.1 use sd-scripts
|
| 405 |
+
scripts_exist = self._ssh_exec(ssh, "test -f /workspace/sd-scripts/setup.py && echo EXISTS || echo MISSING").strip()
|
| 406 |
+
if scripts_exist == "EXISTS":
|
| 407 |
+
job._log("Kohya sd-scripts already cached on volume, updating...")
|
| 408 |
+
install_cmds = [
|
| 409 |
+
"cd /workspace/sd-scripts && git pull 2>&1 | tail -1",
|
| 410 |
+
]
|
| 411 |
+
else:
|
| 412 |
+
job._log("Installing Kohya sd-scripts (this takes a few minutes)...")
|
| 413 |
+
install_cmds = [
|
| 414 |
+
"cd /workspace && git clone --depth 1 https://github.com/kohya-ss/sd-scripts.git",
|
| 415 |
+
]
|
| 416 |
+
# Always install pip deps (pod-local, lost on new pods)
|
| 417 |
+
install_cmds.extend([
|
| 418 |
+
"cd /workspace/sd-scripts && pip install -r requirements.txt 2>&1 | tail -1",
|
| 419 |
+
"pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes xformers 2>&1 | tail -1",
|
| 420 |
+
])
|
| 421 |
for cmd in install_cmds:
|
| 422 |
out = self._ssh_exec(ssh, cmd, timeout=600)
|
| 423 |
job._log(out[:200] if out else "done")
|
| 424 |
|
| 425 |
+
# Download base model from HuggingFace (skip if already on network volume)
|
| 426 |
hf_repo = model_cfg.get("hf_repo", "SG161222/Realistic_Vision_V5.1_noVAE")
|
| 427 |
hf_filename = model_cfg.get("hf_filename", "Realistic_Vision_V5.1_fp16-no-ema.safetensors")
|
| 428 |
model_name = model_cfg.get("name", job.base_model)
|
| 429 |
|
|
|
|
| 430 |
job.progress = 0.1
|
|
|
|
| 431 |
self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
|
| 432 |
|
| 433 |
+
if model_type == "flux2":
|
| 434 |
+
# FLUX.2 models are stored in a directory structure on the volume
|
| 435 |
+
flux2_dir = "/workspace/models/FLUX.2-dev"
|
| 436 |
+
dit_path = f"{flux2_dir}/flux2-dev.safetensors"
|
| 437 |
+
vae_path = f"{flux2_dir}/ae.safetensors" # Original BFL format (not diffusers)
|
| 438 |
+
te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
|
| 439 |
+
|
| 440 |
+
dit_exists = self._ssh_exec(ssh, f"test -f {dit_path} && echo EXISTS || echo MISSING").strip()
|
| 441 |
+
vae_exists = self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING").strip()
|
| 442 |
+
te_exists = self._ssh_exec(ssh, f"test -f {te_path} && echo EXISTS || echo MISSING").strip()
|
| 443 |
+
|
| 444 |
+
if dit_exists != "EXISTS" or te_exists != "EXISTS":
|
| 445 |
+
missing = []
|
| 446 |
+
if dit_exists != "EXISTS":
|
| 447 |
+
missing.append("DiT")
|
| 448 |
+
if te_exists != "EXISTS":
|
| 449 |
+
missing.append("text encoder")
|
| 450 |
+
raise RuntimeError(f"FLUX.2 Dev missing on volume: {', '.join(missing)}. Please download models to the network volume first.")
|
| 451 |
+
|
| 452 |
+
# Download ae.safetensors (original format VAE) if not present
|
| 453 |
+
if vae_exists != "EXISTS":
|
| 454 |
+
job._log("Downloading FLUX.2 VAE (ae.safetensors, 336MB)...")
|
| 455 |
+
self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
|
| 456 |
+
self._ssh_exec(ssh, f"""python -c "
|
| 457 |
+
from huggingface_hub import hf_hub_download
|
| 458 |
+
hf_hub_download('black-forest-labs/FLUX.2-dev', 'ae.safetensors', local_dir='{flux2_dir}')
|
| 459 |
+
print('Downloaded ae.safetensors')
|
| 460 |
+
" 2>&1 | tail -5""", timeout=600)
|
| 461 |
+
# Verify download
|
| 462 |
+
vae_check = self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING").strip()
|
| 463 |
+
if vae_check != "EXISTS":
|
| 464 |
+
raise RuntimeError("Failed to download ae.safetensors")
|
| 465 |
+
job._log("VAE downloaded")
|
| 466 |
+
|
| 467 |
+
job._log("FLUX.2 Dev models ready")
|
| 468 |
+
|
| 469 |
+
else:
|
| 470 |
+
# SD 1.5 / SDXL / FLUX.1 — download single model file
|
| 471 |
+
model_exists = self._ssh_exec(ssh, f"test -f /workspace/models/{hf_filename} && echo EXISTS || echo MISSING").strip()
|
| 472 |
+
if model_exists == "EXISTS":
|
| 473 |
+
job._log(f"Base model already cached on volume: {model_name}")
|
| 474 |
+
else:
|
| 475 |
+
job._log(f"Downloading base model: {model_name}...")
|
| 476 |
+
self._ssh_exec(ssh, f"""
|
| 477 |
+
python -c "
|
| 478 |
from huggingface_hub import hf_hub_download
|
| 479 |
hf_hub_download('{hf_repo}', '{hf_filename}', local_dir='/workspace/models')
|
| 480 |
" 2>&1 | tail -5
|
| 481 |
+
""", timeout=1200)
|
| 482 |
|
| 483 |
+
# For FLUX.1, download additional required models (CLIP, T5, VAE)
|
| 484 |
+
if model_type == "flux":
|
| 485 |
+
flux_files_check = self._ssh_exec(ssh, "test -f /workspace/models/clip_l.safetensors && test -f /workspace/models/t5xxl_fp16.safetensors && test -f /workspace/models/ae.safetensors && echo EXISTS || echo MISSING").strip()
|
| 486 |
+
if flux_files_check == "EXISTS":
|
| 487 |
+
job._log("FLUX.1 auxiliary models already cached on volume")
|
| 488 |
+
else:
|
| 489 |
+
job._log("Downloading FLUX.1 auxiliary models (CLIP, T5, VAE)...")
|
| 490 |
+
job.progress = 0.12
|
| 491 |
+
self._ssh_exec(ssh, """
|
| 492 |
+
python -c "
|
| 493 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 494 |
hf_hub_download('comfyanonymous/flux_text_encoders', 'clip_l.safetensors', local_dir='/workspace/models')
|
|
|
|
| 495 |
hf_hub_download('comfyanonymous/flux_text_encoders', 't5xxl_fp16.safetensors', local_dir='/workspace/models')
|
|
|
|
| 496 |
hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/workspace/models')
|
| 497 |
" 2>&1 | tail -5
|
| 498 |
+
""", timeout=1200)
|
| 499 |
|
| 500 |
+
job._log("Base model ready")
|
| 501 |
job.progress = 0.15
|
| 502 |
|
| 503 |
# Step 5: Run training
|
| 504 |
job.status = "training"
|
| 505 |
job._log(f"Starting {model_type.upper()} LoRA training...")
|
| 506 |
|
| 507 |
+
if model_type == "flux2":
|
| 508 |
+
model_path = f"/workspace/models/FLUX.2-dev/flux2-dev.safetensors"
|
| 509 |
+
else:
|
| 510 |
+
model_path = f"/workspace/models/{hf_filename}"
|
| 511 |
+
|
| 512 |
+
# For musubi-tuner, create TOML dataset config
|
| 513 |
+
if training_framework == "musubi-tuner":
|
| 514 |
+
folder_name = f"10_{trigger_word or 'character'}"
|
| 515 |
+
toml_content = f"""[[datasets]]
|
| 516 |
+
image_directory = "/workspace/dataset/{folder_name}"
|
| 517 |
+
caption_extension = ".txt"
|
| 518 |
+
batch_size = 1
|
| 519 |
+
num_repeats = 10
|
| 520 |
+
resolution = [{resolution}, {resolution}]
|
| 521 |
+
"""
|
| 522 |
+
self._ssh_exec(ssh, f"cat > /workspace/dataset.toml << 'TOMLEOF'\n{toml_content}TOMLEOF")
|
| 523 |
+
job._log("Created dataset.toml config")
|
| 524 |
+
|
| 525 |
+
# musubi-tuner requires pre-caching latents and text encoder outputs
|
| 526 |
+
flux2_dir = "/workspace/models/FLUX.2-dev"
|
| 527 |
+
vae_path = f"{flux2_dir}/ae.safetensors"
|
| 528 |
+
te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
|
| 529 |
+
|
| 530 |
+
job._log("Caching latents (VAE encoding)...")
|
| 531 |
+
job.progress = 0.15
|
| 532 |
+
self._schedule_db_save(job)
|
| 533 |
+
cache_latents_cmd = (
|
| 534 |
+
f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python src/musubi_tuner/flux_2_cache_latents.py"
|
| 535 |
+
f" --dataset_config /workspace/dataset.toml"
|
| 536 |
+
f" --vae {vae_path}"
|
| 537 |
+
f" --model_version dev"
|
| 538 |
+
f" --vae_dtype bfloat16"
|
| 539 |
+
f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}"
|
| 540 |
+
)
|
| 541 |
+
out = self._ssh_exec(ssh, cache_latents_cmd, timeout=600)
|
| 542 |
+
# Get last lines which have the real error
|
| 543 |
+
last_lines = out.split('\n')[-30:]
|
| 544 |
+
job._log('\n'.join(last_lines))
|
| 545 |
+
if "EXIT_CODE=0" not in out:
|
| 546 |
+
# Fetch the full error log
|
| 547 |
+
err_log = self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10")
|
| 548 |
+
job._log(f"Cache error details: {err_log}")
|
| 549 |
+
raise RuntimeError(f"Latent caching failed")
|
| 550 |
+
|
| 551 |
+
job._log("Caching text encoder outputs (bf16)...")
|
| 552 |
+
job.progress = 0.25
|
| 553 |
+
self._schedule_db_save(job)
|
| 554 |
+
cache_te_cmd = (
|
| 555 |
+
f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| 556 |
+
f" python src/musubi_tuner/flux_2_cache_text_encoder_outputs.py"
|
| 557 |
+
f" --dataset_config /workspace/dataset.toml"
|
| 558 |
+
f" --text_encoder {te_path}"
|
| 559 |
+
f" --model_version dev"
|
| 560 |
+
f" --batch_size 1"
|
| 561 |
+
f" 2>&1; echo EXIT_CODE=$?"
|
| 562 |
+
)
|
| 563 |
+
out = self._ssh_exec(ssh, cache_te_cmd, timeout=600)
|
| 564 |
+
job._log(out[-500:] if out else "done")
|
| 565 |
+
if "EXIT_CODE=0" not in out:
|
| 566 |
+
raise RuntimeError(f"Text encoder caching failed: {out[-200:]}")
|
| 567 |
|
| 568 |
# Build training command based on model type
|
| 569 |
train_cmd = self._build_training_command(
|
|
|
|
| 579 |
optimizer=optimizer,
|
| 580 |
save_every_n_epochs=save_every_n_epochs,
|
| 581 |
model_cfg=model_cfg,
|
| 582 |
+
gpu_type=job.gpu_type,
|
| 583 |
)
|
| 584 |
|
| 585 |
# Execute training and stream output
|
|
|
|
| 590 |
|
| 591 |
# Read output progressively
|
| 592 |
buffer = ""
|
| 593 |
+
last_flush = time.time()
|
| 594 |
while not channel.exit_status_ready() or channel.recv_ready():
|
| 595 |
if channel.recv_ready():
|
| 596 |
chunk = channel.recv(4096).decode("utf-8", errors="replace")
|
| 597 |
buffer += chunk
|
| 598 |
+
# Process complete lines (handle both \n and \r for tqdm progress)
|
| 599 |
+
while "\n" in buffer or "\r" in buffer:
|
| 600 |
+
# Split on whichever comes first
|
| 601 |
+
n_pos = buffer.find("\n")
|
| 602 |
+
r_pos = buffer.find("\r")
|
| 603 |
+
if n_pos == -1:
|
| 604 |
+
split_pos = r_pos
|
| 605 |
+
elif r_pos == -1:
|
| 606 |
+
split_pos = n_pos
|
| 607 |
+
else:
|
| 608 |
+
split_pos = min(n_pos, r_pos)
|
| 609 |
+
line = buffer[:split_pos].strip()
|
| 610 |
+
buffer = buffer[split_pos + 1:]
|
| 611 |
if not line:
|
| 612 |
continue
|
| 613 |
job._log(line)
|
| 614 |
self._parse_progress(job, line)
|
| 615 |
+
self._schedule_db_save(job)
|
| 616 |
else:
|
| 617 |
+
# Periodically flush buffer for partial tqdm lines
|
| 618 |
+
if buffer.strip() and time.time() - last_flush > 10:
|
| 619 |
+
job._log(buffer.strip())
|
| 620 |
+
self._parse_progress(job, buffer.strip())
|
| 621 |
+
buffer = ""
|
| 622 |
+
last_flush = time.time()
|
| 623 |
+
self._schedule_db_save(job)
|
| 624 |
await asyncio.sleep(2)
|
| 625 |
|
| 626 |
exit_code = channel.recv_exit_status()
|
|
|
|
| 630 |
job._log("Training completed on RunPod!")
|
| 631 |
job.progress = 0.9
|
| 632 |
|
| 633 |
+
# Step 6: Save LoRA to network volume and download locally
|
| 634 |
job.status = "downloading"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
+
# First, copy to network volume for persistence
|
| 637 |
+
job._log("Saving LoRA to network volume...")
|
| 638 |
+
self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
|
| 639 |
+
remote_output = f"/workspace/output/{name}.safetensors"
|
| 640 |
+
# Find the output file
|
| 641 |
+
check = self._ssh_exec(ssh, f"test -f {remote_output} && echo EXISTS || echo MISSING").strip()
|
| 642 |
+
if check == "MISSING":
|
| 643 |
+
remote_files = self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null").strip()
|
| 644 |
+
if remote_files:
|
| 645 |
+
remote_output = remote_files.split("\n")[-1].strip()
|
| 646 |
else:
|
| 647 |
raise RuntimeError("No .safetensors output found")
|
| 648 |
|
| 649 |
+
self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
|
| 650 |
+
job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
|
| 651 |
+
|
| 652 |
+
# Then download locally
|
| 653 |
+
job._log("Downloading LoRA to local machine...")
|
| 654 |
+
LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 655 |
+
local_path = LORA_OUTPUT_DIR / f"{name}.safetensors"
|
| 656 |
+
sftp.get(remote_output, str(local_path))
|
| 657 |
+
|
| 658 |
job.output_path = str(local_path)
|
| 659 |
+
job._log(f"LoRA saved locally to {local_path}")
|
| 660 |
|
| 661 |
# Done!
|
| 662 |
job.status = "completed"
|
|
|
|
| 692 |
except Exception as e:
|
| 693 |
job._log(f"Warning: Failed to terminate pod {job.pod_id}: {e}")
|
| 694 |
|
| 695 |
+
def _schedule_db_save(self, job: CloudTrainingJob):
|
| 696 |
+
"""Schedule a DB save (non-blocking)."""
|
| 697 |
+
try:
|
| 698 |
+
asyncio.get_event_loop().create_task(self._save_to_db(job))
|
| 699 |
+
except RuntimeError:
|
| 700 |
+
pass # no event loop
|
| 701 |
+
|
| 702 |
+
async def _save_to_db(self, job: CloudTrainingJob):
|
| 703 |
+
"""Persist job state to database."""
|
| 704 |
+
try:
|
| 705 |
+
from sqlalchemy import text
|
| 706 |
+
async with catalog_session_factory() as session:
|
| 707 |
+
# Use raw INSERT OR REPLACE for SQLite upsert
|
| 708 |
+
await session.execute(
|
| 709 |
+
text("""INSERT OR REPLACE INTO training_jobs
|
| 710 |
+
(id, name, status, progress, current_epoch, total_epochs,
|
| 711 |
+
current_step, total_steps, loss, started_at, completed_at,
|
| 712 |
+
output_path, error, log_text, pod_id, gpu_type, backend,
|
| 713 |
+
base_model, model_type, created_at)
|
| 714 |
+
VALUES (:id, :name, :status, :progress, :current_epoch, :total_epochs,
|
| 715 |
+
:current_step, :total_steps, :loss, :started_at, :completed_at,
|
| 716 |
+
:output_path, :error, :log_text, :pod_id, :gpu_type, :backend,
|
| 717 |
+
:base_model, :model_type, COALESCE((SELECT created_at FROM training_jobs WHERE id = :id), CURRENT_TIMESTAMP))
|
| 718 |
+
"""),
|
| 719 |
+
{
|
| 720 |
+
"id": job.id, "name": job.name, "status": job.status,
|
| 721 |
+
"progress": job.progress, "current_epoch": job.current_epoch,
|
| 722 |
+
"total_epochs": job.total_epochs, "current_step": job.current_step,
|
| 723 |
+
"total_steps": job.total_steps, "loss": job.loss,
|
| 724 |
+
"started_at": job.started_at, "completed_at": job.completed_at,
|
| 725 |
+
"output_path": job.output_path, "error": job.error,
|
| 726 |
+
"log_text": "\n".join(job.log_lines[-200:]),
|
| 727 |
+
"pod_id": job.pod_id, "gpu_type": job.gpu_type,
|
| 728 |
+
"backend": "runpod", "base_model": job.base_model,
|
| 729 |
+
"model_type": job.model_type,
|
| 730 |
+
}
|
| 731 |
+
)
|
| 732 |
+
await session.commit()
|
| 733 |
+
except Exception as e:
|
| 734 |
+
logger.warning("Failed to save training job to DB: %s", e)
|
| 735 |
+
|
| 736 |
+
async def _load_jobs_from_db(self):
|
| 737 |
+
"""Load previously saved jobs from database on startup."""
|
| 738 |
+
try:
|
| 739 |
+
from sqlalchemy import select
|
| 740 |
+
async with catalog_session_factory() as session:
|
| 741 |
+
result = await session.execute(
|
| 742 |
+
select(TrainingJobDB).order_by(TrainingJobDB.created_at.desc()).limit(20)
|
| 743 |
+
)
|
| 744 |
+
db_jobs = result.scalars().all()
|
| 745 |
+
for db_job in db_jobs:
|
| 746 |
+
if db_job.id not in self._jobs:
|
| 747 |
+
job = CloudTrainingJob(
|
| 748 |
+
id=db_job.id,
|
| 749 |
+
name=db_job.name,
|
| 750 |
+
status=db_job.status,
|
| 751 |
+
progress=db_job.progress or 0.0,
|
| 752 |
+
current_epoch=db_job.current_epoch or 0,
|
| 753 |
+
total_epochs=db_job.total_epochs or 0,
|
| 754 |
+
current_step=db_job.current_step or 0,
|
| 755 |
+
total_steps=db_job.total_steps or 0,
|
| 756 |
+
loss=db_job.loss,
|
| 757 |
+
started_at=db_job.started_at,
|
| 758 |
+
completed_at=db_job.completed_at,
|
| 759 |
+
output_path=db_job.output_path,
|
| 760 |
+
error=db_job.error,
|
| 761 |
+
log_lines=(db_job.log_text or "").split("\n") if db_job.log_text else [],
|
| 762 |
+
pod_id=db_job.pod_id,
|
| 763 |
+
gpu_type=db_job.gpu_type or DEFAULT_GPU,
|
| 764 |
+
base_model=db_job.base_model or "sd15_realistic",
|
| 765 |
+
model_type=db_job.model_type or "sd15",
|
| 766 |
+
)
|
| 767 |
+
# Mark interrupted jobs as failed
|
| 768 |
+
if job.status not in ("completed", "failed"):
|
| 769 |
+
job.status = "failed"
|
| 770 |
+
job.error = "Interrupted by server restart"
|
| 771 |
+
self._jobs[db_job.id] = job
|
| 772 |
+
except Exception as e:
|
| 773 |
+
logger.warning("Failed to load training jobs from DB: %s", e)
|
| 774 |
+
|
| 775 |
+
async def ensure_loaded(self):
|
| 776 |
+
"""Load jobs from DB on first access."""
|
| 777 |
+
if not self._loaded_from_db:
|
| 778 |
+
self._loaded_from_db = True
|
| 779 |
+
await self._load_jobs_from_db()
|
| 780 |
+
|
| 781 |
def _build_training_command(
|
| 782 |
self,
|
| 783 |
*,
|
|
|
|
| 793 |
optimizer: str,
|
| 794 |
save_every_n_epochs: int,
|
| 795 |
model_cfg: dict,
|
| 796 |
+
gpu_type: str = "",
|
| 797 |
) -> str:
|
| 798 |
"""Build the training command based on model type."""
|
| 799 |
|
|
|
|
| 827 |
lr_scheduler = model_cfg.get("lr_scheduler", "cosine_with_restarts")
|
| 828 |
base_args += f" \\\n --lr_scheduler={lr_scheduler}"
|
| 829 |
|
| 830 |
+
if model_type == "flux2":
|
| 831 |
+
# FLUX.2 training via musubi-tuner
|
| 832 |
+
flux2_dir = "/workspace/models/FLUX.2-dev"
|
| 833 |
+
dit_path = f"{flux2_dir}/flux2-dev.safetensors"
|
| 834 |
+
vae_path = f"{flux2_dir}/ae.safetensors"
|
| 835 |
+
te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
|
| 836 |
+
|
| 837 |
+
network_mod = model_cfg.get("network_module", "networks.lora_flux_2")
|
| 838 |
+
ts_sampling = model_cfg.get("timestep_sampling", "flux2_shift")
|
| 839 |
+
lr_scheduler = model_cfg.get("lr_scheduler", "cosine")
|
| 840 |
+
|
| 841 |
+
# Build as list of args to avoid shell escaping issues
|
| 842 |
+
args = [
|
| 843 |
+
"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
|
| 844 |
+
"accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16",
|
| 845 |
+
"src/musubi_tuner/flux_2_train_network.py",
|
| 846 |
+
"--model_version dev",
|
| 847 |
+
f"--dit {dit_path}",
|
| 848 |
+
f"--vae {vae_path}",
|
| 849 |
+
f"--text_encoder {te_path}",
|
| 850 |
+
"--dataset_config /workspace/dataset.toml",
|
| 851 |
+
"--sdpa --mixed_precision bf16",
|
| 852 |
+
f"--timestep_sampling {ts_sampling} --weighting_scheme none",
|
| 853 |
+
f"--network_module {network_mod}",
|
| 854 |
+
f"--network_dim={network_rank}",
|
| 855 |
+
f"--network_alpha={network_alpha}",
|
| 856 |
+
"--gradient_checkpointing",
|
| 857 |
+
]
|
| 858 |
+
|
| 859 |
+
# Only use fp8_base on GPUs with native fp8 support (RTX 4090, H100)
|
| 860 |
+
# A100 and A6000 don't support fp8 tensor ops, and have enough VRAM without it
|
| 861 |
+
if gpu_type and ("4090" in gpu_type or "5090" in gpu_type or "L40S" in gpu_type or "H100" in gpu_type):
|
| 862 |
+
args.append("--fp8_base")
|
| 863 |
+
|
| 864 |
+
# Handle Prodigy optimizer (needs special class path and args)
|
| 865 |
+
if optimizer.lower() == "prodigy":
|
| 866 |
+
args.extend([
|
| 867 |
+
"--optimizer_type=prodigyopt.Prodigy",
|
| 868 |
+
f"--learning_rate={learning_rate}",
|
| 869 |
+
'--optimizer_args "weight_decay=0.01" "decouple=True" "use_bias_correction=True" "safeguard_warmup=True" "d_coef=2"',
|
| 870 |
+
])
|
| 871 |
+
else:
|
| 872 |
+
args.extend([
|
| 873 |
+
f"--optimizer_type={optimizer}",
|
| 874 |
+
f"--learning_rate={learning_rate}",
|
| 875 |
+
])
|
| 876 |
+
|
| 877 |
+
args.extend([
|
| 878 |
+
f"--save_every_n_epochs={save_every_n_epochs}",
|
| 879 |
+
"--seed=42",
|
| 880 |
+
'--output_dir=/workspace/output',
|
| 881 |
+
f'--output_name={name}',
|
| 882 |
+
f"--lr_scheduler={lr_scheduler}",
|
| 883 |
+
])
|
| 884 |
+
|
| 885 |
+
if max_train_steps:
|
| 886 |
+
args.append(f"--max_train_steps={max_train_steps}")
|
| 887 |
+
else:
|
| 888 |
+
args.append(f"--max_train_epochs={num_epochs}")
|
| 889 |
+
|
| 890 |
+
return " ".join(args) + " 2>&1"
|
| 891 |
+
|
| 892 |
+
elif model_type == "flux":
|
| 893 |
+
# FLUX.1 training via sd-scripts
|
| 894 |
script = "flux_train_network.py"
|
| 895 |
flux_args = f"""
|
| 896 |
--pretrained_model_name_or_path="{model_path}" \
|
|
|
|
| 931 |
--clip_skip={clip_skip} \
|
| 932 |
--xformers {base_args} 2>&1"""
|
| 933 |
|
| 934 |
+
async def _wait_for_pod_ready(self, job: CloudTrainingJob, timeout: int = 600) -> tuple[str, int]:
|
| 935 |
"""Wait for pod to be running and return (ssh_host, ssh_port)."""
|
| 936 |
start = time.time()
|
| 937 |
while time.time() - start < timeout:
|
| 938 |
+
try:
|
| 939 |
+
pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
|
| 940 |
+
except Exception as e:
|
| 941 |
+
job._log(f" API error: {e}")
|
| 942 |
+
await asyncio.sleep(10)
|
| 943 |
+
continue
|
| 944 |
|
| 945 |
status = pod.get("desiredStatus", "")
|
| 946 |
runtime = pod.get("runtime")
|
| 947 |
|
| 948 |
if status == "RUNNING" and runtime:
|
| 949 |
+
ports = runtime.get("ports") or []
|
| 950 |
+
for port_info in (ports or []):
|
| 951 |
if port_info.get("privatePort") == 22:
|
| 952 |
ip = port_info.get("ip")
|
| 953 |
public_port = port_info.get("publicPort")
|
| 954 |
if ip and public_port:
|
| 955 |
return ip, int(public_port)
|
| 956 |
|
| 957 |
+
elapsed = int(time.time() - start)
|
| 958 |
+
if elapsed % 30 < 6:
|
| 959 |
+
job._log(f" Status: {status} | runtime: {'ports pending' if runtime else 'not ready yet'} ({elapsed}s)")
|
| 960 |
+
|
| 961 |
await asyncio.sleep(5)
|
| 962 |
|
| 963 |
raise RuntimeError(f"Pod did not become ready within {timeout}s")
|
|
|
|
| 1024 |
job.status = "failed"
|
| 1025 |
job.error = "Cancelled by user"
|
| 1026 |
return True
|
| 1027 |
+
|
| 1028 |
+
async def delete_job(self, job_id: str) -> bool:
|
| 1029 |
+
"""Delete a training job from memory and database."""
|
| 1030 |
+
if job_id not in self._jobs:
|
| 1031 |
+
return False
|
| 1032 |
+
del self._jobs[job_id]
|
| 1033 |
+
try:
|
| 1034 |
+
async with catalog_session_factory() as session:
|
| 1035 |
+
result = await session.execute(
|
| 1036 |
+
__import__('sqlalchemy').select(TrainingJobDB).where(TrainingJobDB.id == job_id)
|
| 1037 |
+
)
|
| 1038 |
+
db_job = result.scalar_one_or_none()
|
| 1039 |
+
if db_job:
|
| 1040 |
+
await session.delete(db_job)
|
| 1041 |
+
await session.commit()
|
| 1042 |
+
except Exception as e:
|
| 1043 |
+
logger.warning("Failed to delete job from DB: %s", e)
|
| 1044 |
+
return True
|
| 1045 |
+
|
| 1046 |
+
async def delete_failed_jobs(self) -> int:
|
| 1047 |
+
"""Delete all failed/error training jobs."""
|
| 1048 |
+
failed_ids = [jid for jid, j in self._jobs.items() if j.status in ("failed", "error")]
|
| 1049 |
+
for jid in failed_ids:
|
| 1050 |
+
await self.delete_job(jid)
|
| 1051 |
+
return len(failed_ids)
|
src/content_engine/workers/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (239 Bytes)
|
|
|
src/content_engine/workers/__pycache__/local_worker.cpython-311.pyc
DELETED
|
Binary file (6.09 kB)
|
|
|