OrlandoHugBot commited on
Commit
1e77965
·
verified ·
1 Parent(s): 24ace1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -116,14 +116,8 @@ def generate_image(
116
  transformer=transformer
117
  )
118
 
119
- # 关键修复:手动设置 pipeline 使用的设备
120
- # 这确保 _execution_device 返回正确的设备
121
- pipe._execution_device = device
122
-
123
- # 同时确保 processor 也在正确设备上处理
124
- # 修改 pipe 的 device 属性(如果存在)
125
- if hasattr(pipe, 'device'):
126
- pipe.device = device
127
 
128
  print(f"✅ Model loaded!")
129
  print(f" text_encoder device: {next(text_encoder.parameters()).device}")
@@ -178,7 +172,7 @@ def load_transformer(device, dtype):
178
  return QwenImageTransformer2DModel.from_pretrained(
179
  TRANSFORMER_PATH,
180
  torch_dtype=dtype,
181
- use_safetensors=False
182
  ).to(device).eval()
183
  else:
184
  return QwenImageTransformer2DModel.from_pretrained(
@@ -192,18 +186,21 @@ def load_transformer(device, dtype):
192
  # HuggingFace repo path
193
  path_parts = TRANSFORMER_PATH.split('/')
194
  if len(path_parts) >= 3:
195
- repo_id = '/'.join(path_parts[:2])
196
- subfolder = path_parts[2]
 
197
  return QwenImageTransformer2DModel.from_pretrained(
198
  repo_id,
199
  subfolder=subfolder,
200
  torch_dtype=dtype,
 
201
  ).to(device).eval()
202
  else:
203
  return QwenImageTransformer2DModel.from_pretrained(
204
  TRANSFORMER_PATH,
205
  subfolder='transformer',
206
  torch_dtype=dtype,
 
207
  ).to(device).eval()
208
 
209
 
 
116
  transformer=transformer
117
  )
118
 
119
+ # 注意:不需要手动设置 _execution_device
120
+ # 修复后的 pipeline_qwenimage_edit.py 会直接从 text_encoder 获取设备
 
 
 
 
 
 
121
 
122
  print(f"✅ Model loaded!")
123
  print(f" text_encoder device: {next(text_encoder.parameters()).device}")
 
172
  return QwenImageTransformer2DModel.from_pretrained(
173
  TRANSFORMER_PATH,
174
  torch_dtype=dtype,
175
+ use_safetensors=False # 使用 .bin 文件
176
  ).to(device).eval()
177
  else:
178
  return QwenImageTransformer2DModel.from_pretrained(
 
186
  # HuggingFace repo path
187
  path_parts = TRANSFORMER_PATH.split('/')
188
  if len(path_parts) >= 3:
189
+ # 路径格式: "Skywork/Unipic3-DMD/ema_transformer"
190
+ repo_id = '/'.join(path_parts[:2]) # "Skywork/Unipic3-DMD"
191
+ subfolder = '/'.join(path_parts[2:]) # "ema_transformer"
192
  return QwenImageTransformer2DModel.from_pretrained(
193
  repo_id,
194
  subfolder=subfolder,
195
  torch_dtype=dtype,
196
+ use_safetensors=False # 使用 .bin 文件
197
  ).to(device).eval()
198
  else:
199
  return QwenImageTransformer2DModel.from_pretrained(
200
  TRANSFORMER_PATH,
201
  subfolder='transformer',
202
  torch_dtype=dtype,
203
+ use_safetensors=False
204
  ).to(device).eval()
205
 
206