sharfikeg commited on
Commit
3d7e35b
·
verified ·
1 Parent(s): cf9ebfa

Bagel instructions from bagel's repo.

Browse files
Files changed (1) hide show
  1. README.md +183 -16
README.md CHANGED
@@ -13,25 +13,192 @@ library_name: diffusers
13
 
14
  # BAGEL-7B-MoT Alchemist 👨‍🔬
15
 
16
- [BAGEL-7B-MoT Alchemist](https://huggingface.co/yandex/bagel-alchemist) is T2I-finetuned version of [BAGEL-7B-MoT](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT) on [Alchemist](https://huggingface.co/datasets/yandex/alchemist) dataset, proposed in the research paper "Alchemist: Turning Public Text-to-Image Data into Generative Gold". Model generates images with improved aesthetics and complexity. Find more details about dataset and training details in the paper
17
 
18
- ## Using with Diffusers
19
- Upgrade to the latest version of the [🧨 diffusers library](https://github.com/huggingface/diffusers)
 
 
 
 
 
 
 
 
 
 
20
  ```
21
- pip install -U diffusers
 
 
22
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- and then you can run
25
- ```py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  import torch
27
- from diffusers import StableDiffusion3Pipeline
28
-
29
- pipe = StableDiffusion3Pipeline.from_pretrained("yandex/stable-diffusion-3.5-large-alchemist", torch_dtype=torch.bfloat16)
30
- pipe = pipe.to("cuda")
31
- image = pipe(
32
- "a man standing under a tree",
33
- num_inference_steps=28,
34
- guidance_scale=3.5,
35
- ).images[0]
36
- image.save("man.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ```
 
13
 
14
  # BAGEL-7B-MoT Alchemist 👨‍🔬
15
 
16
+ [BAGEL-7B-MoT Alchemist](https://huggingface.co/yandex/bagel-alchemist) is T2I-finetuned version of [BAGEL-7B-MoT](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT) on [Alchemist](https://huggingface.co/datasets/yandex/alchemist) dataset, proposed in the research paper "Alchemist: Turning Public Text-to-Image Data into Generative Gold". Model generates images with improved aesthetics and complexity. Find more details about dataset and training details in the paper.
17
 
18
+ ## Model usage
19
+ For installation and usage instructions let's follow the **BAGEL**'s official [GitHub repository](https://github.com/bytedance-seed/BAGEL):
20
+
21
+ 1️⃣ Set up environment
22
+
23
+ ```
24
+ git clone https://github.com/bytedance-seed/BAGEL.git
25
+ cd BAGEL
26
+ conda create -n bagel python=3.10 -y
27
+ conda activate bagel
28
+ pip install -r requirements.txt
29
+ pip install flash_attn==2.5.8 --no-build-isolation
30
  ```
31
+
32
+ 2️⃣ Download pretrained checkpoint
33
+
34
  ```
35
+ from huggingface_hub import snapshot_download
36
+
37
+ save_dir = "models/BAGEL-7B-MoT-alchemist"
38
+ repo_id = "yandex/BAGEL-7B-MoT-alchemist"
39
+ cache_dir = save_dir + "/cache"
40
+
41
+ snapshot_download(cache_dir=cache_dir,
42
+ local_dir=save_dir,
43
+ repo_id=repo_id,
44
+ local_dir_use_symlinks=False,
45
+ resume_download=True,
46
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
47
+ )
48
+ ```
49
+
50
+ 3️⃣ Load BAGEL-Alchemist. Note that it was trained on images with maximum side of 1408 px!
51
 
52
+ ```
53
+ import os
54
+ from copy import deepcopy
55
+ from typing import (
56
+ Any,
57
+ AsyncIterable,
58
+ Callable,
59
+ Dict,
60
+ Generator,
61
+ List,
62
+ NamedTuple,
63
+ Optional,
64
+ Tuple,
65
+ Union,
66
+ )
67
+ import requests
68
+ from io import BytesIO
69
+
70
+ from PIL import Image
71
  import torch
72
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
73
+
74
+ from data.transforms import ImageTransform
75
+ from data.data_utils import pil_img2rgb, add_special_tokens
76
+ from modeling.bagel import (
77
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
78
+ )
79
+ from modeling.qwen2 import Qwen2Tokenizer
80
+ from modeling.bagel.qwen2_navit import NaiveCache
81
+ from modeling.autoencoder import load_ae
82
+ from safetensors.torch import load_file
83
+
84
+ model_path = "/path/to/BAGEL-7B-MoT-alchemist/weights"
85
+
86
+ # LLM config preparing
87
+ llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
88
+ llm_config.qk_norm = True
89
+ llm_config.tie_word_embeddings = False
90
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
91
+
92
+ # ViT config preparing
93
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
94
+ vit_config.rope = False
95
+ vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
96
+
97
+ # VAE loading
98
+ vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
99
+
100
+ # Bagel config preparing
101
+ config = BagelConfig(
102
+ visual_gen=True,
103
+ visual_und=True,
104
+ llm_config=llm_config,
105
+ vit_config=vit_config,
106
+ vae_config=vae_config,
107
+ vit_max_num_patch_per_side=70,
108
+ connector_act='gelu_pytorch_tanh',
109
+ latent_patch_size=2,
110
+ max_latent_size=88, # max_latent_size is 88 for BAGEL-alchemist!
111
+ )
112
+
113
+ with init_empty_weights():
114
+ language_model = Qwen2ForCausalLM(llm_config)
115
+ vit_model = SiglipVisionModel(vit_config)
116
+ model = Bagel(language_model, vit_model, config)
117
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
118
+
119
+ # Tokenizer Preparing
120
+ tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
121
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
122
+
123
+ # Image Transform Preparing
124
+ vae_transform = ImageTransform(1408, 512, 16) # maximum image side is 1408 for BAGEL-alchemist!
125
+ vit_transform = ImageTransform(980, 224, 14)
126
+
127
+ max_mem_per_gpu = "40GiB" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.
128
+
129
+ device_map = infer_auto_device_map(
130
+ model,
131
+ max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
132
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
133
+ )
134
+ print(device_map)
135
+
136
+ same_device_modules = [
137
+ 'language_model.model.embed_tokens',
138
+ 'time_embedder',
139
+ 'latent_pos_embed',
140
+ 'vae2llm',
141
+ 'llm2vae',
142
+ 'connector',
143
+ 'vit_pos_embed'
144
+ ]
145
+
146
+ if torch.cuda.device_count() == 1:
147
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
148
+ for k in same_device_modules:
149
+ if k in device_map:
150
+ device_map[k] = first_device
151
+ else:
152
+ device_map[k] = "cuda:0"
153
+ else:
154
+ first_device = device_map.get(same_device_modules[0])
155
+ for k in same_device_modules:
156
+ if k in device_map:
157
+ device_map[k] = first_device
158
+
159
+ # Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8
160
+ model = load_checkpoint_and_dispatch(
161
+ model,
162
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
163
+ device_map=device_map,
164
+ offload_buffers=True,
165
+ dtype=torch.bfloat16,
166
+ force_hooks=True,
167
+ offload_folder="/tmp/offload"
168
+ )
169
+
170
+ model = model.eval()
171
+ print('Model loaded')
172
+ ```
173
+
174
+ 4️⃣ Follow final instructions for inference, e.g. T2I inference
175
+
176
+ ```
177
+ from inferencer import InterleaveInferencer
178
+
179
+ inferencer = InterleaveInferencer(
180
+ model=model,
181
+ vae_model=vae_model,
182
+ tokenizer=tokenizer,
183
+ vae_transform=vae_transform,
184
+ vit_transform=vit_transform,
185
+ new_token_ids=new_token_ids
186
+ )
187
+
188
+ inference_hyper=dict(
189
+ cfg_text_scale=6.0,
190
+ cfg_img_scale=1.0,
191
+ cfg_interval=[0.0, 1.0],
192
+ timestep_shift=3.0,
193
+ num_timesteps=50,
194
+ cfg_renorm_min=0.0,
195
+ cfg_renorm_type="global",
196
+ )
197
+
198
+ prompt = "A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
199
+
200
+ print(prompt)
201
+ print('-' * 10)
202
+ output_dict = inferencer(text=prompt, **inference_hyper)
203
+ display(output_dict['image'])
204
  ```