para-lost commited on
Commit
782d4ce
·
1 Parent(s): 4897a76

pipeline update

Browse files
Files changed (1) hide show
  1. pipeline.py +81 -176
pipeline.py CHANGED
@@ -6938,177 +6938,8 @@ class InterleaveInferencer:
6938
  elif isinstance(i, str):
6939
  output_dict['text'] = i
6940
  return output_dict
6941
-
6942
- # class BagelPipeline(DiffusionPipeline):
6943
- # """
6944
- # A “naive” Bagel wrapper that replicates your notebook exactly.
6945
- # """
6946
-
6947
- # model_cpu_offload_seq = "bagel_model"
6948
-
6949
- # def __init__(
6950
- # self,
6951
- # torch_dtype: torch.dtype = torch.bfloat16,
6952
- # ):
6953
- # super().__init__()
6954
- # self._dtype = torch_dtype
6955
- # self._built = False
6956
- # self._inferencer = None
6957
- # self.new_token_ids: List[int] = []
6958
- # # Hard‐code default weights path; overridden by from_pretrained
6959
- # self.weights_root: Optional[str] = None
6960
- # self.register_to_config(weights_root=self.weights_root, torch_dtype=torch_dtype)
6961
- # repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
6962
- # model_path = snapshot_download(repo_id=repo_id)
6963
- # print("loaded from ", model_path)
6964
- # # LLM config preparing
6965
- # llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
6966
- # llm_config.qk_norm = True
6967
- # llm_config.tie_word_embeddings = False
6968
- # llm_config.layer_module = "Qwen2MoTDecoderLayer"
6969
-
6970
- # # ViT config preparing
6971
- # vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
6972
- # vit_config.rope = False
6973
- # vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
6974
-
6975
- # # VAE loading
6976
- # vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
6977
-
6978
- # # Bagel config preparing
6979
- # config = BagelConfig(
6980
- # visual_gen=True,
6981
- # visual_und=True,
6982
- # llm_config=llm_config,
6983
- # vit_config=vit_config,
6984
- # vae_config=vae_config,
6985
- # vit_max_num_patch_per_side=70,
6986
- # connector_act='gelu_pytorch_tanh',
6987
- # latent_patch_size=2,
6988
- # max_latent_size=64,
6989
- # )
6990
-
6991
- # with init_empty_weights():
6992
- # language_model = Qwen2ForCausalLM(llm_config)
6993
- # vit_model = SiglipVisionModel(vit_config)
6994
- # model = Bagel(language_model, vit_model, config)
6995
- # model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
6996
-
6997
- # # Tokenizer Preparing
6998
- # tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
6999
- # tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
7000
-
7001
- # # Image Transform Preparing
7002
- # vae_transform = ImageTransform(1024, 512, 16)
7003
- # vit_transform = ImageTransform(980, 224, 14)
7004
-
7005
- # # set cuda device to 4
7006
-
7007
- # max_mem_per_gpu = "40GiB" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.
7008
-
7009
- # device_map = infer_auto_device_map(
7010
- # model,
7011
- # max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
7012
- # no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
7013
- # )
7014
- # print(device_map)
7015
-
7016
- # same_device_modules = [
7017
- # 'language_model.model.embed_tokens',
7018
- # 'time_embedder',
7019
- # 'latent_pos_embed',
7020
- # 'vae2llm',
7021
- # 'llm2vae',
7022
- # 'connector',
7023
- # 'vit_pos_embed'
7024
- # ]
7025
-
7026
- # if torch.cuda.device_count() == 1:
7027
- # first_device = device_map.get(same_device_modules[0], "cuda:0")
7028
- # for k in same_device_modules:
7029
- # if k in device_map:
7030
- # device_map[k] = first_device
7031
- # else:
7032
- # device_map[k] = "cuda:0"
7033
- # else:
7034
- # first_device = device_map.get(same_device_modules[0])
7035
- # for k in same_device_modules:
7036
- # if k in device_map:
7037
- # device_map[k] = first_device
7038
-
7039
- # # Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8
7040
- # model = load_checkpoint_and_dispatch(
7041
- # model,
7042
- # checkpoint=os.path.join(model_path, "ema.safetensors"),
7043
- # device_map=device_map,
7044
- # offload_buffers=True,
7045
- # dtype=torch.bfloat16,
7046
- # force_hooks=True,
7047
- # offload_folder="/tmp/offload"
7048
- # )
7049
-
7050
- # model = model.eval()
7051
- # print('Model loaded')
7052
-
7053
- # self._inferencer = InterleaveInferencer(
7054
- # model=model,
7055
- # vae_model=vae_model,
7056
- # tokenizer=tokenizer,
7057
- # vae_transform=vae_transform,
7058
- # vit_transform=vit_transform,
7059
- # new_token_ids=new_token_ids
7060
- # )
7061
-
7062
- # seed = 42
7063
- # random.seed(seed)
7064
- # np.random.seed(seed)
7065
- # torch.manual_seed(seed)
7066
- # if torch.cuda.is_available():
7067
- # torch.cuda.manual_seed(seed)
7068
- # torch.cuda.manual_seed_all(seed)
7069
- # torch.backends.cudnn.deterministic = True
7070
- # torch.backends.cudnn.benchmark = False
7071
-
7072
-
7073
- # @torch.no_grad()
7074
- # def __call__(
7075
- # self,
7076
- # prompt: str,
7077
- # think=False,
7078
- # cfg_text_scale: float = 4.0,
7079
- # cfg_img_scale: float = 1.0,
7080
- # cfg_interval=(0.4, 1.0),
7081
- # timestep_shift: float = 3.0,
7082
- # num_timesteps: int = 50,
7083
- # cfg_renorm_min: float = 0.0,
7084
- # cfg_renorm_type: str = "global",
7085
- # seed: Optional[int] = None,
7086
- # output_type: str = "pil",
7087
- # return_dict: bool = True,
7088
- # **unused,
7089
- # ):
7090
-
7091
- # if seed is not None:
7092
- # torch.manual_seed(seed)
7093
- # if torch.cuda.is_available():
7094
- # torch.cuda.manual_seed_all(seed)
7095
-
7096
- # inference_kwargs = dict(
7097
- # text=prompt,
7098
- # think=think,
7099
- # cfg_text_scale=cfg_text_scale,
7100
- # cfg_img_scale=cfg_img_scale,
7101
- # cfg_interval=list(cfg_interval),
7102
- # timestep_shift=timestep_shift,
7103
- # num_timesteps=num_timesteps,
7104
- # cfg_renorm_min=cfg_renorm_min,
7105
- # cfg_renorm_type=cfg_renorm_type,
7106
- # )
7107
- # result = self._inferencer(**inference_kwargs)
7108
- # image = result["image"] if isinstance(result, dict) else result
7109
- # if return_dict:
7110
- # return {"images": [image]}
7111
- # return [image]
7112
 
