hnpinq commited on
Commit
0125d29
·
verified ·
1 Parent(s): 7a6acce

Update test_flux.py

Browse files
Files changed (1) hide show
  1. test_flux.py +30 -3
test_flux.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import time
2
  import random
3
  import torch
@@ -6,6 +8,14 @@ from PIL import Image
6
  import sys
7
  import os
8
 
 
 
 
 
 
 
 
 
9
 
10
  import nodes
11
  from nodes import NODE_CLASS_MAPPINGS
@@ -25,6 +35,7 @@ def closestNumber(n, m):
25
 
26
  # Load models
27
  print("\n⏳ Loading models...")
 
28
 
29
  DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
30
  UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]()
@@ -43,7 +54,8 @@ with torch.inference_mode():
43
  unet = UNETLoader.load_unet("flux1-dev-fp8.safetensors", "fp8_e4m3fn")[0]
44
  vae = VAELoader.load_vae("ae.sft")[0]
45
 
46
-
 
47
 
48
  # Inference test
49
  print("\n🚀 Starting inference...")
@@ -77,7 +89,22 @@ with torch.inference_mode():
77
  model_management.soft_empty_cache()
78
 
79
  decoded = VAEDecode.decode(vae, sample)[0].detach()
80
- image = Image.fromarray(np.array(decoded*255, dtype=np.uint8)[0])
 
 
81
  image.save("flux_output.png")
82
 
83
- inference_time = time.time() - inference_start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
  import time
4
  import random
5
  import torch
 
8
  import sys
9
  import os
10
 
11
+ # Add TotoroUI to path
12
+ sys.path.append(os.getcwd())
13
+
14
+ print("=== FLUX Inference Timing Test ===")
15
+ print(f"CUDA available: {torch.cuda.is_available()}")
16
+ if torch.cuda.is_available():
17
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
18
+ print(f"CUDA version: {torch.version.cuda}")
19
 
20
  import nodes
21
  from nodes import NODE_CLASS_MAPPINGS
 
35
 
36
  # Load models
37
  print("\n⏳ Loading models...")
38
+ start_time = time.time()
39
 
40
  DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
41
  UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]()
 
54
  unet = UNETLoader.load_unet("flux1-dev-fp8.safetensors", "fp8_e4m3fn")[0]
55
  vae = VAELoader.load_vae("ae.sft")[0]
56
 
57
+ load_time = time.time() - load_start
58
+ print(f"✅ Models loaded in {load_time:.2f}s")
59
 
60
  # Inference test
61
  print("\n🚀 Starting inference...")
 
89
  model_management.soft_empty_cache()
90
 
91
  decoded = VAEDecode.decode(vae, sample)[0].detach()
92
+ # Fix deprecation warning
93
+ image_array = (decoded*255).clamp(0, 255).byte().cpu().numpy()[0]
94
+ image = Image.fromarray(image_array)
95
  image.save("flux_output.png")
96
 
97
+ inference_time = time.time() - inference_start
98
+ total_time = time.time() - start_time
99
+
100
+ print(f"\n⏱️ === TIMING RESULTS ===")
101
+ print(f"🔄 Model loading: {load_time:.2f}s")
102
+ print(f"🎨 Image generation: {inference_time:.2f}s")
103
+ print(f"⏱️ Total time: {total_time:.2f}s")
104
+ print(f"💾 Image saved: flux_output.png")
105
+
106
+ # Performance metrics
107
+ steps_per_sec = steps / inference_time
108
+ print(f"\n📊 Performance:")
109
+ print(f"Steps/second: {steps_per_sec:.2f}")
110
+ print(f"Time per step: {inference_time/steps:.2f}s")