TSXu commited on
Commit
a895d85
·
1 Parent(s): d9765a4

Remove AOT compilation code, keep FA3 + FP8 only

Browse files

- Removed compile_model_first_time()
- Removed AOT cache functions (_check_compiled_graph_exists, _load_compiled_graph, _upload_compiled_graph)
- Removed INDUCTOR_CONFIGS
- Simplified logging setup
- FA3 + FP8 quantization is fast enough without pre-compilation

Files changed (2) hide show
  1. FLUX-Kontext-fp8 +0 -1
  2. app.py +6 -239
FLUX-Kontext-fp8 DELETED
@@ -1 +0,0 @@
1
- Subproject commit 1588a5618e83f18d291920de2b399d530edf8dbc
 
 
app.py CHANGED
@@ -1,45 +1,23 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
  Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
4
- With Float8 quantization and AOT compilation for faster inference
5
  """
6
 
7
- # Install compatible torch 2.8 + torchvision 0.23 + torchao + spaces (for AOT compilation)
8
- # spaces.aoti_capture requires PyTorch 2.8+
9
  import os
10
  import sys
11
  import logging
12
- import traceback
13
  from datetime import datetime
14
 
15
- # Setup logging to file
16
- LOG_FILE = "aot_compile.log"
17
  logging.basicConfig(
18
- level=logging.DEBUG,
19
  format='%(asctime)s [%(levelname)s] %(message)s',
20
- handlers=[
21
- logging.FileHandler(LOG_FILE, mode='w', encoding='utf-8'),
22
- logging.StreamHandler(sys.stdout)
23
- ]
24
  )
25
  logger = logging.getLogger(__name__)
26
 
27
- # Also redirect print to log
28
- class LoggingPrinter:
29
- def __init__(self, logger, original_stdout):
30
- self.logger = logger
31
- self.original_stdout = original_stdout
32
- def write(self, message):
33
- if message.strip():
34
- self.logger.info(message.strip())
35
- self.original_stdout.write(message)
36
- def flush(self):
37
- self.original_stdout.flush()
38
-
39
- # Keep original stdout for gradio
40
- _original_stdout = sys.stdout
41
-
42
- # Install compatible nightly versions - let pip resolve the exact matching versions
43
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 torch torchvision torchao spaces')
44
  logger.info("torch + torchvision + torchao + spaces (nightly) installation complete!")
45
 
@@ -100,7 +78,6 @@ except:
100
  # Global generator instance
101
  generator = None
102
  _cached_model_dir = None
103
- _is_optimized = False
104
 
105
  # ============================================================
106
  # Pre-download model files at startup (no GPU needed)
@@ -160,94 +137,10 @@ print("="*50)
160
 
161
 
162
  # ============================================================
163
- # AOT Optimization Configuration (from FLUX-Kontext-fp8)
164
  # ============================================================
165
- from torch.utils._pytree import tree_map_only
166
- # FP8 quantization for faster inference (works with FA3)
167
  from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
168
 
169
- # Inductor configuration for optimal performance
170
- INDUCTOR_CONFIGS = {
171
- 'conv_1x1_as_mm': True,
172
- 'epilogue_fusion': False,
173
- 'coordinate_descent_tuning': True,
174
- 'coordinate_descent_check_all_directions': True,
175
- 'max_autotune': True,
176
- 'triton.cudagraphs': True,
177
- }
178
-
179
- # ============================================================
180
- # AOT Compiled Graph Caching (save to / load from HF Hub)
181
- # ============================================================
182
- HF_CACHE_REPO = "TSXu/Unicalli_Pro"
183
- HF_CACHE_FILENAME = "compiled_graph.pt2"
184
-
185
-
186
- def _check_compiled_graph_exists():
187
- """Check if compiled graph exists on HF Hub (fast check)"""
188
- from huggingface_hub import hf_hub_url, get_hf_file_metadata
189
- try:
190
- url = hf_hub_url(HF_CACHE_REPO, HF_CACHE_FILENAME)
191
- get_hf_file_metadata(url) # Raises if file doesn't exist
192
- return True
193
- except Exception:
194
- return False
195
-
196
-
197
- def _load_compiled_graph(model):
198
- """Load compiled graph from HF Hub using ZeroGPU internals"""
199
- from huggingface_hub import hf_hub_download
200
- from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights, drain_module_parameters
201
-
202
- logger.info(f"Downloading compiled graph from {HF_CACHE_REPO}/{HF_CACHE_FILENAME}...")
203
- compiled_graph_file = hf_hub_download(HF_CACHE_REPO, HF_CACHE_FILENAME)
204
- logger.info(f"✓ Downloaded to: {compiled_graph_file}")
205
-
206
- logger.info("Loading compiled graph into model...")
207
- state_dict = model.state_dict()
208
- zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
209
- compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
210
-
211
- # Replace forward method
212
- setattr(model, "forward", compiled)
213
- drain_module_parameters(model)
214
- logger.info("✓ Compiled graph loaded and applied!")
215
- return True
216
-
217
-
218
- def _upload_compiled_graph(compiled):
219
- """Upload compiled graph to HF Hub"""
220
- from huggingface_hub import upload_file
221
- import tempfile
222
-
223
- hf_token = os.environ.get("HF_TOKEN")
224
- if not hf_token:
225
- logger.warning("HF_TOKEN not set, cannot upload compiled graph")
226
- return False
227
-
228
- logger.info(f"Uploading compiled graph to {HF_CACHE_REPO}/{HF_CACHE_FILENAME}...")
229
-
230
- # Save archive to temp file
231
- with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
232
- f.write(compiled.archive_file.getvalue())
233
- temp_path = f.name
234
-
235
- try:
236
- upload_file(
237
- path_or_fileobj=temp_path,
238
- path_in_repo=HF_CACHE_FILENAME,
239
- repo_id=HF_CACHE_REPO,
240
- token=hf_token,
241
- commit_message="Upload AOT compiled graph",
242
- )
243
- logger.info("✓ Compiled graph uploaded to Hub!")
244
- return True
245
- except Exception as e:
246
- logger.error(f"Failed to upload compiled graph: {e}")
247
- return False
248
- finally:
249
- os.unlink(temp_path)
250
-
251
 
252
  def init_generator():
253
  """Initialize the generator (without optimization - that's done separately)"""
@@ -312,132 +205,6 @@ def parse_font_style(font_style: str) -> str:
312
  return None
313
 
314
 
315
- @spaces.GPU(duration=900) # 15 min for compilation (if needed)
316
- def compile_model_first_time():
317
- """
318
- First-time: Load model and either load cached compiled graph or compile from scratch.
319
- Compiled graph is cached on HF Hub for fast subsequent cold starts.
320
- """
321
- global _is_optimized, generator
322
-
323
- logger.info("="*50)
324
- logger.info("First-time run: Loading model...")
325
- logger.info("="*50)
326
-
327
- try:
328
- # Load model
329
- gen = init_generator()
330
- model = gen.model
331
-
332
- # Check if compiled graph exists on Hub
333
- logger.info("Checking for cached compiled graph on HF Hub...")
334
- if _check_compiled_graph_exists():
335
- logger.info("="*50)
336
- logger.info("Found cached compiled graph! Loading from Hub...")
337
- logger.info("="*50)
338
- _load_compiled_graph(model)
339
- _is_optimized = True
340
- logger.info("✓ Model loaded with cached compiled graph!")
341
- logger.info("="*50)
342
- return None
343
-
344
- # No cached graph, compile from scratch
345
- logger.info("="*50)
346
- logger.info("No cached graph found. Compiling from scratch...")
347
- logger.info("="*50)
348
-
349
- # Step 1: Capture model forward during a real inference
350
- logger.info("Step 1: Capturing model forward pass...")
351
- with spaces.aoti_capture(model) as call:
352
- gen.generate(
353
- text="测试长度等于七",
354
- font_style="楷",
355
- author=None,
356
- num_steps=1,
357
- seed=42,
358
- )
359
- logger.info("✓ Forward pass captured!")
360
-
361
- # Log call info
362
- logger.info(f" call.args types: {[type(a).__name__ for a in call.args]}")
363
- logger.info(f" call.kwargs keys: {list(call.kwargs.keys())}")
364
- for k, v in call.kwargs.items():
365
- if hasattr(v, 'shape'):
366
- logger.info(f" {k}: tensor shape={v.shape}, dtype={v.dtype}")
367
- else:
368
- logger.info(f" {k}: {type(v).__name__} = {v}")
369
-
370
- # Step 2: Build dynamic_shapes (all static)
371
- logger.info("Step 2: Building static shapes...")
372
- dynamic_shapes = {}
373
- for k, v in call.kwargs.items():
374
- dynamic_shapes[k] = None # Static shape for all
375
- logger.info(f" dynamic_shapes: {dynamic_shapes}")
376
- logger.info("✓ Static shapes configured!")
377
-
378
- # Step 3: Disable gradients on model
379
- logger.info("Step 3: Disabling gradients on model...")
380
- model.eval()
381
- model.requires_grad_(False)
382
- logger.info("✓ Model in eval mode with gradients disabled!")
383
-
384
- # Step 4: Detach inputs
385
- logger.info("Step 4: Detaching inputs...")
386
- detached_args = tuple(
387
- a.detach() if isinstance(a, torch.Tensor) else a for a in call.args
388
- )
389
- detached_kwargs = {
390
- k: v.detach() if isinstance(v, torch.Tensor) else v
391
- for k, v in call.kwargs.items()
392
- }
393
- logger.info("✓ Inputs detached!")
394
-
395
- # Step 5: Export model
396
- logger.info("Step 5: Exporting model with torch.export.export...")
397
- exported = torch.export.export(
398
- mod=model,
399
- args=detached_args,
400
- kwargs=detached_kwargs,
401
- dynamic_shapes=dynamic_shapes,
402
- )
403
- logger.info("✓ Model exported!")
404
-
405
- # Step 6: AOT compile
406
- logger.info("Step 6: AOT compiling with spaces.aoti_compile...")
407
- logger.info(f" Inductor configs: {INDUCTOR_CONFIGS}")
408
- compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
409
- logger.info("✓ AOT compilation complete!")
410
-
411
- # Step 7: Upload compiled graph to Hub
412
- logger.info("Step 7: Uploading compiled graph to Hub...")
413
- _upload_compiled_graph(compiled)
414
-
415
- # Step 8: Apply compiled model
416
- logger.info("Step 8: Applying compiled model...")
417
- spaces.aoti_apply(compiled, model)
418
- logger.info("✓ AOT compiled model applied!")
419
-
420
- _is_optimized = True
421
- logger.info("="*50)
422
- logger.info("✓ Model compiled and cached to Hub!")
423
- logger.info("="*50)
424
-
425
- except Exception as e:
426
- logger.error("="*50)
427
- logger.error("AOT COMPILATION FAILED!")
428
- logger.error("="*50)
429
- logger.error(f"Exception: {e}")
430
- logger.error("Full traceback:")
431
- logger.error(traceback.format_exc())
432
- with open("aot_error.log", "w") as f:
433
- f.write(f"Exception: {e}\n\n")
434
- f.write(traceback.format_exc())
435
- raise
436
-
437
- # NOTE: Don't return gen - causes pickle error in ZeroGPU multiprocessing
438
- return None
439
-
440
-
441
  def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
442
  """Calculate dynamic GPU duration: 20s loading + 1.5s per step per image"""
443
  return 20 + int(1.5 * num_steps * num_images)
 
1
  # -*- coding: utf-8 -*-
2
  """
3
  Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
4
+ With FA3 + FP8 quantization for faster inference
5
  """
6
 
 
 
7
  import os
8
  import sys
9
  import logging
 
10
  from datetime import datetime
11
 
12
+ # Setup logging
 
13
  logging.basicConfig(
14
+ level=logging.INFO,
15
  format='%(asctime)s [%(levelname)s] %(message)s',
16
+ handlers=[logging.StreamHandler(sys.stdout)]
 
 
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Install nightly versions for FA3 + FP8 support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 torch torchvision torchao spaces')
22
  logger.info("torch + torchvision + torchao + spaces (nightly) installation complete!")
23
 
 
78
  # Global generator instance
79
  generator = None
80
  _cached_model_dir = None
 
81
 
82
  # ============================================================
83
  # Pre-download model files at startup (no GPU needed)
 
137
 
138
 
139
  # ============================================================
140
+ # FP8 Quantization (works with FA3)
141
  # ============================================================
 
 
142
  from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def init_generator():
146
  """Initialize the generator (without optimization - that's done separately)"""
 
205
  return None
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
209
  """Calculate dynamic GPU duration: 20s loading + 1.5s per step per image"""
210
  return 20 + int(1.5 * num_steps * num_images)