chahui commited on
Commit
35c5f3f
·
verified ·
1 Parent(s): 98dc780

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -5,25 +5,33 @@ import threading
5
  from pathlib import Path
6
 
7
  # === Runtime environment tweaks (must be set before importing heavy libs) ===
8
- os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # force CPU
9
- os.environ.setdefault("WORLDGEN_DISABLE_EVAL", "1") # avoid KNN/eval paths if supported
10
  os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio") # centralize gradio cache
11
 
12
  import nunchaku.utils
13
- def safe_get_precision():
14
- import torch
15
- if not torch.cuda.is_available():
16
- return 'fp32'
17
- return nunchaku.utils.get_precision()
 
18
  nunchaku.utils.get_precision = safe_get_precision
19
 
 
 
 
 
 
 
 
 
20
  import gradio as gr
21
  import torch
22
  from worldgen import WorldGen
23
 
24
  # Lazy init to avoid long cold start
25
  _wg = None
26
-
27
  def get_worldgen():
28
  global _wg
29
  if _wg is None:
 
5
  from pathlib import Path
6
 
7
  # === Runtime environment tweaks (must be set before importing heavy libs) ===
8
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # force CPU
9
+ os.environ.setdefault("WORLDGEN_DISABLE_EVAL", "1") # avoid KNN/eval paths if supported
10
  os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio") # centralize gradio cache
11
 
12
  import nunchaku.utils
13
+
14
+ def safe_get_precision(*args, **kwargs):
15
+ # Always return fp32 on CPU to avoid any torch.cuda references
16
+ return "fp32"
17
+
18
+ # Patch module attribute
19
  nunchaku.utils.get_precision = safe_get_precision
20
 
21
+ # ALSO patch the module that imported the symbol name directly
22
+ try:
23
+ import nunchaku.models.transformers.transformer_flux as tf
24
+ tf.get_precision = safe_get_precision
25
+ except Exception:
26
+ # If the module isn't imported yet, it will still use our utils patch when it imports later.
27
+ pass
28
+
29
  import gradio as gr
30
  import torch
31
  from worldgen import WorldGen
32
 
33
  # Lazy init to avoid long cold start
34
  _wg = None
 
35
  def get_worldgen():
36
  global _wg
37
  if _wg is None: