Commit
·
cb9cf0f
1
Parent(s):
c4ec650
Upload 2 files
Browse files- model_index.json +1 -1
- pipeline_emu2_gen.py +28 -12
model_index.json
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
"CLIPImageProcessor"
|
| 7 |
],
|
| 8 |
"multimodal_encoder": [
|
| 9 |
-
"transformers_modules.modeling_emu",
|
| 10 |
"EmuForCausalLM"
|
| 11 |
],
|
| 12 |
"safety_checker": [
|
|
|
|
| 6 |
"CLIPImageProcessor"
|
| 7 |
],
|
| 8 |
"multimodal_encoder": [
|
| 9 |
+
"transformers_modules.multimodal_encoder.modeling_emu",
|
| 10 |
"EmuForCausalLM"
|
| 11 |
],
|
| 12 |
"safety_checker": [
|
pipeline_emu2_gen.py
CHANGED
|
@@ -8,14 +8,14 @@
|
|
| 8 |
# Email : zhangfan@baai.ac.cn
|
| 9 |
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
| 10 |
# Create On : 2023-12-19 10:45
|
| 11 |
-
# Last Modified : 2023-12-
|
| 12 |
-
# File Name :
|
| 13 |
# Description :
|
| 14 |
#
|
| 15 |
# ===========================================================================================
|
| 16 |
|
| 17 |
from dataclasses import dataclass
|
| 18 |
-
from typing import List, Optional
|
| 19 |
|
| 20 |
from PIL import Image
|
| 21 |
import numpy as np
|
|
@@ -38,8 +38,8 @@ DEFAULT_IMG_PLACEHOLDER = "[<IMG_PLH>]"
|
|
| 38 |
|
| 39 |
@dataclass
|
| 40 |
class EmuVisualGenerationPipelineOutput(BaseOutput):
|
| 41 |
-
|
| 42 |
-
nsfw_content_detected: Optional[
|
| 43 |
|
| 44 |
|
| 45 |
class EmuVisualGenerationPipeline(DiffusionPipeline):
|
|
@@ -76,7 +76,7 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
|
|
| 76 |
TF.Normalize(mean=eva_mean, std=eva_std),
|
| 77 |
])
|
| 78 |
|
| 79 |
-
self.negative_prompt =
|
| 80 |
|
| 81 |
def device(self, module):
|
| 82 |
return next(module.parameters()).device
|
|
@@ -166,7 +166,10 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
|
|
| 166 |
|
| 167 |
# 7. Convert to PIL
|
| 168 |
images = self.numpy_to_pil(images)
|
| 169 |
-
return EmuVisualGenerationPipelineOutput(
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
def _prepare_and_encode_inputs(
|
| 172 |
self,
|
|
@@ -177,11 +180,14 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
|
|
| 177 |
device = self.device(self.multimodal_encoder.model.visual)
|
| 178 |
dtype = self.dtype(self.multimodal_encoder.model.visual)
|
| 179 |
|
|
|
|
| 180 |
text_prompt, image_prompt = "", []
|
| 181 |
for x in inputs:
|
| 182 |
if isinstance(x, str):
|
|
|
|
| 183 |
text_prompt += x
|
| 184 |
else:
|
|
|
|
| 185 |
text_prompt += placeholder
|
| 186 |
image_prompt.append(self.transform(x))
|
| 187 |
|
|
@@ -191,11 +197,21 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
|
|
| 191 |
image_prompt = torch.stack(image_prompt)
|
| 192 |
image_prompt = image_prompt.type(dtype).to(device)
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
if
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
return prompt
|
| 201 |
|
|
|
|
| 8 |
# Email : zhangfan@baai.ac.cn
|
| 9 |
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
| 10 |
# Create On : 2023-12-19 10:45
|
| 11 |
+
# Last Modified : 2023-12-25 07:59
|
| 12 |
+
# File Name : pipeline_emu2_gen.py
|
| 13 |
# Description :
|
| 14 |
#
|
| 15 |
# ===========================================================================================
|
| 16 |
|
| 17 |
from dataclasses import dataclass
|
| 18 |
+
from typing import List, Optional
|
| 19 |
|
| 20 |
from PIL import Image
|
| 21 |
import numpy as np
|
|
|
|
| 38 |
|
| 39 |
@dataclass
|
| 40 |
class EmuVisualGenerationPipelineOutput(BaseOutput):
|
| 41 |
+
image: Image.Image
|
| 42 |
+
nsfw_content_detected: Optional[bool]
|
| 43 |
|
| 44 |
|
| 45 |
class EmuVisualGenerationPipeline(DiffusionPipeline):
|
|
|
|
| 76 |
TF.Normalize(mean=eva_mean, std=eva_std),
|
| 77 |
])
|
| 78 |
|
| 79 |
+
self.negative_prompt = {}
|
| 80 |
|
| 81 |
def device(self, module):
|
| 82 |
return next(module.parameters()).device
|
|
|
|
| 166 |
|
| 167 |
# 7. Convert to PIL
|
| 168 |
images = self.numpy_to_pil(images)
|
| 169 |
+
return EmuVisualGenerationPipelineOutput(
|
| 170 |
+
image=images[0],
|
| 171 |
+
nsfw_content_detected=None if has_nsfw_concept is None else has_nsfw_concept[0],
|
| 172 |
+
)
|
| 173 |
|
| 174 |
def _prepare_and_encode_inputs(
|
| 175 |
self,
|
|
|
|
| 180 |
device = self.device(self.multimodal_encoder.model.visual)
|
| 181 |
dtype = self.dtype(self.multimodal_encoder.model.visual)
|
| 182 |
|
| 183 |
+
has_image, has_text = False, False
|
| 184 |
text_prompt, image_prompt = "", []
|
| 185 |
for x in inputs:
|
| 186 |
if isinstance(x, str):
|
| 187 |
+
has_text = True
|
| 188 |
text_prompt += x
|
| 189 |
else:
|
| 190 |
+
has_image = True
|
| 191 |
text_prompt += placeholder
|
| 192 |
image_prompt.append(self.transform(x))
|
| 193 |
|
|
|
|
| 197 |
image_prompt = torch.stack(image_prompt)
|
| 198 |
image_prompt = image_prompt.type(dtype).to(device)
|
| 199 |
|
| 200 |
+
if has_image and not has_text:
|
| 201 |
+
prompt = self.multimodal_encoder.model.encode_image(image=image_prompt)
|
| 202 |
+
if do_classifier_free_guidance:
|
| 203 |
+
key = "[NULL_IMAGE]"
|
| 204 |
+
if key not in self.negative_prompt:
|
| 205 |
+
negative_image = torch.zeros_like(image_prompt)
|
| 206 |
+
self.negative_prompt[key] = self.multimodal_encoder.model.encode_image(image=negative_image)
|
| 207 |
+
prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
|
| 208 |
+
else:
|
| 209 |
+
prompt = self.multimodal_encoder.generate_image(text=[text_prompt], image=image_prompt, tokenizer=self.tokenizer)
|
| 210 |
+
if do_classifier_free_guidance:
|
| 211 |
+
key = ""
|
| 212 |
+
if key not in self.negative_prompt:
|
| 213 |
+
self.negative_prompt[key] = self.multimodal_encoder.generate_image(text=[""], tokenizer=self.tokenizer)
|
| 214 |
+
prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
|
| 215 |
|
| 216 |
return prompt
|
| 217 |
|