OrlandoHugBot commited on
Commit
1c02ce0
·
verified ·
1 Parent(s): 0427012

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -44,7 +44,7 @@ def get_dtype():
44
  """Get the appropriate dtype"""
45
  return torch.bfloat16 if torch.cuda.is_available() else torch.float32
46
 
47
- @GPU(duration=120)
48
  def generate_image(
49
  images: list[Image.Image],
50
  prompt: str,
@@ -79,25 +79,27 @@ def generate_image(
79
  MODEL_NAME, subfolder='scheduler'
80
  )
81
 
82
- # Load text encoder
 
83
  text_encoder = AutoModel.from_pretrained(
84
  MODEL_NAME,
85
  subfolder='text_encoder',
86
- torch_dtype=dtype
87
- ).to(device)
 
88
 
89
  # Load tokenizer & processor
90
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
91
  processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
92
 
93
- # Load transformer
94
- transformer = load_transformer(device, dtype)
95
 
96
  # Load VAE
97
  vae = AutoencoderKLQwenImage.from_pretrained(
98
  MODEL_NAME,
99
  subfolder='vae',
100
- torch_dtype=dtype
101
  ).to(device)
102
 
103
  # Create pipeline
@@ -146,12 +148,14 @@ def generate_image(
146
  return result
147
 
148
 
149
- def load_transformer(device, dtype):
150
- """Load transformer with proper path handling"""
151
  from diffusers import QwenImageTransformer2DModel
152
 
 
 
153
  if os.path.exists(TRANSFORMER_PATH):
154
- # Local path
155
  if os.path.isdir(TRANSFORMER_PATH):
156
  config_path = os.path.join(TRANSFORMER_PATH, "config.json")
157
  if os.path.exists(config_path):
@@ -169,7 +173,7 @@ def load_transformer(device, dtype):
169
  ).to(device)
170
  raise ValueError(f"Invalid transformer path: {TRANSFORMER_PATH}")
171
  else:
172
- # HuggingFace repo path
173
  path_parts = TRANSFORMER_PATH.split('/')
174
  if len(path_parts) >= 3:
175
  repo_id = '/'.join(path_parts[:2])
@@ -177,14 +181,16 @@ def load_transformer(device, dtype):
177
  return QwenImageTransformer2DModel.from_pretrained(
178
  repo_id,
179
  subfolder=subfolder,
180
- torch_dtype=dtype
181
- ).to(device)
 
182
  else:
183
  return QwenImageTransformer2DModel.from_pretrained(
184
  TRANSFORMER_PATH,
185
  subfolder='transformer',
186
- torch_dtype=dtype
187
- ).to(device)
 
188
 
189
 
190
  # ============================================================
 
44
  """Get the appropriate dtype"""
45
  return torch.bfloat16 if torch.cuda.is_available() else torch.float32
46
 
47
+ @GPU(duration=180)
48
  def generate_image(
49
  images: list[Image.Image],
50
  prompt: str,
 
79
  MODEL_NAME, subfolder='scheduler'
80
  )
81
 
82
+ # Load text encoder - use device_map="cuda" for ZeroGPU compatibility
83
+ # This ensures all submodules are properly placed on the GPU
84
  text_encoder = AutoModel.from_pretrained(
85
  MODEL_NAME,
86
  subfolder='text_encoder',
87
+ torch_dtype=dtype,
88
+ device_map="cuda" # Let transformers handle device placement for ZeroGPU
89
+ )
90
 
91
  # Load tokenizer & processor
92
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
93
  processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
94
 
95
+ # Load transformer - also use device_map for consistency
96
+ transformer = load_transformer(dtype)
97
 
98
  # Load VAE
99
  vae = AutoencoderKLQwenImage.from_pretrained(
100
  MODEL_NAME,
101
  subfolder='vae',
102
+ torch_dtype=dtype,
103
  ).to(device)
104
 
105
  # Create pipeline
 
148
  return result
149
 
150
 
151
+ def load_transformer(dtype):
152
+ """Load transformer with proper path handling for ZeroGPU"""
153
  from diffusers import QwenImageTransformer2DModel
154
 
155
+ device = get_device()
156
+
157
  if os.path.exists(TRANSFORMER_PATH):
158
+ # Local path - for ZeroGPU, still use .to(device) for local files
159
  if os.path.isdir(TRANSFORMER_PATH):
160
  config_path = os.path.join(TRANSFORMER_PATH, "config.json")
161
  if os.path.exists(config_path):
 
173
  ).to(device)
174
  raise ValueError(f"Invalid transformer path: {TRANSFORMER_PATH}")
175
  else:
176
+ # HuggingFace repo path - use device_map for ZeroGPU
177
  path_parts = TRANSFORMER_PATH.split('/')
178
  if len(path_parts) >= 3:
179
  repo_id = '/'.join(path_parts[:2])
 
181
  return QwenImageTransformer2DModel.from_pretrained(
182
  repo_id,
183
  subfolder=subfolder,
184
+ torch_dtype=dtype,
185
+ device_map="cuda"
186
+ )
187
  else:
188
  return QwenImageTransformer2DModel.from_pretrained(
189
  TRANSFORMER_PATH,
190
  subfolder='transformer',
191
+ torch_dtype=dtype,
192
+ device_map="cuda"
193
+ )
194
 
195
 
196
  # ============================================================