Spaces:
Sleeping
Sleeping
Update inferencer.py
Browse files- inferencer.py +5 -5
inferencer.py
CHANGED
|
@@ -49,20 +49,20 @@ class UniPicV2Inferencer:
|
|
| 49 |
else:
|
| 50 |
transformer = SD3Transformer2DKontextModel.from_pretrained(
|
| 51 |
self.model_path, subfolder="transformer",
|
| 52 |
-
torch_dtype=torch.
|
| 53 |
)
|
| 54 |
|
| 55 |
# ===== 3. Load VAE =====
|
| 56 |
vae = AutoencoderKL.from_pretrained(
|
| 57 |
self.model_path, subfolder="vae",
|
| 58 |
-
torch_dtype=torch.
|
| 59 |
).to(self.device)
|
| 60 |
|
| 61 |
# ===== 4. Load Qwen2.5-VL (LMM) =====
|
| 62 |
try:
|
| 63 |
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 64 |
self.qwen_vl_path,
|
| 65 |
-
torch_dtype=torch.
|
| 66 |
attn_implementation="flash_attention_2",
|
| 67 |
device_map="auto",
|
| 68 |
).to(self.device)
|
|
@@ -70,7 +70,7 @@ class UniPicV2Inferencer:
|
|
| 70 |
except Exception:
|
| 71 |
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 72 |
self.qwen_vl_path,
|
| 73 |
-
torch_dtype=torch.
|
| 74 |
attn_implementation="sdpa",
|
| 75 |
device_map="auto",
|
| 76 |
).to(self.device)
|
|
@@ -87,7 +87,7 @@ class UniPicV2Inferencer:
|
|
| 87 |
# ===== 6. Load Conditioner =====
|
| 88 |
self.conditioner = StableDiffusion3Conditioner.from_pretrained(
|
| 89 |
self.model_path, subfolder="conditioner",
|
| 90 |
-
torch_dtype=torch.
|
| 91 |
).to(self.device)
|
| 92 |
|
| 93 |
# ===== 7. Load Scheduler =====
|
|
|
|
| 49 |
else:
|
| 50 |
transformer = SD3Transformer2DKontextModel.from_pretrained(
|
| 51 |
self.model_path, subfolder="transformer",
|
| 52 |
+
torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True
|
| 53 |
)
|
| 54 |
|
| 55 |
# ===== 3. Load VAE =====
|
| 56 |
vae = AutoencoderKL.from_pretrained(
|
| 57 |
self.model_path, subfolder="vae",
|
| 58 |
+
torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True
|
| 59 |
).to(self.device)
|
| 60 |
|
| 61 |
# ===== 4. Load Qwen2.5-VL (LMM) =====
|
| 62 |
try:
|
| 63 |
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 64 |
self.qwen_vl_path,
|
| 65 |
+
torch_dtype=torch.bfloat16,
|
| 66 |
attn_implementation="flash_attention_2",
|
| 67 |
device_map="auto",
|
| 68 |
).to(self.device)
|
|
|
|
| 70 |
except Exception:
|
| 71 |
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 72 |
self.qwen_vl_path,
|
| 73 |
+
torch_dtype=torch.bfloat16,
|
| 74 |
attn_implementation="sdpa",
|
| 75 |
device_map="auto",
|
| 76 |
).to(self.device)
|
|
|
|
| 87 |
# ===== 6. Load Conditioner =====
|
| 88 |
self.conditioner = StableDiffusion3Conditioner.from_pretrained(
|
| 89 |
self.model_path, subfolder="conditioner",
|
| 90 |
+
torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
| 91 |
).to(self.device)
|
| 92 |
|
| 93 |
# ===== 7. Load Scheduler =====
|