OrlandoHugBot commited on
Commit
fc3e76d
·
verified ·
1 Parent(s): 1c02ce0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -28
app.py CHANGED
@@ -69,38 +69,42 @@ def generate_image(
69
  )
70
  from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
71
 
72
- device = get_device()
73
- dtype = get_dtype()
 
74
 
75
  print(f"🚀 Loading model on {device}...")
 
 
76
 
77
- # Load scheduler
78
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
79
  MODEL_NAME, subfolder='scheduler'
80
  )
81
 
82
- # Load text encoder - use device_map="cuda" for ZeroGPU compatibility
83
- # This ensures all submodules are properly placed on the GPU
 
 
 
 
84
  text_encoder = AutoModel.from_pretrained(
85
  MODEL_NAME,
86
  subfolder='text_encoder',
87
  torch_dtype=dtype,
88
- device_map="cuda" # Let transformers handle device placement for ZeroGPU
89
- )
90
-
91
- # Load tokenizer & processor
92
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
93
- processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
94
 
95
- # Load transformer - also use device_map for consistency
96
- transformer = load_transformer(dtype)
 
97
 
98
  # Load VAE
 
99
  vae = AutoencoderKLQwenImage.from_pretrained(
100
  MODEL_NAME,
101
  subfolder='vae',
102
  torch_dtype=dtype,
103
- ).to(device)
104
 
105
  # Create pipeline
106
  pipe = QwenImageEditPipeline(
@@ -112,7 +116,20 @@ def generate_image(
112
  transformer=transformer
113
  )
114
 
115
- print(f"✅ Model loaded! Generating with {len(images)} image(s)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # Generate
118
  with torch.no_grad():
@@ -143,19 +160,18 @@ def generate_image(
143
 
144
  # Cleanup to free VRAM
145
  del pipe, transformer, vae, text_encoder
146
- torch.cuda.empty_cache()
 
147
 
148
  return result
149
 
150
 
151
- def load_transformer(dtype):
152
  """Load transformer with proper path handling for ZeroGPU"""
153
  from diffusers import QwenImageTransformer2DModel
154
 
155
- device = get_device()
156
-
157
  if os.path.exists(TRANSFORMER_PATH):
158
- # Local path - for ZeroGPU, still use .to(device) for local files
159
  if os.path.isdir(TRANSFORMER_PATH):
160
  config_path = os.path.join(TRANSFORMER_PATH, "config.json")
161
  if os.path.exists(config_path):
@@ -163,17 +179,17 @@ def load_transformer(dtype):
163
  TRANSFORMER_PATH,
164
  torch_dtype=dtype,
165
  use_safetensors=False
166
- ).to(device)
167
  else:
168
  return QwenImageTransformer2DModel.from_pretrained(
169
  TRANSFORMER_PATH,
170
  subfolder='transformer',
171
  torch_dtype=dtype,
172
  use_safetensors=False
173
- ).to(device)
174
  raise ValueError(f"Invalid transformer path: {TRANSFORMER_PATH}")
175
  else:
176
- # HuggingFace repo path - use device_map for ZeroGPU
177
  path_parts = TRANSFORMER_PATH.split('/')
178
  if len(path_parts) >= 3:
179
  repo_id = '/'.join(path_parts[:2])
@@ -182,15 +198,13 @@ def load_transformer(dtype):
182
  repo_id,
183
  subfolder=subfolder,
184
  torch_dtype=dtype,
185
- device_map="cuda"
186
- )
187
  else:
188
  return QwenImageTransformer2DModel.from_pretrained(
189
  TRANSFORMER_PATH,
190
  subfolder='transformer',
191
  torch_dtype=dtype,
192
- device_map="cuda"
193
- )
194
 
195
 
196
  # ============================================================
@@ -632,4 +646,4 @@ def create_demo():
632
  demo = create_demo()
633
 
634
  if __name__ == "__main__":
635
- demo.launch()
 
69
  )
70
  from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
71
 
72
+ # ZeroGPU: 必须在 @GPU 函数内部获取设备
73
+ device = torch.device("cuda:0") # 明确指定 cuda:0
74
+ dtype = torch.bfloat16
75
 
