dippoo Claude Opus 4.6 commited on
Commit
27fea48
·
1 Parent(s): b051fff

Add pod generation with FLUX.2, persistent state, training improvements

Browse files

- Pod management: network volume support, HTTPS proxy URLs, fire-and-forget SSH
- FLUX.2 workflow: separate UNETLoader + CLIPLoader + VAELoader with Comfy-Org text encoder
- Persist pod state to disk (survives server restarts)
- Training: persistent job tracking in DB, live log streaming
- Remove tracked __pycache__ files, add .gitignore

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (32) hide show
  1. .gitignore +3 -0
  2. CLAUDE.md +219 -0
  3. config/models.yaml +20 -14
  4. src/content_engine/__pycache__/__init__.cpython-311.pyc +0 -0
  5. src/content_engine/__pycache__/config.cpython-311.pyc +0 -0
  6. src/content_engine/__pycache__/main.cpython-311.pyc +0 -0
  7. src/content_engine/api/__pycache__/__init__.cpython-311.pyc +0 -0
  8. src/content_engine/api/__pycache__/routes_catalog.cpython-311.pyc +0 -0
  9. src/content_engine/api/__pycache__/routes_generation.cpython-311.pyc +0 -0
  10. src/content_engine/api/__pycache__/routes_pod.cpython-311.pyc +0 -0
  11. src/content_engine/api/__pycache__/routes_system.cpython-311.pyc +0 -0
  12. src/content_engine/api/__pycache__/routes_training.cpython-311.pyc +0 -0
  13. src/content_engine/api/__pycache__/routes_ui.cpython-311.pyc +0 -0
  14. src/content_engine/api/__pycache__/routes_video.cpython-311.pyc +0 -0
  15. src/content_engine/api/routes_pod.py +441 -112
  16. src/content_engine/api/routes_training.py +51 -21
  17. src/content_engine/api/ui.html +156 -20
  18. src/content_engine/models/__pycache__/__init__.cpython-311.pyc +0 -0
  19. src/content_engine/models/__pycache__/database.cpython-311.pyc +0 -0
  20. src/content_engine/models/__pycache__/schemas.cpython-311.pyc +0 -0
  21. src/content_engine/models/database.py +29 -7
  22. src/content_engine/services/__pycache__/__init__.cpython-311.pyc +0 -0
  23. src/content_engine/services/__pycache__/catalog.cpython-311.pyc +0 -0
  24. src/content_engine/services/__pycache__/comfyui_client.cpython-311.pyc +0 -0
  25. src/content_engine/services/__pycache__/lora_trainer.cpython-311.pyc +0 -0
  26. src/content_engine/services/__pycache__/runpod_trainer.cpython-311.pyc +0 -0
  27. src/content_engine/services/__pycache__/template_engine.cpython-311.pyc +0 -0
  28. src/content_engine/services/__pycache__/variation_engine.cpython-311.pyc +0 -0
  29. src/content_engine/services/__pycache__/workflow_builder.cpython-311.pyc +0 -0
  30. src/content_engine/services/runpod_trainer.py +494 -68
  31. src/content_engine/workers/__pycache__/__init__.cpython-311.pyc +0 -0
  32. src/content_engine/workers/__pycache__/local_worker.cpython-311.pyc +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ pod_state.json
