dippoo Claude Opus 4.6 commited on
Commit
1c46fa4
·
1 Parent(s): b4ceba9

Auto-reconnect to running training pods after server restart

Browse files

On startup, checks DB for in-progress training jobs, verifies their
RunPod pods are still running, and reconnects SSH to resume log
monitoring. Handles both still-running and already-completed cases.
Copies LoRA to volume and terminates pod on completion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

src/content_engine/services/runpod_trainer.py CHANGED
@@ -788,11 +788,27 @@ resolution = [{resolution}, {resolution}]
788
  base_model=db_job.base_model or "sd15_realistic",
789
  model_type=db_job.model_type or "sd15",
790
  )
791
- # Mark interrupted jobs as failed
792
- if job.status not in ("completed", "failed"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793
  job.status = "failed"
794
  job.error = "Interrupted by server restart"
795
- self._jobs[db_job.id] = job
796
  except Exception as e:
797
  logger.warning("Failed to load training jobs from DB: %s", e)
798
 
@@ -802,6 +818,144 @@ resolution = [{resolution}, {resolution}]
802
  self._loaded_from_db = True
803
  await self._load_jobs_from_db()
804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
  def _build_training_command(
806
  self,
807
  *,
 
788
  base_model=db_job.base_model or "sd15_realistic",
789
  model_type=db_job.model_type or "sd15",
790
  )
791
+ self._jobs[db_job.id] = job
792
+ # Try to reconnect to running training pods
793
+ if job.status not in ("completed", "failed") and job.pod_id:
794
+ try:
795
+ pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
796
+ if pod and pod.get("desiredStatus") == "RUNNING":
797
+ job.status = "training"
798
+ job.error = None
799
+ job._log("Reconnecting to running training pod after restart...")
800
+ asyncio.create_task(self._reconnect_training(job))
801
+ logger.info("Reconnecting to training pod %s for job %s", job.pod_id, job.id)
802
+ else:
803
+ job.status = "failed"
804
+ job.error = "Pod terminated during server restart"
805
+ except Exception as e:
806
+ logger.warning("Could not check pod %s: %s", job.pod_id, e)
807
+ job.status = "failed"
808
+ job.error = "Interrupted by server restart"
809
+ elif job.status not in ("completed", "failed"):
810
  job.status = "failed"
811
  job.error = "Interrupted by server restart"
 
812
  except Exception as e:
813
  logger.warning("Failed to load training jobs from DB: %s", e)
814
 
 
818
  self._loaded_from_db = True
819
  await self._load_jobs_from_db()
820
 
821
+ async def _reconnect_training(self, job: CloudTrainingJob):
822
+ """Reconnect to a training pod after server restart and resume log monitoring."""
823
+ import paramiko
824
+ ssh = None
825
+ try:
826
+ # Get SSH info from RunPod
827
+ pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
828
+ if not pod:
829
+ raise RuntimeError("Pod not found")
830
+
831
+ runtime = pod.get("runtime") or {}
832
+ ports = runtime.get("ports") or []
833
+ ssh_host = ssh_port = None
834
+ for p in ports:
835
+ if p.get("privatePort") == 22:
836
+ ssh_host = p.get("ip")
837
+ ssh_port = p.get("publicPort")
838
+
839
+ if not ssh_host or not ssh_port:
840
+ raise RuntimeError("SSH port not available")
841
+
842
+ # Connect SSH
843
+ ssh = paramiko.SSHClient()
844
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
845
+ await asyncio.to_thread(
846
+ ssh.connect, ssh_host, port=int(ssh_port),
847
+ username="root", password="runpod", timeout=10,
848
+ )
849
+ transport = ssh.get_transport()
850
+ transport.set_keepalive(30)
851
+ job._log(f"Reconnected to pod {job.pod_id}")
852
+
853
+ # Check if training is still running
854
+ log_file = "/tmp/training.log"
855
+ exit_file = "/tmp/training.exit"
856
+ pid_file = "/tmp/training.pid"
857
+
858
+ exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
859
+ if exit_check:
860
+ # Training already finished while we were disconnected
861
+ exit_code = int(exit_check)
862
+ log_tail = await self._ssh_exec(ssh, f"tail -50 {log_file} 2>/dev/null")
863
+ for line in log_tail.split("\n"):
864
+ line = line.strip()
865
+ if line:
866
+ job._log(line)
867
+ self._parse_progress(job, line)
868
+
869
+ if exit_code == 0:
870
+ job._log("Training completed while disconnected!")
871
+ # Copy LoRA to volume
872
+ name = job.name
873
+ await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
874
+ remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip()
875
+ if remote_files:
876
+ remote_output = remote_files.split("\n")[-1].strip()
877
+ await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
878
+ job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
879
+ job.output_path = f"/runpod-volume/loras/{name}.safetensors"
880
+
881
+ job.status = "completed"
882
+ job.progress = 1.0
883
+ job.completed_at = time.time()
884
+ else:
885
+ raise RuntimeError(f"Training failed with exit code {exit_code}")
886
+ else:
887
+ # Training still running — resume log monitoring
888
+ pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip()
889
+ job._log(f"Training still running (PID: {pid}), resuming monitoring...")
890
+
891
+ last_offset = 0
892
+ while True:
893
+ exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
894
+ if exit_check:
895
+ exit_code = int(exit_check)
896
+ remaining = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)
897
+ if remaining:
898
+ for line in remaining.split("\n"):
899
+ line = line.strip()
900
+ if line:
901
+ job._log(line)
902
+ self._parse_progress(job, line)
903
+
904
+ if exit_code == 0:
905
+ # Copy LoRA to volume
906
+ name = job.name
907
+ await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
908
+ remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip()
909
+ if remote_files:
910
+ remote_output = remote_files.split("\n")[-1].strip()
911
+ await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
912
+ job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
913
+ job.output_path = f"/runpod-volume/loras/{name}.safetensors"
914
+
915
+ job.status = "completed"
916
+ job.progress = 1.0
917
+ job.completed_at = time.time()
918
+ break
919
+ else:
920
+ raise RuntimeError(f"Training failed with exit code {exit_code}")
921
+
922
+ try:
923
+ new_output = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)
924
+ if new_output:
925
+ last_offset += len(new_output.encode("utf-8"))
926
+ for line in new_output.replace("\r", "\n").split("\n"):
927
+ line = line.strip()
928
+ if line:
929
+ job._log(line)
930
+ self._parse_progress(job, line)
931
+ self._schedule_db_save(job)
932
+ except Exception:
933
+ pass
934
+ await asyncio.sleep(5)
935
+
936
+ job._log("Training complete!")
937
+
938
+ except Exception as e:
939
+ job.status = "failed"
940
+ job.error = str(e)
941
+ job._log(f"Reconnect failed: {e}")
942
+ logger.error("Training reconnect failed for %s: %s", job.id, e)
943
+
944
+ finally:
945
+ if ssh:
946
+ try:
947
+ ssh.close()
948
+ except Exception:
949
+ pass
950
+ # Terminate pod
951
+ if job.pod_id:
952
+ try:
953
+ await asyncio.to_thread(runpod.terminate_pod, job.pod_id)
954
+ job._log("Pod terminated")
955
+ except Exception:
956
+ pass
957
+ self._schedule_db_save(job)
958
+
959
  def _build_training_command(
960
  self,
961
  *,