76
  print(f"🚀 Loading model on {device}...")
77
+ print(f" CUDA available: {torch.cuda.is_available()}")
78
+ print(f" CUDA device count: {torch.cuda.device_count()}")
79
 
80
+ # Load scheduler (CPU, no device needed)
81
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
82
  MODEL_NAME, subfolder='scheduler'
83
  )
84
 
85
+ # Load tokenizer & processor (CPU, no device needed)
86
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
87
+ processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
88
+
89
+ # Load text encoder - 直接加载到 CUDA
90
+ print(" Loading text_encoder...")
91
  text_encoder = AutoModel.from_pretrained(
92
  MODEL_NAME,
93
  subfolder='text_encoder',
94
  torch_dtype=dtype,
95
+ ).to(device).eval()
 
 
 
 
 
96
 
97
+ # Load transformer
98
+ print(" Loading transformer...")
99
+ transformer = load_transformer(device, dtype)
100
 
101
  # Load VAE
102
+ print(" Loading VAE...")
103
  vae = AutoencoderKLQwenImage.from_pretrained(
104
  MODEL_NAME,
105
  subfolder='vae',
106
  torch_dtype=dtype,
107
+ ).to(device).eval()
108
 
109
  # Create pipeline
110
  pipe = QwenImageEditPipeline(
 
116
  transformer=transformer
117
  )
118
 
119
+ # 关键修复:手动设置 pipeline 使用的设备
120
+ # 这确保 _execution_device 返回正确的设备
121
+ pipe._execution_device = device
122
+
123
+ # 同时确保 processor 也在正确设备上处理
124
+ # 修改 pipe 的 device 属性(如果存在)
125
+ if hasattr(pipe, 'device'):
126
+ pipe.device = device
127
+
128
+ print(f"✅ Model loaded!")
129
+ print(f" text_encoder device: {next(text_encoder.parameters()).device}")
130
+ print(f" transformer device: {next(transformer.parameters()).device}")
131
+ print(f" vae device: {next(vae.parameters()).device}")
132
+ print(f" Generating with {len(images)} image(s)...")
133
 
134
  # Generate
135
  with torch.no_grad():
 
160
 
161
  # Cleanup to free VRAM
162
  del pipe, transformer, vae, text_encoder
163
+ if torch.cuda.is_available():
164
+ torch.cuda.empty_cache()
165
 
166
  return result
167
 
168
 
169
+ def load_transformer(device, dtype):
170
  """Load transformer with proper path handling for ZeroGPU"""
171
  from diffusers import QwenImageTransformer2DModel
172
 
 
 
173
  if os.path.exists(TRANSFORMER_PATH):
174
+ # Local path
175
  if os.path.isdir(TRANSFORMER_PATH):
176
  config_path = os.path.join(TRANSFORMER_PATH, "config.json")
177
  if os.path.exists(config_path):
 
179
  TRANSFORMER_PATH,
180
  torch_dtype=dtype,
181
  use_safetensors=False
182
+ ).to(device).eval()
183
  else:
184
  return QwenImageTransformer2DModel.from_pretrained(
185
  TRANSFORMER_PATH,
186
  subfolder='transformer',
187
  torch_dtype=dtype,
188
  use_safetensors=False
189
+ ).to(device).eval()
190
  raise ValueError(f"Invalid transformer path: {TRANSFORMER_PATH}")
191
  else:
192
+ # HuggingFace repo path
193
  path_parts = TRANSFORMER_PATH.split('/')
194
  if len(path_parts) >= 3:
195
  repo_id = '/'.join(path_parts[:2])
 
198
  repo_id,
199
  subfolder=subfolder,
200
  torch_dtype=dtype,
201
+ ).to(device).eval()
 
202
  else:
203
  return QwenImageTransformer2DModel.from_pretrained(
204
  TRANSFORMER_PATH,
205
  subfolder='transformer',
206
  torch_dtype=dtype,
207
+ ).to(device).eval()
 
208
 
209
 
210
  # ============================================================
 
646
  demo = create_demo()
647
 
648
  if __name__ == "__main__":
649
+ demo.launch()