CLAUDE.md ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Content Engine
2
+
3
+ Automated AI content generation system with cloud APIs, LoRA training, and multi-backend support.
4
+
5
+ ## Repositories
6
+
7
+ - **Local Development**: `D:\AI automation\content_engine\`
8
+ - **HuggingFace Deployment**: `D:\AI automation\content-engine\` (deployed to https://huggingface.co/spaces/dippoo/content-engine)
9
+
10
+ Always sync changes between both directories when modifying code.
11
+
12
+ ## Architecture
13
+
14
+ ```
15
+ ┌─────────────────────────────────────────────────────────────┐
16
+ │ Frontend (ui.html) │
17
+ │ Generate | Batch | Gallery | Train LoRA | Status | Settings │
18
+ └─────────────────────────────────────────────────────────────┘
19
+
20
+
21
+ ┌─────────────────────────────────────────────────────────────┐
22
+ │ FastAPI Backend (main.py) │
23
+ │ routes_generation | routes_video | routes_training | etc │
24
+ └─────────────────────────────────────────────────────────────┘
25
+
26
+ ┌─────────────────────┼─────────────────────┐
27
+ ▼ ▼ ▼
28
+ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
29
+ │ Local GPU │ │ RunPod │ │ Cloud APIs │
30
+ │ (ComfyUI) │ │ (Serverless) │ │ (WaveSpeed) │
31
+ └───────────────┘ └───────────────┘ └───────────────┘
32
+ ```
33
+
34
+ ## Cloud Providers
35
+
36
+ ### WaveSpeed (wavespeed_provider.py)
37
+ Primary cloud API for image/video generation. Uses direct HTTP API (SDK optional).
38
+
39
+ **Text-to-Image Models:**
40
+ - `seedream-4.5` - Best quality, NSFW OK (ByteDance)
41
+ - `seedream-4`, `seedream-3.1` - NSFW friendly
42
+ - `gpt-image-1.5`, `gpt-image-1-mini` - OpenAI models
43
+ - `nano-banana-pro`, `nano-banana` - Google models
44
+ - `wan-2.6`, `wan-2.5` - Alibaba models
45
+ - `kling-image-o3` - Kuaishou
46
+
47
+ **Image-to-Image (Edit) Models:**
48
+ - `seedream-4.5-edit` - Best for face preservation
49
+ - `seedream-4.5-multi`, `seedream-4-multi` - Multi-reference (up to 3 images)
50
+ - `kling-o1-multi` - Multi-reference (up to 10 images)
51
+ - `wan-2.6-edit`, `wan-2.5-edit` - NSFW friendly
52
+
53
+ **Image-to-Video Models:**
54
+ - `wan-2.6-i2v-pro` - Best quality ($0.05/s)
55
+ - `wan-2.6-i2v-flash` - Fast
56
+ - `kling-o3-pro`, `kling-o3` - Kuaishou
57
+ - `higgsfield-dop` - Cinematic 5s clips
58
+ - `veo-3.1`, `sora-2` - Premium models
59
+
60
+ **API Pattern:**
61
+ ```python
62
+ # WaveSpeed returns async jobs - must poll for result
63
+ response = {"data": {"outputs": [], "urls": {"get": "poll_url"}}}
64
+ # Poll urls.get until outputs[] is populated
65
+ ```
66
+
67
+ ### RunPod
68
+ - **Training**: Cloud GPU for LoRA training (runpod_trainer.py)
69
+ - **Generation**: Serverless endpoint for inference (runpod_provider.py)
70
+
71
+ ## Character System
72
+
73
+ Characters link a trained LoRA to generation:
74
+
75
+ **Config file** (`config/characters/alice.yaml`):
76
+ ```yaml
77
+ id: alice
78
+ name: "Alice"
79
+ trigger_word: "alicechar" # Activates the LoRA
80
+ lora_filename: "alice_v1.safetensors" # In D:\ComfyUI\Models\Lora\
81
+ lora_strength: 0.85
82
+ ```
83
+
84
+ **Generation flow:**
85
+ 1. User selects character from dropdown
86
+ 2. System prepends trigger word: `"alicechar, a woman in red dress"`
87
+ 3. LoRA is loaded into workflow (local/RunPod only)
88
+ 4. Character identity is preserved in output
89
+
90
+ **For cloud-only (no local GPU):**
91
+ - Use img2img with reference photo
92
+ - Or deploy LoRA to RunPod serverless endpoint
93
+
94
+ ## Templates
95
+
96
+ Prompt recipes with variables (`config/templates/*.yaml`):
97
+ ```yaml
98
+ id: portrait_glamour
99
+ name: "Glamour Portrait"
100
+ positive: "{{character}}, {{pose}}, {{lighting}}, professional photo"
101
+ variables:
102
+ - name: pose
103
+ options: ["standing", "sitting", "leaning"]
104
+ - name: lighting
105
+ options: ["studio", "natural", "dramatic"]
106
+ ```
107
+
108
+ ## Key Files
109
+
110
+ ### API Routes
111
+ - `routes_generation.py` - txt2img, img2img endpoints
112
+ - `routes_video.py` - img2video, WaveSpeed/Higgsfield video
113
+ - `routes_training.py` - LoRA training jobs
114
+ - `routes_catalog.py` - Gallery/image management
115
+ - `routes_system.py` - Health checks, character list
116
+
117
+ ### Services
118
+ - `wavespeed_provider.py` - WaveSpeed API client (SDK optional, uses httpx)
119
+ - `runpod_trainer.py` - Cloud LoRA training
120
+ - `runpod_provider.py` - Cloud generation endpoint
121
+ - `comfyui_client.py` - Local ComfyUI integration
122
+ - `workflow_builder.py` - ComfyUI workflow JSON builder
123
+ - `template_engine.py` - Prompt template rendering
124
+ - `variation_engine.py` - Batch variation generation
125
+
126
+ ### Frontend
127
+ - `ui.html` - Single-page app with all UI
128
+
129
+ ## Environment Variables
130
+
131
+ ```env
132
+ # Cloud APIs
133
+ WAVESPEED_API_KEY=ws_xxx # WaveSpeed.ai API key
134
+ RUNPOD_API_KEY=xxx # RunPod API key
135
+ RUNPOD_ENDPOINT_ID=xxx # RunPod serverless endpoint (for generation)
136
+
137
+ # Optional
138
+ HIGGSFIELD_API_KEY=xxx # Higgsfield (Kling 3.0, etc.)
139
+ COMFYUI_URL=http://127.0.0.1:8188 # Local ComfyUI
140
+ ```
141
+
142
+ ## Database
143
+
144
+ SQLite with async (aiosqlite):
145
+ - `images` - Generated image catalog
146
+ - `characters` - Character profiles
147
+ - `generation_jobs` - Job tracking
148
+ - `scheduled_posts` - Publishing queue
149
+
150
+ ## UI Structure
151
+
152
+ **Generate Page:**
153
+ - Mode chips: Text to Image | Image to Image | Image to Video
154
+ - Backend chips: Local GPU | RunPod GPU | Cloud API
155
+ - Model dropdowns (conditional on mode/backend)
156
+ - Character/Template selectors (2-column grid)
157
+ - Prompt textareas
158
+ - Output settings (aspect ratio, seed)
159
+
160
+ **Controls Panel:** 340px width, compact styling
161
+ **Drop Zones:** For reference images (character + pose)
162
+
163
+ ## Common Issues
164
+
165
+ ### "Product not found" from WaveSpeed
166
+ Model ID doesn't exist. Check `MODEL_MAP`, `EDIT_MODEL_MAP`, `VIDEO_MODEL_MAP` in wavespeed_provider.py against https://wavespeed.ai/models
167
+
168
+ ### "No image URL in output"
169
+ WaveSpeed returned async job. Check `outputs` is empty and `urls.get` exists, then poll that URL.
170
+
171
+ ### HuggingFace Space startup hang
172
+ Check requirements.txt for missing packages. Common: `python-dotenv`, `runpod`, `wavespeed` (optional).
173
+
174
+ ### Import errors on HF Spaces
175
+ Make optional imports with try/except:
176
+ ```python
177
+ try:
178
+ from wavespeed import Client
179
+ SDK_AVAILABLE = True
180
+ except ImportError:
181
+ SDK_AVAILABLE = False
182
+ ```
183
+
184
+ ## Development Commands
185
+
186
+ ```bash
187
+ # Run locally
188
+ cd content_engine
189
+ python -m uvicorn content_engine.main:app --port 8000 --reload
190
+
191
+ # Push to HuggingFace
192
+ cd content-engine
193
+ git add . && git commit -m "message" && git push origin main
194
+
195
+ # Sync local ↔ HF
196
+ cp content_engine/src/content_engine/file.py content-engine/src/content_engine/file.py
197
+ ```
198
+
199
+ ## Multi-Reference Image Support
200
+
201
+ For img2img with 2 reference images (character + pose):
202
+
203
+ 1. **UI**: Two drop zones side-by-side
204
+ 2. **API**: `image` (required) + `image2` (optional) in FormData
205
+ 3. **Backend**: Both uploaded to temp URLs, sent to WaveSpeed
206
+ 4. **Models**: SeeDream Sequential, Kling O1 support multi-ref
207
+
208
+ ## Pricing Notes
209
+
210
+ - **WaveSpeed**: ~$0.003-0.01 per image, $0.01-0.05/s for video
211
+ - **RunPod**: ~$0.0002/s for GPU time (training/generation)
212
+ - Cloud API cheaper for light use; RunPod better for volume
213
+
214
+ ## Future Improvements
215
+
216
+ - [ ] RunPod serverless endpoint for LoRA-based generation
217
+ - [ ] Auto-captioning for training images
218
+ - [ ] Batch video generation
219
+ - [ ] Publishing integrations (social media APIs)
config/models.yaml CHANGED
@@ -5,25 +5,31 @@ training_models:
5
  # FLUX - Best for photorealistic images (recommended for realistic person)
6
  flux2_dev:
7
  name: "FLUX.2 Dev (Recommended)"
8
- description: "Latest FLUX model, 32B params, best quality for realistic person. Also supports multi-reference without training."
9
  hf_repo: "black-forest-labs/FLUX.2-dev"
10
- hf_filename: "flux.2-dev.safetensors"
11
- model_type: "flux"
 
12
  resolution: 1024
13
- learning_rate: 1e-3
14
- text_encoder_lr: 1e-4
15
- network_rank: 48
16
- network_alpha: 24
17
- clip_skip: 1
18
- optimizer: "AdamW8bit"
19
- lr_scheduler: "cosine"
20
- min_snr_gamma: 5
21
- max_train_steps: 1200
22
  fp8_base: true
 
23
  use_case: "images"
24
- vram_required_gb: 24
 
25
  recommended_images: "15-30 high quality photos with detailed captions"
26
- training_script: "flux_train_network.py"
 
 
 
 
27
 
28
  flux1_dev:
29
  name: "FLUX.1 Dev"
 
5
  # FLUX - Best for photorealistic images (recommended for realistic person)
6
  flux2_dev:
7
  name: "FLUX.2 Dev (Recommended)"
8
+ description: "Latest FLUX model, 32B params, best quality for realistic person. Uses Mistral text encoder."
9
  hf_repo: "black-forest-labs/FLUX.2-dev"
10
+ hf_filename: "flux2-dev.safetensors"
11
+ model_type: "flux2"
12
+ training_framework: "musubi-tuner"
13
  resolution: 1024
14
+ learning_rate: 1.0
15
+ network_rank: 64
16
+ network_alpha: 32
17
+ optimizer: "prodigy"
18
+ lr_scheduler: "constant"
19
+ timestep_sampling: "flux2_shift"
20
+ network_module: "networks.lora_flux_2"
21
+ max_train_steps: 50
 
22
  fp8_base: true
23
+ gradient_checkpointing: true
24
  use_case: "images"
25
+ vram_required_gb: 48
26
+ recommended_gpu: "NVIDIA RTX A6000"
27
  recommended_images: "15-30 high quality photos with detailed captions"
28
+ training_script: "flux_2_train_network.py"
29
+ # Model paths on network volume:
30
+ # DiT: /workspace/models/FLUX.2-dev/flux2-dev.safetensors
31
+ # VAE: /workspace/models/FLUX.2-dev/vae/diffusion_pytorch_model.safetensors
32
+ # Text encoder: /workspace/models/FLUX.2-dev/text_encoder/model-00001-of-00010.safetensors
33
 
34
  flux1_dev:
35
  name: "FLUX.1 Dev"
src/content_engine/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (279 Bytes)
 
src/content_engine/__pycache__/config.cpython-311.pyc DELETED
Binary file (6.33 kB)
 
src/content_engine/__pycache__/main.cpython-311.pyc DELETED
Binary file (10.6 kB)
 
src/content_engine/api/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (211 Bytes)
 
src/content_engine/api/__pycache__/routes_catalog.cpython-311.pyc DELETED
Binary file (7.2 kB)
 
src/content_engine/api/__pycache__/routes_generation.cpython-311.pyc DELETED
Binary file (28 kB)
 
src/content_engine/api/__pycache__/routes_pod.cpython-311.pyc DELETED
Binary file (25.1 kB)
 
src/content_engine/api/__pycache__/routes_system.cpython-311.pyc DELETED
Binary file (10.7 kB)
 
src/content_engine/api/__pycache__/routes_training.cpython-311.pyc DELETED
Binary file (12.6 kB)
 
src/content_engine/api/__pycache__/routes_ui.cpython-311.pyc DELETED
Binary file (1.23 kB)
 
src/content_engine/api/__pycache__/routes_video.cpython-311.pyc DELETED
Binary file (13.3 kB)
 
src/content_engine/api/routes_pod.py CHANGED
@@ -1,12 +1,18 @@
1
- """RunPod Pod management routes — start/stop GPU pods for generation and training."""
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import asyncio
 
6
  import logging
7
  import os
8
  import time
9
  import uuid
 
10
  from typing import Any
11
 
12
  import runpod
@@ -17,28 +23,85 @@ logger = logging.getLogger(__name__)
17
 
18
  router = APIRouter(prefix="/api/pod", tags=["pod"])
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Pod state
21
  _pod_state = {
22
  "pod_id": None,
23
- "status": "stopped", # stopped, starting, running, stopping
24
  "ip": None,
25
- "port": None,
26
- "gpu_type": "NVIDIA GeForce RTX 4090",
 
 
27
  "started_at": None,
28
- "cost_per_hour": 0.44,
 
29
  }
30
 
31
- # Docker image with ComfyUI + FLUX
32
- COMFYUI_IMAGE = "timpietruskyblibla/runpod-worker-comfy:3.4.0-flux1-dev"
33
 
34
- # GPU options
35
  GPU_OPTIONS = {
36
- "NVIDIA GeForce RTX 4090": {"name": "RTX 4090", "vram": 24, "cost": 0.44},
37
- "NVIDIA RTX A6000": {"name": "RTX A6000", "vram": 48, "cost": 0.76},
38
- "NVIDIA A100 80GB PCIe": {"name": "A100 80GB", "vram": 80, "cost": 1.89},
 
 
 
 
 
 
 
39
  }
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def _get_api_key() -> str:
43
  key = os.environ.get("RUNPOD_API_KEY")
44
  if not key:
@@ -48,7 +111,8 @@ def _get_api_key() -> str:
48
 
49
 
50
  class StartPodRequest(BaseModel):
51
- gpu_type: str = "NVIDIA GeForce RTX 4090"
 
52
 
53
 
54
  class PodStatus(BaseModel):
@@ -57,7 +121,9 @@ class PodStatus(BaseModel):
57
  ip: str | None = None
58
  port: int | None = None
59
  gpu_type: str | None = None
 
60
  cost_per_hour: float | None = None
 
61
  uptime_minutes: float | None = None
62
  comfyui_url: str | None = None
63
 
@@ -67,44 +133,52 @@ async def get_pod_status():
67
  """Get current pod status."""
68
  _get_api_key()
69
 
70
- # If we have a pod_id, check its actual status
71
  if _pod_state["pod_id"]:
72
  try:
73
- pod = runpod.get_pod(_pod_state["pod_id"])
 
 
 
74
  if pod:
75
  desired = pod.get("desiredStatus", "")
76
  if desired == "RUNNING":
77
- runtime = pod.get("runtime", {})
78
- ports = runtime.get("ports", [])
79
  for p in ports:
 
 
 
80
  if p.get("privatePort") == 8188:
81
- _pod_state["ip"] = p.get("ip")
82
- _pod_state["port"] = p.get("publicPort")
83
- _pod_state["status"] = "running"
 
84
  elif desired == "EXITED":
85
  _pod_state["status"] = "stopped"
86
  _pod_state["pod_id"] = None
87
  else:
88
  _pod_state["status"] = "stopped"
89
  _pod_state["pod_id"] = None
 
 
90
  except Exception as e:
91
  logger.warning("Failed to check pod: %s", e)
92
 
93
  uptime = None
94
- if _pod_state["started_at"] and _pod_state["status"] == "running":
95
  uptime = (time.time() - _pod_state["started_at"]) / 60
96
 
97
- comfyui_url = None
98
- if _pod_state["ip"] and _pod_state["port"]:
99
- comfyui_url = f"http://{_pod_state['ip']}:{_pod_state['port']}"
100
 
101
  return PodStatus(
102
  status=_pod_state["status"],
103
  pod_id=_pod_state["pod_id"],
104
  ip=_pod_state["ip"],
105
- port=_pod_state["port"],
106
  gpu_type=_pod_state["gpu_type"],
 
107
  cost_per_hour=_pod_state["cost_per_hour"],
 
108
  uptime_minutes=uptime,
109
  comfyui_url=comfyui_url,
110
  )
@@ -116,12 +190,24 @@ async def list_gpu_options():
116
  return {"gpus": GPU_OPTIONS}
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  @router.post("/start")
120
  async def start_pod(request: StartPodRequest):
121
- """Start a GPU pod for generation/training."""
122
  _get_api_key()
123
 
124
- if _pod_state["status"] == "running":
125
  return {"status": "already_running", "pod_id": _pod_state["pod_id"]}
126
 
127
  if _pod_state["status"] == "starting":
@@ -134,74 +220,284 @@ async def start_pod(request: StartPodRequest):
134
  _pod_state["status"] = "starting"
135
  _pod_state["gpu_type"] = request.gpu_type
136
  _pod_state["cost_per_hour"] = gpu_info["cost"]
 
 
137
 
138
  try:
139
- logger.info("Starting RunPod with %s...", request.gpu_type)
140
-
141
- pod = runpod.create_pod(
142
- name="content-engine-gpu",
143
- image_name=COMFYUI_IMAGE,
144
- gpu_type_id=request.gpu_type,
145
- volume_in_gb=50, # For models and LoRAs
146
- container_disk_in_gb=20,
147
- ports="8188/http",
148
- env={
149
- # Pre-load FLUX model
150
- "MODEL_URL": "https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/flux1-dev.safetensors",
151
- },
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
 
154
  _pod_state["pod_id"] = pod["id"]
155
  _pod_state["started_at"] = time.time()
 
156
 
157
  logger.info("Pod created: %s", pod["id"])
158
 
159
- # Start background task to wait for pod ready
160
- asyncio.create_task(_wait_for_pod_ready(pod["id"]))
161
 
162
  return {
163
  "status": "starting",
164
  "pod_id": pod["id"],
165
- "message": f"Starting {gpu_info['name']} pod (~2-3 min)",
166
  }
167
 
168
  except Exception as e:
169
  _pod_state["status"] = "stopped"
 
170
  logger.error("Failed to start pod: %s", e)
171
  raise HTTPException(500, f"Failed to start pod: {e}")
172
 
173
 
174
- async def _wait_for_pod_ready(pod_id: str, timeout: int = 300):
175
- """Background task to wait for pod to be ready."""
176
  start = time.time()
 
 
177
 
 
 
178
  while time.time() - start < timeout:
179
  try:
180
- pod = runpod.get_pod(pod_id)
181
-
182
  if pod and pod.get("desiredStatus") == "RUNNING":
183
- runtime = pod.get("runtime", {})
184
- ports = runtime.get("ports", [])
185
-
186
  for p in ports:
 
 
 
 
 
 
187
  if p.get("privatePort") == 8188:
188
- ip = p.get("ip")
189
- port = p.get("publicPort")
190
-
191
- if ip and port:
192
- _pod_state["ip"] = ip
193
- _pod_state["port"] = int(port)
194
- _pod_state["status"] = "running"
195
- logger.info("Pod ready at %s:%s", ip, port)
196
- return
197
-
198
  except Exception as e:
199
  logger.debug("Waiting for pod: %s", e)
200
-
201
  await asyncio.sleep(5)
202
 
203
- logger.error("Pod did not become ready within %ds", timeout)
204
- _pod_state["status"] = "stopped"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
 
207
  @router.post("/stop")
@@ -221,20 +517,23 @@ async def stop_pod():
221
  pod_id = _pod_state["pod_id"]
222
  logger.info("Stopping pod: %s", pod_id)
223
 
224
- runpod.terminate_pod(pod_id)
225
 
226
  _pod_state["pod_id"] = None
227
  _pod_state["ip"] = None
228
- _pod_state["port"] = None
 
229
  _pod_state["status"] = "stopped"
230
  _pod_state["started_at"] = None
 
 
231
 
232
  logger.info("Pod stopped")
233
  return {"status": "stopped", "message": "Pod terminated"}
234
 
235
  except Exception as e:
236
  logger.error("Failed to stop pod: %s", e)
237
- _pod_state["status"] = "running" # Revert
238
  raise HTTPException(500, f"Failed to stop pod: {e}")
239
 
240
 
@@ -244,10 +543,11 @@ async def list_pod_loras():
244
  if _pod_state["status"] != "running" or not _pod_state["ip"]:
245
  return {"loras": [], "message": "Pod not running"}
246
 
 
247
  try:
248
  import httpx
249
  async with httpx.AsyncClient(timeout=30) as client:
250
- url = f"http://{_pod_state['ip']}:{_pod_state['port']}/object_info/LoraLoader"
251
  resp = await client.get(url)
252
  if resp.status_code == 200:
253
  data = resp.json()
@@ -256,15 +556,12 @@ async def list_pod_loras():
256
  except Exception as e:
257
  logger.warning("Failed to list pod LoRAs: %s", e)
258
 
259
- return {"loras": [], "comfyui_url": f"http://{_pod_state['ip']}:{_pod_state['port']}"}
260
 
261
 
262
  @router.post("/upload-lora")
263
- async def upload_lora_to_pod(
264
- file: UploadFile = File(...),
265
- ):
266
  """Upload a LoRA file to the running pod."""
267
- from fastapi import UploadFile, File
268
  import httpx
269
 
270
  if _pod_state["status"] != "running":
@@ -275,13 +572,12 @@ async def upload_lora_to_pod(
275
 
276
  try:
277
  content = await file.read()
 
278
 
279
  async with httpx.AsyncClient(timeout=120) as client:
280
- # Upload to ComfyUI's models/loras directory
281
- url = f"http://{_pod_state['ip']}:{_pod_state['port']}/upload/image"
282
  files = {"image": (file.filename, content, "application/octet-stream")}
283
  data = {"subfolder": "loras", "type": "input"}
284
-
285
  resp = await client.post(url, files=files, data=data)
286
 
287
  if resp.status_code == 200:
@@ -326,7 +622,7 @@ async def generate_on_pod(request: PodGenerateRequest):
326
  job_id = str(uuid.uuid4())[:8]
327
  seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1)
328
 
329
- # Build ComfyUI workflow
330
  workflow = _build_flux_workflow(
331
  prompt=request.prompt,
332
  negative_prompt=request.negative_prompt,
@@ -337,12 +633,14 @@ async def generate_on_pod(request: PodGenerateRequest):
337
  seed=seed,
338
  lora_name=request.lora_name,
339
  lora_strength=request.lora_strength,
 
340
  )
341
 
 
 
342
  try:
343
  async with httpx.AsyncClient(timeout=30) as client:
344
- url = f"http://{_pod_state['ip']}:{_pod_state['port']}/prompt"
345
- resp = await client.post(url, json={"prompt": workflow})
346
  resp.raise_for_status()
347
 
348
  data = resp.json()
@@ -356,15 +654,9 @@ async def generate_on_pod(request: PodGenerateRequest):
356
  }
357
 
358
  logger.info("Pod generation started: %s -> %s", job_id, prompt_id)
359
-
360
- # Start background task to poll for completion
361
  asyncio.create_task(_poll_pod_job(job_id, prompt_id, request.content_rating))
362
 
363
- return {
364
- "job_id": job_id,
365
- "status": "running",
366
- "seed": seed,
367
- }
368
 
369
  except Exception as e:
370
  logger.error("Pod generation failed: %s", e)
@@ -374,38 +666,33 @@ async def generate_on_pod(request: PodGenerateRequest):
374
  async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str):
375
  """Poll ComfyUI for job completion and save the result."""
376
  import httpx
377
- from pathlib import Path
378
 
379
  start = time.time()
380
- timeout = 300 # 5 minutes
 
381
 
382
  async with httpx.AsyncClient(timeout=60) as client:
383
  while time.time() - start < timeout:
384
  try:
385
- url = f"http://{_pod_state['ip']}:{_pod_state['port']}/history/{prompt_id}"
386
- resp = await client.get(url)
387
 
388
  if resp.status_code == 200:
389
  data = resp.json()
390
  if prompt_id in data:
391
  outputs = data[prompt_id].get("outputs", {})
392
 
393
- # Find SaveImage output
394
  for node_id, node_output in outputs.items():
395
  if "images" in node_output:
396
  image_info = node_output["images"][0]
397
  filename = image_info["filename"]
398
  subfolder = image_info.get("subfolder", "")
399
 
400
- # Download the image
401
- img_url = f"http://{_pod_state['ip']}:{_pod_state['port']}/view"
402
  params = {"filename": filename}
403
  if subfolder:
404
  params["subfolder"] = subfolder
405
 
406
- img_resp = await client.get(img_url, params=params)
407
  if img_resp.status_code == 200:
408
- # Save to local output directory
409
  from content_engine.config import settings
410
  output_dir = settings.paths.output_dir / "pod" / content_rating / "raw"
411
  output_dir.mkdir(parents=True, exist_ok=True)
@@ -419,7 +706,6 @@ async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str):
419
 
420
  logger.info("Pod generation completed: %s -> %s", job_id, local_path)
421
 
422
- # Catalog the image
423
  try:
424
  from content_engine.services.catalog import CatalogService
425
  catalog = CatalogService()
@@ -463,29 +749,69 @@ def _build_flux_workflow(
463
  seed: int,
464
  lora_name: str | None,
465
  lora_strength: float,
 
466
  ) -> dict:
467
- """Build a ComfyUI workflow for FLUX generation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- # Basic FLUX workflow - compatible with ComfyUI FLUX setup
470
  workflow = {
471
- "4": {
472
- "class_type": "CheckpointLoaderSimple",
473
- "inputs": {"ckpt_name": "flux1-dev.safetensors"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  },
 
475
  "6": {
476
  "class_type": "CLIPTextEncode",
477
  "inputs": {
478
  "text": prompt,
479
- "clip": ["4", 1],
480
  },
481
  },
 
482
  "7": {
483
  "class_type": "CLIPTextEncode",
484
  "inputs": {
485
  "text": negative_prompt or "",
486
- "clip": ["4", 1],
487
  },
488
  },
 
489
  "5": {
490
  "class_type": "EmptyLatentImage",
491
  "inputs": {
@@ -494,7 +820,8 @@ def _build_flux_workflow(
494
  "batch_size": 1,
495
  },
496
  },
497
- "3": {
 
498
  "class_type": "KSampler",
499
  "inputs": {
500
  "seed": seed,
@@ -503,19 +830,21 @@ def _build_flux_workflow(
503
  "sampler_name": "euler",
504
  "scheduler": "simple",
505
  "denoise": 1.0,
506
- "model": ["4", 0],
507
  "positive": ["6", 0],
508
  "negative": ["7", 0],
509
  "latent_image": ["5", 0],
510
  },
511
  },
 
512
  "8": {
513
  "class_type": "VAEDecode",
514
  "inputs": {
515
- "samples": ["3", 0],
516
- "vae": ["4", 2],
517
  },
518
  },
 
519
  "9": {
520
  "class_type": "SaveImage",
521
  "inputs": {
@@ -527,19 +856,19 @@ def _build_flux_workflow(
527
 
528
  # Add LoRA if specified
529
  if lora_name:
530
- workflow["10"] = {
531
  "class_type": "LoraLoader",
532
  "inputs": {
533
  "lora_name": lora_name,
534
  "strength_model": lora_strength,
535
  "strength_clip": lora_strength,
536
- "model": ["4", 0],
537
- "clip": ["4", 1],
538
  },
539
  }
540
- # Rewire sampler to use LoRA output
541
- workflow["3"]["inputs"]["model"] = ["10", 0]
542
- workflow["6"]["inputs"]["clip"] = ["10", 1]
543
- workflow["7"]["inputs"]["clip"] = ["10", 1]
544
 
545
  return workflow
 
1
+ """RunPod Pod management routes — start/stop GPU pods for generation.
2
+
3
+ Starts a persistent ComfyUI pod with network volume access.
4
+ Models and LoRAs are loaded from the shared network volume.
5
+ """
6
 
7
  from __future__ import annotations
8
 
9
  import asyncio
10
+ import json
11
  import logging
12
  import os
13
  import time
14
  import uuid
15
+ from pathlib import Path
16
  from typing import Any
17
 
18
  import runpod
 
23
 
24
  router = APIRouter(prefix="/api/pod", tags=["pod"])
25
 
26
+ # Persist pod state to disk so it survives server restarts
27
+ _POD_STATE_FILE = Path(__file__).parent.parent.parent.parent / "pod_state.json"
28
+
29
+
30
+ def _save_pod_state():
31
+ """Save pod state to disk."""
32
+ try:
33
+ data = {k: v for k, v in _pod_state.items() if k != "setup_status"}
34
+ _POD_STATE_FILE.write_text(json.dumps(data))
35
+ except Exception as e:
36
+ logger.warning("Failed to save pod state: %s", e)
37
+
38
+
39
+ def _load_pod_state():
40
+ """Load pod state from disk on startup."""
41
+ try:
42
+ if _POD_STATE_FILE.exists():
43
+ data = json.loads(_POD_STATE_FILE.read_text())
44
+ for k, v in data.items():
45
+ if k in _pod_state:
46
+ _pod_state[k] = v
47
+ logger.info("Restored pod state: pod_id=%s status=%s", _pod_state.get("pod_id"), _pod_state.get("status"))
48
+ except Exception as e:
49
+ logger.warning("Failed to load pod state: %s", e)
50
+
51
+ def _get_volume_config() -> tuple[str, str]:
52
+ """Get network volume config at runtime (after dotenv loads)."""
53
+ return (
54
+ os.environ.get("RUNPOD_VOLUME_ID", ""),
55
+ os.environ.get("RUNPOD_VOLUME_DC", ""),
56
+ )
57
+
58
+ # Docker image — PyTorch base with CUDA, we install ComfyUI ourselves
59
+ DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
60
+
61
  # Pod state
62
  _pod_state = {
63
  "pod_id": None,
64
+ "status": "stopped", # stopped, starting, setting_up, running, stopping
65
  "ip": None,
66
+ "ssh_port": None,
67
+ "comfyui_port": None,
68
+ "gpu_type": "NVIDIA RTX A6000",
69
+ "model_type": "flux2",
70
  "started_at": None,
71
+ "cost_per_hour": 0.76,
72
+ "setup_status": None,
73
  }
74
 
75
+ _load_pod_state()
 
76
 
77
+ # GPU options (same as training)
78
  GPU_OPTIONS = {
79
+ "NVIDIA A40": {"name": "A40 48GB", "vram": 48, "cost": 0.64},
80
+ "NVIDIA RTX A6000": {"name": "RTX A6000 48GB", "vram": 48, "cost": 0.76},
81
+ "NVIDIA L40": {"name": "L40 48GB", "vram": 48, "cost": 0.89},
82
+ "NVIDIA L40S": {"name": "L40S 48GB", "vram": 48, "cost": 1.09},
83
+ "NVIDIA A100-SXM4-80GB": {"name": "A100 SXM 80GB", "vram": 80, "cost": 1.64},
84
+ "NVIDIA A100 80GB PCIe": {"name": "A100 PCIe 80GB", "vram": 80, "cost": 1.89},
85
+ "NVIDIA H100 80GB HBM3": {"name": "H100 80GB", "vram": 80, "cost": 3.89},
86
+ "NVIDIA GeForce RTX 5090": {"name": "RTX 5090 32GB", "vram": 32, "cost": 0.69},
87
+ "NVIDIA GeForce RTX 4090": {"name": "RTX 4090 24GB", "vram": 24, "cost": 0.44},
88
+ "NVIDIA GeForce RTX 3090": {"name": "RTX 3090 24GB", "vram": 24, "cost": 0.22},
89
  }
90
 
91
 
92
+ def _get_comfyui_url() -> str | None:
93
+ """Get the ComfyUI URL via RunPod's HTTPS proxy.
94
+
95
+ RunPod HTTP ports are only accessible through their proxy at
96
+ https://{pod_id}-{private_port}.proxy.runpod.net
97
+ The raw IP:port from the API is an internal address, not publicly routable.
98
+ """
99
+ pod_id = _pod_state.get("pod_id")
100
+ if pod_id:
101
+ return f"https://{pod_id}-8188.proxy.runpod.net"
102
+ return None
103
+
104
+
105
  def _get_api_key() -> str:
106
  key = os.environ.get("RUNPOD_API_KEY")
107
  if not key:
 
111
 
112
 
113
  class StartPodRequest(BaseModel):
114
+ gpu_type: str = "NVIDIA RTX A6000"
115
+ model_type: str = "flux2"
116
 
117
 
118
  class PodStatus(BaseModel):
 
121
  ip: str | None = None
122
  port: int | None = None
123
  gpu_type: str | None = None
124
+ model_type: str | None = None
125
  cost_per_hour: float | None = None
126
+ setup_status: str | None = None
127
  uptime_minutes: float | None = None
128
  comfyui_url: str | None = None
129
 
 
133
  """Get current pod status."""
134
  _get_api_key()
135
 
 
136
  if _pod_state["pod_id"]:
137
  try:
138
+ pod = await asyncio.wait_for(
139
+ asyncio.to_thread(runpod.get_pod, _pod_state["pod_id"]),
140
+ timeout=10,
141
+ )
142
  if pod:
143
  desired = pod.get("desiredStatus", "")
144
  if desired == "RUNNING":
145
+ runtime = pod.get("runtime") or {}
146
+ ports = runtime.get("ports") or []
147
  for p in ports:
148
+ if p.get("privatePort") == 22:
149
+ _pod_state["ssh_ip"] = p.get("ip")
150
+ _pod_state["ssh_port"] = p.get("publicPort")
151
  if p.get("privatePort") == 8188:
152
+ _pod_state["comfyui_ip"] = p.get("ip")
153
+ _pod_state["comfyui_port"] = p.get("publicPort")
154
+ # Use SSH IP as the main IP for display
155
+ _pod_state["ip"] = _pod_state.get("ssh_ip") or _pod_state.get("comfyui_ip")
156
  elif desired == "EXITED":
157
  _pod_state["status"] = "stopped"
158
  _pod_state["pod_id"] = None
159
  else:
160
  _pod_state["status"] = "stopped"
161
  _pod_state["pod_id"] = None
162
+ except asyncio.TimeoutError:
163
+ logger.warning("RunPod API timeout checking pod status")
164
  except Exception as e:
165
  logger.warning("Failed to check pod: %s", e)
166
 
167
  uptime = None
168
+ if _pod_state["started_at"] and _pod_state["status"] in ("running", "setting_up"):
169
  uptime = (time.time() - _pod_state["started_at"]) / 60
170
 
171
+ comfyui_url = _get_comfyui_url()
 
 
172
 
173
  return PodStatus(
174
  status=_pod_state["status"],
175
  pod_id=_pod_state["pod_id"],
176
  ip=_pod_state["ip"],
177
+ port=_pod_state.get("comfyui_port"),
178
  gpu_type=_pod_state["gpu_type"],
179
+ model_type=_pod_state.get("model_type", "flux2"),
180
  cost_per_hour=_pod_state["cost_per_hour"],
181
+ setup_status=_pod_state.get("setup_status"),
182
  uptime_minutes=uptime,
183
  comfyui_url=comfyui_url,
184
  )
 
190
  return {"gpus": GPU_OPTIONS}
191
 
192
 
193
+ @router.get("/model-options")
194
+ async def list_model_options():
195
+ """List available model types for the pod."""
196
+ return {
197
+ "models": {
198
+ "flux2": {"name": "FLUX.2 Dev", "description": "Best for realistic txt2img (requires 48GB+ VRAM)", "use_case": "txt2img"},
199
+ "flux1": {"name": "FLUX.1 Dev", "description": "Previous gen FLUX txt2img", "use_case": "txt2img"},
200
+ "wan22": {"name": "WAN 2.2", "description": "Image-to-video and general generation", "use_case": "img2video"},
201
+ }
202
+ }
203
+
204
+
205
  @router.post("/start")
206
  async def start_pod(request: StartPodRequest):
207
+ """Start a GPU pod with ComfyUI for generation."""
208
  _get_api_key()
209
 
210
+ if _pod_state["status"] in ("running", "setting_up"):
211
  return {"status": "already_running", "pod_id": _pod_state["pod_id"]}
212
 
213
  if _pod_state["status"] == "starting":
 
220
  _pod_state["status"] = "starting"
221
  _pod_state["gpu_type"] = request.gpu_type
222
  _pod_state["cost_per_hour"] = gpu_info["cost"]
223
+ _pod_state["model_type"] = request.model_type
224
+ _pod_state["setup_status"] = "Creating pod..."
225
 
226
  try:
227
+ logger.info("Starting RunPod with %s for %s...", request.gpu_type, request.model_type)
228
+
229
+ pod_kwargs = {
230
+ "container_disk_in_gb": 30,
231
+ "ports": "22/tcp,8188/http",
232
+ "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
233
+ }
234
+
235
+ volume_id, volume_dc = _get_volume_config()
236
+ if volume_id:
237
+ pod_kwargs["network_volume_id"] = volume_id
238
+ if volume_dc:
239
+ pod_kwargs["data_center_id"] = volume_dc
240
+ logger.info("Using network volume: %s (DC: %s)", volume_id, volume_dc)
241
+ else:
242
+ pod_kwargs["volume_in_gb"] = 75
243
+ logger.warning("No network volume configured — using temporary volume")
244
+
245
+ pod = await asyncio.to_thread(
246
+ runpod.create_pod,
247
+ f"comfyui-gen-{request.model_type}",
248
+ DOCKER_IMAGE,
249
+ request.gpu_type,
250
+ **pod_kwargs,
251
  )
252
 
253
  _pod_state["pod_id"] = pod["id"]
254
  _pod_state["started_at"] = time.time()
255
+ _save_pod_state()
256
 
257
  logger.info("Pod created: %s", pod["id"])
258
 
259
+ asyncio.create_task(_wait_and_setup_pod(pod["id"], request.model_type))
 
260
 
261
  return {
262
  "status": "starting",
263
  "pod_id": pod["id"],
264
+ "message": f"Starting {gpu_info['name']} pod (~5-8 min for setup)",
265
  }
266
 
267
  except Exception as e:
268
  _pod_state["status"] = "stopped"
269
+ _pod_state["setup_status"] = None
270
  logger.error("Failed to start pod: %s", e)
271
  raise HTTPException(500, f"Failed to start pod: {e}")
272
 
273
 
274
+ async def _wait_and_setup_pod(pod_id: str, model_type: str, timeout: int = 600):
275
+ """Wait for pod to be ready, then install ComfyUI and link models via SSH."""
276
  start = time.time()
277
+ ssh_host = None
278
+ ssh_port = None
279
 
280
+ # Phase 1: Wait for SSH to be available
281
+ _pod_state["setup_status"] = "Waiting for pod to start..."
282
  while time.time() - start < timeout:
283
  try:
284
+ pod = await asyncio.to_thread(runpod.get_pod, pod_id)
 
285
  if pod and pod.get("desiredStatus") == "RUNNING":
286
+ runtime = pod.get("runtime") or {}
287
+ ports = runtime.get("ports") or []
 
288
  for p in ports:
289
+ if p.get("privatePort") == 22:
290
+ ssh_host = p.get("ip")
291
+ ssh_port = p.get("publicPort")
292
+ _pod_state["ssh_ip"] = ssh_host
293
+ _pod_state["ssh_port"] = ssh_port
294
+ _pod_state["ip"] = ssh_host
295
  if p.get("privatePort") == 8188:
296
+ _pod_state["comfyui_ip"] = p.get("ip")
297
+ _pod_state["comfyui_port"] = p.get("publicPort")
298
+ if ssh_host and ssh_port:
299
+ break
 
 
 
 
 
 
300
  except Exception as e:
301
  logger.debug("Waiting for pod: %s", e)
 
302
  await asyncio.sleep(5)
303
 
304
+ if not ssh_host or not ssh_port:
305
+ logger.error("Pod did not become ready within %ds", timeout)
306
+ _pod_state["status"] = "stopped"
307
+ _pod_state["setup_status"] = "Failed: pod did not start"
308
+ return
309
+
310
+ # Phase 2: SSH in and set up ComfyUI
311
+ _pod_state["status"] = "setting_up"
312
+ _pod_state["setup_status"] = "Connecting via SSH..."
313
+
314
+ import paramiko
315
+ ssh = paramiko.SSHClient()
316
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
317
+
318
+ for attempt in range(30):
319
+ try:
320
+ await asyncio.to_thread(
321
+ ssh.connect, ssh_host, port=int(ssh_port),
322
+ username="root", password="runpod", timeout=10,
323
+ )
324
+ break
325
+ except Exception:
326
+ if attempt == 29:
327
+ _pod_state["setup_status"] = "Failed: SSH connection error"
328
+ _pod_state["status"] = "stopped"
329
+ return
330
+ await asyncio.sleep(5)
331
+
332
+ transport = ssh.get_transport()
333
+ transport.set_keepalive(30)
334
+
335
+ try:
336
+ # Symlink network volume
337
+ volume_id, _ = _get_volume_config()
338
+ if volume_id:
339
+ await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models /runpod-volume/loras")
340
+ await _ssh_exec_async(ssh, "rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
341
+
342
+ # Install ComfyUI (cache on volume for reuse)
343
+ comfy_dir = "/workspace/ComfyUI"
344
+ _pod_state["setup_status"] = "Installing ComfyUI..."
345
+
346
+ comfy_exists = (await _ssh_exec_async(ssh, f"test -f {comfy_dir}/main.py && echo EXISTS || echo MISSING")).strip()
347
+ if comfy_exists == "EXISTS":
348
+ logger.info("ComfyUI already installed")
349
+ _pod_state["setup_status"] = "ComfyUI found, updating..."
350
+ await _ssh_exec_async(ssh, f"cd {comfy_dir} && git pull 2>&1 | tail -3", timeout=120)
351
+ else:
352
+ # Check volume cache
353
+ vol_comfy = (await _ssh_exec_async(ssh, "test -f /runpod-volume/ComfyUI/main.py && echo EXISTS || echo MISSING")).strip()
354
+ if vol_comfy == "EXISTS":
355
+ _pod_state["setup_status"] = "Restoring ComfyUI from volume..."
356
+ await _ssh_exec_async(ssh, f"cp -r /runpod-volume/ComfyUI {comfy_dir}", timeout=300)
357
+ else:
358
+ _pod_state["setup_status"] = "Cloning ComfyUI (first time, ~2 min)..."
359
+ await _ssh_exec_async(ssh, f"cd /workspace && git clone --depth 1 https://github.com/comfyanonymous/ComfyUI.git", timeout=300)
360
+ await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600)
361
+ # Cache to volume
362
+ volume_id, _ = _get_volume_config()
363
+ if volume_id:
364
+ await _ssh_exec_async(ssh, f"cp -r {comfy_dir} /runpod-volume/ComfyUI", timeout=300)
365
+
366
+ # Install pip deps that aren't in ComfyUI requirements
367
+ _pod_state["setup_status"] = "Installing dependencies..."
368
+ await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600)
369
+ await _ssh_exec_async(ssh, "pip install aiohttp einops sqlalchemy 2>&1 | tail -3", timeout=120)
370
+
371
+ # Symlink models into ComfyUI directories
372
+ _pod_state["setup_status"] = "Linking models..."
373
+ await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/checkpoints {comfy_dir}/models/vae {comfy_dir}/models/loras {comfy_dir}/models/text_encoders")
374
+
375
+ if model_type == "flux2":
376
+ # FLUX.2 Dev — separate UNet, text encoder, and VAE
377
+ await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
378
+ await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/flux2-dev.safetensors {comfy_dir}/models/diffusion_models/flux2-dev.safetensors")
379
+ await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
380
+
381
+ # Text encoder — use Comfy-Org's pre-converted single-file version
382
+ # (HF sharded format is incompatible with ComfyUI's CLIPLoader)
383
+ te_file = "/runpod-volume/models/mistral_3_small_flux2_fp8.safetensors"
384
+ te_exists = (await _ssh_exec_async(ssh, f"test -f {te_file} && echo EXISTS || echo MISSING")).strip()
385
+ if te_exists != "EXISTS":
386
+ _pod_state["setup_status"] = "Downloading FLUX.2 text encoder (~12GB, first time only)..."
387
+ await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
388
+ await _ssh_exec_async(ssh, f"""python -c "
389
+ from huggingface_hub import hf_hub_download
390
+ hf_hub_download(
391
+ repo_id='Comfy-Org/flux2-dev',
392
+ filename='split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors',
393
+ local_dir='/tmp/flux2_te',
394
+ )
395
+ import shutil
396
+ shutil.move('/tmp/flux2_te/split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors', '{te_file}')
397
+ print('Text encoder downloaded')
398
+ " 2>&1 | tail -5""", timeout=1800)
399
+ await _ssh_exec_async(ssh, f"ln -sf {te_file} {comfy_dir}/models/text_encoders/mistral_3_small_flux2_fp8.safetensors")
400
+ # Remove old sharded loader patch if present
401
+ await _ssh_exec_async(ssh, f"rm -f {comfy_dir}/custom_nodes/sharded_loader.py")
402
+ elif model_type == "flux1":
403
+ await _ssh_exec_async(ssh, f"ln -sf /workspace/models/flux1-dev.safetensors {comfy_dir}/models/checkpoints/flux1-dev.safetensors")
404
+ await _ssh_exec_async(ssh, f"ln -sf /workspace/models/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
405
+ await _ssh_exec_async(ssh, f"ln -sf /workspace/models/clip_l.safetensors {comfy_dir}/models/text_encoders/clip_l.safetensors")
406
+ await _ssh_exec_async(ssh, f"ln -sf /workspace/models/t5xxl_fp16.safetensors {comfy_dir}/models/text_encoders/t5xxl_fp16.safetensors")
407
+ elif model_type == "wan22":
408
+ # WAN 2.2 Image-to-Video (14B params)
409
+ wan_dir = "/workspace/models/Wan2.2-I2V-A14B"
410
+ wan_exists = (await _ssh_exec_async(ssh, f"test -d {wan_dir} && echo EXISTS || echo MISSING")).strip()
411
+ if wan_exists != "EXISTS":
412
+ _pod_state["setup_status"] = "Downloading WAN 2.2 model (~28GB, first time only)..."
413
+ await _ssh_exec_async(ssh, f"pip install huggingface_hub 2>&1 | tail -1", timeout=60)
414
+ await _ssh_exec_async(ssh, f"""python -c "
415
+ from huggingface_hub import snapshot_download
416
+ snapshot_download('Wan-AI/Wan2.2-I2V-A14B', local_dir='{wan_dir}', ignore_patterns=['*.md', '*.txt'])
417
+ print('WAN 2.2 downloaded')
418
+ " 2>&1 | tail -10""", timeout=3600)
419
+ # Symlink WAN model to ComfyUI diffusion_models dir
420
+ await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
421
+ await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/diffusion_models/Wan2.2-I2V-A14B")
422
+ # Also need a VAE and text encoder for WAN — they use their own
423
+ await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/checkpoints/Wan2.2-I2V-A14B")
424
+
425
+ # Install ComfyUI-WanVideoWrapper custom nodes
426
+ _pod_state["setup_status"] = "Installing WAN 2.2 ComfyUI nodes..."
427
+ wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper"
428
+ wan_nodes_exist = (await _ssh_exec_async(ssh, f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip()
429
+ if wan_nodes_exist != "EXISTS":
430
+ await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120)
431
+ await _ssh_exec_async(ssh, f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
432
+
433
+ # Symlink all LoRAs from volume
434
+ await _ssh_exec_async(ssh, f"ls /runpod-volume/loras/*.safetensors 2>/dev/null | while read f; do ln -sf \"$f\" {comfy_dir}/models/loras/; done")
435
+
436
+ # Start ComfyUI in background (fire-and-forget — don't wait for output)
437
+ _pod_state["setup_status"] = "Starting ComfyUI..."
438
+ await asyncio.to_thread(
439
+ _ssh_exec_fire_and_forget,
440
+ ssh,
441
+ f"cd {comfy_dir} && python main.py --listen 0.0.0.0 --port 8188 --fp8_e4m3fn-unet > /tmp/comfyui.log 2>&1",
442
+ )
443
+ await asyncio.sleep(2) # Give it a moment to start
444
+
445
+ # Wait for ComfyUI HTTP to respond
446
+ _pod_state["setup_status"] = "Waiting for ComfyUI to load model..."
447
+ import httpx
448
+ comfyui_url = _get_comfyui_url()
449
+ for attempt in range(120): # Up to 10 minutes
450
+ try:
451
+ async with httpx.AsyncClient(timeout=5) as client:
452
+ resp = await client.get(f"{comfyui_url}/system_stats")
453
+ if resp.status_code == 200:
454
+ _pod_state["status"] = "running"
455
+ _pod_state["setup_status"] = "Ready"
456
+ _save_pod_state()
457
+ logger.info("ComfyUI ready at %s", comfyui_url)
458
+ return
459
+ except Exception:
460
+ pass
461
+ await asyncio.sleep(5)
462
+
463
+ # If we get here, ComfyUI didn't start
464
+ # Check the log for errors
465
+ log_tail = await _ssh_exec_async(ssh, "tail -20 /tmp/comfyui.log")
466
+ logger.error("ComfyUI didn't start. Log: %s", log_tail)
467
+ _pod_state["setup_status"] = f"ComfyUI failed to start. Check logs."
468
+ _pod_state["status"] = "setting_up" # Keep pod running so user can debug
469
+
470
+ except Exception as e:
471
+ import traceback
472
+ err_msg = f"{type(e).__name__}: {e}"
473
+ logger.error("Pod setup failed: %s\n%s", err_msg, traceback.format_exc())
474
+ _pod_state["setup_status"] = f"Setup failed: {err_msg}"
475
+ _pod_state["status"] = "setting_up" # Keep pod running so user can debug
476
+ finally:
477
+ try:
478
+ ssh.close()
479
+ except Exception:
480
+ pass
481
+
482
+
483
+ def _ssh_exec(ssh, cmd: str, timeout: int = 120) -> str:
484
+ """Execute a command over SSH and return stdout (blocking — call from async via to_thread or background task)."""
485
+ _, stdout, stderr = ssh.exec_command(cmd, timeout=timeout)
486
+ out = stdout.read().decode("utf-8", errors="replace")
487
+ return out.strip()
488
+
489
+
490
+ async def _ssh_exec_async(ssh, cmd: str, timeout: int = 120) -> str:
491
+ """Async wrapper for SSH exec that doesn't block the event loop."""
492
+ return await asyncio.to_thread(_ssh_exec, ssh, cmd, timeout)
493
+
494
+
495
+ def _ssh_exec_fire_and_forget(ssh, cmd: str):
496
+ """Start a command over SSH without waiting for output (for background processes)."""
497
+ transport = ssh.get_transport()
498
+ channel = transport.open_session()
499
+ channel.exec_command(cmd)
500
+ # Don't read stdout/stderr — just let it run
501
 
502
 
503
  @router.post("/stop")
 
517
  pod_id = _pod_state["pod_id"]
518
  logger.info("Stopping pod: %s", pod_id)
519
 
520
+ await asyncio.to_thread(runpod.terminate_pod, pod_id)
521
 
522
  _pod_state["pod_id"] = None
523
  _pod_state["ip"] = None
524
+ _pod_state["ssh_port"] = None
525
+ _pod_state["comfyui_port"] = None
526
  _pod_state["status"] = "stopped"
527
  _pod_state["started_at"] = None
528
+ _pod_state["setup_status"] = None
529
+ _save_pod_state()
530
 
531
  logger.info("Pod stopped")
532
  return {"status": "stopped", "message": "Pod terminated"}
533
 
534
  except Exception as e:
535
  logger.error("Failed to stop pod: %s", e)
536
+ _pod_state["status"] = "running"
537
  raise HTTPException(500, f"Failed to stop pod: {e}")
538
 
539
 
 
543
  if _pod_state["status"] != "running" or not _pod_state["ip"]:
544
  return {"loras": [], "message": "Pod not running"}
545
 
546
+ comfyui_url = _get_comfyui_url()
547
  try:
548
  import httpx
549
  async with httpx.AsyncClient(timeout=30) as client:
550
+ url = f"{comfyui_url}/object_info/LoraLoader"
551
  resp = await client.get(url)
552
  if resp.status_code == 200:
553
  data = resp.json()
 
556
  except Exception as e:
557
  logger.warning("Failed to list pod LoRAs: %s", e)
558
 
559
+ return {"loras": [], "comfyui_url": comfyui_url}
560
 
561
 
562
  @router.post("/upload-lora")
563
+ async def upload_lora_to_pod(file: UploadFile = File(...)):
 
 
564
  """Upload a LoRA file to the running pod."""
 
565
  import httpx
566
 
567
  if _pod_state["status"] != "running":
 
572
 
573
  try:
574
  content = await file.read()
575
+ comfyui_url = _get_comfyui_url()
576
 
577
  async with httpx.AsyncClient(timeout=120) as client:
578
+ url = f"{comfyui_url}/upload/image"
 
579
  files = {"image": (file.filename, content, "application/octet-stream")}
580
  data = {"subfolder": "loras", "type": "input"}
 
581
  resp = await client.post(url, files=files, data=data)
582
 
583
  if resp.status_code == 200:
 
622
  job_id = str(uuid.uuid4())[:8]
623
  seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1)
624
 
625
+ model_type = _pod_state.get("model_type", "flux2")
626
  workflow = _build_flux_workflow(
627
  prompt=request.prompt,
628
  negative_prompt=request.negative_prompt,
 
633
  seed=seed,
634
  lora_name=request.lora_name,
635
  lora_strength=request.lora_strength,
636
+ model_type=model_type,
637
  )
638
 
639
+ comfyui_url = _get_comfyui_url()
640
+
641
  try:
642
  async with httpx.AsyncClient(timeout=30) as client:
643
+ resp = await client.post(f"{comfyui_url}/prompt", json={"prompt": workflow})
 
644
  resp.raise_for_status()
645
 
646
  data = resp.json()
 
654
  }
655
 
656
  logger.info("Pod generation started: %s -> %s", job_id, prompt_id)
 
 
657
  asyncio.create_task(_poll_pod_job(job_id, prompt_id, request.content_rating))
658
 
659
+ return {"job_id": job_id, "status": "running", "seed": seed}
 
 
 
 
660
 
661
  except Exception as e:
662
  logger.error("Pod generation failed: %s", e)
 
666
  async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str):
667
  """Poll ComfyUI for job completion and save the result."""
668
  import httpx
 
669
 
670
  start = time.time()
671
+ timeout = 600 # 10 min — first gen can take 5+ min for model loading
672
+ comfyui_url = _get_comfyui_url()
673
 
674
  async with httpx.AsyncClient(timeout=60) as client:
675
  while time.time() - start < timeout:
676
  try:
677
+ resp = await client.get(f"{comfyui_url}/history/{prompt_id}")
 
678
 
679
  if resp.status_code == 200:
680
  data = resp.json()
681
  if prompt_id in data:
682
  outputs = data[prompt_id].get("outputs", {})
683
 
 
684
  for node_id, node_output in outputs.items():
685
  if "images" in node_output:
686
  image_info = node_output["images"][0]
687
  filename = image_info["filename"]
688
  subfolder = image_info.get("subfolder", "")
689
 
 
 
690
  params = {"filename": filename}
691
  if subfolder:
692
  params["subfolder"] = subfolder
693
 
694
+ img_resp = await client.get(f"{comfyui_url}/view", params=params)
695
  if img_resp.status_code == 200:
 
696
  from content_engine.config import settings
697
  output_dir = settings.paths.output_dir / "pod" / content_rating / "raw"
698
  output_dir.mkdir(parents=True, exist_ok=True)
 
706
 
707
  logger.info("Pod generation completed: %s -> %s", job_id, local_path)
708
 
 
709
  try:
710
  from content_engine.services.catalog import CatalogService
711
  catalog = CatalogService()
 
749
  seed: int,
750
  lora_name: str | None,
751
  lora_strength: float,
752
+ model_type: str = "flux2",
753
  ) -> dict:
754
+ """Build a ComfyUI workflow for FLUX generation.
755
+
756
+ FLUX.2 Dev uses separate model components (not a single checkpoint):
757
+ - UNETLoader for the diffusion model
758
+ - CLIPLoader (type=flux2) for the Mistral text encoder
759
+ - VAELoader for the autoencoder
760
+ """
761
+
762
+ if model_type == "flux2":
763
+ unet_name = "flux2-dev.safetensors"
764
+ clip_type = "flux2"
765
+ clip_name = "mistral_3_small_flux2_fp8.safetensors"
766
+ else:
767
+ unet_name = "flux1-dev.safetensors"
768
+ clip_type = "flux"
769
+ clip_name = "t5xxl_fp16.safetensors"
770
+
771
+ # Model node ID references
772
+ model_out = ["1", 0] # UNETLoader -> MODEL
773
+ clip_out = ["2", 0] # CLIPLoader -> CLIP
774
+ vae_out = ["3", 0] # VAELoader -> VAE
775
 
 
776
  workflow = {
777
+ # Load diffusion model (UNet)
778
+ "1": {
779
+ "class_type": "UNETLoader",
780
+ "inputs": {
781
+ "unet_name": unet_name,
782
+ "weight_dtype": "fp8_e4m3fn",
783
+ },
784
+ },
785
+ # Load text encoder
786
+ "2": {
787
+ "class_type": "CLIPLoader",
788
+ "inputs": {
789
+ "clip_name": clip_name,
790
+ "type": clip_type,
791
+ },
792
+ },
793
+ # Load VAE
794
+ "3": {
795
+ "class_type": "VAELoader",
796
+ "inputs": {"vae_name": "ae.safetensors"},
797
  },
798
+ # Positive prompt
799
  "6": {
800
  "class_type": "CLIPTextEncode",
801
  "inputs": {
802
  "text": prompt,
803
+ "clip": clip_out,
804
  },
805
  },
806
+ # Negative prompt
807
  "7": {
808
  "class_type": "CLIPTextEncode",
809
  "inputs": {
810
  "text": negative_prompt or "",
811
+ "clip": clip_out,
812
  },
813
  },
814
+ # Empty latent
815
  "5": {
816
  "class_type": "EmptyLatentImage",
817
  "inputs": {
 
820
  "batch_size": 1,
821
  },
822
  },
823
+ # Sampler
824
+ "10": {
825
  "class_type": "KSampler",
826
  "inputs": {
827
  "seed": seed,
 
830
  "sampler_name": "euler",
831
  "scheduler": "simple",
832
  "denoise": 1.0,
833
+ "model": model_out,
834
  "positive": ["6", 0],
835
  "negative": ["7", 0],
836
  "latent_image": ["5", 0],
837
  },
838
  },
839
+ # Decode
840
  "8": {
841
  "class_type": "VAEDecode",
842
  "inputs": {
843
+ "samples": ["10", 0],
844
+ "vae": vae_out,
845
  },
846
  },
847
+ # Save
848
  "9": {
849
  "class_type": "SaveImage",
850
  "inputs": {
 
856
 
857
  # Add LoRA if specified
858
  if lora_name:
859
+ workflow["20"] = {
860
  "class_type": "LoraLoader",
861
  "inputs": {
862
  "lora_name": lora_name,
863
  "strength_model": lora_strength,
864
  "strength_clip": lora_strength,
865
+ "model": model_out,
866
+ "clip": clip_out,
867
  },
868
  }
869
+ # Rewire sampler and text encoders to use LoRA output
870
+ workflow["10"]["inputs"]["model"] = ["20", 0]
871
+ workflow["6"]["inputs"]["clip"] = ["20", 1]
872
+ workflow["7"]["inputs"]["clip"] = ["20", 1]
873
 
874
  return workflow
src/content_engine/api/routes_training.py CHANGED
@@ -213,8 +213,10 @@ async def list_training_jobs():
213
  "loss": j.loss, "started_at": j.started_at,
214
  "completed_at": j.completed_at, "output_path": j.output_path,
215
  "error": j.error, "backend": "local",
 
216
  })
217
  if _runpod_trainer:
 
218
  for j in _runpod_trainer.list_jobs():
219
  jobs.append({
220
  "id": j.id, "name": j.name, "status": j.status,
@@ -225,6 +227,7 @@ async def list_training_jobs():
225
  "completed_at": j.completed_at, "output_path": j.output_path,
226
  "error": j.error, "backend": "runpod",
227
  "base_model": j.base_model, "model_type": j.model_type,
 
228
  })
229
  return jobs
230
 
@@ -232,27 +235,35 @@ async def list_training_jobs():
232
  @router.get("/jobs/{job_id}")
233
  async def get_training_job(job_id: str):
234
  """Get details of a specific training job including logs."""
235
- if _trainer is None:
236
- raise HTTPException(503, "Trainer not initialized")
237
- job = _trainer.get_job(job_id)
238
- if job is None:
239
- raise HTTPException(404, f"Training job not found: {job_id}")
240
- return {
241
- "id": job.id,
242
- "name": job.name,
243
- "status": job.status,
244
- "progress": round(job.progress, 3),
245
- "current_epoch": job.current_epoch,
246
- "total_epochs": job.total_epochs,
247
- "current_step": job.current_step,
248
- "total_steps": job.total_steps,
249
- "loss": job.loss,
250
- "started_at": job.started_at,
251
- "completed_at": job.completed_at,
252
- "output_path": job.output_path,
253
- "error": job.error,
254
- "log_lines": job.log_lines[-50:],
255
- }
 
 
 
 
 
 
 
 
256
 
257
 
258
  @router.post("/jobs/{job_id}/cancel")
@@ -267,3 +278,22 @@ async def cancel_training_job(job_id: str):
267
  if cancelled:
268
  return {"status": "cancelled", "job_id": job_id}
269
  raise HTTPException(404, "Job not found or not running")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  "loss": j.loss, "started_at": j.started_at,
214
  "completed_at": j.completed_at, "output_path": j.output_path,
215
  "error": j.error, "backend": "local",
216
+ "log_lines": j.log_lines[-50:] if hasattr(j, 'log_lines') else [],
217
  })
218
  if _runpod_trainer:
219
+ await _runpod_trainer.ensure_loaded()
220
  for j in _runpod_trainer.list_jobs():
221
  jobs.append({
222
  "id": j.id, "name": j.name, "status": j.status,
 
227
  "completed_at": j.completed_at, "output_path": j.output_path,
228
  "error": j.error, "backend": "runpod",
229
  "base_model": j.base_model, "model_type": j.model_type,
230
+ "log_lines": j.log_lines[-50:],
231
  })
232
  return jobs
233
 
 
235
  @router.get("/jobs/{job_id}")
236
  async def get_training_job(job_id: str):
237
  """Get details of a specific training job including logs."""
238
+ # Check RunPod jobs first
239
+ if _runpod_trainer:
240
+ await _runpod_trainer.ensure_loaded()
241
+ job = _runpod_trainer.get_job(job_id)
242
+ if job:
243
+ return {
244
+ "id": job.id, "name": job.name, "status": job.status,
245
+ "progress": round(job.progress, 3),
246
+ "current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
247
+ "current_step": job.current_step, "total_steps": job.total_steps,
248
+ "loss": job.loss, "started_at": job.started_at,
249
+ "completed_at": job.completed_at, "output_path": job.output_path,
250
+ "error": job.error, "log_lines": job.log_lines[-50:],
251
+ "backend": "runpod", "base_model": job.base_model,
252
+ }
253
+ # Then check local trainer
254
+ if _trainer:
255
+ job = _trainer.get_job(job_id)
256
+ if job:
257
+ return {
258
+ "id": job.id, "name": job.name, "status": job.status,
259
+ "progress": round(job.progress, 3),
260
+ "current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
261
+ "current_step": job.current_step, "total_steps": job.total_steps,
262
+ "loss": job.loss, "started_at": job.started_at,
263
+ "completed_at": job.completed_at, "output_path": job.output_path,
264
+ "error": job.error, "log_lines": job.log_lines[-50:],
265
+ }
266
+ raise HTTPException(404, f"Training job not found: {job_id}")
267
 
268
 
269
  @router.post("/jobs/{job_id}/cancel")
 
278
  if cancelled:
279
  return {"status": "cancelled", "job_id": job_id}
280
  raise HTTPException(404, "Job not found or not running")
281
+
282
+
283
+ @router.delete("/jobs/{job_id}")
284
+ async def delete_training_job(job_id: str):
285
+ """Delete a training job from history."""
286
+ if _runpod_trainer:
287
+ deleted = await _runpod_trainer.delete_job(job_id)
288
+ if deleted:
289
+ return {"status": "deleted", "job_id": job_id}
290
+ raise HTTPException(404, f"Training job not found: {job_id}")
291
+
292
+
293
+ @router.delete("/jobs")
294
+ async def delete_failed_jobs():
295
+ """Delete all failed training jobs."""
296
+ if _runpod_trainer:
297
+ count = await _runpod_trainer.delete_failed_jobs()
298
+ return {"status": "ok", "deleted": count}
299
+ return {"status": "ok", "deleted": 0}
src/content_engine/api/ui.html CHANGED
@@ -636,6 +636,25 @@ select { cursor: pointer; }
636
  .job-status-completed { background: rgba(34,197,94,0.15); color: var(--green); }
637
  .job-status-failed { background: rgba(239,68,68,0.15); color: var(--red); }
638
  .job-status-pending { background: rgba(136,136,136,0.15); color: var(--text-secondary); }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  /* --- Lightbox --- */
