File size: 12,795 Bytes
5d87d58
cc772ef
 
 
 
 
 
 
5d87d58
 
 
 
 
 
1b4c6a3
cc772ef
5d87d58
 
 
 
1b4c6a3
 
 
5d87d58
 
 
bd76015
5d87d58
 
1b4c6a3
5d87d58
1b4c6a3
5d87d58
0cb6a43
 
5d87d58
1b4c6a3
469ca62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d87d58
1b4c6a3
5d87d58
1b4c6a3
5d87d58
 
1b4c6a3
 
 
 
 
 
 
 
5d87d58
1b4c6a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
---
base_model:
- deepseek-ai/Janus-Pro-7B
datasets:
- FreedomIntelligence/ShareGPT-4o-Image
language:
- en
library_name: transformers
license: mit
license_name: deepseek
license_link: LICENSE
pipeline_tag: any-to-any
tags:
- text-to-image
- text-and-image-to-image
- multimodal
- unified-model
---

<div align="center">
<h1>
  Janus-4o-7B
</h1>
</div>

<div align="center">
<a href="https://github.com/FreedomIntelligence/ShareGPT-4o-Image" target="_blank">🧰GitHub</a>  |  <a href="https://arxiv.org/abs/2506.18095" target="_blank">📃Paper</a>  |  <a href="https://huggingface.co/datasets/FreedomIntelligence/ShareGPT-4o-Image" target="_blank">📚ShareGPT-4o-Image</a>
</div>

## 1. Introduction

