pipeline update
Browse files- pipeline.py +81 -176
pipeline.py
CHANGED
|
@@ -6938,177 +6938,8 @@ class InterleaveInferencer:
|
|
| 6938 |
elif isinstance(i, str):
|
| 6939 |
output_dict['text'] = i
|
| 6940 |
return output_dict
|
| 6941 |
-
|
| 6942 |
-
|
| 6943 |
-
# """
|
| 6944 |
-
# A “naive” Bagel wrapper that replicates your notebook exactly.
|
| 6945 |
-
# """
|
| 6946 |
-
|
| 6947 |
-
# model_cpu_offload_seq = "bagel_model"
|
| 6948 |
-
|
| 6949 |
-
# def __init__(
|
| 6950 |
-
# self,
|
| 6951 |
-
# torch_dtype: torch.dtype = torch.bfloat16,
|
| 6952 |
-
# ):
|
| 6953 |
-
# super().__init__()
|
| 6954 |
-
# self._dtype = torch_dtype
|
| 6955 |
-
# self._built = False
|
| 6956 |
-
# self._inferencer = None
|
| 6957 |
-
# self.new_token_ids: List[int] = []
|
| 6958 |
-
# # Hard‐code default weights path; overridden by from_pretrained
|
| 6959 |
-
# self.weights_root: Optional[str] = None
|
| 6960 |
-
# self.register_to_config(weights_root=self.weights_root, torch_dtype=torch_dtype)
|
| 6961 |
-
# repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
|
| 6962 |
-
# model_path = snapshot_download(repo_id=repo_id)
|
| 6963 |
-
# print("loaded from ", model_path)
|
| 6964 |
-
# # LLM config preparing
|
| 6965 |
-
# llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
|
| 6966 |
-
# llm_config.qk_norm = True
|
| 6967 |
-
# llm_config.tie_word_embeddings = False
|
| 6968 |
-
# llm_config.layer_module = "Qwen2MoTDecoderLayer"
|
| 6969 |
-
|
| 6970 |
-
# # ViT config preparing
|
| 6971 |
-
# vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
|
| 6972 |
-
# vit_config.rope = False
|
| 6973 |
-
# vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
|
| 6974 |
-
|
| 6975 |
-
# # VAE loading
|
| 6976 |
-
# vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
|
| 6977 |
-
|
| 6978 |
-
# # Bagel config preparing
|
| 6979 |
-
# config = BagelConfig(
|
| 6980 |
-
# visual_gen=True,
|
| 6981 |
-
# visual_und=True,
|
| 6982 |
-
# llm_config=llm_config,
|
| 6983 |
-
# vit_config=vit_config,
|
| 6984 |
-
# vae_config=vae_config,
|
| 6985 |
-
# vit_max_num_patch_per_side=70,
|
| 6986 |
-
# connector_act='gelu_pytorch_tanh',
|
| 6987 |
-
# latent_patch_size=2,
|
| 6988 |
-
# max_latent_size=64,
|
| 6989 |
-
# )
|
| 6990 |
-
|
| 6991 |
-
# with init_empty_weights():
|
| 6992 |
-
# language_model = Qwen2ForCausalLM(llm_config)
|
| 6993 |
-
# vit_model = SiglipVisionModel(vit_config)
|
| 6994 |
-
# model = Bagel(language_model, vit_model, config)
|
| 6995 |
-
# model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
|
| 6996 |
-
|
| 6997 |
-
# # Tokenizer Preparing
|
| 6998 |
-
# tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
|
| 6999 |
-
# tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
|
| 7000 |
-
|
| 7001 |
-
# # Image Transform Preparing
|
| 7002 |
-
# vae_transform = ImageTransform(1024, 512, 16)
|
| 7003 |
-
# vit_transform = ImageTransform(980, 224, 14)
|
| 7004 |
-
|
| 7005 |
-
# # set cuda device to 4
|
| 7006 |
-
|
| 7007 |
-
# max_mem_per_gpu = "40GiB" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.
|
| 7008 |
-
|
| 7009 |
-
# device_map = infer_auto_device_map(
|
| 7010 |
-
# model,
|
| 7011 |
-
# max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
|
| 7012 |
-
# no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
|
| 7013 |
-
# )
|
| 7014 |
-
# print(device_map)
|
| 7015 |
-
|
| 7016 |
-
# same_device_modules = [
|
| 7017 |
-
# 'language_model.model.embed_tokens',
|
| 7018 |
-
# 'time_embedder',
|
| 7019 |
-
# 'latent_pos_embed',
|
| 7020 |
-
# 'vae2llm',
|
| 7021 |
-
# 'llm2vae',
|
| 7022 |
-
# 'connector',
|
| 7023 |
-
# 'vit_pos_embed'
|
| 7024 |
-
# ]
|
| 7025 |
-
|
| 7026 |
-
# if torch.cuda.device_count() == 1:
|
| 7027 |
-
# first_device = device_map.get(same_device_modules[0], "cuda:0")
|
| 7028 |
-
# for k in same_device_modules:
|
| 7029 |
-
# if k in device_map:
|
| 7030 |
-
# device_map[k] = first_device
|
| 7031 |
-
# else:
|
| 7032 |
-
# device_map[k] = "cuda:0"
|
| 7033 |
-
# else:
|
| 7034 |
-
# first_device = device_map.get(same_device_modules[0])
|
| 7035 |
-
# for k in same_device_modules:
|
| 7036 |
-
# if k in device_map:
|
| 7037 |
-
# device_map[k] = first_device
|
| 7038 |
-
|
| 7039 |
-
# # Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8
|
| 7040 |
-
# model = load_checkpoint_and_dispatch(
|
| 7041 |
-
# model,
|
| 7042 |
-
# checkpoint=os.path.join(model_path, "ema.safetensors"),
|
| 7043 |
-
# device_map=device_map,
|
| 7044 |
-
# offload_buffers=True,
|
| 7045 |
-
# dtype=torch.bfloat16,
|
| 7046 |
-
# force_hooks=True,
|
| 7047 |
-
# offload_folder="/tmp/offload"
|
| 7048 |
-
# )
|
| 7049 |
-
|
| 7050 |
-
# model = model.eval()
|
| 7051 |
-
# print('Model loaded')
|
| 7052 |
-
|
| 7053 |
-
# self._inferencer = InterleaveInferencer(
|
| 7054 |
-
# model=model,
|
| 7055 |
-
# vae_model=vae_model,
|
| 7056 |
-
# tokenizer=tokenizer,
|
| 7057 |
-
# vae_transform=vae_transform,
|
| 7058 |
-
# vit_transform=vit_transform,
|
| 7059 |
-
# new_token_ids=new_token_ids
|
| 7060 |
-
# )
|
| 7061 |
-
|
| 7062 |
-
# seed = 42
|
| 7063 |
-
# random.seed(seed)
|
| 7064 |
-
# np.random.seed(seed)
|
| 7065 |
-
# torch.manual_seed(seed)
|
| 7066 |
-
# if torch.cuda.is_available():
|
| 7067 |
-
# torch.cuda.manual_seed(seed)
|
| 7068 |
-
# torch.cuda.manual_seed_all(seed)
|
| 7069 |
-
# torch.backends.cudnn.deterministic = True
|
| 7070 |
-
# torch.backends.cudnn.benchmark = False
|
| 7071 |
-
|
| 7072 |
-
|
| 7073 |
-
# @torch.no_grad()
|
| 7074 |
-
# def __call__(
|
| 7075 |
-
# self,
|
| 7076 |
-
# prompt: str,
|
| 7077 |
-
# think=False,
|
| 7078 |
-
# cfg_text_scale: float = 4.0,
|
| 7079 |
-
# cfg_img_scale: float = 1.0,
|
| 7080 |
-
# cfg_interval=(0.4, 1.0),
|
| 7081 |
-
# timestep_shift: float = 3.0,
|
| 7082 |
-
# num_timesteps: int = 50,
|
| 7083 |
-
# cfg_renorm_min: float = 0.0,
|
| 7084 |
-
# cfg_renorm_type: str = "global",
|
| 7085 |
-
# seed: Optional[int] = None,
|
| 7086 |
-
# output_type: str = "pil",
|
| 7087 |
-
# return_dict: bool = True,
|
| 7088 |
-
# **unused,
|
| 7089 |
-
# ):
|
| 7090 |
-
|
| 7091 |
-
# if seed is not None:
|
| 7092 |
-
# torch.manual_seed(seed)
|
| 7093 |
-
# if torch.cuda.is_available():
|
| 7094 |
-
# torch.cuda.manual_seed_all(seed)
|
| 7095 |
-
|
| 7096 |
-
# inference_kwargs = dict(
|
| 7097 |
-
# text=prompt,
|
| 7098 |
-
# think=think,
|
| 7099 |
-
# cfg_text_scale=cfg_text_scale,
|
| 7100 |
-
# cfg_img_scale=cfg_img_scale,
|
| 7101 |
-
# cfg_interval=list(cfg_interval),
|
| 7102 |
-
# timestep_shift=timestep_shift,
|
| 7103 |
-
# num_timesteps=num_timesteps,
|
| 7104 |
-
# cfg_renorm_min=cfg_renorm_min,
|
| 7105 |
-
# cfg_renorm_type=cfg_renorm_type,
|
| 7106 |
-
# )
|
| 7107 |
-
# result = self._inferencer(**inference_kwargs)
|
| 7108 |
-
# image = result["image"] if isinstance(result, dict) else result
|
| 7109 |
-
# if return_dict:
|
| 7110 |
-
# return {"images": [image]}
|
| 7111 |
-
# return [image]
|
| 7112 |
|
| 7113 |
class BagelPipeline(DiffusionPipeline):
|
| 7114 |
model_cpu_offload_seq = "bagel_model"
|
|
@@ -7130,11 +6961,85 @@ class BagelPipeline(DiffusionPipeline):
|
|
| 7130 |
new_token_ids= new_token_ids,
|
| 7131 |
)
|
| 7132 |
|
| 7133 |
-
def __call__(
|
| 7134 |
-
|
| 7135 |
-
|
| 7136 |
-
|
| 7137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7138 |
|
| 7139 |
def to(self, device):
|
| 7140 |
super().to(device) # moves registered modules
|
|
|
|
| 6938 |
elif isinstance(i, str):
|
| 6939 |
output_dict['text'] = i
|
| 6940 |
return output_dict
|
| 6941 |
+
|
| 6942 |
+
from diffusers import DiffusionPipeline, PipelineOutput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6943 |
|
| 6944 |
class BagelPipeline(DiffusionPipeline):
|
| 6945 |
model_cpu_offload_seq = "bagel_model"
|
|
|
|
| 6961 |
new_token_ids= new_token_ids,
|
| 6962 |
)
|
| 6963 |
|
| 6964 |
+
def __call__(
|
| 6965 |
+
self,
|
| 6966 |
+
*,
|
| 6967 |
+
image: Optional[Image.Image] = None,
|
| 6968 |
+
text: Optional[str] = None,
|
| 6969 |
+
think: bool = False,
|
| 6970 |
+
understanding_output: bool = False,
|
| 6971 |
+
**infer_kwargs
|
| 6972 |
+
) -> PipelineOutput:
|
| 6973 |
+
"""
|
| 6974 |
+
Supports:
|
| 6975 |
+
- text→image (pass text=…)
|
| 6976 |
+
- text→image + think (+ think=True)
|
| 6977 |
+
- image→image edit (pass image=…, text=…)
|
| 6978 |
+
- image→image+think (+ think=True)
|
| 6979 |
+
- image→understanding (+ understanding_output=True)
|
| 6980 |
+
Any other kwargs (cfg_text_scale, num_timesteps, etc.) override the defaults below.
|
| 6981 |
+
"""
|
| 6982 |
+
|
| 6983 |
+
if text is not None and image is None:
|
| 6984 |
+
defaults: Dict[str, Any] = {
|
| 6985 |
+
"cfg_text_scale": 4.0,
|
| 6986 |
+
"cfg_img_scale": 1.0,
|
| 6987 |
+
"cfg_interval": (0.4, 1.0),
|
| 6988 |
+
"timestep_shift": 3.0,
|
| 6989 |
+
"num_timesteps": 50,
|
| 6990 |
+
"cfg_renorm_min": 0.0,
|
| 6991 |
+
"cfg_renorm_type": "global",
|
| 6992 |
+
}
|
| 6993 |
+
if think:
|
| 6994 |
+
defaults.update({
|
| 6995 |
+
"max_think_token_n": 1000,
|
| 6996 |
+
"do_sample": False,
|
| 6997 |
+
"text_temperature": 0.3,
|
| 6998 |
+
})
|
| 6999 |
+
|
| 7000 |
+
elif image is not None and text is not None and not understanding_output:
|
| 7001 |
+
defaults = {
|
| 7002 |
+
"cfg_text_scale": 4.0,
|
| 7003 |
+
"cfg_img_scale": 2.0,
|
| 7004 |
+
"cfg_interval": (0.0, 1.0),
|
| 7005 |
+
"timestep_shift": 3.0,
|
| 7006 |
+
"num_timesteps": 50,
|
| 7007 |
+
"cfg_renorm_min": 0.0,
|
| 7008 |
+
"cfg_renorm_type": "text_channel",
|
| 7009 |
+
}
|
| 7010 |
+
if think:
|
| 7011 |
+
defaults.update({
|
| 7012 |
+
"max_think_token_n": 1000,
|
| 7013 |
+
"do_sample": False,
|
| 7014 |
+
"text_temperature": 0.3,
|
| 7015 |
+
})
|
| 7016 |
+
|
| 7017 |
+
elif image is not None and understanding_output:
|
| 7018 |
+
defaults = {
|
| 7019 |
+
"max_think_token_n": 1000,
|
| 7020 |
+
"do_sample": False,
|
| 7021 |
+
}
|
| 7022 |
+
|
| 7023 |
+
else:
|
| 7024 |
+
defaults = {}
|
| 7025 |
+
|
| 7026 |
+
for k, v in defaults.items():
|
| 7027 |
+
infer_kwargs.setdefault(k, v)
|
| 7028 |
+
|
| 7029 |
+
result: Dict[str, Any] = self._inferencer(
|
| 7030 |
+
image=image,
|
| 7031 |
+
text=text,
|
| 7032 |
+
think=think,
|
| 7033 |
+
understanding_output=understanding_output,
|
| 7034 |
+
**infer_kwargs,
|
| 7035 |
+
)
|
| 7036 |
+
|
| 7037 |
+
out_kwargs: Dict[str, Any] = {}
|
| 7038 |
+
if result.get("image") is not None:
|
| 7039 |
+
out_kwargs["images"] = [result["image"]]
|
| 7040 |
+
if result.get("text") is not None:
|
| 7041 |
+
out_kwargs["text"] = result["text"]
|
| 7042 |
+
return PipelineOutput(**out_kwargs)
|
| 7043 |
|
| 7044 |
def to(self, device):
|
| 7045 |
super().to(device) # moves registered modules
|