Nonnya commited on
Commit
e203a48
·
verified ·
1 Parent(s): 16f6bba

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. Janus_colab_demo.ipynb +0 -0
  2. README.md +2 -8
  3. app.py +224 -0
  4. app_janusflow.py +247 -0
  5. app_januspro.py +245 -0
  6. fastapi_app.py +178 -0
  7. fastapi_client.py +78 -0
Janus_colab_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Demo
3
- emoji: 🏃
4
- colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.17.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: demo
3
+ app_file: app_januspro.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.17.1
 
 
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
5
+ from PIL import Image
6
+
7
+ import numpy as np
8
+
9
+
10
+ # Load model and processor
11
+ model_path = "deepseek-ai/Janus-1.3B"
12
+ config = AutoConfig.from_pretrained(model_path)
13
+ language_config = config.language_config
14
+ language_config._attn_implementation = 'eager'
15
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
16
+ language_config=language_config,
17
+ trust_remote_code=True)
18
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
19
+
20
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
21
+ tokenizer = vl_chat_processor.tokenizer
22
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ # Multimodal Understanding function
24
+ @torch.inference_mode()
25
+ # Multimodal Understanding function
26
+ def multimodal_understanding(image, question, seed, top_p, temperature):
27
+ # Clear CUDA cache before generating
28
+ torch.cuda.empty_cache()
29
+
30
+ # set seed
31
+ torch.manual_seed(seed)
32
+ np.random.seed(seed)
33
+ torch.cuda.manual_seed(seed)
34
+
35
+ conversation = [
36
+ {
37
+ "role": "User",
38
+ "content": f"<image_placeholder>\n{question}",
39
+ "images": [image],
40
+ },
41
+ {"role": "Assistant", "content": ""},
42
+ ]
43
+
44
+ pil_images = [Image.fromarray(image)]
45
+ prepare_inputs = vl_chat_processor(
46
+ conversations=conversation, images=pil_images, force_batchify=True
47
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
48
+
49
+
50
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
51
+
52
+ outputs = vl_gpt.language_model.generate(
53
+ inputs_embeds=inputs_embeds,
54
+ attention_mask=prepare_inputs.attention_mask,
55
+ pad_token_id=tokenizer.eos_token_id,
56
+ bos_token_id=tokenizer.bos_token_id,
57
+ eos_token_id=tokenizer.eos_token_id,
58
+ max_new_tokens=512,
59
+ do_sample=False if temperature == 0 else True,
60
+ use_cache=True,
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ )
64
+
65
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
66
+ return answer
67
+
68
+
69
+ def generate(input_ids,
70
+ width,
71
+ height,
72
+ temperature: float = 1,
73
+ parallel_size: int = 5,
74
+ cfg_weight: float = 5,
75
+ image_token_num_per_image: int = 576,
76
+ patch_size: int = 16):
77
+ # Clear CUDA cache before generating
78
+ torch.cuda.empty_cache()
79
+
80
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
81
+ for i in range(parallel_size * 2):
82
+ tokens[i, :] = input_ids
83
+ if i % 2 != 0:
84
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
85
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
86
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
87
+
88
+ pkv = None
89
+ for i in range(image_token_num_per_image):
90
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
91
+ use_cache=True,
92
+ past_key_values=pkv)
93
+ pkv = outputs.past_key_values
94
+ hidden_states = outputs.last_hidden_state
95
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
96
+ logit_cond = logits[0::2, :]
97
+ logit_uncond = logits[1::2, :]
98
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
99
+ probs = torch.softmax(logits / temperature, dim=-1)
100
+ next_token = torch.multinomial(probs, num_samples=1)
101
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
102
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
103
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
104
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
105
+ patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
106
+ shape=[parallel_size, 8, width // patch_size, height // patch_size])
107
+
108
+ return generated_tokens.to(dtype=torch.int), patches
109
+
110
+ def unpack(dec, width, height, parallel_size=5):
111
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
112
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
113
+
114
+ visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
115
+ visual_img[:, :, :] = dec
116
+
117
+ return visual_img
118
+
119
+
120
+
121
+ @torch.inference_mode()
122
+ def generate_image(prompt,
123
+ seed=None,
124
+ guidance=5):
125
+ # Clear CUDA cache and avoid tracking gradients
126
+ torch.cuda.empty_cache()
127
+ # Set the seed for reproducible results
128
+ if seed is not None:
129
+ torch.manual_seed(seed)
130
+ torch.cuda.manual_seed(seed)
131
+ np.random.seed(seed)
132
+ width = 384
133
+ height = 384
134
+ parallel_size = 5
135
+
136
+ with torch.no_grad():
137
+ messages = [{'role': 'User', 'content': prompt},
138
+ {'role': 'Assistant', 'content': ''}]
139
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
140
+ sft_format=vl_chat_processor.sft_format,
141
+ system_prompt='')
142
+ text = text + vl_chat_processor.image_start_tag
143
+ input_ids = torch.LongTensor(tokenizer.encode(text))
144
+ output, patches = generate(input_ids,
145
+ width // 16 * 16,
146
+ height // 16 * 16,
147
+ cfg_weight=guidance,
148
+ parallel_size=parallel_size)
149
+ images = unpack(patches,
150
+ width // 16 * 16,
151
+ height // 16 * 16)
152
+
153
+ return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
154
+
155
+
156
+
157
+ # Gradio interface
158
+ with gr.Blocks() as demo:
159
+ gr.Markdown(value="# Multimodal Understanding")
160
+ # with gr.Row():
161
+ with gr.Row():
162
+ image_input = gr.Image()
163
+ with gr.Column():
164
+ question_input = gr.Textbox(label="Question")
165
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
166
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
167
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
168
+
169
+ understanding_button = gr.Button("Chat")
170
+ understanding_output = gr.Textbox(label="Response")
171
+
172
+ examples_inpainting = gr.Examples(
173
+ label="Multimodal Understanding examples",
174
+ examples=[
175
+ [
176
+ "explain this meme",
177
+ "images/doge.png",
178
+ ],
179
+ [
180
+ "Convert the formula into latex code.",
181
+ "images/equation.png",
182
+ ],
183
+ ],
184
+ inputs=[question_input, image_input],
185
+ )
186
+
187
+
188
+ gr.Markdown(value="# Text-to-Image Generation")
189
+
190
+
191
+
192
+ with gr.Row():
193
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
194
+
195
+ prompt_input = gr.Textbox(label="Prompt")
196
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
197
+
198
+ generation_button = gr.Button("Generate Images")
199
+
200
+ image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
201
+
202
+ examples_t2i = gr.Examples(
203
+ label="Text to image generation examples. (Tips for designing prompts: Adding description like 'digital art' at the end of the prompt or writing the prompt in more detail can help produce better images!)",
204
+ examples=[
205
+ "Master shifu racoon wearing drip attire as a street gangster.",
206
+ "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
207
+ "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
208
+ ],
209
+ inputs=prompt_input,
210
+ )
211
+
212
+ understanding_button.click(
213
+ multimodal_understanding,
214
+ inputs=[image_input, question_input, und_seed_input, top_p, temperature],
215
+ outputs=understanding_output
216
+ )
217
+
218
+ generation_button.click(
219
+ fn=generate_image,
220
+ inputs=[prompt_input, seed_input, cfg_weight_input],
221
+ outputs=image_output
222
+ )
223
+
224
+ demo.launch(share=True)
app_janusflow.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
4
+ from PIL import Image
5
+ from diffusers.models import AutoencoderKL
6
+ import numpy as np
7
+
8
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
+ # Load model and processor
11
+ model_path = "deepseek-ai/Janus-1.3B"
12
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
13
+ tokenizer = vl_chat_processor.tokenizer
14
+
15
+ vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
16
+ vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
17
+
18
+ # remember to use bfloat16 dtype, this vae doesn't work with fp16
19
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
20
+ vae = vae.to(torch.bfloat16).to(cuda_device).eval()
21
+
22
+ # Multimodal Understanding function
23
+ @torch.inference_mode()
24
+ # Multimodal Understanding function
25
+ def multimodal_understanding(image, question, seed, top_p, temperature):
26
+ # Clear CUDA cache before generating
27
+ torch.cuda.empty_cache()
28
+
29
+ # set seed
30
+ torch.manual_seed(seed)
31
+ np.random.seed(seed)
32
+ torch.cuda.manual_seed(seed)
33
+
34
+ conversation = [
35
+ {
36
+ "role": "User",
37
+ "content": f"<image_placeholder>\n{question}",
38
+ "images": [image],
39
+ },
40
+ {"role": "Assistant", "content": ""},
41
+ ]
42
+
43
+ pil_images = [Image.fromarray(image)]
44
+ prepare_inputs = vl_chat_processor(
45
+ conversations=conversation, images=pil_images, force_batchify=True
46
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
47
+
48
+
49
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
50
+
51
+ outputs = vl_gpt.language_model.generate(
52
+ inputs_embeds=inputs_embeds,
53
+ attention_mask=prepare_inputs.attention_mask,
54
+ pad_token_id=tokenizer.eos_token_id,
55
+ bos_token_id=tokenizer.bos_token_id,
56
+ eos_token_id=tokenizer.eos_token_id,
57
+ max_new_tokens=512,
58
+ do_sample=False if temperature == 0 else True,
59
+ use_cache=True,
60
+ temperature=temperature,
61
+ top_p=top_p,
62
+ )
63
+
64
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
65
+
66
+ return answer
67
+
68
+
69
+ @torch.inference_mode()
70
+ def generate(
71
+ input_ids,
72
+ cfg_weight: float = 2.0,
73
+ num_inference_steps: int = 30
74
+ ):
75
+ # we generate 5 images at a time, *2 for CFG
76
+ tokens = torch.stack([input_ids] * 10).cuda()
77
+ tokens[5:, 1:] = vl_chat_processor.pad_id
78
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
79
+ print(inputs_embeds.shape)
80
+
81
+ # we remove the last <bog> token and replace it with t_emb later
82
+ inputs_embeds = inputs_embeds[:, :-1, :]
83
+
84
+ # generate with rectified flow ode
85
+ # step 1: encode with vision_gen_enc
86
+ z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
87
+
88
+ dt = 1.0 / num_inference_steps
89
+ dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
90
+
91
+ # step 2: run ode
92
+ attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
93
+ attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
94
+ attention_mask = attention_mask.int()
95
+ for step in range(num_inference_steps):
96
+ # prepare inputs for the llm
97
+ z_input = torch.cat([z, z], dim=0) # for cfg
98
+ t = step / num_inference_steps * 1000.
99
+ t = torch.tensor([t] * z_input.shape[0]).to(dt)
100
+ z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
101
+ z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
102
+ z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
103
+ z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
104
+ llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
105
+
106
+ # input to the llm
107
+ # we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
108
+ if step == 0:
109
+ outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
110
+ use_cache=True,
111
+ attention_mask=attention_mask,
112
+ past_key_values=None)
113
+ past_key_values = []
114
+ for kv_cache in past_key_values:
115
+ k, v = kv_cache[0], kv_cache[1]
116
+ past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
117
+ past_key_values = tuple(past_key_values)
118
+ else:
119
+ outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
120
+ use_cache=True,
121
+ attention_mask=attention_mask,
122
+ past_key_values=past_key_values)
123
+ hidden_states = outputs.last_hidden_state
124
+
125
+ # transform hidden_states back to v
126
+ hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
127
+ hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
128
+ v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
129
+ v_cond, v_uncond = torch.chunk(v, 2)
130
+ v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
131
+ z = z + dt * v
132
+
133
+ # step 3: decode with vision_gen_dec and sdxl vae
134
+ decoded_image = vae.decode(z / vae.config.scaling_factor).sample
135
+
136
+ images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
137
+ images = ((images+1) / 2. * 255).astype(np.uint8)
138
+
139
+ return images
140
+
141
+ def unpack(dec, width, height, parallel_size=5):
142
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
143
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
144
+
145
+ visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
146
+ visual_img[:, :, :] = dec
147
+
148
+ return visual_img
149
+
150
+
151
+ @torch.inference_mode()
152
+ def generate_image(prompt,
153
+ seed=None,
154
+ guidance=5,
155
+ num_inference_steps=30):
156
+ # Clear CUDA cache and avoid tracking gradients
157
+ torch.cuda.empty_cache()
158
+ # Set the seed for reproducible results
159
+ if seed is not None:
160
+ torch.manual_seed(seed)
161
+ torch.cuda.manual_seed(seed)
162
+ np.random.seed(seed)
163
+
164
+ with torch.no_grad():
165
+ messages = [{'role': 'User', 'content': prompt},
166
+ {'role': 'Assistant', 'content': ''}]
167
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
168
+ sft_format=vl_chat_processor.sft_format,
169
+ system_prompt='')
170
+ text = text + vl_chat_processor.image_start_tag
171
+ input_ids = torch.LongTensor(tokenizer.encode(text))
172
+ images = generate(input_ids,
173
+ cfg_weight=guidance,
174
+ num_inference_steps=num_inference_steps)
175
+ return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
176
+
177
+
178
+
179
+ # Gradio interface
180
+ with gr.Blocks() as demo:
181
+ gr.Markdown(value="# Multimodal Understanding")
182
+ # with gr.Row():
183
+ with gr.Row():
184
+ image_input = gr.Image()
185
+ with gr.Column():
186
+ question_input = gr.Textbox(label="Question")
187
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
188
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
189
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
190
+
191
+ understanding_button = gr.Button("Chat")
192
+ understanding_output = gr.Textbox(label="Response")
193
+
194
+ examples_inpainting = gr.Examples(
195
+ label="Multimodal Understanding examples",
196
+ examples=[
197
+ [
198
+ "explain this meme",
199
+ "./images/doge.png",
200
+ ],
201
+ [
202
+ "Convert the formula into latex code.",
203
+ "./images/equation.png",
204
+ ],
205
+ ],
206
+ inputs=[question_input, image_input],
207
+ )
208
+
209
+
210
+ gr.Markdown(value="# Text-to-Image Generation")
211
+
212
+
213
+
214
+ with gr.Row():
215
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
216
+ step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
217
+
218
+ prompt_input = gr.Textbox(label="Prompt")
219
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
220
+
221
+ generation_button = gr.Button("Generate Images")
222
+
223
+ image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
224
+
225
+ examples_t2i = gr.Examples(
226
+ label="Text to image generation examples.",
227
+ examples=[
228
+ "Master shifu racoon wearing drip attire as a street gangster.",
229
+ "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
230
+ "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
231
+ ],
232
+ inputs=prompt_input,
233
+ )
234
+
235
+ understanding_button.click(
236
+ multimodal_understanding,
237
+ inputs=[image_input, question_input, und_seed_input, top_p, temperature],
238
+ outputs=understanding_output
239
+ )
240
+
241
+ generation_button.click(
242
+ fn=generate_image,
243
+ inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
244
+ outputs=image_output
245
+ )
246
+
247
+ demo.launch(share=True)
app_januspro.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
5
+ from janus.utils.io import load_pil_images
6
+ from PIL import Image
7
+
8
+ import numpy as np
9
+ import os
10
+ import time
11
+ # import spaces # Import spaces for ZeroGPU compatibility
12
+
13
+
14
+ # Load model and processor
15
+ model_path = "deepseek-ai/Janus-1.3B"
16
+ config = AutoConfig.from_pretrained(model_path)
17
+ language_config = config.language_config
18
+ language_config._attn_implementation = 'eager'
19
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
20
+ language_config=language_config,
21
+ trust_remote_code=True)
22
+ if torch.cuda.is_available():
23
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
24
+ else:
25
+ vl_gpt = vl_gpt.to(torch.float16)
26
+ print(f"GPU: {torch.cuda.is_available()}, Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
27
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
28
+ tokenizer = vl_chat_processor.tokenizer
29
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ @torch.inference_mode()
32
+ # @spaces.GPU(duration=120)
33
+ # Multimodal Understanding function
34
+ def multimodal_understanding(image, question, seed, top_p, temperature):
35
+ # Clear CUDA cache before generating
36
+ torch.cuda.empty_cache()
37
+
38
+ # set seed
39
+ torch.manual_seed(seed)
40
+ np.random.seed(seed)
41
+ torch.cuda.manual_seed(seed)
42
+
43
+ conversation = [
44
+ {
45
+ "role": "<|User|>",
46
+ "content": f"<image_placeholder>\n{question}",
47
+ "images": [image],
48
+ },
49
+ {"role": "<|Assistant|>", "content": ""},
50
+ ]
51
+
52
+ pil_images = [Image.fromarray(image)]
53
+ prepare_inputs = vl_chat_processor(
54
+ conversations=conversation, images=pil_images, force_batchify=True
55
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
56
+
57
+
58
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
59
+
60
+ outputs = vl_gpt.language_model.generate(
61
+ inputs_embeds=inputs_embeds,
62
+ attention_mask=prepare_inputs.attention_mask,
63
+ pad_token_id=tokenizer.eos_token_id,
64
+ bos_token_id=tokenizer.bos_token_id,
65
+ eos_token_id=tokenizer.eos_token_id,
66
+ max_new_tokens=512,
67
+ do_sample=False if temperature == 0 else True,
68
+ use_cache=True,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ )
72
+
73
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
74
+ return answer
75
+
76
+
77
+ def generate(input_ids,
78
+ width,
79
+ height,
80
+ temperature: float = 1,
81
+ parallel_size: int = 1,
82
+ cfg_weight: float = 5,
83
+ image_token_num_per_image: int = 256,
84
+ patch_size: int = 16):
85
+ # Clear CUDA cache before generating
86
+ torch.cuda.empty_cache()
87
+
88
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
89
+ for i in range(parallel_size * 2):
90
+ tokens[i, :] = input_ids
91
+ if i % 2 != 0:
92
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
93
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
94
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
95
+
96
+ pkv = None
97
+ for i in range(image_token_num_per_image):
98
+ with torch.no_grad():
99
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
100
+ use_cache=True,
101
+ past_key_values=pkv)
102
+ pkv = outputs.past_key_values
103
+ hidden_states = outputs.last_hidden_state
104
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
105
+ logit_cond = logits[0::2, :]
106
+ logit_uncond = logits[1::2, :]
107
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
108
+ probs = torch.softmax(logits / temperature, dim=-1)
109
+ next_token = torch.multinomial(probs, num_samples=1)
110
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
111
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
112
+
113
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
114
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
115
+
116
+
117
+
118
+ patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
119
+ shape=[parallel_size, 8, width // patch_size, height // patch_size])
120
+
121
+ return generated_tokens.to(dtype=torch.int), patches
122
+
123
+ def unpack(dec, width, height, parallel_size=5):
124
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
125
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
126
+
127
+ visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
128
+ visual_img[:, :, :] = dec
129
+
130
+ return visual_img
131
+
132
+
133
+
134
+ @torch.inference_mode()
135
+ # @spaces.GPU(duration=120) # Specify a duration to avoid timeout
136
+ def generate_image(prompt,
137
+ seed=None,
138
+ guidance=5,
139
+ t2i_temperature=1.0):
140
+ # Clear CUDA cache and avoid tracking gradients
141
+ torch.cuda.empty_cache()
142
+ # Set the seed for reproducible results
143
+ if seed is not None:
144
+ torch.manual_seed(seed)
145
+ torch.cuda.manual_seed(seed)
146
+ np.random.seed(seed)
147
+ width = 256
148
+ height = 256
149
+ parallel_size = 1
150
+
151
+ with torch.no_grad():
152
+ messages = [{'role': '<|User|>', 'content': prompt},
153
+ {'role': '<|Assistant|>', 'content': ''}]
154
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
155
+ sft_format=vl_chat_processor.sft_format,
156
+ system_prompt='')
157
+ text = text + vl_chat_processor.image_start_tag
158
+
159
+ input_ids = torch.LongTensor(tokenizer.encode(text))
160
+ output, patches = generate(input_ids,
161
+ width // 16 * 16,
162
+ height // 16 * 16,
163
+ cfg_weight=guidance,
164
+ parallel_size=parallel_size,
165
+ temperature=t2i_temperature)
166
+ images = unpack(patches,
167
+ width // 16 * 16,
168
+ height // 16 * 16,
169
+ parallel_size=parallel_size)
170
+
171
+ return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
172
+
173
+
174
+ # Gradio interface
175
+ with gr.Blocks() as demo:
176
+ gr.Markdown(value="# Multimodal Understanding")
177
+ with gr.Row():
178
+ image_input = gr.Image()
179
+ with gr.Column():
180
+ question_input = gr.Textbox(label="Question")
181
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
182
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
183
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
184
+
185
+ understanding_button = gr.Button("Chat")
186
+ understanding_output = gr.Textbox(label="Response")
187
+
188
+ examples_inpainting = gr.Examples(
189
+ label="Multimodal Understanding examples",
190
+ examples=[
191
+ [
192
+ "explain this meme",
193
+ "images/doge.png",
194
+ ],
195
+ [
196
+ "Convert the formula into latex code.",
197
+ "images/equation.png",
198
+ ],
199
+ ],
200
+ inputs=[question_input, image_input],
201
+ )
202
+
203
+
204
+ gr.Markdown(value="# Text-to-Image Generation")
205
+
206
+
207
+
208
+ with gr.Row():
209
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
210
+ t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
211
+
212
+ prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
213
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
214
+
215
+ generation_button = gr.Button("Generate Images")
216
+
217
+ image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
218
+
219
+ examples_t2i = gr.Examples(
220
+ label="Text to image generation examples.",
221
+ examples=[
222
+ "Master shifu racoon wearing drip attire as a street gangster.",
223
+ "The face of a beautiful girl",
224
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
225
+ "A glass of red wine on a reflective surface.",
226
+ "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
227
+ "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
228
+ ],
229
+ inputs=prompt_input,
230
+ )
231
+
232
+ understanding_button.click(
233
+ multimodal_understanding,
234
+ inputs=[image_input, question_input, und_seed_input, top_p, temperature],
235
+ outputs=understanding_output
236
+ )
237
+
238
+ generation_button.click(
239
+ fn=generate_image,
240
+ inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
241
+ outputs=image_output
242
+ )
243
+
244
+ demo.launch(share=True)
245
+ # demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
fastapi_app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ import torch
4
+ from transformers import AutoConfig, AutoModelForCausalLM
5
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
6
+ from PIL import Image
7
+ import numpy as np
8
+ import io
9
+
10
+ app = FastAPI()
11
+
12
+ # Load model and processor
13
+ model_path = "deepseek-ai/Janus-1.3B"
14
+ config = AutoConfig.from_pretrained(model_path)
15
+ language_config = config.language_config
16
+ language_config._attn_implementation = 'eager'
17
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
18
+ language_config=language_config,
19
+ trust_remote_code=True)
20
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
21
+
22
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
23
+ tokenizer = vl_chat_processor.tokenizer
24
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+
26
+
27
+ @torch.inference_mode()
28
+ def multimodal_understanding(image_data, question, seed, top_p, temperature):
29
+ torch.cuda.empty_cache()
30
+ torch.manual_seed(seed)
31
+ np.random.seed(seed)
32
+ torch.cuda.manual_seed(seed)
33
+
34
+ conversation = [
35
+ {
36
+ "role": "User",
37
+ "content": f"<image_placeholder>\n{question}",
38
+ "images": [image_data],
39
+ },
40
+ {"role": "Assistant", "content": ""},
41
+ ]
42
+
43
+ pil_images = [Image.open(io.BytesIO(image_data))]
44
+ prepare_inputs = vl_chat_processor(
45
+ conversations=conversation, images=pil_images, force_batchify=True
46
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
47
+
48
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
49
+ outputs = vl_gpt.language_model.generate(
50
+ inputs_embeds=inputs_embeds,
51
+ attention_mask=prepare_inputs.attention_mask,
52
+ pad_token_id=tokenizer.eos_token_id,
53
+ bos_token_id=tokenizer.bos_token_id,
54
+ eos_token_id=tokenizer.eos_token_id,
55
+ max_new_tokens=512,
56
+ do_sample=False if temperature == 0 else True,
57
+ use_cache=True,
58
+ temperature=temperature,
59
+ top_p=top_p,
60
+ )
61
+
62
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
63
+ return answer
64
+
65
+
66
+ @app.post("/understand_image_and_question/")
67
+ async def understand_image_and_question(
68
+ file: UploadFile = File(...),
69
+ question: str = Form(...),
70
+ seed: int = Form(42),
71
+ top_p: float = Form(0.95),
72
+ temperature: float = Form(0.1)
73
+ ):
74
+ image_data = await file.read()
75
+ response = multimodal_understanding(image_data, question, seed, top_p, temperature)
76
+ return JSONResponse({"response": response})
77
+
78
+
79
+ def generate(input_ids,
80
+ width,
81
+ height,
82
+ temperature: float = 1,
83
+ parallel_size: int = 5,
84
+ cfg_weight: float = 5,
85
+ image_token_num_per_image: int = 576,
86
+ patch_size: int = 16):
87
+ torch.cuda.empty_cache()
88
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
89
+ for i in range(parallel_size * 2):
90
+ tokens[i, :] = input_ids
91
+ if i % 2 != 0:
92
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
93
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
94
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
95
+
96
+ pkv = None
97
+ for i in range(image_token_num_per_image):
98
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
99
+ pkv = outputs.past_key_values
100
+ hidden_states = outputs.last_hidden_state
101
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
102
+ logit_cond = logits[0::2, :]
103
+ logit_uncond = logits[1::2, :]
104
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
105
+ probs = torch.softmax(logits / temperature, dim=-1)
106
+ next_token = torch.multinomial(probs, num_samples=1)
107
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
108
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
109
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
110
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
111
+ patches = vl_gpt.gen_vision_model.decode_code(
112
+ generated_tokens.to(dtype=torch.int),
113
+ shape=[parallel_size, 8, width // patch_size, height // patch_size]
114
+ )
115
+
116
+ return generated_tokens.to(dtype=torch.int), patches
117
+
118
+
119
+ def unpack(dec, width, height, parallel_size=5):
120
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
121
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
122
+
123
+ visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
124
+ visual_img[:, :, :] = dec
125
+
126
+ return visual_img
127
+
128
+
129
+ @torch.inference_mode()
130
+ def generate_image(prompt, seed, guidance):
131
+ torch.cuda.empty_cache()
132
+ seed = seed if seed is not None else 12345
133
+ torch.manual_seed(seed)
134
+ torch.cuda.manual_seed(seed)
135
+ np.random.seed(seed)
136
+ width = 384
137
+ height = 384
138
+ parallel_size = 5
139
+
140
+ with torch.no_grad():
141
+ messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
142
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
143
+ conversations=messages,
144
+ sft_format=vl_chat_processor.sft_format,
145
+ system_prompt=''
146
+ )
147
+ text = text + vl_chat_processor.image_start_tag
148
+ input_ids = torch.LongTensor(tokenizer.encode(text))
149
+ _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
150
+ images = unpack(patches, width // 16 * 16, height // 16 * 16)
151
+
152
+ return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
153
+
154
+
155
+ @app.post("/generate_images/")
156
+ async def generate_images(
157
+ prompt: str = Form(...),
158
+ seed: int = Form(None),
159
+ guidance: float = Form(5.0),
160
+ ):
161
+ try:
162
+ images = generate_image(prompt, seed, guidance)
163
+ def image_stream():
164
+ for img in images:
165
+ buf = io.BytesIO()
166
+ img.save(buf, format='PNG')
167
+ buf.seek(0)
168
+ yield buf.read()
169
+
170
+ return StreamingResponse(image_stream(), media_type="multipart/related")
171
+ except Exception as e:
172
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
173
+
174
+
175
+
176
+ if __name__ == "__main__":
177
+ import uvicorn
178
+ uvicorn.run(app, host="0.0.0.0", port=8000)
fastapi_client.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ import io
4
+ # Endpoint URLs
5
+ understand_image_url = "http://localhost:8000/understand_image_and_question/"
6
+ generate_images_url = "http://localhost:8000/generate_images/"
7
+
8
+ # Use your image file path here
9
+ image_path = "images/equation.png"
10
+
11
+ # Function to call the image understanding endpoint
12
+ def understand_image_and_question(image_path, question, seed=42, top_p=0.95, temperature=0.1):
13
+ files = {'file': open(image_path, 'rb')}
14
+ data = {
15
+ 'question': question,
16
+ 'seed': seed,
17
+ 'top_p': top_p,
18
+ 'temperature': temperature
19
+ }
20
+ response = requests.post(understand_image_url, files=files, data=data)
21
+ response_data = response.json()
22
+ print("Image Understanding Response:", response_data['response'])
23
+
24
+
25
+ # Function to call the text-to-image generation endpoint
26
+ def generate_images(prompt, seed=None, guidance=5.0):
27
+ data = {
28
+ 'prompt': prompt,
29
+ 'seed': seed,
30
+ 'guidance': guidance
31
+ }
32
+ response = requests.post(generate_images_url, data=data, stream=True)
33
+
34
+ if response.ok:
35
+ img_idx = 1
36
+
37
+ # We will create a new BytesIO for each image
38
+ buffers = {}
39
+
40
+ try:
41
+ for chunk in response.iter_content(chunk_size=1024):
42
+ if chunk:
43
+ # Use a boundary detection to determine new image start
44
+ if img_idx not in buffers:
45
+ buffers[img_idx] = io.BytesIO()
46
+
47
+ buffers[img_idx].write(chunk)
48
+
49
+ # Attempt to open the image
50
+ try:
51
+ buffer = buffers[img_idx]
52
+ buffer.seek(0)
53
+ image = Image.open(buffer)
54
+ img_path = f"generated_image_{img_idx}.png"
55
+ image.save(img_path)
56
+ print(f"Saved: {img_path}")
57
+
58
+ # Prepare the next image buffer
59
+ buffer.close()
60
+ img_idx += 1
61
+
62
+ except Exception as e:
63
+ # Continue loading data into the current buffer
64
+ continue
65
+
66
+ except Exception as e:
67
+ print("Error processing image:", e)
68
+ else:
69
+ print("Failed to generate images.")
70
+
71
+
72
+ # Example usage
73
+ if __name__ == "__main__":
74
+ # Call the image understanding API
75
+ understand_image_and_question(image_path, "What is this image about?")
76
+
77
+ # Call the image generation API
78
+ generate_images("A beautiful sunset over a mountain range, digital art.")