Janus-4o is a multimodal large language model (MLLM) capable of both **text-to-image** and **text-and-image-to-image** generation. It is fine-tuned from [Janus-Pro-7B](https://huggingface.co/deepseek-ai/Janus-Pro-7B) using the [ShareGPT-4o-Image](https://huggingface.co/datasets/FreedomIntelligence/ShareGPT-4o-Image) dataset to align Janus-Pro with GPT-4o image generation capabilities. Compared to Janus-Pro, Janus-4o newly supports text-and-image-to-image generation capabilities, along with notable improvements in text-to-image tasks.

> ⚠️ **Statement**: **ShareGPT-4o-Image** is a distilled dataset from GPT-4o-Image, offering 4o-level data quality (_referring to data, not model capability_). **Janus-4o** is a fine-tuned version of Janus-Pro on this dataset, with added image editing support. Fine-tuning brings noticeable gains in image generation, but **Janus-4o still lags behind GPT-4o-Image in overall performance**.


## 2. Quick Start
### Step 1: Install the [Janus](https://github.com/deepseek-ai/Janus) Library
```Bash
git clone https://github.com/deepseek-ai/Janus.git
cd Janus
pip install -e .
```

### Step 2: Inference
- **Text-to-Image Generation**
```Python
import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor

# Load model and processor
model_path = "FreedomIntelligence/Janus-4o-7B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True,torch_dtype=torch.bfloat16
)
vl_gpt = vl_gpt.cuda().eval()

# Define text-to-image generation function
def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature = 1.0, parallel_size = 2, cfg_weight = 5):

    torch.cuda.empty_cache()

    conversation = [
            {
             "role": "<|User|>",
            "content": input_prompt,
        },
        {"role": "<|Assistant|>", "content": ""},
    ]

    sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=conversation,
        sft_format=vl_chat_processor.sft_format,
        system_prompt="",
    )

    prompt = sft_format + vl_chat_processor.image_start_tag

    mmgpt = vl_gpt

    image_token_num_per_image = 576
    img_size = 384
    patch_size = 16

    with torch.inference_mode():
        input_ids = vl_chat_processor.tokenizer.encode(prompt)
        input_ids = torch.LongTensor(input_ids)

        tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
        for i in range(parallel_size*2):
            tokens[i, :] = input_ids
            if i % 2 != 0:
                tokens[i, 1:-1] = vl_chat_processor.pad_id

        inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

        generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

        for i in range(image_token_num_per_image):
            outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state

            logits = mmgpt.gen_head(hidden_states[:, -1, :])
            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]
            
            logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

        dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec

        os.makedirs(output_path, exist_ok=True)
        output_images = []
        for i in range(parallel_size):
            save_path = output_path.replace('.png','') + f'_{i}.png'
            PIL.Image.fromarray(visual_img[i]).save(save_path)
            output_images.append(save_path)
        return output_images

# Run
prompt = "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair"
image_output_path = "./test.png"
text_to_image_generate(prompt, image_output_path, vl_chat_processor, vl_gpt, parallel_size = 2)
```


- **2. Text-and-Image-to-Image Generation**

```Python
import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from dataclasses import dataclass
@dataclass
class VLChatProcessorOutput():
    sft_format: str
    input_ids: torch.Tensor
    pixel_values: torch.Tensor
    num_image_tokens: torch.IntTensor

    def __len__(self):
        return len(self.input_ids)

def process_image(image_paths,vl_chat_processor):
    images = [PIL.Image.open(image_path).convert("RGB") for image_path in image_paths]
    images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
    return images_outputs['pixel_values']

# Load model and processor
model_path = "FreedomIntelligence/Janus-4o-7B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True,torch_dtype=torch.bfloat16
)
vl_gpt = vl_gpt.cuda().eval()

# Define text+image-to-image generation function
def text_and_image_to_image_generate(input_prompt, input_image_path, output_path, vl_chat_processor, vl_gpt, temperature = 1.0, parallel_size = 2, cfg_weight = 5, cfg_weight2 = 5):
    torch.cuda.empty_cache()

    input_img_tokens = vl_chat_processor.image_start_tag + vl_chat_processor.image_tag*vl_chat_processor.num_image_tokens +vl_chat_processor.image_end_tag + vl_chat_processor.image_start_tag + vl_chat_processor.pad_tag*vl_chat_processor.num_image_tokens +vl_chat_processor.image_end_tag
    output_img_tokens = vl_chat_processor.image_start_tag 

    pre_data = []
    input_images = [input_image_path]
    img_len = len(input_images)
    prompts = input_img_tokens * img_len + input_prompt
    conversation = [
                    {"role": "<|User|>","content": prompts},
                    {"role": "<|Assistant|>", "content": ""}
                ]
    sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=conversation,
        sft_format=vl_chat_processor.sft_format,
        system_prompt="",
    )

    sft_format = sft_format + output_img_tokens

    mmgpt = vl_gpt

    image_token_num_per_image = 576
    img_size = 384
    patch_size = 16

    with torch.inference_mode():
        input_image_pixel_values = process_image(input_images,vl_chat_processor).to(torch.bfloat16).cuda()
        quant_input, emb_loss_input, info_input = mmgpt.gen_vision_model.encode(input_image_pixel_values)
        image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
        image_embeds_input = mmgpt.prepare_gen_img_embeds(image_tokens_input)

        input_ids =  torch.LongTensor(vl_chat_processor.tokenizer.encode(sft_format))
        
        encoder_pixel_values = process_image(input_images,vl_chat_processor).cuda()
        tokens = torch.zeros((parallel_size*3, len(input_ids)), dtype=torch.long)
        for i in range(parallel_size*3):
            tokens[i, :] = input_ids
            if i % 3 == 2:
                tokens[i, 1:-1] = vl_chat_processor.pad_id
                pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-2], num_image_tokens=[vl_chat_processor.num_image_tokens] * img_len))
                pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-1], num_image_tokens=[vl_chat_processor.num_image_tokens] * img_len))
                pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=None, input_ids=tokens[i], num_image_tokens=[]))

        prepare_inputs = vl_chat_processor.batchify(pre_data)

        inputs_embeds = mmgpt.prepare_inputs_embeds(
                    input_ids=tokens.cuda(),
                    pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(),
                    images_emb_mask=prepare_inputs['images_emb_mask'].cuda(),
                    images_seq_mask=prepare_inputs['images_seq_mask'].cuda()
                )

        image_gen_indices = (tokens == vl_chat_processor.image_end_id).nonzero()

        for ii, ind in enumerate(image_gen_indices):
            if ii % 4 == 0:
                offset = ind[1] + 2
                inputs_embeds[ind[0],offset: offset+image_embeds_input.shape[1],:] = image_embeds_input[(ii // 2) % img_len]

        generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

        for i in range(image_token_num_per_image):
            outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state

            logits = mmgpt.gen_head(hidden_states[:, -1, :])
            logit_cond_full = logits[0::3, :]
            logit_cond_part = logits[1::3, :]
            logit_uncond = logits[2::3, :]

            logit_cond = (logit_cond_full + cfg_weight2 * (logit_cond_part)) / (1 + cfg_weight2)
            logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

        dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec

        output_images = []
        for i in range(parallel_size):
            save_path = output_path.replace('.png','') + f'_{i}.png'
            PIL.Image.fromarray(visual_img[i]).save(save_path)
            output_images.append(save_path)
        return output_images

# Run
prompt = "Turn the image into a nighttime scene."
input_image_path = "./test_input.png"
image_output_path = "./test_output.png"
text_and_image_to_image_generate(prompt, input_image_path, image_output_path, vl_chat_processor, vl_gpt, parallel_size = 2)
```

## Citation

If you find our dataset helpful, please consider citing our work:

```
@misc{chen2025sharegpt4oimagealigningmultimodalmodels,
      title={ShareGPT-4o-Image: Aligning Multimodal Models with GPT-4o-Level Image Generation}, 
      author={Junying Chen and Zhenyang Cai and Pengcheng Chen and Shunian Chen and Ke Ji and Xidong Wang and Yunjin Yang and Benyou Wang},
      year={2025},
      eprint={2506.18095},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2506.18095}, 
}
```