Txu647 commited on
Commit
414150e
·
1 Parent(s): 11050b2

Add NF4 4-bit inference with bitsandbytes

Browse files
Files changed (5) hide show
  1. README.md +2 -5
  2. app.py +21 -3
  3. inference.py +163 -19
  4. requirements.txt +1 -0
  5. src/flux/xflux_pipeline.py +91 -9
README.md CHANGED
@@ -11,14 +11,10 @@ license: cc-by-nc-nd-4.0
11
  short_description: Chinese Calligraphy Generator
12
  ---
13
 
14
- # 🖌️ UniCalli - Chinese Calligraphy Generator
15
 
16
  **A Unified Diffusion Framework for Column-Level Generation and Recognition of Chinese Calligraphy**
17
 
18
- Generate beautiful Chinese calligraphy in various styles and by different historical masters.
19
-
20
- 用不同历史书法大师的风格生成精美的中国书法。
21
-
22
  ## Links
23
 
24
  - 🌐 **Project Page**: [https://envision-research.github.io/UniCalli/](https://envision-research.github.io/UniCalli/)
@@ -33,6 +29,7 @@ Generate beautiful Chinese calligraphy in various styles and by different histor
33
  - **Historical Masters**: 90+ calligraphers including 王羲之, 颜真卿, 赵佶/宋徽宗, etc.
34
  - **Multiple Font Styles**: 楷 (Regular), 行 (Running), 草 (Cursive)
35
  - **Interactive Session**: Generate multiple images in one GPU session
 
36
 
37
  ## Usage
38
 
 
11
  short_description: Chinese Calligraphy Generator
12
  ---
13
 
14
+ # 🖌️ UniCalli-Dev - Chinese Calligraphy Generator
15
 
16
  **A Unified Diffusion Framework for Column-Level Generation and Recognition of Chinese Calligraphy**
17
 
 
 
 
 
18
  ## Links
19
 
20
  - 🌐 **Project Page**: [https://envision-research.github.io/UniCalli/](https://envision-research.github.io/UniCalli/)
 
29
  - **Historical Masters**: 90+ calligraphers including 王羲之, 颜真卿, 赵佶/宋徽宗, etc.
30
  - **Multiple Font Styles**: 楷 (Regular), 行 (Running), 草 (Cursive)
31
  - **Interactive Session**: Generate multiple images in one GPU session
32
+ - **4-bit Quantization**: Runtime quantization for efficient inference on limited GPU memory
33
 
34
  ## Usage
35
 
app.py CHANGED
@@ -68,16 +68,34 @@ def init_generator():
68
  if generator is None:
69
  # Lazy import to avoid CUDA initialization at module load time
70
  from inference import CalligraphyGenerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  generator = CalligraphyGenerator(
72
  model_name="flux-dev",
73
  device="cuda",
74
  offload=True, # Enable offload to save GPU memory
75
- intern_vlm_path="OpenGVLab/InternVL3-1B",
76
- checkpoint_path="TSXu/Unicalli_Pro",
77
  font_descriptions_path='dataset/chirography.json',
78
  author_descriptions_path='dataset/calligraphy_styles_en.json',
79
  use_deepspeed=False,
80
- use_4bit_quantization=False, # Disabled - quantization overhead not worth it
81
  )
82
  return generator
83
 
 
68
  if generator is None:
69
  # Lazy import to avoid CUDA initialization at module load time
70
  from inference import CalligraphyGenerator
71
+ import os
72
+ from huggingface_hub import snapshot_download
73
+
74
+ # Download NF4 quantized model from HuggingFace (~6GB instead of 23GB)
75
+ hf_token = os.environ.get("HF_TOKEN", None)
76
+ print("Downloading NF4 quantized model from TSXu/Unicalli_Pro...")
77
+ local_dir = snapshot_download(
78
+ repo_id="TSXu/Unicalli_Pro",
79
+ allow_patterns=["unicalli_pro_chars7_nf4/*"],
80
+ token=hf_token
81
+ )
82
+ checkpoint_path = os.path.join(local_dir, "unicalli_pro_chars7_nf4")
83
+ intern_vlm_path = os.path.join(checkpoint_path, "internvl_embedding")
84
+
85
+ # Fallback to full InternVL3 if embedding not in NF4 folder
86
+ if not os.path.exists(intern_vlm_path):
87
+ intern_vlm_path = "OpenGVLab/InternVL3-1B"
88
+
89
  generator = CalligraphyGenerator(
90
  model_name="flux-dev",
91
  device="cuda",
92
  offload=True, # Enable offload to save GPU memory
93
+ intern_vlm_path=intern_vlm_path,
94
+ checkpoint_path=checkpoint_path, # NF4 quantized model
95
  font_descriptions_path='dataset/chirography.json',
96
  author_descriptions_path='dataset/calligraphy_styles_en.json',
97
  use_deepspeed=False,
98
+ use_4bit_quantization=True, # Use bitsandbytes for true 4-bit inference
99
  )
100
  return generator
101
 
inference.py CHANGED
@@ -284,10 +284,10 @@ class CalligraphyGenerator:
284
  def _load_model_from_checkpoint(self, checkpoint_path: str, model_name: str, offload: bool, use_deepspeed: bool = False):
285
  """
286
  Load model from checkpoint without loading flux pretrained weights.
287
- This creates an empty model, initializes module embeddings, then loads your checkpoint.
288
 
289
  Args:
290
- checkpoint_path: Path to your checkpoint file
291
  model_name: flux model name (for config)
292
  offload: whether to offload to CPU
293
  use_deepspeed: whether using DeepSpeed (keeps model on CPU)
@@ -296,8 +296,6 @@ class CalligraphyGenerator:
296
  model with loaded checkpoint
297
  """
298
  print(f"Creating empty flux model structure...")
299
- # Load checkpoint on CPU first to save memory
300
- # If using DeepSpeed, keep on CPU; otherwise move to GPU after loading
301
  load_device = "cpu"
302
 
303
  # Create model structure without loading pretrained weights (using "meta" device)
@@ -312,9 +310,16 @@ class CalligraphyGenerator:
312
  print(f"Moving model to {load_device} for loading...")
313
  model = model.to_empty(device=load_device)
314
 
 
 
 
315
  # Load checkpoint
316
  print(f"Loading checkpoint from {checkpoint_path}")
317
- checkpoint = self._load_checkpoint_file(checkpoint_path)
 
 
 
 
318
 
319
  # Determine dtype from checkpoint and convert to float32
320
  first_tensor = next(iter(checkpoint.values()))
@@ -326,36 +331,175 @@ class CalligraphyGenerator:
326
  print(f"Converting checkpoint from {checkpoint_dtype} to float32...")
327
  checkpoint = {k: v.float() for k, v in checkpoint.items()}
328
 
329
- # Load weights into model (assign=True to use checkpoint tensors directly, preserving dtype)
330
  model.load_state_dict(checkpoint, strict=False, assign=True)
331
  print(f"Model dtype after loading: {next(model.parameters()).dtype}")
 
 
 
332
 
333
- # Apply 4-bit quantization if requested
334
  if hasattr(self, 'use_4bit_quantization') and self.use_4bit_quantization:
335
- print("Applying 4-bit quantization...")
336
- model = model.float() # 先转为 float32
337
- quantize(model, weights=qint4)
338
- freeze(model)
339
- model._is_quantized = True # 添加标记供 xflux_pipeline 检查
340
- print("4-bit quantization complete!")
341
-
342
- # Move to GPU only if NOT using DeepSpeed (DeepSpeed will handle device placement)
 
 
 
 
 
 
 
343
  if not use_deepspeed:
344
- print(f"Moving model to {self.device} and converting to float32...")
345
- model = model.to(device=self.device, dtype=torch.float32)
346
 
347
  # Enable optimized attention backends
348
  try:
349
- # Prefer FlashAttention if available (fastest)
350
  torch.backends.cuda.enable_flash_sdp(True)
351
  torch.backends.cuda.enable_mem_efficient_sdp(True)
352
- torch.backends.cuda.enable_math_sdp(False) # Disable slowest fallback
353
  print("Enabled FlashAttention / Memory-Efficient SDPA backends")
354
  except Exception as e:
355
  print(f"Could not configure SDPA backends: {e}")
356
 
357
  return model
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  def _init_deepspeed(self, model):
360
  """
361
  Initialize DeepSpeed for the model with ZeRO-3 inference optimization.
 
284
  def _load_model_from_checkpoint(self, checkpoint_path: str, model_name: str, offload: bool, use_deepspeed: bool = False):
285
  """
286
  Load model from checkpoint without loading flux pretrained weights.
287
+ Supports both regular checkpoints and NF4 quantized checkpoints.
288
 
289
  Args:
290
+ checkpoint_path: Path to your checkpoint file or NF4 model directory
291
  model_name: flux model name (for config)
292
  offload: whether to offload to CPU
293
  use_deepspeed: whether using DeepSpeed (keeps model on CPU)
 
296
  model with loaded checkpoint
297
  """
298
  print(f"Creating empty flux model structure...")
 
 
299
  load_device = "cpu"
300
 
301
  # Create model structure without loading pretrained weights (using "meta" device)
 
310
  print(f"Moving model to {load_device} for loading...")
311
  model = model.to_empty(device=load_device)
312
 
313
+ # Check if this is an NF4 quantized model
314
+ is_nf4 = self._is_nf4_checkpoint(checkpoint_path)
315
+
316
  # Load checkpoint
317
  print(f"Loading checkpoint from {checkpoint_path}")
318
+ if is_nf4:
319
+ print("Detected NF4 quantized model, dequantizing...")
320
+ checkpoint = self._load_nf4_checkpoint(checkpoint_path)
321
+ else:
322
+ checkpoint = self._load_checkpoint_file(checkpoint_path)
323
 
324
  # Determine dtype from checkpoint and convert to float32
325
  first_tensor = next(iter(checkpoint.values()))
 
331
  print(f"Converting checkpoint from {checkpoint_dtype} to float32...")
332
  checkpoint = {k: v.float() for k, v in checkpoint.items()}
333
 
334
+ # Load weights into model
335
  model.load_state_dict(checkpoint, strict=False, assign=True)
336
  print(f"Model dtype after loading: {next(model.parameters()).dtype}")
337
+
338
+ # Free checkpoint memory
339
+ del checkpoint
340
 
341
+ # Apply bitsandbytes 4-bit quantization if requested
342
  if hasattr(self, 'use_4bit_quantization') and self.use_4bit_quantization:
343
+ try:
344
+ import bitsandbytes as bnb
345
+ print("Applying bitsandbytes NF4 quantization for 4-bit inference...")
346
+ model = self._quantize_model_bnb(model)
347
+ model._is_quantized = True
348
+ print("bitsandbytes NF4 quantization complete!")
349
+ except ImportError:
350
+ print("bitsandbytes not available, using quanto quantization...")
351
+ model = model.float()
352
+ quantize(model, weights=qint4)
353
+ freeze(model)
354
+ model._is_quantized = True
355
+ print("quanto 4-bit quantization complete!")
356
+
357
+ # Move to GPU only if NOT using DeepSpeed
358
  if not use_deepspeed:
359
+ print(f"Moving model to {self.device}...")
360
+ model = model.to(self.device)
361
 
362
  # Enable optimized attention backends
363
  try:
 
364
  torch.backends.cuda.enable_flash_sdp(True)
365
  torch.backends.cuda.enable_mem_efficient_sdp(True)
366
+ torch.backends.cuda.enable_math_sdp(False)
367
  print("Enabled FlashAttention / Memory-Efficient SDPA backends")
368
  except Exception as e:
369
  print(f"Could not configure SDPA backends: {e}")
370
 
371
  return model
372
 
373
+ def _is_nf4_checkpoint(self, path: str) -> bool:
374
+ """Check if path contains an NF4 quantized checkpoint"""
375
+ if os.path.isdir(path):
376
+ return os.path.exists(os.path.join(path, "quantization_config.json"))
377
+ return False
378
+
379
+ def _load_nf4_checkpoint(self, checkpoint_dir: str) -> dict:
380
+ """
381
+ Load NF4 quantized checkpoint and dequantize to float tensors.
382
+
383
+ Args:
384
+ checkpoint_dir: Directory containing NF4 model files
385
+
386
+ Returns:
387
+ Dequantized state dict
388
+ """
389
+ from safetensors.torch import load_file as load_safetensors
390
+
391
+ # Load quantization config
392
+ config_path = os.path.join(checkpoint_dir, "quantization_config.json")
393
+ with open(config_path, 'r') as f:
394
+ quant_config = json.load(f)
395
+
396
+ block_size = quant_config.get("block_size", 64)
397
+ quantized_keys = set(quant_config.get("quantized_keys", []))
398
+
399
+ # Load index
400
+ index_path = os.path.join(checkpoint_dir, "model_nf4.safetensors.index.json")
401
+ with open(index_path, 'r') as f:
402
+ index = json.load(f)
403
+
404
+ # Load all shards
405
+ shard_files = sorted(set(index['weight_map'].values()))
406
+ print(f"Loading {len(shard_files)} NF4 shards...")
407
+
408
+ raw_state = {}
409
+ for shard_file in shard_files:
410
+ shard_path = os.path.join(checkpoint_dir, shard_file)
411
+ print(f" Loading {shard_file}...")
412
+ shard_data = load_safetensors(shard_path)
413
+ raw_state.update(shard_data)
414
+
415
+ # NF4 lookup table for dequantization
416
+ nf4_values = torch.tensor([
417
+ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
418
+ -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
419
+ 0.07958029955625534, 0.16093020141124725, 0.24611230850220, 0.33791524171829224,
420
+ 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
421
+ ], dtype=torch.float32)
422
+
423
+ # Dequantize
424
+ state_dict = {}
425
+ dequant_count = 0
426
+
427
+ for key in list(raw_state.keys()):
428
+ if key.endswith('.quant_data'):
429
+ base_key = key.replace('.quant_data', '')
430
+ if base_key in quantized_keys:
431
+ # Dequantize this tensor
432
+ quant_data = raw_state[f"{base_key}.quant_data"]
433
+ scales = raw_state[f"{base_key}.scales"]
434
+ shape = raw_state[f"{base_key}.shape"].tolist()
435
+ pad_len = raw_state[f"{base_key}.pad_len"].item()
436
+
437
+ # Unpack 4-bit values
438
+ high = (quant_data >> 4) & 0x0F
439
+ low = quant_data & 0x0F
440
+ indices = torch.stack([high, low], dim=-1).flatten().long()
441
+
442
+ # Lookup and reshape
443
+ values = nf4_values[indices]
444
+
445
+ # Apply scales
446
+ num_blocks = len(scales)
447
+ values = values[:num_blocks * block_size].reshape(num_blocks, block_size)
448
+ values = values * scales.float().unsqueeze(1)
449
+ values = values.flatten()
450
+
451
+ # Remove padding and reshape
452
+ if pad_len > 0:
453
+ values = values[:-pad_len]
454
+
455
+ state_dict[base_key] = values.reshape(shape)
456
+ dequant_count += 1
457
+ elif not any(key.endswith(s) for s in ['.scales', '.shape', '.block_size', '.pad_len']):
458
+ # Non-quantized tensor, keep as-is
459
+ state_dict[key] = raw_state[key]
460
+
461
+ print(f"Dequantized {dequant_count} tensors")
462
+ return state_dict
463
+
464
+ def _quantize_model_bnb(self, model):
465
+ """
466
+ Quantize model using bitsandbytes NF4.
467
+ Replaces Linear layers with Linear4bit for true 4-bit inference.
468
+ """
469
+ import bitsandbytes as bnb
470
+ import torch.nn as nn
471
+
472
+ def replace_linear_with_4bit(module, name=''):
473
+ for child_name, child in list(module.named_children()):
474
+ full_name = f"{name}.{child_name}" if name else child_name
475
+
476
+ if isinstance(child, nn.Linear):
477
+ # Create 4-bit linear layer
478
+ new_layer = bnb.nn.Linear4bit(
479
+ child.in_features,
480
+ child.out_features,
481
+ bias=child.bias is not None,
482
+ compute_dtype=torch.bfloat16,
483
+ compress_statistics=True,
484
+ quant_type='nf4'
485
+ )
486
+ # Copy weights (will be quantized when moved to GPU)
487
+ new_layer.weight = bnb.nn.Params4bit(
488
+ child.weight.data,
489
+ requires_grad=False,
490
+ quant_type='nf4'
491
+ )
492
+ if child.bias is not None:
493
+ new_layer.bias = nn.Parameter(child.bias.data)
494
+
495
+ setattr(module, child_name, new_layer)
496
+ else:
497
+ replace_linear_with_4bit(child, full_name)
498
+
499
+ print("Replacing Linear layers with Linear4bit...")
500
+ replace_linear_with_4bit(model)
501
+ return model
502
+
503
  def _init_deepspeed(self, model):
504
  """
505
  Initialize DeepSpeed for the model with ZeRO-3 inference optimization.
requirements.txt CHANGED
@@ -7,6 +7,7 @@ safetensors>=0.4.0
7
 
8
  # Model and inference
9
  optimum-quanto
 
10
  torch
11
  torchvision
12
  timm
 
7
 
8
  # Model and inference
9
  optimum-quanto
10
+ bitsandbytes>=0.41.0
11
  torch
12
  torchvision
13
  timm
src/flux/xflux_pipeline.py CHANGED
@@ -460,12 +460,94 @@ class XFluxSampler(XFluxPipeline):
460
  self.offload = False
461
  self.ref_latent = ref_latent
462
 
463
- self.embed_tokens = AutoModel.from_pretrained(
464
- intern_vlm_path,
465
- torch_dtype=torch.float32,
466
- device_map="cpu",
467
- trust_remote_code=True
468
- ).language_model.model.embed_tokens.eval()
469
- self.embed_tokens.requires_grad_(False)
470
- self.tokenizer = AutoTokenizer.from_pretrained(
471
- intern_vlm_path, trust_remote_code=True, use_fast=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  self.offload = False
461
  self.ref_latent = ref_latent
462
 
463
+ # Load embedding - try lightweight extracted version first, fallback to full model
464
+ self.embed_tokens, self.tokenizer = self._load_embedding(intern_vlm_path)
465
+
466
+ def _load_embedding(self, intern_vlm_path):
467
+ """
468
+ Load embedding layer and tokenizer.
469
+ Supports three modes:
470
+ 1. HuggingFace repo with internvl_embedding subfolder (e.g., TSXu/Unicalli_Pro)
471
+ 2. Lightweight: Load from extracted embedding files (embedding.safetensors + tokenizer)
472
+ 3. Full: Load from complete InternVL3 model (fallback)
473
+ """
474
+ import os
475
+ from safetensors.torch import load_file as load_safetensors
476
+
477
+ # Check if this is a HuggingFace model ID (contains '/' but not a local path)
478
+ if '/' in intern_vlm_path and not os.path.exists(intern_vlm_path):
479
+ print(f"Downloading internvl_embedding from HuggingFace: {intern_vlm_path}")
480
+ from huggingface_hub import snapshot_download
481
+ hf_token = os.environ.get("HF_TOKEN", None)
482
+
483
+ # Download only the internvl_embedding subfolder
484
+ local_dir = snapshot_download(
485
+ repo_id=intern_vlm_path,
486
+ allow_patterns=["internvl_embedding/*", "unicalli_pro_chars7_nf4/internvl_embedding/*"],
487
+ token=hf_token
488
+ )
489
+
490
+ # Check for internvl_embedding in different locations
491
+ possible_paths = [
492
+ os.path.join(local_dir, "internvl_embedding"),
493
+ os.path.join(local_dir, "unicalli_pro_chars7_nf4", "internvl_embedding"),
494
+ ]
495
+
496
+ for path in possible_paths:
497
+ if os.path.exists(path):
498
+ intern_vlm_path = path
499
+ print(f"Found internvl_embedding at: {intern_vlm_path}")
500
+ break
501
+ else:
502
+ print(f"Warning: internvl_embedding not found, falling back to full model")
503
+
504
+ # Check if this is an extracted embedding directory
505
+ embedding_file = os.path.join(intern_vlm_path, "embedding.safetensors")
506
+ config_file = os.path.join(intern_vlm_path, "embedding_config.json")
507
+
508
+ if os.path.exists(embedding_file) and os.path.exists(config_file):
509
+ # Lightweight mode: Load extracted embedding
510
+ print(f"Loading lightweight embedding from: {intern_vlm_path}")
511
+
512
+ import json
513
+ with open(config_file, 'r') as f:
514
+ config = json.load(f)
515
+
516
+ # Create embedding layer
517
+ embed_tokens = torch.nn.Embedding(
518
+ num_embeddings=config["num_embeddings"],
519
+ embedding_dim=config["embedding_dim"],
520
+ padding_idx=config.get("padding_idx", None)
521
+ )
522
+
523
+ # Load weights
524
+ state_dict = load_safetensors(embedding_file)
525
+ embed_tokens.load_state_dict(state_dict)
526
+ embed_tokens.eval()
527
+ embed_tokens.requires_grad_(False)
528
+
529
+ # Load tokenizer
530
+ tokenizer = AutoTokenizer.from_pretrained(
531
+ intern_vlm_path, trust_remote_code=True, use_fast=False
532
+ )
533
+
534
+ print(f"Loaded lightweight embedding: {config['num_embeddings']} x {config['embedding_dim']}")
535
+ return embed_tokens, tokenizer
536
+
537
+ else:
538
+ # Full mode: Load from complete InternVL3 model
539
+ print(f"Loading full InternVL3 model from: {intern_vlm_path}")
540
+
541
+ embed_tokens = AutoModel.from_pretrained(
542
+ intern_vlm_path,
543
+ torch_dtype=torch.float32,
544
+ device_map="cpu",
545
+ trust_remote_code=True
546
+ ).language_model.model.embed_tokens.eval()
547
+ embed_tokens.requires_grad_(False)
548
+
549
+ tokenizer = AutoTokenizer.from_pretrained(
550
+ intern_vlm_path, trust_remote_code=True, use_fast=False
551
+ )
552
+
553
+ return embed_tokens, tokenizer