641
  .lightbox {
@@ -1245,10 +1264,22 @@ select { cursor: pointer; }
1245
  <div style="margin-top:8px">
1246
  <label style="margin:0">GPU Type</label>
1247
  <select id="train-gpu-type" style="margin-top:4px">
1248
- <option value="NVIDIA GeForce RTX 4090">RTX 4090 (~$0.44/hr) - Fastest</option>
1249
- <option value="NVIDIA GeForce RTX 3090">RTX 3090 (~$0.22/hr) - Good value</option>
1250
- <option value="NVIDIA RTX A5000">RTX A5000 (~$0.28/hr)</option>
1251
- <option value="NVIDIA RTX A4000">RTX A4000 (~$0.20/hr) - Cheapest</option>
 
 
 
 
 
 
 
 
 
 
 
 
1252
  </select>
1253
  </div>
1254
  </div>
@@ -1379,13 +1410,27 @@ select { cursor: pointer; }
1379
  </div>
1380
  <div id="pod-controls" style="display:flex; gap:8px; align-items:center; flex-wrap:wrap">
1381
  <select id="pod-model-type" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
1382
- <option value="flux">FLUX.2 (Realistic)</option>
1383
- <option value="wan">WAN 2.2 (General/Anime)</option>
 
1384
  </select>
1385
  <select id="pod-gpu-select" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
1386
- <option value="NVIDIA GeForce RTX 4090">RTX 4090 - $0.44/hr (24GB)</option>
1387
- <option value="NVIDIA RTX A6000">RTX A6000 - $0.76/hr (48GB)</option>
1388
- <option value="NVIDIA A100 80GB PCIe">A100 80GB - $1.89/hr (80GB)</option>
 
 
 
 
 
 
 
 
 
 
 
 
 
1389
  </select>
1390
  <button id="pod-start-btn" class="btn" onclick="startPod()">Start Pod</button>
1391
  <button id="pod-stop-btn" class="btn btn-danger" onclick="stopPod()" style="display:none">Stop Pod</button>
@@ -2761,9 +2806,10 @@ function updateModelDefaults() {
2761
  <span style="color:var(--accent)">Resolution: ${model.resolution}px | LR: ${model.learning_rate} | Rank: ${model.network_rank} | VRAM: ${model.vram_required_gb}GB</span>
2762
  `;
2763
 
2764
- // Update placeholder hints
2765
  document.getElementById('train-lr').placeholder = `Default: ${model.learning_rate}`;
2766
  document.getElementById('lr-default').textContent = `(default: ${model.learning_rate})`;
 
2767
 
2768
  // Update resolution default
2769
  const resSelect = document.getElementById('train-resolution');
@@ -2782,6 +2828,16 @@ function updateModelDefaults() {
2782
  break;
2783
  }
2784
  }
 
 
 
 
 
 
 
 
 
 
2785
  }
2786
 
2787
  function selectTrainBackend(chip, backend) {
@@ -2899,7 +2955,7 @@ async function pollTrainingJobs() {
2899
  renderTrainingJobs(jobs);
2900
 
2901
  // Stop polling if no active jobs
2902
- const active = jobs.filter(j => j.status === 'training' || j.status === 'preparing');
2903
  if (active.length === 0 && trainingPollInterval) {
2904
  clearInterval(trainingPollInterval);
2905
  trainingPollInterval = null;
@@ -2913,35 +2969,93 @@ function renderTrainingJobs(jobs) {
2913
  container.innerHTML = `<div class="empty-state" style="padding:30px"><p>No training jobs yet</p><p style="font-size:12px;margin-top:4px">Upload images and configure settings to start training</p></div>`;
2914
  return;
2915
  }
2916
- container.innerHTML = jobs.map(j => {
 
 
 
 
 
 
 
 
2917
  const pct = (j.progress * 100).toFixed(1);
2918
  const elapsed = j.started_at ? ((Date.now()/1000 - j.started_at) / 60).toFixed(0) : '?';
 
 
2919
  return `
2920
- <div class="job-card">
2921
  <div class="job-header">
2922
  <span class="job-name">${j.name} ${j.backend === 'runpod' ? '<span style="font-size:10px;color:var(--blue);font-weight:400">☁ RunPod</span>' : ''}</span>
2923
  <span class="job-status job-status-${j.status}">${j.status}</span>
2924
  </div>
2925
- ${['training','preparing','creating_pod','uploading','installing','downloading'].includes(j.status) ? `
2926
  <div class="progress-bar-container" style="margin-top:0">
2927
  <div class="progress-bar-fill" style="width:${pct}%"></div>
2928
  </div>
2929
  <div style="display:flex;gap:16px;margin-top:8px;font-size:12px;color:var(--text-secondary)">
2930
  <span>Progress: <strong style="color:var(--text-primary)">${pct}%</strong></span>
2931
  ${j.current_step ? `<span>Step: <strong style="color:var(--text-primary)">${j.current_step}/${j.total_steps}</strong></span>` : ''}
2932
- ${j.loss !== null ? `<span>Loss: <strong style="color:var(--text-primary)">${j.loss.toFixed(4)}</strong></span>` : ''}
2933
  <span>Time: <strong style="color:var(--text-primary)">${elapsed}m</strong></span>
2934
  </div>
2935
- <button class="btn btn-secondary btn-small" style="margin-top:8px" onclick="cancelTraining('${j.id}')">Cancel</button>
 
 
 
2936
  ` : ''}
2937
  ${j.status === 'completed' ? `
2938
  <div style="font-size:12px;color:var(--green);margin-top:4px">LoRA saved to ComfyUI models folder</div>
2939
  ${j.output_path ? `<div style="font-size:11px;color:var(--text-secondary);margin-top:2px;word-break:break-all">${j.output_path}</div>` : ''}
 
 
 
 
 
 
 
 
2940
  ` : ''}
2941
- ${j.status === 'failed' && j.error ? `<div style="font-size:12px;color:var(--red);margin-top:4px">${j.error}</div>` : ''}
 
 
2942
  </div>
2943
  `;
2944
  }).join('');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2945
  }
2946
 
2947
  async function cancelTraining(jobId) {
@@ -2954,6 +3068,27 @@ async function cancelTraining(jobId) {
2954
  }
2955
  }
2956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2957
  // --- Status ---
2958
  async function checkStatus() {
2959
  try {
@@ -3085,10 +3220,11 @@ function updatePodUI(pod) {
3085
  if (!podPollInterval) {
3086
  podPollInterval = setInterval(loadPodStatus, 30000);
3087
  }
3088
- } else if (pod.status === 'starting') {
3089
- statusText.innerHTML = '<span style="color:var(--yellow)">● Starting...</span> <span style="color:var(--text-secondary)">(~2-3 min)</span>';
 
3090
  startBtn.style.display = 'none';
3091
- stopBtn.style.display = 'none';
3092
  gpuSelect.style.display = 'none';
3093
  podInfo.style.display = 'none';
3094
  // Poll more frequently while starting
 
636
  .job-status-completed { background: rgba(34,197,94,0.15); color: var(--green); }
637
  .job-status-failed { background: rgba(239,68,68,0.15); color: var(--red); }
638
  .job-status-pending { background: rgba(136,136,136,0.15); color: var(--text-secondary); }
639
+ .job-logs-panel {
640
+ margin-top: 8px;
641
+ border-top: 1px solid var(--border);
642
+ padding-top: 8px;
643
+ }
644
+ .job-logs-content {
645
+ background: var(--bg-primary);
646
+ border: 1px solid var(--border);
647
+ border-radius: 6px;
648
+ padding: 8px 10px;
649
+ font-family: monospace;
650
+ font-size: 11px;
651
+ line-height: 1.5;
652
+ max-height: 300px;
653
+ overflow-y: auto;
654
+ white-space: pre-wrap;
655
+ word-break: break-all;
656
+ color: var(--text-secondary);
657
+ }
658
 
659
  /* --- Lightbox --- */
660
  .lightbox {
 
1264
  <div style="margin-top:8px">
1265
  <label style="margin:0">GPU Type</label>
1266
  <select id="train-gpu-type" style="margin-top:4px">
1267
+ <optgroup label="48GB+ (Required for FLUX.2 Dev)">
1268
+ <option value="NVIDIA A40">A40 48GB (~$0.64/hr) - Cheapest for FLUX.2</option>
1269
+ <option value="NVIDIA RTX A6000" selected>RTX A6000 48GB (~$0.76/hr) - Recommended</option>
1270
+ <option value="NVIDIA L40">L40 48GB (~$0.89/hr)</option>
1271
+ <option value="NVIDIA L40S">L40S 48GB (~$1.09/hr)</option>
1272
+ <option value="NVIDIA A100-SXM4-80GB">A100 SXM 80GB (~$1.64/hr)</option>
1273
+ <option value="NVIDIA A100 80GB PCIe">A100 PCIe 80GB (~$1.89/hr)</option>
1274
+ <option value="NVIDIA H100 80GB HBM3">H100 80GB (~$3.89/hr) - Fastest</option>
1275
+ </optgroup>
1276
+ <optgroup label="24-32GB (SD 1.5, SDXL, FLUX.1 only)">
1277
+ <option value="NVIDIA GeForce RTX 5090">RTX 5090 32GB (~$0.69/hr)</option>
1278
+ <option value="NVIDIA GeForce RTX 4090">RTX 4090 24GB (~$0.44/hr)</option>
1279
+ <option value="NVIDIA GeForce RTX 3090">RTX 3090 24GB (~$0.22/hr)</option>
1280
+ <option value="NVIDIA RTX A5000">RTX A5000 24GB (~$0.28/hr)</option>
1281
+ <option value="NVIDIA RTX A4000">RTX A4000 16GB (~$0.20/hr)</option>
1282
+ </optgroup>
1283
  </select>
1284
  </div>
1285
  </div>
 
1410
  </div>
1411
  <div id="pod-controls" style="display:flex; gap:8px; align-items:center; flex-wrap:wrap">
1412
  <select id="pod-model-type" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
1413
+ <option value="flux2">FLUX.2 Dev (Realistic txt2img)</option>
1414
+ <option value="flux1">FLUX.1 Dev (txt2img)</option>
1415
+ <option value="wan22">WAN 2.2 (img2video)</option>
1416
  </select>
1417
  <select id="pod-gpu-select" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
1418
+ <optgroup label="48GB+ (FLUX.2 / Large models)">
1419
+ <option value="NVIDIA A40">A40 48GB - $0.64/hr</option>
1420
+ <option value="NVIDIA RTX A6000" selected>A6000 48GB - $0.76/hr</option>
1421
+ <option value="NVIDIA L40">L40 48GB - $0.89/hr</option>
1422
+ <option value="NVIDIA L40S">L40S 48GB - $1.09/hr</option>
1423
+ <option value="NVIDIA A100-SXM4-80GB">A100 SXM 80GB - $1.64/hr</option>
1424
+ <option value="NVIDIA A100 80GB PCIe">A100 PCIe 80GB - $1.89/hr</option>
1425
+ <option value="NVIDIA H100 80GB HBM3">H100 80GB - $3.89/hr</option>
1426
+ </optgroup>
1427
+ <optgroup label="24-32GB (SD 1.5 / SDXL / FLUX.1)">
1428
+ <option value="NVIDIA GeForce RTX 5090">RTX 5090 32GB - $0.69/hr</option>
1429
+ <option value="NVIDIA GeForce RTX 4090">RTX 4090 24GB - $0.44/hr</option>
1430
+ <option value="NVIDIA GeForce RTX 3090">RTX 3090 24GB - $0.22/hr</option>
1431
+ <option value="NVIDIA RTX A5000">A5000 24GB - $0.28/hr</option>
1432
+ <option value="NVIDIA RTX A4000">A4000 16GB - $0.20/hr</option>
1433
+ </optgroup>
1434
  </select>
1435
  <button id="pod-start-btn" class="btn" onclick="startPod()">Start Pod</button>
1436
  <button id="pod-stop-btn" class="btn btn-danger" onclick="stopPod()" style="display:none">Stop Pod</button>
 
2806
  <span style="color:var(--accent)">Resolution: ${model.resolution}px | LR: ${model.learning_rate} | Rank: ${model.network_rank} | VRAM: ${model.vram_required_gb}GB</span>
2807
  `;
2808
 
2809
+ // Update placeholder hints and auto-fill LR
2810
  document.getElementById('train-lr').placeholder = `Default: ${model.learning_rate}`;
2811
  document.getElementById('lr-default').textContent = `(default: ${model.learning_rate})`;
2812
+ document.getElementById('train-lr').value = model.learning_rate;
2813
 
2814
  // Update resolution default
2815
  const resSelect = document.getElementById('train-resolution');
 
2828
  break;
2829
  }
2830
  }
