Tianshuo-Xu commited on
Commit
e51b773
·
1 Parent(s): 39d3dc3

optimize cold start with local cache paths and font resolution

Browse files
Files changed (3) hide show
  1. app.py +9 -5
  2. inference.py +3 -1
  3. src/flux/util.py +4 -2
app.py CHANGED
@@ -86,10 +86,12 @@ _cached_model_dir = None
86
  # ============================================================
87
  _preloaded_embedding = None
88
  _preloaded_tokenizer = None
 
 
89
 
90
  def preload_model_files():
91
  """Pre-download model files to cache at startup (no GPU needed)"""
92
- global _preloaded_embedding, _preloaded_tokenizer
93
  from huggingface_hub import snapshot_download, hf_hub_download
94
 
95
  hf_token = os.environ.get("HF_TOKEN", None)
@@ -108,21 +110,23 @@ def preload_model_files():
108
 
109
  # 2. T5 text encoder
110
  try:
111
- snapshot_download(
112
  "xlabs-ai/xflux_text_encoders",
113
  token=hf_token
114
  )
115
- print(" T5 text encoder cached")
 
116
  except Exception as e:
117
  print(f"Warning: Could not pre-download T5: {e}")
118
 
119
  # 3. CLIP text encoder
120
  try:
121
- snapshot_download(
122
  "openai/clip-vit-large-patch14",
123
  token=hf_token
124
  )
125
- print(" CLIP text encoder cached")
 
126
  except Exception as e:
127
  print(f"Warning: Could not pre-download CLIP: {e}")
128
 
 
86
  # ============================================================
87
  _preloaded_embedding = None
88
  _preloaded_tokenizer = None
89
+ _cached_t5_dir = None
90
+ _cached_clip_dir = None
91
 
92
  def preload_model_files():
93
  """Pre-download model files to cache at startup (no GPU needed)"""
94
+ global _preloaded_embedding, _preloaded_tokenizer, _cached_t5_dir, _cached_clip_dir
95
  from huggingface_hub import snapshot_download, hf_hub_download
96
 
97
  hf_token = os.environ.get("HF_TOKEN", None)
 
110
 
111
  # 2. T5 text encoder
112
  try:
113
+ _cached_t5_dir = snapshot_download(
114
  "xlabs-ai/xflux_text_encoders",
115
  token=hf_token
116
  )
117
+ os.environ["XFLUX_TEXT_ENCODER_PATH"] = _cached_t5_dir
118
+ print(f"✓ T5 text encoder cached at: {_cached_t5_dir}")
119
  except Exception as e:
120
  print(f"Warning: Could not pre-download T5: {e}")
121
 
122
  # 3. CLIP text encoder
123
  try:
124
+ _cached_clip_dir = snapshot_download(
125
  "openai/clip-vit-large-patch14",
126
  token=hf_token
127
  )
128
+ os.environ["XFLUX_CLIP_ENCODER_PATH"] = _cached_clip_dir
129
+ print(f"✓ CLIP text encoder cached at: {_cached_clip_dir}")
130
  except Exception as e:
131
  print(f"Warning: Could not pre-download CLIP: {e}")
132
 
inference.py CHANGED
@@ -288,7 +288,9 @@ class CalligraphyGenerator:
288
  )
289
 
290
  # Font for generating condition images
291
- self.font_path = self._ensure_font_exists("./FangZhengKaiTiFanTi-1.ttf")
 
 
292
  self.default_font_size = 102 # 128 * 0.8
293
 
294
  def _ensure_font_exists(self, font_path: str) -> str:
 
288
  )
289
 
290
  # Font for generating condition images
291
+ project_root = os.path.dirname(os.path.abspath(__file__))
292
+ local_font_path = os.path.join(project_root, "FangZhengKaiTiFanTi-1.ttf")
293
+ self.font_path = self._ensure_font_exists(local_font_path)
294
  self.default_font_size = 102 # 128 * 0.8
295
 
296
  def _ensure_font_exists(self, font_path: str) -> str:
src/flux/util.py CHANGED
@@ -365,11 +365,13 @@ def load_controlnet(name, device, transformer=None):
365
 
366
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
367
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
368
- return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.float32).to(device)
 
369
  # return HFEmbedder("google/mt5-base", max_length=max_length, torch_dtype=torch.float32).to(device)
370
 
371
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
372
- return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.float32).to(device)
 
373
 
374
 
375
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
 
365
 
366
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
367
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
368
+ t5_source = os.environ.get("XFLUX_TEXT_ENCODER_PATH", "xlabs-ai/xflux_text_encoders")
369
+ return HFEmbedder(t5_source, max_length=max_length, torch_dtype=torch.float32).to(device)
370
  # return HFEmbedder("google/mt5-base", max_length=max_length, torch_dtype=torch.float32).to(device)
371
 
372
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
373
+ clip_source = os.environ.get("XFLUX_CLIP_ENCODER_PATH", "openai/clip-vit-large-patch14")
374
+ return HFEmbedder(clip_source, max_length=77, torch_dtype=torch.float32).to(device)
375
 
376
 
377
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: