garvitsachdeva commited on
Commit
82cc86d
Β·
1 Parent(s): a5c7dd0

fix: force-reinstall CUDA torch before gradio import; populate logs on page load

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -14,9 +14,27 @@ Training starts automatically when the Space boots.
14
  Refresh the page or click "Refresh" to see live progress.
15
  """
16
 
17
- import sys, os
18
  print("=== PYTHON STARTED ===", flush=True)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import gradio as gr
21
  print("=== GRADIO IMPORTED ===", flush=True)
22
 
@@ -589,6 +607,7 @@ with gr.Blocks(title="SpindleFlow RL Training", css=CSS) as demo:
589
  )
590
 
591
  refresh_btn.click(fn=_get_state, outputs=[status_box, log_box])
 
592
  timer = gr.Timer(value=10)
593
  timer.tick(fn=_get_state, outputs=[status_box, log_box])
594
 
 
14
  Refresh the page or click "Refresh" to see live progress.
15
  """
16
 
17
+ import sys, os, subprocess
18
  print("=== PYTHON STARTED ===", flush=True)
19
 
20
+ # ── Force CUDA torch before any `import torch` happens in this process ─────────
21
+ # requirements.txt installs CPU torch as a transitive dep of sentence-transformers.
22
+ # --force-reinstall overrides "already satisfied"; --no-deps only touches torch.
23
+ # This subprocess runs before gradio (and therefore before any torch import).
24
+ print("Installing CUDA torch (force)...", flush=True)
25
+ _cuda_r = subprocess.run(
26
+ [sys.executable, "-m", "pip", "install", "-q",
27
+ "--force-reinstall", "--no-deps",
28
+ "--index-url", "https://download.pytorch.org/whl/cu121",
29
+ "torch"],
30
+ capture_output=True, text=True,
31
+ timeout=300,
32
+ )
33
+ if _cuda_r.returncode == 0:
34
+ print("CUDA torch installed OK.", flush=True)
35
+ else:
36
+ print("CUDA torch install FAILED:", _cuda_r.stderr[-400:], flush=True)
37
+
38
  import gradio as gr
39
  print("=== GRADIO IMPORTED ===", flush=True)
40
 
 
607
  )
608
 
609
  refresh_btn.click(fn=_get_state, outputs=[status_box, log_box])
610
+ demo.load(fn=_get_state, outputs=[status_box, log_box])
611
  timer = gr.Timer(value=10)
612
  timer.tick(fn=_get_state, outputs=[status_box, log_box])
613