2831
+
2832
+ // Update optimizer default
2833
+ const optSelect = document.getElementById('train-optimizer');
2834
+ const optName = (model.optimizer || 'AdamW8bit').toLowerCase();
2835
+ for (let opt of optSelect.options) {
2836
+ if (opt.value.toLowerCase() === optName) {
2837
+ opt.selected = true;
2838
+ break;
2839
+ }
2840
+ }
2841
  }
2842
 
2843
  function selectTrainBackend(chip, backend) {
 
2955
  renderTrainingJobs(jobs);
2956
 
2957
  // Stop polling if no active jobs
2958
+ const active = jobs.filter(j => ['training','preparing','creating_pod','uploading','installing','downloading','pending'].includes(j.status));
2959
  if (active.length === 0 && trainingPollInterval) {
2960
  clearInterval(trainingPollInterval);
2961
  trainingPollInterval = null;
 
2969
  container.innerHTML = `<div class="empty-state" style="padding:30px"><p>No training jobs yet</p><p style="font-size:12px;margin-top:4px">Upload images and configure settings to start training</p></div>`;
2970
  return;
2971
  }
2972
+
2973
+ // Store latest jobs for log viewer
2974
+ window._trainingJobs = jobs;
2975
+
2976
+ // Show "Clear Failed" button if there are any failed jobs
2977
+ const failedCount = jobs.filter(j => j.status === 'failed' || j.status === 'error').length;
2978
+ let html = failedCount > 0 ? `<div style="text-align:right;margin-bottom:8px"><button class="btn btn-secondary btn-small" style="color:var(--red)" onclick="clearFailedJobs()">Clear ${failedCount} Failed</button></div>` : '';
2979
+
2980
+ html += jobs.map(j => {
2981
  const pct = (j.progress * 100).toFixed(1);
2982
  const elapsed = j.started_at ? ((Date.now()/1000 - j.started_at) / 60).toFixed(0) : '?';
2983
+ const isActive = ['training','preparing','creating_pod','uploading','installing','downloading'].includes(j.status);
2984
+ const hasLogs = j.log_lines && j.log_lines.length > 0;
2985
  return `
2986
+ <div class="job-card" id="job-card-${j.id}">
2987
  <div class="job-header">
2988
  <span class="job-name">${j.name} ${j.backend === 'runpod' ? '<span style="font-size:10px;color:var(--blue);font-weight:400">☁ RunPod</span>' : ''}</span>
2989
  <span class="job-status job-status-${j.status}">${j.status}</span>
2990
  </div>
2991
+ ${isActive ? `
2992
  <div class="progress-bar-container" style="margin-top:0">
2993
  <div class="progress-bar-fill" style="width:${pct}%"></div>
2994
  </div>
2995
  <div style="display:flex;gap:16px;margin-top:8px;font-size:12px;color:var(--text-secondary)">
2996
  <span>Progress: <strong style="color:var(--text-primary)">${pct}%</strong></span>
2997
  ${j.current_step ? `<span>Step: <strong style="color:var(--text-primary)">${j.current_step}/${j.total_steps}</strong></span>` : ''}
2998
+ ${j.loss !== null && j.loss !== undefined ? `<span>Loss: <strong style="color:var(--text-primary)">${j.loss.toFixed(4)}</strong></span>` : ''}
2999
  <span>Time: <strong style="color:var(--text-primary)">${elapsed}m</strong></span>
3000
  </div>
3001
+ <div style="display:flex;gap:6px;margin-top:8px">
3002
+ <button class="btn btn-secondary btn-small" onclick="toggleJobLogs('${j.id}')">View Logs</button>
3003
+ <button class="btn btn-secondary btn-small" style="color:var(--red)" onclick="cancelTraining('${j.id}')">Cancel</button>
3004
+ </div>
3005
  ` : ''}
3006
  ${j.status === 'completed' ? `
3007
  <div style="font-size:12px;color:var(--green);margin-top:4px">LoRA saved to ComfyUI models folder</div>
3008
  ${j.output_path ? `<div style="font-size:11px;color:var(--text-secondary);margin-top:2px;word-break:break-all">${j.output_path}</div>` : ''}
3009
+ <button class="btn btn-secondary btn-small" style="margin-top:6px" onclick="toggleJobLogs('${j.id}')">View Logs</button>
3010
+ ` : ''}
3011
+ ${j.status === 'failed' ? `
3012
+ ${j.error ? `<div style="font-size:12px;color:var(--red);margin-top:4px">${j.error}</div>` : ''}
3013
+ <div style="display:flex;gap:6px;margin-top:6px">
3014
+ <button class="btn btn-secondary btn-small" onclick="toggleJobLogs('${j.id}')">View Logs</button>
3015
+ <button class="btn btn-secondary btn-small" style="color:var(--red)" onclick="deleteJob('${j.id}')">Delete</button>
3016
+ </div>
3017
  ` : ''}
3018
+ <div id="job-logs-${j.id}" class="job-logs-panel" style="display:none">
3019
+ <div class="job-logs-content"></div>
3020
+ </div>
3021
  </div>
3022
  `;
3023
  }).join('');
3024
+
3025
+ container.innerHTML = html;
3026
+
3027
+ // Auto-show logs for active jobs
3028
+ const activeJob = jobs.find(j => ['training','preparing','creating_pod','uploading','installing','downloading'].includes(j.status));
3029
+ if (activeJob && activeJob.log_lines && activeJob.log_lines.length > 0) {
3030
+ showJobLogs(activeJob.id);
3031
+ }
3032
+ }
3033
+
3034
+ function toggleJobLogs(jobId) {
3035
+ const panel = document.getElementById('job-logs-' + jobId);
3036
+ if (!panel) return;
3037
+ if (panel.style.display === 'none') {
3038
+ showJobLogs(jobId);
3039
+ } else {
3040
+ panel.style.display = 'none';
3041
+ }
3042
+ }
3043
+
3044
+ function showJobLogs(jobId) {
3045
+ const panel = document.getElementById('job-logs-' + jobId);
3046
+ if (!panel) return;
3047
+ panel.style.display = 'block';
3048
+
3049
+ // Find job data
3050
+ const job = (window._trainingJobs || []).find(j => j.id === jobId);
3051
+ if (!job || !job.log_lines) {
3052
+ panel.querySelector('.job-logs-content').textContent = 'No logs available';
3053
+ return;
3054
+ }
3055
+
3056
+ const content = panel.querySelector('.job-logs-content');
3057
+ content.textContent = job.log_lines.join('\n');
3058
+ content.scrollTop = content.scrollHeight;
3059
  }