7113
  class BagelPipeline(DiffusionPipeline):
7114
  model_cpu_offload_seq = "bagel_model"
@@ -7130,11 +6961,85 @@ class BagelPipeline(DiffusionPipeline):
7130
  new_token_ids= new_token_ids,
7131
  )
7132
 
7133
- def __call__(self, prompt: str, **infer_kwargs):
7134
- result = self._inferencer(text=prompt, **infer_kwargs)
7135
- img = result["image"] if isinstance(result, dict) else result
7136
- return {"images": [img]}
7137
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7138
 
7139
  def to(self, device):
7140
  super().to(device) # moves registered modules
 
6938
  elif isinstance(i, str):
6939
  output_dict['text'] = i
6940
  return output_dict
6941
+
6942
+ from diffusers import DiffusionPipeline, PipelineOutput
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6943
 
6944
  class BagelPipeline(DiffusionPipeline):
6945
  model_cpu_offload_seq = "bagel_model"
 
6961
  new_token_ids= new_token_ids,
6962
  )
6963
 
6964
+ def __call__(
6965
+ self,
6966
+ *,
6967
+ image: Optional[Image.Image] = None,
6968
+ text: Optional[str] = None,
6969
+ think: bool = False,
6970
+ understanding_output: bool = False,
6971
+ **infer_kwargs
6972
+ ) -> PipelineOutput:
6973
+ """
6974
+ Supports:
6975
+ - text→image (pass text=…)
6976
+ - text→image + think (+ think=True)
6977
+ - image→image edit (pass image=…, text=…)
6978
+ - image→image+think (+ think=True)
6979
+ - image→understanding (+ understanding_output=True)
6980
+ Any other kwargs (cfg_text_scale, num_timesteps, etc.) override the defaults below.
6981
+ """
6982
+
6983
+ if text is not None and image is None:
6984
+ defaults: Dict[str, Any] = {
6985
+ "cfg_text_scale": 4.0,
6986
+ "cfg_img_scale": 1.0,
6987
+ "cfg_interval": (0.4, 1.0),
6988
+ "timestep_shift": 3.0,
6989
+ "num_timesteps": 50,
6990
+ "cfg_renorm_min": 0.0,
6991
+ "cfg_renorm_type": "global",
6992
+ }
6993
+ if think:
6994
+ defaults.update({
6995
+ "max_think_token_n": 1000,
6996
+ "do_sample": False,
6997
+ "text_temperature": 0.3,
6998
+ })
6999
+
7000
+ elif image is not None and text is not None and not understanding_output:
7001
+ defaults = {
7002
+ "cfg_text_scale": 4.0,
7003
+ "cfg_img_scale": 2.0,
7004
+ "cfg_interval": (0.0, 1.0),
7005
+ "timestep_shift": 3.0,
7006
+ "num_timesteps": 50,
7007
+ "cfg_renorm_min": 0.0,
7008
+ "cfg_renorm_type": "text_channel",
7009
+ }
7010
+ if think:
7011
+ defaults.update({
7012
+ "max_think_token_n": 1000,
7013
+ "do_sample": False,
7014
+ "text_temperature": 0.3,
7015
+ })
7016
+
7017
+ elif image is not None and understanding_output:
7018
+ defaults = {
7019
+ "max_think_token_n": 1000,
7020
+ "do_sample": False,
7021
+ }
7022
+
7023
+ else:
7024
+ defaults = {}
7025
+
7026
+ for k, v in defaults.items():
7027
+ infer_kwargs.setdefault(k, v)
7028
+
7029
+ result: Dict[str, Any] = self._inferencer(
7030
+ image=image,
7031
+ text=text,
7032
+ think=think,
7033
+ understanding_output=understanding_output,
7034
+ **infer_kwargs,
7035
+ )
7036
+
7037
+ out_kwargs: Dict[str, Any] = {}
7038
+ if result.get("image") is not None:
7039
+ out_kwargs["images"] = [result["image"]]
7040
+ if result.get("text") is not None:
7041
+ out_kwargs["text"] = result["text"]
7042
+ return PipelineOutput(**out_kwargs)
7043
 
7044
  def to(self, device):
7045
  super().to(device) # moves registered modules