Spaces:
Running
Running
Auto-reconnect to running training pods after server restart
Browse filesOn startup, checks DB for in-progress training jobs, verifies their
RunPod pods are still running, and reconnects SSH to resume log
monitoring. Handles both still-running and already-completed cases.
Copies LoRA to volume and terminates pod on completion.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
src/content_engine/services/runpod_trainer.py
CHANGED
|
@@ -788,11 +788,27 @@ resolution = [{resolution}, {resolution}]
|
|
| 788 |
base_model=db_job.base_model or "sd15_realistic",
|
| 789 |
model_type=db_job.model_type or "sd15",
|
| 790 |
)
|
| 791 |
-
|
| 792 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
job.status = "failed"
|
| 794 |
job.error = "Interrupted by server restart"
|
| 795 |
-
self._jobs[db_job.id] = job
|
| 796 |
except Exception as e:
|
| 797 |
logger.warning("Failed to load training jobs from DB: %s", e)
|
| 798 |
|
|
@@ -802,6 +818,144 @@ resolution = [{resolution}, {resolution}]
|
|
| 802 |
self._loaded_from_db = True
|
| 803 |
await self._load_jobs_from_db()
|
| 804 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
def _build_training_command(
|
| 806 |
self,
|
| 807 |
*,
|
|
|
|
| 788 |
base_model=db_job.base_model or "sd15_realistic",
|
| 789 |
model_type=db_job.model_type or "sd15",
|
| 790 |
)
|
| 791 |
+
self._jobs[db_job.id] = job
|
| 792 |
+
# Try to reconnect to running training pods
|
| 793 |
+
if job.status not in ("completed", "failed") and job.pod_id:
|
| 794 |
+
try:
|
| 795 |
+
pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
|
| 796 |
+
if pod and pod.get("desiredStatus") == "RUNNING":
|
| 797 |
+
job.status = "training"
|
| 798 |
+
job.error = None
|
| 799 |
+
job._log("Reconnecting to running training pod after restart...")
|
| 800 |
+
asyncio.create_task(self._reconnect_training(job))
|
| 801 |
+
logger.info("Reconnecting to training pod %s for job %s", job.pod_id, job.id)
|
| 802 |
+
else:
|
| 803 |
+
job.status = "failed"
|
| 804 |
+
job.error = "Pod terminated during server restart"
|
| 805 |
+
except Exception as e:
|
| 806 |
+
logger.warning("Could not check pod %s: %s", job.pod_id, e)
|
| 807 |
+
job.status = "failed"
|
| 808 |
+
job.error = "Interrupted by server restart"
|
| 809 |
+
elif job.status not in ("completed", "failed"):
|
| 810 |
job.status = "failed"
|
| 811 |
job.error = "Interrupted by server restart"
|
|
|
|
| 812 |
except Exception as e:
|
| 813 |
logger.warning("Failed to load training jobs from DB: %s", e)
|
| 814 |
|
|
|
|
| 818 |
self._loaded_from_db = True
|
| 819 |
await self._load_jobs_from_db()
|
| 820 |
|
| 821 |
+
async def _reconnect_training(self, job: CloudTrainingJob):
|
| 822 |
+
"""Reconnect to a training pod after server restart and resume log monitoring."""
|
| 823 |
+
import paramiko
|
| 824 |
+
ssh = None
|
| 825 |
+
try:
|
| 826 |
+
# Get SSH info from RunPod
|
| 827 |
+
pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
|
| 828 |
+
if not pod:
|
| 829 |
+
raise RuntimeError("Pod not found")
|
| 830 |
+
|
| 831 |
+
runtime = pod.get("runtime") or {}
|
| 832 |
+
ports = runtime.get("ports") or []
|
| 833 |
+
ssh_host = ssh_port = None
|
| 834 |
+
for p in ports:
|
| 835 |
+
if p.get("privatePort") == 22:
|
| 836 |
+
ssh_host = p.get("ip")
|
| 837 |
+
ssh_port = p.get("publicPort")
|
| 838 |
+
|
| 839 |
+
if not ssh_host or not ssh_port:
|
| 840 |
+
raise RuntimeError("SSH port not available")
|
| 841 |
+
|
| 842 |
+
# Connect SSH
|
| 843 |
+
ssh = paramiko.SSHClient()
|
| 844 |
+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
| 845 |
+
await asyncio.to_thread(
|
| 846 |
+
ssh.connect, ssh_host, port=int(ssh_port),
|
| 847 |
+
username="root", password="runpod", timeout=10,
|
| 848 |
+
)
|
| 849 |
+
transport = ssh.get_transport()
|
| 850 |
+
transport.set_keepalive(30)
|
| 851 |
+
job._log(f"Reconnected to pod {job.pod_id}")
|
| 852 |
+
|
| 853 |
+
# Check if training is still running
|
| 854 |
+
log_file = "/tmp/training.log"
|
| 855 |
+
exit_file = "/tmp/training.exit"
|
| 856 |
+
pid_file = "/tmp/training.pid"
|
| 857 |
+
|
| 858 |
+
exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
|
| 859 |
+
if exit_check:
|
| 860 |
+
# Training already finished while we were disconnected
|
| 861 |
+
exit_code = int(exit_check)
|
| 862 |
+
log_tail = await self._ssh_exec(ssh, f"tail -50 {log_file} 2>/dev/null")
|
| 863 |
+
for line in log_tail.split("\n"):
|
| 864 |
+
line = line.strip()
|
| 865 |
+
if line:
|
| 866 |
+
job._log(line)
|
| 867 |
+
self._parse_progress(job, line)
|
| 868 |
+
|
| 869 |
+
if exit_code == 0:
|
| 870 |
+
job._log("Training completed while disconnected!")
|
| 871 |
+
# Copy LoRA to volume
|
| 872 |
+
name = job.name
|
| 873 |
+
await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
|
| 874 |
+
remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip()
|
| 875 |
+
if remote_files:
|
| 876 |
+
remote_output = remote_files.split("\n")[-1].strip()
|
| 877 |
+
await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
|
| 878 |
+
job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
|
| 879 |
+
job.output_path = f"/runpod-volume/loras/{name}.safetensors"
|
| 880 |
+
|
| 881 |
+
job.status = "completed"
|
| 882 |
+
job.progress = 1.0
|
| 883 |
+
job.completed_at = time.time()
|
| 884 |
+
else:
|
| 885 |
+
raise RuntimeError(f"Training failed with exit code {exit_code}")
|
| 886 |
+
else:
|
| 887 |
+
# Training still running — resume log monitoring
|
| 888 |
+
pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip()
|
| 889 |
+
job._log(f"Training still running (PID: {pid}), resuming monitoring...")
|
| 890 |
+
|
| 891 |
+
last_offset = 0
|
| 892 |
+
while True:
|
| 893 |
+
exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
|
| 894 |
+
if exit_check:
|
| 895 |
+
exit_code = int(exit_check)
|
| 896 |
+
remaining = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)
|
| 897 |
+
if remaining:
|
| 898 |
+
for line in remaining.split("\n"):
|
| 899 |
+
line = line.strip()
|
| 900 |
+
if line:
|
| 901 |
+
job._log(line)
|
| 902 |
+
self._parse_progress(job, line)
|
| 903 |
+
|
| 904 |
+
if exit_code == 0:
|
| 905 |
+
# Copy LoRA to volume
|
| 906 |
+
name = job.name
|
| 907 |
+
await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
|
| 908 |
+
remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip()
|
| 909 |
+
if remote_files:
|
| 910 |
+
remote_output = remote_files.split("\n")[-1].strip()
|
| 911 |
+
await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
|
| 912 |
+
job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
|
| 913 |
+
job.output_path = f"/runpod-volume/loras/{name}.safetensors"
|
| 914 |
+
|
| 915 |
+
job.status = "completed"
|
| 916 |
+
job.progress = 1.0
|
| 917 |
+
job.completed_at = time.time()
|
| 918 |
+
break
|
| 919 |
+
else:
|
| 920 |
+
raise RuntimeError(f"Training failed with exit code {exit_code}")
|
| 921 |
+
|
| 922 |
+
try:
|
| 923 |
+
new_output = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)
|
| 924 |
+
if new_output:
|
| 925 |
+
last_offset += len(new_output.encode("utf-8"))
|
| 926 |
+
for line in new_output.replace("\r", "\n").split("\n"):
|
| 927 |
+
line = line.strip()
|
| 928 |
+
if line:
|
| 929 |
+
job._log(line)
|
| 930 |
+
self._parse_progress(job, line)
|
| 931 |
+
self._schedule_db_save(job)
|
| 932 |
+
except Exception:
|
| 933 |
+
pass
|
| 934 |
+
await asyncio.sleep(5)
|
| 935 |
+
|
| 936 |
+
job._log("Training complete!")
|
| 937 |
+
|
| 938 |
+
except Exception as e:
|
| 939 |
+
job.status = "failed"
|
| 940 |
+
job.error = str(e)
|
| 941 |
+
job._log(f"Reconnect failed: {e}")
|
| 942 |
+
logger.error("Training reconnect failed for %s: %s", job.id, e)
|
| 943 |
+
|
| 944 |
+
finally:
|
| 945 |
+
if ssh:
|
| 946 |
+
try:
|
| 947 |
+
ssh.close()
|
| 948 |
+
except Exception:
|
| 949 |
+
pass
|
| 950 |
+
# Terminate pod
|
| 951 |
+
if job.pod_id:
|
| 952 |
+
try:
|
| 953 |
+
await asyncio.to_thread(runpod.terminate_pod, job.pod_id)
|
| 954 |
+
job._log("Pod terminated")
|
| 955 |
+
except Exception:
|
| 956 |
+
pass
|
| 957 |
+
self._schedule_db_save(job)
|
| 958 |
+
|
| 959 |
def _build_training_command(
|
| 960 |
self,
|
| 961 |
*,
|