3060
 
3061
  async function cancelTraining(jobId) {
 
3068
  }
3069
  }
3070
 
3071
+ async function deleteJob(jobId) {
3072
+ try {
3073
+ await fetch(API + `/api/training/jobs/${jobId}`, { method: 'DELETE' });
3074
+ toast('Job deleted', 'info');
3075
+ pollTrainingJobs();
3076
+ } catch(e) {
3077
+ toast('Failed to delete job', 'error');
3078
+ }
3079
+ }
3080
+
3081
+ async function clearFailedJobs() {
3082
+ try {
3083
+ const res = await fetch(API + '/api/training/jobs', { method: 'DELETE' });
3084
+ const data = await res.json();
3085
+ toast(`Cleared ${data.deleted} failed jobs`, 'info');
3086
+ pollTrainingJobs();
3087
+ } catch(e) {
3088
+ toast('Failed to clear jobs', 'error');
3089
+ }
3090
+ }
3091
+
3092
  // --- Status ---
3093
  async function checkStatus() {
3094
  try {
 
3220
  if (!podPollInterval) {
3221
  podPollInterval = setInterval(loadPodStatus, 30000);
3222
  }
3223
+ } else if (pod.status === 'starting' || pod.status === 'setting_up') {
3224
+ const setupMsg = pod.setup_status || 'Starting pod...';
3225
+ statusText.innerHTML = `<span style="color:var(--yellow)">● ${setupMsg}</span>`;
3226
  startBtn.style.display = 'none';
3227
+ stopBtn.style.display = ''; // Allow stopping during setup
3228
  gpuSelect.style.display = 'none';
3229
  podInfo.style.display = 'none';
3230
  // Poll more frequently while starting
src/content_engine/models/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (802 Bytes)
 
src/content_engine/models/__pycache__/database.cpython-311.pyc DELETED
Binary file (10.7 kB)
 
src/content_engine/models/__pycache__/schemas.cpython-311.pyc DELETED
Binary file (5.59 kB)
 
src/content_engine/models/database.py CHANGED
@@ -2,9 +2,7 @@
2
 
3
  from __future__ import annotations
4
 
5
- import os
6
  from datetime import datetime
7
- from pathlib import Path
8
 
9
  from sqlalchemy import (
10
  Boolean,
@@ -149,12 +147,36 @@ class ScheduledPost(Base):
149
  )
150
 
151
 
152
- # --- Engine / Session factories ---
 
153
 
154
- # Ensure database directory exists before creating engine
155
- _db_path = settings.database.url.replace("sqlite+aiosqlite:///", "")
156
- _db_dir = Path(_db_path).parent
157
- _db_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  _catalog_engine = create_async_engine(
160
  settings.database.url,
 
2
 
3
  from __future__ import annotations
4
 
 
5
  from datetime import datetime
 
6
 
7
  from sqlalchemy import (
8
  Boolean,
 
147
  )
148
 
149
 
150
+ class TrainingJob(Base):
151
+ __tablename__ = "training_jobs"
152
 
153
+ id: Mapped[str] = mapped_column(String(36), primary_key=True)
154
+ name: Mapped[str] = mapped_column(String(128), nullable=False)
155
+ status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
156
+ progress: Mapped[float] = mapped_column(Float, default=0.0)
157
+ current_epoch: Mapped[int] = mapped_column(Integer, default=0)
158
+ total_epochs: Mapped[int] = mapped_column(Integer, default=0)
159
+ current_step: Mapped[int] = mapped_column(Integer, default=0)
160
+ total_steps: Mapped[int] = mapped_column(Integer, default=0)
161
+ loss: Mapped[float | None] = mapped_column(Float)
162
+ started_at: Mapped[float | None] = mapped_column(Float)
163
+ completed_at: Mapped[float | None] = mapped_column(Float)
164
+ output_path: Mapped[str | None] = mapped_column(String(512))
165
+ error: Mapped[str | None] = mapped_column(Text)
166
+ log_text: Mapped[str | None] = mapped_column(Text) # newline-separated log lines
167
+ pod_id: Mapped[str | None] = mapped_column(String(64))
168
+ gpu_type: Mapped[str | None] = mapped_column(String(64))
169
+ backend: Mapped[str] = mapped_column(String(16), default="runpod")
170
+ base_model: Mapped[str | None] = mapped_column(String(64))
171
+ model_type: Mapped[str | None] = mapped_column(String(16))
172
+ trigger_word: Mapped[str | None] = mapped_column(String(128))
173
+ image_upload_dir: Mapped[str | None] = mapped_column(String(512))
174
+ created_at: Mapped[datetime] = mapped_column(
175
+ DateTime, server_default=func.now()
176
+ )
177
+
178
+
179
+ # --- Engine / Session factories ---
180
 
181
  _catalog_engine = create_async_engine(
182
  settings.database.url,
src/content_engine/services/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (232 Bytes)
 
src/content_engine/services/__pycache__/catalog.cpython-311.pyc DELETED
Binary file (13.1 kB)
 
src/content_engine/services/__pycache__/comfyui_client.cpython-311.pyc DELETED
Binary file (17.5 kB)
 
src/content_engine/services/__pycache__/lora_trainer.cpython-311.pyc DELETED
Binary file (19 kB)
 
src/content_engine/services/__pycache__/runpod_trainer.cpython-311.pyc DELETED
Binary file (32 kB)
 
src/content_engine/services/__pycache__/template_engine.cpython-311.pyc DELETED
Binary file (11 kB)
 
src/content_engine/services/__pycache__/variation_engine.cpython-311.pyc DELETED
Binary file (8.52 kB)
 
src/content_engine/services/__pycache__/workflow_builder.cpython-311.pyc DELETED
Binary file (7.16 kB)
 
src/content_engine/services/runpod_trainer.py CHANGED
@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
27
 
28
  import os
29
  from content_engine.config import settings, IS_HF_SPACES
 
30
 
31
  LORA_OUTPUT_DIR = settings.paths.lora_dir
32
  if IS_HF_SPACES:
@@ -36,16 +37,30 @@ else:
36
 
37
  # RunPod GPU options (id -> display name, approx cost/hr)
38
  GPU_OPTIONS = {
39
- "NVIDIA GeForce RTX 3090": "RTX 3090 (~$0.22/hr)",
40
- "NVIDIA GeForce RTX 4090": "RTX 4090 (~$0.44/hr)",
41
- "NVIDIA RTX A4000": "RTX A4000 (~$0.20/hr)",
42
- "NVIDIA RTX A5000": "RTX A5000 (~$0.28/hr)",
43
- "NVIDIA RTX A6000": "RTX A6000 (~$0.76/hr)",
 
 
 
 
 
 
44
  "NVIDIA A100 80GB PCIe": "A100 80GB (~$1.89/hr)",
 
 
45
  }
46
 
47
  DEFAULT_GPU = "NVIDIA GeForce RTX 4090"
48
 
 
 
 
 
 
 
49
  # Docker image with PyTorch + CUDA pre-installed
50
  DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
51
 
@@ -84,12 +99,15 @@ class CloudTrainingJob:
84
  cost_estimate: str | None = None
85
  base_model: str = "sd15_realistic"
86
  model_type: str = "sd15"
 
87
 
88
  def _log(self, msg: str):
89
  self.log_lines.append(msg)
90
  if len(self.log_lines) > 200:
91
  self.log_lines = self.log_lines[-200:]
92
  logger.info("[%s] %s", self.id, msg)
 
 
93
 
94
 
95
  class RunPodTrainer:
@@ -100,6 +118,7 @@ class RunPodTrainer:
100
  runpod.api_key = api_key
101
  self._jobs: dict[str, CloudTrainingJob] = {}
102
  self._model_registry = load_model_registry()
 
103
 
104
  @property
105
  def available(self) -> bool:
@@ -123,6 +142,8 @@ class RunPodTrainer:
123
  "learning_rate": cfg.get("learning_rate", 1e-4),
124
  "network_rank": cfg.get("network_rank", 32),
125
  "network_alpha": cfg.get("network_alpha", 16),
 
 
126
  "vram_required_gb": cfg.get("vram_required_gb", 8),
127
  "recommended_images": cfg.get("recommended_images", "15-30 photos"),
128
  }
@@ -179,6 +200,8 @@ class RunPodTrainer:
179
  model_type=model_type,
180
  )
181
  self._jobs[job_id] = job
 
 
182
 
183
  # Launch the full pipeline as a background task
184
  asyncio.create_task(self._run_cloud_training(
@@ -224,15 +247,26 @@ class RunPodTrainer:
224
  job.status = "creating_pod"
225
  job._log(f"Creating RunPod with {job.gpu_type}...")
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  pod = await asyncio.to_thread(
228
  runpod.create_pod,
229
  f"lora-train-{job.id}",
230
  DOCKER_IMAGE,
231
  job.gpu_type,
232
- volume_in_gb=75, # Increased for FLUX models
233
- container_disk_in_gb=30,
234
- ports="22/tcp",
235
- docker_args="bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo \"root:runpod\" | chpasswd && echo \"PermitRootLogin yes\" >> /etc/ssh/sshd_config && /usr/sbin/sshd && sleep infinity'",
236
  )
237
 
238
  job.pod_id = pod["id"]
@@ -263,89 +297,273 @@ class RunPodTrainer:
263
  await asyncio.sleep(5)
264
 
265
  job._log("SSH connected")
 
 
 
 
 
 
 
 
 
 
266
  sftp = ssh.open_sftp()
 
267
 
268
- # Step 3: Upload training images
269
  job.status = "uploading"
270
- job._log(f"Uploading {len(image_paths)} training images...")
 
 
 
 
 
271
 
272
  folder_name = f"10_{trigger_word or 'character'}"
273
  self._ssh_exec(ssh, f"mkdir -p /workspace/dataset/{folder_name}")
274
- for img_path in image_paths:
275
  p = Path(img_path)
276
  if p.exists():
277
- remote_path = f"/workspace/dataset/{folder_name}/{p.name}"
278
- sftp.put(str(p), remote_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  # Upload matching caption .txt file if it exists locally
280
  local_caption = p.with_suffix(".txt")
281
  if local_caption.exists():
282
- remote_caption = f"/workspace/dataset/{folder_name}/{local_caption.name}"
283
  sftp.put(str(local_caption), remote_caption)
284
  else:
285
  # Fallback: create caption from trigger word
286
- remote_caption = remote_path.rsplit(".", 1)[0] + ".txt"
287
  with sftp.open(remote_caption, "w") as f:
288
  f.write(trigger_word or "")
 
 
 
 
 
289
 
290
  job._log("Images uploaded")
291
 
292
- # Step 4: Install Kohya sd-scripts on the pod
293
  job.status = "installing"
294
- job._log("Installing Kohya sd-scripts (this takes a few minutes)...")
295
  job.progress = 0.05
296
 
297
- install_cmds = [
298
- "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/sd-scripts.git",
299
- "cd /workspace/sd-scripts && pip install -r requirements.txt 2>&1 | tail -1",
300
- "pip install accelerate lion-pytorch prodigy-optimizer safetensors bitsandbytes xformers 2>&1 | tail -1",
301
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  for cmd in install_cmds:
303
  out = self._ssh_exec(ssh, cmd, timeout=600)
304
  job._log(out[:200] if out else "done")
305
 
306
- # Download base model from HuggingFace
307
  hf_repo = model_cfg.get("hf_repo", "SG161222/Realistic_Vision_V5.1_noVAE")
308
  hf_filename = model_cfg.get("hf_filename", "Realistic_Vision_V5.1_fp16-no-ema.safetensors")
309
  model_name = model_cfg.get("name", job.base_model)
310
 
311
- job._log(f"Downloading base model: {model_name}...")
312
  job.progress = 0.1
313
-
314
  self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
315
 
316
- # Download main model
317
- self._ssh_exec(ssh, f"""
318
- python -c "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  from huggingface_hub import hf_hub_download
320
  hf_hub_download('{hf_repo}', '{hf_filename}', local_dir='/workspace/models')
321
  " 2>&1 | tail -5
322
- """, timeout=1200) # Longer timeout for large models
323
 
324
- # For FLUX, download additional required models (CLIP, T5, VAE)
325
- if model_type == "flux":
326
- job._log("Downloading FLUX auxiliary models (CLIP, T5, VAE)...")
327
- job.progress = 0.12
328
-
329
- self._ssh_exec(ssh, """
330
- python -c "
 
 
 
331
  from huggingface_hub import hf_hub_download
332
- # CLIP text encoder
333
  hf_hub_download('comfyanonymous/flux_text_encoders', 'clip_l.safetensors', local_dir='/workspace/models')
334
- # T5 text encoder (fp16)
335
  hf_hub_download('comfyanonymous/flux_text_encoders', 't5xxl_fp16.safetensors', local_dir='/workspace/models')
336
- # VAE/AutoEncoder
337
  hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/workspace/models')
338
  " 2>&1 | tail -5
339
- """, timeout=1200)
340
 
341
- job._log("Base model downloaded")
342
  job.progress = 0.15
343
 
344
  # Step 5: Run training
345
  job.status = "training"
346
  job._log(f"Starting {model_type.upper()} LoRA training...")
347
 
348
- model_path = f"/workspace/models/{hf_filename}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  # Build training command based on model type
351
  train_cmd = self._build_training_command(
@@ -361,6 +579,7 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
361
  optimizer=optimizer,
362
  save_every_n_epochs=save_every_n_epochs,
363
  model_cfg=model_cfg,
 
364
  )
365
 
366
  # Execute training and stream output
@@ -371,19 +590,37 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
371
 
372
  # Read output progressively
373
  buffer = ""
 
374
  while not channel.exit_status_ready() or channel.recv_ready():
375
  if channel.recv_ready():
376
  chunk = channel.recv(4096).decode("utf-8", errors="replace")
377
  buffer += chunk
378
- # Process complete lines
379
- while "\n" in buffer:
380
- line, buffer = buffer.split("\n", 1)
381
- line = line.strip()
 
 
 
 
 
 
 
 
 
382
  if not line:
383
  continue
384
  job._log(line)
385
  self._parse_progress(job, line)
 
386
  else:
 
 
 
 
 
 
 
387
  await asyncio.sleep(2)
388
 
389
  exit_code = channel.recv_exit_status()
@@ -393,27 +630,33 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
393
  job._log("Training completed on RunPod!")
394
  job.progress = 0.9
395
 
396
- # Step 6: Download the LoRA file
397
  job.status = "downloading"
398
- job._log("Downloading trained LoRA...")
399
-
400
- LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
401
- local_path = LORA_OUTPUT_DIR / f"{name}.safetensors"
402
 
403
- # Try the main output file first, then look for any .safetensors
404
- try:
405
- sftp.get(f"/workspace/output/{name}.safetensors", str(local_path))
406
- except FileNotFoundError:
407
- # Find any safetensors file
408
- remote_files = sftp.listdir("/workspace/output")
409
- safetensors = [f for f in remote_files if f.endswith(".safetensors")]
410
- if safetensors:
411
- sftp.get(f"/workspace/output/{safetensors[-1]}", str(local_path))
 
412
  else:
413
  raise RuntimeError("No .safetensors output found")
414
 
 
 
 
 
 
 
 
 
 
415
  job.output_path = str(local_path)
416
- job._log(f"LoRA saved to {local_path}")
417
 
418
  # Done!
419
  job.status = "completed"
@@ -449,6 +692,92 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
449
  except Exception as e:
450
  job._log(f"Warning: Failed to terminate pod {job.pod_id}: {e}")
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  def _build_training_command(
453
  self,
454
  *,
@@ -464,6 +793,7 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
464
  optimizer: str,
465
  save_every_n_epochs: int,
466
  model_cfg: dict,
 
467
  ) -> str:
468
  """Build the training command based on model type."""
469
 
@@ -497,8 +827,70 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
497
  lr_scheduler = model_cfg.get("lr_scheduler", "cosine_with_restarts")
498
  base_args += f" \\\n --lr_scheduler={lr_scheduler}"
499
 
500
- if model_type == "flux":
501
- # FLUX-specific training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  script = "flux_train_network.py"
503
  flux_args = f"""
504
  --pretrained_model_name_or_path="{model_path}" \
@@ -539,24 +931,33 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
539
  --clip_skip={clip_skip} \
540
  --xformers {base_args} 2>&1"""
541
 
542
- async def _wait_for_pod_ready(self, job: CloudTrainingJob, timeout: int = 300) -> tuple[str, int]:
543
  """Wait for pod to be running and return (ssh_host, ssh_port)."""
544
  start = time.time()
545
  while time.time() - start < timeout:
546
- pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
 
 
 
 
 
547
 
548
  status = pod.get("desiredStatus", "")
549
  runtime = pod.get("runtime")
550
 
551
  if status == "RUNNING" and runtime:
552
- ports = runtime.get("ports", [])
553
- for port_info in ports:
554
  if port_info.get("privatePort") == 22:
555
  ip = port_info.get("ip")
556
  public_port = port_info.get("publicPort")
557
  if ip and public_port:
558
  return ip, int(public_port)
559
 
 
 
 
 
560
  await asyncio.sleep(5)
561
 
562
  raise RuntimeError(f"Pod did not become ready within {timeout}s")
@@ -623,3 +1024,28 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
623
  job.status = "failed"
624
  job.error = "Cancelled by user"
625
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  import os
29
  from content_engine.config import settings, IS_HF_SPACES
30
+ from content_engine.models.database import catalog_session_factory, TrainingJob as TrainingJobDB
31
 
32
  LORA_OUTPUT_DIR = settings.paths.lora_dir
33
  if IS_HF_SPACES:
 
37
 
38
  # RunPod GPU options (id -> display name, approx cost/hr)
39
  GPU_OPTIONS = {
40
+ # 24GB - SD 1.5, SDXL, FLUX.1 only (NOT enough for FLUX.2)
41
+ "NVIDIA GeForce RTX 3090": "RTX 3090 24GB (~$0.22/hr)",
42
+ "NVIDIA GeForce RTX 4090": "RTX 4090 24GB (~$0.44/hr)",
43
+ "NVIDIA GeForce RTX 5090": "RTX 5090 32GB (~$0.69/hr)",
44
+ "NVIDIA RTX A4000": "RTX A4000 16GB (~$0.20/hr)",
45
+ "NVIDIA RTX A5000": "RTX A5000 24GB (~$0.28/hr)",
46
+ # 48GB+ - Required for FLUX.2 Dev (Mistral text encoder needs ~48GB)
47
+ "NVIDIA RTX A6000": "RTX A6000 48GB (~$0.76/hr)",
48
+ "NVIDIA A40": "A40 48GB (~$0.64/hr)",
49
+ "NVIDIA L40": "L40 48GB (~$0.89/hr)",
50
+ "NVIDIA L40S": "L40S 48GB (~$1.09/hr)",
51
  "NVIDIA A100 80GB PCIe": "A100 80GB (~$1.89/hr)",
52
+ "NVIDIA A100-SXM4-80GB": "A100 SXM 80GB (~$1.64/hr)",
53
+ "NVIDIA H100 80GB HBM3": "H100 80GB (~$3.89/hr)",
54
  }
55
 
56
  DEFAULT_GPU = "NVIDIA GeForce RTX 4090"
57
 
58
+ # Network volume for persistent model storage (avoids re-downloading models each run)
59
+ # Set RUNPOD_VOLUME_ID in .env to use a persistent volume
60
+ # Set RUNPOD_VOLUME_DC to the datacenter ID where the volume lives (e.g. "EU-RO-1")
61
+ NETWORK_VOLUME_ID = os.environ.get("RUNPOD_VOLUME_ID", "")
62
+ NETWORK_VOLUME_DC = os.environ.get("RUNPOD_VOLUME_DC", "")
63
+
64
  # Docker image with PyTorch + CUDA pre-installed
65
  DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
66
 
 
99
  cost_estimate: str | None = None
100
  base_model: str = "sd15_realistic"
101
  model_type: str = "sd15"
102
+ _db_callback: Any = None # called on state changes to persist to DB
103
 
104
  def _log(self, msg: str):
105
  self.log_lines.append(msg)
106
  if len(self.log_lines) > 200:
107
  self.log_lines = self.log_lines[-200:]
108
  logger.info("[%s] %s", self.id, msg)
109
+ if self._db_callback:
110
+ self._db_callback(self)
111
 
112
 
113
  class RunPodTrainer:
 
118
  runpod.api_key = api_key
119
  self._jobs: dict[str, CloudTrainingJob] = {}
120
  self._model_registry = load_model_registry()
121
+ self._loaded_from_db = False
122
 
123
  @property
124
  def available(self) -> bool:
 
142
  "learning_rate": cfg.get("learning_rate", 1e-4),
143
  "network_rank": cfg.get("network_rank", 32),
144
  "network_alpha": cfg.get("network_alpha", 16),
145
+ "optimizer": cfg.get("optimizer", "AdamW8bit"),
146
+ "lr_scheduler": cfg.get("lr_scheduler", "cosine"),
147
  "vram_required_gb": cfg.get("vram_required_gb", 8),
148
  "recommended_images": cfg.get("recommended_images", "15-30 photos"),
149
  }
 
200
  model_type=model_type,
201
  )
202
  self._jobs[job_id] = job
203
+ job._db_callback = self._schedule_db_save
204
+ asyncio.ensure_future(self._save_to_db(job))
205
 
206
  # Launch the full pipeline as a background task
207
  asyncio.create_task(self._run_cloud_training(
 
247
  job.status = "creating_pod"
248
  job._log(f"Creating RunPod with {job.gpu_type}...")
249
 
250
+ # Use network volume if configured (persists models across runs)
251
+ pod_kwargs = {
252
+ "container_disk_in_gb": 30,
253
+ "ports": "22/tcp",
254
+ "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
255
+ }
256
+ if NETWORK_VOLUME_ID:
257
+ pod_kwargs["network_volume_id"] = NETWORK_VOLUME_ID
258
+ if NETWORK_VOLUME_DC:
259
+ pod_kwargs["data_center_id"] = NETWORK_VOLUME_DC
260
+ job._log(f"Using persistent network volume: {NETWORK_VOLUME_ID} (DC: {NETWORK_VOLUME_DC or 'auto'})")
261
+ else:
262
+ pod_kwargs["volume_in_gb"] = 75
263
+
264
  pod = await asyncio.to_thread(
265
  runpod.create_pod,
266
  f"lora-train-{job.id}",
267
  DOCKER_IMAGE,
268
  job.gpu_type,
269
+ **pod_kwargs,
 
 
 
270
  )
271
 
272
  job.pod_id = pod["id"]
 
297
  await asyncio.sleep(5)
298
 
299
  job._log("SSH connected")
300
+
301
+ # If using network volume, symlink to /workspace so all paths work
302
+ if NETWORK_VOLUME_ID:
303
+ self._ssh_exec(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
304
+ job._log("Network volume symlinked to /workspace")
305
+
306
+ # Enable keepalive to prevent SSH timeout during uploads
307
+ transport = ssh.get_transport()
308
+ transport.set_keepalive(30)
309
+
310
  sftp = ssh.open_sftp()
311
+ sftp.get_channel().settimeout(300) # 5 min timeout per file
312
 
313
+ # Step 3: Upload training images (compress first to speed up transfer)
314
  job.status = "uploading"
315
+ resolution = model_cfg.get("resolution", 1024)
316
+ job._log(f"Compressing and uploading {len(image_paths)} training images...")
317
+
318
+ import tempfile
319
+ from PIL import Image
320
+ tmp_dir = Path(tempfile.mkdtemp(prefix="lora_upload_"))
321
 
322
  folder_name = f"10_{trigger_word or 'character'}"
323
  self._ssh_exec(ssh, f"mkdir -p /workspace/dataset/{folder_name}")
324
+ for i, img_path in enumerate(image_paths):
325
  p = Path(img_path)
326
  if p.exists():
327
+ # Resize and convert to JPEG to reduce upload size
328
+ try:
329
+ img = Image.open(p)
330
+ img.thumbnail((resolution * 2, resolution * 2), Image.LANCZOS)
331
+ compressed = tmp_dir / f"{p.stem}.jpg"
332
+ img.save(compressed, "JPEG", quality=95)
333
+ upload_path = compressed
334
+ except Exception:
335
+ upload_path = p # fallback to original
336
+
337
+ remote_name = f"{p.stem}.jpg" if upload_path.suffix == ".jpg" else p.name
338
+ remote_path = f"/workspace/dataset/{folder_name}/{remote_name}"
339
+ for attempt in range(3):
340
+ try:
341
+ sftp.put(str(upload_path), remote_path)
342
+ break
343
+ except (EOFError, OSError):
344
+ if attempt == 2:
345
+ raise
346
+ job._log(f"Upload retry {attempt+1} for {p.name}")
347
+ sftp.close()
348
+ sftp = ssh.open_sftp()
349
+ sftp.get_channel().settimeout(300)
350
  # Upload matching caption .txt file if it exists locally
351
  local_caption = p.with_suffix(".txt")
352
  if local_caption.exists():
353
+ remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt"
354
  sftp.put(str(local_caption), remote_caption)
355
  else:
356
  # Fallback: create caption from trigger word
357
+ remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt"
358
  with sftp.open(remote_caption, "w") as f:
359
  f.write(trigger_word or "")
360
+ job._log(f"Uploaded {i+1}/{len(image_paths)}: {p.name}")
361
+
362
+ # Cleanup temp compressed images
363
+ import shutil
364
+ shutil.rmtree(tmp_dir, ignore_errors=True)
365
 
366
  job._log("Images uploaded")
367
 
368
+ # Step 4: Install training framework on the pod (skip if cached on volume)
369
  job.status = "installing"
 
370
  job.progress = 0.05
371
 
372
+ training_framework = model_cfg.get("training_framework", "sd-scripts")
373
+
374
+ if training_framework == "musubi-tuner":
375
+ # FLUX.2 uses musubi-tuner (Kohya's newer framework)
376
+ tuner_dir = "/workspace/musubi-tuner"
377
+ install_cmds = []
378
+
379
+ # Check if already present in workspace
380
+ tuner_exist = self._ssh_exec(ssh, f"test -f {tuner_dir}/pyproject.toml && echo EXISTS || echo MISSING").strip()
381
+ if tuner_exist == "EXISTS":
382
+ job._log("musubi-tuner found in workspace")
383
+ else:
384
+ # Check volume cache
385
+ vol_exist = self._ssh_exec(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING").strip()
386
+ if vol_exist == "EXISTS":
387
+ job._log("Restoring musubi-tuner from volume cache...")
388
+ self._ssh_exec(ssh, f"rm -rf {tuner_dir} 2>/dev/null; cp -r /runpod-volume/musubi-tuner {tuner_dir}")
389
+ else:
390
+ job._log("Cloning musubi-tuner from GitHub...")
391
+ self._ssh_exec(ssh, f"rm -rf {tuner_dir} /runpod-volume/musubi-tuner 2>/dev/null; true")
392
+ install_cmds.append(f"cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git")
393
+ # Save to volume for future pods
394
+ if NETWORK_VOLUME_ID:
395
+ install_cmds.append(f"cp -r {tuner_dir} /runpod-volume/musubi-tuner")
396
+
397
+ # Always install pip deps (they are pod-local, lost on every new pod)
398
+ job._log("Installing pip dependencies (accelerate, torch, etc.)...")
399
+ install_cmds.extend([
400
+ f"cd {tuner_dir} && pip install -e . 2>&1 | tail -5",
401
+ "pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes 2>&1 | tail -5",
402
+ ])
403
+ else:
404
+ # SD 1.5 / SDXL / FLUX.1 use sd-scripts
405
+ scripts_exist = self._ssh_exec(ssh, "test -f /workspace/sd-scripts/setup.py && echo EXISTS || echo MISSING").strip()
406
+ if scripts_exist == "EXISTS":
407
+ job._log("Kohya sd-scripts already cached on volume, updating...")
408
+ install_cmds = [
409
+ "cd /workspace/sd-scripts && git pull 2>&1 | tail -1",
410
+ ]
411
+ else:
412
+ job._log("Installing Kohya sd-scripts (this takes a few minutes)...")
413
+ install_cmds = [
414
+ "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/sd-scripts.git",
415
+ ]
416
+ # Always install pip deps (pod-local, lost on new pods)
417
+ install_cmds.extend([
418
+ "cd /workspace/sd-scripts && pip install -r requirements.txt 2>&1 | tail -1",
419
+ "pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes xformers 2>&1 | tail -1",
420
+ ])
421
  for cmd in install_cmds:
422
  out = self._ssh_exec(ssh, cmd, timeout=600)
423
  job._log(out[:200] if out else "done")
424
 
425
+ # Download base model from HuggingFace (skip if already on network volume)
426
  hf_repo = model_cfg.get("hf_repo", "SG161222/Realistic_Vision_V5.1_noVAE")
427
  hf_filename = model_cfg.get("hf_filename", "Realistic_Vision_V5.1_fp16-no-ema.safetensors")
428
  model_name = model_cfg.get("name", job.base_model)
429
 
 
430
  job.progress = 0.1
 
431
  self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
432
 
433
+ if model_type == "flux2":
434
+ # FLUX.2 models are stored in a directory structure on the volume
435
+ flux2_dir = "/workspace/models/FLUX.2-dev"
436
+ dit_path = f"{flux2_dir}/flux2-dev.safetensors"
437
+ vae_path = f"{flux2_dir}/ae.safetensors" # Original BFL format (not diffusers)
438
+ te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
439
+
440
+ dit_exists = self._ssh_exec(ssh, f"test -f {dit_path} && echo EXISTS || echo MISSING").strip()
441
+ vae_exists = self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING").strip()
442
+ te_exists = self._ssh_exec(ssh, f"test -f {te_path} && echo EXISTS || echo MISSING").strip()
443
+
444
+ if dit_exists != "EXISTS" or te_exists != "EXISTS":
445
+ missing = []
446
+ if dit_exists != "EXISTS":
447
+ missing.append("DiT")
448
+ if te_exists != "EXISTS":
449
+ missing.append("text encoder")
450
+ raise RuntimeError(f"FLUX.2 Dev missing on volume: {', '.join(missing)}. Please download models to the network volume first.")
451
+
452
+ # Download ae.safetensors (original format VAE) if not present
453
+ if vae_exists != "EXISTS":
454
+ job._log("Downloading FLUX.2 VAE (ae.safetensors, 336MB)...")
455
+ self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
456
+ self._ssh_exec(ssh, f"""python -c "
457
+ from huggingface_hub import hf_hub_download
458
+ hf_hub_download('black-forest-labs/FLUX.2-dev', 'ae.safetensors', local_dir='{flux2_dir}')
459
+ print('Downloaded ae.safetensors')
460
+ " 2>&1 | tail -5""", timeout=600)
461
+ # Verify download
462
+ vae_check = self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING").strip()
463
+ if vae_check != "EXISTS":
464
+ raise RuntimeError("Failed to download ae.safetensors")
465
+ job._log("VAE downloaded")
466
+
467
+ job._log("FLUX.2 Dev models ready")
468
+
469
+ else:
470
+ # SD 1.5 / SDXL / FLUX.1 — download single model file
471
+ model_exists = self._ssh_exec(ssh, f"test -f /workspace/models/{hf_filename} && echo EXISTS || echo MISSING").strip()
472
+ if model_exists == "EXISTS":
473
+ job._log(f"Base model already cached on volume: {model_name}")
474
+ else:
475
+ job._log(f"Downloading base model: {model_name}...")
476
+ self._ssh_exec(ssh, f"""
477
+ python -c "
478
  from huggingface_hub import hf_hub_download
479
  hf_hub_download('{hf_repo}', '{hf_filename}', local_dir='/workspace/models')
480
  " 2>&1 | tail -5
481
+ """, timeout=1200)
482
 
483
+ # For FLUX.1, download additional required models (CLIP, T5, VAE)
484
+ if model_type == "flux":
485
+ flux_files_check = self._ssh_exec(ssh, "test -f /workspace/models/clip_l.safetensors && test -f /workspace/models/t5xxl_fp16.safetensors && test -f /workspace/models/ae.safetensors && echo EXISTS || echo MISSING").strip()
486
+ if flux_files_check == "EXISTS":
487
+ job._log("FLUX.1 auxiliary models already cached on volume")
488
+ else:
489
+ job._log("Downloading FLUX.1 auxiliary models (CLIP, T5, VAE)...")
490
+ job.progress = 0.12
491
+ self._ssh_exec(ssh, """
492
+ python -c "
493
  from huggingface_hub import hf_hub_download
 
494
  hf_hub_download('comfyanonymous/flux_text_encoders', 'clip_l.safetensors', local_dir='/workspace/models')
 
495
  hf_hub_download('comfyanonymous/flux_text_encoders', 't5xxl_fp16.safetensors', local_dir='/workspace/models')
 
496
  hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/workspace/models')
497
  " 2>&1 | tail -5
498
+ """, timeout=1200)
499
 
500
+ job._log("Base model ready")
501
  job.progress = 0.15
502
 
503
  # Step 5: Run training
504
  job.status = "training"
505
  job._log(f"Starting {model_type.upper()} LoRA training...")
506
 
507
+ if model_type == "flux2":
508
+ model_path = f"/workspace/models/FLUX.2-dev/flux2-dev.safetensors"
509
+ else:
510
+ model_path = f"/workspace/models/{hf_filename}"
511
+
512
+ # For musubi-tuner, create TOML dataset config
513
+ if training_framework == "musubi-tuner":
514
+ folder_name = f"10_{trigger_word or 'character'}"
515
+ toml_content = f"""[[datasets]]
516
+ image_directory = "/workspace/dataset/{folder_name}"
517
+ caption_extension = ".txt"
518
+ batch_size = 1
519
+ num_repeats = 10
520
+ resolution = [{resolution}, {resolution}]
521
+ """
522
+ self._ssh_exec(ssh, f"cat > /workspace/dataset.toml << 'TOMLEOF'\n{toml_content}TOMLEOF")
523
+ job._log("Created dataset.toml config")
524
+
525
+ # musubi-tuner requires pre-caching latents and text encoder outputs
526
+ flux2_dir = "/workspace/models/FLUX.2-dev"
527
+ vae_path = f"{flux2_dir}/ae.safetensors"
528
+ te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
529
+
530
+ job._log("Caching latents (VAE encoding)...")
531
+ job.progress = 0.15
532
+ self._schedule_db_save(job)
533
+ cache_latents_cmd = (
534
+ f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python src/musubi_tuner/flux_2_cache_latents.py"
535
+ f" --dataset_config /workspace/dataset.toml"
536
+ f" --vae {vae_path}"
537
+ f" --model_version dev"
538
+ f" --vae_dtype bfloat16"
539
+ f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}"
540
+ )
541
+ out = self._ssh_exec(ssh, cache_latents_cmd, timeout=600)
542
+ # Get last lines which have the real error
543
+ last_lines = out.split('\n')[-30:]
544
+ job._log('\n'.join(last_lines))
545
+ if "EXIT_CODE=0" not in out:
546
+ # Fetch the full error log
547
+ err_log = self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10")
548
+ job._log(f"Cache error details: {err_log}")
549
+ raise RuntimeError(f"Latent caching failed")
550
+
551
+ job._log("Caching text encoder outputs (bf16)...")
552
+ job.progress = 0.25
553
+ self._schedule_db_save(job)
554
+ cache_te_cmd = (
555
+ f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
556
+ f" python src/musubi_tuner/flux_2_cache_text_encoder_outputs.py"
557
+ f" --dataset_config /workspace/dataset.toml"
558
+ f" --text_encoder {te_path}"
559
+ f" --model_version dev"
560
+ f" --batch_size 1"
561
+ f" 2>&1; echo EXIT_CODE=$?"
562
+ )
563
+ out = self._ssh_exec(ssh, cache_te_cmd, timeout=600)
564
+ job._log(out[-500:] if out else "done")
565
+ if "EXIT_CODE=0" not in out:
566
+ raise RuntimeError(f"Text encoder caching failed: {out[-200:]}")
567
 
568
  # Build training command based on model type
569
  train_cmd = self._build_training_command(
 
579
  optimizer=optimizer,
580
  save_every_n_epochs=save_every_n_epochs,
581
  model_cfg=model_cfg,
582
+ gpu_type=job.gpu_type,
583
  )
584
 
585
  # Execute training and stream output
 
590
 
591
  # Read output progressively
592
  buffer = ""
593
+ last_flush = time.time()
594
  while not channel.exit_status_ready() or channel.recv_ready():
595
  if channel.recv_ready():
596
  chunk = channel.recv(4096).decode("utf-8", errors="replace")
597
  buffer += chunk
598
+ # Process complete lines (handle both \n and \r for tqdm progress)
599
+ while "\n" in buffer or "\r" in buffer:
600
+ # Split on whichever comes first
601
+ n_pos = buffer.find("\n")
602
+ r_pos = buffer.find("\r")
603
+ if n_pos == -1:
604
+ split_pos = r_pos
605
+ elif r_pos == -1:
606
+ split_pos = n_pos
607
+ else:
608
+ split_pos = min(n_pos, r_pos)
609
+ line = buffer[:split_pos].strip()
610
+ buffer = buffer[split_pos + 1:]
611
  if not line:
612
  continue
613
  job._log(line)
614
  self._parse_progress(job, line)
615
+ self._schedule_db_save(job)
616
  else:
617
+ # Periodically flush buffer for partial tqdm lines
618
+ if buffer.strip() and time.time() - last_flush > 10:
619
+ job._log(buffer.strip())
620
+ self._parse_progress(job, buffer.strip())
621
+ buffer = ""
622
+ last_flush = time.time()
623
+ self._schedule_db_save(job)
624
  await asyncio.sleep(2)
625
 
626
  exit_code = channel.recv_exit_status()
 
630
  job._log("Training completed on RunPod!")
631
  job.progress = 0.9
632
 
633
+ # Step 6: Save LoRA to network volume and download locally
634
  job.status = "downloading"
 
 
 
 
635
 
636
+ # First, copy to network volume for persistence
637
+ job._log("Saving LoRA to network volume...")
638
+ self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
639
+ remote_output = f"/workspace/output/{name}.safetensors"
640
+ # Find the output file
641
+ check = self._ssh_exec(ssh, f"test -f {remote_output} && echo EXISTS || echo MISSING").strip()
642
+ if check == "MISSING":
643
+ remote_files = self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null").strip()
644
+ if remote_files:
645
+ remote_output = remote_files.split("\n")[-1].strip()
646
  else:
647
  raise RuntimeError("No .safetensors output found")
648
 
649
+ self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
650
+ job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
651
+
652
+ # Then download locally
653
+ job._log("Downloading LoRA to local machine...")
654
+ LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
655
+ local_path = LORA_OUTPUT_DIR / f"{name}.safetensors"
656
+ sftp.get(remote_output, str(local_path))
657
+
658
  job.output_path = str(local_path)
659
+ job._log(f"LoRA saved locally to {local_path}")
660
 
661
  # Done!
662
  job.status = "completed"
 
692
  except Exception as e:
693
  job._log(f"Warning: Failed to terminate pod {job.pod_id}: {e}")
694
 
695
+ def _schedule_db_save(self, job: CloudTrainingJob):
696
+ """Schedule a DB save (non-blocking)."""
697
+ try:
698
+ asyncio.get_event_loop().create_task(self._save_to_db(job))
699
+ except RuntimeError:
700
+ pass # no event loop
701
+
702
+ async def _save_to_db(self, job: CloudTrainingJob):
703
+ """Persist job state to database."""
704
+ try:
705
+ from sqlalchemy import text
706
+ async with catalog_session_factory() as session:
707
+ # Use raw INSERT OR REPLACE for SQLite upsert
708
+ await session.execute(
709
+ text("""INSERT OR REPLACE INTO training_jobs
710
+ (id, name, status, progress, current_epoch, total_epochs,
711
+ current_step, total_steps, loss, started_at, completed_at,
712
+ output_path, error, log_text, pod_id, gpu_type, backend,
713
+ base_model, model_type, created_at)
714
+ VALUES (:id, :name, :status, :progress, :current_epoch, :total_epochs,
715
+ :current_step, :total_steps, :loss, :started_at, :completed_at,
716
+ :output_path, :error, :log_text, :pod_id, :gpu_type, :backend,
717
+ :base_model, :model_type, COALESCE((SELECT created_at FROM training_jobs WHERE id = :id), CURRENT_TIMESTAMP))
718
+ """),
719
+ {
720
+ "id": job.id, "name": job.name, "status": job.status,
721
+ "progress": job.progress, "current_epoch": job.current_epoch,
722
+ "total_epochs": job.total_epochs, "current_step": job.current_step,
723
+ "total_steps": job.total_steps, "loss": job.loss,
724
+ "started_at": job.started_at, "completed_at": job.completed_at,
725
+ "output_path": job.output_path, "error": job.error,
726
+ "log_text": "\n".join(job.log_lines[-200:]),
727
+ "pod_id": job.pod_id, "gpu_type": job.gpu_type,
728
+ "backend": "runpod", "base_model": job.base_model,
729
+ "model_type": job.model_type,
730
+ }
731
+ )
732
+ await session.commit()
733
+ except Exception as e:
734
+ logger.warning("Failed to save training job to DB: %s", e)
735
+
736
+ async def _load_jobs_from_db(self):
737
+ """Load previously saved jobs from database on startup."""
738
+ try:
739
+ from sqlalchemy import select
740
+ async with catalog_session_factory() as session:
741
+ result = await session.execute(
742
+ select(TrainingJobDB).order_by(TrainingJobDB.created_at.desc()).limit(20)
743
+ )
744
+ db_jobs = result.scalars().all()
745
+ for db_job in db_jobs:
746
+ if db_job.id not in self._jobs:
747
+ job = CloudTrainingJob(
748
+ id=db_job.id,
749
+ name=db_job.name,
750
+ status=db_job.status,
751
+ progress=db_job.progress or 0.0,
752
+ current_epoch=db_job.current_epoch or 0,
753
+ total_epochs=db_job.total_epochs or 0,
754
+ current_step=db_job.current_step or 0,
755
+ total_steps=db_job.total_steps or 0,
756
+ loss=db_job.loss,
757
+ started_at=db_job.started_at,
758
+ completed_at=db_job.completed_at,
759
+ output_path=db_job.output_path,
760
+ error=db_job.error,
761
+ log_lines=(db_job.log_text or "").split("\n") if db_job.log_text else [],
762
+ pod_id=db_job.pod_id,
763
+ gpu_type=db_job.gpu_type or DEFAULT_GPU,
764
+ base_model=db_job.base_model or "sd15_realistic",
765
+ model_type=db_job.model_type or "sd15",
766
+ )
767
+ # Mark interrupted jobs as failed
768
+ if job.status not in ("completed", "failed"):
769
+ job.status = "failed"
770
+ job.error = "Interrupted by server restart"
771
+ self._jobs[db_job.id] = job
772
+ except Exception as e:
773
+ logger.warning("Failed to load training jobs from DB: %s", e)
774
+
775
+ async def ensure_loaded(self):
776
+ """Load jobs from DB on first access."""
777
+ if not self._loaded_from_db:
778
+ self._loaded_from_db = True
779
+ await self._load_jobs_from_db()
780
+
781
  def _build_training_command(
782
  self,
783
  *,
 
793
  optimizer: str,
794
  save_every_n_epochs: int,
795
  model_cfg: dict,
796
+ gpu_type: str = "",
797
  ) -> str:
798
  """Build the training command based on model type."""
799
 
 
827
  lr_scheduler = model_cfg.get("lr_scheduler", "cosine_with_restarts")
828
  base_args += f" \\\n --lr_scheduler={lr_scheduler}"
829
 
830
+ if model_type == "flux2":
831
+ # FLUX.2 training via musubi-tuner
832
+ flux2_dir = "/workspace/models/FLUX.2-dev"
833
+ dit_path = f"{flux2_dir}/flux2-dev.safetensors"
834
+ vae_path = f"{flux2_dir}/ae.safetensors"
835
+ te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
836
+
837
+ network_mod = model_cfg.get("network_module", "networks.lora_flux_2")
838
+ ts_sampling = model_cfg.get("timestep_sampling", "flux2_shift")
839
+ lr_scheduler = model_cfg.get("lr_scheduler", "cosine")
840
+
841
+ # Build as list of args to avoid shell escaping issues
842
+ args = [
843
+ "cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
844
+ "accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16",
845
+ "src/musubi_tuner/flux_2_train_network.py",
846
+ "--model_version dev",
847
+ f"--dit {dit_path}",
848
+ f"--vae {vae_path}",
849
+ f"--text_encoder {te_path}",
850
+ "--dataset_config /workspace/dataset.toml",
851
+ "--sdpa --mixed_precision bf16",
852
+ f"--timestep_sampling {ts_sampling} --weighting_scheme none",
853
+ f"--network_module {network_mod}",
854
+ f"--network_dim={network_rank}",
855
+ f"--network_alpha={network_alpha}",
856
+ "--gradient_checkpointing",
857
+ ]
858
+
859
+ # Only use fp8_base on GPUs with native fp8 support (RTX 4090, H100)
860
+ # A100 and A6000 don't support fp8 tensor ops, and have enough VRAM without it
861
+ if gpu_type and ("4090" in gpu_type or "5090" in gpu_type or "L40S" in gpu_type or "H100" in gpu_type):
862
+ args.append("--fp8_base")
863
+
864
+ # Handle Prodigy optimizer (needs special class path and args)
865
+ if optimizer.lower() == "prodigy":
866
+ args.extend([
867
+ "--optimizer_type=prodigyopt.Prodigy",
868
+ f"--learning_rate={learning_rate}",
869
+ '--optimizer_args "weight_decay=0.01" "decouple=True" "use_bias_correction=True" "safeguard_warmup=True" "d_coef=2"',
870
+ ])
871
+ else:
872
+ args.extend([
873
+ f"--optimizer_type={optimizer}",
874
+ f"--learning_rate={learning_rate}",
875
+ ])
876
+
877
+ args.extend([
878
+ f"--save_every_n_epochs={save_every_n_epochs}",
879
+ "--seed=42",
880
+ '--output_dir=/workspace/output',
881
+ f'--output_name={name}',
882
+ f"--lr_scheduler={lr_scheduler}",
883
+ ])
884
+
885
+ if max_train_steps:
886
+ args.append(f"--max_train_steps={max_train_steps}")
887
+ else:
888
+ args.append(f"--max_train_epochs={num_epochs}")
889
+
890
+ return " ".join(args) + " 2>&1"
891
+
892
+ elif model_type == "flux":
893
+ # FLUX.1 training via sd-scripts
894
  script = "flux_train_network.py"
895
  flux_args = f"""
896
  --pretrained_model_name_or_path="{model_path}" \
 
931
  --clip_skip={clip_skip} \
932
  --xformers {base_args} 2>&1"""
933
 
934
+ async def _wait_for_pod_ready(self, job: CloudTrainingJob, timeout: int = 600) -> tuple[str, int]:
935
  """Wait for pod to be running and return (ssh_host, ssh_port)."""
936
  start = time.time()
937
  while time.time() - start < timeout:
938
+ try:
939
+ pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
940
+ except Exception as e:
941
+ job._log(f" API error: {e}")
942
+ await asyncio.sleep(10)
943
+ continue
944
 
945
  status = pod.get("desiredStatus", "")
946
  runtime = pod.get("runtime")
947
 
948
  if status == "RUNNING" and runtime:
949
+ ports = runtime.get("ports") or []
950
+ for port_info in (ports or []):
951
  if port_info.get("privatePort") == 22:
952
  ip = port_info.get("ip")
953
  public_port = port_info.get("publicPort")
954
  if ip and public_port:
955
  return ip, int(public_port)
956
 
957
+ elapsed = int(time.time() - start)
958
+ if elapsed % 30 < 6:
959
+ job._log(f" Status: {status} | runtime: {'ports pending' if runtime else 'not ready yet'} ({elapsed}s)")
960
+
961
  await asyncio.sleep(5)
962
 
963
  raise RuntimeError(f"Pod did not become ready within {timeout}s")
 
1024
  job.status = "failed"
1025
  job.error = "Cancelled by user"
1026
  return True
1027
+
1028
+ async def delete_job(self, job_id: str) -> bool:
1029
+ """Delete a training job from memory and database."""
1030
+ if job_id not in self._jobs:
1031
+ return False
1032
+ del self._jobs[job_id]
1033
+ try:
1034
+ async with catalog_session_factory() as session:
1035
+ result = await session.execute(
1036
+ __import__('sqlalchemy').select(TrainingJobDB).where(TrainingJobDB.id == job_id)
1037
+ )
1038
+ db_job = result.scalar_one_or_none()
1039
+ if db_job:
1040
+ await session.delete(db_job)
1041
+ await session.commit()
1042
+ except Exception as e:
1043
+ logger.warning("Failed to delete job from DB: %s", e)
1044
+ return True
1045
+
1046
+ async def delete_failed_jobs(self) -> int:
1047
+ """Delete all failed/error training jobs."""
1048
+ failed_ids = [jid for jid, j in self._jobs.items() if j.status in ("failed", "error")]
1049
+ for jid in failed_ids:
1050
+ await self.delete_job(jid)
1051
+ return len(failed_ids)
src/content_engine/workers/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (239 Bytes)
 
src/content_engine/workers/__pycache__/local_worker.cpython-311.pyc DELETED
Binary file (6.09 kB)