fantos commited on
Commit
210bfc3
·
verified ·
1 Parent(s): 7a2ba9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -20
app.py CHANGED
@@ -1,19 +1,24 @@
1
  import spaces
2
- import argparse
3
  import os
4
  import time
5
  from os import path
6
- from safetensors.torch import load_file
7
- from huggingface_hub import hf_hub_download
8
 
 
9
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
10
- os.environ["TRANSFORMERS_CACHE"] = cache_path
11
- os.environ["HF_HUB_CACHE"] = cache_path
12
  os.environ["HF_HOME"] = cache_path
13
 
14
- import gradio as gr
15
- import torch
16
- from diffusers import FluxPipeline
 
 
 
 
 
 
 
17
 
18
  torch.backends.cuda.matmul.allow_tf32 = True
19
 
@@ -30,9 +35,34 @@ class timer:
30
  if not path.exists(cache_path):
31
  os.makedirs(cache_path, exist_ok=True)
32
 
33
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
34
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
35
- pipe.fuse_lora(lora_scale=0.125)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  pipe.to(device="cuda", dtype=torch.bfloat16)
37
 
38
  # Define example prompts
@@ -192,15 +222,29 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
192
  def process_image(height, width, steps, scales, prompt, seed):
193
  global pipe
194
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
195
- return pipe(
196
- prompt=[prompt],
197
- generator=torch.Generator().manual_seed(int(seed)),
198
- num_inference_steps=int(steps),
199
- guidance_scale=float(scales),
200
- height=int(height),
201
- width=int(width),
202
- max_sequence_length=256
203
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  generate_btn.click(
206
  process_image,
 
1
  import spaces
 
2
  import os
3
  import time
4
  from os import path
5
+ import gradio as gr
6
+ import torch
7
 
8
+ # Set cache paths before importing transformers/diffusers
9
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
 
 
10
  os.environ["HF_HOME"] = cache_path
11
 
12
+ # Import with proper error handling
13
+ try:
14
+ from diffusers import DiffusionPipeline
15
+ from diffusers.models import FluxTransformer2DModel
16
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
17
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
18
+ except ImportError as e:
19
+ print(f"Import error: {e}")
20
+ # Fallback to DiffusionPipeline if FluxPipeline is not available
21
+ from diffusers import DiffusionPipeline
22
 
23
  torch.backends.cuda.matmul.allow_tf32 = True
24
 
 
35
  if not path.exists(cache_path):
36
  os.makedirs(cache_path, exist_ok=True)
37
 
38
+ # Use DiffusionPipeline as a more stable alternative
39
+ try:
40
+ # Try to load as FluxPipeline first
41
+ pipe = DiffusionPipeline.from_pretrained(
42
+ "black-forest-labs/FLUX.1-dev",
43
+ torch_dtype=torch.bfloat16,
44
+ use_safetensors=True
45
+ )
46
+ except Exception as e:
47
+ print(f"Error loading FLUX pipeline: {e}")
48
+ # Fallback to a simpler configuration
49
+ pipe = DiffusionPipeline.from_pretrained(
50
+ "black-forest-labs/FLUX.1-dev",
51
+ torch_dtype=torch.float16,
52
+ safety_checker=None,
53
+ requires_safety_checker=False
54
+ )
55
+
56
+ # Try to load LoRA weights with error handling
57
+ try:
58
+ from huggingface_hub import hf_hub_download
59
+ lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
60
+ pipe.load_lora_weights(lora_path)
61
+ pipe.fuse_lora(lora_scale=0.125)
62
+ except Exception as e:
63
+ print(f"Warning: Could not load LoRA weights: {e}")
64
+ print("Continuing without LoRA acceleration...")
65
+
66
  pipe.to(device="cuda", dtype=torch.bfloat16)
67
 
68
  # Define example prompts
 
222
  def process_image(height, width, steps, scales, prompt, seed):
223
  global pipe
224
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
225
+ try:
226
+ # Try the standard call
227
+ result = pipe(
228
+ prompt=[prompt],
229
+ generator=torch.Generator().manual_seed(int(seed)),
230
+ num_inference_steps=int(steps),
231
+ guidance_scale=float(scales),
232
+ height=int(height),
233
+ width=int(width),
234
+ max_sequence_length=256
235
+ )
236
+ except TypeError:
237
+ # Fallback for different pipeline signatures
238
+ result = pipe(
239
+ prompt=prompt,
240
+ generator=torch.Generator().manual_seed(int(seed)),
241
+ num_inference_steps=int(steps),
242
+ guidance_scale=float(scales),
243
+ height=int(height),
244
+ width=int(width)
245
+ )
246
+
247
+ return result.images[0]
248
 
249
  generate_btn.click(
250
  process_image,